[NFC][Py Reformat] Reformat python files in mlir subdir

This is an ongoing series of commits that are reformatting our
Python code.

Reformatting is done with `black`.

If you end up having problems merging this commit because you
have made changes to a python file, the best way to handle that
is to run git checkout --ours <yourfile> and then reformat it
with black.

If you run into any problems, post to discourse about it and
we will try to help.

RFC Thread below:

https://discourse.llvm.org/t/rfc-document-and-standardize-python-code-style

Differential Revision: https://reviews.llvm.org/D150782
This commit is contained in:
Tobias Hieta
2023-05-17 16:53:39 +02:00
parent 9652eba3cf
commit f9008e6366
163 changed files with 16286 additions and 15027 deletions

View File

@@ -25,7 +25,7 @@ from common import setup_passes
def matmul_dsl(
A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)
C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
):
"""Helper function for mlir sparse matrix multiplication benchmark."""
C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
@@ -43,6 +43,7 @@ def benchmark_sparse_mlir_multiplication():
param2_type = ir.RankedTensorType.get([1500, 2000], f64)
result_type = ir.RankedTensorType.get([1000, 2000], f64)
with ir.InsertionPoint(module.body):
@func.FuncOp.from_py_func(param1_type, param2_type, result_type)
def sparse_kernel(x, y, z):
return matmul_dsl(x, y, outs=[z])
@@ -51,37 +52,34 @@ def benchmark_sparse_mlir_multiplication():
with ir.Context(), ir.Location.unknown():
kernel_func = get_kernel_func_from_module(module)
timer_func = emit_timer_func()
wrapped_func = emit_benchmark_wrapped_main_func(
kernel_func,
timer_func
)
wrapped_func = emit_benchmark_wrapped_main_func(kernel_func, timer_func)
main_module_with_benchmark = ir.Module.parse(
str(timer_func) + str(wrapped_func) + str(kernel_func)
)
setup_passes(main_module_with_benchmark)
c_runner_utils = os.getenv("MLIR_C_RUNNER_UTILS", "")
assert os.path.exists(c_runner_utils),\
f"{c_runner_utils} does not exist." \
f" Please pass a valid value for" \
assert os.path.exists(c_runner_utils), (
f"{c_runner_utils} does not exist."
f" Please pass a valid value for"
f" MLIR_C_RUNNER_UTILS environment variable."
)
runner_utils = os.getenv("MLIR_RUNNER_UTILS", "")
assert os.path.exists(runner_utils),\
f"{runner_utils} does not exist." \
f" Please pass a valid value for MLIR_RUNNER_UTILS" \
assert os.path.exists(runner_utils), (
f"{runner_utils} does not exist."
f" Please pass a valid value for MLIR_RUNNER_UTILS"
f" environment variable."
)
engine = ExecutionEngine(
main_module_with_benchmark,
3,
shared_libs=[c_runner_utils, runner_utils]
shared_libs=[c_runner_utils, runner_utils],
)
return engine.invoke
def runner(engine_invoke):
compiled_program_args = []
for argument_type in [
result_type, param1_type, param2_type, result_type
]:
for argument_type in [result_type, param1_type, param2_type, result_type]:
argument_type_str = str(argument_type)
dimensions_str = re.sub("<|>|tensor", "", argument_type_str)
dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]]
@@ -111,6 +109,7 @@ def benchmark_np_matrix_multiplication():
benchmark, we don't have any `compiler` function returned. We just return
the `runner` function.
"""
def runner():
argument1 = np.random.uniform(low=0.0, high=100.0, size=(1000, 1500))
argument2 = np.random.uniform(low=0.0, high=100.0, size=(1500, 2000))

View File

@@ -10,8 +10,7 @@ from mlir.passmanager import PassManager
def setup_passes(mlir_module):
"""Setup pass pipeline parameters for benchmark functions.
"""
"""Setup pass pipeline parameters for benchmark functions."""
opt = (
"parallelization-strategy=none"
" vectorization-strategy=none vl=1 enable-simd-index32=False"
@@ -43,12 +42,15 @@ def get_kernel_func_from_module(module: ir.Module) -> func.FuncOp:
This function only works for a module with one region, one block, and one
operation.
"""
assert len(module.operation.regions) == 1, \
"Expected kernel module to have only one region"
assert len(module.operation.regions[0].blocks) == 1, \
"Expected kernel module to have only one block"
assert len(module.operation.regions[0].blocks[0].operations) == 1, \
"Expected kernel module to have only one operation"
assert (
len(module.operation.regions) == 1
), "Expected kernel module to have only one region"
assert (
len(module.operation.regions[0].blocks) == 1
), "Expected kernel module to have only one block"
assert (
len(module.operation.regions[0].blocks[0].operations) == 1
), "Expected kernel module to have only one operation"
return module.operation.regions[0].blocks[0].operations[0]
@@ -57,8 +59,7 @@ def emit_timer_func() -> func.FuncOp:
used, the `MLIR_RUNNER_UTILS` and `MLIR_C_RUNNER_UTILS` must be included.
"""
i64_type = ir.IntegerType.get_signless(64)
nanoTime = func.FuncOp(
"nanoTime", ([], [i64_type]), visibility="private")
nanoTime = func.FuncOp("nanoTime", ([], [i64_type]), visibility="private")
nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
return nanoTime
@@ -76,9 +77,8 @@ def emit_benchmark_wrapped_main_func(kernel_func, timer_func):
wrapped_func = func.FuncOp(
# Same signature and an extra buffer of indices to save timings.
"main",
(kernel_func.arguments.types + [memref_of_i64_type],
kernel_func.type.results),
visibility="public"
(kernel_func.arguments.types + [memref_of_i64_type], kernel_func.type.results),
visibility="public",
)
wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
@@ -88,13 +88,13 @@ def emit_benchmark_wrapped_main_func(kernel_func, timer_func):
zero = arith.ConstantOp.create_index(0)
n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero)
one = arith.ConstantOp.create_index(1)
iter_args = list(wrapped_func.arguments[-num_results - 1:-1])
iter_args = list(wrapped_func.arguments[-num_results - 1 : -1])
loop = scf.ForOp(zero, n_iterations, one, iter_args)
with ir.InsertionPoint(loop.body):
start = func.CallOp(timer_func, [])
call = func.CallOp(
kernel_func,
wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args
wrapped_func.arguments[: -num_results - 1] + loop.inner_iter_args,
)
end = func.CallOp(timer_func, [])
time_taken = arith.SubIOp(end, start)

View File

@@ -1 +1 @@
config.suffixes.add('.c')
config.suffixes.add(".c")

View File

@@ -16,52 +16,55 @@ from lit.llvm.subst import FindTool
# Configuration file for the 'lit' test runner.
# name: The name of this test suite.
config.name = 'STANDALONE'
config.name = "STANDALONE"
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.mlir']
config.suffixes = [".mlir"]
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.standalone_obj_root, 'test')
config.test_exec_root = os.path.join(config.standalone_obj_root, "test")
config.substitutions.append(('%PATH%', config.environment['PATH']))
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
config.substitutions.append(("%PATH%", config.environment["PATH"]))
config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
llvm_config.with_system_environment(
['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
llvm_config.use_default_substitutions()
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
# subdirectories contain auxiliary inputs for various tests in their parent
# directories.
config.excludes = ['Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt']
config.excludes = ["Inputs", "Examples", "CMakeLists.txt", "README.txt", "LICENSE.txt"]
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.standalone_obj_root, 'test')
config.standalone_tools_dir = os.path.join(config.standalone_obj_root, 'bin')
config.standalone_libs_dir = os.path.join(config.standalone_obj_root, 'lib')
config.test_exec_root = os.path.join(config.standalone_obj_root, "test")
config.standalone_tools_dir = os.path.join(config.standalone_obj_root, "bin")
config.standalone_libs_dir = os.path.join(config.standalone_obj_root, "lib")
config.substitutions.append(('%standalone_libs', config.standalone_libs_dir))
config.substitutions.append(("%standalone_libs", config.standalone_libs_dir))
# Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir]
tools = [
'mlir-opt',
'standalone-capi-test',
'standalone-opt',
'standalone-translate',
"mlir-opt",
"standalone-capi-test",
"standalone-opt",
"standalone-translate",
]
llvm_config.add_tool_substitutions(tools, tool_dirs)
llvm_config.with_environment('PYTHONPATH', [
os.path.join(config.mlir_obj_dir, 'python_packages', 'standalone'),
], append_path=True)
llvm_config.with_environment(
"PYTHONPATH",
[
os.path.join(config.mlir_obj_dir, "python_packages", "standalone"),
],
append_path=True,
)

View File

@@ -1,4 +1,4 @@
config.suffixes.add('.py')
config.suffixes.add(".py")
if not config.enable_bindings_python:
config.unsupported = True
config.unsupported = True

View File

@@ -1,17 +1,16 @@
# RUN: %python %s | FileCheck %s
from mlir_standalone.ir import *
from mlir_standalone.dialects import (
builtin as builtin_d,
standalone as standalone_d
)
from mlir_standalone.dialects import builtin as builtin_d, standalone as standalone_d
with Context():
standalone_d.register_dialect()
module = Module.parse("""
standalone_d.register_dialect()
module = Module.parse(
"""
%0 = arith.constant 2 : i32
%1 = standalone.foo %0 : i32
""")
# CHECK: %[[C:.*]] = arith.constant 2 : i32
# CHECK: standalone.foo %[[C]] : i32
print(str(module))
"""
)
# CHECK: %[[C:.*]] = arith.constant 2 : i32
# CHECK: standalone.foo %[[C]] : i32
print(str(module))

View File

@@ -10,26 +10,26 @@ _this_dir = os.path.dirname(__file__)
def get_lib_dirs() -> Sequence[str]:
"""Gets the lib directory for linking to shared libraries.
"""Gets the lib directory for linking to shared libraries.
On some platforms, the package may need to be built specially to export
development libraries.
"""
return [_this_dir]
On some platforms, the package may need to be built specially to export
development libraries.
"""
return [_this_dir]
def get_include_dirs() -> Sequence[str]:
"""Gets the include directory for compiling against exported C libraries.
"""Gets the include directory for compiling against exported C libraries.
Depending on how the package was build, development C libraries may or may
not be present.
"""
return [os.path.join(_this_dir, "include")]
Depending on how the package was build, development C libraries may or may
not be present.
"""
return [os.path.join(_this_dir, "include")]
# Perform Python level site initialization. This involves:
# 1. Attempting to load initializer modules, specific to the distribution.
# 2. Defining the concrete mlir.ir.Context that does site specific
# 2. Defining the concrete mlir.ir.Context that does site specific
# initialization.
#
# Aside from just being far more convenient to do this at the Python level,
@@ -38,91 +38,106 @@ def get_include_dirs() -> Sequence[str]:
# in the scope of the base class __init__).
#
# For #1, we:
# a. Probe for modules named '_mlirRegisterEverything' and
# '_site_initialize_{i}', where 'i' is a number starting at zero and
# a. Probe for modules named '_mlirRegisterEverything' and
# '_site_initialize_{i}', where 'i' is a number starting at zero and
# proceeding so long as a module with the name is found.
# b. If the module has a 'register_dialects' attribute, it will be called
# immediately with a DialectRegistry to populate.
# c. If the module has a 'context_init_hook', it will be added to a list
# of callbacks that are invoked as the last step of Context
# of callbacks that are invoked as the last step of Context
# initialization (and passed the Context under construction).
#
# This facility allows downstreams to customize Context creation to their
# needs.
def _site_initialize():
import importlib
import itertools
import logging
from ._mlir import ir
logger = logging.getLogger(__name__)
registry = ir.DialectRegistry()
post_init_hooks = []
import importlib
import itertools
import logging
from ._mlir import ir
def process_initializer_module(module_name):
try:
m = importlib.import_module(f".{module_name}", __name__)
except ModuleNotFoundError:
return False
except ImportError:
message = (f"Error importing mlir initializer {module_name}. This may "
"happen in unclean incremental builds but is likely a real bug if "
"encountered otherwise and the MLIR Python API may not function.")
logger.warning(message, exc_info=True)
logger = logging.getLogger(__name__)
registry = ir.DialectRegistry()
post_init_hooks = []
logger.debug("Initializing MLIR with module: %s", module_name)
if hasattr(m, "register_dialects"):
logger.debug("Registering dialects from initializer %r", m)
m.register_dialects(registry)
if hasattr(m, "context_init_hook"):
logger.debug("Adding context init hook from %r", m)
post_init_hooks.append(m.context_init_hook)
return True
def process_initializer_module(module_name):
try:
m = importlib.import_module(f".{module_name}", __name__)
except ModuleNotFoundError:
return False
except ImportError:
message = (
f"Error importing mlir initializer {module_name}. This may "
"happen in unclean incremental builds but is likely a real bug if "
"encountered otherwise and the MLIR Python API may not function."
)
logger.warning(message, exc_info=True)
logger.debug("Initializing MLIR with module: %s", module_name)
if hasattr(m, "register_dialects"):
logger.debug("Registering dialects from initializer %r", m)
m.register_dialects(registry)
if hasattr(m, "context_init_hook"):
logger.debug("Adding context init hook from %r", m)
post_init_hooks.append(m.context_init_hook)
return True
# If _mlirRegisterEverything is built, then include it as an initializer
# module.
process_initializer_module("_mlirRegisterEverything")
# If _mlirRegisterEverything is built, then include it as an initializer
# module.
process_initializer_module("_mlirRegisterEverything")
# Load all _site_initialize_{i} modules, where 'i' is a number starting
# at 0.
for i in itertools.count():
module_name = f"_site_initialize_{i}"
if not process_initializer_module(module_name):
break
# Load all _site_initialize_{i} modules, where 'i' is a number starting
# at 0.
for i in itertools.count():
module_name = f"_site_initialize_{i}"
if not process_initializer_module(module_name):
break
class Context(ir._BaseContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.append_dialect_registry(registry)
for hook in post_init_hooks:
hook(self)
# TODO: There is some debate about whether we should eagerly load
# all dialects. It is being done here in order to preserve existing
# behavior. See: https://github.com/llvm/llvm-project/issues/56037
self.load_all_available_dialects()
ir.Context = Context
class Context(ir._BaseContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.append_dialect_registry(registry)
for hook in post_init_hooks:
hook(self)
# TODO: There is some debate about whether we should eagerly load
# all dialects. It is being done here in order to preserve existing
# behavior. See: https://github.com/llvm/llvm-project/issues/56037
self.load_all_available_dialects()
class MLIRError(Exception):
"""
An exception with diagnostic information. Has the following fields:
message: str
error_diagnostics: List[ir.DiagnosticInfo]
"""
def __init__(self, message, error_diagnostics):
self.message = message
self.error_diagnostics = error_diagnostics
super().__init__(message, error_diagnostics)
ir.Context = Context
def __str__(self):
s = self.message
if self.error_diagnostics:
s += ':'
for diag in self.error_diagnostics:
s += "\nerror: " + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n ')
for note in diag.notes:
s += "\n note: " + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n ')
return s
ir.MLIRError = MLIRError
class MLIRError(Exception):
"""
An exception with diagnostic information. Has the following fields:
message: str
error_diagnostics: List[ir.DiagnosticInfo]
"""
def __init__(self, message, error_diagnostics):
self.message = message
self.error_diagnostics = error_diagnostics
super().__init__(message, error_diagnostics)
def __str__(self):
s = self.message
if self.error_diagnostics:
s += ":"
for diag in self.error_diagnostics:
s += (
"\nerror: "
+ str(diag.location)[4:-1]
+ ": "
+ diag.message.replace("\n", "\n ")
)
for note in diag.notes:
s += (
"\n note: "
+ str(note.location)[4:-1]
+ ": "
+ note.message.replace("\n", "\n ")
)
return s
ir.MLIRError = MLIRError
_site_initialize()

View File

@@ -3,72 +3,67 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import get_default_loc_context as _get_default_loc_context
from ..ir import *
from ._ods_common import get_default_loc_context as _get_default_loc_context
from typing import Any, List, Union
from typing import Any, List, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
def _isa(obj: Any, cls: type):
try:
cls(obj)
except ValueError:
return False
return True
try:
cls(obj)
except ValueError:
return False
return True
def _is_any_of(obj: Any, classes: List[type]):
return any(_isa(obj, cls) for cls in classes)
return any(_isa(obj, cls) for cls in classes)
def _is_integer_like_type(type: Type):
return _is_any_of(type, [IntegerType, IndexType])
return _is_any_of(type, [IntegerType, IndexType])
def _is_float_type(type: Type):
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
class ConstantOp:
"""Specialization for the constant op class."""
"""Specialization for the constant op class."""
def __init__(self,
result: Type,
value: Union[int, float, Attribute],
*,
loc=None,
ip=None):
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
super().__init__(value, loc=loc, ip=ip)
def __init__(
self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
):
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
super().__init__(value, loc=loc, ip=ip)
@classmethod
def create_index(cls, value: int, *, loc=None, ip=None):
"""Create an index-typed constant."""
return cls(
IndexType.get(context=_get_default_loc_context(loc)),
value,
loc=loc,
ip=ip)
@classmethod
def create_index(cls, value: int, *, loc=None, ip=None):
"""Create an index-typed constant."""
return cls(
IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
)
@property
def type(self):
return self.results[0].type
@property
def type(self):
return self.results[0].type
@property
def value(self):
return Attribute(self.operation.attributes["value"])
@property
def value(self):
return Attribute(self.operation.attributes["value"])
@property
def literal_value(self) -> Union[int, float]:
if _is_integer_like_type(self.type):
return IntegerAttr(self.value).value
elif _is_float_type(self.type):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")
@property
def literal_value(self) -> Union[int, float]:
if _is_integer_like_type(self.type):
return IntegerAttr(self.value).value
elif _is_float_type(self.type):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")

View File

@@ -3,36 +3,39 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from typing import Sequence, Union
from ..ir import *
from ._ods_common import get_default_loc_context
from typing import Sequence, Union
from ..ir import *
from ._ods_common import get_default_loc_context
from typing import Any, List, Union
from typing import Any, List, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
class AllocTensorOp:
"""Extends the bufferization.alloc_tensor op."""
"""Extends the bufferization.alloc_tensor op."""
def __init__(self,
tensor_type: Type,
dynamic_sizes: Sequence[Value],
copy: Value,
size_hint: Value,
escape: BoolAttr,
*,
loc=None,
ip=None):
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
context = get_default_loc_context(loc)
attributes = {}
if escape:
attributes["escape"] = escape
op = self.build_generic(
results=[tensor_type],
operands=[dynamic_sizes, copy, size_hint],
attributes=attributes,
loc=loc,
ip=ip)
OpView.__init__(self, op)
def __init__(
self,
tensor_type: Type,
dynamic_sizes: Sequence[Value],
copy: Value,
size_hint: Value,
escape: BoolAttr,
*,
loc=None,
ip=None
):
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
context = get_default_loc_context(loc)
attributes = {}
if escape:
attributes["escape"] = escape
op = self.build_generic(
results=[tensor_type],
operands=[dynamic_sizes, copy, size_hint],
attributes=attributes,
loc=loc,
ip=ip,
)
OpView.__init__(self, op)

View File

@@ -3,18 +3,18 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ..ir import *
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
class ModuleOp:
"""Specialization for the module op class."""
"""Specialization for the module op class."""
def __init__(self, *, loc=None, ip=None):
super().__init__(self.build_generic(results=[], operands=[], loc=loc,
ip=ip))
body = self.regions[0].blocks.append()
def __init__(self, *, loc=None, ip=None):
super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip))
body = self.regions[0].blocks.append()
@property
def body(self):
return self.regions[0].blocks[0]
@property
def body(self):
return self.regions[0].blocks[0]

View File

@@ -3,298 +3,317 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import get_default_loc_context as _get_default_loc_context
from ..ir import *
from ._ods_common import get_default_loc_context as _get_default_loc_context
import inspect
import inspect
from typing import Any, List, Optional, Sequence, Union
from typing import Any, List, Optional, Sequence, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
RESULT_ATTRIBUTE_NAME = "res_attrs"
class ConstantOp:
"""Specialization for the constant op class."""
"""Specialization for the constant op class."""
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
super().__init__(result, value, loc=loc, ip=ip)
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
super().__init__(result, value, loc=loc, ip=ip)
@property
def type(self):
return self.results[0].type
@property
def type(self):
return self.results[0].type
class FuncOp:
"""Specialization for the func op class."""
"""Specialization for the func op class."""
def __init__(self,
name,
type,
*,
visibility=None,
body_builder=None,
loc=None,
ip=None):
"""
Create a FuncOp with the provided `name`, `type`, and `visibility`.
- `name` is a string representing the function name.
- `type` is either a FunctionType or a pair of list describing inputs and
results.
- `visibility` is a string matching `public`, `private`, or `nested`. None
implies private visibility.
- `body_builder` is an optional callback, when provided a new entry block
is created and the callback is invoked with the new op as argument within
an InsertionPoint context already set for the block. The callback is
expected to insert a terminator in the block.
"""
sym_name = StringAttr.get(str(name))
def __init__(
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
):
"""
Create a FuncOp with the provided `name`, `type`, and `visibility`.
- `name` is a string representing the function name.
- `type` is either a FunctionType or a pair of list describing inputs and
results.
- `visibility` is a string matching `public`, `private`, or `nested`. None
implies private visibility.
- `body_builder` is an optional callback, when provided a new entry block
is created and the callback is invoked with the new op as argument within
an InsertionPoint context already set for the block. The callback is
expected to insert a terminator in the block.
"""
sym_name = StringAttr.get(str(name))
# If the type is passed as a tuple, build a FunctionType on the fly.
if isinstance(type, tuple):
type = FunctionType.get(inputs=type[0], results=type[1])
# If the type is passed as a tuple, build a FunctionType on the fly.
if isinstance(type, tuple):
type = FunctionType.get(inputs=type[0], results=type[1])
type = TypeAttr.get(type)
sym_visibility = StringAttr.get(
str(visibility)) if visibility is not None else None
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
if body_builder:
entry_block = self.add_entry_block()
with InsertionPoint(entry_block):
body_builder(self)
type = TypeAttr.get(type)
sym_visibility = (
StringAttr.get(str(visibility)) if visibility is not None else None
)
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
if body_builder:
entry_block = self.add_entry_block()
with InsertionPoint(entry_block):
body_builder(self)
@property
def is_external(self):
return len(self.regions[0].blocks) == 0
@property
def is_external(self):
return len(self.regions[0].blocks) == 0
@property
def body(self):
return self.regions[0]
@property
def body(self):
return self.regions[0]
@property
def type(self):
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
@property
def type(self):
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
@property
def visibility(self):
return self.attributes["sym_visibility"]
@property
def visibility(self):
return self.attributes["sym_visibility"]
@property
def name(self) -> StringAttr:
return StringAttr(self.attributes["sym_name"])
@property
def name(self) -> StringAttr:
return StringAttr(self.attributes["sym_name"])
@property
def entry_block(self):
if self.is_external:
raise IndexError('External function does not have a body')
return self.regions[0].blocks[0]
@property
def entry_block(self):
if self.is_external:
raise IndexError("External function does not have a body")
return self.regions[0].blocks[0]
def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
"""
Add an entry block to the function body using the function signature to
infer block arguments.
Returns the newly created block
"""
if not self.is_external:
raise IndexError('The function already has an entry block!')
self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
return self.body.blocks[0]
def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
"""
Add an entry block to the function body using the function signature to
infer block arguments.
Returns the newly created block
"""
if not self.is_external:
raise IndexError("The function already has an entry block!")
self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
return self.body.blocks[0]
@property
def arg_attrs(self):
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@property
def arg_attrs(self):
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@arg_attrs.setter
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
if isinstance(attribute, ArrayAttr):
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
else:
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
attribute, context=self.context)
@property
def arguments(self):
return self.entry_block.arguments
@property
def result_attrs(self):
return self.attributes[RESULT_ATTRIBUTE_NAME]
@result_attrs.setter
def result_attrs(self, attribute: ArrayAttr):
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
@classmethod
def from_py_func(FuncOp,
*inputs: Type,
results: Optional[Sequence[Type]] = None,
name: Optional[str] = None):
"""Decorator to define an MLIR FuncOp specified as a python function.
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
active for the current thread (i.e. established in a `with` block).
When applied as a decorator to a Python function, an entry block will
be constructed for the FuncOp with types as specified in `*inputs`. The
block arguments will be passed positionally to the Python function. In
addition, if the Python function accepts keyword arguments generally or
has a corresponding keyword argument, the following will be passed:
* `func_op`: The `func` op being defined.
By default, the function name will be the Python function `__name__`. This
can be overriden by passing the `name` argument to the decorator.
If `results` is not specified, then the decorator will implicitly
insert a `ReturnOp` with the `Value`'s returned from the decorated
function. It will also set the `FuncOp` type with the actual return
value types. If `results` is specified, then the decorated function
must return `None` and no implicit `ReturnOp` is added (nor are the result
types updated). The implicit behavior is intended for simple, single-block
cases, and users should specify result types explicitly for any complicated
cases.
The decorated function can further be called from Python and will insert
a `CallOp` at the then-current insertion point, returning either None (
if no return values), a unary Value (for one result), or a list of Values).
This mechanism cannot be used to emit recursive calls (by construction).
"""
def decorator(f):
from . import func
# Introspect the callable for optional features.
sig = inspect.signature(f)
has_arg_func_op = False
for param in sig.parameters.values():
if param.kind == param.VAR_KEYWORD:
has_arg_func_op = True
if param.name == "func_op" and (param.kind
== param.POSITIONAL_OR_KEYWORD or
param.kind == param.KEYWORD_ONLY):
has_arg_func_op = True
# Emit the FuncOp.
implicit_return = results is None
symbol_name = name or f.__name__
function_type = FunctionType.get(
inputs=inputs, results=[] if implicit_return else results)
func_op = FuncOp(name=symbol_name, type=function_type)
with InsertionPoint(func_op.add_entry_block()):
func_args = func_op.entry_block.arguments
func_kwargs = {}
if has_arg_func_op:
func_kwargs["func_op"] = func_op
return_values = f(*func_args, **func_kwargs)
if not implicit_return:
return_types = list(results)
assert return_values is None, (
"Capturing a python function with explicit `results=` "
"requires that the wrapped function returns None.")
@arg_attrs.setter
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
if isinstance(attribute, ArrayAttr):
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
else:
# Coerce return values, add ReturnOp and rewrite func type.
if return_values is None:
return_values = []
elif isinstance(return_values, tuple):
return_values = list(return_values)
elif isinstance(return_values, Value):
# Returning a single value is fine, coerce it into a list.
return_values = [return_values]
elif isinstance(return_values, OpView):
# Returning a single operation is fine, coerce its results a list.
return_values = return_values.operation.results
elif isinstance(return_values, Operation):
# Returning a single operation is fine, coerce its results a list.
return_values = return_values.results
else:
return_values = list(return_values)
func.ReturnOp(return_values)
# Recompute the function type.
return_types = [v.type for v in return_values]
function_type = FunctionType.get(inputs=inputs, results=return_types)
func_op.attributes["function_type"] = TypeAttr.get(function_type)
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
attribute, context=self.context
)
def emit_call_op(*call_args):
call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name),
call_args)
if return_types is None:
return None
elif len(return_types) == 1:
return call_op.result
else:
return call_op.results
@property
def arguments(self):
return self.entry_block.arguments
wrapped = emit_call_op
wrapped.__name__ = f.__name__
wrapped.func_op = func_op
return wrapped
@property
def result_attrs(self):
return self.attributes[RESULT_ATTRIBUTE_NAME]
@result_attrs.setter
def result_attrs(self, attribute: ArrayAttr):
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
@classmethod
def from_py_func(
FuncOp,
*inputs: Type,
results: Optional[Sequence[Type]] = None,
name: Optional[str] = None,
):
"""Decorator to define an MLIR FuncOp specified as a python function.
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
active for the current thread (i.e. established in a `with` block).
When applied as a decorator to a Python function, an entry block will
be constructed for the FuncOp with types as specified in `*inputs`. The
block arguments will be passed positionally to the Python function. In
addition, if the Python function accepts keyword arguments generally or
has a corresponding keyword argument, the following will be passed:
* `func_op`: The `func` op being defined.
By default, the function name will be the Python function `__name__`. This
can be overriden by passing the `name` argument to the decorator.
If `results` is not specified, then the decorator will implicitly
insert a `ReturnOp` with the `Value`'s returned from the decorated
function. It will also set the `FuncOp` type with the actual return
value types. If `results` is specified, then the decorated function
must return `None` and no implicit `ReturnOp` is added (nor are the result
types updated). The implicit behavior is intended for simple, single-block
cases, and users should specify result types explicitly for any complicated
cases.
The decorated function can further be called from Python and will insert
a `CallOp` at the then-current insertion point, returning either None (
if no return values), a unary Value (for one result), or a list of Values).
This mechanism cannot be used to emit recursive calls (by construction).
"""
def decorator(f):
from . import func
# Introspect the callable for optional features.
sig = inspect.signature(f)
has_arg_func_op = False
for param in sig.parameters.values():
if param.kind == param.VAR_KEYWORD:
has_arg_func_op = True
if param.name == "func_op" and (
param.kind == param.POSITIONAL_OR_KEYWORD
or param.kind == param.KEYWORD_ONLY
):
has_arg_func_op = True
# Emit the FuncOp.
implicit_return = results is None
symbol_name = name or f.__name__
function_type = FunctionType.get(
inputs=inputs, results=[] if implicit_return else results
)
func_op = FuncOp(name=symbol_name, type=function_type)
with InsertionPoint(func_op.add_entry_block()):
func_args = func_op.entry_block.arguments
func_kwargs = {}
if has_arg_func_op:
func_kwargs["func_op"] = func_op
return_values = f(*func_args, **func_kwargs)
if not implicit_return:
return_types = list(results)
assert return_values is None, (
"Capturing a python function with explicit `results=` "
"requires that the wrapped function returns None."
)
else:
# Coerce return values, add ReturnOp and rewrite func type.
if return_values is None:
return_values = []
elif isinstance(return_values, tuple):
return_values = list(return_values)
elif isinstance(return_values, Value):
# Returning a single value is fine, coerce it into a list.
return_values = [return_values]
elif isinstance(return_values, OpView):
# Returning a single operation is fine, coerce its results a list.
return_values = return_values.operation.results
elif isinstance(return_values, Operation):
# Returning a single operation is fine, coerce its results a list.
return_values = return_values.results
else:
return_values = list(return_values)
func.ReturnOp(return_values)
# Recompute the function type.
return_types = [v.type for v in return_values]
function_type = FunctionType.get(
inputs=inputs, results=return_types
)
func_op.attributes["function_type"] = TypeAttr.get(function_type)
def emit_call_op(*call_args):
call_op = func.CallOp(
return_types, FlatSymbolRefAttr.get(symbol_name), call_args
)
if return_types is None:
return None
elif len(return_types) == 1:
return call_op.result
else:
return call_op.results
wrapped = emit_call_op
wrapped.__name__ = f.__name__
wrapped.func_op = func_op
return wrapped
return decorator
return decorator
class CallOp:
"""Specialization for the call op class."""
"""Specialization for the call op class."""
def __init__(self,
calleeOrResults: Union[FuncOp, List[Type]],
argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
arguments: Optional[List] = None,
*,
loc=None,
ip=None):
"""Creates an call operation.
def __init__(
self,
calleeOrResults: Union[FuncOp, List[Type]],
argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
arguments: Optional[List] = None,
*,
loc=None,
ip=None,
):
"""Creates an call operation.
The constructor accepts three different forms:
The constructor accepts three different forms:
1. A function op to be called followed by a list of arguments.
2. A list of result types, followed by the name of the function to be
called as string, following by a list of arguments.
3. A list of result types, followed by the name of the function to be
called as symbol reference attribute, followed by a list of arguments.
1. A function op to be called followed by a list of arguments.
2. A list of result types, followed by the name of the function to be
called as string, following by a list of arguments.
3. A list of result types, followed by the name of the function to be
called as symbol reference attribute, followed by a list of arguments.
For example
For example
f = func.FuncOp("foo", ...)
func.CallOp(f, [args])
func.CallOp([result_types], "foo", [args])
f = func.FuncOp("foo", ...)
func.CallOp(f, [args])
func.CallOp([result_types], "foo", [args])
In all cases, the location and insertion point may be specified as keyword
arguments if not provided by the surrounding context managers.
"""
In all cases, the location and insertion point may be specified as keyword
arguments if not provided by the surrounding context managers.
"""
# TODO: consider supporting constructor "overloads", e.g., through a custom
# or pybind-provided metaclass.
if isinstance(calleeOrResults, FuncOp):
if not isinstance(argumentsOrCallee, list):
raise ValueError(
"when constructing a call to a function, expected " +
"the second argument to be a list of call arguments, " +
f"got {type(argumentsOrCallee)}")
if arguments is not None:
raise ValueError("unexpected third argument when constructing a call" +
"to a function")
# TODO: consider supporting constructor "overloads", e.g., through a custom
# or pybind-provided metaclass.
if isinstance(calleeOrResults, FuncOp):
if not isinstance(argumentsOrCallee, list):
raise ValueError(
"when constructing a call to a function, expected "
+ "the second argument to be a list of call arguments, "
+ f"got {type(argumentsOrCallee)}"
)
if arguments is not None:
raise ValueError(
"unexpected third argument when constructing a call"
+ "to a function"
)
super().__init__(
calleeOrResults.type.results,
FlatSymbolRefAttr.get(
calleeOrResults.name.value,
context=_get_default_loc_context(loc)),
argumentsOrCallee,
loc=loc,
ip=ip)
return
super().__init__(
calleeOrResults.type.results,
FlatSymbolRefAttr.get(
calleeOrResults.name.value, context=_get_default_loc_context(loc)
),
argumentsOrCallee,
loc=loc,
ip=ip,
)
return
if isinstance(argumentsOrCallee, list):
raise ValueError("when constructing a call to a function by name, " +
"expected the second argument to be a string or a " +
f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}")
if isinstance(argumentsOrCallee, list):
raise ValueError(
"when constructing a call to a function by name, "
+ "expected the second argument to be a string or a "
+ f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
)
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
super().__init__(
calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip)
elif isinstance(argumentsOrCallee, str):
super().__init__(
calleeOrResults,
FlatSymbolRefAttr.get(
argumentsOrCallee, context=_get_default_loc_context(loc)),
arguments,
loc=loc,
ip=ip)
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
super().__init__(
calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
)
elif isinstance(argumentsOrCallee, str):
super().__init__(
calleeOrResults,
FlatSymbolRefAttr.get(
argumentsOrCallee, context=_get_default_loc_context(loc)
),
arguments,
loc=loc,
ip=ip,
)

View File

@@ -3,39 +3,45 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from typing import Optional, Sequence, Union
from ..ir import *
from ._ods_common import get_default_loc_context
from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
from typing import Optional, Sequence, Union
from ..ir import *
from ._ods_common import get_default_loc_context
from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
def isa(cls: Type, ty: Type):
try:
cls(ty)
return True
except ValueError:
return False
try:
cls(ty)
return True
except ValueError:
return False
class StructuredOpMixin:
"""All structured ops use the same mixin class."""
"""All structured ops use the same mixin class."""
def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
super().__init__(
self.build_generic(results=list(results),
operands=[list(inputs), list(outputs)],
loc=loc,
ip=ip))
def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
super().__init__(
self.build_generic(
results=list(results),
operands=[list(inputs), list(outputs)],
loc=loc,
ip=ip,
)
)
def select_opview_mixin(parent_opview_cls):
# TODO: This shouldn't be a heuristic: we should have a way to annotate
# the OpView to note that it is a structured op.
if ("__init__" not in parent_opview_cls.__dict__ and
hasattr(parent_opview_cls, "inputs") and
hasattr(parent_opview_cls, "outputs") and
hasattr(parent_opview_cls, "result_tensors")):
return StructuredOpMixin
# TODO: This shouldn't be a heuristic: we should have a way to annotate
# the OpView to note that it is a structured op.
if (
"__init__" not in parent_opview_cls.__dict__
and hasattr(parent_opview_cls, "inputs")
and hasattr(parent_opview_cls, "outputs")
and hasattr(parent_opview_cls, "result_tensors")
):
return StructuredOpMixin

View File

@@ -3,125 +3,130 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Union
class GetParentForOp:
"""Extension for GetParentForOp."""
"""Extension for GetParentForOp."""
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
num_loops: Optional[int] = None,
ip=None,
loc=None,
):
if num_loops is None:
num_loops = 1
super().__init__(
result_type,
_get_op_result_or_value(target),
num_loops=num_loops,
ip=ip,
loc=loc,
)
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
num_loops: Optional[int] = None,
ip=None,
loc=None,
):
if num_loops is None:
num_loops = 1
super().__init__(
result_type,
_get_op_result_or_value(target),
num_loops=num_loops,
ip=ip,
loc=loc,
)
class LoopOutlineOp:
"""Extension for LoopOutlineOp."""
"""Extension for LoopOutlineOp."""
def __init__(
self,
function_type: Type,
call_type: Type,
target: Union[Operation, Value],
*,
func_name: Union[str, StringAttr],
ip=None,
loc=None,
):
super().__init__(
function_type,
call_type,
_get_op_result_or_value(target),
func_name=(func_name if isinstance(func_name, StringAttr) else
StringAttr.get(func_name)),
ip=ip,
loc=loc,
)
def __init__(
self,
function_type: Type,
call_type: Type,
target: Union[Operation, Value],
*,
func_name: Union[str, StringAttr],
ip=None,
loc=None,
):
super().__init__(
function_type,
call_type,
_get_op_result_or_value(target),
func_name=(
func_name
if isinstance(func_name, StringAttr)
else StringAttr.get(func_name)
),
ip=ip,
loc=loc,
)
class LoopPeelOp:
"""Extension for LoopPeelOp."""
"""Extension for LoopPeelOp."""
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
fail_if_already_divisible: Union[bool, BoolAttr] = False,
ip=None,
loc=None,
):
super().__init__(
result_type,
_get_op_result_or_value(target),
fail_if_already_divisible=(fail_if_already_divisible if isinstance(
fail_if_already_divisible, BoolAttr) else
BoolAttr.get(fail_if_already_divisible)),
ip=ip,
loc=loc,
)
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
fail_if_already_divisible: Union[bool, BoolAttr] = False,
ip=None,
loc=None,
):
super().__init__(
result_type,
_get_op_result_or_value(target),
fail_if_already_divisible=(
fail_if_already_divisible
if isinstance(fail_if_already_divisible, BoolAttr)
else BoolAttr.get(fail_if_already_divisible)
),
ip=ip,
loc=loc,
)
class LoopPipelineOp:
"""Extension for LoopPipelineOp."""
"""Extension for LoopPipelineOp."""
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
read_latency: Optional[Union[int, IntegerAttr]] = None,
ip=None,
loc=None,
):
if iteration_interval is None:
iteration_interval = 1
if read_latency is None:
read_latency = 10
super().__init__(
result_type,
_get_op_result_or_value(target),
iteration_interval=iteration_interval,
read_latency=read_latency,
ip=ip,
loc=loc,
)
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
read_latency: Optional[Union[int, IntegerAttr]] = None,
ip=None,
loc=None,
):
if iteration_interval is None:
iteration_interval = 1
if read_latency is None:
read_latency = 10
super().__init__(
result_type,
_get_op_result_or_value(target),
iteration_interval=iteration_interval,
read_latency=read_latency,
ip=ip,
loc=loc,
)
class LoopUnrollOp:
"""Extension for LoopUnrollOp."""
"""Extension for LoopUnrollOp."""
def __init__(
self,
target: Union[Operation, Value],
*,
factor: Union[int, IntegerAttr],
ip=None,
loc=None,
):
super().__init__(
_get_op_result_or_value(target),
factor=factor,
ip=ip,
loc=loc,
)
def __init__(
self,
target: Union[Operation, Value],
*,
factor: Union[int, IntegerAttr],
ip=None,
loc=None,
):
super().__init__(
_get_op_result_or_value(target),
factor=factor,
ip=ip,
loc=loc,
)

View File

@@ -3,34 +3,34 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ._ods_common import get_op_results_or_values as _get_op_results_or_values
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ._ods_common import get_op_results_or_values as _get_op_results_or_values
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
class LoadOp:
"""Specialization for the MemRef load operation."""
"""Specialization for the MemRef load operation."""
def __init__(self,
memref: Union[Operation, OpView, Value],
indices: Optional[Union[Operation, OpView,
Sequence[Value]]] = None,
*,
loc=None,
ip=None):
"""Creates a memref load operation.
def __init__(
self,
memref: Union[Operation, OpView, Value],
indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
loc=None,
ip=None
):
"""Creates a memref load operation.
Args:
memref: the buffer to load from.
indices: the list of subscripts, may be empty for zero-dimensional
buffers.
loc: user-visible location of the operation.
ip: insertion point.
"""
indices_resolved = [] if indices is None else _get_op_results_or_values(
indices)
super().__init__(memref, indices_resolved, loc=loc, ip=ip)
Args:
memref: the buffer to load from.
indices: the list of subscripts, may be empty for zero-dimensional
buffers.
loc: user-visible location of the operation.
ip: insertion point.
"""
indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
super().__init__(memref, indices_resolved, loc=loc, ip=ip)

View File

@@ -3,11 +3,11 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from typing import Union
from ..ir import *
from ._ods_common import get_default_loc_context as _get_default_loc_context
from typing import Union
from ..ir import *
from ._ods_common import get_default_loc_context as _get_default_loc_context
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
from ._ml_program_ops_gen import *
@@ -17,100 +17,97 @@ RESULT_ATTRIBUTE_NAME = "res_attrs"
class FuncOp:
"""Specialization for the func op class."""
"""Specialization for the func op class."""
def __init__(self,
name,
type,
*,
visibility=None,
body_builder=None,
loc=None,
ip=None):
"""
Create a FuncOp with the provided `name`, `type`, and `visibility`.
- `name` is a string representing the function name.
- `type` is either a FunctionType or a pair of list describing inputs and
results.
- `visibility` is a string matching `public`, `private`, or `nested`. None
implies private visibility.
- `body_builder` is an optional callback, when provided a new entry block
is created and the callback is invoked with the new op as argument within
an InsertionPoint context already set for the block. The callback is
expected to insert a terminator in the block.
"""
sym_name = StringAttr.get(str(name))
def __init__(
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
):
"""
Create a FuncOp with the provided `name`, `type`, and `visibility`.
- `name` is a string representing the function name.
- `type` is either a FunctionType or a pair of list describing inputs and
results.
- `visibility` is a string matching `public`, `private`, or `nested`. None
implies private visibility.
- `body_builder` is an optional callback, when provided a new entry block
is created and the callback is invoked with the new op as argument within
an InsertionPoint context already set for the block. The callback is
expected to insert a terminator in the block.
"""
sym_name = StringAttr.get(str(name))
# If the type is passed as a tuple, build a FunctionType on the fly.
if isinstance(type, tuple):
type = FunctionType.get(inputs=type[0], results=type[1])
# If the type is passed as a tuple, build a FunctionType on the fly.
if isinstance(type, tuple):
type = FunctionType.get(inputs=type[0], results=type[1])
type = TypeAttr.get(type)
sym_visibility = StringAttr.get(
str(visibility)) if visibility is not None else None
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
if body_builder:
entry_block = self.add_entry_block()
with InsertionPoint(entry_block):
body_builder(self)
type = TypeAttr.get(type)
sym_visibility = (
StringAttr.get(str(visibility)) if visibility is not None else None
)
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
if body_builder:
entry_block = self.add_entry_block()
with InsertionPoint(entry_block):
body_builder(self)
@property
def is_external(self):
return len(self.regions[0].blocks) == 0
@property
def is_external(self):
return len(self.regions[0].blocks) == 0
@property
def body(self):
return self.regions[0]
@property
def body(self):
return self.regions[0]
@property
def type(self):
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
@property
def type(self):
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
@property
def visibility(self):
return self.attributes["sym_visibility"]
@property
def visibility(self):
return self.attributes["sym_visibility"]
@property
def name(self) -> StringAttr:
return StringAttr(self.attributes["sym_name"])
@property
def name(self) -> StringAttr:
return StringAttr(self.attributes["sym_name"])
@property
def entry_block(self):
if self.is_external:
raise IndexError('External function does not have a body')
return self.regions[0].blocks[0]
@property
def entry_block(self):
if self.is_external:
raise IndexError("External function does not have a body")
return self.regions[0].blocks[0]
def add_entry_block(self):
"""
Add an entry block to the function body using the function signature to
infer block arguments.
Returns the newly created block
"""
if not self.is_external:
raise IndexError('The function already has an entry block!')
self.body.blocks.append(*self.type.inputs)
return self.body.blocks[0]
def add_entry_block(self):
"""
Add an entry block to the function body using the function signature to
infer block arguments.
Returns the newly created block
"""
if not self.is_external:
raise IndexError("The function already has an entry block!")
self.body.blocks.append(*self.type.inputs)
return self.body.blocks[0]
@property
def arg_attrs(self):
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@property
def arg_attrs(self):
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@arg_attrs.setter
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
if isinstance(attribute, ArrayAttr):
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
else:
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
attribute, context=self.context)
@arg_attrs.setter
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
if isinstance(attribute, ArrayAttr):
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
else:
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
attribute, context=self.context
)
@property
def arguments(self):
return self.entry_block.arguments
@property
def arguments(self):
return self.entry_block.arguments
@property
def result_attrs(self):
return self.attributes[RESULT_ATTRIBUTE_NAME]
@property
def result_attrs(self):
return self.attributes[RESULT_ATTRIBUTE_NAME]
@result_attrs.setter
def result_attrs(self, attribute: ArrayAttr):
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
@result_attrs.setter
def result_attrs(self, attribute: ArrayAttr):
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute

View File

@@ -18,144 +18,152 @@ __all__ = [
def extend_opview_class(ext_module):
"""Decorator to extend an OpView class from an extension module.
"""Decorator to extend an OpView class from an extension module.
Extension modules can expose various entry-points:
Stand-alone class with the same name as a parent OpView class (i.e.
"ReturnOp"). A name-based match is attempted first before falling back
to a below mechanism.
Extension modules can expose various entry-points:
Stand-alone class with the same name as a parent OpView class (i.e.
"ReturnOp"). A name-based match is attempted first before falling back
to a below mechanism.
def select_opview_mixin(parent_opview_cls):
If defined, allows an appropriate mixin class to be selected dynamically
based on the parent OpView class. Should return NotImplemented if a
decision is not made.
def select_opview_mixin(parent_opview_cls):
If defined, allows an appropriate mixin class to be selected dynamically
based on the parent OpView class. Should return NotImplemented if a
decision is not made.
Args:
ext_module: A module from which to locate extensions. Can be None if not
available.
Args:
ext_module: A module from which to locate extensions. Can be None if not
available.
Returns:
A decorator that takes an OpView subclass and further extends it as
needed.
"""
Returns:
A decorator that takes an OpView subclass and further extends it as
needed.
"""
def class_decorator(parent_opview_cls: type):
if ext_module is None:
return parent_opview_cls
mixin_cls = NotImplemented
# First try to resolve by name.
try:
mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
except AttributeError:
# Fall back to a select_opview_mixin hook.
try:
select_mixin = getattr(ext_module, "select_opview_mixin")
except AttributeError:
pass
else:
mixin_cls = select_mixin(parent_opview_cls)
def class_decorator(parent_opview_cls: type):
if ext_module is None:
return parent_opview_cls
mixin_cls = NotImplemented
# First try to resolve by name.
try:
mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
except AttributeError:
# Fall back to a select_opview_mixin hook.
try:
select_mixin = getattr(ext_module, "select_opview_mixin")
except AttributeError:
pass
else:
mixin_cls = select_mixin(parent_opview_cls)
if mixin_cls is NotImplemented or mixin_cls is None:
return parent_opview_cls
if mixin_cls is NotImplemented or mixin_cls is None:
return parent_opview_cls
# Have a mixin_cls. Create an appropriate subclass.
try:
# Have a mixin_cls. Create an appropriate subclass.
try:
class LocalOpView(mixin_cls, parent_opview_cls):
pass
except TypeError as e:
raise TypeError(
f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e
LocalOpView.__name__ = parent_opview_cls.__name__
LocalOpView.__qualname__ = parent_opview_cls.__qualname__
return LocalOpView
class LocalOpView(mixin_cls, parent_opview_cls):
pass
return class_decorator
except TypeError as e:
raise TypeError(
f"Could not mixin {mixin_cls} into {parent_opview_cls}"
) from e
LocalOpView.__name__ = parent_opview_cls.__name__
LocalOpView.__qualname__ = parent_opview_cls.__qualname__
return LocalOpView
return class_decorator
def segmented_accessor(elements, raw_segments, idx):
"""
Returns a slice of elements corresponding to the idx-th segment.
"""
Returns a slice of elements corresponding to the idx-th segment.
elements: a sliceable container (operands or results).
raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
sizes of the segments.
idx: index of the segment.
"""
segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
start = sum(segments[i] for i in range(idx))
end = start + segments[idx]
return elements[start:end]
elements: a sliceable container (operands or results).
raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
sizes of the segments.
idx: index of the segment.
"""
segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
start = sum(segments[i] for i in range(idx))
end = start + segments[idx]
return elements[start:end]
def equally_sized_accessor(elements, n_variadic, n_preceding_simple,
n_preceding_variadic):
"""
Returns a starting position and a number of elements per variadic group
assuming equally-sized groups and the given numbers of preceding groups.
def equally_sized_accessor(
elements, n_variadic, n_preceding_simple, n_preceding_variadic
):
"""
Returns a starting position and a number of elements per variadic group
assuming equally-sized groups and the given numbers of preceding groups.
elements: a sequential container.
n_variadic: the number of variadic groups in the container.
n_preceding_simple: the number of non-variadic groups preceding the current
group.
n_preceding_variadic: the number of variadic groups preceding the current
group.
"""
elements: a sequential container.
n_variadic: the number of variadic groups in the container.
n_preceding_simple: the number of non-variadic groups preceding the current
group.
n_preceding_variadic: the number of variadic groups preceding the current
group.
"""
total_variadic_length = len(elements) - n_variadic + 1
# This should be enforced by the C++-side trait verifier.
assert total_variadic_length % n_variadic == 0
total_variadic_length = len(elements) - n_variadic + 1
# This should be enforced by the C++-side trait verifier.
assert total_variadic_length % n_variadic == 0
elements_per_group = total_variadic_length // n_variadic
start = n_preceding_simple + n_preceding_variadic * elements_per_group
return start, elements_per_group
elements_per_group = total_variadic_length // n_variadic
start = n_preceding_simple + n_preceding_variadic * elements_per_group
return start, elements_per_group
def get_default_loc_context(location=None):
"""
Returns a context in which the defaulted location is created. If the location
is None, takes the current location from the stack, raises ValueError if there
is no location on the stack.
"""
if location is None:
# Location.current raises ValueError if there is no current location.
return _cext.ir.Location.current.context
return location.context
"""
Returns a context in which the defaulted location is created. If the location
is None, takes the current location from the stack, raises ValueError if there
is no location on the stack.
"""
if location is None:
# Location.current raises ValueError if there is no current location.
return _cext.ir.Location.current.context
return location.context
def get_op_result_or_value(
arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList]
arg: _Union[
_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList
]
) -> _cext.ir.Value:
"""Returns the given value or the single result of the given op.
"""Returns the given value or the single result of the given op.
This is useful to implement op constructors so that they can take other ops as
arguments instead of requiring the caller to extract results for every op.
Raises ValueError if provided with an op that doesn't have a single result.
"""
if isinstance(arg, _cext.ir.OpView):
return arg.operation.result
elif isinstance(arg, _cext.ir.Operation):
return arg.result
elif isinstance(arg, _cext.ir.OpResultList):
return arg[0]
else:
assert isinstance(arg, _cext.ir.Value)
return arg
This is useful to implement op constructors so that they can take other ops as
arguments instead of requiring the caller to extract results for every op.
Raises ValueError if provided with an op that doesn't have a single result.
"""
if isinstance(arg, _cext.ir.OpView):
return arg.operation.result
elif isinstance(arg, _cext.ir.Operation):
return arg.result
elif isinstance(arg, _cext.ir.OpResultList):
return arg[0]
else:
assert isinstance(arg, _cext.ir.Value)
return arg
def get_op_results_or_values(
arg: _Union[_cext.ir.OpView, _cext.ir.Operation,
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]]
arg: _Union[
_cext.ir.OpView,
_cext.ir.Operation,
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
]
) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
"""Returns the given sequence of values or the results of the given op.
"""Returns the given sequence of values or the results of the given op.
This is useful to implement op constructors so that they can take other ops as
lists of arguments instead of requiring the caller to extract results for
every op.
"""
if isinstance(arg, _cext.ir.OpView):
return arg.operation.results
elif isinstance(arg, _cext.ir.Operation):
return arg.results
else:
return [get_op_result_or_value(element) for element in arg]
This is useful to implement op constructors so that they can take other ops as
lists of arguments instead of requiring the caller to extract results for
every op.
"""
if isinstance(arg, _cext.ir.OpView):
return arg.operation.results
elif isinstance(arg, _cext.ir.Operation):
return arg.results
else:
return [get_op_result_or_value(element) for element in arg]

View File

@@ -3,10 +3,10 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ..dialects import pdl
from ..ir import *
from ..dialects import pdl
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
from typing import Union, Optional, Sequence, Mapping
from ._ods_common import (
@@ -16,264 +16,256 @@ from ._ods_common import (
class ApplyNativeConstraintOp:
"""Specialization for PDL apply native constraint op class."""
"""Specialization for PDL apply native constraint op class."""
def __init__(
self,
name: Union[str, StringAttr],
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
args = _get_values(args)
super().__init__(name, args, loc=loc, ip=ip)
def __init__(
self,
name: Union[str, StringAttr],
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
args = _get_values(args)
super().__init__(name, args, loc=loc, ip=ip)
class ApplyNativeRewriteOp:
"""Specialization for PDL apply native rewrite op class."""
"""Specialization for PDL apply native rewrite op class."""
def __init__(
self,
results: Sequence[Type],
name: Union[str, StringAttr],
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
args = _get_values(args)
super().__init__(results, name, args, loc=loc, ip=ip)
def __init__(
self,
results: Sequence[Type],
name: Union[str, StringAttr],
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
args = _get_values(args)
super().__init__(results, name, args, loc=loc, ip=ip)
class AttributeOp:
"""Specialization for PDL attribute op class."""
"""Specialization for PDL attribute op class."""
def __init__(
self,
valueType: Optional[Union[OpView, Operation, Value]] = None,
value: Optional[Attribute] = None,
*,
loc=None,
ip=None,
):
valueType = valueType if valueType is None else _get_value(valueType)
result = pdl.AttributeType.get()
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
def __init__(
self,
valueType: Optional[Union[OpView, Operation, Value]] = None,
value: Optional[Attribute] = None,
*,
loc=None,
ip=None,
):
valueType = valueType if valueType is None else _get_value(valueType)
result = pdl.AttributeType.get()
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
class EraseOp:
"""Specialization for PDL erase op class."""
"""Specialization for PDL erase op class."""
def __init__(
self,
operation: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
operation = _get_value(operation)
super().__init__(operation, loc=loc, ip=ip)
def __init__(
self,
operation: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
operation = _get_value(operation)
super().__init__(operation, loc=loc, ip=ip)
class OperandOp:
"""Specialization for PDL operand op class."""
"""Specialization for PDL operand op class."""
def __init__(
self,
type: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
type = type if type is None else _get_value(type)
result = pdl.ValueType.get()
super().__init__(result, valueType=type, loc=loc, ip=ip)
def __init__(
self,
type: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
type = type if type is None else _get_value(type)
result = pdl.ValueType.get()
super().__init__(result, valueType=type, loc=loc, ip=ip)
class OperandsOp:
"""Specialization for PDL operands op class."""
"""Specialization for PDL operands op class."""
def __init__(
self,
types: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
types = types if types is None else _get_value(types)
result = pdl.RangeType.get(pdl.ValueType.get())
super().__init__(result, valueType=types, loc=loc, ip=ip)
def __init__(
self,
types: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
types = types if types is None else _get_value(types)
result = pdl.RangeType.get(pdl.ValueType.get())
super().__init__(result, valueType=types, loc=loc, ip=ip)
class OperationOp:
"""Specialization for PDL operand op class."""
"""Specialization for PDL operand op class."""
def __init__(
self,
name: Optional[Union[str, StringAttr]] = None,
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
attributes: Optional[Mapping[str, Union[OpView, Operation,
Value]]] = None,
types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if types is None:
types = []
if attributes is None:
attributes = {}
if args is None:
args = []
args = _get_values(args)
attrNames = []
attrValues = []
for attrName, attrValue in attributes.items():
attrNames.append(StringAttr.get(attrName))
attrValues.append(_get_value(attrValue))
attrNames = ArrayAttr.get(attrNames)
types = _get_values(types)
result = pdl.OperationType.get()
super().__init__(result,
args,
attrValues,
attrNames,
types,
opName=name,
loc=loc,
ip=ip)
def __init__(
self,
name: Optional[Union[str, StringAttr]] = None,
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if types is None:
types = []
if attributes is None:
attributes = {}
if args is None:
args = []
args = _get_values(args)
attrNames = []
attrValues = []
for attrName, attrValue in attributes.items():
attrNames.append(StringAttr.get(attrName))
attrValues.append(_get_value(attrValue))
attrNames = ArrayAttr.get(attrNames)
types = _get_values(types)
result = pdl.OperationType.get()
super().__init__(
result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
)
class PatternOp:
"""Specialization for PDL pattern op class."""
"""Specialization for PDL pattern op class."""
def __init__(
self,
benefit: Union[IntegerAttr, int],
name: Optional[Union[StringAttr, str]] = None,
*,
loc=None,
ip=None,
):
"""Creates an PDL `pattern` operation."""
super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
self.regions[0].blocks.append()
def __init__(
self,
benefit: Union[IntegerAttr, int],
name: Optional[Union[StringAttr, str]] = None,
*,
loc=None,
ip=None,
):
"""Creates an PDL `pattern` operation."""
super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
self.regions[0].blocks.append()
@property
def body(self):
"""Return the body (block) of the pattern."""
return self.regions[0].blocks[0]
@property
def body(self):
"""Return the body (block) of the pattern."""
return self.regions[0].blocks[0]
class ReplaceOp:
"""Specialization for PDL replace op class."""
"""Specialization for PDL replace op class."""
def __init__(
self,
op: Union[OpView, Operation, Value],
*,
with_op: Optional[Union[OpView, Operation, Value]] = None,
with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
loc=None,
ip=None,
):
if with_values is None:
with_values = []
op = _get_value(op)
with_op = with_op if with_op is None else _get_value(with_op)
with_values = _get_values(with_values)
super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
def __init__(
self,
op: Union[OpView, Operation, Value],
*,
with_op: Optional[Union[OpView, Operation, Value]] = None,
with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
loc=None,
ip=None,
):
if with_values is None:
with_values = []
op = _get_value(op)
with_op = with_op if with_op is None else _get_value(with_op)
with_values = _get_values(with_values)
super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
class ResultOp:
"""Specialization for PDL result op class."""
"""Specialization for PDL result op class."""
def __init__(
self,
parent: Union[OpView, Operation, Value],
index: Union[IntegerAttr, int],
*,
loc=None,
ip=None,
):
parent = _get_value(parent)
result = pdl.ValueType.get()
super().__init__(result, parent, index, loc=loc, ip=ip)
def __init__(
self,
parent: Union[OpView, Operation, Value],
index: Union[IntegerAttr, int],
*,
loc=None,
ip=None,
):
parent = _get_value(parent)
result = pdl.ValueType.get()
super().__init__(result, parent, index, loc=loc, ip=ip)
class ResultsOp:
"""Specialization for PDL results op class."""
"""Specialization for PDL results op class."""
def __init__(
self,
result: Type,
parent: Union[OpView, Operation, Value],
index: Optional[Union[IntegerAttr, int]] = None,
*,
loc=None,
ip=None,
):
parent = _get_value(parent)
super().__init__(result, parent, index=index, loc=loc, ip=ip)
def __init__(
self,
result: Type,
parent: Union[OpView, Operation, Value],
index: Optional[Union[IntegerAttr, int]] = None,
*,
loc=None,
ip=None,
):
parent = _get_value(parent)
super().__init__(result, parent, index=index, loc=loc, ip=ip)
class RewriteOp:
"""Specialization for PDL rewrite op class."""
"""Specialization for PDL rewrite op class."""
def __init__(
self,
root: Optional[Union[OpView, Operation, Value]] = None,
name: Optional[Union[StringAttr, str]] = None,
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
root = root if root is None else _get_value(root)
args = _get_values(args)
super().__init__(args, root=root, name=name, loc=loc, ip=ip)
def __init__(
self,
root: Optional[Union[OpView, Operation, Value]] = None,
name: Optional[Union[StringAttr, str]] = None,
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
root = root if root is None else _get_value(root)
args = _get_values(args)
super().__init__(args, root=root, name=name, loc=loc, ip=ip)
def add_body(self):
"""Add body (block) to the rewrite."""
self.regions[0].blocks.append()
return self.body
def add_body(self):
"""Add body (block) to the rewrite."""
self.regions[0].blocks.append()
return self.body
@property
def body(self):
"""Return the body (block) of the rewrite."""
return self.regions[0].blocks[0]
@property
def body(self):
"""Return the body (block) of the rewrite."""
return self.regions[0].blocks[0]
class TypeOp:
"""Specialization for PDL type op class."""
"""Specialization for PDL type op class."""
def __init__(self,
constantType: Optional[Union[TypeAttr, Type]] = None,
*,
loc=None,
ip=None):
result = pdl.TypeType.get()
super().__init__(result, constantType=constantType, loc=loc, ip=ip)
def __init__(
self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
):
result = pdl.TypeType.get()
super().__init__(result, constantType=constantType, loc=loc, ip=ip)
class TypesOp:
"""Specialization for PDL types op class."""
"""Specialization for PDL types op class."""
def __init__(
self,
constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
*,
loc=None,
ip=None,
):
if constantTypes is None:
constantTypes = []
result = pdl.RangeType.get(pdl.TypeType.get())
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
def __init__(
self,
constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
*,
loc=None,
ip=None,
):
if constantTypes is None:
constantTypes = []
result = pdl.RangeType.get(pdl.TypeType.get())
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)

View File

@@ -3,105 +3,104 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ..ir import *
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
from typing import Any, Optional, Sequence, Union
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
)
class ForOp:
"""Specialization for the SCF for op class."""
"""Specialization for the SCF for op class."""
def __init__(self,
lower_bound,
upper_bound,
step,
iter_args: Optional[Union[Operation, OpView,
Sequence[Value]]] = None,
*,
loc=None,
ip=None):
"""Creates an SCF `for` operation.
def __init__(
self,
lower_bound,
upper_bound,
step,
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
loc=None,
ip=None
):
"""Creates an SCF `for` operation.
- `lower_bound` is the value to use as lower bound of the loop.
- `upper_bound` is the value to use as upper bound of the loop.
- `step` is the value to use as loop step.
- `iter_args` is a list of additional loop-carried arguments or an operation
producing them as results.
"""
if iter_args is None:
iter_args = []
iter_args = _get_op_results_or_values(iter_args)
- `lower_bound` is the value to use as lower bound of the loop.
- `upper_bound` is the value to use as upper bound of the loop.
- `step` is the value to use as loop step.
- `iter_args` is a list of additional loop-carried arguments or an operation
producing them as results.
"""
if iter_args is None:
iter_args = []
iter_args = _get_op_results_or_values(iter_args)
results = [arg.type for arg in iter_args]
super().__init__(
self.build_generic(
regions=1,
results=results,
operands=[
_get_op_result_or_value(o)
for o in [lower_bound, upper_bound, step]
] + list(iter_args),
loc=loc,
ip=ip))
self.regions[0].blocks.append(IndexType.get(), *results)
results = [arg.type for arg in iter_args]
super().__init__(
self.build_generic(
regions=1,
results=results,
operands=[
_get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
]
+ list(iter_args),
loc=loc,
ip=ip,
)
)
self.regions[0].blocks.append(IndexType.get(), *results)
@property
def body(self):
"""Returns the body (block) of the loop."""
return self.regions[0].blocks[0]
@property
def body(self):
"""Returns the body (block) of the loop."""
return self.regions[0].blocks[0]
@property
def induction_variable(self):
"""Returns the induction variable of the loop."""
return self.body.arguments[0]
@property
def induction_variable(self):
"""Returns the induction variable of the loop."""
return self.body.arguments[0]
@property
def inner_iter_args(self):
"""Returns the loop-carried arguments usable within the loop.
@property
def inner_iter_args(self):
"""Returns the loop-carried arguments usable within the loop.
To obtain the loop-carried operands, use `iter_args`.
"""
return self.body.arguments[1:]
To obtain the loop-carried operands, use `iter_args`.
"""
return self.body.arguments[1:]
class IfOp:
"""Specialization for the SCF if op class."""
"""Specialization for the SCF if op class."""
def __init__(self,
cond,
results_=[],
*,
hasElse=False,
loc=None,
ip=None):
"""Creates an SCF `if` operation.
def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
"""Creates an SCF `if` operation.
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
- `hasElse` determines whether the if operation has the else branch.
"""
operands = []
operands.append(cond)
results = []
results.extend(results_)
super().__init__(
self.build_generic(
regions=2,
results=results,
operands=operands,
loc=loc,
ip=ip))
self.regions[0].blocks.append(*[])
if hasElse:
self.regions[1].blocks.append(*[])
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
- `hasElse` determines whether the if operation has the else branch.
"""
operands = []
operands.append(cond)
results = []
results.extend(results_)
super().__init__(
self.build_generic(
regions=2, results=results, operands=operands, loc=loc, ip=ip
)
)
self.regions[0].blocks.append(*[])
if hasElse:
self.regions[1].blocks.append(*[])
@property
def then_block(self):
"""Returns the then block of the if operation."""
return self.regions[0].blocks[0]
@property
def then_block(self):
"""Returns the then block of the if operation."""
return self.regions[0].blocks[0]
@property
def else_block(self):
"""Returns the else block of the if operation."""
return self.regions[1].blocks[0]
@property
def else_block(self):
"""Returns the else block of the if operation."""
return self.regions[1].blocks[0]

View File

@@ -3,11 +3,11 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ..dialects import pdl, transform
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ..dialects import pdl, transform
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
from typing import List, Optional, Sequence, Union, overload
@@ -16,312 +16,315 @@ OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
def _get_int_int_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
IntOrAttrList]]]]
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
) -> ArrayAttr:
"""Creates an array attribute containing array attributes of integers.
"""Creates an array attribute containing array attributes of integers.
If the operand is already an array attribute, forwards it. Otherwise treats
the operand as a list of attributes or integers, potentially interpserced, to
create a new array-of-array attribute. Expects the thread-local MLIR context
to have been set by the context manager.
"""
if values is None:
return ArrayAttr.get([])
if isinstance(values, ArrayAttr):
return values
if isinstance(values, list):
values = [
ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), v)
for v in value])
for value in values
]
if values is None:
return ArrayAttr.get([])
if isinstance(values, ArrayAttr):
return values
if isinstance(values, list):
values = [
ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in value]
)
for value in values
]
return ArrayAttr.get(values)
return ArrayAttr.get(values)
class DecomposeOp:
"""Specialization for DecomposeOp class."""
"""Specialization for DecomposeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(pdl.OperationType.get(),
_get_op_result_or_value(target),
loc=loc,
ip=ip)
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(
pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
)
class GeneralizeOp:
"""Specialization for GeneralizeOp class."""
"""Specialization for GeneralizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(pdl.OperationType.get(),
_get_op_result_or_value(target),
loc=loc,
ip=ip)
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(
pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
)
class InterchangeOp:
"""Specialization for InterchangeOp class."""
"""Specialization for InterchangeOp class."""
def __init__(
self,
target: Union[Operation, Value],
*,
iterator_interchange: OptionalIntList = None,
loc=None,
ip=None,
):
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
iterator_interchange=iterator_interchange,
loc=loc,
ip=ip,
)
def __init__(
self,
target: Union[Operation, Value],
*,
iterator_interchange: OptionalIntList = None,
loc=None,
ip=None,
):
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
iterator_interchange=iterator_interchange,
loc=loc,
ip=ip,
)
class MatchOp:
"""Specialization for MatchOp class."""
"""Specialization for MatchOp class."""
@classmethod
def match_op_names(
MatchOp,
target: Union[Operation, Value],
names: Sequence[str],
loc=None,
ip=None,
):
pdl_operation_type = pdl.OperationType.get()
return MatchOp(
pdl_operation_type,
_get_op_result_or_value(target),
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
loc=loc,
ip=ip,
)
@classmethod
def match_op_names(
MatchOp,
target: Union[Operation, Value],
names: Sequence[str],
loc=None,
ip=None,
):
pdl_operation_type = pdl.OperationType.get()
return MatchOp(
pdl_operation_type,
_get_op_result_or_value(target),
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
loc=loc,
ip=ip,
)
class MultiTileSizesOp:
"""Specialization for MultitileSizesOp class."""
"""Specialization for MultitileSizesOp class."""
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
dimension: Union[int, IntegerAttr],
target_size: Union[int, IntegerAttr],
divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
loc=None,
ip=None,
):
if divisor is None:
divisor = 1
super().__init__(
result_type,
result_type,
result_type,
_get_op_result_or_value(target),
dimension=dimension,
target_size=target_size,
divisor=divisor,
loc=loc,
ip=ip,
)
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
dimension: Union[int, IntegerAttr],
target_size: Union[int, IntegerAttr],
divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
loc=None,
ip=None,
):
if divisor is None:
divisor = 1
super().__init__(
result_type,
result_type,
result_type,
_get_op_result_or_value(target),
dimension=dimension,
target_size=target_size,
divisor=divisor,
loc=loc,
ip=ip,
)
class PadOp:
"""Specialization for PadOp class."""
"""Specialization for PadOp class."""
def __init__(
self,
target: Union[Operation, Value],
*,
padding_values: Optional[Optional[Union[ArrayAttr,
Sequence[Attribute]]]] = None,
padding_dimensions: OptionalIntList = None,
pack_paddings: OptionalIntList = None,
transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
ArrayAttr, IntOrAttrList]]]] = None,
loc=None,
ip=None,
):
if transpose_paddings is None:
transpose_paddings = []
if pack_paddings is None:
pack_paddings = []
if padding_dimensions is None:
padding_dimensions = []
if padding_values is None:
padding_values = []
pdl_operation_type = pdl.OperationType.get()
transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
padding_values=padding_values,
padding_dimensions=padding_dimensions,
pack_paddings=pack_paddings,
transpose_paddings=transpose_paddings_attr,
loc=loc,
ip=ip,
)
def __init__(
self,
target: Union[Operation, Value],
*,
padding_values: Optional[
Optional[Union[ArrayAttr, Sequence[Attribute]]]
] = None,
padding_dimensions: OptionalIntList = None,
pack_paddings: OptionalIntList = None,
transpose_paddings: Optional[
Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
] = None,
loc=None,
ip=None,
):
if transpose_paddings is None:
transpose_paddings = []
if pack_paddings is None:
pack_paddings = []
if padding_dimensions is None:
padding_dimensions = []
if padding_values is None:
padding_values = []
pdl_operation_type = pdl.OperationType.get()
transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
padding_values=padding_values,
padding_dimensions=padding_dimensions,
pack_paddings=pack_paddings,
transpose_paddings=transpose_paddings_attr,
loc=loc,
ip=ip,
)
class ScalarizeOp:
"""Specialization for ScalarizeOp class."""
"""Specialization for ScalarizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
pdl_operation_type = pdl.OperationType.get()
super().__init__(pdl_operation_type,
_get_op_result_or_value(target),
loc=loc,
ip=ip)
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip
)
class SplitOp:
"""Specialization for SplitOp class."""
"""Specialization for SplitOp class."""
def __init__(
self,
target: Union[Operation, Value],
dimension: Union[int, Attribute],
split_point: Union[int, Operation, Value, Attribute],
*,
loc=None,
ip=None,
):
if isinstance(split_point, int):
static_split_point = split_point
dynamic_split_point = None
else:
static_split_point = ShapedType.get_dynamic_size()
dynamic_split_point = _get_op_result_or_value(split_point)
def __init__(
self,
target: Union[Operation, Value],
dimension: Union[int, Attribute],
split_point: Union[int, Operation, Value, Attribute],
*,
loc=None,
ip=None,
):
if isinstance(split_point, int):
static_split_point = split_point
dynamic_split_point = None
else:
static_split_point = ShapedType.get_dynamic_size()
dynamic_split_point = _get_op_result_or_value(split_point)
target = _get_op_result_or_value(target)
target = _get_op_result_or_value(target)
super().__init__(
target.type,
target.type,
target,
dimension=dimension,
static_split_point=static_split_point,
dynamic_split_point=dynamic_split_point,
loc=loc,
ip=ip,
)
super().__init__(
target.type,
target.type,
target,
dimension=dimension,
static_split_point=static_split_point,
dynamic_split_point=dynamic_split_point,
loc=loc,
ip=ip,
)
class TileOp:
"""Specialization for TileOp class."""
"""Specialization for TileOp class."""
@overload
def __init__(
self,
loop_types: Union[Type, List[Type]],
target: Union[Operation, Value],
*,
sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
loop_types: Union[Type, List[Type]],
target: Union[Operation, Value],
*,
sizes: Optional[
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, Value, OpView],
*,
sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, Value, OpView],
*,
sizes: Optional[
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None,
):
...
def __init__(
self,
loop_types_or_target: Union[Type, List[Type], Operation, Value],
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
*,
sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None,
):
if interchange is None:
interchange = []
if sizes is None:
sizes = []
def __init__(
self,
loop_types_or_target: Union[Type, List[Type], Operation, Value],
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
*,
sizes: Optional[
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None,
):
if interchange is None:
interchange = []
if sizes is None:
sizes = []
static_sizes = []
dynamic_sizes = []
if isinstance(sizes, ArrayAttr):
sizes_attr = sizes
else:
for size in sizes:
if isinstance(size, int):
static_sizes.append(size)
static_sizes = []
dynamic_sizes = []
if isinstance(sizes, ArrayAttr):
sizes_attr = sizes
else:
static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(_get_op_result_or_value(size))
sizes_attr = DenseI64ArrayAttr.get(static_sizes)
for size in sizes:
if isinstance(size, int):
static_sizes.append(size)
else:
static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(_get_op_result_or_value(size))
sizes_attr = DenseI64ArrayAttr.get(static_sizes)
num_loops = sum(
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
num_loops = sum(v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
loop_types = [transform.AnyOpType.get()] * num_loops
target = loop_types_or_target
assert target_or_none is None, "Cannot construct TileOp with two targets."
else:
loop_types = (([loop_types_or_target] * num_loops) if isinstance(
loop_types_or_target, Type) else loop_types_or_target)
target = target_or_none
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
loop_types = [transform.AnyOpType.get()] * num_loops
target = loop_types_or_target
assert target_or_none is None, "Cannot construct TileOp with two targets."
else:
loop_types = (
([loop_types_or_target] * num_loops)
if isinstance(loop_types_or_target, Type)
else loop_types_or_target
)
target = target_or_none
target = _get_op_result_or_value(target)
target = _get_op_result_or_value(target)
super().__init__(
target.type,
loop_types,
target,
dynamic_sizes=dynamic_sizes,
static_sizes=sizes_attr,
interchange=interchange,
loc=loc,
ip=ip,
)
super().__init__(
target.type,
loop_types,
target,
dynamic_sizes=dynamic_sizes,
static_sizes=sizes_attr,
interchange=interchange,
loc=loc,
ip=ip,
)
def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
if not attr:
return []
return [element for element in attr]
def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
if not attr:
return []
return [element for element in attr]
class VectorizeOp:
"""Specialization for VectorizeOp class."""
"""Specialization for VectorizeOp class."""
def __init__(
self,
target: Union[Operation, Value],
*,
vectorize_padding: Union[bool, BoolAttr] = False,
loc=None,
ip=None,
):
pdl_operation_type = pdl.OperationType.get()
if isinstance(vectorize_padding, bool):
vectorize_padding = UnitAttr.get()
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
vectorize_padding=vectorize_padding,
loc=loc,
ip=ip,
)
def __init__(
self,
target: Union[Operation, Value],
*,
vectorize_padding: Union[bool, BoolAttr] = False,
loc=None,
ip=None,
):
pdl_operation_type = pdl.OperationType.get()
if isinstance(vectorize_padding, bool):
vectorize_padding = UnitAttr.get()
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
vectorize_padding=vectorize_padding,
loc=loc,
ip=ip,
)

View File

@@ -3,40 +3,42 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ..ir import *
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
from typing import Any, Optional, Sequence, Union
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
)
class EmptyOp:
"""Extends the tensor.empty op."""
"""Extends the tensor.empty op."""
def __init__(self,
sizes: Sequence[Union[int, Value]],
element_type: Type,
*,
loc=None,
ip=None):
"""Constructs an `empty` with mixed static/dynamic sizes."""
# TODO: Refactor the EmptyOp to take an element type attribute and
# then use normal result type inference, unifying the Python and C++ side
# with a standard mechanism (versus stashing that in builders).
dynamic_sizes = []
static_sizes = []
for s in sizes:
if isinstance(s, int):
static_sizes.append(s)
else:
static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(s)
result_type = RankedTensorType.get(static_sizes, element_type)
op = self.build_generic(
results=[result_type],
operands=dynamic_sizes,
attributes={},
loc=loc,
ip=ip)
OpView.__init__(self, op)
def __init__(
self,
sizes: Sequence[Union[int, Value]],
element_type: Type,
*,
loc=None,
ip=None
):
"""Constructs an `empty` with mixed static/dynamic sizes."""
# TODO: Refactor the EmptyOp to take an element type attribute and
# then use normal result type inference, unifying the Python and C++ side
# with a standard mechanism (versus stashing that in builders).
dynamic_sizes = []
static_sizes = []
for s in sizes:
if isinstance(s, int):
static_sizes.append(s)
else:
static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(s)
result_type = RankedTensorType.get(static_sizes, element_type)
op = self.build_generic(
results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip
)
OpView.__init__(self, op)

View File

@@ -3,144 +3,131 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
)
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
class CastOp:
def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
loc=None,
ip=None):
super().__init__(result_type,
_get_op_result_or_value(target),
loc=loc,
ip=ip)
def __init__(
self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None
):
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
class GetClosestIsolatedParentOp:
def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
loc=None,
ip=None):
super().__init__(result_type,
_get_op_result_or_value(target),
loc=loc,
ip=ip)
def __init__(
self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None
):
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
class MergeHandlesOp:
def __init__(
self,
handles: Sequence[Union[Operation, Value]],
*,
deduplicate: bool = False,
loc=None,
ip=None,
):
super().__init__(
[_get_op_result_or_value(h) for h in handles],
deduplicate=deduplicate,
loc=loc,
ip=ip,
)
def __init__(
self,
handles: Sequence[Union[Operation, Value]],
*,
deduplicate: bool = False,
loc=None,
ip=None,
):
super().__init__(
[_get_op_result_or_value(h) for h in handles],
deduplicate=deduplicate,
loc=loc,
ip=ip,
)
class ReplicateOp:
def __init__(
self,
pattern: Union[Operation, Value],
handles: Sequence[Union[Operation, Value]],
*,
loc=None,
ip=None,
):
super().__init__(
[_get_op_result_or_value(h).type for h in handles],
_get_op_result_or_value(pattern),
[_get_op_result_or_value(h) for h in handles],
loc=loc,
ip=ip,
)
def __init__(
self,
pattern: Union[Operation, Value],
handles: Sequence[Union[Operation, Value]],
*,
loc=None,
ip=None,
):
super().__init__(
[_get_op_result_or_value(h).type for h in handles],
_get_op_result_or_value(pattern),
[_get_op_result_or_value(h) for h in handles],
loc=loc,
ip=ip,
)
class SequenceOp:
def __init__(
self,
failure_propagation_mode,
results: Sequence[Type],
target: Union[Operation, Value, Type],
extra_bindings: Optional[
Union[Sequence[Value], Sequence[Type], Operation, OpView]
] = None,
):
root = (
_get_op_result_or_value(target)
if isinstance(target, (Operation, Value))
else None
)
root_type = root.type if not isinstance(target, Type) else target
if not isinstance(failure_propagation_mode, Attribute):
failure_propagation_mode_attr = IntegerAttr.get(
IntegerType.get_signless(32), failure_propagation_mode._as_int()
)
else:
failure_propagation_mode_attr = failure_propagation_mode
def __init__(
self,
failure_propagation_mode,
results: Sequence[Type],
target: Union[Operation, Value, Type],
extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], Operation,
OpView]] = None,
):
root = (_get_op_result_or_value(target) if isinstance(
target, (Operation, Value)) else None)
root_type = root.type if not isinstance(target, Type) else target
if not isinstance(failure_propagation_mode, Attribute):
failure_propagation_mode_attr = IntegerAttr.get(
IntegerType.get_signless(32), failure_propagation_mode._as_int())
else:
failure_propagation_mode_attr = failure_propagation_mode
if extra_bindings is None:
extra_bindings = []
if isinstance(extra_bindings, (Operation, OpView)):
extra_bindings = _get_op_results_or_values(extra_bindings)
if extra_bindings is None:
extra_bindings = []
if isinstance(extra_bindings, (Operation, OpView)):
extra_bindings = _get_op_results_or_values(extra_bindings)
extra_binding_types = []
if len(extra_bindings) != 0:
if isinstance(extra_bindings[0], Type):
extra_binding_types = extra_bindings
extra_bindings = []
else:
extra_binding_types = [v.type for v in extra_bindings]
extra_binding_types = []
if len(extra_bindings) != 0:
if isinstance(extra_bindings[0], Type):
extra_binding_types = extra_bindings
extra_bindings = []
else:
extra_binding_types = [v.type for v in extra_bindings]
super().__init__(
results_=results,
failure_propagation_mode=failure_propagation_mode_attr,
root=root,
extra_bindings=extra_bindings,
)
self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
super().__init__(
results_=results,
failure_propagation_mode=failure_propagation_mode_attr,
root=root,
extra_bindings=extra_bindings,
)
self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def bodyTarget(self) -> Value:
return self.body.arguments[0]
@property
def bodyTarget(self) -> Value:
return self.body.arguments[0]
@property
def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]
@property
def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]
class YieldOp:
def __init__(
self,
operands: Optional[Union[Operation, Sequence[Value]]] = None,
*,
loc=None,
ip=None,
):
if operands is None:
operands = []
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
def __init__(
self,
operands: Optional[Union[Operation, Sequence[Value]]] = None,
*,
loc=None,
ip=None,
):
if operands is None:
operands = []
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)

View File

@@ -31,61 +31,60 @@ from .lang.yaml_helper import *
def create_arg_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Dump an oplib in various formats")
p.add_argument("modules",
metavar="M",
type=str,
nargs="*",
help="Op module to dump")
p.add_argument("--file",
metavar="F",
type=str,
nargs="*",
help="Python op file to dump")
p.add_argument("--format",
type=str,
dest="format",
default="yaml",
choices=("yaml", "repr"),
help="Format in which to dump")
return p
p = argparse.ArgumentParser(description="Dump an oplib in various formats")
p.add_argument(
"modules", metavar="M", type=str, nargs="*", help="Op module to dump"
)
p.add_argument(
"--file", metavar="F", type=str, nargs="*", help="Python op file to dump"
)
p.add_argument(
"--format",
type=str,
dest="format",
default="yaml",
choices=("yaml", "repr"),
help="Format in which to dump",
)
return p
def load_module_from_file(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
m = importlib.util.module_from_spec(spec)
spec.loader.exec_module(m)
return m
spec = importlib.util.spec_from_file_location(module_name, file_path)
m = importlib.util.module_from_spec(spec)
spec.loader.exec_module(m)
return m
def main(args):
# Load all configs.
configs = []
modules = []
for module_name in args.modules:
modules.append(
importlib.import_module(module_name,
package="mlir.dialects.linalg.opdsl"))
for i, file_path in enumerate(args.file or []):
modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path))
for m in modules:
for attr_name, value in m.__dict__.items():
# TODO: This class layering is awkward.
if isinstance(value, DefinedOpCallable):
try:
linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def)
except Exception as e:
raise ValueError(
f"Could not create LinalgOpConfig from {value.op_def}") from e
configs.extend(linalg_config)
# Load all configs.
configs = []
modules = []
for module_name in args.modules:
modules.append(
importlib.import_module(module_name, package="mlir.dialects.linalg.opdsl")
)
for i, file_path in enumerate(args.file or []):
modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path))
for m in modules:
for attr_name, value in m.__dict__.items():
# TODO: This class layering is awkward.
if isinstance(value, DefinedOpCallable):
try:
linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def)
except Exception as e:
raise ValueError(
f"Could not create LinalgOpConfig from {value.op_def}"
) from e
configs.extend(linalg_config)
# Print.
if args.format == "yaml":
print(yaml_dump_all(configs))
elif args.format == "repr":
for config in configs:
print(repr(config))
# Print.
if args.format == "yaml":
print(yaml_dump_all(configs))
elif args.format == "repr":
for config in configs:
print(repr(config))
if __name__ == "__main__":
main(create_arg_parser().parse_args())
main(create_arg_parser().parse_args())

View File

@@ -66,201 +66,201 @@ __all__ = [
class AffineBuildState:
"""Internal state for the AffineExprDef._create impls.
"""Internal state for the AffineExprDef._create impls.
Note that a "local" AffineBuildState can be created relative to a "global"
AffineBuildState. In that case, any affine expressions built will inherit
symbol and dim bindings from the global state and will update both as new
ones are discovered. This allows for building expressions across contexts
which share a common symbol and dim space.
"""
Note that a "local" AffineBuildState can be created relative to a "global"
AffineBuildState. In that case, any affine expressions built will inherit
symbol and dim bindings from the global state and will update both as new
ones are discovered. This allows for building expressions across contexts
which share a common symbol and dim space.
"""
def __init__(self,
*,
global_state: "AffineBuildState" = None,
allow_new_symbols: bool = True,
allow_new_dims: bool = True):
if not global_state:
self.all_symbols = dict() # type: Dict[str, int]
self.all_dims = dict() # type: Dict[str, int]
else:
# Alias the global dict.
self.all_symbols = global_state.all_symbols
self.all_dims = global_state.all_dims
def __init__(
self,
*,
global_state: "AffineBuildState" = None,
allow_new_symbols: bool = True,
allow_new_dims: bool = True,
):
if not global_state:
self.all_symbols = dict() # type: Dict[str, int]
self.all_dims = dict() # type: Dict[str, int]
else:
# Alias the global dict.
self.all_symbols = global_state.all_symbols
self.all_dims = global_state.all_dims
# Map of symbols and dims in the current build.
self.local_symbols = dict() # type: Dict[str, int]
self.local_dims = dict() # type: Dict[str, int]
self.allow_new_symbols = allow_new_symbols
self.allow_new_dims = allow_new_dims
# Map of symbols and dims in the current build.
self.local_symbols = dict() # type: Dict[str, int]
self.local_dims = dict() # type: Dict[str, int]
self.allow_new_symbols = allow_new_symbols
self.allow_new_dims = allow_new_dims
def get_dim(self, dimname: str) -> int:
"""Gets the dim position given a name."""
pos = self.all_dims.get(dimname)
if pos is None:
if not self.allow_new_dims:
raise ValueError(
f"New dimensions not allowed in the current affine expression: "
f"Requested '{dimname}', Availble: {self.all_dims}")
pos = len(self.all_dims)
self.all_dims[dimname] = pos
self.local_dims[dimname] = pos
return pos
def get_dim(self, dimname: str) -> int:
"""Gets the dim position given a name."""
pos = self.all_dims.get(dimname)
if pos is None:
if not self.allow_new_dims:
raise ValueError(
f"New dimensions not allowed in the current affine expression: "
f"Requested '{dimname}', Availble: {self.all_dims}"
)
pos = len(self.all_dims)
self.all_dims[dimname] = pos
self.local_dims[dimname] = pos
return pos
def get_symbol(self, symname: str) -> int:
"""Geta a symbol position given a name."""
pos = self.all_symbols.get(symname)
if pos is None:
if not self.allow_new_symbols:
raise ValueError(
f"New symbols not allowed in the current affine expression: "
f"Requested '{symname}', Availble: {self.all_symbols}")
pos = len(self.all_symbols)
self.all_symbols[symname] = pos
self.local_symbols[symname] = pos
return pos
def get_symbol(self, symname: str) -> int:
"""Geta a symbol position given a name."""
pos = self.all_symbols.get(symname)
if pos is None:
if not self.allow_new_symbols:
raise ValueError(
f"New symbols not allowed in the current affine expression: "
f"Requested '{symname}', Availble: {self.all_symbols}"
)
pos = len(self.all_symbols)
self.all_symbols[symname] = pos
self.local_symbols[symname] = pos
return pos
@property
def local_dim_count(self) -> int:
return len(self.local_dims)
@property
def local_dim_count(self) -> int:
return len(self.local_dims)
@property
def local_symbol_count(self) -> int:
return len(self.local_symbols)
@property
def local_symbol_count(self) -> int:
return len(self.local_symbols)
@property
def dim_count(self) -> int:
return len(self.all_dims)
@property
def dim_count(self) -> int:
return len(self.all_dims)
@property
def symbol_count(self) -> int:
return len(self.all_symbols)
@property
def symbol_count(self) -> int:
return len(self.all_symbols)
def __repr__(self):
lines = [f"AffineBuildState<"]
lines.append(f" symbols={self.local_symbols}")
lines.append(f" dims={self.local_dims}>")
return "\n".join(lines)
def __repr__(self):
lines = [f"AffineBuildState<"]
lines.append(f" symbols={self.local_symbols}")
lines.append(f" dims={self.local_dims}>")
return "\n".join(lines)
class AffineExprDef:
"""Base class for an affine expression being defined."""
"""Base class for an affine expression being defined."""
def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
"""Builds the corresponding _ir.AffineExpr from the definitions.
"""
state = AffineBuildState() if state is None else state
expr = self._create(state)
return expr
def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
"""Builds the corresponding _ir.AffineExpr from the definitions."""
state = AffineBuildState() if state is None else state
expr = self._create(state)
return expr
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
raise NotImplementedError()
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
raise NotImplementedError()
@staticmethod
def coerce_from(py_value):
if isinstance(py_value, int):
return AffineConstantExpr(py_value)
assert isinstance(py_value, AffineExprDef)
return py_value
@staticmethod
def coerce_from(py_value):
if isinstance(py_value, int):
return AffineConstantExpr(py_value)
assert isinstance(py_value, AffineExprDef)
return py_value
def visit_affine_exprs(self, callback):
"""Visits all AffineExprDefs including self."""
callback(self)
def visit_affine_exprs(self, callback):
"""Visits all AffineExprDefs including self."""
callback(self)
def __add__(lhs, rhs):
rhs = AffineExprDef.coerce_from(rhs)
return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
def __add__(lhs, rhs):
rhs = AffineExprDef.coerce_from(rhs)
return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
def __mul__(lhs, rhs):
rhs = AffineExprDef.coerce_from(rhs)
return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
def __mul__(lhs, rhs):
rhs = AffineExprDef.coerce_from(rhs)
return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
def __mod__(lhs, rhs):
rhs = AffineExprDef.coerce_from(rhs)
return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
def __mod__(lhs, rhs):
rhs = AffineExprDef.coerce_from(rhs)
return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
def __floordiv__(lhs, rhs):
rhs = AffineExprDef.coerce_from(rhs)
return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
def __floordiv__(lhs, rhs):
rhs = AffineExprDef.coerce_from(rhs)
return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
def __truediv__(lhs, rhs):
# TODO: Not really a ceil div - taking liberties for the DSL.
rhs = AffineExprDef.coerce_from(rhs)
return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
def __truediv__(lhs, rhs):
# TODO: Not really a ceil div - taking liberties for the DSL.
rhs = AffineExprDef.coerce_from(rhs)
return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
class AffineConstantExpr(AffineExprDef):
"""An affine constant being defined."""
"""An affine constant being defined."""
def __init__(self, value: int):
assert isinstance(value, int)
self.value = value
def __init__(self, value: int):
assert isinstance(value, int)
self.value = value
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
return _ir.AffineConstantExpr.get(self.value)
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
return _ir.AffineConstantExpr.get(self.value)
def __repr__(self):
return f"Const({self.value})"
def __repr__(self):
return f"Const({self.value})"
class AffineBinaryExprDef(AffineExprDef):
"""An affine binary expression being defined."""
"""An affine binary expression being defined."""
def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
self.ir_ctor = ir_ctor
self.lhs = lhs
self.rhs = rhs
def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
self.ir_ctor = ir_ctor
self.lhs = lhs
self.rhs = rhs
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
def visit_affine_exprs(self, callback):
"""Visits all AffineExprDefs including self."""
super().visit_affine_exprs(callback)
self.lhs.visit_affine_exprs(callback)
self.rhs.visit_affine_exprs(callback)
def visit_affine_exprs(self, callback):
"""Visits all AffineExprDefs including self."""
super().visit_affine_exprs(callback)
self.lhs.visit_affine_exprs(callback)
self.rhs.visit_affine_exprs(callback)
def __repr__(self):
return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
def __repr__(self):
return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
class DimDef(AffineExprDef):
"""Represents a named dimension.
"""Represents a named dimension."""
"""
ALL_DIMS = dict() # type: Dict[str, "DimDef"]
ALL_DIMS = dict() # type: Dict[str, "DimDef"]
def __new__(cls, dimname: str):
existing = cls.ALL_DIMS.get(dimname)
if existing is not None:
return existing
new = super().__new__(cls)
new.dimname = dimname
cls.ALL_DIMS[dimname] = new
return new
def __new__(cls, dimname: str):
existing = cls.ALL_DIMS.get(dimname)
if existing is not None:
return existing
new = super().__new__(cls)
new.dimname = dimname
cls.ALL_DIMS[dimname] = new
return new
def __repr__(self):
return f"Dim({self.dimname})"
def __repr__(self):
return f"Dim({self.dimname})"
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
pos = state.get_dim(self.dimname)
return _ir.AffineDimExpr.get(position=pos)
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
pos = state.get_dim(self.dimname)
return _ir.AffineDimExpr.get(position=pos)
@classmethod
def create_expando(cls):
"""Create an expando class that creates unique symbols based on attr access.
"""
@classmethod
def create_expando(cls):
"""Create an expando class that creates unique symbols based on attr access."""
class ExpandoDims:
class ExpandoDims:
def __getattr__(self, n):
return cls(n)
def __getattr__(self, n):
return cls(n)
return ExpandoDims()
return ExpandoDims()
class SymbolDef(AffineExprDef):
"""Represents a named symbol.
"""Represents a named symbol.
>>> s1 = SymbolDef("s1")
>>> s1
@@ -270,36 +270,35 @@ class SymbolDef(AffineExprDef):
False
>>> s1 is SymbolDef("s1")
True
"""
ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"]
def __new__(cls, symname: str):
existing = cls.ALL_SYMBOLS.get(symname)
if existing is not None:
return existing
new = super().__new__(cls)
new.symname = symname
cls.ALL_SYMBOLS[symname] = new
return new
def __repr__(self):
return f"Symbol({self.symname})"
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
pos = state.get_symbol(self.symname)
return _ir.AffineSymbolExpr.get(position=pos)
@classmethod
def create_expando(cls):
"""Create an expando class that creates unique symbols based on attr access.
"""
class ExpandoSymbols:
ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"]
def __getattr__(self, n):
return cls(n)
def __new__(cls, symname: str):
existing = cls.ALL_SYMBOLS.get(symname)
if existing is not None:
return existing
new = super().__new__(cls)
new.symname = symname
cls.ALL_SYMBOLS[symname] = new
return new
return ExpandoSymbols()
def __repr__(self):
return f"Symbol({self.symname})"
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
pos = state.get_symbol(self.symname)
return _ir.AffineSymbolExpr.get(position=pos)
@classmethod
def create_expando(cls):
"""Create an expando class that creates unique symbols based on attr access."""
class ExpandoSymbols:
def __getattr__(self, n):
return cls(n)
return ExpandoSymbols()
# Global accessor for on-demand dims and symbols.

