[mlir][linalg][transform][python] Drop _get_op_result... from mix-ins. (#65726)

`_get_op_result_or_value` was used in mix-ins to unify the handling of
op results and values. However, that function is now called in the
generated constructors, such that doing so in the mix-ins is not
necessary anymore.
This commit is contained in:
Ingo Müller
2023-09-14 17:24:16 +02:00
committed by GitHub
parent ddc3346a6b
commit 360c629024

View File

@@ -4,7 +4,6 @@
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ..dialects import pdl, transform
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@@ -101,7 +100,7 @@ def _dispatch_mixed_values(
static_values.append(size)
else:
static_values.append(ShapedType.get_dynamic_size())
dynamic_values.append(_get_op_result_or_value(size))
dynamic_values.append(size)
static_values = DenseI64ArrayAttr.get(static_values)
return (dynamic_values, packed_values, static_values)
@@ -204,9 +203,7 @@ class DecomposeOp:
"""Specialization for DecomposeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(
pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
)
super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
class FuseIntoContainingOp:
@@ -277,9 +274,7 @@ class GeneralizeOp:
"""Specialization for GeneralizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(
pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
)
super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
class InterchangeOp:
@@ -296,7 +291,7 @@ class InterchangeOp:
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
target,
iterator_interchange=iterator_interchange,
loc=loc,
ip=ip,
@@ -415,7 +410,7 @@ class MatchOp:
loc=None,
ip=None,
):
...
...
@overload
@classmethod
@@ -428,7 +423,7 @@ class MatchOp:
loc=None,
ip=None,
):
...
...
@classmethod
def match_op_names(
@@ -441,20 +436,20 @@ class MatchOp:
ip=None,
):
if isinstance(result_type_or_target, Type):
result_type = result_type_or_target
target = target_or_names
names = names_or_none
result_type = result_type_or_target
target = target_or_names
names = names_or_none
else:
result_type = transform.AnyOpType.get()
target = result_type_or_target
names = target_or_names
result_type = transform.AnyOpType.get()
target = result_type_or_target
names = target_or_names
if isinstance(names, str):
names = [names]
names = [names]
return cls(
result_type,
_get_op_result_or_value(target),
target,
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
loc=loc,
ip=ip,
@@ -479,7 +474,7 @@ class MultiTileSizesOp:
result_type,
result_type,
result_type,
_get_op_result_or_value(target),
target,
dimension=dimension,
target_size=target_size,
divisor=divisor,
@@ -530,9 +525,7 @@ class ScalarizeOp:
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip
)
super().__init__(pdl_operation_type, target, loc=loc, ip=ip)
class SplitOp:
@@ -552,9 +545,7 @@ class SplitOp:
dynamic_split_point = None
else:
static_split_point = ShapedType.get_dynamic_size()
dynamic_split_point = _get_op_result_or_value(split_point)
target = _get_op_result_or_value(target)
dynamic_split_point = split_point
super().__init__(
target.type,
@@ -626,8 +617,6 @@ class TileOp:
)
target = target_or_none
target = _get_op_result_or_value(target)
super().__init__(
target.type,
loop_types,
@@ -750,7 +739,7 @@ class VectorizeOp:
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
target,
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
vectorize_nd_extract=vectorize_nd_extract,