mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 03:56:16 +08:00
[mlir][linalg][transform][python] Drop _get_op_result... from mix-ins. (#65726)
`_get_op_result_or_value` was used in mix-ins to unify the handling of op results and values. However, that function is now called in the generated constructors, such that doing so in the mix-ins is not necessary anymore.
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
from ..dialects import pdl, transform
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
@@ -101,7 +100,7 @@ def _dispatch_mixed_values(
|
||||
static_values.append(size)
|
||||
else:
|
||||
static_values.append(ShapedType.get_dynamic_size())
|
||||
dynamic_values.append(_get_op_result_or_value(size))
|
||||
dynamic_values.append(size)
|
||||
static_values = DenseI64ArrayAttr.get(static_values)
|
||||
|
||||
return (dynamic_values, packed_values, static_values)
|
||||
@@ -204,9 +203,7 @@ class DecomposeOp:
|
||||
"""Specialization for DecomposeOp class."""
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
|
||||
)
|
||||
super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class FuseIntoContainingOp:
|
||||
@@ -277,9 +274,7 @@ class GeneralizeOp:
|
||||
"""Specialization for GeneralizeOp class."""
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
|
||||
)
|
||||
super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class InterchangeOp:
|
||||
@@ -296,7 +291,7 @@ class InterchangeOp:
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
super().__init__(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
target,
|
||||
iterator_interchange=iterator_interchange,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
@@ -415,7 +410,7 @@ class MatchOp:
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
@@ -428,7 +423,7 @@ class MatchOp:
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def match_op_names(
|
||||
@@ -441,20 +436,20 @@ class MatchOp:
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(result_type_or_target, Type):
|
||||
result_type = result_type_or_target
|
||||
target = target_or_names
|
||||
names = names_or_none
|
||||
result_type = result_type_or_target
|
||||
target = target_or_names
|
||||
names = names_or_none
|
||||
else:
|
||||
result_type = transform.AnyOpType.get()
|
||||
target = result_type_or_target
|
||||
names = target_or_names
|
||||
result_type = transform.AnyOpType.get()
|
||||
target = result_type_or_target
|
||||
names = target_or_names
|
||||
|
||||
if isinstance(names, str):
|
||||
names = [names]
|
||||
names = [names]
|
||||
|
||||
return cls(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
target,
|
||||
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
@@ -479,7 +474,7 @@ class MultiTileSizesOp:
|
||||
result_type,
|
||||
result_type,
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
target,
|
||||
dimension=dimension,
|
||||
target_size=target_size,
|
||||
divisor=divisor,
|
||||
@@ -530,9 +525,7 @@ class ScalarizeOp:
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
super().__init__(
|
||||
pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip
|
||||
)
|
||||
super().__init__(pdl_operation_type, target, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class SplitOp:
|
||||
@@ -552,9 +545,7 @@ class SplitOp:
|
||||
dynamic_split_point = None
|
||||
else:
|
||||
static_split_point = ShapedType.get_dynamic_size()
|
||||
dynamic_split_point = _get_op_result_or_value(split_point)
|
||||
|
||||
target = _get_op_result_or_value(target)
|
||||
dynamic_split_point = split_point
|
||||
|
||||
super().__init__(
|
||||
target.type,
|
||||
@@ -626,8 +617,6 @@ class TileOp:
|
||||
)
|
||||
target = target_or_none
|
||||
|
||||
target = _get_op_result_or_value(target)
|
||||
|
||||
super().__init__(
|
||||
target.type,
|
||||
loop_types,
|
||||
@@ -750,7 +739,7 @@ class VectorizeOp:
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
super().__init__(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
target,
|
||||
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
|
||||
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
|
||||
vectorize_nd_extract=vectorize_nd_extract,
|
||||
|
||||
Reference in New Issue
Block a user