mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 09:57:08 +08:00
[mlir][python] meta region_op (#75673)
This commit is contained in:
@@ -21,7 +21,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
|
||||
_mlir_libs/__init__.py
|
||||
ir.py
|
||||
passmanager.py
|
||||
extras/types.py
|
||||
dialects/_ods_common.py
|
||||
|
||||
# The main _mlir module has submodules: include stubs from each.
|
||||
@@ -30,6 +29,14 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
|
||||
_mlir_libs/_mlir/passmanager.pyi
|
||||
)
|
||||
|
||||
declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
ADD_TO_PARENT MLIRPythonSources.Core.Python
|
||||
SOURCES
|
||||
extras/types.py
|
||||
extras/meta.py
|
||||
)
|
||||
|
||||
declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
ADD_TO_PARENT MLIRPythonSources
|
||||
|
||||
@@ -11,6 +11,8 @@ try:
|
||||
from ._ods_common import (
|
||||
get_default_loc_context as _get_default_loc_context,
|
||||
_cext as _ods_cext,
|
||||
get_op_result_or_op_results as _get_op_result_or_op_results,
|
||||
SubClassValueT as _SubClassValueT,
|
||||
)
|
||||
|
||||
from typing import Any, List, Union
|
||||
@@ -75,3 +77,9 @@ class ConstantOp(ConstantOp):
|
||||
return FloatAttr(self.value).value
|
||||
else:
|
||||
raise ValueError("only integer and float constants have literal values")
|
||||
|
||||
|
||||
def constant(
|
||||
result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
|
||||
) -> _SubClassValueT:
|
||||
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
|
||||
|
||||
@@ -2,8 +2,11 @@
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ._builtin_ops_gen import *
|
||||
from ._builtin_ops_gen import _Dialect
|
||||
from ..extras.meta import region_op
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
@@ -23,3 +26,23 @@ class ModuleOp(ModuleOp):
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
@region_op
|
||||
def module(
|
||||
*,
|
||||
sym_name=None,
|
||||
sym_visibility=None,
|
||||
attrs: Optional[Dict[str, Attribute]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
mod = ModuleOp.__base__(
|
||||
sym_name=sym_name, sym_visibility=sym_visibility, loc=loc, ip=ip
|
||||
)
|
||||
if attrs is None:
|
||||
attrs = {}
|
||||
for attr_name, attr in attrs.items():
|
||||
mod.operation.attributes[attr_name] = attr
|
||||
|
||||
return mod
|
||||
|
||||
@@ -243,6 +243,9 @@ class FuncOp(FuncOp):
|
||||
return decorator
|
||||
|
||||
|
||||
func = FuncOp.from_py_func
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class CallOp(CallOp):
|
||||
"""Specialization for the call op class."""
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from ._pdl_ops_gen import *
|
||||
from ._pdl_ops_gen import _Dialect
|
||||
from .._mlir_libs._mlirDialectsPDL import *
|
||||
from .._mlir_libs._mlirDialectsPDL import OperationType
|
||||
|
||||
|
||||
try:
|
||||
@@ -13,7 +14,7 @@ try:
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Union, Optional, Sequence, Mapping
|
||||
from typing import Union, Optional, Sequence, Mapping, NewType
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_value,
|
||||
get_op_results_or_values as _get_values,
|
||||
@@ -220,3 +221,10 @@ class TypesOp(TypesOp):
|
||||
constantTypes = []
|
||||
result = pdl.RangeType.get(pdl.TypeType.get())
|
||||
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
|
||||
|
||||
|
||||
OperationTypeT = NewType("OperationType", OperationType)
|
||||
|
||||
|
||||
def op_t() -> OperationTypeT:
|
||||
return OperationTypeT(OperationType.get())
|
||||
|
||||
@@ -120,7 +120,7 @@ def for_(
|
||||
params = [start, stop, step]
|
||||
for i, p in enumerate(params):
|
||||
if isinstance(p, int):
|
||||
p = constant(IntegerAttr.get(IndexType.get(), p))
|
||||
p = constant(IndexType.get(), p)
|
||||
elif isinstance(p, float):
|
||||
raise ValueError(f"{p=} must be int.")
|
||||
params[i] = p
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from ._tensor_ops_gen import *
|
||||
from ._tensor_ops_gen import _Dialect
|
||||
from ..extras.meta import region_op
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
@@ -40,3 +41,9 @@ class EmptyOp(EmptyOp):
|
||||
dynamic_sizes.append(s)
|
||||
result_type = RankedTensorType.get(static_sizes, element_type)
|
||||
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
|
||||
|
||||
|
||||
generate = region_op(
|
||||
lambda result, dynamic_extents: GenerateOp(result, dynamic_extents),
|
||||
terminator=lambda args: YieldOp(args[0]),
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ try:
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
from typing import Optional, Sequence, Union, NewType
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
@@ -175,7 +175,7 @@ class NamedSequenceOp(NamedSequenceOp):
|
||||
result_types: Sequence[Type],
|
||||
sym_visibility=None,
|
||||
arg_attrs=None,
|
||||
res_attrs=None
|
||||
res_attrs=None,
|
||||
):
|
||||
function_type = FunctionType.get(input_types, result_types)
|
||||
super().__init__(
|
||||
@@ -183,7 +183,7 @@ class NamedSequenceOp(NamedSequenceOp):
|
||||
function_type=TypeAttr.get(function_type),
|
||||
sym_visibility=sym_visibility,
|
||||
arg_attrs=arg_attrs,
|
||||
res_attrs=res_attrs
|
||||
res_attrs=res_attrs,
|
||||
)
|
||||
self.regions[0].blocks.append(*input_types)
|
||||
|
||||
@@ -212,3 +212,10 @@ class YieldOp(YieldOp):
|
||||
if operands is None:
|
||||
operands = []
|
||||
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
|
||||
|
||||
|
||||
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
|
||||
|
||||
|
||||
def any_op_t() -> AnyOpTypeT:
|
||||
return AnyOpTypeT(AnyOpType.get())
|
||||
|
||||
@@ -4,8 +4,16 @@
|
||||
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
|
||||
from ....extras.meta import region_op
|
||||
from .... import ir
|
||||
from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp
|
||||
from .. import (
|
||||
AnyOpType,
|
||||
OperationType,
|
||||
NamedSequenceOp,
|
||||
YieldOp,
|
||||
SequenceOp,
|
||||
ApplyPatternsOp,
|
||||
)
|
||||
from .. import structured
|
||||
|
||||
|
||||
@@ -147,3 +155,8 @@ def insert_transform_script(
|
||||
|
||||
if dump_script:
|
||||
print(named_sequence_op)
|
||||
|
||||
|
||||
sequence = region_op(SequenceOp.__base__, terminator=YieldOp)
|
||||
named_sequence = region_op(NamedSequenceOp, terminator=YieldOp)
|
||||
apply_patterns = region_op(ApplyPatternsOp)
|
||||
|
||||
83
mlir/python/mlir/extras/meta.py
Normal file
83
mlir/python/mlir/extras/meta.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
from ..dialects._ods_common import get_op_result_or_op_results
|
||||
from ..ir import Type, InsertionPoint
|
||||
|
||||
|
||||
def op_region_builder(op, op_region, terminator=None):
|
||||
def builder_wrapper(body_builder):
|
||||
# Add a block with block args having types determined by type hints on the wrapped function.
|
||||
if len(op_region.blocks) == 0:
|
||||
sig = inspect.signature(body_builder)
|
||||
types = [p.annotation for p in sig.parameters.values()]
|
||||
if not (
|
||||
len(types) == len(sig.parameters)
|
||||
and all(isinstance(t, Type) for t in types)
|
||||
):
|
||||
raise ValueError(
|
||||
f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
|
||||
)
|
||||
|
||||
op_region.blocks.append(*types)
|
||||
|
||||
with InsertionPoint(op_region.blocks[0]):
|
||||
results = body_builder(*list(op_region.blocks[0].arguments))
|
||||
|
||||
with InsertionPoint(list(op_region.blocks)[-1]):
|
||||
if terminator is not None:
|
||||
res = []
|
||||
if isinstance(results, (tuple, list)):
|
||||
res.extend(results)
|
||||
elif results is not None:
|
||||
res.append(results)
|
||||
terminator(res)
|
||||
|
||||
return get_op_result_or_op_results(op)
|
||||
|
||||
return builder_wrapper
|
||||
|
||||
|
||||
def region_op(op_constructor, terminator=None):
|
||||
"""Decorator to define an MLIR Op specified as a python function.
|
||||
|
||||
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
|
||||
active for the current thread (i.e. established in a `with` block).
|
||||
|
||||
Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor.
|
||||
|
||||
When applied as a decorator to a Python function, an entry block will
|
||||
be constructed for the Op with types as specified **as type hints on the args of the function**.
|
||||
The block arguments will be passed positionally to the Python function.
|
||||
|
||||
If a terminator is specified then the return from the decorated function will be passed
|
||||
to the terminator as the last statement in the entry block. Note, the API for the terminator
|
||||
is a (possibly empty) list; terminator accepting single values should be wrapped in a
|
||||
`lambda args: term(args[0])`
|
||||
|
||||
The identifier (name) of the function will become:
|
||||
1. A single value result if the Op returns a single value;
|
||||
2. An OpResultList (as a list) if the Op returns multiple values;
|
||||
3. The Operation if the Op returns no results.
|
||||
|
||||
See examples in tensor.py and transform.extras.
|
||||
"""
|
||||
|
||||
def op_decorator(*args, **kwargs):
|
||||
op = op_constructor(*args, **kwargs)
|
||||
op_region = op.regions[0]
|
||||
|
||||
return op_region_builder(op, op_region, terminator)
|
||||
|
||||
@wraps(op_decorator)
|
||||
def maybe_no_args(*args, **kwargs):
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
return op_decorator()(args[0])
|
||||
else:
|
||||
return op_decorator(*args, **kwargs)
|
||||
|
||||
return maybe_no_args
|
||||
Reference in New Issue
Block a user