[mlir][Arith] ValueBoundsOpInterface: Support arith.select (#86383)

This commit adds a `ValueBoundsOpInterface` implementation for
`arith.select`. The implementation is almost identical to `scf.if`
(#85895), but there is one special case: if the condition is a shaped
value, the selection is applied element-wise and the result shape can be
inferred from either operand.
This commit is contained in:
Matthias Springer
2024-04-05 13:39:14 +09:00
committed by GitHub
parent 49f0b536fd
commit 62b58d3418
2 changed files with 101 additions and 0 deletions

View File

@@ -66,6 +66,75 @@ struct MulIOpInterface
}
};
struct SelectOpInterface
: public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
SelectOp> {
static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
Value value = selectOp.getResult();
Value condition = selectOp.getCondition();
Value trueValue = selectOp.getTrueValue();
Value falseValue = selectOp.getFalseValue();
if (isa<ShapedType>(condition.getType())) {
// If the condition is a shaped type, the condition is applied
// element-wise. All three operands must have the same shape.
cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
return;
}
// Populate constraints for the true/false values (and all values on the
// backward slice, as long as the current stop condition is not satisfied).
cstr.populateConstraints(trueValue, dim);
cstr.populateConstraints(falseValue, dim);
auto boundsBuilder = cstr.bound(value);
if (dim)
boundsBuilder[*dim];
// Compare yielded values.
// If trueValue <= falseValue:
// * result <= falseValue
// * result >= trueValue
if (cstr.compare(trueValue, dim,
ValueBoundsConstraintSet::ComparisonOperator::LE,
falseValue, dim)) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
} else {
cstr.bound(value) >= trueValue;
cstr.bound(value) <= falseValue;
}
}
// If falseValue <= trueValue:
// * result <= trueValue
// * result >= falseValue
if (cstr.compare(falseValue, dim,
ValueBoundsConstraintSet::ComparisonOperator::LE,
trueValue, dim)) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
} else {
cstr.bound(value) >= falseValue;
cstr.bound(value) <= trueValue;
}
}
}
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
}
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<SelectOp>(op), dim, cstr);
}
};
} // namespace
} // namespace arith
} // namespace mlir
@@ -77,5 +146,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
});
}

View File

@@ -74,3 +74,34 @@ func.func @arith_const() -> index {
%0 = "test.reify_bound"(%c5) : (index) -> (index)
return %0 : index
}
// -----
// CHECK-LABEL: func @arith_select(
func.func @arith_select(%c: i1) -> (index, index) {
// CHECK: arith.constant 5 : index
%c5 = arith.constant 5 : index
// CHECK: arith.constant 9 : index
%c9 = arith.constant 9 : index
%r = arith.select %c, %c5, %c9 : index
// CHECK: %[[c5:.*]] = arith.constant 5 : index
// CHECK: %[[c10:.*]] = arith.constant 10 : index
%0 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
%1 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
// CHECK: return %[[c5]], %[[c10]]
return %0, %1 : index, index
}
// -----
// CHECK-LABEL: func @arith_select_elementwise(
// CHECK-SAME: %[[a:.*]]: tensor<?xf32>, %[[b:.*]]: tensor<?xf32>, %[[c:.*]]: tensor<?xi1>)
func.func @arith_select_elementwise(%a: tensor<?xf32>, %b: tensor<?xf32>, %c: tensor<?xi1>) -> index {
%r = arith.select %c, %a, %b : tensor<?xi1>, tensor<?xf32>
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[dim:.*]] = tensor.dim %[[a]], %[[c0]]
%0 = "test.reify_bound"(%r) {type = "EQ", dim = 0}
: (tensor<?xf32>) -> (index)
// CHECK: return %[[dim]]
return %0 : index
}