[mlir][python] meta region_op (#75673)

This commit is contained in:
Maksim Levental
2023-12-21 11:20:29 -06:00
committed by GitHub
parent 11140cc238
commit 537b2aa264
14 changed files with 429 additions and 13 deletions

View File

@@ -21,7 +21,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/__init__.py
ir.py
passmanager.py
extras/types.py
dialects/_ods_common.py
# The main _mlir module has submodules: include stubs from each.
@@ -30,6 +29,14 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/_mlir/passmanager.pyi
)
declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
ADD_TO_PARENT MLIRPythonSources.Core.Python
SOURCES
extras/types.py
extras/meta.py
)
declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
ADD_TO_PARENT MLIRPythonSources

View File

@@ -11,6 +11,8 @@ try:
from ._ods_common import (
get_default_loc_context as _get_default_loc_context,
_cext as _ods_cext,
get_op_result_or_op_results as _get_op_result_or_op_results,
SubClassValueT as _SubClassValueT,
)
from typing import Any, List, Union
@@ -75,3 +77,9 @@ class ConstantOp(ConstantOp):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")
def constant(
result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
) -> _SubClassValueT:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))

View File

@@ -2,8 +2,11 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Dict, Optional
from ._builtin_ops_gen import *
from ._builtin_ops_gen import _Dialect
from ..extras.meta import region_op
try:
from ..ir import *
@@ -23,3 +26,23 @@ class ModuleOp(ModuleOp):
@property
def body(self):
return self.regions[0].blocks[0]
@region_op
def module(
*,
sym_name=None,
sym_visibility=None,
attrs: Optional[Dict[str, Attribute]] = None,
loc=None,
ip=None,
):
mod = ModuleOp.__base__(
sym_name=sym_name, sym_visibility=sym_visibility, loc=loc, ip=ip
)
if attrs is None:
attrs = {}
for attr_name, attr in attrs.items():
mod.operation.attributes[attr_name] = attr
return mod

View File

@@ -243,6 +243,9 @@ class FuncOp(FuncOp):
return decorator
func = FuncOp.from_py_func
@_ods_cext.register_operation(_Dialect, replace=True)
class CallOp(CallOp):
"""Specialization for the call op class."""

View File

@@ -5,6 +5,7 @@
from ._pdl_ops_gen import *
from ._pdl_ops_gen import _Dialect
from .._mlir_libs._mlirDialectsPDL import *
from .._mlir_libs._mlirDialectsPDL import OperationType
try:
@@ -13,7 +14,7 @@ try:
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Union, Optional, Sequence, Mapping
from typing import Union, Optional, Sequence, Mapping, NewType
from ._ods_common import (
get_op_result_or_value as _get_value,
get_op_results_or_values as _get_values,
@@ -220,3 +221,10 @@ class TypesOp(TypesOp):
constantTypes = []
result = pdl.RangeType.get(pdl.TypeType.get())
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
OperationTypeT = NewType("OperationType", OperationType)
def op_t() -> OperationTypeT:
return OperationTypeT(OperationType.get())

View File

@@ -120,7 +120,7 @@ def for_(
params = [start, stop, step]
for i, p in enumerate(params):
if isinstance(p, int):
p = constant(IntegerAttr.get(IndexType.get(), p))
p = constant(IndexType.get(), p)
elif isinstance(p, float):
raise ValueError(f"{p=} must be int.")
params[i] = p

View File

@@ -4,6 +4,7 @@
from ._tensor_ops_gen import *
from ._tensor_ops_gen import _Dialect
from ..extras.meta import region_op
try:
from ..ir import *
@@ -40,3 +41,9 @@ class EmptyOp(EmptyOp):
dynamic_sizes.append(s)
result_type = RankedTensorType.get(static_sizes, element_type)
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
generate = region_op(
lambda result, dynamic_extents: GenerateOp(result, dynamic_extents),
terminator=lambda args: YieldOp(args[0]),
)

View File

@@ -18,7 +18,7 @@ try:
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
from typing import Optional, Sequence, Union, NewType
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -175,7 +175,7 @@ class NamedSequenceOp(NamedSequenceOp):
result_types: Sequence[Type],
sym_visibility=None,
arg_attrs=None,
res_attrs=None
res_attrs=None,
):
function_type = FunctionType.get(input_types, result_types)
super().__init__(
@@ -183,7 +183,7 @@ class NamedSequenceOp(NamedSequenceOp):
function_type=TypeAttr.get(function_type),
sym_visibility=sym_visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs
res_attrs=res_attrs,
)
self.regions[0].blocks.append(*input_types)
@@ -212,3 +212,10 @@ class YieldOp(YieldOp):
if operands is None:
operands = []
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
def any_op_t() -> AnyOpTypeT:
return AnyOpTypeT(AnyOpType.get())

View File

@@ -4,8 +4,16 @@
from typing import Callable, Optional, Sequence, Union
from ....extras.meta import region_op
from .... import ir
from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
from .. import (
AnyOpType,
OperationType,
NamedSequenceOp,
YieldOp,
SequenceOp,
ApplyPatternsOp,
)
from .. import structured
@@ -147,3 +155,8 @@ def insert_transform_script(
if dump_script:
print(named_sequence_op)
sequence = region_op(SequenceOp.__base__, terminator=YieldOp)
named_sequence = region_op(NamedSequenceOp, terminator=YieldOp)
apply_patterns = region_op(ApplyPatternsOp)

View File

@@ -0,0 +1,83 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import inspect
from functools import wraps
from ..dialects._ods_common import get_op_result_or_op_results
from ..ir import Type, InsertionPoint
def op_region_builder(op, op_region, terminator=None):
def builder_wrapper(body_builder):
# Add a block with block args having types determined by type hints on the wrapped function.
if len(op_region.blocks) == 0:
sig = inspect.signature(body_builder)
types = [p.annotation for p in sig.parameters.values()]
if not (
len(types) == len(sig.parameters)
and all(isinstance(t, Type) for t in types)
):
raise ValueError(
f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
)
op_region.blocks.append(*types)
with InsertionPoint(op_region.blocks[0]):
results = body_builder(*list(op_region.blocks[0].arguments))
with InsertionPoint(list(op_region.blocks)[-1]):
if terminator is not None:
res = []
if isinstance(results, (tuple, list)):
res.extend(results)
elif results is not None:
res.append(results)
terminator(res)
return get_op_result_or_op_results(op)
return builder_wrapper
def region_op(op_constructor, terminator=None):
"""Decorator to define an MLIR Op specified as a python function.
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
active for the current thread (i.e. established in a `with` block).
Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor.
When applied as a decorator to a Python function, an entry block will
be constructed for the Op with types as specified **as type hints on the args of the function**.
The block arguments will be passed positionally to the Python function.
If a terminator is specified then the return from the decorated function will be passed
to the terminator as the last statement in the entry block. Note, the API for the terminator
is a (possibly empty) list; terminator accepting single values should be wrapped in a
`lambda args: term(args[0])`
The identifier (name) of the function will become:
1. A single value result if the Op returns a single value;
2. An OpResultList (as a list) if the Op returns multiple values;
3. The Operation if the Op returns no results.
See examples in tensor.py and transform.extras.
"""
def op_decorator(*args, **kwargs):
op = op_constructor(*args, **kwargs)
op_region = op.regions[0]
return op_region_builder(op, op_region, terminator)
@wraps(op_decorator)
def maybe_no_args(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return op_decorator()(args[0])
else:
return op_decorator(*args, **kwargs)
return maybe_no_args