[mlir] switch the transform loop extension to use types

Add types to the Loop (SCF) extension of the transform dialect.

See https://discourse.llvm.org/t/rfc-type-system-for-the-transform-dialect/65702

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D135587
This commit is contained in:
Alex Zinenko
2022-10-10 14:38:31 +00:00
parent 3e1f6d02f7
commit 59bb8af4c3
13 changed files with 84 additions and 69 deletions

View File

@@ -5,7 +5,6 @@
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ..dialects import pdl
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@@ -28,13 +27,14 @@ class GetParentForOp:
"""Extension for GetParentForOp."""
def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
num_loops: int = 1,
ip=None,
loc=None):
super().__init__(
pdl.OperationType.get(),
result_type,
_get_op_result_or_value(target),
num_loops=_get_int64_attr(num_loops, default_value=1),
ip=ip,
@@ -45,13 +45,14 @@ class LoopOutlineOp:
"""Extension for LoopOutlineOp."""
def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
func_name: Union[str, StringAttr],
ip=None,
loc=None):
super().__init__(
pdl.OperationType.get(),
result_type,
_get_op_result_or_value(target),
func_name=(func_name if isinstance(func_name, StringAttr) else
StringAttr.get(func_name)),
@@ -63,13 +64,14 @@ class LoopPeelOp:
"""Extension for LoopPeelOp."""
def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
fail_if_already_divisible: Union[bool, BoolAttr] = False,
ip=None,
loc=None):
super().__init__(
pdl.OperationType.get(),
result_type,
_get_op_result_or_value(target),
fail_if_already_divisible=(fail_if_already_divisible if isinstance(
fail_if_already_divisible, BoolAttr) else
@@ -82,6 +84,7 @@ class LoopPipelineOp:
"""Extension for LoopPipelineOp."""
def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
@@ -89,7 +92,7 @@ class LoopPipelineOp:
ip=None,
loc=None):
super().__init__(
pdl.OperationType.get(),
result_type,
_get_op_result_or_value(target),
iteration_interval=_get_int64_attr(iteration_interval, default_value=1),
read_latency=_get_int64_attr(read_latency, default_value=10),