[flang] Take result length into account in ApplyElementwise folding

ApplyElementwise on character operation was always creating a result
ArrayConstructor with the length of the left operand. This is not
correct for concatenation and SetLength operations.

Compute and thread the length to the spot creating the ArrayConstructor
so that the length is correct for those character operations.

Differential Revision: https://reviews.llvm.org/D108711
This commit is contained in:
Jean Perier
2021-08-26 09:44:24 +02:00
parent fdefde4965
commit 9016b2a1ca
2 changed files with 60 additions and 14 deletions

View File

@@ -898,12 +898,24 @@ Expr<RESULT> MapOperation(FoldingContext &context,
context, std::move(result), AsConstantExtents(context, shape));
}
template <typename RESULT, typename A>
ArrayConstructor<RESULT> ArrayConstructorFromMold(
const A &prototype, std::optional<Expr<SubscriptInteger>> &&length) {
if constexpr (RESULT::category == TypeCategory::Character) {
return ArrayConstructor<RESULT>{
std::move(length.value()), ArrayConstructorValues<RESULT>{}};
} else {
return ArrayConstructor<RESULT>{prototype};
}
}
// array * array case
template <typename RESULT, typename LEFT, typename RIGHT>
Expr<RESULT> MapOperation(FoldingContext &context,
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
const Shape &shape, Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
ArrayConstructor<RESULT> result{leftValues};
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
std::visit(
@@ -942,9 +954,9 @@ Expr<RESULT> MapOperation(FoldingContext &context,
template <typename RESULT, typename LEFT, typename RIGHT>
Expr<RESULT> MapOperation(FoldingContext &context,
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
const Shape &shape, Expr<LEFT> &&leftValues,
const Expr<RIGHT> &rightScalar) {
ArrayConstructor<RESULT> result{leftValues};
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
Expr<LEFT> &&leftValues, const Expr<RIGHT> &rightScalar) {
auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
for (auto &leftValue : leftArrConst) {
auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
@@ -959,9 +971,9 @@ Expr<RESULT> MapOperation(FoldingContext &context,
template <typename RESULT, typename LEFT, typename RIGHT>
Expr<RESULT> MapOperation(FoldingContext &context,
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
const Shape &shape, const Expr<LEFT> &leftScalar,
Expr<RIGHT> &&rightValues) {
ArrayConstructor<RESULT> result{leftScalar};
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
const Expr<LEFT> &leftScalar, Expr<RIGHT> &&rightValues) {
auto result{ArrayConstructorFromMold<RESULT>(leftScalar, std::move(length))};
if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
std::visit(
[&](auto &&kindExpr) {
@@ -987,6 +999,15 @@ Expr<RESULT> MapOperation(FoldingContext &context,
context, std::move(result), AsConstantExtents(context, shape));
}
template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
std::optional<Expr<SubscriptInteger>> ComputeResultLength(
Operation<DERIVED, RESULT, LEFT, RIGHT> &operation) {
if constexpr (RESULT::category == TypeCategory::Character) {
return Expr<RESULT>{operation.derived()}.LEN();
}
return std::nullopt;
}
// ApplyElementwise() recursively folds the operand expression(s) of an
// operation, then attempts to apply the operation to the (corresponding)
// scalar element(s) of those operands. Returns std::nullopt for scalars
@@ -1024,6 +1045,7 @@ auto ApplyElementwise(FoldingContext &context,
Operation<DERIVED, RESULT, LEFT, RIGHT> &operation,
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f)
-> std::optional<Expr<RESULT>> {
auto resultLength{ComputeResultLength(operation)};
auto &leftExpr{operation.left()};
leftExpr = Fold(context, std::move(leftExpr));
auto &rightExpr{operation.right()};
@@ -1038,25 +1060,26 @@ auto ApplyElementwise(FoldingContext &context,
CheckConformanceFlags::EitherScalarExpandable)
.value_or(false /*fail if not known now to conform*/)) {
return MapOperation(context, std::move(f), *leftShape,
std::move(*left), std::move(*right));
std::move(resultLength), std::move(*left),
std::move(*right));
} else {
return std::nullopt;
}
return MapOperation(context, std::move(f), *leftShape,
std::move(*left), std::move(*right));
std::move(resultLength), std::move(*left), std::move(*right));
}
}
} else if (IsExpandableScalar(rightExpr)) {
return MapOperation(
context, std::move(f), *leftShape, std::move(*left), rightExpr);
return MapOperation(context, std::move(f), *leftShape,
std::move(resultLength), std::move(*left), rightExpr);
}
}
}
} else if (rightExpr.Rank() > 0 && IsExpandableScalar(leftExpr)) {
if (std::optional<Shape> shape{GetShape(context, rightExpr)}) {
if (auto right{AsFlatArrayConstructor(rightExpr)}) {
return MapOperation(
context, std::move(f), *shape, leftExpr, std::move(*right));
return MapOperation(context, std::move(f), *shape,
std::move(resultLength), leftExpr, std::move(*right));
}
}
}

View File

@@ -0,0 +1,23 @@
! RUN: %S/test_folding.sh %s %t %flang_fc1
! REQUIRES: shell
! Test character concatenation folding
logical, parameter :: test_scalar_scalar = ('ab' // 'cde').eq.('abcde')
character(2), parameter :: scalar_array(2) = ['1','2'] // 'a'
logical, parameter :: test_scalar_array = all(scalar_array.eq.(['1a', '2a']))
character(2), parameter :: array_scalar(2) = '1' // ['a', 'b']
logical, parameter :: test_array_scalar = all(array_scalar.eq.(['1a', '1b']))
character(2), parameter :: array_array(2) = ['1','2'] // ['a', 'b']
logical, parameter :: test_array_array = all(array_array.eq.(['1a', '2b']))
character(1), parameter :: input(2) = ['x', 'y']
character(*), parameter :: zero_sized(*) = input(2:1:1) // 'abcde'
logical, parameter :: test_zero_sized = len(zero_sized).eq.6
end