File diff suppressed because it is too large Load Diff

View File

@@ -21,422 +21,468 @@ __all__ = ["LinalgStructuredOpConfig", "LinalgOpConfig", "OperandDefConfig"]
def _serialize_affine_map(affine_map: _ir.AffineMap) -> str:
with affine_map.context:
# Affine map printing/parsing is via an AffineMap attr.
attr = _ir.AffineMapAttr.get(affine_map)
return str(attr)
with affine_map.context:
# Affine map printing/parsing is via an AffineMap attr.
attr = _ir.AffineMapAttr.get(affine_map)
return str(attr)
class TensorUseConfig:
"""Wrapper around a TensorUse with additional context-bound state."""
"""Wrapper around a TensorUse with additional context-bound state."""
def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap):
self.tensor_use = tensor_use
self.indexing_map = indexing_map
def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap):
self.tensor_use = tensor_use
self.indexing_map = indexing_map
def __repr__(self):
return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
def __repr__(self):
return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
class OperandDefConfig(YAMLObject):
"""Wrapper containing an operand definition with additional state."""
yaml_tag = "!LinalgOperandDefConfig"
"""Wrapper containing an operand definition with additional state."""
def __init__(self,
operand_def: OperandDef,
shape_map: Optional[_ir.AffineMap] = None,
index_attr_map: Optional[_ir.AffineMap] = None):
self.operand_def = operand_def
self.shape_map = shape_map # type: Optional[_ir.AffineMap]
self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap]
self.indexing_map = None # type: Optional[_ir.AffineMap]
yaml_tag = "!LinalgOperandDefConfig"
@property
def name(self) -> str:
return self.operand_def.name
def __init__(
self,
operand_def: OperandDef,
shape_map: Optional[_ir.AffineMap] = None,
index_attr_map: Optional[_ir.AffineMap] = None,
):
self.operand_def = operand_def
self.shape_map = shape_map # type: Optional[_ir.AffineMap]
self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap]
self.indexing_map = None # type: Optional[_ir.AffineMap]
@property
def kind(self) -> OperandKind:
return self.operand_def.kind
@property
def name(self) -> str:
return self.operand_def.name
@property
def type_var(self) -> TypeVar:
return self.operand_def.type_var
@property
def kind(self) -> OperandKind:
return self.operand_def.kind
def to_yaml_custom_dict(self):
self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
if self.type_var:
self_dict["type_var"] = self.type_var.name
if self.shape_map:
self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
if self.index_attr_map:
self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
if self.operand_def.default_indices:
self_dict["default_indices"] = self.operand_def.default_indices
if self.operand_def.default_fn:
self_dict["default_fn"] = self.operand_def.default_fn
return self_dict
@property
def type_var(self) -> TypeVar:
return self.operand_def.type_var
def __repr__(self):
return (f"OperandDefConfig({self.operand_def}, "
def to_yaml_custom_dict(self):
self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
if self.type_var:
self_dict["type_var"] = self.type_var.name
if self.shape_map:
self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
if self.index_attr_map:
self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
if self.operand_def.default_indices:
self_dict["default_indices"] = self.operand_def.default_indices
if self.operand_def.default_fn:
self_dict["default_fn"] = self.operand_def.default_fn
return self_dict
def __repr__(self):
return (
f"OperandDefConfig({self.operand_def}, "
f"shape_map={self.shape_map}, "
f"index_attr_map={self.index_attr_map}, "
f"indexing_map={self.indexing_map})")
f"indexing_map={self.indexing_map})"
)
class LinalgIndexingMapsConfig(YAMLObject):
"""Abstracts the style of indexing maps that the op exports.
"""Abstracts the style of indexing maps that the op exports.
Presently only static (tied to the op name) indexing maps are supported. In
the future, it is expected that we will have additional variants:
- Dynamic based on attributes
- Dynamic based on operands
Each is expected to require a different variant of specification.
"""
yaml_tag = "!LinalgIndexingMapsConfig"
Presently only static (tied to the op name) indexing maps are supported. In
the future, it is expected that we will have additional variants:
- Dynamic based on attributes
- Dynamic based on operands
Each is expected to require a different variant of specification.
"""
def __init__(self,
static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None):
self.static_indexing_maps = static_indexing_maps
yaml_tag = "!LinalgIndexingMapsConfig"
def to_yaml_custom_dict(self):
if self.static_indexing_maps is not None:
return dict(static_indexing_maps=[
_serialize_affine_map(m) for m in self.static_indexing_maps
])
raise ValueError(
f"LinalgIndexingMapsConfig must have one type of indexing map"
f"(got none)")
def __init__(self, static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None):
self.static_indexing_maps = static_indexing_maps
def to_yaml_custom_dict(self):
if self.static_indexing_maps is not None:
return dict(
static_indexing_maps=[
_serialize_affine_map(m) for m in self.static_indexing_maps
]
)
raise ValueError(
f"LinalgIndexingMapsConfig must have one type of indexing map" f"(got none)"
)
class LinalgStructuredOpConfig(YAMLObject):
"""Configuration for metadata sufficient to construct a linalg named op."""
"""Configuration for metadata sufficient to construct a linalg named op."""
yaml_tag = "!LinalgStructuredOpConfig"
yaml_tag = "!LinalgStructuredOpConfig"
def __init__(self,
comprehension: Comprehension,
domain: Sequence[DimDef],
registered_operands: Sequence[OperandDef],
context: Optional[_ir.Context] = None):
self.context = context if context is not None else _ir.Context()
self.affine_state = AffineBuildState()
self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]]
self.operands = dict() # type: Dict[OperandDef, OperandDefConfig]
self.uses = dict() # type: Dict[TensorUse, TensorUseConfig]
def __init__(
self,
comprehension: Comprehension,
domain: Sequence[DimDef],
registered_operands: Sequence[OperandDef],
context: Optional[_ir.Context] = None,
):
self.context = context if context is not None else _ir.Context()
self.affine_state = AffineBuildState()
self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]]
self.operands = dict() # type: Dict[OperandDef, OperandDefConfig]
self.uses = dict() # type: Dict[TensorUse, TensorUseConfig]
# Compute the ordered set of writes and collect the tensor, capture, dims,
# and index uses.
collected_tensor_uses = set()
collected_scalar_uses = set()
collected_dim_uses = set()
collected_indices = set()
for write_use, read_use in zip(comprehension.definitions,
comprehension.values):
self.writes.append((write_use, read_use))
# Compute the ordered set of writes and collect the tensor, capture, dims,
# and index uses.
collected_tensor_uses = set()
collected_scalar_uses = set()
collected_dim_uses = set()
collected_indices = set()
for write_use, read_use in zip(comprehension.definitions, comprehension.values):
self.writes.append((write_use, read_use))
for write_use, read_use in self.writes:
collected_tensor_uses.add(write_use)
read_use.collect_tensor_uses(collected_tensor_uses)
read_use.collect_scalar_uses(collected_scalar_uses)
read_use.collect_dim_uses(collected_dim_uses)
write_use.collect_dim_uses(collected_dim_uses)
read_use.collect_indices(collected_indices)
for write_use, read_use in self.writes:
collected_tensor_uses.add(write_use)
read_use.collect_tensor_uses(collected_tensor_uses)
read_use.collect_scalar_uses(collected_scalar_uses)
read_use.collect_dim_uses(collected_dim_uses)
write_use.collect_dim_uses(collected_dim_uses)
read_use.collect_indices(collected_indices)
# Set domain to the sorted list of uses if no domain annotation is given.
if not domain:
domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
# Set domain to the sorted list of uses if no domain annotation is given.
if not domain:
domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
# Verify the domain dimensions match the used dimensions.
if (len(domain) != len(collected_dim_uses) or
any(dim not in collected_dim_uses for dim in domain)):
raise ValueError(f"Expected the annotated domain dimensions {domain} to "
f"match the set of dimension used by the tensor "
f"comprehension {collected_dim_uses}")
# Verify the domain dimensions match the used dimensions.
if len(domain) != len(collected_dim_uses) or any(
dim not in collected_dim_uses for dim in domain
):
raise ValueError(
f"Expected the annotated domain dimensions {domain} to "
f"match the set of dimension used by the tensor "
f"comprehension {collected_dim_uses}"
)
# Instantiate the dimensions in the given order.
with self.context:
local_state = AffineBuildState(
global_state=self.affine_state, allow_new_symbols=False)
for dim in domain:
dim.build(state=local_state)
# Instantiate the dimensions in the given order.
with self.context:
local_state = AffineBuildState(
global_state=self.affine_state, allow_new_symbols=False
)
for dim in domain:
dim.build(state=local_state)
# Collect all attribute definitions.
collected_attr_defs = list()
for operand in registered_operands:
if operand.is_attribute():
collected_attr_defs.append(operand)
# Collect all attribute definitions.
collected_attr_defs = list()
for operand in registered_operands:
if operand.is_attribute():
collected_attr_defs.append(operand)
# Collect all tensors with manual indexing annotation.
collected_index_defs = list()
for operand in registered_operands:
if operand.index_dims:
if any(dim not in collected_dim_uses for dim in operand.index_dims):
raise ValueError(f"Expected all index dims {operand.index_dims} of "
f"operand {operand.name} to have uses.")
collected_index_defs.append(operand)
# Collect all tensors with manual indexing annotation.
collected_index_defs = list()
for operand in registered_operands:
if operand.index_dims:
if any(dim not in collected_dim_uses for dim in operand.index_dims):
raise ValueError(
f"Expected all index dims {operand.index_dims} of "
f"operand {operand.name} to have uses."
)
collected_index_defs.append(operand)
# Collect the operand definitions of all tensor/scalar uses, attributes, and
# shape-only tensors.
all_operand_defs = list()
for use in collected_tensor_uses:
all_operand_defs.append(use.operand_def)
for use in collected_scalar_uses:
all_operand_defs.append(use.operand_def)
for definition in collected_attr_defs:
all_operand_defs.append(definition)
for definition in collected_index_defs:
all_operand_defs.append(definition)
# Collect the operand definitions of all tensor/scalar uses, attributes, and
# shape-only tensors.
all_operand_defs = list()
for use in collected_tensor_uses:
all_operand_defs.append(use.operand_def)
for use in collected_scalar_uses:
all_operand_defs.append(use.operand_def)
for definition in collected_attr_defs:
all_operand_defs.append(definition)
for definition in collected_index_defs:
all_operand_defs.append(definition)
# Add all operands in registration order to ensure the symbols are
# registered in the order they appear.
all_operand_defs = sorted(
all_operand_defs, key=lambda operand_def: operand_def.registered_index)
for operand_def in all_operand_defs:
self.add_operand(operand_def)
# Add all operands in registration order to ensure the symbols are
# registered in the order they appear.
all_operand_defs = sorted(
all_operand_defs, key=lambda operand_def: operand_def.registered_index
)
for operand_def in all_operand_defs:
self.add_operand(operand_def)
# Add all shape-only tensor index_dim annotations and all tensor uses.
for definition in collected_index_defs:
self.add_indexed_operand(definition)
for use in collected_tensor_uses:
self.add_tensor_use(use)
# Add all shape-only tensor index_dim annotations and all tensor uses.
for definition in collected_index_defs:
self.add_indexed_operand(definition)
for use in collected_tensor_uses:
self.add_tensor_use(use)
# Normalize all shape and indexing maps now that full count of dims and
# symbols are known.
for cuse in self.uses.values():
cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
for definition in collected_index_defs:
self.operands[definition].indexing_map = self._normalize_affine_map(
self.operands[definition].indexing_map)
for operand_config in self.operands.values():
if operand_config.shape_map:
operand_config.shape_map = self._normalize_affine_map(
operand_config.shape_map, with_dims=False)
if operand_config.index_attr_map:
operand_config.index_attr_map = self._normalize_affine_map(
operand_config.index_attr_map, with_dims=False)
# Normalize all shape and indexing maps now that full count of dims and
# symbols are known.
for cuse in self.uses.values():
cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
for definition in collected_index_defs:
self.operands[definition].indexing_map = self._normalize_affine_map(
self.operands[definition].indexing_map
)
for operand_config in self.operands.values():
if operand_config.shape_map:
operand_config.shape_map = self._normalize_affine_map(
operand_config.shape_map, with_dims=False
)
if operand_config.index_attr_map:
operand_config.index_attr_map = self._normalize_affine_map(
operand_config.index_attr_map, with_dims=False
)
# Now for each write use, propagate the indexing maps from the use to the
# tensor, ensuring that there are not conflicts.
for write_use, _ in self.writes:
write_tensor_config = self.operands[write_use.operand_def]
if write_tensor_config.indexing_map:
raise ValueError(
f"Unexpected multi-write to a single tensor: {write_tensor_config}")
write_tensor_config.indexing_map = self.uses[write_use].indexing_map
# Now for each write use, propagate the indexing maps from the use to the
# tensor, ensuring that there are not conflicts.
for write_use, _ in self.writes:
write_tensor_config = self.operands[write_use.operand_def]
if write_tensor_config.indexing_map:
raise ValueError(
f"Unexpected multi-write to a single tensor: {write_tensor_config}"
)
write_tensor_config.indexing_map = self.uses[write_use].indexing_map
# For each read use, propagate the indexing maps from the use to the
# tensor, ensuring that there are not conflicts.
for _, read_expr in self.writes:
read_uses = set() # type: Set[TensorUse]
read_expr.collect_tensor_uses(read_uses)
for read_use in read_uses:
read_operand_config = self.operands[read_use.operand_def]
if (read_operand_config.indexing_map and
read_operand_config.indexing_map !=
self.uses[read_use].indexing_map):
raise ValueError(
f"Unexpected multi-read of a tensor with different accesses:"
f"{read_operand_config} vs {read_use}")
read_operand_config.indexing_map = self.uses[read_use].indexing_map
# For each read use, propagate the indexing maps from the use to the
# tensor, ensuring that there are not conflicts.
for _, read_expr in self.writes:
read_uses = set() # type: Set[TensorUse]
read_expr.collect_tensor_uses(read_uses)
for read_use in read_uses:
read_operand_config = self.operands[read_use.operand_def]
if (
read_operand_config.indexing_map
and read_operand_config.indexing_map
!= self.uses[read_use].indexing_map
):
raise ValueError(
f"Unexpected multi-read of a tensor with different accesses:"
f"{read_operand_config} vs {read_use}"
)
read_operand_config.indexing_map = self.uses[read_use].indexing_map
# Set the indexing map of all scalar uses to the empty map.
for operand_config in self.operands.values():
if operand_config.operand_def.kind == OperandKind.SCALAR:
operand_config.indexing_map = self._get_scalar_map()
# Set the indexing map of all scalar uses to the empty map.
for operand_config in self.operands.values():
if operand_config.operand_def.kind == OperandKind.SCALAR:
operand_config.indexing_map = self._get_scalar_map()
# Check all registered tensor and scalar operands have an indexing map.
for operand in registered_operands:
if operand.is_attribute():
continue
if not (operand in self.operands and self.operands[operand].indexing_map):
raise ValueError(f"Failed to compute an indexing map for operand "
f"{operand.name}")
# Check all registered tensor and scalar operands have an indexing map.
for operand in registered_operands:
if operand.is_attribute():
continue
if not (operand in self.operands and self.operands[operand].indexing_map):
raise ValueError(
f"Failed to compute an indexing map for operand " f"{operand.name}"
)
# Collect reduction dims and ensure all the same.
all_reduction_dims = set(comprehension.all_reduction_dims)
if len(all_reduction_dims) != 1:
raise ValueError(
f"All writes within a generic must have the same reduction "
f"dims. Got: {all_reduction_dims}")
self.reduction_dims = next(iter(all_reduction_dims))
# Collect reduction dims and ensure all the same.
all_reduction_dims = set(comprehension.all_reduction_dims)
if len(all_reduction_dims) != 1:
raise ValueError(
f"All writes within a generic must have the same reduction "
f"dims. Got: {all_reduction_dims}"
)
self.reduction_dims = next(iter(all_reduction_dims))
# Check the index dimension exists and resolve.
for index in collected_indices:
if index.dim_def.dimname not in self.affine_state.all_dims:
raise ValueError(
f"The dimension {index.dim_def.dimname} is not part of the "
f"iteration domain {self.affine_state.all_dims}")
index.resolve_dimension_name(self.affine_state)
# Check the index dimension exists and resolve.
for index in collected_indices:
if index.dim_def.dimname not in self.affine_state.all_dims:
raise ValueError(
f"The dimension {index.dim_def.dimname} is not part of the "
f"iteration domain {self.affine_state.all_dims}"
)
index.resolve_dimension_name(self.affine_state)
# Generate the scalar assignments (used to build a body).
self.assignments = [
ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
for write_use, read_expr in self.writes
]
# Generate the scalar assignments (used to build a body).
self.assignments = [
ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
for write_use, read_expr in self.writes
]
@property
def ordered_operands(self) -> Sequence[OperandDefConfig]:
return sorted(
self.operands.values(),
key=lambda operand: operand.operand_def.registered_index)
@property
def ordered_operands(self) -> Sequence[OperandDefConfig]:
return sorted(
self.operands.values(),
key=lambda operand: operand.operand_def.registered_index,
)
@property
def ordered_dims(self) -> Sequence[Tuple[str, int]]:
"""Gets the ordered list of dim bindings (symbolic name, position).
@property
def ordered_dims(self) -> Sequence[Tuple[str, int]]:
"""Gets the ordered list of dim bindings (symbolic name, position).
TODO: The original parser relies on parse ordering to arrive at the
iterator types, but that ordering is not defined on the Python side, so
this may be ambiguous.
"""
return list(self.affine_state.all_dims.items())
TODO: The original parser relies on parse ordering to arrive at the
iterator types, but that ordering is not defined on the Python side, so
this may be ambiguous.
"""
return list(self.affine_state.all_dims.items())
@property
def indexing_maps(self) -> Sequence[_ir.AffineMap]:
return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
@property
def indexing_maps(self) -> Sequence[_ir.AffineMap]:
return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
@property
def iterator_types(self) -> Sequence[str]:
@property
def iterator_types(self) -> Sequence[str]:
def get_type(symbolic_name, position):
for reduction_dim_expr in self.reduction_dims:
if reduction_dim_expr.dimname == symbolic_name:
return "reduction"
return "parallel"
def get_type(symbolic_name, position):
for reduction_dim_expr in self.reduction_dims:
if reduction_dim_expr.dimname == symbolic_name:
return "reduction"
return "parallel"
return [get_type(*dim) for dim in self.ordered_dims]
return [get_type(*dim) for dim in self.ordered_dims]
def add_operand(self, operand_def: OperandDef):
if operand_def in self.operands:
return
if not (operand_def.is_tensor() or operand_def.kind == OperandKind.INDEX_ATTR):
self.operands[operand_def] = OperandDefConfig(operand_def)
return
with self.context:
local_state = AffineBuildState(
global_state=self.affine_state, allow_new_dims=False
)
exprs = []
for expr in operand_def.size_exprs:
exprs.append(expr.build(state=local_state))
assert local_state.local_dim_count == 0
affine_map = _ir.AffineMap.get(
dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs
)
if operand_def.kind == OperandKind.INDEX_ATTR:
self.operands[operand_def] = OperandDefConfig(
operand_def, index_attr_map=affine_map
)
else:
self.operands[operand_def] = OperandDefConfig(
operand_def, shape_map=affine_map
)
def add_operand(self, operand_def: OperandDef):
if operand_def in self.operands:
return
if not (operand_def.is_tensor() or
operand_def.kind == OperandKind.INDEX_ATTR):
self.operands[operand_def] = OperandDefConfig(operand_def)
return
with self.context:
local_state = AffineBuildState(
global_state=self.affine_state, allow_new_dims=False)
exprs = []
for expr in operand_def.size_exprs:
exprs.append(expr.build(state=local_state))
assert local_state.local_dim_count == 0
affine_map = _ir.AffineMap.get(
dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
if operand_def.kind == OperandKind.INDEX_ATTR:
self.operands[operand_def] = OperandDefConfig(
operand_def, index_attr_map=affine_map)
else:
self.operands[operand_def] = OperandDefConfig(
operand_def, shape_map=affine_map)
def add_indexed_operand(self, operand_def: OperandDef):
with self.context:
local_state = AffineBuildState(
global_state=self.affine_state, allow_new_symbols=False
)
exprs = []
for expr in operand_def.index_dims:
exprs.append(expr.build(state=local_state))
self.operands[operand_def].indexing_map = _ir.AffineMap.get(
dim_count=local_state.dim_count,
symbol_count=local_state.symbol_count,
exprs=exprs,
)
def add_indexed_operand(self, operand_def: OperandDef):
with self.context:
local_state = AffineBuildState(
global_state=self.affine_state, allow_new_symbols=False)
exprs = []
for expr in operand_def.index_dims:
exprs.append(expr.build(state=local_state))
self.operands[operand_def].indexing_map = _ir.AffineMap.get(
dim_count=local_state.dim_count,
symbol_count=local_state.symbol_count,
exprs=exprs)
def add_tensor_use(self, tensor_use: TensorUse):
if tensor_use in self.uses:
return
with self.context:
local_state = AffineBuildState(
global_state=self.affine_state, allow_new_symbols=False
)
exprs = []
for expr in tensor_use.indices:
exprs.append(expr.build(state=local_state))
indexing_map = _ir.AffineMap.get(
dim_count=local_state.dim_count,
symbol_count=local_state.symbol_count,
exprs=exprs,
)
def add_tensor_use(self, tensor_use: TensorUse):
if tensor_use in self.uses:
return
with self.context:
local_state = AffineBuildState(
global_state=self.affine_state, allow_new_symbols=False)
exprs = []
for expr in tensor_use.indices:
exprs.append(expr.build(state=local_state))
indexing_map = _ir.AffineMap.get(
dim_count=local_state.dim_count,
symbol_count=local_state.symbol_count,
exprs=exprs)
use_config = TensorUseConfig(tensor_use, indexing_map)
self.uses[tensor_use] = use_config
use_config = TensorUseConfig(tensor_use, indexing_map)
self.uses[tensor_use] = use_config
def _get_scalar_map(self) -> _ir.AffineMap:
"""Create an empty affine map used to index a scalar."""
with self.context:
return _ir.AffineMap.get(
dim_count=self.affine_state.dim_count,
symbol_count=self.affine_state.symbol_count,
exprs=list(),
)
def _get_scalar_map(self) -> _ir.AffineMap:
"""Create an empty affine map used to index a scalar."""
with self.context:
return _ir.AffineMap.get(
dim_count=self.affine_state.dim_count,
symbol_count=self.affine_state.symbol_count,
exprs=list())
def _normalize_affine_map(
self, affine_map: _ir.AffineMap, with_dims: bool = True
) -> _ir.AffineMap:
"""Normalizes an indexing map to have the max known symbols and dims."""
with self.context:
return _ir.AffineMap.get(
dim_count=self.affine_state.dim_count if with_dims else 0,
symbol_count=self.affine_state.symbol_count,
exprs=list(affine_map.results),
)
def _normalize_affine_map(self,
affine_map: _ir.AffineMap,
with_dims: bool = True) -> _ir.AffineMap:
"""Normalizes an indexing map to have the max known symbols and dims."""
with self.context:
return _ir.AffineMap.get(
dim_count=self.affine_state.dim_count if with_dims else 0,
symbol_count=self.affine_state.symbol_count,
exprs=list(affine_map.results))
def to_yaml_custom_dict(self):
self_dict = dict(args=self.ordered_operands)
# TODO: Refactor the hierarchy internally when supporting more
# than static (preserving this serialized form).
self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
static_indexing_maps=self.indexing_maps
)
self_dict["iterator_types"] = self.iterator_types
self_dict["assignments"] = self.assignments
return self_dict
def to_yaml_custom_dict(self):
self_dict = dict(args=self.ordered_operands)
# TODO: Refactor the hierarchy internally when supporting more
# than static (preserving this serialized form).
self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
static_indexing_maps=self.indexing_maps)
self_dict["iterator_types"] = self.iterator_types
self_dict["assignments"] = self.assignments
return self_dict
def __repr__(self):
lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
lines.append("operands=[")
for def_config in self.ordered_operands:
lines.append(f" {repr(def_config)}")
lines.append("], indexing_maps=[")
for m in self.indexing_maps:
lines.append(f" {repr(m)}")
lines.append(f"], iterator_types=[")
for t in self.iterator_types:
lines.append(f" {t}")
lines.append("])")
return "\n".join(lines)
def __repr__(self):
lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
lines.append("operands=[")
for def_config in self.ordered_operands:
lines.append(f" {repr(def_config)}")
lines.append("], indexing_maps=[")
for m in self.indexing_maps:
lines.append(f" {repr(m)}")
lines.append(f"], iterator_types=[")
for t in self.iterator_types:
lines.append(f" {t}")
lines.append("])")
return "\n".join(lines)
class LinalgOpConfig(YAMLObject):
"""Container for any supported linalg op type.
"""Container for any supported linalg op type.
This includes the concrete type by name for ease of parsing by systems
that ignore tags.
"""
yaml_tag = "!LinalgOpConfig"
This includes the concrete type by name for ease of parsing by systems
that ignore tags.
"""
def __init__(self,
metadata: OpMetadataDef,
*,
structured_op: Optional[LinalgStructuredOpConfig] = None):
self.metadata = metadata
self.structured_op = structured_op
yaml_tag = "!LinalgOpConfig"
def to_yaml_custom_dict(self):
self_dict = dict(metadata=self.metadata,)
if self.structured_op:
self_dict["structured_op"] = self.structured_op
return self_dict
def __init__(
self,
metadata: OpMetadataDef,
*,
structured_op: Optional[LinalgStructuredOpConfig] = None,
):
self.metadata = metadata
self.structured_op = structured_op
@staticmethod
def from_linalg_op_def(
op_def: LinalgOpDef,
context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]:
"""Expands a LinalgOpDef into corresponding Linalg configured ops."""
# TODO: Many LinalgOpDef patterns need to expand to multiple generics.
assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
return [
LinalgOpConfig(
op_def.metadata,
structured_op=LinalgStructuredOpConfig(
op_def.comprehensions[0], op_def.domain,
op_def.registered_operands.values(), context)),
]
def to_yaml_custom_dict(self):
self_dict = dict(
metadata=self.metadata,
)
if self.structured_op:
self_dict["structured_op"] = self.structured_op
return self_dict
def __repr__(self):
return (f"LinalgOpConfig(metadata={self.metadata},\n"
f"structured_op={self.structured_op})")
@staticmethod
def from_linalg_op_def(
op_def: LinalgOpDef, context: Optional[_ir.Context] = None
) -> Sequence["LinalgOpConfig"]:
"""Expands a LinalgOpDef into corresponding Linalg configured ops."""
# TODO: Many LinalgOpDef patterns need to expand to multiple generics.
assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
return [
LinalgOpConfig(
op_def.metadata,
structured_op=LinalgStructuredOpConfig(
op_def.comprehensions[0],
op_def.domain,
op_def.registered_operands.values(),
context,
),
),
]
def __repr__(self):
return (
f"LinalgOpConfig(metadata={self.metadata},\n"
f"structured_op={self.structured_op})"
)

