mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[mlir][python] simplify extensions (#69642)
https://github.com/llvm/llvm-project/pull/68853 enabled a lot of nice cleanup. Note, I made sure each of the touched extensions had tests.
This commit is contained in:
@@ -3,48 +3,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._affine_ops_gen import *
|
||||
from ._affine_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class AffineStoreOp(AffineStoreOp):
|
||||
"""Specialization for the Affine store operation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: Union[Operation, OpView, Value],
|
||||
memref: Union[Operation, OpView, Value],
|
||||
map: AffineMap = None,
|
||||
*,
|
||||
map_operands=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an affine store operation.
|
||||
|
||||
- `value`: the value to store into the memref.
|
||||
- `memref`: the buffer to store into.
|
||||
- `map`: the affine map that maps the map_operands to the index of the
|
||||
memref.
|
||||
- `map_operands`: the list of arguments to substitute the dimensions,
|
||||
then symbols in the affine map, in increasing order.
|
||||
"""
|
||||
map = map if map is not None else []
|
||||
map_operands = map_operands if map_operands is not None else []
|
||||
indicies = [_get_op_result_or_value(op) for op in map_operands]
|
||||
_ods_successors = None
|
||||
super().__init__(
|
||||
value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@@ -3,40 +3,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._bufferization_ops_gen import *
|
||||
from ._bufferization_ops_gen import _Dialect
|
||||
from ._bufferization_enum_gen import *
|
||||
|
||||
try:
|
||||
from typing import Sequence, Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context, _cext as _ods_cext
|
||||
|
||||
from typing import Any, List, Union
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class AllocTensorOp(AllocTensorOp):
|
||||
"""Extends the bufferization.alloc_tensor op."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensor_type: Type,
|
||||
dynamic_sizes: Sequence[Value],
|
||||
copy: Value,
|
||||
size_hint: Value,
|
||||
escape: BoolAttr,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
|
||||
super().__init__(
|
||||
tensor_type,
|
||||
dynamic_sizes,
|
||||
copy=copy,
|
||||
size_hint=size_hint,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -26,9 +26,6 @@ RESULT_ATTRIBUTE_NAME = "res_attrs"
|
||||
class ConstantOp(ConstantOp):
|
||||
"""Specialization for the constant op class."""
|
||||
|
||||
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
|
||||
super().__init__(result, value, loc=loc, ip=ip)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.results[0].type
|
||||
|
||||
@@ -3,41 +3,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._memref_ops_gen import *
|
||||
from ._memref_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class LoadOp(LoadOp):
|
||||
"""Specialization for the MemRef load operation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memref: Union[Operation, OpView, Value],
|
||||
indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates a memref load operation.
|
||||
|
||||
Args:
|
||||
memref: the buffer to load from.
|
||||
indices: the list of subscripts, may be empty for zero-dimensional
|
||||
buffers.
|
||||
loc: user-visible location of the operation.
|
||||
ip: insertion point.
|
||||
"""
|
||||
indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
|
||||
super().__init__(memref, indices_resolved, loc=loc, ip=ip)
|
||||
|
||||
@@ -21,43 +21,6 @@ from ._ods_common import (
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
|
||||
"""Specialization for PDL apply native constraint op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Union[str, StringAttr],
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
super().__init__(name, args, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
|
||||
"""Specialization for PDL apply native rewrite op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
results: Sequence[Type],
|
||||
name: Union[str, StringAttr],
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
super().__init__(results, name, args, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class AttributeOp(AttributeOp):
|
||||
"""Specialization for PDL attribute op class."""
|
||||
@@ -75,21 +38,6 @@ class AttributeOp(AttributeOp):
|
||||
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class EraseOp(EraseOp):
|
||||
"""Specialization for PDL erase op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
operation: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
operation = _get_value(operation)
|
||||
super().__init__(operation, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class OperandOp(OperandOp):
|
||||
"""Specialization for PDL operand op class."""
|
||||
@@ -216,23 +164,6 @@ class ResultOp(ResultOp):
|
||||
super().__init__(result, parent, index, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ResultsOp(ResultsOp):
|
||||
"""Specialization for PDL results op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result: Type,
|
||||
parent: Union[OpView, Operation, Value],
|
||||
index: Optional[Union[IntegerAttr, int]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
parent = _get_value(parent)
|
||||
super().__init__(result, parent, index=index, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class RewriteOp(RewriteOp):
|
||||
"""Specialization for PDL rewrite op class."""
|
||||
|
||||
@@ -20,11 +20,8 @@ except ImportError as e:
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
_ForOp = ForOp
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ForOp(_ForOp):
|
||||
class ForOp(ForOp):
|
||||
"""Specialization for the SCF for op class."""
|
||||
|
||||
def __init__(
|
||||
@@ -50,17 +47,8 @@ class ForOp(_ForOp):
|
||||
iter_args = _get_op_results_or_values(iter_args)
|
||||
|
||||
results = [arg.type for arg in iter_args]
|
||||
super(_ForOp, self).__init__(
|
||||
self.build_generic(
|
||||
regions=1,
|
||||
results=results,
|
||||
operands=[
|
||||
_get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
|
||||
]
|
||||
+ list(iter_args),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
super().__init__(
|
||||
results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip
|
||||
)
|
||||
self.regions[0].blocks.append(self.operands[0].type, *results)
|
||||
|
||||
@@ -83,28 +71,23 @@ class ForOp(_ForOp):
|
||||
return self.body.arguments[1:]
|
||||
|
||||
|
||||
_IfOp = IfOp
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class IfOp(_IfOp):
|
||||
class IfOp(IfOp):
|
||||
"""Specialization for the SCF if op class."""
|
||||
|
||||
def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
|
||||
def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None):
|
||||
"""Creates an SCF `if` operation.
|
||||
|
||||
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
|
||||
- `hasElse` determines whether the if operation has the else branch.
|
||||
"""
|
||||
if results_ is None:
|
||||
results_ = []
|
||||
operands = []
|
||||
operands.append(cond)
|
||||
results = []
|
||||
results.extend(results_)
|
||||
super(_IfOp, self).__init__(
|
||||
self.build_generic(
|
||||
regions=2, results=results, operands=operands, loc=loc, ip=ip
|
||||
)
|
||||
)
|
||||
super().__init__(results, cond)
|
||||
self.regions[0].blocks.append(*[])
|
||||
if hasElse:
|
||||
self.regions[1].blocks.append(*[])
|
||||
|
||||
Reference in New Issue
Block a user