From 360c6290240910991bcd660297b3e615cb8f3216 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 14 Sep 2023 17:24:16 +0200 Subject: [PATCH] [mlir][linalg][transform][python] Drop _get_op_result... from mix-ins. (#65726) `_get_op_result_or_value` was used in mix-ins to unify the handling of op results and values. However, that function is now called in the generated constructors, such that doing so in the mix-ins is not necessary anymore. --- .../dialects/_structured_transform_ops_ext.py | 47 +++++++------------ 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index f368e56f9818..212fbc5badcb 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -4,7 +4,6 @@ try: from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value from ..dialects import pdl, transform except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -101,7 +100,7 @@ def _dispatch_mixed_values( static_values.append(size) else: static_values.append(ShapedType.get_dynamic_size()) - dynamic_values.append(_get_op_result_or_value(size)) + dynamic_values.append(size) static_values = DenseI64ArrayAttr.get(static_values) return (dynamic_values, packed_values, static_values) @@ -204,9 +203,7 @@ class DecomposeOp: """Specialization for DecomposeOp class.""" def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__( - pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip - ) + super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip) class FuseIntoContainingOp: @@ -277,9 +274,7 @@ class GeneralizeOp: """Specialization for GeneralizeOp class.""" def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__( - pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip - ) + super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip) class InterchangeOp: @@ -296,7 +291,7 @@ class InterchangeOp: pdl_operation_type = pdl.OperationType.get() super().__init__( pdl_operation_type, - _get_op_result_or_value(target), + target, iterator_interchange=iterator_interchange, loc=loc, ip=ip, @@ -415,7 +410,7 @@ class MatchOp: loc=None, ip=None, ): - ... + ... @overload @classmethod @@ -428,7 +423,7 @@ class MatchOp: loc=None, ip=None, ): - ... + ... @classmethod def match_op_names( @@ -441,20 +436,20 @@ class MatchOp: ip=None, ): if isinstance(result_type_or_target, Type): - result_type = result_type_or_target - target = target_or_names - names = names_or_none + result_type = result_type_or_target + target = target_or_names + names = names_or_none else: - result_type = transform.AnyOpType.get() - target = result_type_or_target - names = target_or_names + result_type = transform.AnyOpType.get() + target = result_type_or_target + names = target_or_names if isinstance(names, str): - names = [names] + names = [names] return cls( result_type, - _get_op_result_or_value(target), + target, ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), loc=loc, ip=ip, @@ -479,7 +474,7 @@ class MultiTileSizesOp: result_type, result_type, result_type, - _get_op_result_or_value(target), + target, dimension=dimension, target_size=target_size, divisor=divisor, @@ -530,9 +525,7 @@ class ScalarizeOp: def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): pdl_operation_type = pdl.OperationType.get() - super().__init__( - pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip - ) + super().__init__(pdl_operation_type, target, loc=loc, ip=ip) class SplitOp: @@ -552,9 +545,7 @@ class SplitOp: dynamic_split_point = None else: static_split_point = ShapedType.get_dynamic_size() - dynamic_split_point = _get_op_result_or_value(split_point) - - target = _get_op_result_or_value(target) + dynamic_split_point = split_point super().__init__( target.type, @@ -626,8 +617,6 @@ class TileOp: ) target = target_or_none - target = _get_op_result_or_value(target) - super().__init__( target.type, loop_types, @@ -750,7 +739,7 @@ class VectorizeOp: pdl_operation_type = pdl.OperationType.get() super().__init__( pdl_operation_type, - _get_op_result_or_value(target), + target, disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns, vectorize_nd_extract=vectorize_nd_extract,