View File

@@ -10,160 +10,192 @@ import inspect
import threading
from ..... import ir
from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
from ...._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
)
from .comprehension import *
from .config import *
from .emitter import *
_CONTEXT = threading.local()
StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList,
Sequence[Union[ir.Value, ir.Operation, ir.OpView]]]
StructuredOpOuts = Union[
ir.Operation,
ir.OpView,
ir.OpResultList,
Sequence[Union[ir.Value, ir.Operation, ir.OpView]],
]
@contextmanager
def bind_op_def(op_def: LinalgOpDef):
if hasattr(_CONTEXT, "current_op_def"):
raise ValueError("Cannot recursively define an operation")
_CONTEXT.current_op_def = op_def
try:
yield op_def
finally:
del _CONTEXT.current_op_def
if hasattr(_CONTEXT, "current_op_def"):
raise ValueError("Cannot recursively define an operation")
_CONTEXT.current_op_def = op_def
try:
yield op_def
finally:
del _CONTEXT.current_op_def
def current_op_def() -> LinalgOpDef:
try:
return _CONTEXT.current_op_def
except AttributeError:
raise ValueError(
"Attempt to access the current op definition being defined "
"but none is set. Did you mean to call this in an op definition?")
try:
return _CONTEXT.current_op_def
except AttributeError:
raise ValueError(
"Attempt to access the current op definition being defined "
"but none is set. Did you mean to call this in an op definition?"
)
def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList:
if isinstance(outs, (ir.Operation, ir.OpView)):
return _get_op_results_or_values(outs)
elif isinstance(outs, ir.OpResultList):
return outs
if isinstance(outs, (ir.Operation, ir.OpView)):
return _get_op_results_or_values(outs)
elif isinstance(outs, ir.OpResultList):
return outs
return [_get_op_result_or_value(o) for o in outs]
return [_get_op_result_or_value(o) for o in outs]
class DefinedOpCallable:
"""Callable that wraps any defined op function."""
"""Callable that wraps any defined op function."""
def __init__(self, op_name: str, op_def: LinalgOpDef):
self.op_name = op_name
self.op_def = op_def
def __init__(self, op_name: str, op_def: LinalgOpDef):
self.op_name = op_name
self.op_def = op_def
def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value],
outs: StructuredOpOuts, **kwargs):
"""Emits the corresponding op definition as IR.
def __call__(
self,
*ins: Union[ir.Operation, ir.OpView, ir.Value],
outs: StructuredOpOuts,
**kwargs,
):
"""Emits the corresponding op definition as IR.
Most arguments are passed through to the underlying emitter. The following
keyword argument is interpreted here:
emit_generic: Emits a generic form as appropriate (default True). If
False, a named form is emitted (which must have been built in to the
compiler).
"""
emit_generic = kwargs.pop("emit_generic", False)
if not isinstance(emit_generic, bool):
raise ValueError(f"The named argument 'emit_generic' needs to be "
f" of type bool but got {type(emit_generic)}")
Most arguments are passed through to the underlying emitter. The following
keyword argument is interpreted here:
emit_generic: Emits a generic form as appropriate (default True). If
False, a named form is emitted (which must have been built in to the
compiler).
"""
emit_generic = kwargs.pop("emit_generic", False)
if not isinstance(emit_generic, bool):
raise ValueError(
f"The named argument 'emit_generic' needs to be "
f" of type bool but got {type(emit_generic)}"
)
op_configs = LinalgOpConfig.from_linalg_op_def(
self.op_def, context=ir.Context.current)
op_configs = LinalgOpConfig.from_linalg_op_def(
self.op_def, context=ir.Context.current
)
if len(op_configs) != 1:
# TODO: Support composite ops.
raise NotImplementedError(
f"Emission of composite linalg ops not supported: {op_configs}")
if len(op_configs) != 1:
# TODO: Support composite ops.
raise NotImplementedError(
f"Emission of composite linalg ops not supported: {op_configs}"
)
ctx = ir.Context.current
linalgDialect = ctx.get_dialect_descriptor("linalg")
fully_qualified_name = "linalg." + self.op_name
emit_generic = (
emit_generic or not ctx.is_registered_operation(fully_qualified_name))
ctx = ir.Context.current
linalgDialect = ctx.get_dialect_descriptor("linalg")
fully_qualified_name = "linalg." + self.op_name
emit_generic = emit_generic or not ctx.is_registered_operation(
fully_qualified_name
)
op_config = op_configs[0]
out_values = _prepare_structured_op_outs(outs)
in_values = [_get_op_result_or_value(i) for i in ins]
if op_config.structured_op:
if emit_generic:
return emit_generic_structured_op(
op_config.structured_op, *in_values, outs=out_values, **kwargs)
else:
return emit_named_structured_op(
op_config.structured_op,
self.op_name,
self.op_def.metadata.cpp_class_name,
*in_values,
outs=out_values,
**kwargs)
op_config = op_configs[0]
out_values = _prepare_structured_op_outs(outs)
in_values = [_get_op_result_or_value(i) for i in ins]
if op_config.structured_op:
if emit_generic:
return emit_generic_structured_op(
op_config.structured_op, *in_values, outs=out_values, **kwargs
)
else:
return emit_named_structured_op(
op_config.structured_op,
self.op_name,
self.op_def.metadata.cpp_class_name,
*in_values,
outs=out_values,
**kwargs,
)
raise NotImplementedError(
f"Emission of linalg op type not supported: {op_config}")
raise NotImplementedError(
f"Emission of linalg op type not supported: {op_config}"
)
def linalg_structured_op(dsl_func=None,
*,
op_name=None,
op_class_name=None) -> DefinedOpCallable:
if dsl_func is None:
# Curry the keyword args in for delayed application.
return functools.partial(
linalg_structured_op, op_name=op_name, op_class_name=op_class_name)
# Determine default names by introspecting the function.
if op_name is None:
op_name = dsl_func.__name__
if op_class_name is None:
# Camel case it.
op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
def linalg_structured_op(
dsl_func=None, *, op_name=None, op_class_name=None
) -> DefinedOpCallable:
if dsl_func is None:
# Curry the keyword args in for delayed application.
return functools.partial(
linalg_structured_op, op_name=op_name, op_class_name=op_class_name
)
# Determine default names by introspecting the function.
if op_name is None:
op_name = dsl_func.__name__
if op_class_name is None:
# Camel case it.
op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
op_def = LinalgOpDef(
name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func))
op_def = LinalgOpDef(
name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)
)
# Extract arguments and TensorDefs from the signature.
dsl_func_args = list()
sig = inspect.signature(dsl_func)
for param_name, param in sig.parameters.items():
param_default = param.default
if isinstance(param_default,
(TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef,
BinaryFnAttrDef, TypeFnAttrDef)):
op_def.add_operand(param_name, param_default.operand_def)
else:
raise ValueError(
f"@linalg_structured_op function parameters must be defaulted as "
f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): "
f"Found {param_name}: {param_default}")
dsl_func_args.append(param_default)
# Extract arguments and TensorDefs from the signature.
dsl_func_args = list()
sig = inspect.signature(dsl_func)
for param_name, param in sig.parameters.items():
param_default = param.default
if isinstance(
param_default,
(
TensorDef,
ScalarDef,
IndexAttrDef,
UnaryFnAttrDef,
BinaryFnAttrDef,
TypeFnAttrDef,
),
):
op_def.add_operand(param_name, param_default.operand_def)
else:
raise ValueError(
f"@linalg_structured_op function parameters must be defaulted as "
f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): "
f"Found {param_name}: {param_default}"
)
dsl_func_args.append(param_default)
# Invoke the DSL func to finish populating the op definition.
with bind_op_def(op_def):
dsl_func(*dsl_func_args)
# Invoke the DSL func to finish populating the op definition.
with bind_op_def(op_def):
dsl_func(*dsl_func_args)
# TODO: The returned callable should be an IR emitter but that is not
# upstreamed yet.
return DefinedOpCallable(op_name, op_def)
# TODO: The returned callable should be an IR emitter but that is not
# upstreamed yet.
return DefinedOpCallable(op_name, op_def)
def domain(*dimensions: DimDef):
if any(not isinstance(d, DimDef) for d in dimensions):
raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
current_op_def().domain.extend(dimensions)
if any(not isinstance(d, DimDef) for d in dimensions):
raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
current_op_def().domain.extend(dimensions)
def implements(*interfaces: OpInterfaceDef):
if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
raise ValueError(
f"Expected interfaces of type OpInterfaceDef but got {interfaces}")
current_op_def().metadata.implements.extend(interfaces)
if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
raise ValueError(
f"Expected interfaces of type OpInterfaceDef but got {interfaces}"
)
current_op_def().metadata.implements.extend(interfaces)
def defines(*definitions: OpDefinitionDef):
if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
raise ValueError(
f"Expected definitions of type OpDefinitionDef but got {definitions}")
current_op_def().metadata.defines.extend(definitions)
if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
raise ValueError(
f"Expected definitions of type OpDefinitionDef but got {definitions}"
)
current_op_def().metadata.defines.extend(definitions)

