mirror of
https://github.com/intel/llvm.git
synced 2026-01-23 07:58:23 +08:00
[mlir] Add min/max operations to Standard.
[RFC: Add min/max ops](https://llvm.discourse.group/t/rfc-add-min-max-operations/4353) I was following the naming style for Arith dialect in https://reviews.llvm.org/D110200, i.e. similar to DivSIOp and DivUIOp I defined MaxSIOp, MaxUIOp. When Arith PR is landed, I will migrate these ops as well. Differential Revision: https://reviews.llvm.org/D110540
This commit is contained in:
@@ -344,33 +344,6 @@ possible to store the predicate as string attribute, it would have rendered
|
||||
impossible to implement switching logic based on the comparison kind and made
|
||||
attribute validity checks (one out of ten possible kinds) more complex.
|
||||
|
||||
### 'select' operation to implement min/max
|
||||
|
||||
Although `min` and `max` operations are likely to occur as a result of
|
||||
transforming affine loops in ML functions, we did not make them first-class
|
||||
operations. Instead, we provide the `select` operation that can be combined with
|
||||
`cmpi` to implement the minimum and maximum computation. Although they now
|
||||
require two operations, they are likely to be emitted automatically during the
|
||||
transformation inside MLIR. On the other hand, there are multiple benefits of
|
||||
introducing `select`: standalone min/max would concern themselves with the
|
||||
signedness of the comparison, already taken into account by `cmpi`; `select` can
|
||||
support floats transparently if used after a float-comparison operation; the
|
||||
lower-level targets provide `select`-like instructions making the translation
|
||||
trivial.
|
||||
|
||||
This operation could have been implemented with additional control flow: `%r =
|
||||
select %cond, %t, %f` is equivalent to
|
||||
|
||||
```mlir
|
||||
^bb0:
|
||||
cond_br %cond, ^bb1(%t), ^bb1(%f)
|
||||
^bb1(%r):
|
||||
```
|
||||
|
||||
However, this control flow granularity is not available in the ML functions
|
||||
where min/max, and thus `select`, are likely to appear. In addition, simpler
|
||||
control flow may be beneficial for optimization in general.
|
||||
|
||||
### Regions
|
||||
|
||||
#### Attributes of type 'Block'
|
||||
|
||||
@@ -1247,6 +1247,152 @@ def IndexCastOp : ArithmeticCastOp<"index_cast"> {
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MaxFOp : FloatBinaryOp<"maxf"> {
|
||||
let summary = "floating-point maximum operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type
|
||||
```
|
||||
|
||||
Returns the maximum of the two arguments, treating -0.0 as less than +0.0.
|
||||
If one of the arguments is NaN, then the result is also NaN.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Scalar floating-point maximum.
|
||||
%a = maxf %b, %c : f64
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxSIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MaxSIOp : IntBinaryOp<"maxsi"> {
|
||||
let summary = "signed integer maximum operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type
|
||||
```
|
||||
|
||||
Returns the larger of %a and %b comparing the values as signed integers.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Scalar signed integer maximum.
|
||||
%a = maxsi %b, %c : i64
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxUIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MaxUIOp : IntBinaryOp<"maxui"> {
|
||||
let summary = "unsigned integer maximum operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type
|
||||
```
|
||||
|
||||
Returns the larger of %a and %b comparing the values as unsigned integers.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Scalar unsigned integer maximum.
|
||||
%a = maxui %b, %c : i64
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MinFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MinFOp : FloatBinaryOp<"minf"> {
|
||||
let summary = "floating-point minimum operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type
|
||||
```
|
||||
|
||||
Returns the minimum of the two arguments, treating -0.0 as less than +0.0.
|
||||
If one of the arguments is NaN, then the result is also NaN.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Scalar floating-point minimum.
|
||||
%a = minf %b, %c : f64
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MinSIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MinSIOp : IntBinaryOp<"minsi"> {
|
||||
let summary = "signed integer minimum operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= ssa-id `=` `minsi` ssa-use `,` ssa-use `:` type
|
||||
```
|
||||
|
||||
Returns the smaller of %a and %b comparing the values as signed integers.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Scalar signed integer minimum.
|
||||
%a = minsi %b, %c : i64
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MinUIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MinUIOp : IntBinaryOp<"minui"> {
|
||||
let summary = "unsigned integer minimum operation";
|
||||
let description = [{
|
||||
Syntax:
|
||||
|
||||
```
|
||||
operation ::= ssa-id `=` `minui` ssa-use `,` ssa-use `:` type
|
||||
```
|
||||
|
||||
Returns the smaller of %a and %b comparing the values as unsigned integers.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Scalar unsigned integer minimum.
|
||||
%a = minui %b, %c : i64
|
||||
```
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -215,6 +215,55 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
static Type getElementTypeOrSelf(Type type) {
|
||||
if (auto st = type.dyn_cast<ShapedType>())
|
||||
return st.getElementType();
|
||||
return type;
|
||||
}
|
||||
|
||||
template <typename OpTy, CmpFPredicate pred>
|
||||
struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
|
||||
public:
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value cmp = rewriter.create<CmpFOp>(loc, pred, lhs, rhs);
|
||||
Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
|
||||
|
||||
auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
|
||||
Value isNaN = rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, lhs, rhs);
|
||||
|
||||
Value nan = rewriter.create<ConstantFloatOp>(
|
||||
loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType);
|
||||
if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>())
|
||||
nan = rewriter.create<SplatOp>(loc, vectorType, nan);
|
||||
|
||||
rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy, CmpIPredicate pred>
|
||||
struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
|
||||
public:
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Value lhs = op.lhs();
|
||||
Value rhs = op.rhs();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value cmp = rewriter.create<CmpIOp>(loc, pred, lhs, rhs);
|
||||
rewriter.replaceOpWithNewOp<SelectOp>(op, cmp, lhs, rhs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
|
||||
void runOnFunction() override {
|
||||
MLIRContext &ctx = getContext();
|
||||
@@ -232,8 +281,18 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
|
||||
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
|
||||
return !op.shape().getType().cast<MemRefType>().hasStaticShape();
|
||||
});
|
||||
target.addIllegalOp<SignedCeilDivIOp>();
|
||||
target.addIllegalOp<SignedFloorDivIOp>();
|
||||
// clang-format off
|
||||
target.addIllegalOp<
|
||||
MaxFOp,
|
||||
MaxSIOp,
|
||||
MaxUIOp,
|
||||
MinFOp,
|
||||
MinSIOp,
|
||||
MinUIOp,
|
||||
SignedCeilDivIOp,
|
||||
SignedFloorDivIOp
|
||||
>();
|
||||
// clang-format on
|
||||
if (failed(
|
||||
applyPartialConversion(getFunction(), target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
@@ -243,9 +302,20 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
|
||||
} // namespace
|
||||
|
||||
void mlir::populateStdExpandOpsPatterns(RewritePatternSet &patterns) {
|
||||
patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter,
|
||||
SignedCeilDivIOpConverter, SignedFloorDivIOpConverter>(
|
||||
patterns.getContext());
|
||||
// clang-format off
|
||||
patterns.add<
|
||||
AtomicRMWOpConverter,
|
||||
MaxMinFOpConverter<MaxFOp, CmpFPredicate::OGT>,
|
||||
MaxMinFOpConverter<MinFOp, CmpFPredicate::OLT>,
|
||||
MaxMinIOpConverter<MaxSIOp, CmpIPredicate::sgt>,
|
||||
MaxMinIOpConverter<MaxUIOp, CmpIPredicate::ugt>,
|
||||
MaxMinIOpConverter<MinSIOp, CmpIPredicate::slt>,
|
||||
MaxMinIOpConverter<MinUIOp, CmpIPredicate::ult>,
|
||||
MemRefReshapeOpConverter,
|
||||
SignedCeilDivIOpConverter,
|
||||
SignedFloorDivIOpConverter
|
||||
>(patterns.getContext());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createStdExpandOpsPass() {
|
||||
|
||||
@@ -109,3 +109,92 @@ func @memref_reshape(%input: memref<*xf32>,
|
||||
// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
|
||||
// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
|
||||
// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @maxf
|
||||
func @maxf(%a: f32, %b: f32) -> f32 {
|
||||
%result = maxf(%a, %b): (f32, f32) -> f32
|
||||
return %result : f32
|
||||
}
|
||||
// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpf ogt, %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
|
||||
// CHECK-NEXT: return %[[RESULT]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @maxf_vector
|
||||
func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
|
||||
%result = maxf(%a, %b): (vector<4xf16>, vector<4xf16>) -> vector<4xf16>
|
||||
return %result : vector<4xf16>
|
||||
}
|
||||
// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>)
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16>
|
||||
// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]]
|
||||
// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16>
|
||||
// CHECK-NEXT: %[[NAN:.*]] = constant 0x7E00 : f16
|
||||
// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]]
|
||||
// CHECK-NEXT: return %[[RESULT]] : vector<4xf16>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @minf
|
||||
func @minf(%a: f32, %b: f32) -> f32 {
|
||||
%result = minf(%a, %b): (f32, f32) -> f32
|
||||
return %result : f32
|
||||
}
|
||||
// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpf olt, %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
|
||||
// CHECK-NEXT: return %[[RESULT]] : f32
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @maxsi
|
||||
func @maxsi(%a: i32, %b: i32) -> i32 {
|
||||
%result = maxsi(%a, %b): (i32, i32) -> i32
|
||||
return %result : i32
|
||||
}
|
||||
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS]], %[[RHS]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @minsi
|
||||
func @minsi(%a: i32, %b: i32) -> i32 {
|
||||
%result = minsi(%a, %b): (i32, i32) -> i32
|
||||
return %result : i32
|
||||
}
|
||||
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi slt, %[[LHS]], %[[RHS]] : i32
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @maxui
|
||||
func @maxui(%a: i32, %b: i32) -> i32 {
|
||||
%result = maxui(%a, %b): (i32, i32) -> i32
|
||||
return %result : i32
|
||||
}
|
||||
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi ugt, %[[LHS]], %[[RHS]] : i32
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @minui
|
||||
func @minui(%a: i32, %b: i32) -> i32 {
|
||||
%result = minui(%a, %b): (i32, i32) -> i32
|
||||
return %result : i32
|
||||
}
|
||||
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi ult, %[[LHS]], %[[RHS]] : i32
|
||||
|
||||
@@ -86,3 +86,27 @@ func @bitcast(%arg : f32) -> i32 {
|
||||
%res = bitcast %arg : f32 to i32
|
||||
return %res : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @maximum
|
||||
func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
|
||||
%f1: f32, %f2: f32,
|
||||
%i1: i32, %i2: i32) {
|
||||
%max_vector = maxf(%v1, %v2)
|
||||
: (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
|
||||
%max_float = maxf(%f1, %f2) : (f32, f32) -> f32
|
||||
%max_signed = maxsi(%i1, %i2) : (i32, i32) -> i32
|
||||
%max_unsigned = maxui(%i1, %i2) : (i32, i32) -> i32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @minimum
|
||||
func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
|
||||
%f1: f32, %f2: f32,
|
||||
%i1: i32, %i2: i32) {
|
||||
%min_vector = minf(%v1, %v2)
|
||||
: (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
|
||||
%min_float = minf(%f1, %f2) : (f32, f32) -> f32
|
||||
%min_signed = minsi(%i1, %i2) : (i32, i32) -> i32
|
||||
%min_unsigned = minui(%i1, %i2) : (i32, i32) -> i32
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user