[mlir] Add min/max operations to Standard.

[RFC: Add min/max ops](https://llvm.discourse.group/t/rfc-add-min-max-operations/4353)

I was following the naming style for Arith dialect in
https://reviews.llvm.org/D110200,
i.e. similar to DivSIOp and DivUIOp I defined MaxSIOp, MaxUIOp.

When Arith PR is landed, I will migrate these ops as well.

Differential Revision: https://reviews.llvm.org/D110540
This commit is contained in:
Alexander Belyaev
2021-09-28 09:22:39 +02:00
parent 20c0280733
commit 9fb57c8c1d
5 changed files with 334 additions and 32 deletions

View File

@@ -344,33 +344,6 @@ possible to store the predicate as string attribute, it would have rendered
impossible to implement switching logic based on the comparison kind and made
attribute validity checks (one out of ten possible kinds) more complex.
### 'select' operation to implement min/max
Although `min` and `max` operations are likely to occur as a result of
transforming affine loops in ML functions, we did not make them first-class
operations. Instead, we provide the `select` operation that can be combined with
`cmpi` to implement the minimum and maximum computation. Although they now
require two operations, they are likely to be emitted automatically during the
transformation inside MLIR. On the other hand, there are multiple benefits of
introducing `select`: standalone min/max would concern themselves with the
signedness of the comparison, already taken into account by `cmpi`; `select` can
support floats transparently if used after a float-comparison operation; the
lower-level targets provide `select`-like instructions making the translation
trivial.
This operation could have been implemented with additional control flow: `%r =
select %cond, %t, %f` is equivalent to
```mlir
^bb0:
cond_br %cond, ^bb1(%t), ^bb1(%f)
^bb1(%r):
```
However, this control flow granularity is not available in the ML functions
where min/max, and thus `select`, are likely to appear. In addition, simpler
control flow may be beneficial for optimization in general.
### Regions
#### Attributes of type 'Block'

View File

@@ -1247,6 +1247,152 @@ def IndexCastOp : ArithmeticCastOp<"index_cast"> {
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// MaxFOp
//===----------------------------------------------------------------------===//
def MaxFOp : FloatBinaryOp<"maxf"> {
let summary = "floating-point maximum operation";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type
```
Returns the maximum of the two arguments, treating -0.0 as less than +0.0.
If one of the arguments is NaN, then the result is also NaN.
Example:
```mlir
// Scalar floating-point maximum.
%a = maxf %b, %c : f64
```
}];
}
//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//
def MaxSIOp : IntBinaryOp<"maxsi"> {
let summary = "signed integer maximum operation";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type
```
Returns the larger of %a and %b comparing the values as signed integers.
Example:
```mlir
// Scalar signed integer maximum.
%a = maxsi %b, %c : i64
```
}];
}
//===----------------------------------------------------------------------===//
// MaxUIOp
//===----------------------------------------------------------------------===//
def MaxUIOp : IntBinaryOp<"maxui"> {
let summary = "unsigned integer maximum operation";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type
```
Returns the larger of %a and %b comparing the values as unsigned integers.
Example:
```mlir
// Scalar unsigned integer maximum.
%a = maxui %b, %c : i64
```
}];
}
//===----------------------------------------------------------------------===//
// MinFOp
//===----------------------------------------------------------------------===//
def MinFOp : FloatBinaryOp<"minf"> {
let summary = "floating-point minimum operation";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type
```
Returns the minimum of the two arguments, treating -0.0 as less than +0.0.
If one of the arguments is NaN, then the result is also NaN.
Example:
```mlir
// Scalar floating-point minimum.
%a = minf %b, %c : f64
```
}];
}
//===----------------------------------------------------------------------===//
// MinSIOp
//===----------------------------------------------------------------------===//
def MinSIOp : IntBinaryOp<"minsi"> {
let summary = "signed integer minimum operation";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `minsi` ssa-use `,` ssa-use `:` type
```
Returns the smaller of %a and %b comparing the values as signed integers.
Example:
```mlir
// Scalar signed integer minimum.
%a = minsi %b, %c : i64
```
}];
}
//===----------------------------------------------------------------------===//
// MinUIOp
//===----------------------------------------------------------------------===//
def MinUIOp : IntBinaryOp<"minui"> {
let summary = "unsigned integer minimum operation";
let description = [{
Syntax:
```
operation ::= ssa-id `=` `minui` ssa-use `,` ssa-use `:` type
```
Returns the smaller of %a and %b comparing the values as unsigned integers.
Example:
```mlir
// Scalar unsigned integer minimum.
%a = minui %b, %c : i64
```
}];
}
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//

View File

@@ -215,6 +215,55 @@ public:
}
};
static Type getElementTypeOrSelf(Type type) {
if (auto st = type.dyn_cast<ShapedType>())
return st.getElementType();
return type;
}
template <typename OpTy, CmpFPredicate pred>
struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const final {
Value lhs = op.lhs();
Value rhs = op.rhs();
Location loc = op.getLoc();
Value cmp = rewriter.create<CmpFOp>(loc, pred, lhs, rhs);
Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
Value isNaN = rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, lhs, rhs);
Value nan = rewriter.create<ConstantFloatOp>(
loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType);
if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>())
nan = rewriter.create<SplatOp>(loc, vectorType, nan);
rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select);
return success();
}
};
template <typename OpTy, CmpIPredicate pred>
struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const final {
Value lhs = op.lhs();
Value rhs = op.rhs();
Location loc = op.getLoc();
Value cmp = rewriter.create<CmpIOp>(loc, pred, lhs, rhs);
rewriter.replaceOpWithNewOp<SelectOp>(op, cmp, lhs, rhs);
return success();
}
};
struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
void runOnFunction() override {
MLIRContext &ctx = getContext();
@@ -232,8 +281,18 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
return !op.shape().getType().cast<MemRefType>().hasStaticShape();
});
target.addIllegalOp<SignedCeilDivIOp>();
target.addIllegalOp<SignedFloorDivIOp>();
// clang-format off
target.addIllegalOp<
MaxFOp,
MaxSIOp,
MaxUIOp,
MinFOp,
MinSIOp,
MinUIOp,
SignedCeilDivIOp,
SignedFloorDivIOp
>();
// clang-format on
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
@@ -243,9 +302,20 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
} // namespace
void mlir::populateStdExpandOpsPatterns(RewritePatternSet &patterns) {
patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter,
SignedCeilDivIOpConverter, SignedFloorDivIOpConverter>(
patterns.getContext());
// clang-format off
patterns.add<
AtomicRMWOpConverter,
MaxMinFOpConverter<MaxFOp, CmpFPredicate::OGT>,
MaxMinFOpConverter<MinFOp, CmpFPredicate::OLT>,
MaxMinIOpConverter<MaxSIOp, CmpIPredicate::sgt>,
MaxMinIOpConverter<MaxUIOp, CmpIPredicate::ugt>,
MaxMinIOpConverter<MinSIOp, CmpIPredicate::slt>,
MaxMinIOpConverter<MinUIOp, CmpIPredicate::ult>,
MemRefReshapeOpConverter,
SignedCeilDivIOpConverter,
SignedFloorDivIOpConverter
>(patterns.getContext());
// clang-format on
}
std::unique_ptr<Pass> mlir::createStdExpandOpsPass() {

View File

@@ -109,3 +109,92 @@ func @memref_reshape(%input: memref<*xf32>,
// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>
// -----
// CHECK-LABEL: func @maxf
func @maxf(%a: f32, %b: f32) -> f32 {
%result = maxf(%a, %b): (f32, f32) -> f32
return %result : f32
}
// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
// CHECK-NEXT: %[[CMP:.*]] = cmpf ogt, %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @maxf_vector
func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
%result = maxf(%a, %b): (vector<4xf16>, vector<4xf16>) -> vector<4xf16>
return %result : vector<4xf16>
}
// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>)
// CHECK-NEXT: %[[CMP:.*]] = cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16>
// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]]
// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16>
// CHECK-NEXT: %[[NAN:.*]] = constant 0x7E00 : f16
// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16>
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]]
// CHECK-NEXT: return %[[RESULT]] : vector<4xf16>
// -----
// CHECK-LABEL: func @minf
func @minf(%a: f32, %b: f32) -> f32 {
%result = minf(%a, %b): (f32, f32) -> f32
return %result : f32
}
// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
// CHECK-NEXT: %[[CMP:.*]] = cmpf olt, %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @maxsi
func @maxsi(%a: i32, %b: i32) -> i32 {
%result = maxsi(%a, %b): (i32, i32) -> i32
return %result : i32
}
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS]], %[[RHS]] : i32
// -----
// CHECK-LABEL: func @minsi
func @minsi(%a: i32, %b: i32) -> i32 {
%result = minsi(%a, %b): (i32, i32) -> i32
return %result : i32
}
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
// CHECK-NEXT: %[[CMP:.*]] = cmpi slt, %[[LHS]], %[[RHS]] : i32
// -----
// CHECK-LABEL: func @maxui
func @maxui(%a: i32, %b: i32) -> i32 {
%result = maxui(%a, %b): (i32, i32) -> i32
return %result : i32
}
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
// CHECK-NEXT: %[[CMP:.*]] = cmpi ugt, %[[LHS]], %[[RHS]] : i32
// -----
// CHECK-LABEL: func @minui
func @minui(%a: i32, %b: i32) -> i32 {
%result = minui(%a, %b): (i32, i32) -> i32
return %result : i32
}
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
// CHECK-NEXT: %[[CMP:.*]] = cmpi ult, %[[LHS]], %[[RHS]] : i32

View File

@@ -86,3 +86,27 @@ func @bitcast(%arg : f32) -> i32 {
%res = bitcast %arg : f32 to i32
return %res : i32
}
// CHECK-LABEL: func @maximum
func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
%f1: f32, %f2: f32,
%i1: i32, %i2: i32) {
%max_vector = maxf(%v1, %v2)
: (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
%max_float = maxf(%f1, %f2) : (f32, f32) -> f32
%max_signed = maxsi(%i1, %i2) : (i32, i32) -> i32
%max_unsigned = maxui(%i1, %i2) : (i32, i32) -> i32
return
}
// CHECK-LABEL: func @minimum
func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
%f1: f32, %f2: f32,
%i1: i32, %i2: i32) {
%min_vector = minf(%v1, %v2)
: (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
%min_float = minf(%f1, %f2) : (f32, f32) -> f32
%min_signed = minsi(%i1, %i2) : (i32, i32) -> i32
%min_unsigned = minui(%i1, %i2) : (i32, i32) -> i32
return
}