File diff suppressed because it is too large Load Diff

View File

@@ -30,123 +30,137 @@ __all__ = [
class ScalarFn:
"""A type of ScalarExpression that applies a function."""
"""A type of ScalarExpression that applies a function."""
def __init__(self, kind: "FunctionKind", fn_name: Optional[str],
attr_name: Optional[str], type_var: Optional["TypeVar"],
operands: Sequence["ScalarExpression"]):
if bool(fn_name) + bool(attr_name) != 1:
raise ValueError("One of 'fn_name', 'attr_name' must be specified")
self.kind = kind
self.fn_name = fn_name
self.attr_name = attr_name
self.type_var = type_var
self.operands = operands
def __init__(
self,
kind: "FunctionKind",
fn_name: Optional[str],
attr_name: Optional[str],
type_var: Optional["TypeVar"],
operands: Sequence["ScalarExpression"],
):
if bool(fn_name) + bool(attr_name) != 1:
raise ValueError("One of 'fn_name', 'attr_name' must be specified")
self.kind = kind
self.fn_name = fn_name
self.attr_name = attr_name
self.type_var = type_var
self.operands = operands
def expr(self) -> "ScalarExpression":
return ScalarExpression(scalar_fn=self)
def expr(self) -> "ScalarExpression":
return ScalarExpression(scalar_fn=self)
def __repr__(self):
name = self.fn_name if self.fn_name else self.attr_name
return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
f"operands=[{', '.join(self.operands)}])")
def __repr__(self):
name = self.fn_name if self.fn_name else self.attr_name
return (
f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
f"operands=[{', '.join(self.operands)}])"
)
class ScalarArg:
"""A type of ScalarExpression that references a named argument."""
"""A type of ScalarExpression that references a named argument."""
def __init__(self, arg: str):
self.arg = arg
def __init__(self, arg: str):
self.arg = arg
def expr(self) -> "ScalarExpression":
return ScalarExpression(scalar_arg=self)
def expr(self) -> "ScalarExpression":
return ScalarExpression(scalar_arg=self)
def __repr__(self):
return f"(ScalarArg({self.arg})"
def __repr__(self):
return f"(ScalarArg({self.arg})"
class ScalarConst:
"""A type of ScalarExpression representing a constant."""
"""A type of ScalarExpression representing a constant."""
def __init__(self, value: str):
self.value = value
def __init__(self, value: str):
self.value = value
def expr(self) -> "ScalarExpression":
return ScalarExpression(scalar_const=self)
def expr(self) -> "ScalarExpression":
return ScalarExpression(scalar_const=self)
def __repr__(self):
return f"(ScalarConst({self.value})"
def __repr__(self):
return f"(ScalarConst({self.value})"
class ScalarIndex:
"""A type of ScalarExpression accessing an iteration index."""
"""A type of ScalarExpression accessing an iteration index."""
def __init__(self, dim: int):
self.dim = dim
def __init__(self, dim: int):
self.dim = dim
def expr(self) -> "ScalarExpression":
return ScalarExpression(scalar_index=self)
def expr(self) -> "ScalarExpression":
return ScalarExpression(scalar_index=self)
def __repr__(self):
return f"(ScalarIndex({self.dim})"
def __repr__(self):
return f"(ScalarIndex({self.dim})"
class ScalarExpression(YAMLObject):
"""An expression on scalar values.
"""An expression on scalar values.
Can be one of:
- ScalarFn
- ScalarArg
- ScalarConst
- ScalarIndex
"""
yaml_tag = "!ScalarExpression"
Can be one of:
- ScalarFn
- ScalarArg
- ScalarConst
- ScalarIndex
"""
def __init__(self,
scalar_fn: Optional[ScalarFn] = None,
scalar_arg: Optional[ScalarArg] = None,
scalar_const: Optional[ScalarConst] = None,
scalar_index: Optional[ScalarIndex] = None):
if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) +
bool(scalar_index)) != 1:
raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
"'scalar_index' must be specified")
self.scalar_fn = scalar_fn
self.scalar_arg = scalar_arg
self.scalar_const = scalar_const
self.scalar_index = scalar_index
yaml_tag = "!ScalarExpression"
def to_yaml_custom_dict(self):
if self.scalar_fn:
scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
if self.scalar_fn.fn_name:
scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
if self.scalar_fn.attr_name:
scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
if self.scalar_fn.type_var:
scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
return dict(scalar_fn=scalar_fn_dict)
elif self.scalar_arg:
return dict(scalar_arg=self.scalar_arg.arg)
elif self.scalar_const:
return dict(scalar_const=self.scalar_const.value)
elif self.scalar_index:
return dict(scalar_index=self.scalar_index.dim)
else:
raise ValueError(f"Unexpected ScalarExpression type: {self}")
def __init__(
self,
scalar_fn: Optional[ScalarFn] = None,
scalar_arg: Optional[ScalarArg] = None,
scalar_const: Optional[ScalarConst] = None,
scalar_index: Optional[ScalarIndex] = None,
):
if (
bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + bool(scalar_index)
) != 1:
raise ValueError(
"One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
"'scalar_index' must be specified"
)
self.scalar_fn = scalar_fn
self.scalar_arg = scalar_arg
self.scalar_const = scalar_const
self.scalar_index = scalar_index
def to_yaml_custom_dict(self):
if self.scalar_fn:
scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
if self.scalar_fn.fn_name:
scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
if self.scalar_fn.attr_name:
scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
if self.scalar_fn.type_var:
scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
return dict(scalar_fn=scalar_fn_dict)
elif self.scalar_arg:
return dict(scalar_arg=self.scalar_arg.arg)
elif self.scalar_const:
return dict(scalar_const=self.scalar_const.value)
elif self.scalar_index:
return dict(scalar_index=self.scalar_index.dim)
else:
raise ValueError(f"Unexpected ScalarExpression type: {self}")
class ScalarAssign(YAMLObject):
"""An assignment to a named argument (LHS of a comprehension)."""
yaml_tag = "!ScalarAssign"
"""An assignment to a named argument (LHS of a comprehension)."""
def __init__(self, arg: str, value: ScalarExpression):
self.arg = arg
self.value = value
yaml_tag = "!ScalarAssign"
def to_yaml_custom_dict(self):
return dict(arg=self.arg, value=self.value)
def __init__(self, arg: str, value: ScalarExpression):
self.arg = arg
self.value = value
def __repr__(self):
return f"ScalarAssign({self.arg}, {self.value})"
def to_yaml_custom_dict(self):
return dict(arg=self.arg, value=self.value)
def __repr__(self):
return f"ScalarAssign({self.arg}, {self.value})"

View File

@@ -21,13 +21,11 @@ from typing import Dict
__all__ = [
"TypeVar",
"TV",
# Predefined types.
"I32",
"I64",
"F32",
"F64",
# TypeVar aliases.
"T",
"U",
@@ -36,34 +34,34 @@ __all__ = [
class TypeVar:
"""A replaceable type variable.
"""A replaceable type variable.
Type variables are uniqued by name.
"""
ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"]
Type variables are uniqued by name.
"""
def __new__(cls, name: str):
existing = cls.ALL_TYPEVARS.get(name)
if existing is not None:
return existing
new = super().__new__(cls)
new.name = name
cls.ALL_TYPEVARS[name] = new
return new
ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"]
def __repr__(self):
return f"TypeVar({self.name})"
def __new__(cls, name: str):
existing = cls.ALL_TYPEVARS.get(name)
if existing is not None:
return existing
new = super().__new__(cls)
new.name = name
cls.ALL_TYPEVARS[name] = new
return new
@classmethod
def create_expando(cls):
"""Create an expando class that creates unique type vars on attr access."""
def __repr__(self):
return f"TypeVar({self.name})"
class ExpandoTypeVars:
@classmethod
def create_expando(cls):
"""Create an expando class that creates unique type vars on attr access."""
def __getattr__(self, n):
return cls(n)
class ExpandoTypeVars:
def __getattr__(self, n):
return cls(n)
return ExpandoTypeVars()
return ExpandoTypeVars()
# Expando access via TV.foo

View File

@@ -6,11 +6,12 @@
import sys
try:
import yaml
import yaml
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"This tool requires PyYAML but it was not installed. "
f"Recommend: {sys.executable} -m pip install PyYAML") from e
raise ModuleNotFoundError(
f"This tool requires PyYAML but it was not installed. "
f"Recommend: {sys.executable} -m pip install PyYAML"
) from e
__all__ = [
"yaml_dump",
@@ -20,35 +21,33 @@ __all__ = [
class YAMLObject(yaml.YAMLObject):
@classmethod
def to_yaml(cls, dumper, self):
"""Default to a custom dictionary mapping."""
return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())
@classmethod
def to_yaml(cls, dumper, self):
"""Default to a custom dictionary mapping."""
return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())
def to_yaml_custom_dict(self):
raise NotImplementedError()
def to_yaml_custom_dict(self):
raise NotImplementedError()
def as_linalg_yaml(self):
return yaml_dump(self)
def as_linalg_yaml(self):
return yaml_dump(self)
def multiline_str_representer(dumper, data):
if len(data.splitlines()) > 1:
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
else:
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
if len(data.splitlines()) > 1:
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
else:
return dumper.represent_scalar("tag:yaml.org,2002:str", data)
yaml.add_representer(str, multiline_str_representer)
def yaml_dump(data, sort_keys=False, **kwargs):
return yaml.dump(data, sort_keys=sort_keys, **kwargs)
return yaml.dump(data, sort_keys=sort_keys, **kwargs)
def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs):
return yaml.dump_all(data,
sort_keys=sort_keys,
explicit_start=explicit_start,
**kwargs)
return yaml.dump_all(
data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs
)

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,8 @@
from ._python_test_ops_gen import *
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
def register_python_test_dialect(context, load=True):
from .._mlir_libs import _mlirPythonTest
_mlirPythonTest.register_python_test_dialect(context, load)
from .._mlir_libs import _mlirPythonTest
_mlirPythonTest.register_python_test_dialect(context, load)

View File

@@ -6,16 +6,18 @@ from enum import Enum
class FailurePropagationMode(Enum):
"""Propagation mode for silenceable errors."""
PROPAGATE = 1
SUPPRESS = 2
"""Propagation mode for silenceable errors."""
def _as_int(self):
if self is FailurePropagationMode.PROPAGATE:
return 1
PROPAGATE = 1
SUPPRESS = 2
def _as_int(self):
if self is FailurePropagationMode.PROPAGATE:
return 1
assert self is FailurePropagationMode.SUPPRESS
return 2
assert self is FailurePropagationMode.SUPPRESS
return 2
from .._transform_ops_gen import *
from ..._mlir_libs._mlirDialectsTransform import *

View File

@@ -7,37 +7,37 @@ from ._mlir_libs import _mlirExecutionEngine as _execution_engine
import ctypes
__all__ = [
"ExecutionEngine",
"ExecutionEngine",
]
class ExecutionEngine(_execution_engine.ExecutionEngine):
def lookup(self, name):
"""Lookup a function emitted with the `llvm.emit_c_interface`
attribute and returns a ctype callable.
Raise a RuntimeError if the function isn't found.
"""
func = self.raw_lookup("_mlir_ciface_" + name)
if not func:
raise RuntimeError("Unknown function " + name)
prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
return prototype(func)
def lookup(self, name):
"""Lookup a function emitted with the `llvm.emit_c_interface`
attribute and returns a ctype callable.
Raise a RuntimeError if the function isn't found.
"""
func = self.raw_lookup("_mlir_ciface_" + name)
if not func:
raise RuntimeError("Unknown function " + name)
prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
return prototype(func)
def invoke(self, name, *ctypes_args):
"""Invoke a function with the list of ctypes arguments.
All arguments must be pointers.
Raise a RuntimeError if the function isn't found.
"""
func = self.lookup(name)
packed_args = (ctypes.c_void_p * len(ctypes_args))()
for argNum in range(len(ctypes_args)):
packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
func(packed_args)
def invoke(self, name, *ctypes_args):
"""Invoke a function with the list of ctypes arguments.
All arguments must be pointers.
Raise a RuntimeError if the function isn't found.
"""
func = self.lookup(name)
packed_args = (ctypes.c_void_p * len(ctypes_args))()
for argNum in range(len(ctypes_args)):
packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
func(packed_args)
def register_runtime(self, name, ctypes_callback):
"""Register a runtime function available to the jitted code
under the provided `name`. The `ctypes_callback` must be a
`CFuncType` that outlives the execution engine.
"""
callback = ctypes.cast(ctypes_callback, ctypes.c_void_p)
self.raw_register_runtime("_mlir_ciface_" + name, callback)
def register_runtime(self, name, ctypes_callback):
"""Register a runtime function available to the jitted code
under the provided `name`. The `ctypes_callback` must be a
`CFuncType` that outlives the execution engine.
"""
callback = ctypes.cast(ctypes_callback, ctypes.c_void_p)
self.raw_register_runtime("_mlir_ciface_" + name, callback)

View File

@@ -8,124 +8,123 @@ from ._mlir_libs._mlir.ir import _GlobalDebug
# Convenience decorator for registering user-friendly Attribute builders.
def register_attribute_builder(kind):
def decorator_builder(func):
AttrBuilder.insert(kind, func)
return func
def decorator_builder(func):
AttrBuilder.insert(kind, func)
return func
return decorator_builder
return decorator_builder
@register_attribute_builder("BoolAttr")
def _boolAttr(x, context):
return BoolAttr.get(x, context=context)
return BoolAttr.get(x, context=context)
@register_attribute_builder("IndexAttr")
def _indexAttr(x, context):
return IntegerAttr.get(IndexType.get(context=context), x)
return IntegerAttr.get(IndexType.get(context=context), x)
@register_attribute_builder("I16Attr")
def _i16Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
@register_attribute_builder("I32Attr")
def _i32Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
@register_attribute_builder("I64Attr")
def _i64Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
@register_attribute_builder("SI16Attr")
def _si16Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
@register_attribute_builder("SI32Attr")
def _si32Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
@register_attribute_builder("F32Attr")
def _f32Attr(x, context):
return FloatAttr.get_f32(x, context=context)
return FloatAttr.get_f32(x, context=context)
@register_attribute_builder("F64Attr")
def _f64Attr(x, context):
return FloatAttr.get_f64(x, context=context)
return FloatAttr.get_f64(x, context=context)
@register_attribute_builder("StrAttr")
def _stringAttr(x, context):
return StringAttr.get(x, context=context)
return StringAttr.get(x, context=context)
@register_attribute_builder("SymbolNameAttr")
def _symbolNameAttr(x, context):
return StringAttr.get(x, context=context)
return StringAttr.get(x, context=context)
@register_attribute_builder("SymbolRefAttr")
def _symbolRefAttr(x, context):
return FlatSymbolRefAttr.get(x, context=context)
return FlatSymbolRefAttr.get(x, context=context)
@register_attribute_builder("ArrayAttr")
def _arrayAttr(x, context):
return ArrayAttr.get(x, context=context)
return ArrayAttr.get(x, context=context)
@register_attribute_builder("I32ArrayAttr")
def _i32ArrayAttr(x, context):
return ArrayAttr.get([_i32Attr(v, context) for v in x])
return ArrayAttr.get([_i32Attr(v, context) for v in x])
@register_attribute_builder("I64ArrayAttr")
def _i64ArrayAttr(x, context):
return ArrayAttr.get([_i64Attr(v, context) for v in x])
return ArrayAttr.get([_i64Attr(v, context) for v in x])
@register_attribute_builder("F32ArrayAttr")
def _f32ArrayAttr(x, context):
return ArrayAttr.get([_f32Attr(v, context) for v in x])
return ArrayAttr.get([_f32Attr(v, context) for v in x])
@register_attribute_builder("F64ArrayAttr")
def _f64ArrayAttr(x, context):
return ArrayAttr.get([_f64Attr(v, context) for v in x])
return ArrayAttr.get([_f64Attr(v, context) for v in x])
@register_attribute_builder("DenseI64ArrayAttr")
def _denseI64ArrayAttr(x, context):
return DenseI64ArrayAttr.get(x, context=context)
return DenseI64ArrayAttr.get(x, context=context)
@register_attribute_builder("TypeAttr")
def _typeAttr(x, context):
return TypeAttr.get(x, context=context)
return TypeAttr.get(x, context=context)
@register_attribute_builder("TypeArrayAttr")
def _typeArrayAttr(x, context):
return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
try:
import numpy as np
import numpy as np
@register_attribute_builder("IndexElementsAttr")
def _indexElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int64),
type=IndexType.get(context=context),
context=context,
)
@register_attribute_builder("IndexElementsAttr")
def _indexElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int64),
type=IndexType.get(context=context),
context=context,
)
except ImportError:
pass
pass

View File

@@ -9,131 +9,134 @@ import ctypes
class C128(ctypes.Structure):
"""A ctype representation for MLIR's Double Complex."""
_fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
"""A ctype representation for MLIR's Double Complex."""
_fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
class C64(ctypes.Structure):
"""A ctype representation for MLIR's Float Complex."""
_fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
"""A ctype representation for MLIR's Float Complex."""
_fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
class F16(ctypes.Structure):
"""A ctype representation for MLIR's Float16."""
_fields_ = [("f16", ctypes.c_int16)]
"""A ctype representation for MLIR's Float16."""
_fields_ = [("f16", ctypes.c_int16)]
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
def as_ctype(dtp):
"""Converts dtype to ctype."""
if dtp == np.dtype(np.complex128):
return C128
if dtp == np.dtype(np.complex64):
return C64
if dtp == np.dtype(np.float16):
return F16
return np.ctypeslib.as_ctypes_type(dtp)
"""Converts dtype to ctype."""
if dtp == np.dtype(np.complex128):
return C128
if dtp == np.dtype(np.complex64):
return C64
if dtp == np.dtype(np.float16):
return F16
return np.ctypeslib.as_ctypes_type(dtp)
def to_numpy(array):
"""Converts ctypes array back to numpy dtype array."""
if array.dtype == C128:
return array.view("complex128")
if array.dtype == C64:
return array.view("complex64")
if array.dtype == F16:
return array.view("float16")
return array
"""Converts ctypes array back to numpy dtype array."""
if array.dtype == C128:
return array.view("complex128")
if array.dtype == C64:
return array.view("complex64")
if array.dtype == F16:
return array.view("float16")
return array
def make_nd_memref_descriptor(rank, dtype):
class MemRefDescriptor(ctypes.Structure):
"""Builds an empty descriptor for the given rank/dtype, where rank>0."""
class MemRefDescriptor(ctypes.Structure):
"""Builds an empty descriptor for the given rank/dtype, where rank>0."""
_fields_ = [
("allocated", ctypes.c_longlong),
("aligned", ctypes.POINTER(dtype)),
("offset", ctypes.c_longlong),
("shape", ctypes.c_longlong * rank),
("strides", ctypes.c_longlong * rank),
]
_fields_ = [
("allocated", ctypes.c_longlong),
("aligned", ctypes.POINTER(dtype)),
("offset", ctypes.c_longlong),
("shape", ctypes.c_longlong * rank),
("strides", ctypes.c_longlong * rank),
]
return MemRefDescriptor
return MemRefDescriptor
def make_zero_d_memref_descriptor(dtype):
class MemRefDescriptor(ctypes.Structure):
"""Builds an empty descriptor for the given dtype, where rank=0."""
class MemRefDescriptor(ctypes.Structure):
"""Builds an empty descriptor for the given dtype, where rank=0."""
_fields_ = [
("allocated", ctypes.c_longlong),
("aligned", ctypes.POINTER(dtype)),
("offset", ctypes.c_longlong),
]
_fields_ = [
("allocated", ctypes.c_longlong),
("aligned", ctypes.POINTER(dtype)),
("offset", ctypes.c_longlong),
]
return MemRefDescriptor
return MemRefDescriptor
class UnrankedMemRefDescriptor(ctypes.Structure):
"""Creates a ctype struct for memref descriptor"""
_fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
"""Creates a ctype struct for memref descriptor"""
_fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
def get_ranked_memref_descriptor(nparray):
"""Returns a ranked memref descriptor for the given numpy array."""
ctp = as_ctype(nparray.dtype)
if nparray.ndim == 0:
x = make_zero_d_memref_descriptor(ctp)()
"""Returns a ranked memref descriptor for the given numpy array."""
ctp = as_ctype(nparray.dtype)
if nparray.ndim == 0:
x = make_zero_d_memref_descriptor(ctp)()
x.allocated = nparray.ctypes.data
x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
x.offset = ctypes.c_longlong(0)
return x
x = make_nd_memref_descriptor(nparray.ndim, ctp)()
x.allocated = nparray.ctypes.data
x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
x.offset = ctypes.c_longlong(0)
x.shape = nparray.ctypes.shape
# Numpy uses byte quantities to express strides, MLIR OTOH uses the
# torch abstraction which specifies strides in terms of elements.
strides_ctype_t = ctypes.c_longlong * nparray.ndim
x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
return x
x = make_nd_memref_descriptor(nparray.ndim, ctp)()
x.allocated = nparray.ctypes.data
x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
x.offset = ctypes.c_longlong(0)
x.shape = nparray.ctypes.shape
# Numpy uses byte quantities to express strides, MLIR OTOH uses the
# torch abstraction which specifies strides in terms of elements.
strides_ctype_t = ctypes.c_longlong * nparray.ndim
x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
return x
def get_unranked_memref_descriptor(nparray):
"""Returns a generic/unranked memref descriptor for the given numpy array."""
d = UnrankedMemRefDescriptor()
d.rank = nparray.ndim
x = get_ranked_memref_descriptor(nparray)
d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
return d
"""Returns a generic/unranked memref descriptor for the given numpy array."""
d = UnrankedMemRefDescriptor()
d.rank = nparray.ndim
x = get_ranked_memref_descriptor(nparray)
d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
return d
def unranked_memref_to_numpy(unranked_memref, np_dtype):
"""Converts unranked memrefs to numpy arrays."""
ctp = as_ctype(np_dtype)
descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
strided_arr = np.lib.stride_tricks.as_strided(
np_arr,
np.ctypeslib.as_array(val[0].shape),
np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
)
return to_numpy(strided_arr)
"""Converts unranked memrefs to numpy arrays."""
ctp = as_ctype(np_dtype)
descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
strided_arr = np.lib.stride_tricks.as_strided(
np_arr,
np.ctypeslib.as_array(val[0].shape),
np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
)
return to_numpy(strided_arr)
def ranked_memref_to_numpy(ranked_memref):
"""Converts ranked memrefs to numpy arrays."""
np_arr = np.ctypeslib.as_array(
ranked_memref[0].aligned, shape=ranked_memref[0].shape)
strided_arr = np.lib.stride_tricks.as_strided(
np_arr,
np.ctypeslib.as_array(ranked_memref[0].shape),
np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
)
return to_numpy(strided_arr)
"""Converts ranked memrefs to numpy arrays."""
np_arr = np.ctypeslib.as_array(
ranked_memref[0].aligned, shape=ranked_memref[0].shape
)
strided_arr = np.lib.stride_tricks.as_strided(
np_arr,
np.ctypeslib.as_array(ranked_memref[0].shape),
np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
)
return to_numpy(strided_arr)

View File

@@ -1 +1 @@
config.suffixes.add('.c')
config.suffixes.add(".c")

View File

@@ -1,2 +1,2 @@
if not config.run_cuda_tests:
config.unsupported = True
config.unsupported = True

View File

@@ -1,2 +1,2 @@
if not config.run_rocm_tests:
config.unsupported = True
config.unsupported = True

View File

@@ -1,5 +1,3 @@
# Requires native execution.
if 'host-supports-jit' not in config.available_features:
if "host-supports-jit" not in config.available_features:
config.unsupported = True

View File

@@ -1,5 +1,3 @@
# Requires native execution.
if 'host-supports-jit' not in config.available_features:
if "host-supports-jit" not in config.available_features:
config.unsupported = True

View File

@@ -1,2 +1,2 @@
if not config.build_examples:
config.unsupported = True
config.unsupported = True

View File

@@ -1,13 +1,12 @@
# Disable with sanitizers for now, this require some more setup apparently.
for san in ['asan', 'msan', 'ubsan']:
if (san in config.available_features):
config.unsupported = True
for san in ["asan", "msan", "ubsan"]:
if san in config.available_features:
config.unsupported = True
config.substitutions.append(("%cmake_exe", config.host_cmake))
config.substitutions.append(("%cmake_generator", config.host_cmake_generator))
config.substitutions.append(("%host_cxx", config.host_cxx))
config.substitutions.append(("%host_cc", config.host_cc))
config.substitutions.append(("%enable_libcxx", config.enable_libcxx))
config.substitutions.append(
("%mlir_cmake_dir", config.mlir_cmake_dir))
config.substitutions.append(("%mlir_cmake_dir", config.mlir_cmake_dir))
config.substitutions.append(("%llvm_use_linker", config.llvm_use_linker))

View File

@@ -1,5 +1,5 @@
import sys
# Windows does not have aligned_alloc
if sys.platform == 'win32':
if sys.platform == "win32":
config.unsupported = True

View File

