[mlir][linalg] Update OpDSL to use the newly introduced min and max ops.

Implement min and max using the newly introduced std operations instead of relying on compare and select.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D111170
This commit is contained in:
Tobias Gysi
2021-10-06 06:45:42 +00:00
parent 24af1ba605
commit a744c7e962
5 changed files with 23 additions and 52 deletions

View File

@@ -319,20 +319,16 @@ class _BodyBuilder:
def _eval_max(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
return _emit_cmpf_and_select(lhs, rhs, ogt_attr)
return std.MaxFOp(lhs.type, lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
return _emit_cmpi_and_select(lhs, rhs, sgt_attr)
return std.MaxSIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
return _emit_cmpf_and_select(lhs, rhs, olt_attr)
return std.MinFOp(lhs.type, lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
return _emit_cmpi_and_select(lhs, rhs, slt_attr)
return std.MinSIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
@@ -413,13 +409,3 @@ def _get_floating_point_width(t: Type) -> int:
if BF16Type.isinstance(t):
return 16
raise NotImplementedError(f"Unhandled floating point type switch {t}")
def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result
return std.SelectOp(lhs.type, cond, lhs, rhs).result
def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result
return std.SelectOp(lhs.type, cond, lhs, rhs).result