mirror of
https://github.com/intel/llvm.git
synced 2026-01-22 07:01:03 +08:00
[flang] Take result length into account in ApplyElementwise folding
ApplyElementwise on character operation was always creating a result ArrayConstructor with the length of the left operand. This is not correct for concatenation and SetLength operations. Compute and thread the length to the spot creating the ArrayConstructor so that the length is correct for those character operations. Differential Revision: https://reviews.llvm.org/D108711
This commit is contained in:
@@ -898,12 +898,24 @@ Expr<RESULT> MapOperation(FoldingContext &context,
|
||||
context, std::move(result), AsConstantExtents(context, shape));
|
||||
}
|
||||
|
||||
template <typename RESULT, typename A>
|
||||
ArrayConstructor<RESULT> ArrayConstructorFromMold(
|
||||
const A &prototype, std::optional<Expr<SubscriptInteger>> &&length) {
|
||||
if constexpr (RESULT::category == TypeCategory::Character) {
|
||||
return ArrayConstructor<RESULT>{
|
||||
std::move(length.value()), ArrayConstructorValues<RESULT>{}};
|
||||
} else {
|
||||
return ArrayConstructor<RESULT>{prototype};
|
||||
}
|
||||
}
|
||||
|
||||
// array * array case
|
||||
template <typename RESULT, typename LEFT, typename RIGHT>
|
||||
Expr<RESULT> MapOperation(FoldingContext &context,
|
||||
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
|
||||
const Shape &shape, Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
|
||||
ArrayConstructor<RESULT> result{leftValues};
|
||||
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
|
||||
Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
|
||||
auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
|
||||
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
|
||||
if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
|
||||
std::visit(
|
||||
@@ -942,9 +954,9 @@ Expr<RESULT> MapOperation(FoldingContext &context,
|
||||
template <typename RESULT, typename LEFT, typename RIGHT>
|
||||
Expr<RESULT> MapOperation(FoldingContext &context,
|
||||
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
|
||||
const Shape &shape, Expr<LEFT> &&leftValues,
|
||||
const Expr<RIGHT> &rightScalar) {
|
||||
ArrayConstructor<RESULT> result{leftValues};
|
||||
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
|
||||
Expr<LEFT> &&leftValues, const Expr<RIGHT> &rightScalar) {
|
||||
auto result{ArrayConstructorFromMold<RESULT>(leftValues, std::move(length))};
|
||||
auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
|
||||
for (auto &leftValue : leftArrConst) {
|
||||
auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
|
||||
@@ -959,9 +971,9 @@ Expr<RESULT> MapOperation(FoldingContext &context,
|
||||
template <typename RESULT, typename LEFT, typename RIGHT>
|
||||
Expr<RESULT> MapOperation(FoldingContext &context,
|
||||
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
|
||||
const Shape &shape, const Expr<LEFT> &leftScalar,
|
||||
Expr<RIGHT> &&rightValues) {
|
||||
ArrayConstructor<RESULT> result{leftScalar};
|
||||
const Shape &shape, std::optional<Expr<SubscriptInteger>> &&length,
|
||||
const Expr<LEFT> &leftScalar, Expr<RIGHT> &&rightValues) {
|
||||
auto result{ArrayConstructorFromMold<RESULT>(leftScalar, std::move(length))};
|
||||
if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
|
||||
std::visit(
|
||||
[&](auto &&kindExpr) {
|
||||
@@ -987,6 +999,15 @@ Expr<RESULT> MapOperation(FoldingContext &context,
|
||||
context, std::move(result), AsConstantExtents(context, shape));
|
||||
}
|
||||
|
||||
template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
|
||||
std::optional<Expr<SubscriptInteger>> ComputeResultLength(
|
||||
Operation<DERIVED, RESULT, LEFT, RIGHT> &operation) {
|
||||
if constexpr (RESULT::category == TypeCategory::Character) {
|
||||
return Expr<RESULT>{operation.derived()}.LEN();
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// ApplyElementwise() recursively folds the operand expression(s) of an
|
||||
// operation, then attempts to apply the operation to the (corresponding)
|
||||
// scalar element(s) of those operands. Returns std::nullopt for scalars
|
||||
@@ -1024,6 +1045,7 @@ auto ApplyElementwise(FoldingContext &context,
|
||||
Operation<DERIVED, RESULT, LEFT, RIGHT> &operation,
|
||||
std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f)
|
||||
-> std::optional<Expr<RESULT>> {
|
||||
auto resultLength{ComputeResultLength(operation)};
|
||||
auto &leftExpr{operation.left()};
|
||||
leftExpr = Fold(context, std::move(leftExpr));
|
||||
auto &rightExpr{operation.right()};
|
||||
@@ -1038,25 +1060,26 @@ auto ApplyElementwise(FoldingContext &context,
|
||||
CheckConformanceFlags::EitherScalarExpandable)
|
||||
.value_or(false /*fail if not known now to conform*/)) {
|
||||
return MapOperation(context, std::move(f), *leftShape,
|
||||
std::move(*left), std::move(*right));
|
||||
std::move(resultLength), std::move(*left),
|
||||
std::move(*right));
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
return MapOperation(context, std::move(f), *leftShape,
|
||||
std::move(*left), std::move(*right));
|
||||
std::move(resultLength), std::move(*left), std::move(*right));
|
||||
}
|
||||
}
|
||||
} else if (IsExpandableScalar(rightExpr)) {
|
||||
return MapOperation(
|
||||
context, std::move(f), *leftShape, std::move(*left), rightExpr);
|
||||
return MapOperation(context, std::move(f), *leftShape,
|
||||
std::move(resultLength), std::move(*left), rightExpr);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (rightExpr.Rank() > 0 && IsExpandableScalar(leftExpr)) {
|
||||
if (std::optional<Shape> shape{GetShape(context, rightExpr)}) {
|
||||
if (auto right{AsFlatArrayConstructor(rightExpr)}) {
|
||||
return MapOperation(
|
||||
context, std::move(f), *shape, leftExpr, std::move(*right));
|
||||
return MapOperation(context, std::move(f), *shape,
|
||||
std::move(resultLength), leftExpr, std::move(*right));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
23
flang/test/Evaluate/folding22.f90
Normal file
23
flang/test/Evaluate/folding22.f90
Normal file
@@ -0,0 +1,23 @@
|
||||
! RUN: %S/test_folding.sh %s %t %flang_fc1
|
||||
! REQUIRES: shell
|
||||
|
||||
! Test character concatenation folding
|
||||
|
||||
logical, parameter :: test_scalar_scalar = ('ab' // 'cde').eq.('abcde')
|
||||
|
||||
character(2), parameter :: scalar_array(2) = ['1','2'] // 'a'
|
||||
logical, parameter :: test_scalar_array = all(scalar_array.eq.(['1a', '2a']))
|
||||
|
||||
character(2), parameter :: array_scalar(2) = '1' // ['a', 'b']
|
||||
logical, parameter :: test_array_scalar = all(array_scalar.eq.(['1a', '1b']))
|
||||
|
||||
character(2), parameter :: array_array(2) = ['1','2'] // ['a', 'b']
|
||||
logical, parameter :: test_array_array = all(array_array.eq.(['1a', '2b']))
|
||||
|
||||
|
||||
character(1), parameter :: input(2) = ['x', 'y']
|
||||
character(*), parameter :: zero_sized(*) = input(2:1:1) // 'abcde'
|
||||
logical, parameter :: test_zero_sized = len(zero_sized).eq.6
|
||||
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user