@@ -1,4 +1,4 @@
import platform
if platform.machine() != 'x86_64':
if platform.machine() != "x86_64":
config.unsupported = True

View File

@@ -1,18 +1,22 @@
import sys
lli_cmd = 'lli'
lli_cmd = "lli"
if config.riscv_emulator_lli_executable:
lli_cmd = config.riscv_emulator_lli_executable
config.substitutions.append(('%mlir_native_utils_lib_dir',
config.riscv_emulator_utils_lib_dir or config.mlir_lib_dir))
config.substitutions.append(
(
"%mlir_native_utils_lib_dir",
config.riscv_emulator_utils_lib_dir or config.mlir_lib_dir,
)
)
if config.riscv_vector_emulator_executable:
# Run test in qemu emulator.
emulation_cmd = config.riscv_vector_emulator_executable
if config.riscv_vector_emulator_options:
emulation_cmd = emulation_cmd + ' ' + config.riscv_vector_emulator_options
emulation_cmd = emulation_cmd + ' ' + lli_cmd + ' --march=riscv64 -mattr=+v '
config.substitutions.append(('%lli', emulation_cmd))
emulation_cmd = emulation_cmd + " " + config.riscv_vector_emulator_options
emulation_cmd = emulation_cmd + " " + lli_cmd + " --march=riscv64 -mattr=+v "
config.substitutions.append(("%lli", emulation_cmd))
else:
config.substitutions.append(('%lli', lli_cmd))
config.substitutions.append(("%lli", lli_cmd))

View File

@@ -2,14 +2,16 @@ import sys
from lit.llvm import llvm_config
# FIXME: %mlir_native_utils_lib_dir is set incorrectly on Windows
if sys.platform == 'win32':
if sys.platform == "win32":
config.unsupported = True
# ArmSVE tests must be enabled via build flag.
if config.mlir_run_arm_sve_tests:
config.substitutions.append(('%ENABLE_VLA', 'true'))
config.substitutions.append(('%VLA_ARCH_ATTR_OPTIONS', '--march=aarch64 --mattr="+sve"'))
config.substitutions.append(("%ENABLE_VLA", "true"))
config.substitutions.append(
("%VLA_ARCH_ATTR_OPTIONS", '--march=aarch64 --mattr="+sve"')
)
else:
config.substitutions.append(('%ENABLE_VLA', 'false'))
config.substitutions.append(('%VLA_ARCH_ATTR_OPTIONS', ''))
config.substitutions.append(('%mlir_native_utils_lib_dir', config.mlir_lib_dir))
config.substitutions.append(("%ENABLE_VLA", "false"))
config.substitutions.append(("%VLA_ARCH_ATTR_OPTIONS", ""))
config.substitutions.append(("%mlir_native_utils_lib_dir", config.mlir_lib_dir))

View File

@@ -1,2 +1,2 @@
if not config.enable_cuda_runner or not config.mlir_run_cuda_sm80_tests:
config.unsupported = True
config.unsupported = True

View File

@@ -1,5 +1,5 @@
# Disable ASAN's leak detection for python OpsDSL tests.
config.environment['ASAN_OPTIONS'] = 'detect_leaks=0'
config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
# Only run when python bindings are enabled.
if not config.enable_bindings_python:
config.unsupported = True
config.unsupported = True

View File

@@ -18,42 +18,45 @@ _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
from tools import sparse_compiler
@dsl.linalg_structured_op
def sddmm_dsl(
A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
S=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N),
C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)):
C[dsl.D.m,
dsl.D.n] += S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
):
C[dsl.D.m, dsl.D.n] += (
S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
)
def build_SDDMM(attr: st.EncodingAttr):
"""Build SDDMM kernel.
"""Build SDDMM kernel.
This method generates a linalg op with for matrix multiplication using
just the Python API. Effectively, a generic linalg op is constructed
that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
"""
module = ir.Module.create()
f64 = ir.F64Type.get()
a = ir.RankedTensorType.get([8, 8], f64)
b = ir.RankedTensorType.get([8, 8], f64)
c = ir.RankedTensorType.get([8, 8], f64)
s = ir.RankedTensorType.get([8, 8], f64, attr)
arguments = [a, b, s, c]
with ir.InsertionPoint(module.body):
This method generates a linalg op with for matrix multiplication using
just the Python API. Effectively, a generic linalg op is constructed
that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
"""
module = ir.Module.create()
f64 = ir.F64Type.get()
a = ir.RankedTensorType.get([8, 8], f64)
b = ir.RankedTensorType.get([8, 8], f64)
c = ir.RankedTensorType.get([8, 8], f64)
s = ir.RankedTensorType.get([8, 8], f64, attr)
arguments = [a, b, s, c]
with ir.InsertionPoint(module.body):
@func.FuncOp.from_py_func(*arguments)
def sddmm(*args):
return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]])
@func.FuncOp.from_py_func(*arguments)
def sddmm(*args):
return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]])
return module
return module
def boilerplate(attr: st.EncodingAttr):
"""Returns boilerplate code for main driver."""
return f"""
"""Returns boilerplate code for main driver."""
return f"""
func.func @main(%a: tensor<8x8xf64>,
%b: tensor<8x8xf64>,
%c: tensor<8x8xf64>) -> tensor<8x8xf64> attributes {{ llvm.emit_c_interface }} {{
@@ -69,92 +72,100 @@ func.func @main(%a: tensor<8x8xf64>,
def build_compile_and_run_SDDMMM(attr: st.EncodingAttr, compiler):
# Build.
module = build_SDDMM(attr)
func = str(module.operation.regions[0].blocks[0].operations[0].operation)
module = ir.Module.parse(func + boilerplate(attr))
# Build.
module = build_SDDMM(attr)
func = str(module.operation.regions[0].blocks[0].operations[0].operation)
module = ir.Module.parse(func + boilerplate(attr))
# Compile.
engine = compiler.compile_and_jit(module)
# Compile.
engine = compiler.compile_and_jit(module)
# Set up numpy input and buffer for output.
a = np.array([[1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
[1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
[1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
[1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
[1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
[1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
[1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
[1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8]], np.float64)
b = np.ones((8, 8), np.float64)
c = np.zeros((8, 8), np.float64)
# Set up numpy input and buffer for output.
a = np.array(
[
[1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
[1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
[1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
[1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
[1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
[1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
[1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
[1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8],
],
np.float64,
)
b = np.ones((8, 8), np.float64)
c = np.zeros((8, 8), np.float64)
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
# Allocate a MemRefDescriptor to receive the output tensor.
# The buffer itself is allocated inside the MLIR code generation.
ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
mem_out = ctypes.pointer(ctypes.pointer(ref_out))
# Allocate a MemRefDescriptor to receive the output tensor.
# The buffer itself is allocated inside the MLIR code generation.
ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
mem_out = ctypes.pointer(ctypes.pointer(ref_out))
# Invoke the kernel and get numpy output.
# Built-in bufferization uses in-out buffers.
# TODO: replace with inplace comprehensive bufferization.
engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
# Invoke the kernel and get numpy output.
# Built-in bufferization uses in-out buffers.
# TODO: replace with inplace comprehensive bufferization.
engine.invoke("main", mem_out, mem_a, mem_b, mem_c)
# Sanity check on computed result. Only a few elements
# are sampled from the full dense matrix multiplication.
full_matmul = np.matmul(a, b)
expected = np.zeros((8, 8), np.float64)
expected[0, 0] = 1.0 * full_matmul[0, 0]
expected[0, 2] = 2.0 * full_matmul[0, 2]
expected[4, 1] = 3.0 * full_matmul[4, 1]
c = rt.ranked_memref_to_numpy(mem_out[0])
if np.allclose(c, expected):
pass
else:
quit(f'FAILURE')
# Sanity check on computed result. Only a few elements
# are sampled from the full dense matrix multiplication.
full_matmul = np.matmul(a, b)
expected = np.zeros((8, 8), np.float64)
expected[0, 0] = 1.0 * full_matmul[0, 0]
expected[0, 2] = 2.0 * full_matmul[0, 2]
expected[4, 1] = 3.0 * full_matmul[4, 1]
c = rt.ranked_memref_to_numpy(mem_out[0])
if np.allclose(c, expected):
pass
else:
quit(f"FAILURE")
def main():
support_lib = os.getenv('SUPPORT_LIB')
assert support_lib is not None, 'SUPPORT_LIB is undefined'
if not os.path.exists(support_lib):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
support_lib)
support_lib = os.getenv("SUPPORT_LIB")
assert support_lib is not None, "SUPPORT_LIB is undefined"
if not os.path.exists(support_lib):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
# CHECK-LABEL: TEST: testSDDMMM
print('\nTEST: testSDDMMM')
with ir.Context() as ctx, ir.Location.unknown():
count = 0
# Loop over various ways to compile and annotate the SDDMM kernel with
# a *single* sparse tensor. Note that we deliberate do not exhaustively
# search the full state space to reduce runtime of the test. It is
# straightforward to adapt the code below to explore more combinations.
levels = [[st.DimLevelType.dense, st.DimLevelType.dense],
[st.DimLevelType.dense, st.DimLevelType.compressed],
[st.DimLevelType.compressed, st.DimLevelType.dense],
[st.DimLevelType.compressed, st.DimLevelType.compressed]]
orderings = [
ir.AffineMap.get_permutation([0, 1]),
ir.AffineMap.get_permutation([1, 0])
]
for level in levels:
for ordering in orderings:
for pwidth in [32]:
for iwidth in [32]:
for e in [True]:
attr = st.EncodingAttr.get(level, ordering, None, pwidth,
iwidth)
opt = (f'parallelization-strategy=none')
compiler = sparse_compiler.SparseCompiler(
options=opt, opt_level=0, shared_libs=[support_lib])
build_compile_and_run_SDDMMM(attr, compiler)
count = count + 1
# CHECK: Passed 8 tests
print('Passed ', count, 'tests')
# CHECK-LABEL: TEST: testSDDMMM
print("\nTEST: testSDDMMM")
with ir.Context() as ctx, ir.Location.unknown():
count = 0
# Loop over various ways to compile and annotate the SDDMM kernel with
# a *single* sparse tensor. Note that we deliberate do not exhaustively
# search the full state space to reduce runtime of the test. It is
# straightforward to adapt the code below to explore more combinations.
levels = [
[st.DimLevelType.dense, st.DimLevelType.dense],
[st.DimLevelType.dense, st.DimLevelType.compressed],
[st.DimLevelType.compressed, st.DimLevelType.dense],
[st.DimLevelType.compressed, st.DimLevelType.compressed],
]
orderings = [
ir.AffineMap.get_permutation([0, 1]),
ir.AffineMap.get_permutation([1, 0]),
]
for level in levels:
for ordering in orderings:
for pwidth in [32]:
for iwidth in [32]:
for e in [True]:
attr = st.EncodingAttr.get(
level, ordering, None, pwidth, iwidth
)
opt = f"parallelization-strategy=none"
compiler = sparse_compiler.SparseCompiler(
options=opt, opt_level=0, shared_libs=[support_lib]
)
build_compile_and_run_SDDMMM(attr, compiler)
count = count + 1
# CHECK: Passed 8 tests
print("Passed ", count, "tests")
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()

View File

@@ -18,45 +18,47 @@ _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
from tools import sparse_compiler
@dsl.linalg_structured_op
def matmul_dsl(
A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)):
C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
):
C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
def build_SpMM(attr: st.EncodingAttr):
"""Build SpMM kernel.
"""Build SpMM kernel.
This method generates a linalg op with for matrix multiplication using
just the Python API. Effectively, a generic linalg op is constructed
that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A.
"""
module = ir.Module.create()
f64 = ir.F64Type.get()
a = ir.RankedTensorType.get([3, 4], f64, attr)
b = ir.RankedTensorType.get([4, 2], f64)
c = ir.RankedTensorType.get([3, 2], f64)
arguments = [a, b, c]
with ir.InsertionPoint(module.body):
This method generates a linalg op with for matrix multiplication using
just the Python API. Effectively, a generic linalg op is constructed
that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A.
"""
module = ir.Module.create()
f64 = ir.F64Type.get()
a = ir.RankedTensorType.get([3, 4], f64, attr)
b = ir.RankedTensorType.get([4, 2], f64)
c = ir.RankedTensorType.get([3, 2], f64)
arguments = [a, b, c]
with ir.InsertionPoint(module.body):
@func.FuncOp.from_py_func(*arguments)
def spMxM(*args):
return matmul_dsl(args[0], args[1], outs=[args[2]])
@func.FuncOp.from_py_func(*arguments)
def spMxM(*args):
return matmul_dsl(args[0], args[1], outs=[args[2]])
return module
return module
def boilerplate(attr: st.EncodingAttr):
"""Returns boilerplate main method.
"""Returns boilerplate main method.
This method sets up a boilerplate main method that takes three tensors
(a, b, c), converts the first tensor a into s sparse tensor, and then
calls the sparse kernel for matrix multiplication. For convenience,
this part is purely done as string input.
"""
return f"""
This method sets up a boilerplate main method that takes three tensors
(a, b, c), converts the first tensor a into s sparse tensor, and then
calls the sparse kernel for matrix multiplication. For convenience,
this part is purely done as string input.
"""
return f"""
func.func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) -> tensor<3x2xf64>
attributes {{ llvm.emit_c_interface }} {{
%a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, {attr}>
@@ -69,82 +71,87 @@ func.func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>)
def build_compile_and_run_SpMM(attr: st.EncodingAttr, compiler):
# Build.
module = build_SpMM(attr)
func = str(module.operation.regions[0].blocks[0].operations[0].operation)
module = ir.Module.parse(func + boilerplate(attr))
# Build.
module = build_SpMM(attr)
func = str(module.operation.regions[0].blocks[0].operations[0].operation)
module = ir.Module.parse(func + boilerplate(attr))
# Compile.
engine = compiler.compile_and_jit(module)
# Compile.
engine = compiler.compile_and_jit(module)
# Set up numpy input and buffer for output.
a = np.array(
[[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]],
np.float64)
b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
c = np.zeros((3, 2), np.float64)
# Set up numpy input and buffer for output.
a = np.array(
[[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]], np.float64
)
b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
c = np.zeros((3, 2), np.float64)
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
# Allocate a MemRefDescriptor to receive the output tensor.
# The buffer itself is allocated inside the MLIR code generation.
ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
mem_out = ctypes.pointer(ctypes.pointer(ref_out))
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
# Allocate a MemRefDescriptor to receive the output tensor.
# The buffer itself is allocated inside the MLIR code generation.
ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
mem_out = ctypes.pointer(ctypes.pointer(ref_out))
# Invoke the kernel and get numpy output.
# Built-in bufferization uses in-out buffers.
# TODO: replace with inplace comprehensive bufferization.
engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
# Invoke the kernel and get numpy output.
# Built-in bufferization uses in-out buffers.
# TODO: replace with inplace comprehensive bufferization.
engine.invoke("main", mem_out, mem_a, mem_b, mem_c)
# Sanity check on computed result.
expected = np.matmul(a, b);
c = rt.ranked_memref_to_numpy(mem_out[0])
if np.allclose(c, expected):
pass
else:
quit(f'FAILURE')
# Sanity check on computed result.
expected = np.matmul(a, b)
c = rt.ranked_memref_to_numpy(mem_out[0])
if np.allclose(c, expected):
pass
else:
quit(f"FAILURE")
def main():
support_lib = os.getenv('SUPPORT_LIB')
assert support_lib is not None, 'SUPPORT_LIB is undefined'
if not os.path.exists(support_lib):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
support_lib = os.getenv("SUPPORT_LIB")
assert support_lib is not None, "SUPPORT_LIB is undefined"
if not os.path.exists(support_lib):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
# CHECK-LABEL: TEST: testSpMM
print('\nTEST: testSpMM')
with ir.Context() as ctx, ir.Location.unknown():
count = 0
# Loop over various ways to compile and annotate the SpMM kernel with
# a *single* sparse tensor. Note that we deliberate do not exhaustively
# search the full state space to reduce runtime of the test. It is
# straightforward to adapt the code below to explore more combinations.
# CHECK-LABEL: TEST: testSpMM
print("\nTEST: testSpMM")
with ir.Context() as ctx, ir.Location.unknown():
count = 0
# Loop over various ways to compile and annotate the SpMM kernel with
# a *single* sparse tensor. Note that we deliberate do not exhaustively
# search the full state space to reduce runtime of the test. It is
# straightforward to adapt the code below to explore more combinations.
vl = 1
e = False
opt = (f'parallelization-strategy=none')
levels = [[st.DimLevelType.dense, st.DimLevelType.dense],
[st.DimLevelType.dense, st.DimLevelType.compressed],
[st.DimLevelType.compressed, st.DimLevelType.dense],
[st.DimLevelType.compressed, st.DimLevelType.compressed]]
orderings = [
ir.AffineMap.get_permutation([0, 1]),
ir.AffineMap.get_permutation([1, 0])
]
bitwidths = [0]
compiler = sparse_compiler.SparseCompiler(
options=opt, opt_level=0, shared_libs=[support_lib])
for level in levels:
for ordering in orderings:
for pwidth in bitwidths:
for iwidth in bitwidths:
attr = st.EncodingAttr.get(level, ordering, None, pwidth, iwidth)
build_compile_and_run_SpMM(attr, compiler)
count = count + 1
# CHECK: Passed 8 tests
print('Passed ', count, 'tests')
vl = 1
e = False
opt = f"parallelization-strategy=none"
levels = [
[st.DimLevelType.dense, st.DimLevelType.dense],
[st.DimLevelType.dense, st.DimLevelType.compressed],
[st.DimLevelType.compressed, st.DimLevelType.dense],
[st.DimLevelType.compressed, st.DimLevelType.compressed],
]
orderings = [
ir.AffineMap.get_permutation([0, 1]),
ir.AffineMap.get_permutation([1, 0]),
]
bitwidths = [0]
compiler = sparse_compiler.SparseCompiler(
options=opt, opt_level=0, shared_libs=[support_lib]
)
for level in levels:
for ordering in orderings:
for pwidth in bitwidths:
for iwidth in bitwidths:
attr = st.EncodingAttr.get(
level, ordering, None, pwidth, iwidth
)
build_compile_and_run_SpMM(attr, compiler)
count = count + 1
# CHECK: Passed 8 tests
print("Passed ", count, "tests")
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()

View File

@@ -57,49 +57,52 @@ func.func @main(%ad: tensor<3x4xf64>, %bd: tensor<3x4xf64>) -> tensor<3x4xf64, #
def _run_test(support_lib, kernel):
"""Compiles, runs and checks results."""
compiler = sparse_compiler.SparseCompiler(
options='', opt_level=2, shared_libs=[support_lib])
module = ir.Module.parse(kernel)
engine = compiler.compile_and_jit(module)
"""Compiles, runs and checks results."""
compiler = sparse_compiler.SparseCompiler(
options="", opt_level=2, shared_libs=[support_lib]
)
module = ir.Module.parse(kernel)
engine = compiler.compile_and_jit(module)
# Set up numpy inputs and buffer for output.
a = np.array(
[[1.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 6.6, 0.0]],
np.float64)
b = np.array(
[[1.1, 0.0, 0.0, 2.8], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
np.float64)
# Set up numpy inputs and buffer for output.
a = np.array(
[[1.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 6.6, 0.0]], np.float64
)
b = np.array(
[[1.1, 0.0, 0.0, 2.8], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], np.float64
)
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
# The sparse tensor output is a pointer to pointer of char.
out = ctypes.c_char(0)
mem_out = ctypes.pointer(ctypes.pointer(out))
# The sparse tensor output is a pointer to pointer of char.
out = ctypes.c_char(0)
mem_out = ctypes.pointer(ctypes.pointer(out))
# Invoke the kernel.
engine.invoke('main', mem_a, mem_b, mem_out)
# Invoke the kernel.
engine.invoke("main", mem_a, mem_b, mem_out)
# Retrieve and check the result.
rank, nse, shape, values, indices = test_tools.sparse_tensor_to_coo_tensor(
support_lib, mem_out[0], np.float64)
# Retrieve and check the result.
rank, nse, shape, values, indices = test_tools.sparse_tensor_to_coo_tensor(
support_lib, mem_out[0], np.float64
)
# CHECK: PASSED
if np.allclose(values, [2.2, 2.8, 6.6]) and np.allclose(
indices, [[0, 0], [0, 3], [2, 2]]):
print('PASSED')
else:
quit('FAILURE')
# CHECK: PASSED
if np.allclose(values, [2.2, 2.8, 6.6]) and np.allclose(
indices, [[0, 0], [0, 3], [2, 2]]
):
print("PASSED")
else:
quit("FAILURE")
def test_elementwise_add():
# Obtain path to runtime support library.
support_lib = os.getenv('SUPPORT_LIB')
assert support_lib is not None, 'SUPPORT_LIB is undefined'
assert os.path.exists(support_lib), f'{support_lib} does not exist'
with ir.Context() as ctx, ir.Location.unknown():
_run_test(support_lib, _KERNEL_STR)
# Obtain path to runtime support library.
support_lib = os.getenv("SUPPORT_LIB")
assert support_lib is not None, "SUPPORT_LIB is undefined"
assert os.path.exists(support_lib), f"{support_lib} does not exist"
with ir.Context() as ctx, ir.Location.unknown():
_run_test(support_lib, _KERNEL_STR)
test_elementwise_add()

View File

@@ -18,8 +18,8 @@ from tools import sparse_compiler
# TODO: move more into actual IR building.
def boilerplate(attr: st.EncodingAttr):
"""Returns boilerplate main method."""
return f"""
"""Returns boilerplate main method."""
return f"""
func.func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }} {{
%d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]],
[1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64>
@@ -31,13 +31,13 @@ func.func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }}
def expected():
"""Returns expected contents of output.
"""Returns expected contents of output.
Regardless of the dimension ordering, compression, and bitwidths that are
used in the sparse tensor, the output is always lexicographically sorted
by natural index order.
"""
return f"""; extended FROSTT format
Regardless of the dimension ordering, compression, and bitwidths that are
used in the sparse tensor, the output is always lexicographically sorted
by natural index order.
"""
return f"""; extended FROSTT format
2 5
10 10
1 1 1
@@ -49,53 +49,55 @@ def expected():
def build_compile_and_run_output(attr: st.EncodingAttr, compiler):
# Build and Compile.
module = ir.Module.parse(boilerplate(attr))
engine = compiler.compile_and_jit(module)
# Build and Compile.
module = ir.Module.parse(boilerplate(attr))
engine = compiler.compile_and_jit(module)
# Invoke the kernel and compare output.
with tempfile.TemporaryDirectory() as test_dir:
out = os.path.join(test_dir, 'out.tns')
buf = out.encode('utf-8')
mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
engine.invoke('main', mem_a)
# Invoke the kernel and compare output.
with tempfile.TemporaryDirectory() as test_dir:
out = os.path.join(test_dir, "out.tns")
buf = out.encode("utf-8")
mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
engine.invoke("main", mem_a)
actual = open(out).read()
if actual != expected():
quit('FAILURE')
actual = open(out).read()
if actual != expected():
quit("FAILURE")
def main():
support_lib = os.getenv('SUPPORT_LIB')
assert support_lib is not None, 'SUPPORT_LIB is undefined'
if not os.path.exists(support_lib):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
support_lib)
support_lib = os.getenv("SUPPORT_LIB")
assert support_lib is not None, "SUPPORT_LIB is undefined"
if not os.path.exists(support_lib):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
# CHECK-LABEL: TEST: test_output
print('\nTEST: test_output')
count = 0
with ir.Context() as ctx, ir.Location.unknown():
# Loop over various sparse types: CSR, DCSR, CSC, DCSC.
levels = [[st.DimLevelType.dense, st.DimLevelType.compressed],
[st.DimLevelType.compressed, st.DimLevelType.compressed]]
orderings = [
ir.AffineMap.get_permutation([0, 1]),
ir.AffineMap.get_permutation([1, 0])
]
bitwidths = [8, 16, 32, 64]
compiler = sparse_compiler.SparseCompiler(
options='', opt_level=2, shared_libs=[support_lib])
for level in levels:
for ordering in orderings:
for bwidth in bitwidths:
attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
build_compile_and_run_output(attr, compiler)
count = count + 1
# CHECK-LABEL: TEST: test_output
print("\nTEST: test_output")
count = 0
with ir.Context() as ctx, ir.Location.unknown():
# Loop over various sparse types: CSR, DCSR, CSC, DCSC.
levels = [
[st.DimLevelType.dense, st.DimLevelType.compressed],
[st.DimLevelType.compressed, st.DimLevelType.compressed],
]
orderings = [
ir.AffineMap.get_permutation([0, 1]),
ir.AffineMap.get_permutation([1, 0]),
]
bitwidths = [8, 16, 32, 64]
compiler = sparse_compiler.SparseCompiler(
options="", opt_level=2, shared_libs=[support_lib]
)
for level in levels:
for ordering in orderings:
for bwidth in bitwidths:
attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
build_compile_and_run_output(attr, compiler)
count = count + 1
# CHECK: Passed 16 tests
print('Passed', count, 'tests')
# CHECK: Passed 16 tests
print("Passed", count, "tests")
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()

View File

@@ -28,216 +28,241 @@ from tools import sparse_compiler
# TODO: move this boilerplate to its own module, so it can be used by
# other tests and programs.
class TypeConverter:
"""Converter between NumPy types and MLIR types."""
"""Converter between NumPy types and MLIR types."""
def __init__(self, context: ir.Context):
# Note 1: these are numpy "scalar types" (i.e., the values of
# np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
#
# Note 2: we must construct the MLIR types in the same context as the
# types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
# otherwise, those methods will raise a KeyError.
types_list = [
(np.float64, ir.F64Type.get(context=context)),
(np.float32, ir.F32Type.get(context=context)),
(np.int64, ir.IntegerType.get_signless(64, context=context)),
(np.int32, ir.IntegerType.get_signless(32, context=context)),
(np.int16, ir.IntegerType.get_signless(16, context=context)),
(np.int8, ir.IntegerType.get_signless(8, context=context)),
]
self._sc2ir = dict(types_list)
self._ir2sc = dict(( (ir,sc) for sc,ir in types_list ))
def __init__(self, context: ir.Context):
# Note 1: these are numpy "scalar types" (i.e., the values of
# np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
#
# Note 2: we must construct the MLIR types in the same context as the
# types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
# otherwise, those methods will raise a KeyError.
types_list = [
(np.float64, ir.F64Type.get(context=context)),
(np.float32, ir.F32Type.get(context=context)),
(np.int64, ir.IntegerType.get_signless(64, context=context)),
(np.int32, ir.IntegerType.get_signless(32, context=context)),
(np.int16, ir.IntegerType.get_signless(16, context=context)),
(np.int8, ir.IntegerType.get_signless(8, context=context)),
]
self._sc2ir = dict(types_list)
self._ir2sc = dict(((ir, sc) for sc, ir in types_list))
def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
"""Returns the MLIR equivalent of a NumPy dtype."""
try:
return self.sctype_to_irtype(dtype.type)
except KeyError as e:
raise KeyError(f'Unknown dtype: {dtype}') from e
def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
"""Returns the MLIR equivalent of a NumPy dtype."""
try:
return self.sctype_to_irtype(dtype.type)
except KeyError as e:
raise KeyError(f"Unknown dtype: {dtype}") from e
def sctype_to_irtype(self, sctype) -> ir.Type:
"""Returns the MLIR equivalent of a NumPy scalar type."""
if sctype in self._sc2ir:
return self._sc2ir[sctype]
else:
raise KeyError(f'Unknown sctype: {sctype}')
def sctype_to_irtype(self, sctype) -> ir.Type:
"""Returns the MLIR equivalent of a NumPy scalar type."""
if sctype in self._sc2ir:
return self._sc2ir[sctype]
else:
raise KeyError(f"Unknown sctype: {sctype}")
def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
"""Returns the NumPy dtype equivalent of an MLIR type."""
return np.dtype(self.irtype_to_sctype(tp))
def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
"""Returns the NumPy dtype equivalent of an MLIR type."""
return np.dtype(self.irtype_to_sctype(tp))
def irtype_to_sctype(self, tp: ir.Type):
"""Returns the NumPy scalar-type equivalent of an MLIR type."""
if tp in self._ir2sc:
return self._ir2sc[tp]
else:
raise KeyError(f'Unknown ir.Type: {tp}')
def irtype_to_sctype(self, tp: ir.Type):
"""Returns the NumPy scalar-type equivalent of an MLIR type."""
if tp in self._ir2sc:
return self._ir2sc[tp]
else:
raise KeyError(f"Unknown ir.Type: {tp}")
def get_RankedTensorType_of_nparray(
self, nparray: np.ndarray
) -> ir.RankedTensorType:
"""Returns the ir.RankedTensorType of a NumPy array. Note that NumPy
arrays can only be converted to/from dense tensors, not sparse tensors."""
# TODO: handle strides as well?
return ir.RankedTensorType.get(
nparray.shape, self.dtype_to_irtype(nparray.dtype)
)
def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType:
"""Returns the ir.RankedTensorType of a NumPy array. Note that NumPy
arrays can only be converted to/from dense tensors, not sparse tensors."""
# TODO: handle strides as well?
return ir.RankedTensorType.get(nparray.shape,
self.dtype_to_irtype(nparray.dtype))
# ===----------------------------------------------------------------------=== #
class StressTest:
def __init__(self, tyconv: TypeConverter):
self._tyconv = tyconv
self._roundtripTp = None
self._module = None
self._engine = None
def __init__(self, tyconv: TypeConverter):
self._tyconv = tyconv
self._roundtripTp = None
self._module = None
self._engine = None
def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
assert self._roundtripTp is not None, \
'StressTest: uninitialized roundtrip type'
if tp != self._roundtripTp:
raise AssertionError(
f"Type is not equal to the roundtrip type.\n"
f"\tExpected: {self._roundtripTp}\n"
f"\tFound: {tp}\n")
def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
assert self._roundtripTp is not None, "StressTest: uninitialized roundtrip type"
if tp != self._roundtripTp:
raise AssertionError(
f"Type is not equal to the roundtrip type.\n"
f"\tExpected: {self._roundtripTp}\n"
f"\tFound: {tp}\n"
)
def build(self, types: List[ir.Type]):
"""Builds the ir.Module. The module has only the @main function,
which will convert the input through the list of types and then back
to the initial type. The roundtrip type must be a dense tensor."""
assert self._module is None, 'StressTest: must not call build() repeatedly'
self._module = ir.Module.create()
with ir.InsertionPoint(self._module.body):
tp0 = types.pop(0)
self._roundtripTp = tp0
# TODO: assert dense? assert element type is recognised by the TypeConverter?
types.append(tp0)
funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
funcOp = func.FuncOp(name='main', type=funcTp)
funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get()
with ir.InsertionPoint(funcOp.add_entry_block()):
arg0 = funcOp.entry_block.arguments[0]
self._assertEqualsRoundtripTp(arg0.type)
v = st.ConvertOp(types.pop(0), arg0)
for tp in types:
w = st.ConvertOp(tp, v)
# Release intermediate tensors before they fall out of scope.
bufferization.DeallocTensorOp(v.result)
v = w
self._assertEqualsRoundtripTp(v.result.type)
func.ReturnOp(v)
return self
def build(self, types: List[ir.Type]):
"""Builds the ir.Module. The module has only the @main function,
which will convert the input through the list of types and then back
to the initial type. The roundtrip type must be a dense tensor."""
assert self._module is None, "StressTest: must not call build() repeatedly"
self._module = ir.Module.create()
with ir.InsertionPoint(self._module.body):
tp0 = types.pop(0)
self._roundtripTp = tp0
# TODO: assert dense? assert element type is recognised by the TypeConverter?
types.append(tp0)
funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
funcOp = func.FuncOp(name="main", type=funcTp)
funcOp.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
with ir.InsertionPoint(funcOp.add_entry_block()):
arg0 = funcOp.entry_block.arguments[0]
self._assertEqualsRoundtripTp(arg0.type)
v = st.ConvertOp(types.pop(0), arg0)
for tp in types:
w = st.ConvertOp(tp, v)
# Release intermediate tensors before they fall out of scope.
bufferization.DeallocTensorOp(v.result)
v = w
self._assertEqualsRoundtripTp(v.result.type)
func.ReturnOp(v)
return self
def writeTo(self, filename):
"""Write the ir.Module to the given file. If the file already exists,
then raises an error. If the filename is None, then is a no-op."""
assert self._module is not None, \
'StressTest: must call build() before writeTo()'
if filename is None:
# Silent no-op, for convenience.
return self
if os.path.exists(filename):
raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
with open(filename, 'w') as f:
f.write(str(self._module))
return self
def writeTo(self, filename):
"""Write the ir.Module to the given file. If the file already exists,
then raises an error. If the filename is None, then is a no-op."""
assert (
self._module is not None
), "StressTest: must call build() before writeTo()"
if filename is None:
# Silent no-op, for convenience.
return self
if os.path.exists(filename):
raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
with open(filename, "w") as f:
f.write(str(self._module))
return self
def compile(self, compiler):
"""Compile the ir.Module."""
assert self._module is not None, \
'StressTest: must call build() before compile()'
assert self._engine is None, \
'StressTest: must not call compile() repeatedly'
self._engine = compiler.compile_and_jit(self._module)
return self
def compile(self, compiler):
"""Compile the ir.Module."""
assert (
self._module is not None
), "StressTest: must call build() before compile()"
assert self._engine is None, "StressTest: must not call compile() repeatedly"
self._engine = compiler.compile_and_jit(self._module)
return self
def run(self, np_arg0: np.ndarray) -> np.ndarray:
"""Runs the test on the given numpy array, and returns the resulting
numpy array."""
assert self._engine is not None, "StressTest: must call compile() before run()"
self._assertEqualsRoundtripTp(
self._tyconv.get_RankedTensorType_of_nparray(np_arg0)
)
np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
self._assertEqualsRoundtripTp(
self._tyconv.get_RankedTensorType_of_nparray(np_out)
)
mem_arg0 = ctypes.pointer(
ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0))
)
mem_out = ctypes.pointer(
ctypes.pointer(rt.get_ranked_memref_descriptor(np_out))
)
self._engine.invoke("main", mem_out, mem_arg0)
return rt.ranked_memref_to_numpy(mem_out[0])
def run(self, np_arg0: np.ndarray) -> np.ndarray:
"""Runs the test on the given numpy array, and returns the resulting
numpy array."""
assert self._engine is not None, \
'StressTest: must call compile() before run()'
self._assertEqualsRoundtripTp(
self._tyconv.get_RankedTensorType_of_nparray(np_arg0))
np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
self._assertEqualsRoundtripTp(
self._tyconv.get_RankedTensorType_of_nparray(np_out))
mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0)))
mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out)))
self._engine.invoke('main', mem_out, mem_arg0)
return rt.ranked_memref_to_numpy(mem_out[0])
# ===----------------------------------------------------------------------=== #
def main():
"""
USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
"""
USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
The environment variable SUPPORT_LIB must be set to point to the
libmlir_c_runner_utils shared library. There are two optional
arguments, for debugging purposes. The first argument specifies where
to write out the raw/generated ir.Module. The second argument specifies
where to write out the compiled version of that ir.Module.
"""
support_lib = os.getenv('SUPPORT_LIB')
assert support_lib is not None, 'SUPPORT_LIB is undefined'
if not os.path.exists(support_lib):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
The environment variable SUPPORT_LIB must be set to point to the
libmlir_c_runner_utils shared library. There are two optional
arguments, for debugging purposes. The first argument specifies where
to write out the raw/generated ir.Module. The second argument specifies
where to write out the compiled version of that ir.Module.
"""
support_lib = os.getenv("SUPPORT_LIB")
assert support_lib is not None, "SUPPORT_LIB is undefined"
if not os.path.exists(support_lib):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
# CHECK-LABEL: TEST: test_stress
print("\nTEST: test_stress")
with ir.Context() as ctx, ir.Location.unknown():
# Disable direct sparse2sparse conversion, because it doubles the time!
# TODO: While direct s2s is far too slow for per-commit testing,
# we should have some framework ensure that we run this test with
# `s2s=0` on a regular basis, to ensure that it does continue to work.
# TODO: be sure to test s2s=0 together with singletons.
s2s = 1
sparsification_options = (
f'parallelization-strategy=none '
f's2s-strategy={s2s}')
compiler = sparse_compiler.SparseCompiler(
options=sparsification_options, opt_level=0, shared_libs=[support_lib])
f64 = ir.F64Type.get()
# Be careful about increasing this because
# len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
shape = range(2, 3)
rank = len(shape)
# All combinations.
# TODO: add singleton here too; which requires updating how `np_arg0`
# is initialized below.
levels = list(itertools.product(*itertools.repeat(
[st.DimLevelType.dense, st.DimLevelType.compressed], rank)))
# All permutations.
orderings = list(map(ir.AffineMap.get_permutation,
itertools.permutations(range(rank))))
bitwidths = [0]
# The first type must be a dense tensor for numpy conversion to work.
types = [ir.RankedTensorType.get(shape, f64)]
for level in levels:
for ordering in orderings:
for pwidth in bitwidths:
for iwidth in bitwidths:
attr = st.EncodingAttr.get(level, ordering, None, pwidth, iwidth)
types.append(ir.RankedTensorType.get(shape, f64, attr))
#
# For exhaustiveness we should have one or more StressTest, such
# that their paths cover all 2*n*(n-1) directed pairwise combinations
# of the `types` set. However, since n is already superexponential,
# such exhaustiveness would be prohibitive for a test that runs on
# every commit. So for now we'll just pick one particular path that
# at least hits all n elements of the `types` set.
#
tyconv = TypeConverter(ctx)
size = 1
for d in shape:
size *= d
np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
np_out = (
StressTest(tyconv).build(types).writeTo(
sys.argv[1] if len(sys.argv) > 1 else None).compile(compiler)
.writeTo(sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0))
# CHECK: Passed
if np.allclose(np_out, np_arg0):
print('Passed')
else:
sys.exit('FAILURE')
# CHECK-LABEL: TEST: test_stress
print("\nTEST: test_stress")
with ir.Context() as ctx, ir.Location.unknown():
# Disable direct sparse2sparse conversion, because it doubles the time!
# TODO: While direct s2s is far too slow for per-commit testing,
# we should have some framework ensure that we run this test with
# `s2s=0` on a regular basis, to ensure that it does continue to work.
# TODO: be sure to test s2s=0 together with singletons.
s2s = 1
sparsification_options = f"parallelization-strategy=none " f"s2s-strategy={s2s}"
compiler = sparse_compiler.SparseCompiler(
options=sparsification_options, opt_level=0, shared_libs=[support_lib]
)
f64 = ir.F64Type.get()
# Be careful about increasing this because
# len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
shape = range(2, 3)
rank = len(shape)
# All combinations.
# TODO: add singleton here too; which requires updating how `np_arg0`
# is initialized below.
levels = list(
itertools.product(
*itertools.repeat(
[st.DimLevelType.dense, st.DimLevelType.compressed], rank
)
)
)
# All permutations.
orderings = list(
map(ir.AffineMap.get_permutation, itertools.permutations(range(rank)))
)
bitwidths = [0]
# The first type must be a dense tensor for numpy conversion to work.
types = [ir.RankedTensorType.get(shape, f64)]
for level in levels:
for ordering in orderings:
for pwidth in bitwidths:
for iwidth in bitwidths:
attr = st.EncodingAttr.get(
level, ordering, None, pwidth, iwidth
)
types.append(ir.RankedTensorType.get(shape, f64, attr))
#
# For exhaustiveness we should have one or more StressTest, such
# that their paths cover all 2*n*(n-1) directed pairwise combinations
# of the `types` set. However, since n is already superexponential,
# such exhaustiveness would be prohibitive for a test that runs on
# every commit. So for now we'll just pick one particular path that
# at least hits all n elements of the `types` set.
#
tyconv = TypeConverter(ctx)
size = 1
for d in shape:
size *= d
np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
np_out = (
StressTest(tyconv)
.build(types)
.writeTo(sys.argv[1] if len(sys.argv) > 1 else None)
.compile(compiler)
.writeTo(sys.argv[2] if len(sys.argv) > 2 else None)
.run(np_arg0)
)
# CHECK: Passed
if np.allclose(np_out, np_arg0):
print("Passed")
else:
sys.exit("FAILURE")
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()

View File

@@ -11,65 +11,71 @@ import numpy as np
@functools.lru_cache()
def _get_c_shared_lib(lib_name: str):
"""Loads and returns the requested C shared library.
"""Loads and returns the requested C shared library.
Args:
lib_name: A string representing the C shared library.
Args:
lib_name: A string representing the C shared library.
Returns:
The C shared library.
Returns:
The C shared library.
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routine.
"""
# This raises OSError exception if there is any problem in loading the shared
# library.
c_lib = ctypes.CDLL(lib_name)
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routine.
"""
# This raises OSError exception if there is any problem in loading the shared
# library.
c_lib = ctypes.CDLL(lib_name)
try:
c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p
except Exception as e:
raise ValueError('Missing function convertFromMLIRSparseTensorF64 from '
f'the C shared library: {e} ') from e
try:
c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p
except Exception as e:
raise ValueError(
"Missing function convertFromMLIRSparseTensorF64 from "
f"the C shared library: {e} "
) from e
return c_lib
return c_lib
def sparse_tensor_to_coo_tensor(support_lib, sparse, dtype):
"""Converts a sparse tensor to COO-flavored format.
"""Converts a sparse tensor to COO-flavored format.
Args:
support_lib: A string for the supporting C shared library.
sparse: A ctypes.pointer to the sparse tensor descriptor.
dtype: The numpy data type for the tensor elements.
Args:
support_lib: A string for the supporting C shared library.
sparse: A ctypes.pointer to the sparse tensor descriptor.
dtype: The numpy data type for the tensor elements.
Returns:
A tuple that contains the following values:
rank: An integer for the rank of the tensor.
nse: An integer for the number of non-zero values in the tensor.
shape: A 1D numpy array of integers, for the shape of the tensor.
values: A 1D numpy array, for the non-zero values in the tensor.
indices: A 2D numpy array of integers, representing the indices for the
non-zero values in the tensor.
Returns:
A tuple that contains the following values:
rank: An integer for the rank of the tensor.
nse: An integer for the number of non-zero values in the tensor.
shape: A 1D numpy array of integers, for the shape of the tensor.
values: A 1D numpy array, for the non-zero values in the tensor.
indices: A 2D numpy array of integers, representing the indices for the
non-zero values in the tensor.
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routine.
"""
c_lib = _get_c_shared_lib(support_lib)
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routine.
"""
c_lib = _get_c_shared_lib(support_lib)
rank = ctypes.c_ulonglong(0)
nse = ctypes.c_ulonglong(0)
shape = ctypes.POINTER(ctypes.c_ulonglong)()
values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
indices = ctypes.POINTER(ctypes.c_ulonglong)()
c_lib.convertFromMLIRSparseTensorF64(sparse, ctypes.byref(rank),
ctypes.byref(nse), ctypes.byref(shape),
ctypes.byref(values),
ctypes.byref(indices))
# Convert the returned values to the corresponding numpy types.
shape = np.ctypeslib.as_array(shape, shape=[rank.value])
values = np.ctypeslib.as_array(values, shape=[nse.value])
indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
return rank, nse, shape, values, indices
rank = ctypes.c_ulonglong(0)
nse = ctypes.c_ulonglong(0)
shape = ctypes.POINTER(ctypes.c_ulonglong)()
values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
indices = ctypes.POINTER(ctypes.c_ulonglong)()
c_lib.convertFromMLIRSparseTensorF64(
sparse,
ctypes.byref(rank),
ctypes.byref(nse),
ctypes.byref(shape),
ctypes.byref(values),
ctypes.byref(indices),
)
# Convert the returned values to the corresponding numpy types.
shape = np.ctypeslib.as_array(shape, shape=[rank.value])
values = np.ctypeslib.as_array(values, shape=[nse.value])
indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
return rank, nse, shape, values, indices

View File

@@ -9,30 +9,31 @@ from mlir import ir
from mlir import passmanager
from typing import Sequence
class SparseCompiler:
"""Sparse compiler class for compiling and building MLIR modules."""
"""Sparse compiler class for compiling and building MLIR modules."""
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
pipeline = f'builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})'
self.pipeline = pipeline
self.opt_level = opt_level
self.shared_libs = shared_libs
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
pipeline = f"builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})"
self.pipeline = pipeline
self.opt_level = opt_level
self.shared_libs = shared_libs
def __call__(self, module: ir.Module):
"""Convenience application method."""
self.compile(module)
def __call__(self, module: ir.Module):
"""Convenience application method."""
self.compile(module)
def compile(self, module: ir.Module):
"""Compiles the module by invoking the sparse copmiler pipeline."""
passmanager.PassManager.parse(self.pipeline).run(module.operation)
def compile(self, module: ir.Module):
"""Compiles the module by invoking the sparse copmiler pipeline."""
passmanager.PassManager.parse(self.pipeline).run(module.operation)
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
return execution_engine.ExecutionEngine(
module, opt_level=self.opt_level, shared_libs=self.shared_libs)
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
return execution_engine.ExecutionEngine(
module, opt_level=self.opt_level, shared_libs=self.shared_libs
)
def compile_and_jit(self,
module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles and jits the module."""
self.compile(module)
return self.jit(module)
def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles and jits the module."""
self.compile(module)
return self.jit(module)

View File

@@ -1,5 +1,5 @@
# Disable ASAN's leak detection for python taco tests.
config.environment['ASAN_OPTIONS'] = 'detect_leaks=0'
config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
# Only run when python bindings are enabled.
if not config.enable_bindings_python:
config.unsupported = True
config.unsupported = True

View File

@@ -46,10 +46,10 @@ A[i, j] = B[i, k, l] * D[l, j] * C[k, j]
# Perform the MTTKRP computation and write the result to file.
with tempfile.TemporaryDirectory() as test_dir:
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns")
out_file = os.path.join(test_dir, "A.tns")
pt.write(out_file, A)
#
# CHECK: Compare result True
#
print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns")
out_file = os.path.join(test_dir, "A.tns")
pt.write(out_file, A)
#
# CHECK: Compare result True
#
print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")

View File

@@ -46,13 +46,13 @@ expected = """; extended FROSTT format
# Force evaluation of the kernels by writing out X and Y.
with tempfile.TemporaryDirectory() as test_dir:
x_file = os.path.join(test_dir, "X.tns")
y_file = os.path.join(test_dir, "Y.tns")
pt.write(x_file, X)
pt.write(y_file, Y)
#
# CHECK: Compare result True True
#
x_data = utils.file_as_string(x_file)
y_data = utils.file_as_string(y_file)
print(f"Compare result {x_data == expected} {y_data == expected}")
x_file = os.path.join(test_dir, "X.tns")
y_file = os.path.join(test_dir, "Y.tns")
pt.write(x_file, X)
pt.write(y_file, Y)
#
# CHECK: Compare result True True
#
x_data = utils.file_as_string(x_file)
y_data = utils.file_as_string(y_file)
print(f"Compare result {x_data == expected} {y_data == expected}")

View File

@@ -26,10 +26,10 @@ C[i, j] = A[i, k] * B[k, j]
# Force evaluation of the kernel by writing out C.
with tempfile.TemporaryDirectory() as test_dir:
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
out_file = os.path.join(test_dir, "C.tns")
pt.write(out_file, C)
#
# CHECK: Compare result True
#
print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
out_file = os.path.join(test_dir, "C.tns")
pt.write(out_file, C)
#
# CHECK: Compare result True
#
print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")

View File

@@ -47,10 +47,10 @@ y[i] = A[i, j] * x[j] + z[i]
# Perform the SpMV computation and write the result to file
with tempfile.TemporaryDirectory() as test_dir:
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns")
out_file = os.path.join(test_dir, "y.tns")
pt.write(out_file, y)
#
# CHECK: Compare result True
#
print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns")
out_file = os.path.join(test_dir, "y.tns")
pt.write(out_file, y)
#
# CHECK: Compare result True
#
print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")

View File

@@ -18,11 +18,10 @@ i, j, k, l, m = pt.get_index_vars(5)
alpha = pt.tensor(42.0)
# Set up some sparse tensors with different dim annotations and ordering.
S = pt.tensor([8, 8, 8],
pt.format([pt.compressed, pt.dense, pt.compressed], [1, 0, 2]))
X = pt.tensor([8, 8, 8],
pt.format([pt.compressed, pt.compressed, pt.compressed],
[1, 0, 2]))
S = pt.tensor([8, 8, 8], pt.format([pt.compressed, pt.dense, pt.compressed], [1, 0, 2]))
X = pt.tensor(
[8, 8, 8], pt.format([pt.compressed, pt.compressed, pt.compressed], [1, 0, 2])
)
S.insert([0, 0, 0], 2.0)
S.insert([1, 1, 1], 3.0)
S.insert([4, 4, 4], 4.0)
@@ -32,16 +31,14 @@ X[i, j, k] = alpha[0] * S[i, j, k]
# Set up tensors with a dense last dimension. This results in a full
# enveloping storage of all last "rows" with one or more nonzeros.
T = pt.tensor([1, 2, 3, 4, 5],
pt.format([
pt.compressed, pt.compressed, pt.compressed, pt.compressed,
pt.dense
]))
Y = pt.tensor([1, 2, 3, 4, 5],
pt.format([
pt.compressed, pt.compressed, pt.compressed, pt.compressed,
pt.dense
]))
T = pt.tensor(
[1, 2, 3, 4, 5],
pt.format([pt.compressed, pt.compressed, pt.compressed, pt.compressed, pt.dense]),
)
Y = pt.tensor(
[1, 2, 3, 4, 5],
pt.format([pt.compressed, pt.compressed, pt.compressed, pt.compressed, pt.dense]),
)
T.insert([0, 1, 2, 3, 4], -2.0)
Y[i, j, k, l, m] = alpha[0] * T[i, j, k, l, m]
@@ -85,18 +82,18 @@ z_expected = """; extended FROSTT format
# Force evaluation of the kernel by writing out X.
with tempfile.TemporaryDirectory() as test_dir:
x_file = os.path.join(test_dir, 'X.tns')
pt.write(x_file, X)
y_file = os.path.join(test_dir, 'Y.tns')
pt.write(y_file, Y)
z_file = os.path.join(test_dir, 'Z.tns')
pt.write(z_file, Z)
#
# CHECK: Compare result True True True
#
x_data = utils.file_as_string(x_file)
y_data = utils.file_as_string(y_file)
z_data = utils.file_as_string(z_file)
print(
f'Compare result {x_data == x_expected} {y_data == y_expected} {z_data == z_expected}'
)
x_file = os.path.join(test_dir, "X.tns")
pt.write(x_file, X)
y_file = os.path.join(test_dir, "Y.tns")
pt.write(y_file, Y)
z_file = os.path.join(test_dir, "Z.tns")
pt.write(z_file, Z)
#
# CHECK: Compare result True True True
#
x_data = utils.file_as_string(x_file)
y_data = utils.file_as_string(y_file)
z_data = utils.file_as_string(z_file)
print(
f"Compare result {x_data == x_expected} {y_data == y_expected} {z_data == z_expected}"
)

View File

@@ -12,7 +12,7 @@ compressed = pt.compressed
i, j = pt.get_index_vars(2)
A = pt.tensor([2, 3])
S = pt.tensor(3) # S is a scalar tensor.
S = pt.tensor(3) # S is a scalar tensor.
B = pt.tensor([2, 3], compressed)
A.insert([0, 1], 10)
A.insert([1, 2], 40)
@@ -26,11 +26,11 @@ passed += np.array_equal(values, [30.0, 120.0])
# Sum all the values in A.
S[0] = A[i, j]
passed += (S.get_scalar_value() == 50.0)
passed += S.get_scalar_value() == 50.0
indices, values = S.get_coordinates_and_values()
passed += (len(indices)==0)
passed += (values == 50.0)
passed += len(indices) == 0
passed += values == 50.0
# CHECK: Number of passed: 5
print("Number of passed:", passed)

View File

@@ -12,20 +12,20 @@ compressed = pt.compressed
passed = 0
all_types = [pt.complex64, pt.complex128]
for t in all_types:
i, j = pt.get_index_vars(2)
A = pt.tensor([2, 3], dtype=t)
B = pt.tensor([2, 3], dtype=t)
C = pt.tensor([2, 3], compressed, dtype=t)
A.insert([0, 1], 10 + 20j)
A.insert([1, 2], 40 + 0.5j)
B.insert([0, 0], 20)
B.insert([1, 2], 30 + 15j)
C[i, j] = A[i, j] + B[i, j]
i, j = pt.get_index_vars(2)
A = pt.tensor([2, 3], dtype=t)
B = pt.tensor([2, 3], dtype=t)
C = pt.tensor([2, 3], compressed, dtype=t)
A.insert([0, 1], 10 + 20j)
A.insert([1, 2], 40 + 0.5j)
B.insert([0, 0], 20)
B.insert([1, 2], 30 + 15j)
C[i, j] = A[i, j] + B[i, j]
indices, values = C.get_coordinates_and_values()
passed += isinstance(values[0], t.value)
passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j])
indices, values = C.get_coordinates_and_values()
passed += isinstance(values[0], t.value)
passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j])
# CHECK: Number of passed: 6
print("Number of passed:", passed)

