[mlir][arith] Add narrowing patterns for max*i and min*i

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149583
This commit is contained in:
Jakub Kuderski
2023-05-02 10:48:01 -04:00
parent 48f18ecd82
commit 9701c5abd6
2 changed files with 153 additions and 1 deletions

View File

@@ -367,6 +367,29 @@ struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {
}
};
//===----------------------------------------------------------------------===//
// Min/Max Patterns
//===----------------------------------------------------------------------===//
template <typename MinMaxOp, ExtensionKind Kind>
struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> {
using BinaryOpNarrowingPattern<MinMaxOp>::BinaryOpNarrowingPattern;
bool isSupported(ExtensionOp ext) const override {
return ext.getKind() == Kind;
}
// Min/max returns one of the arguments and does not require any extra result
// bits.
unsigned getResultBitsProduced(unsigned operandBits) const override {
return operandBits;
}
};
using MaxSIPattern = MinMaxPattern<arith::MaxSIOp, ExtensionKind::Sign>;
using MaxUIPattern = MinMaxPattern<arith::MaxUIOp, ExtensionKind::Zero>;
using MinSIPattern = MinMaxPattern<arith::MinSIOp, ExtensionKind::Sign>;
using MinUIPattern = MinMaxPattern<arith::MinUIOp, ExtensionKind::Zero>;
//===----------------------------------------------------------------------===//
// *IToFPOp Patterns
//===----------------------------------------------------------------------===//
@@ -690,7 +713,8 @@ void populateArithIntNarrowingPatterns(
patterns.getContext(), options, PatternBenefit(2));
patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
DivUIPattern, SIToFPPattern, UIToFPPattern>(
DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern,
MinUIPattern, SIToFPPattern, UIToFPPattern>(
patterns.getContext(), options);
}

View File

@@ -473,6 +473,134 @@ func.func @uitofp_extsi_i16(%a: i16) -> f16 {
return %f : f16
}
//===----------------------------------------------------------------------===//
// arith.maxsi
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @maxsi_extsi_i8
// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8)
// CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[LHS]], %[[RHS]] : i8
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MAX]] : i8 to i32
// CHECK-NEXT: return %[[RET]] : i32
func.func @maxsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extsi %lhs : i8 to i32
%b = arith.extsi %rhs : i8 to i32
%r = arith.maxsi %a, %b : i32
return %r : i32
}
// This patterns should only apply to `arith.maxsi` ops with sign-extended
// arguments.
//
// CHECK-LABEL: func.func @maxsi_extui_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
// CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[EXT0]], %[[EXT1]] : i32
// CHECK-NEXT: return %[[MAX]] : i32
func.func @maxsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extui %lhs : i8 to i32
%b = arith.extui %rhs : i8 to i32
%r = arith.maxsi %a, %b : i32
return %r : i32
}
//===----------------------------------------------------------------------===//
// arith.maxui
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @maxui_extui_i8
// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8)
// CHECK-NEXT: %[[MAX:.+]] = arith.maxui %[[LHS]], %[[RHS]] : i8
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MAX]] : i8 to i32
// CHECK-NEXT: return %[[RET]] : i32
func.func @maxui_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extui %lhs : i8 to i32
%b = arith.extui %rhs : i8 to i32
%r = arith.maxui %a, %b : i32
return %r : i32
}
// This patterns should only apply to `arith.maxsi` ops with zero-extended
// arguments.
//
// CHECK-LABEL: func.func @maxui_extsi_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
// CHECK-NEXT: %[[MAX:.+]] = arith.maxui %[[EXT0]], %[[EXT1]] : i32
// CHECK-NEXT: return %[[MAX]] : i32
func.func @maxui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extsi %lhs : i8 to i32
%b = arith.extsi %rhs : i8 to i32
%r = arith.maxui %a, %b : i32
return %r : i32
}
//===----------------------------------------------------------------------===//
// arith.minsi
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @minsi_extsi_i8
// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8)
// CHECK-NEXT: %[[min:.+]] = arith.minsi %[[LHS]], %[[RHS]] : i8
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[min]] : i8 to i32
// CHECK-NEXT: return %[[RET]] : i32
func.func @minsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extsi %lhs : i8 to i32
%b = arith.extsi %rhs : i8 to i32
%r = arith.minsi %a, %b : i32
return %r : i32
}
// This patterns should only apply to `arith.minsi` ops with sign-extended
// arguments.
//
// CHECK-LABEL: func.func @minsi_extui_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
// CHECK-NEXT: %[[min:.+]] = arith.minsi %[[EXT0]], %[[EXT1]] : i32
// CHECK-NEXT: return %[[min]] : i32
func.func @minsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extui %lhs : i8 to i32
%b = arith.extui %rhs : i8 to i32
%r = arith.minsi %a, %b : i32
return %r : i32
}
//===----------------------------------------------------------------------===//
// arith.minui
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @minui_extui_i8
// CHECK-SAME: (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8)
// CHECK-NEXT: %[[min:.+]] = arith.minui %[[LHS]], %[[RHS]] : i8
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[min]] : i8 to i32
// CHECK-NEXT: return %[[RET]] : i32
func.func @minui_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extui %lhs : i8 to i32
%b = arith.extui %rhs : i8 to i32
%r = arith.minui %a, %b : i32
return %r : i32
}
// This patterns should only apply to `arith.minsi` ops with zero-extended
// arguments.
//
// CHECK-LABEL: func.func @minui_extsi_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
// CHECK-NEXT: %[[min:.+]] = arith.minui %[[EXT0]], %[[EXT1]] : i32
// CHECK-NEXT: return %[[min]] : i32
func.func @minui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extsi %lhs : i8 to i32
%b = arith.extsi %rhs : i8 to i32
%r = arith.minui %a, %b : i32
return %r : i32
}
//===----------------------------------------------------------------------===//
// Commute Extension over Vector Ops
//===----------------------------------------------------------------------===//