[mlir][python] simplify extensions (#69642)

https://github.com/llvm/llvm-project/pull/68853 enabled a lot of nice
cleanup. Note, I made sure each of the touched extensions had tests.
This commit is contained in:
Maksim Levental
2023-10-19 18:07:06 -05:00
committed by GitHub
parent dda3ed9091
commit dd473f1dd1
8 changed files with 13 additions and 217 deletions

View File

@@ -3,48 +3,3 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._affine_ops_gen import *
from ._affine_ops_gen import _Dialect
try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
@_ods_cext.register_operation(_Dialect, replace=True)
class AffineStoreOp(AffineStoreOp):
"""Specialization for the Affine store operation."""
def __init__(
self,
value: Union[Operation, OpView, Value],
memref: Union[Operation, OpView, Value],
map: AffineMap = None,
*,
map_operands=None,
loc=None,
ip=None,
):
"""Creates an affine store operation.
- `value`: the value to store into the memref.
- `memref`: the buffer to store into.
- `map`: the affine map that maps the map_operands to the index of the
memref.
- `map_operands`: the list of arguments to substitute the dimensions,
then symbols in the affine map, in increasing order.
"""
map = map if map is not None else []
map_operands = map_operands if map_operands is not None else []
indicies = [_get_op_result_or_value(op) for op in map_operands]
_ods_successors = None
super().__init__(
value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
)

View File

@@ -3,40 +3,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._bufferization_ops_gen import *
from ._bufferization_ops_gen import _Dialect
from ._bufferization_enum_gen import *
try:
from typing import Sequence, Union
from ..ir import *
from ._ods_common import get_default_loc_context, _cext as _ods_cext
from typing import Any, List, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@_ods_cext.register_operation(_Dialect, replace=True)
class AllocTensorOp(AllocTensorOp):
"""Extends the bufferization.alloc_tensor op."""
def __init__(
self,
tensor_type: Type,
dynamic_sizes: Sequence[Value],
copy: Value,
size_hint: Value,
escape: BoolAttr,
*,
loc=None,
ip=None,
):
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
super().__init__(
tensor_type,
dynamic_sizes,
copy=copy,
size_hint=size_hint,
loc=loc,
ip=ip,
)

View File

@@ -26,9 +26,6 @@ RESULT_ATTRIBUTE_NAME = "res_attrs"
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
super().__init__(result, value, loc=loc, ip=ip)
@property
def type(self):
return self.results[0].type

View File

@@ -3,41 +3,3 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._memref_ops_gen import *
from ._memref_ops_gen import _Dialect
try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
@_ods_cext.register_operation(_Dialect, replace=True)
class LoadOp(LoadOp):
"""Specialization for the MemRef load operation."""
def __init__(
self,
memref: Union[Operation, OpView, Value],
indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
loc=None,
ip=None,
):
"""Creates a memref load operation.
Args:
memref: the buffer to load from.
indices: the list of subscripts, may be empty for zero-dimensional
buffers.
loc: user-visible location of the operation.
ip: insertion point.
"""
indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
super().__init__(memref, indices_resolved, loc=loc, ip=ip)

View File

@@ -21,43 +21,6 @@ from ._ods_common import (
)
@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
"""Specialization for PDL apply native constraint op class."""
def __init__(
self,
name: Union[str, StringAttr],
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
args = _get_values(args)
super().__init__(name, args, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
"""Specialization for PDL apply native rewrite op class."""
def __init__(
self,
results: Sequence[Type],
name: Union[str, StringAttr],
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
args = _get_values(args)
super().__init__(results, name, args, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class AttributeOp(AttributeOp):
"""Specialization for PDL attribute op class."""
@@ -75,21 +38,6 @@ class AttributeOp(AttributeOp):
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class EraseOp(EraseOp):
"""Specialization for PDL erase op class."""
def __init__(
self,
operation: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
operation = _get_value(operation)
super().__init__(operation, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class OperandOp(OperandOp):
"""Specialization for PDL operand op class."""
@@ -216,23 +164,6 @@ class ResultOp(ResultOp):
super().__init__(result, parent, index, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class ResultsOp(ResultsOp):
"""Specialization for PDL results op class."""
def __init__(
self,
result: Type,
parent: Union[OpView, Operation, Value],
index: Optional[Union[IntegerAttr, int]] = None,
*,
loc=None,
ip=None,
):
parent = _get_value(parent)
super().__init__(result, parent, index=index, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class RewriteOp(RewriteOp):
"""Specialization for PDL rewrite op class."""

View File

@@ -20,11 +20,8 @@ except ImportError as e:
from typing import Optional, Sequence, Union
_ForOp = ForOp
@_ods_cext.register_operation(_Dialect, replace=True)
class ForOp(_ForOp):
class ForOp(ForOp):
"""Specialization for the SCF for op class."""
def __init__(
@@ -50,17 +47,8 @@ class ForOp(_ForOp):
iter_args = _get_op_results_or_values(iter_args)
results = [arg.type for arg in iter_args]
super(_ForOp, self).__init__(
self.build_generic(
regions=1,
results=results,
operands=[
_get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
]
+ list(iter_args),
loc=loc,
ip=ip,
)
super().__init__(
results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip
)
self.regions[0].blocks.append(self.operands[0].type, *results)
@@ -83,28 +71,23 @@ class ForOp(_ForOp):
return self.body.arguments[1:]
_IfOp = IfOp
@_ods_cext.register_operation(_Dialect, replace=True)
class IfOp(_IfOp):
class IfOp(IfOp):
"""Specialization for the SCF if op class."""
def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None):
"""Creates an SCF `if` operation.
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
- `hasElse` determines whether the if operation has the else branch.
"""
if results_ is None:
results_ = []
operands = []
operands.append(cond)
results = []
results.extend(results_)
super(_IfOp, self).__init__(
self.build_generic(
regions=2, results=results, operands=operands, loc=loc, ip=ip
)
)
super().__init__(results, cond)
self.regions[0].blocks.append(*[])
if hasElse:
self.regions[1].blocks.append(*[])