View File

@@ -12,24 +12,22 @@ compressed = pt.compressed
dense = pt.dense
passed = 0
all_types = [
pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64
]
all_types = [pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64]
for t in all_types:
i, j = pt.get_index_vars(2)
A = pt.tensor([2, 3], dtype=t)
B = pt.tensor([2, 3], dtype=t)
C = pt.tensor([2, 3], compressed, dtype=t)
A.insert([0, 1], 10)
A.insert([1, 2], 40)
B.insert([0, 0], 20)
B.insert([1, 2], 30)
C[i, j] = A[i, j] + B[i, j]
i, j = pt.get_index_vars(2)
A = pt.tensor([2, 3], dtype=t)
B = pt.tensor([2, 3], dtype=t)
C = pt.tensor([2, 3], compressed, dtype=t)
A.insert([0, 1], 10)
A.insert([1, 2], 40)
B.insert([0, 0], 20)
B.insert([1, 2], 30)
C[i, j] = A[i, j] + B[i, j]
indices, values = C.get_coordinates_and_values()
passed += isinstance(values[0], t.value)
passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
passed += np.allclose(values, [20.0, 10.0, 70.0])
indices, values = C.get_coordinates_and_values()
passed += isinstance(values[0], t.value)
passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
passed += np.allclose(values, [20.0, 10.0, 70.0])
# CHECK: Number of passed: 21
print("Number of passed:", passed)

View File

@@ -10,8 +10,8 @@ from tools import mlir_pytaco_api as pt
i, j = pt.get_index_vars(2)
# Both tensors are true dense tensors.
A = pt.from_array(np.full([2,3], 1, dtype=np.float64))
B = pt.from_array(np.full([2,3], 2, dtype=np.float64))
A = pt.from_array(np.full([2, 3], 1, dtype=np.float64))
B = pt.from_array(np.full([2, 3], 2, dtype=np.float64))
# Define the result tensor as a true dense tensor. The parameter is_dense=True
# is an MLIR-PyTACO extension.
C = pt.tensor([2, 3], dtype=pt.float64, is_dense=True)

View File

@@ -31,50 +31,52 @@ _MTX_FILENAME_SUFFIX = ".mtx"
_TNS_FILENAME_SUFFIX = ".tns"
def read(filename: str, fmt: Format,
dtype: DType = DType(Type.FLOAT32)) -> Tensor:
"""Inputs a tensor from a given file.
def read(filename: str, fmt: Format, dtype: DType = DType(Type.FLOAT32)) -> Tensor:
"""Inputs a tensor from a given file.
The name suffix of the file specifies the format of the input tensor. We
currently only support .mtx format for support sparse tensors.
The name suffix of the file specifies the format of the input tensor. We
currently only support .mtx format for support sparse tensors.
Args:
filename: A string input filename.
fmt: The storage format of the tensor.
dtype: The data type, default to float32.
Args:
filename: A string input filename.
fmt: The storage format of the tensor.
dtype: The data type, default to float32.
Raises:
ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an
instance of Format or fmt is not a sparse tensor.
"""
if (not isinstance(filename, str) or
(not filename.endswith(_MTX_FILENAME_SUFFIX) and
not filename.endswith(_TNS_FILENAME_SUFFIX))):
raise ValueError("Expected string filename ends with "
f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: "
f"{filename}.")
Raises:
ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an
instance of Format or fmt is not a sparse tensor.
"""
if not isinstance(filename, str) or (
not filename.endswith(_MTX_FILENAME_SUFFIX)
and not filename.endswith(_TNS_FILENAME_SUFFIX)
):
raise ValueError(
"Expected string filename ends with "
f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: "
f"{filename}."
)
return Tensor.from_file(filename, fmt, dtype)
return Tensor.from_file(filename, fmt, dtype)
def write(filename: str, tensor: Tensor) -> None:
"""Outputs a tensor to a given file.
"""Outputs a tensor to a given file.
The name suffix of the file specifies the format of the output. We currently
only support .tns format.
The name suffix of the file specifies the format of the output. We currently
only support .tns format.
Args:
filename: A string output filename.
tensor: The tensor to output.
Args:
filename: A string output filename.
tensor: The tensor to output.
Raises:
ValueError: If filename doesn't end with .tns or tensor is not a Tensor.
"""
if (not isinstance(filename, str) or
not filename.endswith(_TNS_FILENAME_SUFFIX)):
raise ValueError("Expected string filename ends with"
f" {_TNS_FILENAME_SUFFIX}: {filename}.")
if not isinstance(tensor, Tensor):
raise ValueError(f"Expected a Tensor object: {tensor}.")
Raises:
ValueError: If filename doesn't end with .tns or tensor is not a Tensor.
"""
if not isinstance(filename, str) or not filename.endswith(_TNS_FILENAME_SUFFIX):
raise ValueError(
"Expected string filename ends with" f" {_TNS_FILENAME_SUFFIX}: {filename}."
)
if not isinstance(tensor, Tensor):
raise ValueError(f"Expected a Tensor object: {tensor}.")
tensor.to_file(filename)
tensor.to_file(filename)

View File

