mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[mlir] switch the transform loop extension to use types
Add types to the Loop (SCF) extension of the transform dialect. See https://discourse.llvm.org/t/rfc-type-system-for-the-transform-dialect/65702 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D135587
This commit is contained in:
@@ -5,7 +5,6 @@
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
from ..dialects import pdl
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
@@ -28,13 +27,14 @@ class GetParentForOp:
|
||||
"""Extension for GetParentForOp."""
|
||||
|
||||
def __init__(self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
num_loops: int = 1,
|
||||
ip=None,
|
||||
loc=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(),
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
num_loops=_get_int64_attr(num_loops, default_value=1),
|
||||
ip=ip,
|
||||
@@ -45,13 +45,14 @@ class LoopOutlineOp:
|
||||
"""Extension for LoopOutlineOp."""
|
||||
|
||||
def __init__(self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
func_name: Union[str, StringAttr],
|
||||
ip=None,
|
||||
loc=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(),
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
func_name=(func_name if isinstance(func_name, StringAttr) else
|
||||
StringAttr.get(func_name)),
|
||||
@@ -63,13 +64,14 @@ class LoopPeelOp:
|
||||
"""Extension for LoopPeelOp."""
|
||||
|
||||
def __init__(self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
fail_if_already_divisible: Union[bool, BoolAttr] = False,
|
||||
ip=None,
|
||||
loc=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(),
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
fail_if_already_divisible=(fail_if_already_divisible if isinstance(
|
||||
fail_if_already_divisible, BoolAttr) else
|
||||
@@ -82,6 +84,7 @@ class LoopPipelineOp:
|
||||
"""Extension for LoopPipelineOp."""
|
||||
|
||||
def __init__(self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
|
||||
@@ -89,7 +92,7 @@ class LoopPipelineOp:
|
||||
ip=None,
|
||||
loc=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(),
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
iteration_interval=_get_int64_attr(iteration_interval, default_value=1),
|
||||
read_latency=_get_int64_attr(read_latency, default_value=10),
|
||||
|
||||
Reference in New Issue
Block a user