mirror of
https://github.com/intel/llvm.git
synced 2026-01-20 01:58:44 +08:00
[mlir][linalg] Update OpDSL to use the newly introduced min and max ops.
Implement min and max using the newly introduced std operations instead of relying on compare and select. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D111170
This commit is contained in:
@@ -319,20 +319,16 @@ class _BodyBuilder:
|
||||
|
||||
def _eval_max(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
|
||||
return _emit_cmpf_and_select(lhs, rhs, ogt_attr)
|
||||
return std.MaxFOp(lhs.type, lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
|
||||
return _emit_cmpi_and_select(lhs, rhs, sgt_attr)
|
||||
return std.MaxSIOp(lhs.type, lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
|
||||
|
||||
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
|
||||
return _emit_cmpf_and_select(lhs, rhs, olt_attr)
|
||||
return std.MinFOp(lhs.type, lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
|
||||
return _emit_cmpi_and_select(lhs, rhs, slt_attr)
|
||||
return std.MinSIOp(lhs.type, lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
|
||||
|
||||
|
||||
@@ -413,13 +409,3 @@ def _get_floating_point_width(t: Type) -> int:
|
||||
if BF16Type.isinstance(t):
|
||||
return 16
|
||||
raise NotImplementedError(f"Unhandled floating point type switch {t}")
|
||||
|
||||
|
||||
def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
|
||||
cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result
|
||||
return std.SelectOp(lhs.type, cond, lhs, rhs).result
|
||||
|
||||
|
||||
def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
|
||||
cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result
|
||||
return std.SelectOp(lhs.type, cond, lhs, rhs).result
|
||||
|
||||
Reference in New Issue
Block a user