@@ -36,190 +36,234 @@ _ENTRY_NAME = "main"
@functools.lru_cache()
def _get_support_lib_name() -> str:
"""Gets the string name for the supporting C shared library."""
return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
"""Gets the string name for the supporting C shared library."""
return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
@functools.lru_cache()
def _get_sparse_compiler() -> mlir_sparse_compiler.SparseCompiler:
"""Gets the MLIR sparse compiler with default setting."""
return mlir_sparse_compiler.SparseCompiler(
options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
"""Gets the MLIR sparse compiler with default setting."""
return mlir_sparse_compiler.SparseCompiler(
options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()]
)
def _record_support_funcs(
ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc,
ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None:
"""Records the two supporting functions for a given data type."""
to_func.restype = ctypes.c_void_p
from_func.restype = ctypes.c_void_p
ty_to_funcs[ty] = (to_func, from_func)
ty: np.dtype,
to_func: _SupportFunc,
from_func: _SupportFunc,
ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]],
) -> None:
"""Records the two supporting functions for a given data type."""
to_func.restype = ctypes.c_void_p
from_func.restype = ctypes.c_void_p
ty_to_funcs[ty] = (to_func, from_func)
@functools.lru_cache()
def _get_support_func_locator() -> _SupportFuncLocator:
"""Constructs a function to locate the supporting functions for a data type.
"""Constructs a function to locate the supporting functions for a data type.
Loads the supporting C shared library with the needed routines. Constructs a
dictionary from the supported data types to the routines for the data types,
and then a function to look up the dictionary for a given data type.
Loads the supporting C shared library with the needed routines. Constructs a
dictionary from the supported data types to the routines for the data types,
and then a function to look up the dictionary for a given data type.
The name of the supporting C shared library is either provided by an
an environment variable or a default value.
The name of the supporting C shared library is either provided by an
an environment variable or a default value.
Returns:
The function to look up the supporting functions for a given data type.
Returns:
The function to look up the supporting functions for a given data type.
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routines.
"""
# This raises OSError exception if there is any problem in loading the shared
# library.
c_lib = ctypes.CDLL(_get_support_lib_name())
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routines.
"""
# This raises OSError exception if there is any problem in loading the shared
# library.
c_lib = ctypes.CDLL(_get_support_lib_name())
type_to_funcs = {}
try:
support_types = [(np.int8, c_lib.convertToMLIRSparseTensorI8,
c_lib.convertFromMLIRSparseTensorI8),
(np.int16, c_lib.convertToMLIRSparseTensorI16,
c_lib.convertFromMLIRSparseTensorI16),
(np.int32, c_lib.convertToMLIRSparseTensorI32,
c_lib.convertFromMLIRSparseTensorI32),
(np.int64, c_lib.convertToMLIRSparseTensorI64,
c_lib.convertFromMLIRSparseTensorI64),
(np.float16, c_lib.convertToMLIRSparseTensorF16,
c_lib.convertFromMLIRSparseTensorF16),
(np.float32, c_lib.convertToMLIRSparseTensorF32,
c_lib.convertFromMLIRSparseTensorF32),
(np.float64, c_lib.convertToMLIRSparseTensorF64,
c_lib.convertFromMLIRSparseTensorF64),
(np.complex64, c_lib.convertToMLIRSparseTensorC32,
c_lib.convertFromMLIRSparseTensorC32),
(np.complex128, c_lib.convertToMLIRSparseTensorC64,
c_lib.convertFromMLIRSparseTensorC64)]
except Exception as e:
raise ValueError(f"Missing supporting function: {e}") from e
for i, info in enumerate(support_types):
_record_support_funcs(info[0], info[1], info[2], type_to_funcs)
type_to_funcs = {}
try:
support_types = [
(
np.int8,
c_lib.convertToMLIRSparseTensorI8,
c_lib.convertFromMLIRSparseTensorI8,
),
(
np.int16,
c_lib.convertToMLIRSparseTensorI16,
c_lib.convertFromMLIRSparseTensorI16,
),
(
np.int32,
c_lib.convertToMLIRSparseTensorI32,
c_lib.convertFromMLIRSparseTensorI32,
),
(
np.int64,
c_lib.convertToMLIRSparseTensorI64,
c_lib.convertFromMLIRSparseTensorI64,
),
(
np.float16,
c_lib.convertToMLIRSparseTensorF16,
c_lib.convertFromMLIRSparseTensorF16,
),
(
np.float32,
c_lib.convertToMLIRSparseTensorF32,
c_lib.convertFromMLIRSparseTensorF32,
),
(
np.float64,
c_lib.convertToMLIRSparseTensorF64,
c_lib.convertFromMLIRSparseTensorF64,
),
(
np.complex64,
c_lib.convertToMLIRSparseTensorC32,
c_lib.convertFromMLIRSparseTensorC32,
),
(
np.complex128,
c_lib.convertToMLIRSparseTensorC64,
c_lib.convertFromMLIRSparseTensorC64,
),
]
except Exception as e:
raise ValueError(f"Missing supporting function: {e}") from e
for i, info in enumerate(support_types):
_record_support_funcs(info[0], info[1], info[2], type_to_funcs)
def get_support_funcs(ty: np.dtype):
funcs = type_to_funcs[ty]
assert funcs is not None
return funcs
def get_support_funcs(ty: np.dtype):
funcs = type_to_funcs[ty]
assert funcs is not None
return funcs
return get_support_funcs
return get_support_funcs
def sparse_tensor_to_coo_tensor(
sparse_tensor: ctypes.c_void_p,
dtype: np.dtype,
) -> Tuple[int, int, np.ndarray, np.ndarray, np.ndarray]:
"""Converts an MLIR sparse tensor to a COO-flavored format tensor.
"""Converts an MLIR sparse tensor to a COO-flavored format tensor.
Args:
sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
dtype: The numpy data type for the tensor elements.
Args:
sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
dtype: The numpy data type for the tensor elements.
Returns:
A tuple that contains the following values for the COO-flavored format
tensor:
rank: An integer for the rank of the tensor.
nse: An integer for the number of non-zero values in the tensor.
shape: A 1D numpy array of integers, for the shape of the tensor.
values: A 1D numpy array, for the non-zero values in the tensor.
indices: A 2D numpy array of integers, representing the indices for the
non-zero values in the tensor.
Returns:
A tuple that contains the following values for the COO-flavored format
tensor:
rank: An integer for the rank of the tensor.
nse: An integer for the number of non-zero values in the tensor.
shape: A 1D numpy array of integers, for the shape of the tensor.
values: A 1D numpy array, for the non-zero values in the tensor.
indices: A 2D numpy array of integers, representing the indices for the
non-zero values in the tensor.
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routines.
"""
convert_from = _get_support_func_locator()(dtype)[1]
rank = ctypes.c_ulonglong(0)
nse = ctypes.c_ulonglong(0)
shape = ctypes.POINTER(ctypes.c_ulonglong)()
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routines.
"""
convert_from = _get_support_func_locator()(dtype)[1]
rank = ctypes.c_ulonglong(0)
nse = ctypes.c_ulonglong(0)
shape = ctypes.POINTER(ctypes.c_ulonglong)()
values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))()
indices = ctypes.POINTER(ctypes.c_ulonglong)()
convert_from(sparse_tensor, ctypes.byref(rank), ctypes.byref(nse),
ctypes.byref(shape), ctypes.byref(values), ctypes.byref(indices))
values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))()
indices = ctypes.POINTER(ctypes.c_ulonglong)()
convert_from(
sparse_tensor,
ctypes.byref(rank),
ctypes.byref(nse),
ctypes.byref(shape),
ctypes.byref(values),
ctypes.byref(indices),
)
# Convert the returned values to the corresponding numpy types.
shape = np.ctypeslib.as_array(shape, shape=[rank.value])
values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value]))
indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
return rank.value, nse.value, shape, values, indices
# Convert the returned values to the corresponding numpy types.
shape = np.ctypeslib.as_array(shape, shape=[rank.value])
values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value]))
indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
return rank.value, nse.value, shape, values, indices
def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray,
np_indices: np.ndarray, np_perm: np.ndarray,
np_sparse: np.ndarray) -> int:
"""Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
def coo_tensor_to_sparse_tensor(
np_shape: np.ndarray,
np_values: np.ndarray,
np_indices: np.ndarray,
np_perm: np.ndarray,
np_sparse: np.ndarray,
) -> int:
"""Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
Args:
np_shape: A 1D numpy array of integers, for the shape of the tensor.
np_values: A 1D numpy array, for the non-zero values in the tensor.
np_indices: A 2D numpy array of integers, representing the indices for the
non-zero values in the tensor.
np_perm: A 1D numpy array of integers, representing the storage ordering
for the dimensions.
np_sparse: A 1D numpy array of uint8, representing the sparsity values
for the dimensions.
Args:
np_shape: A 1D numpy array of integers, for the shape of the tensor.
np_values: A 1D numpy array, for the non-zero values in the tensor.
np_indices: A 2D numpy array of integers, representing the indices for the
non-zero values in the tensor.
np_perm: A 1D numpy array of integers, representing the storage ordering
for the dimensions.
np_sparse: A 1D numpy array of uint8, representing the sparsity values
for the dimensions.
Returns:
An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor
descriptor.
Returns:
An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor
descriptor.
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routines.
"""
Raises:
OSError: If there is any problem in loading the shared library.
ValueError: If the shared library doesn't contain the needed routines.
"""
r = len(np_shape)
rank = ctypes.c_ulonglong(r)
nse = ctypes.c_ulonglong(len(np_values))
shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
values = np_values.ctypes.data_as(
ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype))))
indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
r = len(np_shape)
rank = ctypes.c_ulonglong(r)
nse = ctypes.c_ulonglong(len(np_values))
shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
values = np_values.ctypes.data_as(
ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype)))
)
indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
convert_to = _get_support_func_locator()(np_values.dtype.type)[0]
ptr = convert_to(rank, nse, shape, values, indices, perm, sparse)
assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
return ptr
convert_to = _get_support_func_locator()(np_values.dtype.type)[0]
ptr = convert_to(rank, nse, shape, values, indices, perm, sparse)
assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
return ptr
def compile_and_build_engine(
module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles an MLIR module and builds a JIT execution engine.
def compile_and_build_engine(module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles an MLIR module and builds a JIT execution engine.
Args:
module: The MLIR module.
Args:
module: The MLIR module.
Returns:
A JIT execution engine for the MLIR module.
Returns:
A JIT execution engine for the MLIR module.
"""
return _get_sparse_compiler().compile_and_jit(module)
"""
return _get_sparse_compiler().compile_and_jit(module)
class _SparseTensorDescriptor(ctypes.Structure):
"""A C structure for an MLIR sparse tensor."""
_fields_ = [
# A pointer for the MLIR sparse tensor storage.
("storage", ctypes.POINTER(ctypes.c_ulonglong)),
# An MLIR MemRef descriptor for the shape of the sparse tensor.
("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
]
"""A C structure for an MLIR sparse tensor."""
_fields_ = [
# A pointer for the MLIR sparse tensor storage.
("storage", ctypes.POINTER(ctypes.c_ulonglong)),
# An MLIR MemRef descriptor for the shape of the sparse tensor.
("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
]
def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str:
"""Produces the MLIR text code to output the size for the given dimension."""
return f"""
"""Produces the MLIR text code to output the size for the given dimension."""
return f"""
%c{dim} = arith.constant {dim} : index
%d{dim} = tensor.dim %t, %c{dim} : tensor<{shape}x{type}, #enc>
memref.store %d{dim}, %b[%c{dim}] : memref<{rank}xindex>
@@ -233,26 +277,29 @@ def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str:
# (2) Use scf.for instead of an unrolled loop to write out the dimension sizes
# when tensor.dim supports non-constant dimension value.
def _get_create_sparse_tensor_kernel(
sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str) -> str:
"""Creates an MLIR text kernel to contruct a sparse tensor from a file.
sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str
) -> str:
"""Creates an MLIR text kernel to contruct a sparse tensor from a file.
The kernel returns a _SparseTensorDescriptor structure.
"""
rank = len(sparsity_codes)
The kernel returns a _SparseTensorDescriptor structure.
"""
rank = len(sparsity_codes)
# Use ? to represent a dimension in the dynamic shape string representation.
shape = "x".join(map(lambda d: "?", range(rank)))
# Use ? to represent a dimension in the dynamic shape string representation.
shape = "x".join(map(lambda d: "?", range(rank)))
# Convert the encoded sparsity values to a string representation.
sparsity = ", ".join(
map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes))
# Convert the encoded sparsity values to a string representation.
sparsity = ", ".join(
map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes)
)
# Get the MLIR text code to write the dimension sizes to the output buffer.
output_dims = "\n".join(
map(lambda d: _output_one_dim(d, rank, shape, type), range(rank)))
# Get the MLIR text code to write the dimension sizes to the output buffer.
output_dims = "\n".join(
map(lambda d: _output_one_dim(d, rank, shape, type), range(rank))
)
# Return the MLIR text kernel.
return f"""
# Return the MLIR text kernel.
return f"""
!Ptr = !llvm.ptr<i8>
#enc = #sparse_tensor.encoding<{{
lvlTypes = [ {sparsity} ]
@@ -266,69 +313,69 @@ attributes {{ llvm.emit_c_interface }} {{
}}"""
def create_sparse_tensor(filename: str,
sparsity: Sequence[sparse_tensor.DimLevelType],
type: str) -> Tuple[ctypes.c_void_p, np.ndarray]:
"""Creates an MLIR sparse tensor from the input file.
def create_sparse_tensor(
filename: str, sparsity: Sequence[sparse_tensor.DimLevelType], type: str
) -> Tuple[ctypes.c_void_p, np.ndarray]:
"""Creates an MLIR sparse tensor from the input file.
Args:
filename: A string for the name of the file that contains the tensor data in
a COO-flavored format.
sparsity: A sequence of DimLevelType values, one for each dimension of the
tensor.
Args:
filename: A string for the name of the file that contains the tensor data in
a COO-flavored format.
sparsity: A sequence of DimLevelType values, one for each dimension of the
tensor.
Returns:
A Tuple containing the following values:
storage: A ctypes.c_void_p for the MLIR sparse tensor storage.
shape: A 1D numpy array of integers, for the shape of the tensor.
Returns:
A Tuple containing the following values:
storage: A ctypes.c_void_p for the MLIR sparse tensor storage.
shape: A 1D numpy array of integers, for the shape of the tensor.
Raises:
OSError: If there is any problem in loading the supporting C shared library.
ValueError: If the shared library doesn't contain the needed routine.
"""
with ir.Context() as ctx, ir.Location.unknown():
module = _get_create_sparse_tensor_kernel(sparsity, type)
module = ir.Module.parse(module)
engine = compile_and_build_engine(module)
Raises:
OSError: If there is any problem in loading the supporting C shared library.
ValueError: If the shared library doesn't contain the needed routine.
"""
with ir.Context() as ctx, ir.Location.unknown():
module = _get_create_sparse_tensor_kernel(sparsity, type)
module = ir.Module.parse(module)
engine = compile_and_build_engine(module)
# A sparse tensor descriptor to receive the kernel result.
c_tensor_desc = _SparseTensorDescriptor()
# Convert the filename to a byte stream.
c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
# A sparse tensor descriptor to receive the kernel result.
c_tensor_desc = _SparseTensorDescriptor()
# Convert the filename to a byte stream.
c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
arg_pointers = [
ctypes.byref(ctypes.pointer(c_tensor_desc)),
ctypes.byref(c_filename)
]
arg_pointers = [
ctypes.byref(ctypes.pointer(c_tensor_desc)),
ctypes.byref(c_filename),
]
# Invoke the execution engine to run the module and return the result.
engine.invoke(_ENTRY_NAME, *arg_pointers)
shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
return c_tensor_desc.storage, shape
# Invoke the execution engine to run the module and return the result.
engine.invoke(_ENTRY_NAME, *arg_pointers)
shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
return c_tensor_desc.storage, shape
# TODO: With better support from MLIR, we may improve the current implementation
# by using Python code to generate the kernel instead of doing MLIR text code
# stitching.
def _get_output_sparse_tensor_kernel(
sparsity_codes: Sequence[sparse_tensor.DimLevelType],
type: str) -> str:
"""Creates an MLIR text kernel to output a sparse tensor to a file.
sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str
) -> str:
"""Creates an MLIR text kernel to output a sparse tensor to a file.
The kernel returns void.
"""
rank = len(sparsity_codes)
The kernel returns void.
"""
rank = len(sparsity_codes)
# Use ? to represent a dimension in the dynamic shape string representation.
shape = "x".join(map(lambda d: "?", range(rank)))
# Use ? to represent a dimension in the dynamic shape string representation.
shape = "x".join(map(lambda d: "?", range(rank)))
# Convert the encoded sparsity values to a string representation.
sparsity = ", ".join(
map(lambda s: '"compressed"'
if s.value else '"dense"', sparsity_codes))
# Convert the encoded sparsity values to a string representation.
sparsity = ", ".join(
map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes)
)
# Return the MLIR text kernel.
return f"""
# Return the MLIR text kernel.
return f"""
!Ptr = !llvm.ptr<i8>
#enc = #sparse_tensor.encoding<{{
lvlTypes = [ {sparsity} ]
@@ -340,35 +387,38 @@ attributes {{ llvm.emit_c_interface }} {{
}}"""
def output_sparse_tensor(tensor: ctypes.c_void_p, filename: str,
sparsity: Sequence[sparse_tensor.DimLevelType],
type: str) -> None:
"""Outputs an MLIR sparse tensor to the given file.
def output_sparse_tensor(
tensor: ctypes.c_void_p,
filename: str,
sparsity: Sequence[sparse_tensor.DimLevelType],
type: str,
) -> None:
"""Outputs an MLIR sparse tensor to the given file.
Args:
tensor: A C pointer to the MLIR sparse tensor.
filename: A string for the name of the file that contains the tensor data in
a COO-flavored format.
sparsity: A sequence of DimLevelType values, one for each dimension of the
tensor.
type: The MLIR string for the data type.
Args:
tensor: A C pointer to the MLIR sparse tensor.
filename: A string for the name of the file that contains the tensor data in
a COO-flavored format.
sparsity: A sequence of DimLevelType values, one for each dimension of the
tensor.
type: The MLIR string for the data type.
Raises:
OSError: If there is any problem in loading the supporting C shared library.
ValueError: If the shared library doesn't contain the needed routine.
"""
with ir.Context() as ctx, ir.Location.unknown():
module = _get_output_sparse_tensor_kernel(sparsity, type)
module = ir.Module.parse(module)
engine = compile_and_build_engine(module)
Raises:
OSError: If there is any problem in loading the supporting C shared library.
ValueError: If the shared library doesn't contain the needed routine.
"""
with ir.Context() as ctx, ir.Location.unknown():
module = _get_output_sparse_tensor_kernel(sparsity, type)
module = ir.Module.parse(module)
engine = compile_and_build_engine(module)
# Convert the filename to a byte stream.
c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
# Convert the filename to a byte stream.
c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
arg_pointers = [
ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
ctypes.byref(c_filename)
]
arg_pointers = [
ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
ctypes.byref(c_filename),
]
# Invoke the execution engine to run the module and return the result.
engine.invoke(_ENTRY_NAME, *arg_pointers)
# Invoke the execution engine to run the module and return the result.
engine.invoke(_ENTRY_NAME, *arg_pointers)

View File

@@ -13,29 +13,29 @@ from typing import Sequence
class SparseCompiler:
"""Sparse compiler class for compiling and building MLIR modules."""
"""Sparse compiler class for compiling and building MLIR modules."""
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
pipeline = f'builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})'
self.pipeline = pipeline
self.opt_level = opt_level
self.shared_libs = shared_libs
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
pipeline = f"builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})"
self.pipeline = pipeline
self.opt_level = opt_level
self.shared_libs = shared_libs
def __call__(self, module: ir.Module):
"""Convenience application method."""
self.compile(module)
def __call__(self, module: ir.Module):
"""Convenience application method."""
self.compile(module)
def compile(self, module: ir.Module):
"""Compiles the module by invoking the sparse copmiler pipeline."""
passmanager.PassManager.parse(self.pipeline).run(module.operation)
def compile(self, module: ir.Module):
"""Compiles the module by invoking the sparse copmiler pipeline."""
passmanager.PassManager.parse(self.pipeline).run(module.operation)
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
return execution_engine.ExecutionEngine(
module, opt_level=self.opt_level, shared_libs=self.shared_libs)
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
return execution_engine.ExecutionEngine(
module, opt_level=self.opt_level, shared_libs=self.shared_libs
)
def compile_and_jit(self,
module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles and jits the module."""
self.compile(module)
return self.jit(module)
def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles and jits the module."""
self.compile(module)
return self.jit(module)

View File

@@ -8,38 +8,40 @@ import numpy as np
def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool:
"""Compares sparse tensor actual output file with expected output file.
"""Compares sparse tensor actual output file with expected output file.
This routine assumes the input files are in FROSTT format. See
http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format.
This routine assumes the input files are in FROSTT format. See
http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format.
It also assumes the first line in the output file is a comment line.
It also assumes the first line in the output file is a comment line.
"""
with open(actual, "r") as actual_f:
with open(expected, "r") as expected_f:
# Skip the first comment line.
_ = actual_f.readline()
_ = expected_f.readline()
"""
with open(actual, "r") as actual_f:
with open(expected, "r") as expected_f:
# Skip the first comment line.
_ = actual_f.readline()
_ = expected_f.readline()
# Compare the two lines of meta data
if (actual_f.readline() != expected_f.readline() or
actual_f.readline() != expected_f.readline()):
return FALSE
# Compare the two lines of meta data
if (
actual_f.readline() != expected_f.readline()
or actual_f.readline() != expected_f.readline()
):
return FALSE
actual_data = np.loadtxt(actual, np.float64, skiprows=3)
expected_data = np.loadtxt(expected, np.float64, skiprows=3)
return np.allclose(actual_data, expected_data, rtol=rtol)
actual_data = np.loadtxt(actual, np.float64, skiprows=3)
expected_data = np.loadtxt(expected, np.float64, skiprows=3)
return np.allclose(actual_data, expected_data, rtol=rtol)
def file_as_string(file: str) -> str:
"""Returns contents of file as string."""
with open(file, "r") as f:
return f.read()
"""Returns contents of file as string."""
with open(file, "r") as f:
return f.read()
def run_test(f):
"""Prints the test name and runs the test."""
print(f.__name__)
f()
return f
"""Prints the test name and runs the test."""
print(f.__name__)
f()
return f

View File

@@ -18,509 +18,630 @@ _DENSE = mlir_pytaco.ModeFormat.DENSE
def _init_3d(T, I, J, K):
for i in range(I):
for j in range(J):
for k in range(K):
T.insert([i, j, k], i + j + k + 1)
for i in range(I):
for j in range(J):
for k in range(K):
T.insert([i, j, k], i + j + k + 1)
def _init_2d(T, I, J):
for i in range(I):
for j in range(J):
T.insert([i, j], i + j + 1)
for i in range(I):
for j in range(J):
T.insert([i, j], i + j + 1)
def _init_1d_with_value(T, I, v):
for i in range(I):
T.insert([i], v)
for i in range(I):
T.insert([i], v)
def test_expect_error(name, code, error):
"""Executes the code then verifies the expected error message."""
try:
exec(code)
except ValueError as e:
passed = "passed" if (str(e).startswith(error)) else "failed"
print(f"test_{name}: {passed}")
"""Executes the code then verifies the expected error message."""
try:
exec(code)
except ValueError as e:
passed = "passed" if (str(e).startswith(error)) else "failed"
print(f"test_{name}: {passed}")
# CHECK-LABEL: test_tensor_dtype
@testing_utils.run_test
def test_tensor_dtype():
passed = mlir_pytaco.DType(mlir_pytaco.Type.INT16).is_int()
passed += mlir_pytaco.DType(mlir_pytaco.Type.INT32).is_int()
passed += mlir_pytaco.DType(mlir_pytaco.Type.INT64).is_int()
passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32).is_float()
passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64).is_float()
# CHECK: Number of passed: 5
print("Number of passed:", passed)
passed = mlir_pytaco.DType(mlir_pytaco.Type.INT16).is_int()
passed += mlir_pytaco.DType(mlir_pytaco.Type.INT32).is_int()
passed += mlir_pytaco.DType(mlir_pytaco.Type.INT64).is_int()
passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32).is_float()
passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64).is_float()
# CHECK: Number of passed: 5
print("Number of passed:", passed)
# CHECK: test_mode_ordering_not_int: passed
test_expect_error("mode_ordering_not_int",
"m = mlir_pytaco.ModeOrdering(['x'])",
"Ordering must be a list of integers")
test_expect_error(
"mode_ordering_not_int",
"m = mlir_pytaco.ModeOrdering(['x'])",
"Ordering must be a list of integers",
)
# CHECK: test_mode_ordering_not_permutation: passed
test_expect_error("mode_ordering_not_permutation",
"m = mlir_pytaco.ModeOrdering([2, 1])", "Invalid ordering")
test_expect_error(
"mode_ordering_not_permutation",
"m = mlir_pytaco.ModeOrdering([2, 1])",
"Invalid ordering",
)
# CHECK: test_mode_format_invalid: passed
test_expect_error("mode_format_invalid",
"m = mlir_pytaco.ModeFormatPack(['y'])",
"Formats must be a list of ModeFormat")
test_expect_error(
"mode_format_invalid",
"m = mlir_pytaco.ModeFormatPack(['y'])",
"Formats must be a list of ModeFormat",
)
# CHECK: test_expect_mode_format_pack: passed
test_expect_error("expect_mode_format_pack", ("""
test_expect_error(
"expect_mode_format_pack",
(
"""
mode_ordering = mlir_pytaco.ModeOrdering([0, 1, 2])
f = mlir_pytaco.Format(["x"], mode_ordering)
"""), "Expected a list of ModeFormat")
"""
),
"Expected a list of ModeFormat",
)
# CHECK: test_expect_mode_ordering: passed
test_expect_error("expect_mode_ordering", ("""
test_expect_error(
"expect_mode_ordering",
(
"""
mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
f = mlir_pytaco.Format(mode_format_pack, "x")
"""), "Expected ModeOrdering")
"""
),
"Expected ModeOrdering",
)
# CHECK: test_inconsistent_mode_format_pack_and_mode_ordering: passed
test_expect_error("inconsistent_mode_format_pack_and_mode_ordering", ("""
test_expect_error(
"inconsistent_mode_format_pack_and_mode_ordering",
(
"""
mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
mode_ordering = mlir_pytaco.ModeOrdering([0, 1, 2])
f = mlir_pytaco.Format(mode_format_pack, mode_ordering)
"""), "Inconsistent ModeFormatPack and ModeOrdering")
"""
),
"Inconsistent ModeFormatPack and ModeOrdering",
)
# CHECK-LABEL: test_format_default_ordering
@testing_utils.run_test
def test_format_default_ordering():
f = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED])
passed = 0
passed += np.array_equal(f.ordering.ordering, [0, 1])
# CHECK: Number of passed: 1
print("Number of passed:", passed)
f = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED])
passed = 0
passed += np.array_equal(f.ordering.ordering, [0, 1])
# CHECK: Number of passed: 1
print("Number of passed:", passed)
# CHECK-LABEL: test_format_explicit_ordering
@testing_utils.run_test
def test_format_explicit_ordering():
f = mlir_pytaco.Format([_COMPRESSED, _DENSE], [1, 0])
passed = 0
passed += np.array_equal(f.ordering.ordering, [1, 0])
# CHECK: Number of passed: 1
print("Number of passed:", passed)
f = mlir_pytaco.Format([_COMPRESSED, _DENSE], [1, 0])
passed = 0
passed += np.array_equal(f.ordering.ordering, [1, 0])
# CHECK: Number of passed: 1
print("Number of passed:", passed)
# CHECK-LABEL: test_index_var
@testing_utils.run_test
def test_index_var():
i = mlir_pytaco.IndexVar()
j = mlir_pytaco.IndexVar()
passed = (i.name != j.name)
i = mlir_pytaco.IndexVar()
j = mlir_pytaco.IndexVar()
passed = i.name != j.name
vars = mlir_pytaco.get_index_vars(10)
passed += (len(vars) == 10)
passed += (all([isinstance(e, mlir_pytaco.IndexVar) for e in vars]))
vars = mlir_pytaco.get_index_vars(10)
passed += len(vars) == 10
passed += all([isinstance(e, mlir_pytaco.IndexVar) for e in vars])
# CHECK: Number of passed: 3
print("Number of passed:", passed)
# CHECK: Number of passed: 3
print("Number of passed:", passed)
# CHECK: test_tensor_invalid_first_argument: passed
test_expect_error("tensor_invalid_first_argument",
"t = mlir_pytaco.Tensor('f')", "Invalid first argument")
test_expect_error(
"tensor_invalid_first_argument",
"t = mlir_pytaco.Tensor('f')",
"Invalid first argument",
)
# CHECK: test_tensor_inconsistent_shape_and_format: passed
test_expect_error("tensor_inconsistent_shape_and_format", ("""
test_expect_error(
"tensor_inconsistent_shape_and_format",
(
"""
mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
mode_ordering = mlir_pytaco.ModeOrdering([0, 1])
f = mlir_pytaco.Format(mode_format_pack, mode_ordering)
t = mlir_pytaco.Tensor([3], f)
"""), "Inconsistent shape and format")
"""
),
"Inconsistent shape and format",
)
# CHECK: test_tensor_invalid_format: passed
test_expect_error("tensor_invalid_format", "t = mlir_pytaco.Tensor([3], 'f')",
"Invalid format argument")
test_expect_error(
"tensor_invalid_format",
"t = mlir_pytaco.Tensor([3], 'f')",
"Invalid format argument",
)
# CHECK: test_tensor_insert_nonlist_coordinate: passed
test_expect_error("tensor_insert_nonlist_coordinate", ("""
test_expect_error(
"tensor_insert_nonlist_coordinate",
(
"""
t = mlir_pytaco.Tensor([3])
t.insert(1, 0)
"""), "Non list coordinate detected")
"""
),
"Non list coordinate detected",
)
# CHECK: test_tensor_insert_too_much_coordinate: passed
test_expect_error("tensor_insert_too_much_coordinate", ("""
test_expect_error(
"tensor_insert_too_much_coordinate",
(
"""
t = mlir_pytaco.Tensor([3])
t.insert([0, 0], 0)
"""), "Invalid coordinate")
"""
),
"Invalid coordinate",
)
# CHECK: test_tensor_insert_coordinate_outof_range: passed
test_expect_error("tensor_insert_coordinate_outof_range", ("""
test_expect_error(
"tensor_insert_coordinate_outof_range",
(
"""
t = mlir_pytaco.Tensor([1, 1])
t.insert([1, 0], 0)
"""), "Invalid coordinate")
"""
),
"Invalid coordinate",
)
# CHECK: test_tensor_insert_coordinate_nonint: passed
test_expect_error("tensor_insert_coordinate_nonint", ("""
test_expect_error(
"tensor_insert_coordinate_nonint",
(
"""
t = mlir_pytaco.Tensor([1, 1])
t.insert([0, "xy"], 0)
"""), "Non integer coordinate detected")
"""
),
"Non integer coordinate detected",
)
# CHECK: test_tensor_insert_invalid_value: passed
test_expect_error("tensor_insert_invalid_value", ("""
test_expect_error(
"tensor_insert_invalid_value",
(
"""
t = mlir_pytaco.Tensor([1, 1])
t.insert([0, 0], "x")
"""), "Value is neither int nor float")
"""
),
"Value is neither int nor float",
)
# CHECK: test_access_non_index_var_index: passed
test_expect_error("access_non_index_var_index", ("""
test_expect_error(
"access_non_index_var_index",
(
"""
t = mlir_pytaco.Tensor([5, 6])
i = mlir_pytaco.IndexVar()
a = mlir_pytaco.Access(t, (i, "j"))
"""), "Indices contain non IndexVar")
"""
),
"Indices contain non IndexVar",
)
# CHECK: test_access_inconsistent_rank_indices: passed
test_expect_error("access_inconsistent_rank_indices", ("""
test_expect_error(
"access_inconsistent_rank_indices",
(
"""
t = mlir_pytaco.Tensor([5, 6])
i = mlir_pytaco.IndexVar()
a = mlir_pytaco.Access(t, (i,))
"""), "Invalid indices for rank")
"""
),
"Invalid indices for rank",
)
# CHECK: test_access_invalid_indices_for_rank: passed
test_expect_error("access_invalid_indices_for_rank", ("""
test_expect_error(
"access_invalid_indices_for_rank",
(
"""
t = mlir_pytaco.Tensor([5, 6])
i, j, k = mlir_pytaco.get_index_vars(3)
a = mlir_pytaco.Access(t, (i,j, k))
"""), "Invalid indices for rank")
"""
),
"Invalid indices for rank",
)
# CHECK: test_invalid_indices: passed
test_expect_error("invalid_indices", ("""
test_expect_error(
"invalid_indices",
(
"""
i, j = mlir_pytaco.get_index_vars(2)
A = mlir_pytaco.Tensor([2, 3])
B = mlir_pytaco.Tensor([2, 3])
C = mlir_pytaco.Tensor([2, 3], _DENSE)
C[i, j] = A[1, j] + B[i, j]
"""), "Expected IndexVars")
"""
),
"Expected IndexVars",
)
# CHECK: test_inconsistent_rank_indices: passed
test_expect_error("inconsistent_rank_indices", ("""
test_expect_error(
"inconsistent_rank_indices",
(
"""
i, j = mlir_pytaco.get_index_vars(2)
A = mlir_pytaco.Tensor([2, 3])
C = mlir_pytaco.Tensor([2, 3], _DENSE)
C[i, j] = A[i]
"""), "Invalid indices for rank")
"""
),
"Invalid indices for rank",
)
# CHECK: test_destination_index_not_used_in_source: passed
test_expect_error("destination_index_not_used_in_source", ("""
test_expect_error(
"destination_index_not_used_in_source",
(
"""
i, j = mlir_pytaco.get_index_vars(2)
A = mlir_pytaco.Tensor([3])
C = mlir_pytaco.Tensor([3], _DENSE)
C[j] = A[i]
C.evaluate()
"""), "Destination IndexVar not used in the source expression")
"""
),
"Destination IndexVar not used in the source expression",
)
# CHECK: test_destination_dim_not_consistent_with_source: passed
test_expect_error("destination_dim_not_consistent_with_source", ("""
test_expect_error(
"destination_dim_not_consistent_with_source",
(
"""
i = mlir_pytaco.IndexVar()
A = mlir_pytaco.Tensor([3])
C = mlir_pytaco.Tensor([5], _DENSE)
C[i] = A[i]
C.evaluate()
"""), "Inconsistent destination dimension for IndexVar")
"""
),
"Inconsistent destination dimension for IndexVar",
)
# CHECK: test_inconsistent_source_dim: passed
test_expect_error("inconsistent_source_dim", ("""
test_expect_error(
"inconsistent_source_dim",
(
"""
i = mlir_pytaco.IndexVar()
A = mlir_pytaco.Tensor([3])
B = mlir_pytaco.Tensor([5])
C = mlir_pytaco.Tensor([3], _DENSE)
C[i] = A[i] + B[i]
C.evaluate()
"""), "Inconsistent source dimension for IndexVar")
"""
),
"Inconsistent source dimension for IndexVar",
)
# CHECK: test_index_var_outside_domain: passed
test_expect_error("index_var_outside_domain", ("""
test_expect_error(
"index_var_outside_domain",
(
"""
i, j = mlir_pytaco.get_index_vars(2)
A = mlir_pytaco.Tensor([3])
B = mlir_pytaco.Tensor([3])
B[i] = A[i] + j
B.evaluate()
"""), "IndexVar is not part of the iteration domain")
"""
),
"IndexVar is not part of the iteration domain",
)
# CHECK-LABEL: test_tensor_all_dense_sparse
@testing_utils.run_test
def test_tensor_all_dense_sparse():
a = mlir_pytaco.Tensor([4], [_DENSE])
passed = (not a.is_dense())
passed += (a.order == 1)
passed += (a.shape[0] == 4)
# CHECK: Number of passed: 3
print("Number of passed:", passed)
a = mlir_pytaco.Tensor([4], [_DENSE])
passed = not a.is_dense()
passed += a.order == 1
passed += a.shape[0] == 4
# CHECK: Number of passed: 3
print("Number of passed:", passed)
# CHECK-LABEL: test_tensor_true_dense
@testing_utils.run_test
def test_tensor_true_dense():
a = mlir_pytaco.Tensor.from_array(np.random.uniform(size=5))
passed = a.is_dense()
passed += (a.order == 1)
passed += (a.shape[0] == 5)
# CHECK: Number of passed: 3
print("Number of passed:", passed)
a = mlir_pytaco.Tensor.from_array(np.random.uniform(size=5))
passed = a.is_dense()
passed += a.order == 1
passed += a.shape[0] == 5
# CHECK: Number of passed: 3
print("Number of passed:", passed)
# CHECK-LABEL: test_tensor_copy
@testing_utils.run_test
def test_tensor_copy():
i, j = mlir_pytaco.get_index_vars(2)
I = 2
J = 3
A = mlir_pytaco.Tensor([I, J])
A.insert([0, 1], 5.0)
A.insert([1, 2], 6.0)
B = mlir_pytaco.Tensor([I, J])
B[i, j] = A[i, j]
passed = (B._assignment is not None)
passed += (B._engine is None)
try:
i, j = mlir_pytaco.get_index_vars(2)
I = 2
J = 3
A = mlir_pytaco.Tensor([I, J])
A.insert([0, 1], 5.0)
A.insert([1, 2], 6.0)
B = mlir_pytaco.Tensor([I, J])
B[i, j] = A[i, j]
passed = B._assignment is not None
passed += B._engine is None
try:
B.compute()
except ValueError as e:
passed += str(e).startswith("Need to invoke compile")
B.compile()
passed += B._engine is not None
B.compute()
except ValueError as e:
passed += (str(e).startswith("Need to invoke compile"))
B.compile()
passed += (B._engine is not None)
B.compute()
passed += (B._assignment is None)
passed += (B._engine is None)
indices, values = B.get_coordinates_and_values()
passed += np.array_equal(indices, [[0, 1], [1, 2]])
passed += np.allclose(values, [5.0, 6.0])
# No temporary tensor is used.
passed += (B._stats.get_total() == 0)
# CHECK: Number of passed: 9
print("Number of passed:", passed)
passed += B._assignment is None
passed += B._engine is None
indices, values = B.get_coordinates_and_values()
passed += np.array_equal(indices, [[0, 1], [1, 2]])
passed += np.allclose(values, [5.0, 6.0])
# No temporary tensor is used.
passed += B._stats.get_total() == 0
# CHECK: Number of passed: 9
print("Number of passed:", passed)
# CHECK-LABEL: test_tensor_trivial_reduction
@testing_utils.run_test
def test_tensor_trivial_reduction():
i, j = mlir_pytaco.get_index_vars(2)
I = 2
J = 3
A = mlir_pytaco.Tensor([I, J])
A.insert([0, 1], 5.0)
A.insert([0, 2], 3.0)
A.insert([1, 2], 6.0)
B = mlir_pytaco.Tensor([I])
B[i] = A[i, j]
indices, values = B.get_coordinates_and_values()
passed = np.array_equal(indices, [[0], [1]])
passed += np.allclose(values, [8.0, 6.0])
# No temporary tensor is used.
passed += (B._stats.get_total() == 0)
i, j = mlir_pytaco.get_index_vars(2)
I = 2
J = 3
A = mlir_pytaco.Tensor([I, J])
A.insert([0, 1], 5.0)
A.insert([0, 2], 3.0)
A.insert([1, 2], 6.0)
B = mlir_pytaco.Tensor([I])
B[i] = A[i, j]
indices, values = B.get_coordinates_and_values()
passed = np.array_equal(indices, [[0], [1]])
passed += np.allclose(values, [8.0, 6.0])
# No temporary tensor is used.
passed += B._stats.get_total() == 0
# CHECK: Number of passed: 3
print("Number of passed:", passed)
# CHECK: Number of passed: 3
print("Number of passed:", passed)
# CHECK-LABEL: test_binary_add
@testing_utils.run_test
def test_binary_add():
i = mlir_pytaco.IndexVar()
A = mlir_pytaco.Tensor([4])
B = mlir_pytaco.Tensor([4])
C = mlir_pytaco.Tensor([4])
A.insert([1], 10)
A.insert([2], 1)
B.insert([3], 20)
B.insert([2], 2)
C[i] = A[i] + B[i]
indices, values = C.get_coordinates_and_values()
passed = np.array_equal(indices, [[1], [2], [3]])
passed += np.array_equal(values, [10., 3., 20.])
# No temporary tensor is used.
passed += (C._stats.get_total() == 0)
# CHECK: Number of passed: 3
print("Number of passed:", passed)
i = mlir_pytaco.IndexVar()
A = mlir_pytaco.Tensor([4])
B = mlir_pytaco.Tensor([4])
C = mlir_pytaco.Tensor([4])
A.insert([1], 10)
A.insert([2], 1)
B.insert([3], 20)
B.insert([2], 2)
C[i] = A[i] + B[i]
indices, values = C.get_coordinates_and_values()
passed = np.array_equal(indices, [[1], [2], [3]])
passed += np.array_equal(values, [10.0, 3.0, 20.0])
# No temporary tensor is used.
passed += C._stats.get_total() == 0
# CHECK: Number of passed: 3
print("Number of passed:", passed)
# CHECK-LABEL: test_binary_add_sub
@testing_utils.run_test
def test_binary_add_sub():
i = mlir_pytaco.IndexVar()
j = mlir_pytaco.IndexVar()
A = mlir_pytaco.Tensor([2, 3])
B = mlir_pytaco.Tensor([2, 3])
C = mlir_pytaco.Tensor([2, 3])
D = mlir_pytaco.Tensor([2, 3])
A.insert([0, 1], 10)
A.insert([1, 2], 40)
B.insert([0, 0], 20)
B.insert([1, 2], 30)
C.insert([0, 1], 5)
C.insert([1, 2], 7)
D[i, j] = A[i, j] + B[i, j] - C[i, j]
indices, values = D.get_coordinates_and_values()
passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
passed += np.array_equal(values, [20., 5., 63.])
# No temporary tensor is used.
passed += (D._stats.get_total() == 0)
# CHECK: Number of passed: 3
print("Number of passed:", passed)
i = mlir_pytaco.IndexVar()
j = mlir_pytaco.IndexVar()
A = mlir_pytaco.Tensor([2, 3])
B = mlir_pytaco.Tensor([2, 3])
C = mlir_pytaco.Tensor([2, 3])
D = mlir_pytaco.Tensor([2, 3])
A.insert([0, 1], 10)
A.insert([1, 2], 40)
B.insert([0, 0], 20)
B.insert([1, 2], 30)
C.insert([0, 1], 5)
C.insert([1, 2], 7)
D[i, j] = A[i, j] + B[i, j] - C[i, j]
indices, values = D.get_coordinates_and_values()
passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
passed += np.array_equal(values, [20.0, 5.0, 63.0])
# No temporary tensor is used.
passed += D._stats.get_total() == 0
# CHECK: Number of passed: 3
print("Number of passed:", passed)
# CHECK-LABEL: test_binary_mul_add
@testing_utils.run_test
def test_binary_mul_add():
i = mlir_pytaco.IndexVar()
j = mlir_pytaco.IndexVar()
A = mlir_pytaco.Tensor([2, 3])
B = mlir_pytaco.Tensor([2, 3])
C = mlir_pytaco.Tensor([2, 3])
D = mlir_pytaco.Tensor([2, 3])
A.insert([0, 1], 10)
A.insert([1, 2], 40)
B.insert([0, 0], 20)
B.insert([1, 2], 30)
C.insert([0, 1], 5)
C.insert([1, 2], 7)
D[i, j] = A[i, j] * C[i, j] + B[i, j]
indices, values = D.get_coordinates_and_values()
passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
passed += np.array_equal(values, [20., 50., 310.])
# No temporary tensor is used.
passed += (D._stats.get_total() == 0)
# CHECK: Number of passed: 3
print("Number of passed:", passed)
i = mlir_pytaco.IndexVar()
j = mlir_pytaco.IndexVar()
A = mlir_pytaco.Tensor([2, 3])
B = mlir_pytaco.Tensor([2, 3])
C = mlir_pytaco.Tensor([2, 3])
D = mlir_pytaco.Tensor([2, 3])
A.insert([0, 1], 10)
A.insert([1, 2], 40)
B.insert([0, 0], 20)
B.insert([1, 2], 30)
C.insert([0, 1], 5)
C.insert([1, 2], 7)
D[i, j] = A[i, j] * C[i, j] + B[i, j]
indices, values = D.get_coordinates_and_values()
passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
passed += np.array_equal(values, [20.0, 50.0, 310.0])
# No temporary tensor is used.
passed += D._stats.get_total() == 0
# CHECK: Number of passed: 3
print("Number of passed:", passed)
# CHECK-LABEL: test_binary_add_reduce_at_root
@testing_utils.run_test
def test_binary_add_reduce_at_root():
i = mlir_pytaco.IndexVar()
j = mlir_pytaco.IndexVar()
A = mlir_pytaco.Tensor([2, 3])
B = mlir_pytaco.Tensor([2, 3])
C = mlir_pytaco.Tensor([2], _DENSE)
A.insert([0, 1], 10)
A.insert([1, 2], 40)
B.insert([0, 0], 20)
B.insert([1, 2], 30)
C[i] = A[i, j] + B[i, j]
indices, values = C.get_coordinates_and_values()
passed = np.array_equal(indices, [[0], [1]])
passed += np.array_equal(values, [30., 70.])
# No temporary tensor is used.
passed += (C._stats.get_total() == 0)
# CHECK: Number of passed: 3
print("Number of passed:", passed)
i = mlir_pytaco.IndexVar()
j = mlir_pytaco.IndexVar()
A = mlir_pytaco.Tensor([2, 3])
B = mlir_pytaco.Tensor([2, 3])
C = mlir_pytaco.Tensor([2], _DENSE)
A.insert([0, 1], 10)
A.insert([1, 2], 40)
B.insert([0, 0], 20)
B.insert([1, 2], 30)
C[i] = A[i, j] + B[i, j]
indices, values = C.get_coordinates_and_values()
passed = np.array_equal(indices, [[0], [1]])
passed += np.array_equal(values, [30.0, 70.0])
# No temporary tensor is used.
passed += C._stats.get_total() == 0
# CHECK: Number of passed: 3
print("Number of passed:", passed)
# CHECK-LABEL: test_binary_add_reduce_at_child
@testing_utils.run_test
def test_binary_add_reduce_at_child():
i = mlir_pytaco.IndexVar()
j = mlir_pytaco.IndexVar()
I = 2
J = 3
A = mlir_pytaco.Tensor([I, J])
B = mlir_pytaco.Tensor([J])
C = mlir_pytaco.Tensor([I])
D = mlir_pytaco.Tensor([I], _DENSE)
i = mlir_pytaco.IndexVar()
j = mlir_pytaco.IndexVar()
I = 2
J = 3
A = mlir_pytaco.Tensor([I, J])
B = mlir_pytaco.Tensor([J])
C = mlir_pytaco.Tensor([I])
D = mlir_pytaco.Tensor([I], _DENSE)
_init_2d(A, I, J)
_init_1d_with_value(C, I, 2)
_init_1d_with_value(B, J, 1)
_init_2d(A, I, J)
_init_1d_with_value(C, I, 2)
_init_1d_with_value(B, J, 1)
D[i] = A[i, j] * B[j] + C[i]
indices, values = D.get_coordinates_and_values()
passed = np.array_equal(indices, [[0], [1]])
passed += np.array_equal(values, [8., 11.])
D[i] = A[i, j] * B[j] + C[i]
indices, values = D.get_coordinates_and_values()
passed = np.array_equal(indices, [[0], [1]])
passed += np.array_equal(values, [8.0, 11.0])
# The expression is implemented as:
# temp0[i] = A[i, j] * B[i]
# D[i] = temp0[i] + C[i]
# Check the temporary tensor introduced by the implementation.
stats = D._stats
passed += (stats.get_total() == 1)
passed += (stats.get_formats(0) == (_COMPRESSED,))
passed += (stats.get_dimensions(0) == (I,))
# CHECK: Number of passed: 5
print("Number of passed:", passed)
# The expression is implemented as:
# temp0[i] = A[i, j] * B[i]
# D[i] = temp0[i] + C[i]
# Check the temporary tensor introduced by the implementation.
stats = D._stats
passed += stats.get_total() == 1
passed += stats.get_formats(0) == (_COMPRESSED,)
passed += stats.get_dimensions(0) == (I,)
# CHECK: Number of passed: 5
print("Number of passed:", passed)
# CHECK-LABEL: test_binary_add_reduce_3d_1
@testing_utils.run_test
def test_binary_add_reduce_3d_1():
i, j, k, l = mlir_pytaco.get_index_vars(4)
I = 2
J = 3
K = 4
L = 5
A = mlir_pytaco.Tensor([I, J, K])
B = mlir_pytaco.Tensor([I, J, L])
C = mlir_pytaco.Tensor([K])
D = mlir_pytaco.Tensor([L])
E = mlir_pytaco.Tensor([I], _DENSE)
i, j, k, l = mlir_pytaco.get_index_vars(4)
I = 2
J = 3
K = 4
L = 5
A = mlir_pytaco.Tensor([I, J, K])
B = mlir_pytaco.Tensor([I, J, L])
C = mlir_pytaco.Tensor([K])
D = mlir_pytaco.Tensor([L])
E = mlir_pytaco.Tensor([I], _DENSE)
_init_3d(A, I, J, K)
_init_3d(B, I, J, L)
_init_1d_with_value(C, K, 1)
_init_1d_with_value(D, L, 2)
_init_3d(A, I, J, K)
_init_3d(B, I, J, L)
_init_1d_with_value(C, K, 1)
_init_1d_with_value(D, L, 2)
E[i] = A[i, j, k] * C[k] + B[i, j, l] * D[l]
indices, values = E.get_coordinates_and_values()
passed = np.array_equal(indices, [[0], [1]])
passed += np.array_equal(values, [162., 204.])
E[i] = A[i, j, k] * C[k] + B[i, j, l] * D[l]
indices, values = E.get_coordinates_and_values()
passed = np.array_equal(indices, [[0], [1]])
passed += np.array_equal(values, [162.0, 204.0])
# The expression is implemented as:
# temp0[i, j] = A[i, j, k] * C[k]
# temp1[i, j] = B[i, j, l] * D[l]
# E[i] = temp0[i, j] + temp1[i, j]
# Check the two temporary tensors introduced by the implementation.
stats = E._stats
passed += (stats.get_total() == 2)
passed += (stats.get_formats(0) == (_COMPRESSED, _COMPRESSED))
passed += (stats.get_dimensions(0) == (I, J))
passed += (stats.get_formats(1) == (_COMPRESSED, _COMPRESSED))
passed += (stats.get_dimensions(1) == (I, J))
# CHECK: Number of passed: 7
print("Number of passed:", passed)
# The expression is implemented as:
# temp0[i, j] = A[i, j, k] * C[k]
# temp1[i, j] = B[i, j, l] * D[l]
# E[i] = temp0[i, j] + temp1[i, j]
# Check the two temporary tensors introduced by the implementation.
stats = E._stats
passed += stats.get_total() == 2
passed += stats.get_formats(0) == (_COMPRESSED, _COMPRESSED)
passed += stats.get_dimensions(0) == (I, J)
passed += stats.get_formats(1) == (_COMPRESSED, _COMPRESSED)
passed += stats.get_dimensions(1) == (I, J)
# CHECK: Number of passed: 7
print("Number of passed:", passed)
# CHECK-LABEL: test_binary_add_reduce_3d_2
@testing_utils.run_test
def test_binary_add_reduce_3d_2():
i, j, k, l = mlir_pytaco.get_index_vars(4)
I = 2
J = 3
K = 4
L = 5
A = mlir_pytaco.Tensor([I, J, K], [_COMPRESSED, _COMPRESSED, _DENSE])
B = mlir_pytaco.Tensor([I, L, K], [_DENSE, _COMPRESSED, _COMPRESSED])
C = mlir_pytaco.Tensor([J, K], [_COMPRESSED, _COMPRESSED])
D = mlir_pytaco.Tensor([L])
E = mlir_pytaco.Tensor([I], _DENSE)
i, j, k, l = mlir_pytaco.get_index_vars(4)
I = 2
J = 3
K = 4
L = 5
A = mlir_pytaco.Tensor([I, J, K], [_COMPRESSED, _COMPRESSED, _DENSE])
B = mlir_pytaco.Tensor([I, L, K], [_DENSE, _COMPRESSED, _COMPRESSED])
C = mlir_pytaco.Tensor([J, K], [_COMPRESSED, _COMPRESSED])
D = mlir_pytaco.Tensor([L])
E = mlir_pytaco.Tensor([I], _DENSE)
_init_3d(A, I, J, K)
_init_3d(B, I, L, K)
_init_2d(C, J, K)
_init_1d_with_value(D, L, 2)
_init_3d(A, I, J, K)
_init_3d(B, I, L, K)
_init_2d(C, J, K)
_init_1d_with_value(D, L, 2)
E[i] = A[i, j, k] + C[j, k] + B[i, l, k] * D[l]
indices, values = E.get_coordinates_and_values()
passed = np.array_equal(indices, [[0], [1]])
passed += np.array_equal(values, [264., 316.])
E[i] = A[i, j, k] + C[j, k] + B[i, l, k] * D[l]
indices, values = E.get_coordinates_and_values()
passed = np.array_equal(indices, [[0], [1]])
passed += np.array_equal(values, [264.0, 316.0])
# The expression is implemented as:
# temp0[i, k] = A[i, j, k] + C[j, k]
# temp1[i, k] = B[i, l, k] * D[l]
# E[i] = temp0[i, k] + temp1[i, k]
# Check the two temporary tensors introduced by the implementation.
stats = E._stats
passed += (stats.get_total() == 2)
passed += (stats.get_formats(0) == (_COMPRESSED, _DENSE))
passed += (stats.get_dimensions(0) == (I, K))
passed += (stats.get_formats(1) == (_DENSE, _COMPRESSED))
passed += (stats.get_dimensions(1) == (I, K))
# CHECK: Number of passed: 7
print("Number of passed:", passed)
# The expression is implemented as:
# temp0[i, k] = A[i, j, k] + C[j, k]
# temp1[i, k] = B[i, l, k] * D[l]
# E[i] = temp0[i, k] + temp1[i, k]
# Check the two temporary tensors introduced by the implementation.
stats = E._stats
passed += stats.get_total() == 2
passed += stats.get_formats(0) == (_COMPRESSED, _DENSE)
passed += stats.get_dimensions(0) == (I, K)
passed += stats.get_formats(1) == (_DENSE, _COMPRESSED)
passed += stats.get_dimensions(1) == (I, K)
# CHECK: Number of passed: 7
print("Number of passed:", passed)

View File

@@ -32,21 +32,21 @@ _MTX_DATA = """%%MatrixMarket matrix coordinate real general
# CHECK-LABEL: test_read_mtx_matrix_general
@testing_utils.run_test
def test_read_mtx_matrix_general():
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.mtx")
with open(file_name, "w") as file:
file.write(_MTX_DATA)
a = mlir_pytaco_io.read(file_name, _FORMAT)
passed = 0
# The value of a is stored as an MLIR sparse tensor.
passed += (not a.is_unpacked())
a.unpack()
passed += (a.is_unpacked())
coords, values = a.get_coordinates_and_values()
passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
passed += np.allclose(values, [2.0, 3.0, 4.0])
# CHECK: 4
print(passed)
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.mtx")
with open(file_name, "w") as file:
file.write(_MTX_DATA)
a = mlir_pytaco_io.read(file_name, _FORMAT)
passed = 0
# The value of a is stored as an MLIR sparse tensor.
passed += not a.is_unpacked()
a.unpack()
passed += a.is_unpacked()
coords, values = a.get_coordinates_and_values()
passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
passed += np.allclose(values, [2.0, 3.0, 4.0])
# CHECK: 4
print(passed)
_TNS_DATA = """2 3
@@ -60,57 +60,57 @@ _TNS_DATA = """2 3
# CHECK-LABEL: test_read_tns
@testing_utils.run_test
def test_read_tns():
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.tns")
with open(file_name, "w") as file:
file.write(_TNS_DATA)
a = mlir_pytaco_io.read(file_name, _FORMAT)
passed = 0
# The value of a is stored as an MLIR sparse tensor.
passed += (not a.is_unpacked())
a.unpack()
passed += (a.is_unpacked())
coords, values = a.get_coordinates_and_values()
passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
passed += np.allclose(values, [2.0, 3.0, 4.0])
# CHECK: 4
print(passed)
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.tns")
with open(file_name, "w") as file:
file.write(_TNS_DATA)
a = mlir_pytaco_io.read(file_name, _FORMAT)
passed = 0
# The value of a is stored as an MLIR sparse tensor.
passed += not a.is_unpacked()
a.unpack()
passed += a.is_unpacked()
coords, values = a.get_coordinates_and_values()
passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
passed += np.allclose(values, [2.0, 3.0, 4.0])
# CHECK: 4
print(passed)
# CHECK-LABEL: test_write_unpacked_tns
@testing_utils.run_test
def test_write_unpacked_tns():
a = mlir_pytaco.Tensor([2, 3])
a.insert([0, 1], 10)
a.insert([1, 2], 40)
a.insert([0, 0], 20)
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.tns")
try:
mlir_pytaco_io.write(file_name, a)
except ValueError as e:
# CHECK: Writing unpacked sparse tensors to file is not supported
print(e)
a = mlir_pytaco.Tensor([2, 3])
a.insert([0, 1], 10)
a.insert([1, 2], 40)
a.insert([0, 0], 20)
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.tns")
try:
mlir_pytaco_io.write(file_name, a)
except ValueError as e:
# CHECK: Writing unpacked sparse tensors to file is not supported
print(e)
# CHECK-LABEL: test_write_packed_tns
@testing_utils.run_test
def test_write_packed_tns():
a = mlir_pytaco.Tensor([2, 3])
a.insert([0, 1], 10)
a.insert([1, 2], 40)
a.insert([0, 0], 20)
b = mlir_pytaco.Tensor([2, 3])
i, j = mlir_pytaco.get_index_vars(2)
b[i, j] = a[i, j] + a[i, j]
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.tns")
mlir_pytaco_io.write(file_name, b)
with open(file_name, "r") as file:
lines = file.readlines()
passed = 0
# Skip the comment line in the output.
if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]:
passed = 1
# CHECK: 1
print(passed)
a = mlir_pytaco.Tensor([2, 3])
a.insert([0, 1], 10)
a.insert([1, 2], 40)
a.insert([0, 0], 20)
b = mlir_pytaco.Tensor([2, 3])
i, j = mlir_pytaco.get_index_vars(2)
b[i, j] = a[i, j] + a[i, j]
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.tns")
mlir_pytaco_io.write(file_name, b)
with open(file_name, "r") as file:
lines = file.readlines()
passed = 0
# Skip the comment line in the output.
if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]:
passed = 1
# CHECK: 1
print(passed)

View File

@@ -20,79 +20,93 @@ _DENSE = mlir_pytaco.ModeFormat.DENSE
def _to_string(s: Sequence[int]) -> str:
"""Converts a sequence of integer to a space separated value string."""
return " ".join(map(lambda e: str(e), s))
"""Converts a sequence of integer to a space separated value string."""
return " ".join(map(lambda e: str(e), s))
def _add_one(s: Sequence[int]) -> Sequence[int]:
"""Adds one to each element in the sequence of integer."""
return [i + 1 for i in s]
"""Adds one to each element in the sequence of integer."""
return [i + 1 for i in s]
@dataclasses.dataclass(frozen=True)
class _SparseTensorCOO:
"""Values for a COO-flavored format sparse tensor.
"""Values for a COO-flavored format sparse tensor.
Attributes:
rank: An integer rank for the tensor.
nse: An integer for the number of non-zero values.
shape: A sequence of integer for the dimension size.
values: A sequence of float for the non-zero values of the tensor.
indices: A sequence of coordinate, each coordinate is a sequence of integer.
"""
rank: int
nse: int
shape: Sequence[int]
values: Sequence[float]
indices: Sequence[Sequence[int]]
Attributes:
rank: An integer rank for the tensor.
nse: An integer for the number of non-zero values.
shape: A sequence of integer for the dimension size.
values: A sequence of float for the non-zero values of the tensor.
indices: A sequence of coordinate, each coordinate is a sequence of integer.
"""
rank: int
nse: int
shape: Sequence[int]
values: Sequence[float]
indices: Sequence[Sequence[int]]
def _coo_values_to_tns_format(t: _SparseTensorCOO) -> str:
"""Converts a sparse tensor COO-flavored values to TNS text format."""
# The coo_value_str contains one line for each (coordinate value) pair.
# Indices are 1-based in TNS text format but 0-based in MLIR.
coo_value_str = "\n".join(
map(lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]),
range(t.nse)))
"""Converts a sparse tensor COO-flavored values to TNS text format."""
# The coo_value_str contains one line for each (coordinate value) pair.
# Indices are 1-based in TNS text format but 0-based in MLIR.
coo_value_str = "\n".join(
map(
lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]),
range(t.nse),
)
)
# Returns the TNS text format representation for the tensor.
return f"""{t.rank} {t.nse}
# Returns the TNS text format representation for the tensor.
return f"""{t.rank} {t.nse}
{_to_string(t.shape)}
{coo_value_str}
"""
def _implement_read_tns_test(
t: _SparseTensorCOO,
sparsity_codes: Sequence[sparse_tensor.DimLevelType]) -> int:
tns_data = _coo_values_to_tns_format(t)
t: _SparseTensorCOO, sparsity_codes: Sequence[sparse_tensor.DimLevelType]
) -> int:
tns_data = _coo_values_to_tns_format(t)
# Write sparse tensor data to a file.
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.tns")
with open(file_name, "w") as file:
file.write(tns_data)
# Write sparse tensor data to a file.
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.tns")
with open(file_name, "w") as file:
file.write(tns_data)
# Read the data from the file and construct an MLIR sparse tensor.
sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor(
file_name, sparsity_codes, "f64")
# Read the data from the file and construct an MLIR sparse tensor.
sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor(
file_name, sparsity_codes, "f64"
)
passed = 0
passed = 0
# Verify the output shape for the tensor.
if np.array_equal(o_shape, t.shape):
passed += 1
# Verify the output shape for the tensor.
if np.array_equal(o_shape, t.shape):
passed += 1
# Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
# values and verify the values.
o_rank, o_nse, o_shape, o_values, o_indices = (
pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64))
if o_rank == t.rank and o_nse == t.nse and np.array_equal(
o_shape, t.shape) and np.allclose(o_values, t.values) and np.array_equal(
o_indices, t.indices):
passed += 1
# Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
# values and verify the values.
(
o_rank,
o_nse,
o_shape,
o_values,
o_indices,
) = pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64)
if (
o_rank == t.rank
and o_nse == t.nse
and np.array_equal(o_shape, t.shape)
and np.allclose(o_values, t.values)
and np.array_equal(o_indices, t.indices)
):
passed += 1
return passed
return passed
# A 2D sparse tensor data in COO-flavored format.

View File

@@ -5,11 +5,11 @@ if not config.mlir_run_amx_tests:
config.unsupported = True
# No JIT on win32.
if sys.platform == 'win32':
if sys.platform == "win32":
config.unsupported = True
if config.intel_sde_executable:
# Run test in emulator (Intel SDE): AMX needs Sapphire Rapids CPU.
config.substitutions.append(('%lli', config.intel_sde_executable + ' -spr -- lli'))
config.substitutions.append(("%lli", config.intel_sde_executable + " -spr -- lli"))
else:
config.substitutions.append(('%lli', 'lli'))
config.substitutions.append(("%lli", "lli"))

View File

@@ -5,5 +5,5 @@ if not config.mlir_run_arm_sme_tests:
config.unsupported = True
# No JIT on win32.
if sys.platform == 'win32':
if sys.platform == "win32":
config.unsupported = True

View File

@@ -5,5 +5,5 @@ if not config.mlir_run_arm_sve_tests:
config.unsupported = True
# No JIT on win32.
if sys.platform == 'win32':
if sys.platform == "win32":
config.unsupported = True

View File

@@ -5,11 +5,11 @@ if not config.mlir_run_x86vector_tests:
config.unsupported = True
# No JIT on win32.
if sys.platform == 'win32':
if sys.platform == "win32":
config.unsupported = True
if config.intel_sde_executable:
# Run test in emulator (Intel SDE).
config.substitutions.append(('%lli', config.intel_sde_executable + ' -tgl -- lli'))
config.substitutions.append(("%lli", config.intel_sde_executable + " -tgl -- lli"))
else:
config.substitutions.append(('%lli', 'lli'))
config.substitutions.append(("%lli", "lli"))

View File

@@ -1,2 +1,2 @@
if not config.enable_cuda_runner:
config.unsupported = True
config.unsupported = True

View File

@@ -2,4 +2,4 @@ import sys
# TensorCore tests must be enabled via build flag.
if not config.mlir_run_cuda_tensor_core_tests:
config.unsupported = True
config.unsupported = True

View File

@@ -1,2 +1,2 @@
if not config.enable_cuda_runner:
config.unsupported = True
config.unsupported = True

View File

@@ -1,4 +1,4 @@
if not config.enable_rocm_runner or not config.rocm_test_chipset:
config.unsupported = True
config.unsupported = True
config.substitutions.append(('%chip', config.rocm_test_chipset))
config.substitutions.append(("%chip", config.rocm_test_chipset))

View File

@@ -3,8 +3,9 @@ from lit.llvm import llvm_config
if not config.mlir_include_integration_tests:
config.unsupported = True
def configure_aarch64_lli_cmd():
lli_cmd = 'lli'
lli_cmd = "lli"
# NOTE: If the SVE tests are disabled and the SME tests are enabled to run
# under emulation, the SVE specific RUN lines in the SparseTensor tests
@@ -12,8 +13,12 @@ def configure_aarch64_lli_cmd():
if not (config.mlir_run_arm_sve_tests or config.mlir_run_arm_sme_tests):
return lli_cmd
config.substitutions.append(('%mlir_native_utils_lib_dir',
config.arm_emulator_utils_lib_dir or config.mlir_lib_dir))
config.substitutions.append(
(
"%mlir_native_utils_lib_dir",
config.arm_emulator_utils_lib_dir or config.mlir_lib_dir,
)
)
if config.arm_emulator_executable:
if config.arm_emulator_lli_executable:
@@ -23,16 +28,22 @@ def configure_aarch64_lli_cmd():
# when running under an emulator. If the user didn't specify an lli
# executable, use absolute path %llvm_tools_dir/lli.
lli_cmd = llvm_config.use_llvm_tool(
'lli', search_env='LLI', required=True,
search_paths=[config.llvm_tools_dir], use_installed=False
"lli",
search_env="LLI",
required=True,
search_paths=[config.llvm_tools_dir],
use_installed=False,
)
# Run test in emulator (qemu or armie)
emulation_cmd = f'{config.arm_emulator_executable} {config.arm_emulator_options}'
lli_cmd = f'{emulation_cmd} {lli_cmd}'
emulation_cmd = (
f"{config.arm_emulator_executable} {config.arm_emulator_options}"
)
lli_cmd = f"{emulation_cmd} {lli_cmd}"
return lli_cmd
aarch64_lli_cmd = configure_aarch64_lli_cmd()
# Configure the following AArch64 substitutions:
@@ -52,5 +63,5 @@ aarch64_lli_cmd = configure_aarch64_lli_cmd()
# could be used in the SparseTensor tests where necessary, but the meaning
# conveyed by the substitution name would be a misnomer if the host target
# is not AArch64 and MLIR_RUN_ARM_SVE_TESTS=OFF.
config.substitutions.append(('%lli_aarch64_cmd', aarch64_lli_cmd))
config.substitutions.append(('%lli_host_or_aarch64_cmd', aarch64_lli_cmd))
config.substitutions.append(("%lli_aarch64_cmd", aarch64_lli_cmd))
config.substitutions.append(("%lli_host_or_aarch64_cmd", aarch64_lli_cmd))

View File

@@ -8,43 +8,43 @@ import subprocess
import lit.formats
# name: The name of this test suite.
config.name = 'MLIR-Unit'
config.name = "MLIR-Unit"
# suffixes: A list of file extensions to treat as test files.
config.suffixes = []
# test_source_root: The root path where tests are located.
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.mlir_obj_root, 'unittests')
config.test_exec_root = os.path.join(config.mlir_obj_root, "unittests")
config.test_source_root = config.test_exec_root
# testFormat: The test format to use to interpret tests.
config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, 'Tests')
config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, "Tests")
# Propagate the temp directory. Windows requires this because it uses \Windows\
# if none of these are present.
if 'TMP' in os.environ:
config.environment['TMP'] = os.environ['TMP']
if 'TEMP' in os.environ:
config.environment['TEMP'] = os.environ['TEMP']
if "TMP" in os.environ:
config.environment["TMP"] = os.environ["TMP"]
if "TEMP" in os.environ:
config.environment["TEMP"] = os.environ["TEMP"]
# Propagate HOME as it can be used to override incorrect homedir in passwd
# that causes the tests to fail.
if 'HOME' in os.environ:
config.environment['HOME'] = os.environ['HOME']
if "HOME" in os.environ:
config.environment["HOME"] = os.environ["HOME"]
# Propagate sanitizer options.
for var in [
'ASAN_SYMBOLIZER_PATH',
'HWASAN_SYMBOLIZER_PATH',
'MSAN_SYMBOLIZER_PATH',
'TSAN_SYMBOLIZER_PATH',
'UBSAN_SYMBOLIZER_PATH',
'ASAN_OPTIONS',
'HWASAN_OPTIONS',
'MSAN_OPTIONS',
'TSAN_OPTIONS',
'UBSAN_OPTIONS',
"ASAN_SYMBOLIZER_PATH",
"HWASAN_SYMBOLIZER_PATH",
"MSAN_SYMBOLIZER_PATH",
"TSAN_SYMBOLIZER_PATH",
"UBSAN_SYMBOLIZER_PATH",
"ASAN_OPTIONS",
"HWASAN_OPTIONS",
"MSAN_OPTIONS",
"TSAN_OPTIONS",
"UBSAN_OPTIONS",
]:
if var in os.environ:
config.environment[var] = os.environ[var]

View File

@@ -1 +1 @@
config.suffixes.remove('.td')
config.suffixes.remove(".td")

View File

@@ -1 +1 @@
config.suffixes.remove('.td')
config.suffixes.remove(".td")

View File

@@ -1 +1 @@
config.suffixes.remove('.pdll')
config.suffixes.remove(".pdll")

View File

@@ -1 +1 @@
config.suffixes.remove('.pdll')
config.suffixes.remove(".pdll")

View File

@@ -16,21 +16,32 @@ from lit.llvm.subst import FindTool
# Configuration file for the 'lit' test runner.
# name: The name of this test suite.
config.name = 'MLIR'
config.name = "MLIR"
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test', '.pdll', '.c']
config.suffixes = [
".td",
".mlir",
".toy",
".ll",
".tc",
".py",
".yaml",
".test",
".pdll",
".c",
]
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.mlir_obj_root, 'test')
config.test_exec_root = os.path.join(config.mlir_obj_root, "test")
config.substitutions.append(('%PATH%', config.environment['PATH']))
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
config.substitutions.append(("%PATH%", config.environment["PATH"]))
config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
config.substitutions.append(("%mlir_src_root", config.mlir_src_root))
config.substitutions.append(("%host_cxx", config.host_cxx))
config.substitutions.append(("%host_cc", config.host_cc))
@@ -40,94 +51,109 @@ config.substitutions.append(("%host_cc", config.host_cc))
# substitution of the same name and the found path.
# Correctly handles the platforms shared library directory and naming conventions.
def add_runtime(name):
path = ''
for prefix in ['', 'lib']:
path = os.path.join(config.llvm_shlib_dir, f'{prefix}{name}{config.llvm_shlib_ext}')
path = ""
for prefix in ["", "lib"]:
path = os.path.join(
config.llvm_shlib_dir, f"{prefix}{name}{config.llvm_shlib_ext}"
)
if os.path.isfile(path):
break
return ToolSubst(f'%{name}', path)
return ToolSubst(f"%{name}", path)
llvm_config.with_system_environment(
['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
llvm_config.use_default_substitutions()
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
# subdirectories contain auxiliary inputs for various tests in their parent
# directories.
config.excludes = ['Inputs', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
'lit.cfg.py', 'lit.site.cfg.py']
config.excludes = [
"Inputs",
"CMakeLists.txt",
"README.txt",
"LICENSE.txt",
"lit.cfg.py",
"lit.site.cfg.py",
]
# Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.mlir_tools_dir, append_path=True)
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
llvm_config.with_environment("PATH", config.mlir_tools_dir, append_path=True)
llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
tools = [
'mlir-tblgen',
'mlir-translate',
'mlir-lsp-server',
'mlir-capi-execution-engine-test',
'mlir-capi-ir-test',
'mlir-capi-llvm-test',
'mlir-capi-pass-test',
'mlir-capi-pdl-test',
'mlir-capi-quant-test',
'mlir-capi-sparse-tensor-test',
'mlir-capi-transform-test',
'mlir-cpu-runner',
add_runtime('mlir_runner_utils'),
add_runtime('mlir_c_runner_utils'),
add_runtime('mlir_async_runtime'),
'mlir-linalg-ods-yaml-gen',
'mlir-reduce',
'mlir-pdll',
'not',
"mlir-tblgen",
"mlir-translate",
"mlir-lsp-server",
"mlir-capi-execution-engine-test",
"mlir-capi-ir-test",
"mlir-capi-llvm-test",
"mlir-capi-pass-test",
"mlir-capi-pdl-test",
"mlir-capi-quant-test",
"mlir-capi-sparse-tensor-test",
"mlir-capi-transform-test",
"mlir-cpu-runner",
add_runtime("mlir_runner_utils"),
add_runtime("mlir_c_runner_utils"),
add_runtime("mlir_async_runtime"),
"mlir-linalg-ods-yaml-gen",
"mlir-reduce",
"mlir-pdll",
"not",
]
if config.enable_spirv_cpu_runner:
tools.extend(['mlir-spirv-cpu-runner', add_runtime('mlir_test_spirv_cpu_runner_c_wrappers')])
tools.extend(
["mlir-spirv-cpu-runner", add_runtime("mlir_test_spirv_cpu_runner_c_wrappers")]
)
if config.enable_vulkan_runner:
tools.extend([add_runtime('vulkan-runtime-wrappers')])
tools.extend([add_runtime("vulkan-runtime-wrappers")])
if config.enable_rocm_runner:
tools.extend([add_runtime('mlir_rocm_runtime')])
tools.extend([add_runtime("mlir_rocm_runtime")])
if config.enable_cuda_runner:
tools.extend([add_runtime('mlir_cuda_runtime')])
tools.extend([add_runtime("mlir_cuda_runtime")])
# The following tools are optional
tools.extend([
ToolSubst('toyc-ch1', unresolved='ignore'),
ToolSubst('toyc-ch2', unresolved='ignore'),
ToolSubst('toyc-ch3', unresolved='ignore'),
ToolSubst('toyc-ch4', unresolved='ignore'),
ToolSubst('toyc-ch5', unresolved='ignore'),
ToolSubst('toyc-ch6', unresolved='ignore'),
ToolSubst('toyc-ch7', unresolved='ignore'),
ToolSubst('%mlir_lib_dir', config.mlir_lib_dir, unresolved='ignore'),
ToolSubst('%mlir_src_dir', config.mlir_src_root, unresolved='ignore'),
])
tools.extend(
[
ToolSubst("toyc-ch1", unresolved="ignore"),
ToolSubst("toyc-ch2", unresolved="ignore"),
ToolSubst("toyc-ch3", unresolved="ignore"),
ToolSubst("toyc-ch4", unresolved="ignore"),
ToolSubst("toyc-ch5", unresolved="ignore"),
ToolSubst("toyc-ch6", unresolved="ignore"),
ToolSubst("toyc-ch7", unresolved="ignore"),
ToolSubst("%mlir_lib_dir", config.mlir_lib_dir, unresolved="ignore"),
ToolSubst("%mlir_src_dir", config.mlir_src_root, unresolved="ignore"),
]
)
python_executable = config.python_executable
# Python configuration with sanitizer requires some magic preloading. This will only work on clang/linux.
# TODO: detect Darwin/Windows situation (or mark these tests as unsupported on these platforms).
if "asan" in config.available_features and "Linux" in config.host_os:
python_executable = f"LD_PRELOAD=$({config.host_cxx} -print-file-name=libclang_rt.asan-{config.host_arch}.so) {config.python_executable}"
python_executable = f"LD_PRELOAD=$({config.host_cxx} -print-file-name=libclang_rt.asan-{config.host_arch}.so) {config.python_executable}"
# On Windows the path to python could contains spaces in which case it needs to be provided in quotes.
# This is the equivalent of how %python is setup in llvm/utils/lit/lit/llvm/config.py.
elif "Windows" in config.host_os:
python_executable = '"%s"' % (python_executable)
tools.extend([
ToolSubst('%PYTHON', python_executable, unresolved='ignore'),
])
python_executable = '"%s"' % (python_executable)
tools.extend(
[
ToolSubst("%PYTHON", python_executable, unresolved="ignore"),
]
)
if "MLIR_OPT_CHECK_IR_ROUNDTRIP" in os.environ:
tools.extend([
ToolSubst('mlir-opt', 'mlir-opt --verify-roundtrip', unresolved='fatal'),
])
tools.extend(
[
ToolSubst("mlir-opt", "mlir-opt --verify-roundtrip", unresolved="fatal"),
]
)
llvm_config.add_tool_substitutions(tools, tool_dirs)
@@ -135,40 +161,48 @@ llvm_config.add_tool_substitutions(tools, tool_dirs)
# FileCheck -enable-var-scope is enabled by default in MLIR test
# This option avoids to accidentally reuse variable across -LABEL match,
# it can be explicitly opted-in by prefixing the variable name with $
config.environment['FILECHECK_OPTS'] = "-enable-var-scope --allow-unused-prefixes=false"
config.environment["FILECHECK_OPTS"] = "-enable-var-scope --allow-unused-prefixes=false"
# Add the python path for both the source and binary tree.
# Note that presently, the python sources come from the source tree and the
# binaries come from the build tree. This should be unified to the build tree
# by copying/linking sources to build.
if config.enable_bindings_python:
llvm_config.with_environment('PYTHONPATH', [
os.path.join(config.mlir_obj_root, 'python_packages', 'mlir_core'),
os.path.join(config.mlir_obj_root, 'python_packages', 'mlir_test'),
], append_path=True)
llvm_config.with_environment(
"PYTHONPATH",
[
os.path.join(config.mlir_obj_root, "python_packages", "mlir_core"),
os.path.join(config.mlir_obj_root, "python_packages", "mlir_test"),
],
append_path=True,
)
if config.enable_assertions:
config.available_features.add('asserts')
config.available_features.add("asserts")
else:
config.available_features.add('noasserts')
config.available_features.add("noasserts")
def have_host_jit_feature_support(feature_name):
mlir_cpu_runner_exe = lit.util.which('mlir-cpu-runner', config.mlir_tools_dir)
mlir_cpu_runner_exe = lit.util.which("mlir-cpu-runner", config.mlir_tools_dir)
if not mlir_cpu_runner_exe:
return False
if not mlir_cpu_runner_exe:
return False
try:
mlir_cpu_runner_cmd = subprocess.Popen(
[mlir_cpu_runner_exe, '--host-supports-' + feature_name], stdout=subprocess.PIPE)
except OSError:
print('could not exec mlir-cpu-runner')
return False
try:
mlir_cpu_runner_cmd = subprocess.Popen(
[mlir_cpu_runner_exe, "--host-supports-" + feature_name],
stdout=subprocess.PIPE,
)
except OSError:
print("could not exec mlir-cpu-runner")
return False
mlir_cpu_runner_out = mlir_cpu_runner_cmd.stdout.read().decode('ascii')
mlir_cpu_runner_cmd.wait()
mlir_cpu_runner_out = mlir_cpu_runner_cmd.stdout.read().decode("ascii")
mlir_cpu_runner_cmd.wait()
return 'true' in mlir_cpu_runner_out
return "true" in mlir_cpu_runner_out
if have_host_jit_feature_support('jit'):
config.available_features.add('host-supports-jit')
if have_host_jit_feature_support("jit"):
config.available_features.add("host-supports-jit")

View File

@@ -1,12 +1,11 @@
import sys
# MSAN does not work with JIT.
if 'msan' in config.available_features:
config.unsupported = True
# Requires native execution.
if 'host-supports-jit' not in config.available_features:
if "msan" in config.available_features:
config.unsupported = True
config.available_features.add(
config.root.native_target.lower() + '-native-target')
# Requires native execution.
if "host-supports-jit" not in config.available_features:
config.unsupported = True
config.available_features.add(config.root.native_target.lower() + "-native-target")

View File

@@ -1 +1 @@
config.excludes = ['include']
config.excludes = ["include"]

View File

@@ -1,2 +1,2 @@
config.suffixes = ['.pdll', '.mlir']
config.excludes = ['include']
config.suffixes = [".pdll", ".mlir"]
config.excludes = ["include"]

View File

@@ -1,4 +1,4 @@
import sys
if not config.enable_spirv_cpu_runner:
config.unsupported = True
config.unsupported = True

View File

@@ -1,2 +1,2 @@
if not config.enable_vulkan_runner:
config.unsupported = True
config.unsupported = True

View File

@@ -14,5 +14,6 @@ expected_lib_name = "MLIRPythonCAPI"
all_libs = os.listdir(get_lib_dirs()[0])
found_lib = False
for file_name in all_libs:
if expected_lib_name in file_name: found_lib = True
if expected_lib_name in file_name:
found_lib = True
assert found_lib, f"Did not find '{expected_lib_name}' lib in {all_libs}"

View File

@@ -4,16 +4,18 @@ from mlir.ir import *
import mlir.dialects.func as func
import mlir.dialects.arith as arith
def run(f):
print("\nTEST:", f.__name__)
f()
print("\nTEST:", f.__name__)
f()
# CHECK-LABEL: TEST: testConstantOp
@run
def testConstantOps():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
arith.ConstantOp(value=42.42, result=F32Type.get())
# CHECK: %cst = arith.constant 4.242000e+01 : f32
print(module)
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
arith.ConstantOp(value=42.42, result=F32Type.get())
# CHECK: %cst = arith.constant 4.242000e+01 : f32
print(module)

View File

@@ -5,14 +5,17 @@ import mlir.dialects.async_dialect
import mlir.dialects.async_dialect.passes
from mlir.passmanager import *
def run(f):
print("\nTEST:", f.__name__)
f()
print("\nTEST:", f.__name__)
f()
def testAsyncPass():
with Context() as context:
PassManager.parse('any(async-to-async-runtime)')
print('SUCCESS')
with Context() as context:
PassManager.parse("any(async-to-async-runtime)")
print("SUCCESS")
# CHECK-LABEL: testAsyncPass
# CHECK: SUCCESS

View File

@@ -7,232 +7,242 @@ import numpy as np
def run(f):
print("\nTEST:", f.__name__)
f()
return f
print("\nTEST:", f.__name__)
f()
return f
# CHECK-LABEL: TEST: testFromPyFunc
@run
def testFromPyFunc():
with Context() as ctx, Location.unknown() as loc:
ctx.allow_unregistered_dialects = True
m = builtin.ModuleOp()
f32 = F32Type.get()
f64 = F64Type.get()
with InsertionPoint(m.body):
# CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
# CHECK: return %arg0 : f64
@func.FuncOp.from_py_func(f64)
def unary_return(a):
return a
with Context() as ctx, Location.unknown() as loc:
ctx.allow_unregistered_dialects = True
m = builtin.ModuleOp()
f32 = F32Type.get()
f64 = F64Type.get()
with InsertionPoint(m.body):
# CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
# CHECK: return %arg0 : f64
@func.FuncOp.from_py_func(f64)
def unary_return(a):
return a
# CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
# CHECK: return %arg0, %arg1 : f32, f64
@func.FuncOp.from_py_func(f32, f64)
def binary_return(a, b):
return a, b
# CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
# CHECK: return %arg0, %arg1 : f32, f64
@func.FuncOp.from_py_func(f32, f64)
def binary_return(a, b):
return a, b
# CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
# CHECK: return
@func.FuncOp.from_py_func(f32, f64)
def none_return(a, b):
pass
# CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
# CHECK: return
@func.FuncOp.from_py_func(f32, f64)
def none_return(a, b):
pass
# CHECK-LABEL: func @call_unary
# CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
# CHECK: return %0 : f64
@func.FuncOp.from_py_func(f64)
def call_unary(a):
return unary_return(a)
# CHECK-LABEL: func @call_unary
# CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
# CHECK: return %0 : f64
@func.FuncOp.from_py_func(f64)
def call_unary(a):
return unary_return(a)
# CHECK-LABEL: func @call_binary
# CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
# CHECK: return %0#0, %0#1 : f32, f64
@func.FuncOp.from_py_func(f32, f64)
def call_binary(a, b):
return binary_return(a, b)
# CHECK-LABEL: func @call_binary
# CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
# CHECK: return %0#0, %0#1 : f32, f64
@func.FuncOp.from_py_func(f32, f64)
def call_binary(a, b):
return binary_return(a, b)
# We expect coercion of a single result operation to a returned value.
# CHECK-LABEL: func @single_result_op
# CHECK: %0 = "custom.op1"() : () -> f32
# CHECK: return %0 : f32
@func.FuncOp.from_py_func()
def single_result_op():
return Operation.create("custom.op1", results=[f32])
# We expect coercion of a single result operation to a returned value.
# CHECK-LABEL: func @single_result_op
# CHECK: %0 = "custom.op1"() : () -> f32
# CHECK: return %0 : f32
@func.FuncOp.from_py_func()
def single_result_op():
return Operation.create("custom.op1", results=[f32])
# CHECK-LABEL: func @call_none
# CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
# CHECK: return
@func.FuncOp.from_py_func(f32, f64)
def call_none(a, b):
return none_return(a, b)
# CHECK-LABEL: func @call_none
# CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
# CHECK: return
@func.FuncOp.from_py_func(f32, f64)
def call_none(a, b):
return none_return(a, b)
## Variants and optional feature tests.
# CHECK-LABEL: func @from_name_arg
@func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
def explicit_name(a, b):
return b
## Variants and optional feature tests.
# CHECK-LABEL: func @from_name_arg
@func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
def explicit_name(a, b):
return b
@func.FuncOp.from_py_func(f32, f64)
def positional_func_op(a, b, func_op):
assert isinstance(func_op, func.FuncOp)
return b
@func.FuncOp.from_py_func(f32, f64)
def positional_func_op(a, b, func_op):
assert isinstance(func_op, func.FuncOp)
return b
@func.FuncOp.from_py_func(f32, f64)
def kw_func_op(a, b=None, func_op=None):
assert isinstance(func_op, func.FuncOp)
return b
@func.FuncOp.from_py_func(f32, f64)
def kw_func_op(a, b=None, func_op=None):
assert isinstance(func_op, func.FuncOp)
return b
@func.FuncOp.from_py_func(f32, f64)
def kwargs_func_op(a, b=None, **kwargs):
assert isinstance(kwargs["func_op"], func.FuncOp)
return b
@func.FuncOp.from_py_func(f32, f64)
def kwargs_func_op(a, b=None, **kwargs):
assert isinstance(kwargs["func_op"], func.FuncOp)
return b
# CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
# CHECK: return %arg1 : f64
@func.FuncOp.from_py_func(f32, f64, results=[f64])
def explicit_results(a, b):
func.ReturnOp([b])
# CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
# CHECK: return %arg1 : f64
@func.FuncOp.from_py_func(f32, f64, results=[f64])
def explicit_results(a, b):
func.ReturnOp([b])
print(m)
print(m)
# CHECK-LABEL: TEST: testFromPyFuncErrors
@run
def testFromPyFuncErrors():
with Context() as ctx, Location.unknown() as loc:
m = builtin.ModuleOp()
f32 = F32Type.get()
f64 = F64Type.get()
with InsertionPoint(m.body):
try:
with Context() as ctx, Location.unknown() as loc:
m = builtin.ModuleOp()
f32 = F32Type.get()
f64 = F64Type.get()
with InsertionPoint(m.body):
try:
@func.FuncOp.from_py_func(f64, results=[f64])
def unary_return(a):
return a
except AssertionError as e:
# CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
print(e)
@func.FuncOp.from_py_func(f64, results=[f64])
def unary_return(a):
return a
except AssertionError as e:
# CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
print(e)
# CHECK-LABEL: TEST: testBuildFuncOp
@run
def testBuildFuncOp():
ctx = Context()
with Location.unknown(ctx) as loc:
m = builtin.ModuleOp()
ctx = Context()
with Location.unknown(ctx) as loc:
m = builtin.ModuleOp()
f32 = F32Type.get()
tensor_type = RankedTensorType.get((2, 3, 4), f32)
with InsertionPoint.at_block_begin(m.body):
f = func.FuncOp(name="some_func",
type=FunctionType.get(
inputs=[tensor_type, tensor_type],
results=[tensor_type]),
visibility="nested")
# CHECK: Name is: "some_func"
print("Name is: ", f.name)
f32 = F32Type.get()
tensor_type = RankedTensorType.get((2, 3, 4), f32)
with InsertionPoint.at_block_begin(m.body):
f = func.FuncOp(
name="some_func",
type=FunctionType.get(
inputs=[tensor_type, tensor_type], results=[tensor_type]
),
visibility="nested",
)
# CHECK: Name is: "some_func"
print("Name is: ", f.name)
# CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
print("Type is: ", f.type)
# CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
print("Type is: ", f.type)
# CHECK: Visibility is: "nested"
print("Visibility is: ", f.visibility)
# CHECK: Visibility is: "nested"
print("Visibility is: ", f.visibility)
try:
entry_block = f.entry_block
except IndexError as e:
# CHECK: External function does not have a body
print(e)
try:
entry_block = f.entry_block
except IndexError as e:
# CHECK: External function does not have a body
print(e)
with InsertionPoint(f.add_entry_block()):
func.ReturnOp([f.entry_block.arguments[0]])
pass
with InsertionPoint(f.add_entry_block()):
func.ReturnOp([f.entry_block.arguments[0]])
pass
try:
f.add_entry_block()
except IndexError as e:
# CHECK: The function already has an entry block!
print(e)
try:
f.add_entry_block()
except IndexError as e:
# CHECK: The function already has an entry block!
print(e)
# Try the callback builder and passing type as tuple.
f = func.FuncOp(name="some_other_func",
type=([tensor_type, tensor_type], [tensor_type]),
visibility="nested",
body_builder=lambda f: func.ReturnOp(
[f.entry_block.arguments[0]]))
# Try the callback builder and passing type as tuple.
f = func.FuncOp(
name="some_other_func",
type=([tensor_type, tensor_type], [tensor_type]),
visibility="nested",
body_builder=lambda f: func.ReturnOp([f.entry_block.arguments[0]]),
)
# CHECK: module {
# CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
# CHECK: return %arg0 : tensor<2x3x4xf32>
# CHECK: }
# CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
# CHECK: return %arg0 : tensor<2x3x4xf32>
# CHECK: }
print(m)
# CHECK: module {
# CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
# CHECK: return %arg0 : tensor<2x3x4xf32>
# CHECK: }
# CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
# CHECK: return %arg0 : tensor<2x3x4xf32>
# CHECK: }
print(m)
# CHECK-LABEL: TEST: testFuncArgumentAccess
@run
def testFuncArgumentAccess():
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
module = Module.create()
f32 = F32Type.get()
f64 = F64Type.get()
with InsertionPoint(module.body):
f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
with InsertionPoint(f.add_entry_block()):
func.ReturnOp(f.arguments)
f.arg_attrs = ArrayAttr.get([
DictAttr.get({
"custom_dialect.foo": StringAttr.get("bar"),
"custom_dialect.baz": UnitAttr.get()
}),
DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])})
])
f.result_attrs = ArrayAttr.get([
DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)})
])
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
module = Module.create()
f32 = F32Type.get()
f64 = F64Type.get()
with InsertionPoint(module.body):
f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
with InsertionPoint(f.add_entry_block()):
func.ReturnOp(f.arguments)
f.arg_attrs = ArrayAttr.get(
[
DictAttr.get(
{
"custom_dialect.foo": StringAttr.get("bar"),
"custom_dialect.baz": UnitAttr.get(),
}
),
DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])}),
]
)
f.result_attrs = ArrayAttr.get(
[
DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)}),
]
)
other = func.FuncOp("other_func", ([f32, f32], []))
with InsertionPoint(other.add_entry_block()):
func.ReturnOp([])
other.arg_attrs = [
DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
DictAttr.get()
]
other = func.FuncOp("other_func", ([f32, f32], []))
with InsertionPoint(other.add_entry_block()):
func.ReturnOp([])
other.arg_attrs = [
DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
DictAttr.get(),
]
# CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
print(f.arg_attrs)
# CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
print(f.arg_attrs)
# CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
print(f.result_attrs)
# CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
print(f.result_attrs)
# CHECK: func @some_func(
# CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
# CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
# CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
# CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
# CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
#
# CHECK: func @other_func(
# CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
# CHECK: %{{.*}}: f32)
print(module)
# CHECK: func @some_func(
# CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
# CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
# CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
# CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
# CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
#
# CHECK: func @other_func(
# CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
# CHECK: %{{.*}}: f32)
print(module)
# CHECK-LABEL: testDenseElementsAttr
@run
def testDenseElementsAttr():
with Context(), Location.unknown():
values = np.arange(4, dtype=np.int32)
i32 = IntegerType.get_signless(32)
print(DenseElementsAttr.get(values, type=i32))
# CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
# CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
# CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
with Context(), Location.unknown():
values = np.arange(4, dtype=np.int32)
i32 = IntegerType.get_signless(32)
print(DenseElementsAttr.get(values, type=i32))
# CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
# CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
# CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>

View File

@@ -9,24 +9,24 @@ import mlir.dialects.complex as mlir_complex
def run(f):
print("\nTEST:", f.__name__)
f()
print("\nTEST:", f.__name__)
f()
# CHECK-LABEL: TEST: testComplexOps
@run
def testComplexOps():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
@func.FuncOp.from_py_func(ComplexType.get(F32Type.get()))
def emit_add(arg):
return mlir_complex.AddOp(arg, arg)
@func.FuncOp.from_py_func(ComplexType.get(F32Type.get()))
def emit_add(arg):
return mlir_complex.AddOp(arg, arg)
# CHECK-LABEL: func @emit_add(
# CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
# CHECK: %[[RES:.*]] = complex.add %[[ARG]], %[[ARG]] : complex<f32>
# CHECK: return %[[RES]] : complex<f32>
# CHECK: }
print(module)
# CHECK-LABEL: func @emit_add(
# CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
# CHECK: %[[RES:.*]] = complex.add %[[ARG]], %[[ARG]] : complex<f32>
# CHECK: return %[[RES]] : complex<f32>
# CHECK: }
print(module)

View File

@@ -7,13 +7,13 @@ from mlir.dialects import func
def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
return f
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
return f
# CHECK-LABEL: TEST: testConstantOp
@@ -21,21 +21,21 @@ def constructAndPrintInModule(f):
@constructAndPrintInModule
def testConstantOp():
c1 = arith.ConstantOp(IntegerType.get_signless(32), 42)
c2 = arith.ConstantOp(IntegerType.get_signless(64), 100)
c3 = arith.ConstantOp(F32Type.get(), 3.14)
c4 = arith.ConstantOp(F64Type.get(), 1.23)
# CHECK: 42
print(c1.literal_value)
c1 = arith.ConstantOp(IntegerType.get_signless(32), 42)
c2 = arith.ConstantOp(IntegerType.get_signless(64), 100)
c3 = arith.ConstantOp(F32Type.get(), 3.14)
c4 = arith.ConstantOp(F64Type.get(), 1.23)
# CHECK: 42
print(c1.literal_value)
# CHECK: 100
print(c2.literal_value)
# CHECK: 100
print(c2.literal_value)
# CHECK: 3.140000104904175
print(c3.literal_value)
# CHECK: 3.140000104904175
print(c3.literal_value)
# CHECK: 1.23
print(c4.literal_value)
# CHECK: 1.23
print(c4.literal_value)
# CHECK: = arith.constant 42 : i32
@@ -47,17 +47,17 @@ def testConstantOp():
# CHECK-LABEL: TEST: testVectorConstantOp
@constructAndPrintInModule
def testVectorConstantOp():
int_type = IntegerType.get_signless(32)
vec_type = VectorType.get([2, 2], int_type)
c1 = arith.ConstantOp(
vec_type,
DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42)))
try:
print(c1.literal_value)
except ValueError as e:
assert "only integer and float constants have literal values" in str(e)
else:
assert False
int_type = IntegerType.get_signless(32)
vec_type = VectorType.get([2, 2], int_type)
c1 = arith.ConstantOp(
vec_type, DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42))
)
try:
print(c1.literal_value)
except ValueError as e:
assert "only integer and float constants have literal values" in str(e)
else:
assert False
# CHECK: = arith.constant dense<42> : vector<2x2xi32>
@@ -66,9 +66,9 @@ def testVectorConstantOp():
# CHECK-LABEL: TEST: testConstantIndexOp
@constructAndPrintInModule
def testConstantIndexOp():
c1 = arith.ConstantOp.create_index(10)
# CHECK: 10
print(c1.literal_value)
c1 = arith.ConstantOp.create_index(10)
# CHECK: 10
print(c1.literal_value)
# CHECK: = arith.constant 10 : index
@@ -77,18 +77,18 @@ def testConstantIndexOp():
# CHECK-LABEL: TEST: testFunctionCalls
@constructAndPrintInModule
def testFunctionCalls():
foo = func.FuncOp("foo", ([], []))
foo.sym_visibility = StringAttr.get("private")
bar = func.FuncOp("bar", ([], [IndexType.get()]))
bar.sym_visibility = StringAttr.get("private")
qux = func.FuncOp("qux", ([], [F32Type.get()]))
qux.sym_visibility = StringAttr.get("private")
foo = func.FuncOp("foo", ([], []))
foo.sym_visibility = StringAttr.get("private")
bar = func.FuncOp("bar", ([], [IndexType.get()]))
bar.sym_visibility = StringAttr.get("private")
qux = func.FuncOp("qux", ([], [F32Type.get()]))
qux.sym_visibility = StringAttr.get("private")
with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
func.CallOp(foo, [])
func.CallOp([IndexType.get()], "bar", [])
func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
func.ReturnOp([])
with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
func.CallOp(foo, [])
func.CallOp([IndexType.get()], "bar", [])
func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
func.ReturnOp([])
# CHECK: func private @foo()

Some files were not shown because too many files have changed in this diff Show More