mirror of
https://github.com/intel/llvm.git
synced 2026-01-21 20:53:29 +08:00
[NFC][Py Reformat] Reformat python files in mlir subdir
This is an ongoing series of commits that are reformatting our Python code. Reformatting is done with `black`. If you end up having problems merging this commit because you have made changes to a python file, the best way to handle that is to run git checkout --ours <yourfile> and then reformat it with black. If you run into any problems, post to discourse about it and we will try to help. RFC Thread below: https://discourse.llvm.org/t/rfc-document-and-standardize-python-code-style Differential Revision: https://reviews.llvm.org/D150782
This commit is contained in:
@@ -25,7 +25,7 @@ from common import setup_passes
|
||||
def matmul_dsl(
|
||||
A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
|
||||
B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
|
||||
C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)
|
||||
C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
|
||||
):
|
||||
"""Helper function for mlir sparse matrix multiplication benchmark."""
|
||||
C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
|
||||
@@ -43,6 +43,7 @@ def benchmark_sparse_mlir_multiplication():
|
||||
param2_type = ir.RankedTensorType.get([1500, 2000], f64)
|
||||
result_type = ir.RankedTensorType.get([1000, 2000], f64)
|
||||
with ir.InsertionPoint(module.body):
|
||||
|
||||
@func.FuncOp.from_py_func(param1_type, param2_type, result_type)
|
||||
def sparse_kernel(x, y, z):
|
||||
return matmul_dsl(x, y, outs=[z])
|
||||
@@ -51,37 +52,34 @@ def benchmark_sparse_mlir_multiplication():
|
||||
with ir.Context(), ir.Location.unknown():
|
||||
kernel_func = get_kernel_func_from_module(module)
|
||||
timer_func = emit_timer_func()
|
||||
wrapped_func = emit_benchmark_wrapped_main_func(
|
||||
kernel_func,
|
||||
timer_func
|
||||
)
|
||||
wrapped_func = emit_benchmark_wrapped_main_func(kernel_func, timer_func)
|
||||
main_module_with_benchmark = ir.Module.parse(
|
||||
str(timer_func) + str(wrapped_func) + str(kernel_func)
|
||||
)
|
||||
setup_passes(main_module_with_benchmark)
|
||||
c_runner_utils = os.getenv("MLIR_C_RUNNER_UTILS", "")
|
||||
assert os.path.exists(c_runner_utils),\
|
||||
f"{c_runner_utils} does not exist." \
|
||||
f" Please pass a valid value for" \
|
||||
assert os.path.exists(c_runner_utils), (
|
||||
f"{c_runner_utils} does not exist."
|
||||
f" Please pass a valid value for"
|
||||
f" MLIR_C_RUNNER_UTILS environment variable."
|
||||
)
|
||||
runner_utils = os.getenv("MLIR_RUNNER_UTILS", "")
|
||||
assert os.path.exists(runner_utils),\
|
||||
f"{runner_utils} does not exist." \
|
||||
f" Please pass a valid value for MLIR_RUNNER_UTILS" \
|
||||
assert os.path.exists(runner_utils), (
|
||||
f"{runner_utils} does not exist."
|
||||
f" Please pass a valid value for MLIR_RUNNER_UTILS"
|
||||
f" environment variable."
|
||||
)
|
||||
|
||||
engine = ExecutionEngine(
|
||||
main_module_with_benchmark,
|
||||
3,
|
||||
shared_libs=[c_runner_utils, runner_utils]
|
||||
shared_libs=[c_runner_utils, runner_utils],
|
||||
)
|
||||
return engine.invoke
|
||||
|
||||
def runner(engine_invoke):
|
||||
compiled_program_args = []
|
||||
for argument_type in [
|
||||
result_type, param1_type, param2_type, result_type
|
||||
]:
|
||||
for argument_type in [result_type, param1_type, param2_type, result_type]:
|
||||
argument_type_str = str(argument_type)
|
||||
dimensions_str = re.sub("<|>|tensor", "", argument_type_str)
|
||||
dimensions = [int(dim) for dim in dimensions_str.split("x")[:-1]]
|
||||
@@ -111,6 +109,7 @@ def benchmark_np_matrix_multiplication():
|
||||
benchmark, we don't have any `compiler` function returned. We just return
|
||||
the `runner` function.
|
||||
"""
|
||||
|
||||
def runner():
|
||||
argument1 = np.random.uniform(low=0.0, high=100.0, size=(1000, 1500))
|
||||
argument2 = np.random.uniform(low=0.0, high=100.0, size=(1500, 2000))
|
||||
|
||||
@@ -10,8 +10,7 @@ from mlir.passmanager import PassManager
|
||||
|
||||
|
||||
def setup_passes(mlir_module):
|
||||
"""Setup pass pipeline parameters for benchmark functions.
|
||||
"""
|
||||
"""Setup pass pipeline parameters for benchmark functions."""
|
||||
opt = (
|
||||
"parallelization-strategy=none"
|
||||
" vectorization-strategy=none vl=1 enable-simd-index32=False"
|
||||
@@ -43,12 +42,15 @@ def get_kernel_func_from_module(module: ir.Module) -> func.FuncOp:
|
||||
This function only works for a module with one region, one block, and one
|
||||
operation.
|
||||
"""
|
||||
assert len(module.operation.regions) == 1, \
|
||||
"Expected kernel module to have only one region"
|
||||
assert len(module.operation.regions[0].blocks) == 1, \
|
||||
"Expected kernel module to have only one block"
|
||||
assert len(module.operation.regions[0].blocks[0].operations) == 1, \
|
||||
"Expected kernel module to have only one operation"
|
||||
assert (
|
||||
len(module.operation.regions) == 1
|
||||
), "Expected kernel module to have only one region"
|
||||
assert (
|
||||
len(module.operation.regions[0].blocks) == 1
|
||||
), "Expected kernel module to have only one block"
|
||||
assert (
|
||||
len(module.operation.regions[0].blocks[0].operations) == 1
|
||||
), "Expected kernel module to have only one operation"
|
||||
return module.operation.regions[0].blocks[0].operations[0]
|
||||
|
||||
|
||||
@@ -57,8 +59,7 @@ def emit_timer_func() -> func.FuncOp:
|
||||
used, the `MLIR_RUNNER_UTILS` and `MLIR_C_RUNNER_UTILS` must be included.
|
||||
"""
|
||||
i64_type = ir.IntegerType.get_signless(64)
|
||||
nanoTime = func.FuncOp(
|
||||
"nanoTime", ([], [i64_type]), visibility="private")
|
||||
nanoTime = func.FuncOp("nanoTime", ([], [i64_type]), visibility="private")
|
||||
nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
||||
return nanoTime
|
||||
|
||||
@@ -76,9 +77,8 @@ def emit_benchmark_wrapped_main_func(kernel_func, timer_func):
|
||||
wrapped_func = func.FuncOp(
|
||||
# Same signature and an extra buffer of indices to save timings.
|
||||
"main",
|
||||
(kernel_func.arguments.types + [memref_of_i64_type],
|
||||
kernel_func.type.results),
|
||||
visibility="public"
|
||||
(kernel_func.arguments.types + [memref_of_i64_type], kernel_func.type.results),
|
||||
visibility="public",
|
||||
)
|
||||
wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
||||
|
||||
@@ -88,13 +88,13 @@ def emit_benchmark_wrapped_main_func(kernel_func, timer_func):
|
||||
zero = arith.ConstantOp.create_index(0)
|
||||
n_iterations = memref.DimOp(ir.IndexType.get(), timer_buffer, zero)
|
||||
one = arith.ConstantOp.create_index(1)
|
||||
iter_args = list(wrapped_func.arguments[-num_results - 1:-1])
|
||||
iter_args = list(wrapped_func.arguments[-num_results - 1 : -1])
|
||||
loop = scf.ForOp(zero, n_iterations, one, iter_args)
|
||||
with ir.InsertionPoint(loop.body):
|
||||
start = func.CallOp(timer_func, [])
|
||||
call = func.CallOp(
|
||||
kernel_func,
|
||||
wrapped_func.arguments[:-num_results - 1] + loop.inner_iter_args
|
||||
wrapped_func.arguments[: -num_results - 1] + loop.inner_iter_args,
|
||||
)
|
||||
end = func.CallOp(timer_func, [])
|
||||
time_taken = arith.SubIOp(end, start)
|
||||
|
||||
@@ -1 +1 @@
|
||||
config.suffixes.add('.c')
|
||||
config.suffixes.add(".c")
|
||||
|
||||
@@ -16,52 +16,55 @@ from lit.llvm.subst import FindTool
|
||||
# Configuration file for the 'lit' test runner.
|
||||
|
||||
# name: The name of this test suite.
|
||||
config.name = 'STANDALONE'
|
||||
config.name = "STANDALONE"
|
||||
|
||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||
|
||||
# suffixes: A list of file extensions to treat as test files.
|
||||
config.suffixes = ['.mlir']
|
||||
config.suffixes = [".mlir"]
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.standalone_obj_root, 'test')
|
||||
config.test_exec_root = os.path.join(config.standalone_obj_root, "test")
|
||||
|
||||
config.substitutions.append(('%PATH%', config.environment['PATH']))
|
||||
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
|
||||
config.substitutions.append(("%PATH%", config.environment["PATH"]))
|
||||
config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
|
||||
|
||||
llvm_config.with_system_environment(
|
||||
['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
|
||||
llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
|
||||
|
||||
llvm_config.use_default_substitutions()
|
||||
|
||||
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
||||
# subdirectories contain auxiliary inputs for various tests in their parent
|
||||
# directories.
|
||||
config.excludes = ['Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt']
|
||||
config.excludes = ["Inputs", "Examples", "CMakeLists.txt", "README.txt", "LICENSE.txt"]
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.standalone_obj_root, 'test')
|
||||
config.standalone_tools_dir = os.path.join(config.standalone_obj_root, 'bin')
|
||||
config.standalone_libs_dir = os.path.join(config.standalone_obj_root, 'lib')
|
||||
config.test_exec_root = os.path.join(config.standalone_obj_root, "test")
|
||||
config.standalone_tools_dir = os.path.join(config.standalone_obj_root, "bin")
|
||||
config.standalone_libs_dir = os.path.join(config.standalone_obj_root, "lib")
|
||||
|
||||
config.substitutions.append(('%standalone_libs', config.standalone_libs_dir))
|
||||
config.substitutions.append(("%standalone_libs", config.standalone_libs_dir))
|
||||
|
||||
# Tweak the PATH to include the tools dir.
|
||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||
llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
|
||||
|
||||
tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir]
|
||||
tools = [
|
||||
'mlir-opt',
|
||||
'standalone-capi-test',
|
||||
'standalone-opt',
|
||||
'standalone-translate',
|
||||
"mlir-opt",
|
||||
"standalone-capi-test",
|
||||
"standalone-opt",
|
||||
"standalone-translate",
|
||||
]
|
||||
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
||||
llvm_config.with_environment('PYTHONPATH', [
|
||||
os.path.join(config.mlir_obj_dir, 'python_packages', 'standalone'),
|
||||
], append_path=True)
|
||||
llvm_config.with_environment(
|
||||
"PYTHONPATH",
|
||||
[
|
||||
os.path.join(config.mlir_obj_dir, "python_packages", "standalone"),
|
||||
],
|
||||
append_path=True,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
config.suffixes.add('.py')
|
||||
config.suffixes.add(".py")
|
||||
|
||||
if not config.enable_bindings_python:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
# RUN: %python %s | FileCheck %s
|
||||
|
||||
from mlir_standalone.ir import *
|
||||
from mlir_standalone.dialects import (
|
||||
builtin as builtin_d,
|
||||
standalone as standalone_d
|
||||
)
|
||||
from mlir_standalone.dialects import builtin as builtin_d, standalone as standalone_d
|
||||
|
||||
with Context():
|
||||
standalone_d.register_dialect()
|
||||
module = Module.parse("""
|
||||
standalone_d.register_dialect()
|
||||
module = Module.parse(
|
||||
"""
|
||||
%0 = arith.constant 2 : i32
|
||||
%1 = standalone.foo %0 : i32
|
||||
""")
|
||||
# CHECK: %[[C:.*]] = arith.constant 2 : i32
|
||||
# CHECK: standalone.foo %[[C]] : i32
|
||||
print(str(module))
|
||||
"""
|
||||
)
|
||||
# CHECK: %[[C:.*]] = arith.constant 2 : i32
|
||||
# CHECK: standalone.foo %[[C]] : i32
|
||||
print(str(module))
|
||||
|
||||
@@ -10,26 +10,26 @@ _this_dir = os.path.dirname(__file__)
|
||||
|
||||
|
||||
def get_lib_dirs() -> Sequence[str]:
|
||||
"""Gets the lib directory for linking to shared libraries.
|
||||
"""Gets the lib directory for linking to shared libraries.
|
||||
|
||||
On some platforms, the package may need to be built specially to export
|
||||
development libraries.
|
||||
"""
|
||||
return [_this_dir]
|
||||
On some platforms, the package may need to be built specially to export
|
||||
development libraries.
|
||||
"""
|
||||
return [_this_dir]
|
||||
|
||||
|
||||
def get_include_dirs() -> Sequence[str]:
|
||||
"""Gets the include directory for compiling against exported C libraries.
|
||||
"""Gets the include directory for compiling against exported C libraries.
|
||||
|
||||
Depending on how the package was build, development C libraries may or may
|
||||
not be present.
|
||||
"""
|
||||
return [os.path.join(_this_dir, "include")]
|
||||
Depending on how the package was build, development C libraries may or may
|
||||
not be present.
|
||||
"""
|
||||
return [os.path.join(_this_dir, "include")]
|
||||
|
||||
|
||||
# Perform Python level site initialization. This involves:
|
||||
# 1. Attempting to load initializer modules, specific to the distribution.
|
||||
# 2. Defining the concrete mlir.ir.Context that does site specific
|
||||
# 2. Defining the concrete mlir.ir.Context that does site specific
|
||||
# initialization.
|
||||
#
|
||||
# Aside from just being far more convenient to do this at the Python level,
|
||||
@@ -38,91 +38,106 @@ def get_include_dirs() -> Sequence[str]:
|
||||
# in the scope of the base class __init__).
|
||||
#
|
||||
# For #1, we:
|
||||
# a. Probe for modules named '_mlirRegisterEverything' and
|
||||
# '_site_initialize_{i}', where 'i' is a number starting at zero and
|
||||
# a. Probe for modules named '_mlirRegisterEverything' and
|
||||
# '_site_initialize_{i}', where 'i' is a number starting at zero and
|
||||
# proceeding so long as a module with the name is found.
|
||||
# b. If the module has a 'register_dialects' attribute, it will be called
|
||||
# immediately with a DialectRegistry to populate.
|
||||
# c. If the module has a 'context_init_hook', it will be added to a list
|
||||
# of callbacks that are invoked as the last step of Context
|
||||
# of callbacks that are invoked as the last step of Context
|
||||
# initialization (and passed the Context under construction).
|
||||
#
|
||||
# This facility allows downstreams to customize Context creation to their
|
||||
# needs.
|
||||
def _site_initialize():
|
||||
import importlib
|
||||
import itertools
|
||||
import logging
|
||||
from ._mlir import ir
|
||||
logger = logging.getLogger(__name__)
|
||||
registry = ir.DialectRegistry()
|
||||
post_init_hooks = []
|
||||
import importlib
|
||||
import itertools
|
||||
import logging
|
||||
from ._mlir import ir
|
||||
|
||||
def process_initializer_module(module_name):
|
||||
try:
|
||||
m = importlib.import_module(f".{module_name}", __name__)
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
except ImportError:
|
||||
message = (f"Error importing mlir initializer {module_name}. This may "
|
||||
"happen in unclean incremental builds but is likely a real bug if "
|
||||
"encountered otherwise and the MLIR Python API may not function.")
|
||||
logger.warning(message, exc_info=True)
|
||||
logger = logging.getLogger(__name__)
|
||||
registry = ir.DialectRegistry()
|
||||
post_init_hooks = []
|
||||
|
||||
logger.debug("Initializing MLIR with module: %s", module_name)
|
||||
if hasattr(m, "register_dialects"):
|
||||
logger.debug("Registering dialects from initializer %r", m)
|
||||
m.register_dialects(registry)
|
||||
if hasattr(m, "context_init_hook"):
|
||||
logger.debug("Adding context init hook from %r", m)
|
||||
post_init_hooks.append(m.context_init_hook)
|
||||
return True
|
||||
def process_initializer_module(module_name):
|
||||
try:
|
||||
m = importlib.import_module(f".{module_name}", __name__)
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
except ImportError:
|
||||
message = (
|
||||
f"Error importing mlir initializer {module_name}. This may "
|
||||
"happen in unclean incremental builds but is likely a real bug if "
|
||||
"encountered otherwise and the MLIR Python API may not function."
|
||||
)
|
||||
logger.warning(message, exc_info=True)
|
||||
|
||||
logger.debug("Initializing MLIR with module: %s", module_name)
|
||||
if hasattr(m, "register_dialects"):
|
||||
logger.debug("Registering dialects from initializer %r", m)
|
||||
m.register_dialects(registry)
|
||||
if hasattr(m, "context_init_hook"):
|
||||
logger.debug("Adding context init hook from %r", m)
|
||||
post_init_hooks.append(m.context_init_hook)
|
||||
return True
|
||||
|
||||
# If _mlirRegisterEverything is built, then include it as an initializer
|
||||
# module.
|
||||
process_initializer_module("_mlirRegisterEverything")
|
||||
# If _mlirRegisterEverything is built, then include it as an initializer
|
||||
# module.
|
||||
process_initializer_module("_mlirRegisterEverything")
|
||||
|
||||
# Load all _site_initialize_{i} modules, where 'i' is a number starting
|
||||
# at 0.
|
||||
for i in itertools.count():
|
||||
module_name = f"_site_initialize_{i}"
|
||||
if not process_initializer_module(module_name):
|
||||
break
|
||||
# Load all _site_initialize_{i} modules, where 'i' is a number starting
|
||||
# at 0.
|
||||
for i in itertools.count():
|
||||
module_name = f"_site_initialize_{i}"
|
||||
if not process_initializer_module(module_name):
|
||||
break
|
||||
|
||||
class Context(ir._BaseContext):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.append_dialect_registry(registry)
|
||||
for hook in post_init_hooks:
|
||||
hook(self)
|
||||
# TODO: There is some debate about whether we should eagerly load
|
||||
# all dialects. It is being done here in order to preserve existing
|
||||
# behavior. See: https://github.com/llvm/llvm-project/issues/56037
|
||||
self.load_all_available_dialects()
|
||||
ir.Context = Context
|
||||
class Context(ir._BaseContext):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.append_dialect_registry(registry)
|
||||
for hook in post_init_hooks:
|
||||
hook(self)
|
||||
# TODO: There is some debate about whether we should eagerly load
|
||||
# all dialects. It is being done here in order to preserve existing
|
||||
# behavior. See: https://github.com/llvm/llvm-project/issues/56037
|
||||
self.load_all_available_dialects()
|
||||
|
||||
class MLIRError(Exception):
|
||||
"""
|
||||
An exception with diagnostic information. Has the following fields:
|
||||
message: str
|
||||
error_diagnostics: List[ir.DiagnosticInfo]
|
||||
"""
|
||||
def __init__(self, message, error_diagnostics):
|
||||
self.message = message
|
||||
self.error_diagnostics = error_diagnostics
|
||||
super().__init__(message, error_diagnostics)
|
||||
ir.Context = Context
|
||||
|
||||
def __str__(self):
|
||||
s = self.message
|
||||
if self.error_diagnostics:
|
||||
s += ':'
|
||||
for diag in self.error_diagnostics:
|
||||
s += "\nerror: " + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n ')
|
||||
for note in diag.notes:
|
||||
s += "\n note: " + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n ')
|
||||
return s
|
||||
ir.MLIRError = MLIRError
|
||||
class MLIRError(Exception):
|
||||
"""
|
||||
An exception with diagnostic information. Has the following fields:
|
||||
message: str
|
||||
error_diagnostics: List[ir.DiagnosticInfo]
|
||||
"""
|
||||
|
||||
def __init__(self, message, error_diagnostics):
|
||||
self.message = message
|
||||
self.error_diagnostics = error_diagnostics
|
||||
super().__init__(message, error_diagnostics)
|
||||
|
||||
def __str__(self):
|
||||
s = self.message
|
||||
if self.error_diagnostics:
|
||||
s += ":"
|
||||
for diag in self.error_diagnostics:
|
||||
s += (
|
||||
"\nerror: "
|
||||
+ str(diag.location)[4:-1]
|
||||
+ ": "
|
||||
+ diag.message.replace("\n", "\n ")
|
||||
)
|
||||
for note in diag.notes:
|
||||
s += (
|
||||
"\n note: "
|
||||
+ str(note.location)[4:-1]
|
||||
+ ": "
|
||||
+ note.message.replace("\n", "\n ")
|
||||
)
|
||||
return s
|
||||
|
||||
ir.MLIRError = MLIRError
|
||||
|
||||
|
||||
_site_initialize()
|
||||
|
||||
@@ -3,72 +3,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context as _get_default_loc_context
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context as _get_default_loc_context
|
||||
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, List, Union
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
def _isa(obj: Any, cls: type):
|
||||
try:
|
||||
cls(obj)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
try:
|
||||
cls(obj)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_any_of(obj: Any, classes: List[type]):
|
||||
return any(_isa(obj, cls) for cls in classes)
|
||||
return any(_isa(obj, cls) for cls in classes)
|
||||
|
||||
|
||||
def _is_integer_like_type(type: Type):
|
||||
return _is_any_of(type, [IntegerType, IndexType])
|
||||
return _is_any_of(type, [IntegerType, IndexType])
|
||||
|
||||
|
||||
def _is_float_type(type: Type):
|
||||
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
|
||||
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
|
||||
|
||||
|
||||
class ConstantOp:
|
||||
"""Specialization for the constant op class."""
|
||||
"""Specialization for the constant op class."""
|
||||
|
||||
def __init__(self,
|
||||
result: Type,
|
||||
value: Union[int, float, Attribute],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
if isinstance(value, int):
|
||||
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
|
||||
elif isinstance(value, float):
|
||||
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
|
||||
else:
|
||||
super().__init__(value, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
|
||||
):
|
||||
if isinstance(value, int):
|
||||
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
|
||||
elif isinstance(value, float):
|
||||
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
|
||||
else:
|
||||
super().__init__(value, loc=loc, ip=ip)
|
||||
|
||||
@classmethod
|
||||
def create_index(cls, value: int, *, loc=None, ip=None):
|
||||
"""Create an index-typed constant."""
|
||||
return cls(
|
||||
IndexType.get(context=_get_default_loc_context(loc)),
|
||||
value,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
@classmethod
|
||||
def create_index(cls, value: int, *, loc=None, ip=None):
|
||||
"""Create an index-typed constant."""
|
||||
return cls(
|
||||
IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.results[0].type
|
||||
@property
|
||||
def type(self):
|
||||
return self.results[0].type
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return Attribute(self.operation.attributes["value"])
|
||||
@property
|
||||
def value(self):
|
||||
return Attribute(self.operation.attributes["value"])
|
||||
|
||||
@property
|
||||
def literal_value(self) -> Union[int, float]:
|
||||
if _is_integer_like_type(self.type):
|
||||
return IntegerAttr(self.value).value
|
||||
elif _is_float_type(self.type):
|
||||
return FloatAttr(self.value).value
|
||||
else:
|
||||
raise ValueError("only integer and float constants have literal values")
|
||||
@property
|
||||
def literal_value(self) -> Union[int, float]:
|
||||
if _is_integer_like_type(self.type):
|
||||
return IntegerAttr(self.value).value
|
||||
elif _is_float_type(self.type):
|
||||
return FloatAttr(self.value).value
|
||||
else:
|
||||
raise ValueError("only integer and float constants have literal values")
|
||||
|
||||
@@ -3,36 +3,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from typing import Sequence, Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context
|
||||
from typing import Sequence, Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context
|
||||
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, List, Union
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
class AllocTensorOp:
|
||||
"""Extends the bufferization.alloc_tensor op."""
|
||||
"""Extends the bufferization.alloc_tensor op."""
|
||||
|
||||
def __init__(self,
|
||||
tensor_type: Type,
|
||||
dynamic_sizes: Sequence[Value],
|
||||
copy: Value,
|
||||
size_hint: Value,
|
||||
escape: BoolAttr,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
|
||||
context = get_default_loc_context(loc)
|
||||
attributes = {}
|
||||
if escape:
|
||||
attributes["escape"] = escape
|
||||
op = self.build_generic(
|
||||
results=[tensor_type],
|
||||
operands=[dynamic_sizes, copy, size_hint],
|
||||
attributes=attributes,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
OpView.__init__(self, op)
|
||||
def __init__(
|
||||
self,
|
||||
tensor_type: Type,
|
||||
dynamic_sizes: Sequence[Value],
|
||||
copy: Value,
|
||||
size_hint: Value,
|
||||
escape: BoolAttr,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
|
||||
context = get_default_loc_context(loc)
|
||||
attributes = {}
|
||||
if escape:
|
||||
attributes["escape"] = escape
|
||||
op = self.build_generic(
|
||||
results=[tensor_type],
|
||||
operands=[dynamic_sizes, copy, size_hint],
|
||||
attributes=attributes,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
OpView.__init__(self, op)
|
||||
|
||||
@@ -3,18 +3,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..ir import *
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
class ModuleOp:
|
||||
"""Specialization for the module op class."""
|
||||
"""Specialization for the module op class."""
|
||||
|
||||
def __init__(self, *, loc=None, ip=None):
|
||||
super().__init__(self.build_generic(results=[], operands=[], loc=loc,
|
||||
ip=ip))
|
||||
body = self.regions[0].blocks.append()
|
||||
def __init__(self, *, loc=None, ip=None):
|
||||
super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip))
|
||||
body = self.regions[0].blocks.append()
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0].blocks[0]
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@@ -3,298 +3,317 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context as _get_default_loc_context
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context as _get_default_loc_context
|
||||
|
||||
import inspect
|
||||
import inspect
|
||||
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
|
||||
RESULT_ATTRIBUTE_NAME = "res_attrs"
|
||||
|
||||
|
||||
class ConstantOp:
|
||||
"""Specialization for the constant op class."""
|
||||
"""Specialization for the constant op class."""
|
||||
|
||||
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
|
||||
super().__init__(result, value, loc=loc, ip=ip)
|
||||
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
|
||||
super().__init__(result, value, loc=loc, ip=ip)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.results[0].type
|
||||
@property
|
||||
def type(self):
|
||||
return self.results[0].type
|
||||
|
||||
|
||||
class FuncOp:
|
||||
"""Specialization for the func op class."""
|
||||
"""Specialization for the func op class."""
|
||||
|
||||
def __init__(self,
|
||||
name,
|
||||
type,
|
||||
*,
|
||||
visibility=None,
|
||||
body_builder=None,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""
|
||||
Create a FuncOp with the provided `name`, `type`, and `visibility`.
|
||||
- `name` is a string representing the function name.
|
||||
- `type` is either a FunctionType or a pair of list describing inputs and
|
||||
results.
|
||||
- `visibility` is a string matching `public`, `private`, or `nested`. None
|
||||
implies private visibility.
|
||||
- `body_builder` is an optional callback, when provided a new entry block
|
||||
is created and the callback is invoked with the new op as argument within
|
||||
an InsertionPoint context already set for the block. The callback is
|
||||
expected to insert a terminator in the block.
|
||||
"""
|
||||
sym_name = StringAttr.get(str(name))
|
||||
def __init__(
|
||||
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
|
||||
):
|
||||
"""
|
||||
Create a FuncOp with the provided `name`, `type`, and `visibility`.
|
||||
- `name` is a string representing the function name.
|
||||
- `type` is either a FunctionType or a pair of list describing inputs and
|
||||
results.
|
||||
- `visibility` is a string matching `public`, `private`, or `nested`. None
|
||||
implies private visibility.
|
||||
- `body_builder` is an optional callback, when provided a new entry block
|
||||
is created and the callback is invoked with the new op as argument within
|
||||
an InsertionPoint context already set for the block. The callback is
|
||||
expected to insert a terminator in the block.
|
||||
"""
|
||||
sym_name = StringAttr.get(str(name))
|
||||
|
||||
# If the type is passed as a tuple, build a FunctionType on the fly.
|
||||
if isinstance(type, tuple):
|
||||
type = FunctionType.get(inputs=type[0], results=type[1])
|
||||
# If the type is passed as a tuple, build a FunctionType on the fly.
|
||||
if isinstance(type, tuple):
|
||||
type = FunctionType.get(inputs=type[0], results=type[1])
|
||||
|
||||
type = TypeAttr.get(type)
|
||||
sym_visibility = StringAttr.get(
|
||||
str(visibility)) if visibility is not None else None
|
||||
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
|
||||
if body_builder:
|
||||
entry_block = self.add_entry_block()
|
||||
with InsertionPoint(entry_block):
|
||||
body_builder(self)
|
||||
type = TypeAttr.get(type)
|
||||
sym_visibility = (
|
||||
StringAttr.get(str(visibility)) if visibility is not None else None
|
||||
)
|
||||
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
|
||||
if body_builder:
|
||||
entry_block = self.add_entry_block()
|
||||
with InsertionPoint(entry_block):
|
||||
body_builder(self)
|
||||
|
||||
@property
|
||||
def is_external(self):
|
||||
return len(self.regions[0].blocks) == 0
|
||||
@property
|
||||
def is_external(self):
|
||||
return len(self.regions[0].blocks) == 0
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0]
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0]
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
|
||||
@property
|
||||
def type(self):
|
||||
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
|
||||
|
||||
@property
|
||||
def visibility(self):
|
||||
return self.attributes["sym_visibility"]
|
||||
@property
|
||||
def visibility(self):
|
||||
return self.attributes["sym_visibility"]
|
||||
|
||||
@property
|
||||
def name(self) -> StringAttr:
|
||||
return StringAttr(self.attributes["sym_name"])
|
||||
@property
|
||||
def name(self) -> StringAttr:
|
||||
return StringAttr(self.attributes["sym_name"])
|
||||
|
||||
@property
|
||||
def entry_block(self):
|
||||
if self.is_external:
|
||||
raise IndexError('External function does not have a body')
|
||||
return self.regions[0].blocks[0]
|
||||
@property
|
||||
def entry_block(self):
|
||||
if self.is_external:
|
||||
raise IndexError("External function does not have a body")
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
|
||||
"""
|
||||
Add an entry block to the function body using the function signature to
|
||||
infer block arguments.
|
||||
Returns the newly created block
|
||||
"""
|
||||
if not self.is_external:
|
||||
raise IndexError('The function already has an entry block!')
|
||||
self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
|
||||
return self.body.blocks[0]
|
||||
def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
|
||||
"""
|
||||
Add an entry block to the function body using the function signature to
|
||||
infer block arguments.
|
||||
Returns the newly created block
|
||||
"""
|
||||
if not self.is_external:
|
||||
raise IndexError("The function already has an entry block!")
|
||||
self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
|
||||
return self.body.blocks[0]
|
||||
|
||||
@property
|
||||
def arg_attrs(self):
|
||||
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
|
||||
@property
|
||||
def arg_attrs(self):
|
||||
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
|
||||
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
|
||||
if isinstance(attribute, ArrayAttr):
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
|
||||
else:
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
|
||||
attribute, context=self.context)
|
||||
|
||||
@property
|
||||
def arguments(self):
|
||||
return self.entry_block.arguments
|
||||
|
||||
@property
|
||||
def result_attrs(self):
|
||||
return self.attributes[RESULT_ATTRIBUTE_NAME]
|
||||
|
||||
@result_attrs.setter
|
||||
def result_attrs(self, attribute: ArrayAttr):
|
||||
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
|
||||
|
||||
@classmethod
|
||||
def from_py_func(FuncOp,
|
||||
*inputs: Type,
|
||||
results: Optional[Sequence[Type]] = None,
|
||||
name: Optional[str] = None):
|
||||
"""Decorator to define an MLIR FuncOp specified as a python function.
|
||||
|
||||
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
|
||||
active for the current thread (i.e. established in a `with` block).
|
||||
|
||||
When applied as a decorator to a Python function, an entry block will
|
||||
be constructed for the FuncOp with types as specified in `*inputs`. The
|
||||
block arguments will be passed positionally to the Python function. In
|
||||
addition, if the Python function accepts keyword arguments generally or
|
||||
has a corresponding keyword argument, the following will be passed:
|
||||
* `func_op`: The `func` op being defined.
|
||||
|
||||
By default, the function name will be the Python function `__name__`. This
|
||||
can be overriden by passing the `name` argument to the decorator.
|
||||
|
||||
If `results` is not specified, then the decorator will implicitly
|
||||
insert a `ReturnOp` with the `Value`'s returned from the decorated
|
||||
function. It will also set the `FuncOp` type with the actual return
|
||||
value types. If `results` is specified, then the decorated function
|
||||
must return `None` and no implicit `ReturnOp` is added (nor are the result
|
||||
types updated). The implicit behavior is intended for simple, single-block
|
||||
cases, and users should specify result types explicitly for any complicated
|
||||
cases.
|
||||
|
||||
The decorated function can further be called from Python and will insert
|
||||
a `CallOp` at the then-current insertion point, returning either None (
|
||||
if no return values), a unary Value (for one result), or a list of Values).
|
||||
This mechanism cannot be used to emit recursive calls (by construction).
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
from . import func
|
||||
# Introspect the callable for optional features.
|
||||
sig = inspect.signature(f)
|
||||
has_arg_func_op = False
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == param.VAR_KEYWORD:
|
||||
has_arg_func_op = True
|
||||
if param.name == "func_op" and (param.kind
|
||||
== param.POSITIONAL_OR_KEYWORD or
|
||||
param.kind == param.KEYWORD_ONLY):
|
||||
has_arg_func_op = True
|
||||
|
||||
# Emit the FuncOp.
|
||||
implicit_return = results is None
|
||||
symbol_name = name or f.__name__
|
||||
function_type = FunctionType.get(
|
||||
inputs=inputs, results=[] if implicit_return else results)
|
||||
func_op = FuncOp(name=symbol_name, type=function_type)
|
||||
with InsertionPoint(func_op.add_entry_block()):
|
||||
func_args = func_op.entry_block.arguments
|
||||
func_kwargs = {}
|
||||
if has_arg_func_op:
|
||||
func_kwargs["func_op"] = func_op
|
||||
return_values = f(*func_args, **func_kwargs)
|
||||
if not implicit_return:
|
||||
return_types = list(results)
|
||||
assert return_values is None, (
|
||||
"Capturing a python function with explicit `results=` "
|
||||
"requires that the wrapped function returns None.")
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
|
||||
if isinstance(attribute, ArrayAttr):
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
|
||||
else:
|
||||
# Coerce return values, add ReturnOp and rewrite func type.
|
||||
if return_values is None:
|
||||
return_values = []
|
||||
elif isinstance(return_values, tuple):
|
||||
return_values = list(return_values)
|
||||
elif isinstance(return_values, Value):
|
||||
# Returning a single value is fine, coerce it into a list.
|
||||
return_values = [return_values]
|
||||
elif isinstance(return_values, OpView):
|
||||
# Returning a single operation is fine, coerce its results a list.
|
||||
return_values = return_values.operation.results
|
||||
elif isinstance(return_values, Operation):
|
||||
# Returning a single operation is fine, coerce its results a list.
|
||||
return_values = return_values.results
|
||||
else:
|
||||
return_values = list(return_values)
|
||||
func.ReturnOp(return_values)
|
||||
# Recompute the function type.
|
||||
return_types = [v.type for v in return_values]
|
||||
function_type = FunctionType.get(inputs=inputs, results=return_types)
|
||||
func_op.attributes["function_type"] = TypeAttr.get(function_type)
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
|
||||
attribute, context=self.context
|
||||
)
|
||||
|
||||
def emit_call_op(*call_args):
|
||||
call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name),
|
||||
call_args)
|
||||
if return_types is None:
|
||||
return None
|
||||
elif len(return_types) == 1:
|
||||
return call_op.result
|
||||
else:
|
||||
return call_op.results
|
||||
@property
|
||||
def arguments(self):
|
||||
return self.entry_block.arguments
|
||||
|
||||
wrapped = emit_call_op
|
||||
wrapped.__name__ = f.__name__
|
||||
wrapped.func_op = func_op
|
||||
return wrapped
|
||||
@property
|
||||
def result_attrs(self):
|
||||
return self.attributes[RESULT_ATTRIBUTE_NAME]
|
||||
|
||||
@result_attrs.setter
|
||||
def result_attrs(self, attribute: ArrayAttr):
|
||||
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
|
||||
|
||||
@classmethod
|
||||
def from_py_func(
|
||||
FuncOp,
|
||||
*inputs: Type,
|
||||
results: Optional[Sequence[Type]] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Decorator to define an MLIR FuncOp specified as a python function.
|
||||
|
||||
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
|
||||
active for the current thread (i.e. established in a `with` block).
|
||||
|
||||
When applied as a decorator to a Python function, an entry block will
|
||||
be constructed for the FuncOp with types as specified in `*inputs`. The
|
||||
block arguments will be passed positionally to the Python function. In
|
||||
addition, if the Python function accepts keyword arguments generally or
|
||||
has a corresponding keyword argument, the following will be passed:
|
||||
* `func_op`: The `func` op being defined.
|
||||
|
||||
By default, the function name will be the Python function `__name__`. This
|
||||
can be overriden by passing the `name` argument to the decorator.
|
||||
|
||||
If `results` is not specified, then the decorator will implicitly
|
||||
insert a `ReturnOp` with the `Value`'s returned from the decorated
|
||||
function. It will also set the `FuncOp` type with the actual return
|
||||
value types. If `results` is specified, then the decorated function
|
||||
must return `None` and no implicit `ReturnOp` is added (nor are the result
|
||||
types updated). The implicit behavior is intended for simple, single-block
|
||||
cases, and users should specify result types explicitly for any complicated
|
||||
cases.
|
||||
|
||||
The decorated function can further be called from Python and will insert
|
||||
a `CallOp` at the then-current insertion point, returning either None (
|
||||
if no return values), a unary Value (for one result), or a list of Values).
|
||||
This mechanism cannot be used to emit recursive calls (by construction).
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
from . import func
|
||||
|
||||
# Introspect the callable for optional features.
|
||||
sig = inspect.signature(f)
|
||||
has_arg_func_op = False
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == param.VAR_KEYWORD:
|
||||
has_arg_func_op = True
|
||||
if param.name == "func_op" and (
|
||||
param.kind == param.POSITIONAL_OR_KEYWORD
|
||||
or param.kind == param.KEYWORD_ONLY
|
||||
):
|
||||
has_arg_func_op = True
|
||||
|
||||
# Emit the FuncOp.
|
||||
implicit_return = results is None
|
||||
symbol_name = name or f.__name__
|
||||
function_type = FunctionType.get(
|
||||
inputs=inputs, results=[] if implicit_return else results
|
||||
)
|
||||
func_op = FuncOp(name=symbol_name, type=function_type)
|
||||
with InsertionPoint(func_op.add_entry_block()):
|
||||
func_args = func_op.entry_block.arguments
|
||||
func_kwargs = {}
|
||||
if has_arg_func_op:
|
||||
func_kwargs["func_op"] = func_op
|
||||
return_values = f(*func_args, **func_kwargs)
|
||||
if not implicit_return:
|
||||
return_types = list(results)
|
||||
assert return_values is None, (
|
||||
"Capturing a python function with explicit `results=` "
|
||||
"requires that the wrapped function returns None."
|
||||
)
|
||||
else:
|
||||
# Coerce return values, add ReturnOp and rewrite func type.
|
||||
if return_values is None:
|
||||
return_values = []
|
||||
elif isinstance(return_values, tuple):
|
||||
return_values = list(return_values)
|
||||
elif isinstance(return_values, Value):
|
||||
# Returning a single value is fine, coerce it into a list.
|
||||
return_values = [return_values]
|
||||
elif isinstance(return_values, OpView):
|
||||
# Returning a single operation is fine, coerce its results a list.
|
||||
return_values = return_values.operation.results
|
||||
elif isinstance(return_values, Operation):
|
||||
# Returning a single operation is fine, coerce its results a list.
|
||||
return_values = return_values.results
|
||||
else:
|
||||
return_values = list(return_values)
|
||||
func.ReturnOp(return_values)
|
||||
# Recompute the function type.
|
||||
return_types = [v.type for v in return_values]
|
||||
function_type = FunctionType.get(
|
||||
inputs=inputs, results=return_types
|
||||
)
|
||||
func_op.attributes["function_type"] = TypeAttr.get(function_type)
|
||||
|
||||
def emit_call_op(*call_args):
|
||||
call_op = func.CallOp(
|
||||
return_types, FlatSymbolRefAttr.get(symbol_name), call_args
|
||||
)
|
||||
if return_types is None:
|
||||
return None
|
||||
elif len(return_types) == 1:
|
||||
return call_op.result
|
||||
else:
|
||||
return call_op.results
|
||||
|
||||
wrapped = emit_call_op
|
||||
wrapped.__name__ = f.__name__
|
||||
wrapped.func_op = func_op
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
|
||||
return decorator
|
||||
|
||||
class CallOp:
|
||||
"""Specialization for the call op class."""
|
||||
"""Specialization for the call op class."""
|
||||
|
||||
def __init__(self,
|
||||
calleeOrResults: Union[FuncOp, List[Type]],
|
||||
argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
|
||||
arguments: Optional[List] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""Creates an call operation.
|
||||
def __init__(
|
||||
self,
|
||||
calleeOrResults: Union[FuncOp, List[Type]],
|
||||
argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
|
||||
arguments: Optional[List] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an call operation.
|
||||
|
||||
The constructor accepts three different forms:
|
||||
The constructor accepts three different forms:
|
||||
|
||||
1. A function op to be called followed by a list of arguments.
|
||||
2. A list of result types, followed by the name of the function to be
|
||||
called as string, following by a list of arguments.
|
||||
3. A list of result types, followed by the name of the function to be
|
||||
called as symbol reference attribute, followed by a list of arguments.
|
||||
1. A function op to be called followed by a list of arguments.
|
||||
2. A list of result types, followed by the name of the function to be
|
||||
called as string, following by a list of arguments.
|
||||
3. A list of result types, followed by the name of the function to be
|
||||
called as symbol reference attribute, followed by a list of arguments.
|
||||
|
||||
For example
|
||||
For example
|
||||
|
||||
f = func.FuncOp("foo", ...)
|
||||
func.CallOp(f, [args])
|
||||
func.CallOp([result_types], "foo", [args])
|
||||
f = func.FuncOp("foo", ...)
|
||||
func.CallOp(f, [args])
|
||||
func.CallOp([result_types], "foo", [args])
|
||||
|
||||
In all cases, the location and insertion point may be specified as keyword
|
||||
arguments if not provided by the surrounding context managers.
|
||||
"""
|
||||
In all cases, the location and insertion point may be specified as keyword
|
||||
arguments if not provided by the surrounding context managers.
|
||||
"""
|
||||
|
||||
# TODO: consider supporting constructor "overloads", e.g., through a custom
|
||||
# or pybind-provided metaclass.
|
||||
if isinstance(calleeOrResults, FuncOp):
|
||||
if not isinstance(argumentsOrCallee, list):
|
||||
raise ValueError(
|
||||
"when constructing a call to a function, expected " +
|
||||
"the second argument to be a list of call arguments, " +
|
||||
f"got {type(argumentsOrCallee)}")
|
||||
if arguments is not None:
|
||||
raise ValueError("unexpected third argument when constructing a call" +
|
||||
"to a function")
|
||||
# TODO: consider supporting constructor "overloads", e.g., through a custom
|
||||
# or pybind-provided metaclass.
|
||||
if isinstance(calleeOrResults, FuncOp):
|
||||
if not isinstance(argumentsOrCallee, list):
|
||||
raise ValueError(
|
||||
"when constructing a call to a function, expected "
|
||||
+ "the second argument to be a list of call arguments, "
|
||||
+ f"got {type(argumentsOrCallee)}"
|
||||
)
|
||||
if arguments is not None:
|
||||
raise ValueError(
|
||||
"unexpected third argument when constructing a call"
|
||||
+ "to a function"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
calleeOrResults.type.results,
|
||||
FlatSymbolRefAttr.get(
|
||||
calleeOrResults.name.value,
|
||||
context=_get_default_loc_context(loc)),
|
||||
argumentsOrCallee,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
return
|
||||
super().__init__(
|
||||
calleeOrResults.type.results,
|
||||
FlatSymbolRefAttr.get(
|
||||
calleeOrResults.name.value, context=_get_default_loc_context(loc)
|
||||
),
|
||||
argumentsOrCallee,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(argumentsOrCallee, list):
|
||||
raise ValueError("when constructing a call to a function by name, " +
|
||||
"expected the second argument to be a string or a " +
|
||||
f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}")
|
||||
if isinstance(argumentsOrCallee, list):
|
||||
raise ValueError(
|
||||
"when constructing a call to a function by name, "
|
||||
+ "expected the second argument to be a string or a "
|
||||
+ f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
|
||||
)
|
||||
|
||||
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
|
||||
super().__init__(
|
||||
calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip)
|
||||
elif isinstance(argumentsOrCallee, str):
|
||||
super().__init__(
|
||||
calleeOrResults,
|
||||
FlatSymbolRefAttr.get(
|
||||
argumentsOrCallee, context=_get_default_loc_context(loc)),
|
||||
arguments,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
|
||||
super().__init__(
|
||||
calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
|
||||
)
|
||||
elif isinstance(argumentsOrCallee, str):
|
||||
super().__init__(
|
||||
calleeOrResults,
|
||||
FlatSymbolRefAttr.get(
|
||||
argumentsOrCallee, context=_get_default_loc_context(loc)
|
||||
),
|
||||
arguments,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -3,39 +3,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from typing import Optional, Sequence, Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context
|
||||
from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
|
||||
from typing import Optional, Sequence, Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context
|
||||
from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
|
||||
|
||||
def isa(cls: Type, ty: Type):
|
||||
try:
|
||||
cls(ty)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
try:
|
||||
cls(ty)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class StructuredOpMixin:
|
||||
"""All structured ops use the same mixin class."""
|
||||
"""All structured ops use the same mixin class."""
|
||||
|
||||
def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
|
||||
super().__init__(
|
||||
self.build_generic(results=list(results),
|
||||
operands=[list(inputs), list(outputs)],
|
||||
loc=loc,
|
||||
ip=ip))
|
||||
def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
results=list(results),
|
||||
operands=[list(inputs), list(outputs)],
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def select_opview_mixin(parent_opview_cls):
|
||||
# TODO: This shouldn't be a heuristic: we should have a way to annotate
|
||||
# the OpView to note that it is a structured op.
|
||||
if ("__init__" not in parent_opview_cls.__dict__ and
|
||||
hasattr(parent_opview_cls, "inputs") and
|
||||
hasattr(parent_opview_cls, "outputs") and
|
||||
hasattr(parent_opview_cls, "result_tensors")):
|
||||
return StructuredOpMixin
|
||||
# TODO: This shouldn't be a heuristic: we should have a way to annotate
|
||||
# the OpView to note that it is a structured op.
|
||||
if (
|
||||
"__init__" not in parent_opview_cls.__dict__
|
||||
and hasattr(parent_opview_cls, "inputs")
|
||||
and hasattr(parent_opview_cls, "outputs")
|
||||
and hasattr(parent_opview_cls, "result_tensors")
|
||||
):
|
||||
return StructuredOpMixin
|
||||
|
||||
@@ -3,125 +3,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class GetParentForOp:
|
||||
"""Extension for GetParentForOp."""
|
||||
"""Extension for GetParentForOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
num_loops: Optional[int] = None,
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
if num_loops is None:
|
||||
num_loops = 1
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
num_loops=num_loops,
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
num_loops: Optional[int] = None,
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
if num_loops is None:
|
||||
num_loops = 1
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
num_loops=num_loops,
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
|
||||
class LoopOutlineOp:
|
||||
"""Extension for LoopOutlineOp."""
|
||||
"""Extension for LoopOutlineOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function_type: Type,
|
||||
call_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
func_name: Union[str, StringAttr],
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
super().__init__(
|
||||
function_type,
|
||||
call_type,
|
||||
_get_op_result_or_value(target),
|
||||
func_name=(func_name if isinstance(func_name, StringAttr) else
|
||||
StringAttr.get(func_name)),
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
function_type: Type,
|
||||
call_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
func_name: Union[str, StringAttr],
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
super().__init__(
|
||||
function_type,
|
||||
call_type,
|
||||
_get_op_result_or_value(target),
|
||||
func_name=(
|
||||
func_name
|
||||
if isinstance(func_name, StringAttr)
|
||||
else StringAttr.get(func_name)
|
||||
),
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
|
||||
class LoopPeelOp:
|
||||
"""Extension for LoopPeelOp."""
|
||||
"""Extension for LoopPeelOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
fail_if_already_divisible: Union[bool, BoolAttr] = False,
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
fail_if_already_divisible=(fail_if_already_divisible if isinstance(
|
||||
fail_if_already_divisible, BoolAttr) else
|
||||
BoolAttr.get(fail_if_already_divisible)),
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
fail_if_already_divisible: Union[bool, BoolAttr] = False,
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
fail_if_already_divisible=(
|
||||
fail_if_already_divisible
|
||||
if isinstance(fail_if_already_divisible, BoolAttr)
|
||||
else BoolAttr.get(fail_if_already_divisible)
|
||||
),
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
|
||||
class LoopPipelineOp:
|
||||
"""Extension for LoopPipelineOp."""
|
||||
"""Extension for LoopPipelineOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
|
||||
read_latency: Optional[Union[int, IntegerAttr]] = None,
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
if iteration_interval is None:
|
||||
iteration_interval = 1
|
||||
if read_latency is None:
|
||||
read_latency = 10
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
iteration_interval=iteration_interval,
|
||||
read_latency=read_latency,
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
|
||||
read_latency: Optional[Union[int, IntegerAttr]] = None,
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
if iteration_interval is None:
|
||||
iteration_interval = 1
|
||||
if read_latency is None:
|
||||
read_latency = 10
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
iteration_interval=iteration_interval,
|
||||
read_latency=read_latency,
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
|
||||
class LoopUnrollOp:
|
||||
"""Extension for LoopUnrollOp."""
|
||||
"""Extension for LoopUnrollOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
factor: Union[int, IntegerAttr],
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
super().__init__(
|
||||
_get_op_result_or_value(target),
|
||||
factor=factor,
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
factor: Union[int, IntegerAttr],
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
super().__init__(
|
||||
_get_op_result_or_value(target),
|
||||
factor=factor,
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
@@ -3,34 +3,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
from ._ods_common import get_op_results_or_values as _get_op_results_or_values
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
from ._ods_common import get_op_results_or_values as _get_op_results_or_values
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
class LoadOp:
|
||||
"""Specialization for the MemRef load operation."""
|
||||
"""Specialization for the MemRef load operation."""
|
||||
|
||||
def __init__(self,
|
||||
memref: Union[Operation, OpView, Value],
|
||||
indices: Optional[Union[Operation, OpView,
|
||||
Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""Creates a memref load operation.
|
||||
def __init__(
|
||||
self,
|
||||
memref: Union[Operation, OpView, Value],
|
||||
indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
"""Creates a memref load operation.
|
||||
|
||||
Args:
|
||||
memref: the buffer to load from.
|
||||
indices: the list of subscripts, may be empty for zero-dimensional
|
||||
buffers.
|
||||
loc: user-visible location of the operation.
|
||||
ip: insertion point.
|
||||
"""
|
||||
indices_resolved = [] if indices is None else _get_op_results_or_values(
|
||||
indices)
|
||||
super().__init__(memref, indices_resolved, loc=loc, ip=ip)
|
||||
Args:
|
||||
memref: the buffer to load from.
|
||||
indices: the list of subscripts, may be empty for zero-dimensional
|
||||
buffers.
|
||||
loc: user-visible location of the operation.
|
||||
ip: insertion point.
|
||||
"""
|
||||
indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
|
||||
super().__init__(memref, indices_resolved, loc=loc, ip=ip)
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from typing import Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context as _get_default_loc_context
|
||||
from typing import Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context as _get_default_loc_context
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from ._ml_program_ops_gen import *
|
||||
|
||||
@@ -17,100 +17,97 @@ RESULT_ATTRIBUTE_NAME = "res_attrs"
|
||||
|
||||
|
||||
class FuncOp:
|
||||
"""Specialization for the func op class."""
|
||||
"""Specialization for the func op class."""
|
||||
|
||||
def __init__(self,
|
||||
name,
|
||||
type,
|
||||
*,
|
||||
visibility=None,
|
||||
body_builder=None,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""
|
||||
Create a FuncOp with the provided `name`, `type`, and `visibility`.
|
||||
- `name` is a string representing the function name.
|
||||
- `type` is either a FunctionType or a pair of list describing inputs and
|
||||
results.
|
||||
- `visibility` is a string matching `public`, `private`, or `nested`. None
|
||||
implies private visibility.
|
||||
- `body_builder` is an optional callback, when provided a new entry block
|
||||
is created and the callback is invoked with the new op as argument within
|
||||
an InsertionPoint context already set for the block. The callback is
|
||||
expected to insert a terminator in the block.
|
||||
"""
|
||||
sym_name = StringAttr.get(str(name))
|
||||
def __init__(
|
||||
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
|
||||
):
|
||||
"""
|
||||
Create a FuncOp with the provided `name`, `type`, and `visibility`.
|
||||
- `name` is a string representing the function name.
|
||||
- `type` is either a FunctionType or a pair of list describing inputs and
|
||||
results.
|
||||
- `visibility` is a string matching `public`, `private`, or `nested`. None
|
||||
implies private visibility.
|
||||
- `body_builder` is an optional callback, when provided a new entry block
|
||||
is created and the callback is invoked with the new op as argument within
|
||||
an InsertionPoint context already set for the block. The callback is
|
||||
expected to insert a terminator in the block.
|
||||
"""
|
||||
sym_name = StringAttr.get(str(name))
|
||||
|
||||
# If the type is passed as a tuple, build a FunctionType on the fly.
|
||||
if isinstance(type, tuple):
|
||||
type = FunctionType.get(inputs=type[0], results=type[1])
|
||||
# If the type is passed as a tuple, build a FunctionType on the fly.
|
||||
if isinstance(type, tuple):
|
||||
type = FunctionType.get(inputs=type[0], results=type[1])
|
||||
|
||||
type = TypeAttr.get(type)
|
||||
sym_visibility = StringAttr.get(
|
||||
str(visibility)) if visibility is not None else None
|
||||
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
|
||||
if body_builder:
|
||||
entry_block = self.add_entry_block()
|
||||
with InsertionPoint(entry_block):
|
||||
body_builder(self)
|
||||
type = TypeAttr.get(type)
|
||||
sym_visibility = (
|
||||
StringAttr.get(str(visibility)) if visibility is not None else None
|
||||
)
|
||||
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
|
||||
if body_builder:
|
||||
entry_block = self.add_entry_block()
|
||||
with InsertionPoint(entry_block):
|
||||
body_builder(self)
|
||||
|
||||
@property
|
||||
def is_external(self):
|
||||
return len(self.regions[0].blocks) == 0
|
||||
@property
|
||||
def is_external(self):
|
||||
return len(self.regions[0].blocks) == 0
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0]
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0]
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
|
||||
@property
|
||||
def type(self):
|
||||
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
|
||||
|
||||
@property
|
||||
def visibility(self):
|
||||
return self.attributes["sym_visibility"]
|
||||
@property
|
||||
def visibility(self):
|
||||
return self.attributes["sym_visibility"]
|
||||
|
||||
@property
|
||||
def name(self) -> StringAttr:
|
||||
return StringAttr(self.attributes["sym_name"])
|
||||
@property
|
||||
def name(self) -> StringAttr:
|
||||
return StringAttr(self.attributes["sym_name"])
|
||||
|
||||
@property
|
||||
def entry_block(self):
|
||||
if self.is_external:
|
||||
raise IndexError('External function does not have a body')
|
||||
return self.regions[0].blocks[0]
|
||||
@property
|
||||
def entry_block(self):
|
||||
if self.is_external:
|
||||
raise IndexError("External function does not have a body")
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
def add_entry_block(self):
|
||||
"""
|
||||
Add an entry block to the function body using the function signature to
|
||||
infer block arguments.
|
||||
Returns the newly created block
|
||||
"""
|
||||
if not self.is_external:
|
||||
raise IndexError('The function already has an entry block!')
|
||||
self.body.blocks.append(*self.type.inputs)
|
||||
return self.body.blocks[0]
|
||||
def add_entry_block(self):
|
||||
"""
|
||||
Add an entry block to the function body using the function signature to
|
||||
infer block arguments.
|
||||
Returns the newly created block
|
||||
"""
|
||||
if not self.is_external:
|
||||
raise IndexError("The function already has an entry block!")
|
||||
self.body.blocks.append(*self.type.inputs)
|
||||
return self.body.blocks[0]
|
||||
|
||||
@property
|
||||
def arg_attrs(self):
|
||||
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
|
||||
@property
|
||||
def arg_attrs(self):
|
||||
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
|
||||
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
|
||||
if isinstance(attribute, ArrayAttr):
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
|
||||
else:
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
|
||||
attribute, context=self.context)
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
|
||||
if isinstance(attribute, ArrayAttr):
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
|
||||
else:
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
|
||||
attribute, context=self.context
|
||||
)
|
||||
|
||||
@property
|
||||
def arguments(self):
|
||||
return self.entry_block.arguments
|
||||
@property
|
||||
def arguments(self):
|
||||
return self.entry_block.arguments
|
||||
|
||||
@property
|
||||
def result_attrs(self):
|
||||
return self.attributes[RESULT_ATTRIBUTE_NAME]
|
||||
@property
|
||||
def result_attrs(self):
|
||||
return self.attributes[RESULT_ATTRIBUTE_NAME]
|
||||
|
||||
@result_attrs.setter
|
||||
def result_attrs(self, attribute: ArrayAttr):
|
||||
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
|
||||
@result_attrs.setter
|
||||
def result_attrs(self, attribute: ArrayAttr):
|
||||
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
|
||||
|
||||
@@ -18,144 +18,152 @@ __all__ = [
|
||||
|
||||
|
||||
def extend_opview_class(ext_module):
|
||||
"""Decorator to extend an OpView class from an extension module.
|
||||
"""Decorator to extend an OpView class from an extension module.
|
||||
|
||||
Extension modules can expose various entry-points:
|
||||
Stand-alone class with the same name as a parent OpView class (i.e.
|
||||
"ReturnOp"). A name-based match is attempted first before falling back
|
||||
to a below mechanism.
|
||||
Extension modules can expose various entry-points:
|
||||
Stand-alone class with the same name as a parent OpView class (i.e.
|
||||
"ReturnOp"). A name-based match is attempted first before falling back
|
||||
to a below mechanism.
|
||||
|
||||
def select_opview_mixin(parent_opview_cls):
|
||||
If defined, allows an appropriate mixin class to be selected dynamically
|
||||
based on the parent OpView class. Should return NotImplemented if a
|
||||
decision is not made.
|
||||
def select_opview_mixin(parent_opview_cls):
|
||||
If defined, allows an appropriate mixin class to be selected dynamically
|
||||
based on the parent OpView class. Should return NotImplemented if a
|
||||
decision is not made.
|
||||
|
||||
Args:
|
||||
ext_module: A module from which to locate extensions. Can be None if not
|
||||
available.
|
||||
Args:
|
||||
ext_module: A module from which to locate extensions. Can be None if not
|
||||
available.
|
||||
|
||||
Returns:
|
||||
A decorator that takes an OpView subclass and further extends it as
|
||||
needed.
|
||||
"""
|
||||
Returns:
|
||||
A decorator that takes an OpView subclass and further extends it as
|
||||
needed.
|
||||
"""
|
||||
|
||||
def class_decorator(parent_opview_cls: type):
|
||||
if ext_module is None:
|
||||
return parent_opview_cls
|
||||
mixin_cls = NotImplemented
|
||||
# First try to resolve by name.
|
||||
try:
|
||||
mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
|
||||
except AttributeError:
|
||||
# Fall back to a select_opview_mixin hook.
|
||||
try:
|
||||
select_mixin = getattr(ext_module, "select_opview_mixin")
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
mixin_cls = select_mixin(parent_opview_cls)
|
||||
def class_decorator(parent_opview_cls: type):
|
||||
if ext_module is None:
|
||||
return parent_opview_cls
|
||||
mixin_cls = NotImplemented
|
||||
# First try to resolve by name.
|
||||
try:
|
||||
mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
|
||||
except AttributeError:
|
||||
# Fall back to a select_opview_mixin hook.
|
||||
try:
|
||||
select_mixin = getattr(ext_module, "select_opview_mixin")
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
mixin_cls = select_mixin(parent_opview_cls)
|
||||
|
||||
if mixin_cls is NotImplemented or mixin_cls is None:
|
||||
return parent_opview_cls
|
||||
if mixin_cls is NotImplemented or mixin_cls is None:
|
||||
return parent_opview_cls
|
||||
|
||||
# Have a mixin_cls. Create an appropriate subclass.
|
||||
try:
|
||||
# Have a mixin_cls. Create an appropriate subclass.
|
||||
try:
|
||||
|
||||
class LocalOpView(mixin_cls, parent_opview_cls):
|
||||
pass
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e
|
||||
LocalOpView.__name__ = parent_opview_cls.__name__
|
||||
LocalOpView.__qualname__ = parent_opview_cls.__qualname__
|
||||
return LocalOpView
|
||||
class LocalOpView(mixin_cls, parent_opview_cls):
|
||||
pass
|
||||
|
||||
return class_decorator
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
f"Could not mixin {mixin_cls} into {parent_opview_cls}"
|
||||
) from e
|
||||
LocalOpView.__name__ = parent_opview_cls.__name__
|
||||
LocalOpView.__qualname__ = parent_opview_cls.__qualname__
|
||||
return LocalOpView
|
||||
|
||||
return class_decorator
|
||||
|
||||
|
||||
def segmented_accessor(elements, raw_segments, idx):
|
||||
"""
|
||||
Returns a slice of elements corresponding to the idx-th segment.
|
||||
"""
|
||||
Returns a slice of elements corresponding to the idx-th segment.
|
||||
|
||||
elements: a sliceable container (operands or results).
|
||||
raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
|
||||
sizes of the segments.
|
||||
idx: index of the segment.
|
||||
"""
|
||||
segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
|
||||
start = sum(segments[i] for i in range(idx))
|
||||
end = start + segments[idx]
|
||||
return elements[start:end]
|
||||
elements: a sliceable container (operands or results).
|
||||
raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
|
||||
sizes of the segments.
|
||||
idx: index of the segment.
|
||||
"""
|
||||
segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
|
||||
start = sum(segments[i] for i in range(idx))
|
||||
end = start + segments[idx]
|
||||
return elements[start:end]
|
||||
|
||||
|
||||
def equally_sized_accessor(elements, n_variadic, n_preceding_simple,
|
||||
n_preceding_variadic):
|
||||
"""
|
||||
Returns a starting position and a number of elements per variadic group
|
||||
assuming equally-sized groups and the given numbers of preceding groups.
|
||||
def equally_sized_accessor(
|
||||
elements, n_variadic, n_preceding_simple, n_preceding_variadic
|
||||
):
|
||||
"""
|
||||
Returns a starting position and a number of elements per variadic group
|
||||
assuming equally-sized groups and the given numbers of preceding groups.
|
||||
|
||||
elements: a sequential container.
|
||||
n_variadic: the number of variadic groups in the container.
|
||||
n_preceding_simple: the number of non-variadic groups preceding the current
|
||||
group.
|
||||
n_preceding_variadic: the number of variadic groups preceding the current
|
||||
group.
|
||||
"""
|
||||
elements: a sequential container.
|
||||
n_variadic: the number of variadic groups in the container.
|
||||
n_preceding_simple: the number of non-variadic groups preceding the current
|
||||
group.
|
||||
n_preceding_variadic: the number of variadic groups preceding the current
|
||||
group.
|
||||
"""
|
||||
|
||||
total_variadic_length = len(elements) - n_variadic + 1
|
||||
# This should be enforced by the C++-side trait verifier.
|
||||
assert total_variadic_length % n_variadic == 0
|
||||
total_variadic_length = len(elements) - n_variadic + 1
|
||||
# This should be enforced by the C++-side trait verifier.
|
||||
assert total_variadic_length % n_variadic == 0
|
||||
|
||||
elements_per_group = total_variadic_length // n_variadic
|
||||
start = n_preceding_simple + n_preceding_variadic * elements_per_group
|
||||
return start, elements_per_group
|
||||
elements_per_group = total_variadic_length // n_variadic
|
||||
start = n_preceding_simple + n_preceding_variadic * elements_per_group
|
||||
return start, elements_per_group
|
||||
|
||||
|
||||
def get_default_loc_context(location=None):
|
||||
"""
|
||||
Returns a context in which the defaulted location is created. If the location
|
||||
is None, takes the current location from the stack, raises ValueError if there
|
||||
is no location on the stack.
|
||||
"""
|
||||
if location is None:
|
||||
# Location.current raises ValueError if there is no current location.
|
||||
return _cext.ir.Location.current.context
|
||||
return location.context
|
||||
"""
|
||||
Returns a context in which the defaulted location is created. If the location
|
||||
is None, takes the current location from the stack, raises ValueError if there
|
||||
is no location on the stack.
|
||||
"""
|
||||
if location is None:
|
||||
# Location.current raises ValueError if there is no current location.
|
||||
return _cext.ir.Location.current.context
|
||||
return location.context
|
||||
|
||||
|
||||
def get_op_result_or_value(
|
||||
arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList]
|
||||
arg: _Union[
|
||||
_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList
|
||||
]
|
||||
) -> _cext.ir.Value:
|
||||
"""Returns the given value or the single result of the given op.
|
||||
"""Returns the given value or the single result of the given op.
|
||||
|
||||
This is useful to implement op constructors so that they can take other ops as
|
||||
arguments instead of requiring the caller to extract results for every op.
|
||||
Raises ValueError if provided with an op that doesn't have a single result.
|
||||
"""
|
||||
if isinstance(arg, _cext.ir.OpView):
|
||||
return arg.operation.result
|
||||
elif isinstance(arg, _cext.ir.Operation):
|
||||
return arg.result
|
||||
elif isinstance(arg, _cext.ir.OpResultList):
|
||||
return arg[0]
|
||||
else:
|
||||
assert isinstance(arg, _cext.ir.Value)
|
||||
return arg
|
||||
This is useful to implement op constructors so that they can take other ops as
|
||||
arguments instead of requiring the caller to extract results for every op.
|
||||
Raises ValueError if provided with an op that doesn't have a single result.
|
||||
"""
|
||||
if isinstance(arg, _cext.ir.OpView):
|
||||
return arg.operation.result
|
||||
elif isinstance(arg, _cext.ir.Operation):
|
||||
return arg.result
|
||||
elif isinstance(arg, _cext.ir.OpResultList):
|
||||
return arg[0]
|
||||
else:
|
||||
assert isinstance(arg, _cext.ir.Value)
|
||||
return arg
|
||||
|
||||
|
||||
def get_op_results_or_values(
|
||||
arg: _Union[_cext.ir.OpView, _cext.ir.Operation,
|
||||
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]]
|
||||
arg: _Union[
|
||||
_cext.ir.OpView,
|
||||
_cext.ir.Operation,
|
||||
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
|
||||
]
|
||||
) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
|
||||
"""Returns the given sequence of values or the results of the given op.
|
||||
"""Returns the given sequence of values or the results of the given op.
|
||||
|
||||
This is useful to implement op constructors so that they can take other ops as
|
||||
lists of arguments instead of requiring the caller to extract results for
|
||||
every op.
|
||||
"""
|
||||
if isinstance(arg, _cext.ir.OpView):
|
||||
return arg.operation.results
|
||||
elif isinstance(arg, _cext.ir.Operation):
|
||||
return arg.results
|
||||
else:
|
||||
return [get_op_result_or_value(element) for element in arg]
|
||||
This is useful to implement op constructors so that they can take other ops as
|
||||
lists of arguments instead of requiring the caller to extract results for
|
||||
every op.
|
||||
"""
|
||||
if isinstance(arg, _cext.ir.OpView):
|
||||
return arg.operation.results
|
||||
elif isinstance(arg, _cext.ir.Operation):
|
||||
return arg.results
|
||||
else:
|
||||
return [get_op_result_or_value(element) for element in arg]
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..dialects import pdl
|
||||
from ..ir import *
|
||||
from ..dialects import pdl
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Union, Optional, Sequence, Mapping
|
||||
from ._ods_common import (
|
||||
@@ -16,264 +16,256 @@ from ._ods_common import (
|
||||
|
||||
|
||||
class ApplyNativeConstraintOp:
|
||||
"""Specialization for PDL apply native constraint op class."""
|
||||
"""Specialization for PDL apply native constraint op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Union[str, StringAttr],
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
super().__init__(name, args, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
name: Union[str, StringAttr],
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
super().__init__(name, args, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class ApplyNativeRewriteOp:
|
||||
"""Specialization for PDL apply native rewrite op class."""
|
||||
"""Specialization for PDL apply native rewrite op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
results: Sequence[Type],
|
||||
name: Union[str, StringAttr],
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
super().__init__(results, name, args, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
results: Sequence[Type],
|
||||
name: Union[str, StringAttr],
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
super().__init__(results, name, args, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class AttributeOp:
|
||||
"""Specialization for PDL attribute op class."""
|
||||
"""Specialization for PDL attribute op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
valueType: Optional[Union[OpView, Operation, Value]] = None,
|
||||
value: Optional[Attribute] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
valueType = valueType if valueType is None else _get_value(valueType)
|
||||
result = pdl.AttributeType.get()
|
||||
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
valueType: Optional[Union[OpView, Operation, Value]] = None,
|
||||
value: Optional[Attribute] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
valueType = valueType if valueType is None else _get_value(valueType)
|
||||
result = pdl.AttributeType.get()
|
||||
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class EraseOp:
|
||||
"""Specialization for PDL erase op class."""
|
||||
"""Specialization for PDL erase op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
operation: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
operation = _get_value(operation)
|
||||
super().__init__(operation, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
operation: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
operation = _get_value(operation)
|
||||
super().__init__(operation, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class OperandOp:
|
||||
"""Specialization for PDL operand op class."""
|
||||
"""Specialization for PDL operand op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
type = type if type is None else _get_value(type)
|
||||
result = pdl.ValueType.get()
|
||||
super().__init__(result, valueType=type, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
type = type if type is None else _get_value(type)
|
||||
result = pdl.ValueType.get()
|
||||
super().__init__(result, valueType=type, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class OperandsOp:
|
||||
"""Specialization for PDL operands op class."""
|
||||
"""Specialization for PDL operands op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
types: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
types = types if types is None else _get_value(types)
|
||||
result = pdl.RangeType.get(pdl.ValueType.get())
|
||||
super().__init__(result, valueType=types, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
types: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
types = types if types is None else _get_value(types)
|
||||
result = pdl.RangeType.get(pdl.ValueType.get())
|
||||
super().__init__(result, valueType=types, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class OperationOp:
|
||||
"""Specialization for PDL operand op class."""
|
||||
"""Specialization for PDL operand op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[Union[str, StringAttr]] = None,
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
attributes: Optional[Mapping[str, Union[OpView, Operation,
|
||||
Value]]] = None,
|
||||
types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if types is None:
|
||||
types = []
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
attrNames = []
|
||||
attrValues = []
|
||||
for attrName, attrValue in attributes.items():
|
||||
attrNames.append(StringAttr.get(attrName))
|
||||
attrValues.append(_get_value(attrValue))
|
||||
attrNames = ArrayAttr.get(attrNames)
|
||||
types = _get_values(types)
|
||||
result = pdl.OperationType.get()
|
||||
super().__init__(result,
|
||||
args,
|
||||
attrValues,
|
||||
attrNames,
|
||||
types,
|
||||
opName=name,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[Union[str, StringAttr]] = None,
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
|
||||
types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if types is None:
|
||||
types = []
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
attrNames = []
|
||||
attrValues = []
|
||||
for attrName, attrValue in attributes.items():
|
||||
attrNames.append(StringAttr.get(attrName))
|
||||
attrValues.append(_get_value(attrValue))
|
||||
attrNames = ArrayAttr.get(attrNames)
|
||||
types = _get_values(types)
|
||||
result = pdl.OperationType.get()
|
||||
super().__init__(
|
||||
result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
class PatternOp:
|
||||
"""Specialization for PDL pattern op class."""
|
||||
"""Specialization for PDL pattern op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
benefit: Union[IntegerAttr, int],
|
||||
name: Optional[Union[StringAttr, str]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an PDL `pattern` operation."""
|
||||
super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
|
||||
self.regions[0].blocks.append()
|
||||
def __init__(
|
||||
self,
|
||||
benefit: Union[IntegerAttr, int],
|
||||
name: Optional[Union[StringAttr, str]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an PDL `pattern` operation."""
|
||||
super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
|
||||
self.regions[0].blocks.append()
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""Return the body (block) of the pattern."""
|
||||
return self.regions[0].blocks[0]
|
||||
@property
|
||||
def body(self):
|
||||
"""Return the body (block) of the pattern."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
class ReplaceOp:
|
||||
"""Specialization for PDL replace op class."""
|
||||
"""Specialization for PDL replace op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
op: Union[OpView, Operation, Value],
|
||||
*,
|
||||
with_op: Optional[Union[OpView, Operation, Value]] = None,
|
||||
with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if with_values is None:
|
||||
with_values = []
|
||||
op = _get_value(op)
|
||||
with_op = with_op if with_op is None else _get_value(with_op)
|
||||
with_values = _get_values(with_values)
|
||||
super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
op: Union[OpView, Operation, Value],
|
||||
*,
|
||||
with_op: Optional[Union[OpView, Operation, Value]] = None,
|
||||
with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if with_values is None:
|
||||
with_values = []
|
||||
op = _get_value(op)
|
||||
with_op = with_op if with_op is None else _get_value(with_op)
|
||||
with_values = _get_values(with_values)
|
||||
super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class ResultOp:
|
||||
"""Specialization for PDL result op class."""
|
||||
"""Specialization for PDL result op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent: Union[OpView, Operation, Value],
|
||||
index: Union[IntegerAttr, int],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
parent = _get_value(parent)
|
||||
result = pdl.ValueType.get()
|
||||
super().__init__(result, parent, index, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
parent: Union[OpView, Operation, Value],
|
||||
index: Union[IntegerAttr, int],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
parent = _get_value(parent)
|
||||
result = pdl.ValueType.get()
|
||||
super().__init__(result, parent, index, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class ResultsOp:
|
||||
"""Specialization for PDL results op class."""
|
||||
"""Specialization for PDL results op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result: Type,
|
||||
parent: Union[OpView, Operation, Value],
|
||||
index: Optional[Union[IntegerAttr, int]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
parent = _get_value(parent)
|
||||
super().__init__(result, parent, index=index, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
result: Type,
|
||||
parent: Union[OpView, Operation, Value],
|
||||
index: Optional[Union[IntegerAttr, int]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
parent = _get_value(parent)
|
||||
super().__init__(result, parent, index=index, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class RewriteOp:
|
||||
"""Specialization for PDL rewrite op class."""
|
||||
"""Specialization for PDL rewrite op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Optional[Union[OpView, Operation, Value]] = None,
|
||||
name: Optional[Union[StringAttr, str]] = None,
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
root = root if root is None else _get_value(root)
|
||||
args = _get_values(args)
|
||||
super().__init__(args, root=root, name=name, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
root: Optional[Union[OpView, Operation, Value]] = None,
|
||||
name: Optional[Union[StringAttr, str]] = None,
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
root = root if root is None else _get_value(root)
|
||||
args = _get_values(args)
|
||||
super().__init__(args, root=root, name=name, loc=loc, ip=ip)
|
||||
|
||||
def add_body(self):
|
||||
"""Add body (block) to the rewrite."""
|
||||
self.regions[0].blocks.append()
|
||||
return self.body
|
||||
def add_body(self):
|
||||
"""Add body (block) to the rewrite."""
|
||||
self.regions[0].blocks.append()
|
||||
return self.body
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""Return the body (block) of the rewrite."""
|
||||
return self.regions[0].blocks[0]
|
||||
@property
|
||||
def body(self):
|
||||
"""Return the body (block) of the rewrite."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
class TypeOp:
|
||||
"""Specialization for PDL type op class."""
|
||||
"""Specialization for PDL type op class."""
|
||||
|
||||
def __init__(self,
|
||||
constantType: Optional[Union[TypeAttr, Type]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
result = pdl.TypeType.get()
|
||||
super().__init__(result, constantType=constantType, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
|
||||
):
|
||||
result = pdl.TypeType.get()
|
||||
super().__init__(result, constantType=constantType, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class TypesOp:
|
||||
"""Specialization for PDL types op class."""
|
||||
"""Specialization for PDL types op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if constantTypes is None:
|
||||
constantTypes = []
|
||||
result = pdl.RangeType.get(pdl.TypeType.get())
|
||||
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if constantTypes is None:
|
||||
constantTypes = []
|
||||
result = pdl.RangeType.get(pdl.TypeType.get())
|
||||
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
|
||||
|
||||
@@ -3,105 +3,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..ir import *
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
)
|
||||
|
||||
|
||||
class ForOp:
|
||||
"""Specialization for the SCF for op class."""
|
||||
"""Specialization for the SCF for op class."""
|
||||
|
||||
def __init__(self,
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
step,
|
||||
iter_args: Optional[Union[Operation, OpView,
|
||||
Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""Creates an SCF `for` operation.
|
||||
def __init__(
|
||||
self,
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
step,
|
||||
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
"""Creates an SCF `for` operation.
|
||||
|
||||
- `lower_bound` is the value to use as lower bound of the loop.
|
||||
- `upper_bound` is the value to use as upper bound of the loop.
|
||||
- `step` is the value to use as loop step.
|
||||
- `iter_args` is a list of additional loop-carried arguments or an operation
|
||||
producing them as results.
|
||||
"""
|
||||
if iter_args is None:
|
||||
iter_args = []
|
||||
iter_args = _get_op_results_or_values(iter_args)
|
||||
- `lower_bound` is the value to use as lower bound of the loop.
|
||||
- `upper_bound` is the value to use as upper bound of the loop.
|
||||
- `step` is the value to use as loop step.
|
||||
- `iter_args` is a list of additional loop-carried arguments or an operation
|
||||
producing them as results.
|
||||
"""
|
||||
if iter_args is None:
|
||||
iter_args = []
|
||||
iter_args = _get_op_results_or_values(iter_args)
|
||||
|
||||
results = [arg.type for arg in iter_args]
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
regions=1,
|
||||
results=results,
|
||||
operands=[
|
||||
_get_op_result_or_value(o)
|
||||
for o in [lower_bound, upper_bound, step]
|
||||
] + list(iter_args),
|
||||
loc=loc,
|
||||
ip=ip))
|
||||
self.regions[0].blocks.append(IndexType.get(), *results)
|
||||
results = [arg.type for arg in iter_args]
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
regions=1,
|
||||
results=results,
|
||||
operands=[
|
||||
_get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
|
||||
]
|
||||
+ list(iter_args),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
self.regions[0].blocks.append(IndexType.get(), *results)
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""Returns the body (block) of the loop."""
|
||||
return self.regions[0].blocks[0]
|
||||
@property
|
||||
def body(self):
|
||||
"""Returns the body (block) of the loop."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def induction_variable(self):
|
||||
"""Returns the induction variable of the loop."""
|
||||
return self.body.arguments[0]
|
||||
@property
|
||||
def induction_variable(self):
|
||||
"""Returns the induction variable of the loop."""
|
||||
return self.body.arguments[0]
|
||||
|
||||
@property
|
||||
def inner_iter_args(self):
|
||||
"""Returns the loop-carried arguments usable within the loop.
|
||||
@property
|
||||
def inner_iter_args(self):
|
||||
"""Returns the loop-carried arguments usable within the loop.
|
||||
|
||||
To obtain the loop-carried operands, use `iter_args`.
|
||||
"""
|
||||
return self.body.arguments[1:]
|
||||
To obtain the loop-carried operands, use `iter_args`.
|
||||
"""
|
||||
return self.body.arguments[1:]
|
||||
|
||||
|
||||
class IfOp:
|
||||
"""Specialization for the SCF if op class."""
|
||||
"""Specialization for the SCF if op class."""
|
||||
|
||||
def __init__(self,
|
||||
cond,
|
||||
results_=[],
|
||||
*,
|
||||
hasElse=False,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""Creates an SCF `if` operation.
|
||||
def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
|
||||
"""Creates an SCF `if` operation.
|
||||
|
||||
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
|
||||
- `hasElse` determines whether the if operation has the else branch.
|
||||
"""
|
||||
operands = []
|
||||
operands.append(cond)
|
||||
results = []
|
||||
results.extend(results_)
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
regions=2,
|
||||
results=results,
|
||||
operands=operands,
|
||||
loc=loc,
|
||||
ip=ip))
|
||||
self.regions[0].blocks.append(*[])
|
||||
if hasElse:
|
||||
self.regions[1].blocks.append(*[])
|
||||
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
|
||||
- `hasElse` determines whether the if operation has the else branch.
|
||||
"""
|
||||
operands = []
|
||||
operands.append(cond)
|
||||
results = []
|
||||
results.extend(results_)
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
regions=2, results=results, operands=operands, loc=loc, ip=ip
|
||||
)
|
||||
)
|
||||
self.regions[0].blocks.append(*[])
|
||||
if hasElse:
|
||||
self.regions[1].blocks.append(*[])
|
||||
|
||||
@property
|
||||
def then_block(self):
|
||||
"""Returns the then block of the if operation."""
|
||||
return self.regions[0].blocks[0]
|
||||
@property
|
||||
def then_block(self):
|
||||
"""Returns the then block of the if operation."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def else_block(self):
|
||||
"""Returns the else block of the if operation."""
|
||||
return self.regions[1].blocks[0]
|
||||
@property
|
||||
def else_block(self):
|
||||
"""Returns the else block of the if operation."""
|
||||
return self.regions[1].blocks[0]
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
from ..dialects import pdl, transform
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
from ..dialects import pdl, transform
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import List, Optional, Sequence, Union, overload
|
||||
|
||||
@@ -16,312 +16,315 @@ OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
|
||||
|
||||
|
||||
def _get_int_int_array_attr(
|
||||
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
|
||||
IntOrAttrList]]]]
|
||||
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
|
||||
) -> ArrayAttr:
|
||||
"""Creates an array attribute containing array attributes of integers.
|
||||
"""Creates an array attribute containing array attributes of integers.
|
||||
|
||||
If the operand is already an array attribute, forwards it. Otherwise treats
|
||||
the operand as a list of attributes or integers, potentially interpserced, to
|
||||
create a new array-of-array attribute. Expects the thread-local MLIR context
|
||||
to have been set by the context manager.
|
||||
"""
|
||||
if values is None:
|
||||
return ArrayAttr.get([])
|
||||
if isinstance(values, ArrayAttr):
|
||||
return values
|
||||
if isinstance(values, list):
|
||||
values = [
|
||||
ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), v)
|
||||
for v in value])
|
||||
for value in values
|
||||
]
|
||||
if values is None:
|
||||
return ArrayAttr.get([])
|
||||
if isinstance(values, ArrayAttr):
|
||||
return values
|
||||
if isinstance(values, list):
|
||||
values = [
|
||||
ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in value]
|
||||
)
|
||||
for value in values
|
||||
]
|
||||
|
||||
return ArrayAttr.get(values)
|
||||
return ArrayAttr.get(values)
|
||||
|
||||
|
||||
class DecomposeOp:
|
||||
"""Specialization for DecomposeOp class."""
|
||||
"""Specialization for DecomposeOp class."""
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
super().__init__(pdl.OperationType.get(),
|
||||
_get_op_result_or_value(target),
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
class GeneralizeOp:
|
||||
"""Specialization for GeneralizeOp class."""
|
||||
"""Specialization for GeneralizeOp class."""
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
super().__init__(pdl.OperationType.get(),
|
||||
_get_op_result_or_value(target),
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
class InterchangeOp:
|
||||
"""Specialization for InterchangeOp class."""
|
||||
"""Specialization for InterchangeOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
iterator_interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
super().__init__(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
iterator_interchange=iterator_interchange,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
iterator_interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
super().__init__(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
iterator_interchange=iterator_interchange,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class MatchOp:
|
||||
"""Specialization for MatchOp class."""
|
||||
"""Specialization for MatchOp class."""
|
||||
|
||||
@classmethod
|
||||
def match_op_names(
|
||||
MatchOp,
|
||||
target: Union[Operation, Value],
|
||||
names: Sequence[str],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
return MatchOp(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
@classmethod
|
||||
def match_op_names(
|
||||
MatchOp,
|
||||
target: Union[Operation, Value],
|
||||
names: Sequence[str],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
return MatchOp(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class MultiTileSizesOp:
|
||||
"""Specialization for MultitileSizesOp class."""
|
||||
"""Specialization for MultitileSizesOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
dimension: Union[int, IntegerAttr],
|
||||
target_size: Union[int, IntegerAttr],
|
||||
divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if divisor is None:
|
||||
divisor = 1
|
||||
super().__init__(
|
||||
result_type,
|
||||
result_type,
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
dimension=dimension,
|
||||
target_size=target_size,
|
||||
divisor=divisor,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
dimension: Union[int, IntegerAttr],
|
||||
target_size: Union[int, IntegerAttr],
|
||||
divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if divisor is None:
|
||||
divisor = 1
|
||||
super().__init__(
|
||||
result_type,
|
||||
result_type,
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
dimension=dimension,
|
||||
target_size=target_size,
|
||||
divisor=divisor,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class PadOp:
|
||||
"""Specialization for PadOp class."""
|
||||
"""Specialization for PadOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
padding_values: Optional[Optional[Union[ArrayAttr,
|
||||
Sequence[Attribute]]]] = None,
|
||||
padding_dimensions: OptionalIntList = None,
|
||||
pack_paddings: OptionalIntList = None,
|
||||
transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
|
||||
ArrayAttr, IntOrAttrList]]]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if transpose_paddings is None:
|
||||
transpose_paddings = []
|
||||
if pack_paddings is None:
|
||||
pack_paddings = []
|
||||
if padding_dimensions is None:
|
||||
padding_dimensions = []
|
||||
if padding_values is None:
|
||||
padding_values = []
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
|
||||
super().__init__(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
padding_values=padding_values,
|
||||
padding_dimensions=padding_dimensions,
|
||||
pack_paddings=pack_paddings,
|
||||
transpose_paddings=transpose_paddings_attr,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
padding_values: Optional[
|
||||
Optional[Union[ArrayAttr, Sequence[Attribute]]]
|
||||
] = None,
|
||||
padding_dimensions: OptionalIntList = None,
|
||||
pack_paddings: OptionalIntList = None,
|
||||
transpose_paddings: Optional[
|
||||
Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
|
||||
] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if transpose_paddings is None:
|
||||
transpose_paddings = []
|
||||
if pack_paddings is None:
|
||||
pack_paddings = []
|
||||
if padding_dimensions is None:
|
||||
padding_dimensions = []
|
||||
if padding_values is None:
|
||||
padding_values = []
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
|
||||
super().__init__(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
padding_values=padding_values,
|
||||
padding_dimensions=padding_dimensions,
|
||||
pack_paddings=pack_paddings,
|
||||
transpose_paddings=transpose_paddings_attr,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class ScalarizeOp:
|
||||
"""Specialization for ScalarizeOp class."""
|
||||
"""Specialization for ScalarizeOp class."""
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
super().__init__(pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
super().__init__(
|
||||
pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
class SplitOp:
|
||||
"""Specialization for SplitOp class."""
|
||||
"""Specialization for SplitOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
dimension: Union[int, Attribute],
|
||||
split_point: Union[int, Operation, Value, Attribute],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(split_point, int):
|
||||
static_split_point = split_point
|
||||
dynamic_split_point = None
|
||||
else:
|
||||
static_split_point = ShapedType.get_dynamic_size()
|
||||
dynamic_split_point = _get_op_result_or_value(split_point)
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
dimension: Union[int, Attribute],
|
||||
split_point: Union[int, Operation, Value, Attribute],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(split_point, int):
|
||||
static_split_point = split_point
|
||||
dynamic_split_point = None
|
||||
else:
|
||||
static_split_point = ShapedType.get_dynamic_size()
|
||||
dynamic_split_point = _get_op_result_or_value(split_point)
|
||||
|
||||
target = _get_op_result_or_value(target)
|
||||
target = _get_op_result_or_value(target)
|
||||
|
||||
super().__init__(
|
||||
target.type,
|
||||
target.type,
|
||||
target,
|
||||
dimension=dimension,
|
||||
static_split_point=static_split_point,
|
||||
dynamic_split_point=dynamic_split_point,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
super().__init__(
|
||||
target.type,
|
||||
target.type,
|
||||
target,
|
||||
dimension=dimension,
|
||||
static_split_point=static_split_point,
|
||||
dynamic_split_point=dynamic_split_point,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class TileOp:
|
||||
"""Specialization for TileOp class."""
|
||||
"""Specialization for TileOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
loop_types: Union[Type, List[Type]],
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
|
||||
ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
loop_types: Union[Type, List[Type]],
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
sizes: Optional[
|
||||
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
|
||||
] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
|
||||
ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
sizes: Optional[
|
||||
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
|
||||
] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loop_types_or_target: Union[Type, List[Type], Operation, Value],
|
||||
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
|
||||
*,
|
||||
sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]],
|
||||
ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if interchange is None:
|
||||
interchange = []
|
||||
if sizes is None:
|
||||
sizes = []
|
||||
def __init__(
|
||||
self,
|
||||
loop_types_or_target: Union[Type, List[Type], Operation, Value],
|
||||
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
|
||||
*,
|
||||
sizes: Optional[
|
||||
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
|
||||
] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if interchange is None:
|
||||
interchange = []
|
||||
if sizes is None:
|
||||
sizes = []
|
||||
|
||||
static_sizes = []
|
||||
dynamic_sizes = []
|
||||
if isinstance(sizes, ArrayAttr):
|
||||
sizes_attr = sizes
|
||||
else:
|
||||
for size in sizes:
|
||||
if isinstance(size, int):
|
||||
static_sizes.append(size)
|
||||
static_sizes = []
|
||||
dynamic_sizes = []
|
||||
if isinstance(sizes, ArrayAttr):
|
||||
sizes_attr = sizes
|
||||
else:
|
||||
static_sizes.append(ShapedType.get_dynamic_size())
|
||||
dynamic_sizes.append(_get_op_result_or_value(size))
|
||||
sizes_attr = DenseI64ArrayAttr.get(static_sizes)
|
||||
for size in sizes:
|
||||
if isinstance(size, int):
|
||||
static_sizes.append(size)
|
||||
else:
|
||||
static_sizes.append(ShapedType.get_dynamic_size())
|
||||
dynamic_sizes.append(_get_op_result_or_value(size))
|
||||
sizes_attr = DenseI64ArrayAttr.get(static_sizes)
|
||||
|
||||
num_loops = sum(
|
||||
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
|
||||
num_loops = sum(v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
|
||||
|
||||
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
|
||||
loop_types = [transform.AnyOpType.get()] * num_loops
|
||||
target = loop_types_or_target
|
||||
assert target_or_none is None, "Cannot construct TileOp with two targets."
|
||||
else:
|
||||
loop_types = (([loop_types_or_target] * num_loops) if isinstance(
|
||||
loop_types_or_target, Type) else loop_types_or_target)
|
||||
target = target_or_none
|
||||
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
|
||||
loop_types = [transform.AnyOpType.get()] * num_loops
|
||||
target = loop_types_or_target
|
||||
assert target_or_none is None, "Cannot construct TileOp with two targets."
|
||||
else:
|
||||
loop_types = (
|
||||
([loop_types_or_target] * num_loops)
|
||||
if isinstance(loop_types_or_target, Type)
|
||||
else loop_types_or_target
|
||||
)
|
||||
target = target_or_none
|
||||
|
||||
target = _get_op_result_or_value(target)
|
||||
target = _get_op_result_or_value(target)
|
||||
|
||||
super().__init__(
|
||||
target.type,
|
||||
loop_types,
|
||||
target,
|
||||
dynamic_sizes=dynamic_sizes,
|
||||
static_sizes=sizes_attr,
|
||||
interchange=interchange,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
super().__init__(
|
||||
target.type,
|
||||
loop_types,
|
||||
target,
|
||||
dynamic_sizes=dynamic_sizes,
|
||||
static_sizes=sizes_attr,
|
||||
interchange=interchange,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
|
||||
if not attr:
|
||||
return []
|
||||
return [element for element in attr]
|
||||
def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
|
||||
if not attr:
|
||||
return []
|
||||
return [element for element in attr]
|
||||
|
||||
|
||||
class VectorizeOp:
|
||||
"""Specialization for VectorizeOp class."""
|
||||
"""Specialization for VectorizeOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
vectorize_padding: Union[bool, BoolAttr] = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
if isinstance(vectorize_padding, bool):
|
||||
vectorize_padding = UnitAttr.get()
|
||||
super().__init__(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
vectorize_padding=vectorize_padding,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
vectorize_padding: Union[bool, BoolAttr] = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
if isinstance(vectorize_padding, bool):
|
||||
vectorize_padding = UnitAttr.get()
|
||||
super().__init__(
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
vectorize_padding=vectorize_padding,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -3,40 +3,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..ir import *
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
)
|
||||
|
||||
|
||||
class EmptyOp:
|
||||
"""Extends the tensor.empty op."""
|
||||
"""Extends the tensor.empty op."""
|
||||
|
||||
def __init__(self,
|
||||
sizes: Sequence[Union[int, Value]],
|
||||
element_type: Type,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""Constructs an `empty` with mixed static/dynamic sizes."""
|
||||
# TODO: Refactor the EmptyOp to take an element type attribute and
|
||||
# then use normal result type inference, unifying the Python and C++ side
|
||||
# with a standard mechanism (versus stashing that in builders).
|
||||
dynamic_sizes = []
|
||||
static_sizes = []
|
||||
for s in sizes:
|
||||
if isinstance(s, int):
|
||||
static_sizes.append(s)
|
||||
else:
|
||||
static_sizes.append(ShapedType.get_dynamic_size())
|
||||
dynamic_sizes.append(s)
|
||||
result_type = RankedTensorType.get(static_sizes, element_type)
|
||||
op = self.build_generic(
|
||||
results=[result_type],
|
||||
operands=dynamic_sizes,
|
||||
attributes={},
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
OpView.__init__(self, op)
|
||||
def __init__(
|
||||
self,
|
||||
sizes: Sequence[Union[int, Value]],
|
||||
element_type: Type,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
"""Constructs an `empty` with mixed static/dynamic sizes."""
|
||||
# TODO: Refactor the EmptyOp to take an element type attribute and
|
||||
# then use normal result type inference, unifying the Python and C++ side
|
||||
# with a standard mechanism (versus stashing that in builders).
|
||||
dynamic_sizes = []
|
||||
static_sizes = []
|
||||
for s in sizes:
|
||||
if isinstance(s, int):
|
||||
static_sizes.append(s)
|
||||
else:
|
||||
static_sizes.append(ShapedType.get_dynamic_size())
|
||||
dynamic_sizes.append(s)
|
||||
result_type = RankedTensorType.get(static_sizes, element_type)
|
||||
op = self.build_generic(
|
||||
results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip
|
||||
)
|
||||
OpView.__init__(self, op)
|
||||
|
||||
@@ -3,144 +3,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
)
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
class CastOp:
|
||||
|
||||
def __init__(self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
super().__init__(result_type,
|
||||
_get_op_result_or_value(target),
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
def __init__(
|
||||
self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None
|
||||
):
|
||||
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
|
||||
|
||||
|
||||
class GetClosestIsolatedParentOp:
|
||||
|
||||
def __init__(self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
super().__init__(result_type,
|
||||
_get_op_result_or_value(target),
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
def __init__(
|
||||
self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None
|
||||
):
|
||||
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
|
||||
|
||||
|
||||
class MergeHandlesOp:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handles: Sequence[Union[Operation, Value]],
|
||||
*,
|
||||
deduplicate: bool = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
[_get_op_result_or_value(h) for h in handles],
|
||||
deduplicate=deduplicate,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
handles: Sequence[Union[Operation, Value]],
|
||||
*,
|
||||
deduplicate: bool = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
[_get_op_result_or_value(h) for h in handles],
|
||||
deduplicate=deduplicate,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class ReplicateOp:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pattern: Union[Operation, Value],
|
||||
handles: Sequence[Union[Operation, Value]],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
[_get_op_result_or_value(h).type for h in handles],
|
||||
_get_op_result_or_value(pattern),
|
||||
[_get_op_result_or_value(h) for h in handles],
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
pattern: Union[Operation, Value],
|
||||
handles: Sequence[Union[Operation, Value]],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
[_get_op_result_or_value(h).type for h in handles],
|
||||
_get_op_result_or_value(pattern),
|
||||
[_get_op_result_or_value(h) for h in handles],
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class SequenceOp:
|
||||
def __init__(
|
||||
self,
|
||||
failure_propagation_mode,
|
||||
results: Sequence[Type],
|
||||
target: Union[Operation, Value, Type],
|
||||
extra_bindings: Optional[
|
||||
Union[Sequence[Value], Sequence[Type], Operation, OpView]
|
||||
] = None,
|
||||
):
|
||||
root = (
|
||||
_get_op_result_or_value(target)
|
||||
if isinstance(target, (Operation, Value))
|
||||
else None
|
||||
)
|
||||
root_type = root.type if not isinstance(target, Type) else target
|
||||
if not isinstance(failure_propagation_mode, Attribute):
|
||||
failure_propagation_mode_attr = IntegerAttr.get(
|
||||
IntegerType.get_signless(32), failure_propagation_mode._as_int()
|
||||
)
|
||||
else:
|
||||
failure_propagation_mode_attr = failure_propagation_mode
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_propagation_mode,
|
||||
results: Sequence[Type],
|
||||
target: Union[Operation, Value, Type],
|
||||
extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], Operation,
|
||||
OpView]] = None,
|
||||
):
|
||||
root = (_get_op_result_or_value(target) if isinstance(
|
||||
target, (Operation, Value)) else None)
|
||||
root_type = root.type if not isinstance(target, Type) else target
|
||||
if not isinstance(failure_propagation_mode, Attribute):
|
||||
failure_propagation_mode_attr = IntegerAttr.get(
|
||||
IntegerType.get_signless(32), failure_propagation_mode._as_int())
|
||||
else:
|
||||
failure_propagation_mode_attr = failure_propagation_mode
|
||||
if extra_bindings is None:
|
||||
extra_bindings = []
|
||||
if isinstance(extra_bindings, (Operation, OpView)):
|
||||
extra_bindings = _get_op_results_or_values(extra_bindings)
|
||||
|
||||
if extra_bindings is None:
|
||||
extra_bindings = []
|
||||
if isinstance(extra_bindings, (Operation, OpView)):
|
||||
extra_bindings = _get_op_results_or_values(extra_bindings)
|
||||
extra_binding_types = []
|
||||
if len(extra_bindings) != 0:
|
||||
if isinstance(extra_bindings[0], Type):
|
||||
extra_binding_types = extra_bindings
|
||||
extra_bindings = []
|
||||
else:
|
||||
extra_binding_types = [v.type for v in extra_bindings]
|
||||
|
||||
extra_binding_types = []
|
||||
if len(extra_bindings) != 0:
|
||||
if isinstance(extra_bindings[0], Type):
|
||||
extra_binding_types = extra_bindings
|
||||
extra_bindings = []
|
||||
else:
|
||||
extra_binding_types = [v.type for v in extra_bindings]
|
||||
super().__init__(
|
||||
results_=results,
|
||||
failure_propagation_mode=failure_propagation_mode_attr,
|
||||
root=root,
|
||||
extra_bindings=extra_bindings,
|
||||
)
|
||||
self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
|
||||
|
||||
super().__init__(
|
||||
results_=results,
|
||||
failure_propagation_mode=failure_propagation_mode_attr,
|
||||
root=root,
|
||||
extra_bindings=extra_bindings,
|
||||
)
|
||||
self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
@property
|
||||
def bodyTarget(self) -> Value:
|
||||
return self.body.arguments[0]
|
||||
|
||||
@property
|
||||
def bodyTarget(self) -> Value:
|
||||
return self.body.arguments[0]
|
||||
|
||||
@property
|
||||
def bodyExtraArgs(self) -> BlockArgumentList:
|
||||
return self.body.arguments[1:]
|
||||
@property
|
||||
def bodyExtraArgs(self) -> BlockArgumentList:
|
||||
return self.body.arguments[1:]
|
||||
|
||||
|
||||
class YieldOp:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
operands: Optional[Union[Operation, Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if operands is None:
|
||||
operands = []
|
||||
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
|
||||
def __init__(
|
||||
self,
|
||||
operands: Optional[Union[Operation, Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if operands is None:
|
||||
operands = []
|
||||
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
|
||||
|
||||
@@ -31,61 +31,60 @@ from .lang.yaml_helper import *
|
||||
|
||||
|
||||
def create_arg_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(description="Dump an oplib in various formats")
|
||||
p.add_argument("modules",
|
||||
metavar="M",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Op module to dump")
|
||||
p.add_argument("--file",
|
||||
metavar="F",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Python op file to dump")
|
||||
p.add_argument("--format",
|
||||
type=str,
|
||||
dest="format",
|
||||
default="yaml",
|
||||
choices=("yaml", "repr"),
|
||||
help="Format in which to dump")
|
||||
return p
|
||||
p = argparse.ArgumentParser(description="Dump an oplib in various formats")
|
||||
p.add_argument(
|
||||
"modules", metavar="M", type=str, nargs="*", help="Op module to dump"
|
||||
)
|
||||
p.add_argument(
|
||||
"--file", metavar="F", type=str, nargs="*", help="Python op file to dump"
|
||||
)
|
||||
p.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
dest="format",
|
||||
default="yaml",
|
||||
choices=("yaml", "repr"),
|
||||
help="Format in which to dump",
|
||||
)
|
||||
return p
|
||||
|
||||
|
||||
def load_module_from_file(module_name, file_path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
m = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(m)
|
||||
return m
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
m = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(m)
|
||||
return m
|
||||
|
||||
|
||||
def main(args):
|
||||
# Load all configs.
|
||||
configs = []
|
||||
modules = []
|
||||
for module_name in args.modules:
|
||||
modules.append(
|
||||
importlib.import_module(module_name,
|
||||
package="mlir.dialects.linalg.opdsl"))
|
||||
for i, file_path in enumerate(args.file or []):
|
||||
modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path))
|
||||
for m in modules:
|
||||
for attr_name, value in m.__dict__.items():
|
||||
# TODO: This class layering is awkward.
|
||||
if isinstance(value, DefinedOpCallable):
|
||||
try:
|
||||
linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not create LinalgOpConfig from {value.op_def}") from e
|
||||
configs.extend(linalg_config)
|
||||
# Load all configs.
|
||||
configs = []
|
||||
modules = []
|
||||
for module_name in args.modules:
|
||||
modules.append(
|
||||
importlib.import_module(module_name, package="mlir.dialects.linalg.opdsl")
|
||||
)
|
||||
for i, file_path in enumerate(args.file or []):
|
||||
modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path))
|
||||
for m in modules:
|
||||
for attr_name, value in m.__dict__.items():
|
||||
# TODO: This class layering is awkward.
|
||||
if isinstance(value, DefinedOpCallable):
|
||||
try:
|
||||
linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not create LinalgOpConfig from {value.op_def}"
|
||||
) from e
|
||||
configs.extend(linalg_config)
|
||||
|
||||
# Print.
|
||||
if args.format == "yaml":
|
||||
print(yaml_dump_all(configs))
|
||||
elif args.format == "repr":
|
||||
for config in configs:
|
||||
print(repr(config))
|
||||
# Print.
|
||||
if args.format == "yaml":
|
||||
print(yaml_dump_all(configs))
|
||||
elif args.format == "repr":
|
||||
for config in configs:
|
||||
print(repr(config))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(create_arg_parser().parse_args())
|
||||
main(create_arg_parser().parse_args())
|
||||
|
||||
@@ -66,201 +66,201 @@ __all__ = [
|
||||
|
||||
|
||||
class AffineBuildState:
|
||||
"""Internal state for the AffineExprDef._create impls.
|
||||
"""Internal state for the AffineExprDef._create impls.
|
||||
|
||||
Note that a "local" AffineBuildState can be created relative to a "global"
|
||||
AffineBuildState. In that case, any affine expressions built will inherit
|
||||
symbol and dim bindings from the global state and will update both as new
|
||||
ones are discovered. This allows for building expressions across contexts
|
||||
which share a common symbol and dim space.
|
||||
"""
|
||||
Note that a "local" AffineBuildState can be created relative to a "global"
|
||||
AffineBuildState. In that case, any affine expressions built will inherit
|
||||
symbol and dim bindings from the global state and will update both as new
|
||||
ones are discovered. This allows for building expressions across contexts
|
||||
which share a common symbol and dim space.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
global_state: "AffineBuildState" = None,
|
||||
allow_new_symbols: bool = True,
|
||||
allow_new_dims: bool = True):
|
||||
if not global_state:
|
||||
self.all_symbols = dict() # type: Dict[str, int]
|
||||
self.all_dims = dict() # type: Dict[str, int]
|
||||
else:
|
||||
# Alias the global dict.
|
||||
self.all_symbols = global_state.all_symbols
|
||||
self.all_dims = global_state.all_dims
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
global_state: "AffineBuildState" = None,
|
||||
allow_new_symbols: bool = True,
|
||||
allow_new_dims: bool = True,
|
||||
):
|
||||
if not global_state:
|
||||
self.all_symbols = dict() # type: Dict[str, int]
|
||||
self.all_dims = dict() # type: Dict[str, int]
|
||||
else:
|
||||
# Alias the global dict.
|
||||
self.all_symbols = global_state.all_symbols
|
||||
self.all_dims = global_state.all_dims
|
||||
|
||||
# Map of symbols and dims in the current build.
|
||||
self.local_symbols = dict() # type: Dict[str, int]
|
||||
self.local_dims = dict() # type: Dict[str, int]
|
||||
self.allow_new_symbols = allow_new_symbols
|
||||
self.allow_new_dims = allow_new_dims
|
||||
# Map of symbols and dims in the current build.
|
||||
self.local_symbols = dict() # type: Dict[str, int]
|
||||
self.local_dims = dict() # type: Dict[str, int]
|
||||
self.allow_new_symbols = allow_new_symbols
|
||||
self.allow_new_dims = allow_new_dims
|
||||
|
||||
def get_dim(self, dimname: str) -> int:
|
||||
"""Gets the dim position given a name."""
|
||||
pos = self.all_dims.get(dimname)
|
||||
if pos is None:
|
||||
if not self.allow_new_dims:
|
||||
raise ValueError(
|
||||
f"New dimensions not allowed in the current affine expression: "
|
||||
f"Requested '{dimname}', Availble: {self.all_dims}")
|
||||
pos = len(self.all_dims)
|
||||
self.all_dims[dimname] = pos
|
||||
self.local_dims[dimname] = pos
|
||||
return pos
|
||||
def get_dim(self, dimname: str) -> int:
|
||||
"""Gets the dim position given a name."""
|
||||
pos = self.all_dims.get(dimname)
|
||||
if pos is None:
|
||||
if not self.allow_new_dims:
|
||||
raise ValueError(
|
||||
f"New dimensions not allowed in the current affine expression: "
|
||||
f"Requested '{dimname}', Availble: {self.all_dims}"
|
||||
)
|
||||
pos = len(self.all_dims)
|
||||
self.all_dims[dimname] = pos
|
||||
self.local_dims[dimname] = pos
|
||||
return pos
|
||||
|
||||
def get_symbol(self, symname: str) -> int:
|
||||
"""Geta a symbol position given a name."""
|
||||
pos = self.all_symbols.get(symname)
|
||||
if pos is None:
|
||||
if not self.allow_new_symbols:
|
||||
raise ValueError(
|
||||
f"New symbols not allowed in the current affine expression: "
|
||||
f"Requested '{symname}', Availble: {self.all_symbols}")
|
||||
pos = len(self.all_symbols)
|
||||
self.all_symbols[symname] = pos
|
||||
self.local_symbols[symname] = pos
|
||||
return pos
|
||||
def get_symbol(self, symname: str) -> int:
|
||||
"""Geta a symbol position given a name."""
|
||||
pos = self.all_symbols.get(symname)
|
||||
if pos is None:
|
||||
if not self.allow_new_symbols:
|
||||
raise ValueError(
|
||||
f"New symbols not allowed in the current affine expression: "
|
||||
f"Requested '{symname}', Availble: {self.all_symbols}"
|
||||
)
|
||||
pos = len(self.all_symbols)
|
||||
self.all_symbols[symname] = pos
|
||||
self.local_symbols[symname] = pos
|
||||
return pos
|
||||
|
||||
@property
|
||||
def local_dim_count(self) -> int:
|
||||
return len(self.local_dims)
|
||||
@property
|
||||
def local_dim_count(self) -> int:
|
||||
return len(self.local_dims)
|
||||
|
||||
@property
|
||||
def local_symbol_count(self) -> int:
|
||||
return len(self.local_symbols)
|
||||
@property
|
||||
def local_symbol_count(self) -> int:
|
||||
return len(self.local_symbols)
|
||||
|
||||
@property
|
||||
def dim_count(self) -> int:
|
||||
return len(self.all_dims)
|
||||
@property
|
||||
def dim_count(self) -> int:
|
||||
return len(self.all_dims)
|
||||
|
||||
@property
|
||||
def symbol_count(self) -> int:
|
||||
return len(self.all_symbols)
|
||||
@property
|
||||
def symbol_count(self) -> int:
|
||||
return len(self.all_symbols)
|
||||
|
||||
def __repr__(self):
|
||||
lines = [f"AffineBuildState<"]
|
||||
lines.append(f" symbols={self.local_symbols}")
|
||||
lines.append(f" dims={self.local_dims}>")
|
||||
return "\n".join(lines)
|
||||
def __repr__(self):
|
||||
lines = [f"AffineBuildState<"]
|
||||
lines.append(f" symbols={self.local_symbols}")
|
||||
lines.append(f" dims={self.local_dims}>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class AffineExprDef:
|
||||
"""Base class for an affine expression being defined."""
|
||||
"""Base class for an affine expression being defined."""
|
||||
|
||||
def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
|
||||
"""Builds the corresponding _ir.AffineExpr from the definitions.
|
||||
"""
|
||||
state = AffineBuildState() if state is None else state
|
||||
expr = self._create(state)
|
||||
return expr
|
||||
def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr:
|
||||
"""Builds the corresponding _ir.AffineExpr from the definitions."""
|
||||
state = AffineBuildState() if state is None else state
|
||||
expr = self._create(state)
|
||||
return expr
|
||||
|
||||
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
|
||||
raise NotImplementedError()
|
||||
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def coerce_from(py_value):
|
||||
if isinstance(py_value, int):
|
||||
return AffineConstantExpr(py_value)
|
||||
assert isinstance(py_value, AffineExprDef)
|
||||
return py_value
|
||||
@staticmethod
|
||||
def coerce_from(py_value):
|
||||
if isinstance(py_value, int):
|
||||
return AffineConstantExpr(py_value)
|
||||
assert isinstance(py_value, AffineExprDef)
|
||||
return py_value
|
||||
|
||||
def visit_affine_exprs(self, callback):
|
||||
"""Visits all AffineExprDefs including self."""
|
||||
callback(self)
|
||||
def visit_affine_exprs(self, callback):
|
||||
"""Visits all AffineExprDefs including self."""
|
||||
callback(self)
|
||||
|
||||
def __add__(lhs, rhs):
|
||||
rhs = AffineExprDef.coerce_from(rhs)
|
||||
return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
|
||||
def __add__(lhs, rhs):
|
||||
rhs = AffineExprDef.coerce_from(rhs)
|
||||
return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs)
|
||||
|
||||
def __mul__(lhs, rhs):
|
||||
rhs = AffineExprDef.coerce_from(rhs)
|
||||
return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
|
||||
def __mul__(lhs, rhs):
|
||||
rhs = AffineExprDef.coerce_from(rhs)
|
||||
return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs)
|
||||
|
||||
def __mod__(lhs, rhs):
|
||||
rhs = AffineExprDef.coerce_from(rhs)
|
||||
return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
|
||||
def __mod__(lhs, rhs):
|
||||
rhs = AffineExprDef.coerce_from(rhs)
|
||||
return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs)
|
||||
|
||||
def __floordiv__(lhs, rhs):
|
||||
rhs = AffineExprDef.coerce_from(rhs)
|
||||
return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
|
||||
def __floordiv__(lhs, rhs):
|
||||
rhs = AffineExprDef.coerce_from(rhs)
|
||||
return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs)
|
||||
|
||||
def __truediv__(lhs, rhs):
|
||||
# TODO: Not really a ceil div - taking liberties for the DSL.
|
||||
rhs = AffineExprDef.coerce_from(rhs)
|
||||
return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
|
||||
def __truediv__(lhs, rhs):
|
||||
# TODO: Not really a ceil div - taking liberties for the DSL.
|
||||
rhs = AffineExprDef.coerce_from(rhs)
|
||||
return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs)
|
||||
|
||||
|
||||
class AffineConstantExpr(AffineExprDef):
|
||||
"""An affine constant being defined."""
|
||||
"""An affine constant being defined."""
|
||||
|
||||
def __init__(self, value: int):
|
||||
assert isinstance(value, int)
|
||||
self.value = value
|
||||
def __init__(self, value: int):
|
||||
assert isinstance(value, int)
|
||||
self.value = value
|
||||
|
||||
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
|
||||
return _ir.AffineConstantExpr.get(self.value)
|
||||
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
|
||||
return _ir.AffineConstantExpr.get(self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Const({self.value})"
|
||||
def __repr__(self):
|
||||
return f"Const({self.value})"
|
||||
|
||||
|
||||
class AffineBinaryExprDef(AffineExprDef):
|
||||
"""An affine binary expression being defined."""
|
||||
"""An affine binary expression being defined."""
|
||||
|
||||
def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
|
||||
self.ir_ctor = ir_ctor
|
||||
self.lhs = lhs
|
||||
self.rhs = rhs
|
||||
def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef):
|
||||
self.ir_ctor = ir_ctor
|
||||
self.lhs = lhs
|
||||
self.rhs = rhs
|
||||
|
||||
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
|
||||
return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
|
||||
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
|
||||
return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state))
|
||||
|
||||
def visit_affine_exprs(self, callback):
|
||||
"""Visits all AffineExprDefs including self."""
|
||||
super().visit_affine_exprs(callback)
|
||||
self.lhs.visit_affine_exprs(callback)
|
||||
self.rhs.visit_affine_exprs(callback)
|
||||
def visit_affine_exprs(self, callback):
|
||||
"""Visits all AffineExprDefs including self."""
|
||||
super().visit_affine_exprs(callback)
|
||||
self.lhs.visit_affine_exprs(callback)
|
||||
self.rhs.visit_affine_exprs(callback)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
|
||||
def __repr__(self):
|
||||
return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})"
|
||||
|
||||
|
||||
class DimDef(AffineExprDef):
|
||||
"""Represents a named dimension.
|
||||
"""Represents a named dimension."""
|
||||
|
||||
"""
|
||||
ALL_DIMS = dict() # type: Dict[str, "DimDef"]
|
||||
ALL_DIMS = dict() # type: Dict[str, "DimDef"]
|
||||
|
||||
def __new__(cls, dimname: str):
|
||||
existing = cls.ALL_DIMS.get(dimname)
|
||||
if existing is not None:
|
||||
return existing
|
||||
new = super().__new__(cls)
|
||||
new.dimname = dimname
|
||||
cls.ALL_DIMS[dimname] = new
|
||||
return new
|
||||
def __new__(cls, dimname: str):
|
||||
existing = cls.ALL_DIMS.get(dimname)
|
||||
if existing is not None:
|
||||
return existing
|
||||
new = super().__new__(cls)
|
||||
new.dimname = dimname
|
||||
cls.ALL_DIMS[dimname] = new
|
||||
return new
|
||||
|
||||
def __repr__(self):
|
||||
return f"Dim({self.dimname})"
|
||||
def __repr__(self):
|
||||
return f"Dim({self.dimname})"
|
||||
|
||||
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
|
||||
pos = state.get_dim(self.dimname)
|
||||
return _ir.AffineDimExpr.get(position=pos)
|
||||
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
|
||||
pos = state.get_dim(self.dimname)
|
||||
return _ir.AffineDimExpr.get(position=pos)
|
||||
|
||||
@classmethod
|
||||
def create_expando(cls):
|
||||
"""Create an expando class that creates unique symbols based on attr access.
|
||||
"""
|
||||
@classmethod
|
||||
def create_expando(cls):
|
||||
"""Create an expando class that creates unique symbols based on attr access."""
|
||||
|
||||
class ExpandoDims:
|
||||
class ExpandoDims:
|
||||
def __getattr__(self, n):
|
||||
return cls(n)
|
||||
|
||||
def __getattr__(self, n):
|
||||
return cls(n)
|
||||
|
||||
return ExpandoDims()
|
||||
return ExpandoDims()
|
||||
|
||||
|
||||
class SymbolDef(AffineExprDef):
|
||||
"""Represents a named symbol.
|
||||
"""Represents a named symbol.
|
||||
|
||||
>>> s1 = SymbolDef("s1")
|
||||
>>> s1
|
||||
@@ -270,36 +270,35 @@ class SymbolDef(AffineExprDef):
|
||||
False
|
||||
>>> s1 is SymbolDef("s1")
|
||||
True
|
||||
"""
|
||||
ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"]
|
||||
|
||||
def __new__(cls, symname: str):
|
||||
existing = cls.ALL_SYMBOLS.get(symname)
|
||||
if existing is not None:
|
||||
return existing
|
||||
new = super().__new__(cls)
|
||||
new.symname = symname
|
||||
cls.ALL_SYMBOLS[symname] = new
|
||||
return new
|
||||
|
||||
def __repr__(self):
|
||||
return f"Symbol({self.symname})"
|
||||
|
||||
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
|
||||
pos = state.get_symbol(self.symname)
|
||||
return _ir.AffineSymbolExpr.get(position=pos)
|
||||
|
||||
@classmethod
|
||||
def create_expando(cls):
|
||||
"""Create an expando class that creates unique symbols based on attr access.
|
||||
"""
|
||||
|
||||
class ExpandoSymbols:
|
||||
ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"]
|
||||
|
||||
def __getattr__(self, n):
|
||||
return cls(n)
|
||||
def __new__(cls, symname: str):
|
||||
existing = cls.ALL_SYMBOLS.get(symname)
|
||||
if existing is not None:
|
||||
return existing
|
||||
new = super().__new__(cls)
|
||||
new.symname = symname
|
||||
cls.ALL_SYMBOLS[symname] = new
|
||||
return new
|
||||
|
||||
return ExpandoSymbols()
|
||||
def __repr__(self):
|
||||
return f"Symbol({self.symname})"
|
||||
|
||||
def _create(self, state: AffineBuildState) -> _ir.AffineExpr:
|
||||
pos = state.get_symbol(self.symname)
|
||||
return _ir.AffineSymbolExpr.get(position=pos)
|
||||
|
||||
@classmethod
|
||||
def create_expando(cls):
|
||||
"""Create an expando class that creates unique symbols based on attr access."""
|
||||
|
||||
class ExpandoSymbols:
|
||||
def __getattr__(self, n):
|
||||
return cls(n)
|
||||
|
||||
return ExpandoSymbols()
|
||||
|
||||
|
||||
# Global accessor for on-demand dims and symbols.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,422 +21,468 @@ __all__ = ["LinalgStructuredOpConfig", "LinalgOpConfig", "OperandDefConfig"]
|
||||
|
||||
|
||||
def _serialize_affine_map(affine_map: _ir.AffineMap) -> str:
|
||||
with affine_map.context:
|
||||
# Affine map printing/parsing is via an AffineMap attr.
|
||||
attr = _ir.AffineMapAttr.get(affine_map)
|
||||
return str(attr)
|
||||
with affine_map.context:
|
||||
# Affine map printing/parsing is via an AffineMap attr.
|
||||
attr = _ir.AffineMapAttr.get(affine_map)
|
||||
return str(attr)
|
||||
|
||||
|
||||
class TensorUseConfig:
|
||||
"""Wrapper around a TensorUse with additional context-bound state."""
|
||||
"""Wrapper around a TensorUse with additional context-bound state."""
|
||||
|
||||
def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap):
|
||||
self.tensor_use = tensor_use
|
||||
self.indexing_map = indexing_map
|
||||
def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap):
|
||||
self.tensor_use = tensor_use
|
||||
self.indexing_map = indexing_map
|
||||
|
||||
def __repr__(self):
|
||||
return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
|
||||
def __repr__(self):
|
||||
return f"Use({self.tensor_use}, indexing_map={self.indexing_map})"
|
||||
|
||||
|
||||
class OperandDefConfig(YAMLObject):
|
||||
"""Wrapper containing an operand definition with additional state."""
|
||||
yaml_tag = "!LinalgOperandDefConfig"
|
||||
"""Wrapper containing an operand definition with additional state."""
|
||||
|
||||
def __init__(self,
|
||||
operand_def: OperandDef,
|
||||
shape_map: Optional[_ir.AffineMap] = None,
|
||||
index_attr_map: Optional[_ir.AffineMap] = None):
|
||||
self.operand_def = operand_def
|
||||
self.shape_map = shape_map # type: Optional[_ir.AffineMap]
|
||||
self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap]
|
||||
self.indexing_map = None # type: Optional[_ir.AffineMap]
|
||||
yaml_tag = "!LinalgOperandDefConfig"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.operand_def.name
|
||||
def __init__(
|
||||
self,
|
||||
operand_def: OperandDef,
|
||||
shape_map: Optional[_ir.AffineMap] = None,
|
||||
index_attr_map: Optional[_ir.AffineMap] = None,
|
||||
):
|
||||
self.operand_def = operand_def
|
||||
self.shape_map = shape_map # type: Optional[_ir.AffineMap]
|
||||
self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap]
|
||||
self.indexing_map = None # type: Optional[_ir.AffineMap]
|
||||
|
||||
@property
|
||||
def kind(self) -> OperandKind:
|
||||
return self.operand_def.kind
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.operand_def.name
|
||||
|
||||
@property
|
||||
def type_var(self) -> TypeVar:
|
||||
return self.operand_def.type_var
|
||||
@property
|
||||
def kind(self) -> OperandKind:
|
||||
return self.operand_def.kind
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
|
||||
if self.type_var:
|
||||
self_dict["type_var"] = self.type_var.name
|
||||
if self.shape_map:
|
||||
self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
|
||||
if self.index_attr_map:
|
||||
self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
|
||||
if self.operand_def.default_indices:
|
||||
self_dict["default_indices"] = self.operand_def.default_indices
|
||||
if self.operand_def.default_fn:
|
||||
self_dict["default_fn"] = self.operand_def.default_fn
|
||||
return self_dict
|
||||
@property
|
||||
def type_var(self) -> TypeVar:
|
||||
return self.operand_def.type_var
|
||||
|
||||
def __repr__(self):
|
||||
return (f"OperandDefConfig({self.operand_def}, "
|
||||
def to_yaml_custom_dict(self):
|
||||
self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower())
|
||||
if self.type_var:
|
||||
self_dict["type_var"] = self.type_var.name
|
||||
if self.shape_map:
|
||||
self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
|
||||
if self.index_attr_map:
|
||||
self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
|
||||
if self.operand_def.default_indices:
|
||||
self_dict["default_indices"] = self.operand_def.default_indices
|
||||
if self.operand_def.default_fn:
|
||||
self_dict["default_fn"] = self.operand_def.default_fn
|
||||
return self_dict
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"OperandDefConfig({self.operand_def}, "
|
||||
f"shape_map={self.shape_map}, "
|
||||
f"index_attr_map={self.index_attr_map}, "
|
||||
f"indexing_map={self.indexing_map})")
|
||||
f"indexing_map={self.indexing_map})"
|
||||
)
|
||||
|
||||
|
||||
class LinalgIndexingMapsConfig(YAMLObject):
|
||||
"""Abstracts the style of indexing maps that the op exports.
|
||||
"""Abstracts the style of indexing maps that the op exports.
|
||||
|
||||
Presently only static (tied to the op name) indexing maps are supported. In
|
||||
the future, it is expected that we will have additional variants:
|
||||
- Dynamic based on attributes
|
||||
- Dynamic based on operands
|
||||
Each is expected to require a different variant of specification.
|
||||
"""
|
||||
yaml_tag = "!LinalgIndexingMapsConfig"
|
||||
Presently only static (tied to the op name) indexing maps are supported. In
|
||||
the future, it is expected that we will have additional variants:
|
||||
- Dynamic based on attributes
|
||||
- Dynamic based on operands
|
||||
Each is expected to require a different variant of specification.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None):
|
||||
self.static_indexing_maps = static_indexing_maps
|
||||
yaml_tag = "!LinalgIndexingMapsConfig"
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
if self.static_indexing_maps is not None:
|
||||
return dict(static_indexing_maps=[
|
||||
_serialize_affine_map(m) for m in self.static_indexing_maps
|
||||
])
|
||||
raise ValueError(
|
||||
f"LinalgIndexingMapsConfig must have one type of indexing map"
|
||||
f"(got none)")
|
||||
def __init__(self, static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None):
|
||||
self.static_indexing_maps = static_indexing_maps
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
if self.static_indexing_maps is not None:
|
||||
return dict(
|
||||
static_indexing_maps=[
|
||||
_serialize_affine_map(m) for m in self.static_indexing_maps
|
||||
]
|
||||
)
|
||||
raise ValueError(
|
||||
f"LinalgIndexingMapsConfig must have one type of indexing map" f"(got none)"
|
||||
)
|
||||
|
||||
|
||||
class LinalgStructuredOpConfig(YAMLObject):
|
||||
"""Configuration for metadata sufficient to construct a linalg named op."""
|
||||
"""Configuration for metadata sufficient to construct a linalg named op."""
|
||||
|
||||
yaml_tag = "!LinalgStructuredOpConfig"
|
||||
yaml_tag = "!LinalgStructuredOpConfig"
|
||||
|
||||
def __init__(self,
|
||||
comprehension: Comprehension,
|
||||
domain: Sequence[DimDef],
|
||||
registered_operands: Sequence[OperandDef],
|
||||
context: Optional[_ir.Context] = None):
|
||||
self.context = context if context is not None else _ir.Context()
|
||||
self.affine_state = AffineBuildState()
|
||||
self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]]
|
||||
self.operands = dict() # type: Dict[OperandDef, OperandDefConfig]
|
||||
self.uses = dict() # type: Dict[TensorUse, TensorUseConfig]
|
||||
def __init__(
|
||||
self,
|
||||
comprehension: Comprehension,
|
||||
domain: Sequence[DimDef],
|
||||
registered_operands: Sequence[OperandDef],
|
||||
context: Optional[_ir.Context] = None,
|
||||
):
|
||||
self.context = context if context is not None else _ir.Context()
|
||||
self.affine_state = AffineBuildState()
|
||||
self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]]
|
||||
self.operands = dict() # type: Dict[OperandDef, OperandDefConfig]
|
||||
self.uses = dict() # type: Dict[TensorUse, TensorUseConfig]
|
||||
|
||||
# Compute the ordered set of writes and collect the tensor, capture, dims,
|
||||
# and index uses.
|
||||
collected_tensor_uses = set()
|
||||
collected_scalar_uses = set()
|
||||
collected_dim_uses = set()
|
||||
collected_indices = set()
|
||||
for write_use, read_use in zip(comprehension.definitions,
|
||||
comprehension.values):
|
||||
self.writes.append((write_use, read_use))
|
||||
# Compute the ordered set of writes and collect the tensor, capture, dims,
|
||||
# and index uses.
|
||||
collected_tensor_uses = set()
|
||||
collected_scalar_uses = set()
|
||||
collected_dim_uses = set()
|
||||
collected_indices = set()
|
||||
for write_use, read_use in zip(comprehension.definitions, comprehension.values):
|
||||
self.writes.append((write_use, read_use))
|
||||
|
||||
for write_use, read_use in self.writes:
|
||||
collected_tensor_uses.add(write_use)
|
||||
read_use.collect_tensor_uses(collected_tensor_uses)
|
||||
read_use.collect_scalar_uses(collected_scalar_uses)
|
||||
read_use.collect_dim_uses(collected_dim_uses)
|
||||
write_use.collect_dim_uses(collected_dim_uses)
|
||||
read_use.collect_indices(collected_indices)
|
||||
for write_use, read_use in self.writes:
|
||||
collected_tensor_uses.add(write_use)
|
||||
read_use.collect_tensor_uses(collected_tensor_uses)
|
||||
read_use.collect_scalar_uses(collected_scalar_uses)
|
||||
read_use.collect_dim_uses(collected_dim_uses)
|
||||
write_use.collect_dim_uses(collected_dim_uses)
|
||||
read_use.collect_indices(collected_indices)
|
||||
|
||||
# Set domain to the sorted list of uses if no domain annotation is given.
|
||||
if not domain:
|
||||
domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
|
||||
# Set domain to the sorted list of uses if no domain annotation is given.
|
||||
if not domain:
|
||||
domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
|
||||
|
||||
# Verify the domain dimensions match the used dimensions.
|
||||
if (len(domain) != len(collected_dim_uses) or
|
||||
any(dim not in collected_dim_uses for dim in domain)):
|
||||
raise ValueError(f"Expected the annotated domain dimensions {domain} to "
|
||||
f"match the set of dimension used by the tensor "
|
||||
f"comprehension {collected_dim_uses}")
|
||||
# Verify the domain dimensions match the used dimensions.
|
||||
if len(domain) != len(collected_dim_uses) or any(
|
||||
dim not in collected_dim_uses for dim in domain
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected the annotated domain dimensions {domain} to "
|
||||
f"match the set of dimension used by the tensor "
|
||||
f"comprehension {collected_dim_uses}"
|
||||
)
|
||||
|
||||
# Instantiate the dimensions in the given order.
|
||||
with self.context:
|
||||
local_state = AffineBuildState(
|
||||
global_state=self.affine_state, allow_new_symbols=False)
|
||||
for dim in domain:
|
||||
dim.build(state=local_state)
|
||||
# Instantiate the dimensions in the given order.
|
||||
with self.context:
|
||||
local_state = AffineBuildState(
|
||||
global_state=self.affine_state, allow_new_symbols=False
|
||||
)
|
||||
for dim in domain:
|
||||
dim.build(state=local_state)
|
||||
|
||||
# Collect all attribute definitions.
|
||||
collected_attr_defs = list()
|
||||
for operand in registered_operands:
|
||||
if operand.is_attribute():
|
||||
collected_attr_defs.append(operand)
|
||||
# Collect all attribute definitions.
|
||||
collected_attr_defs = list()
|
||||
for operand in registered_operands:
|
||||
if operand.is_attribute():
|
||||
collected_attr_defs.append(operand)
|
||||
|
||||
# Collect all tensors with manual indexing annotation.
|
||||
collected_index_defs = list()
|
||||
for operand in registered_operands:
|
||||
if operand.index_dims:
|
||||
if any(dim not in collected_dim_uses for dim in operand.index_dims):
|
||||
raise ValueError(f"Expected all index dims {operand.index_dims} of "
|
||||
f"operand {operand.name} to have uses.")
|
||||
collected_index_defs.append(operand)
|
||||
# Collect all tensors with manual indexing annotation.
|
||||
collected_index_defs = list()
|
||||
for operand in registered_operands:
|
||||
if operand.index_dims:
|
||||
if any(dim not in collected_dim_uses for dim in operand.index_dims):
|
||||
raise ValueError(
|
||||
f"Expected all index dims {operand.index_dims} of "
|
||||
f"operand {operand.name} to have uses."
|
||||
)
|
||||
collected_index_defs.append(operand)
|
||||
|
||||
# Collect the operand definitions of all tensor/scalar uses, attributes, and
|
||||
# shape-only tensors.
|
||||
all_operand_defs = list()
|
||||
for use in collected_tensor_uses:
|
||||
all_operand_defs.append(use.operand_def)
|
||||
for use in collected_scalar_uses:
|
||||
all_operand_defs.append(use.operand_def)
|
||||
for definition in collected_attr_defs:
|
||||
all_operand_defs.append(definition)
|
||||
for definition in collected_index_defs:
|
||||
all_operand_defs.append(definition)
|
||||
# Collect the operand definitions of all tensor/scalar uses, attributes, and
|
||||
# shape-only tensors.
|
||||
all_operand_defs = list()
|
||||
for use in collected_tensor_uses:
|
||||
all_operand_defs.append(use.operand_def)
|
||||
for use in collected_scalar_uses:
|
||||
all_operand_defs.append(use.operand_def)
|
||||
for definition in collected_attr_defs:
|
||||
all_operand_defs.append(definition)
|
||||
for definition in collected_index_defs:
|
||||
all_operand_defs.append(definition)
|
||||
|
||||
# Add all operands in registration order to ensure the symbols are
|
||||
# registered in the order they appear.
|
||||
all_operand_defs = sorted(
|
||||
all_operand_defs, key=lambda operand_def: operand_def.registered_index)
|
||||
for operand_def in all_operand_defs:
|
||||
self.add_operand(operand_def)
|
||||
# Add all operands in registration order to ensure the symbols are
|
||||
# registered in the order they appear.
|
||||
all_operand_defs = sorted(
|
||||
all_operand_defs, key=lambda operand_def: operand_def.registered_index
|
||||
)
|
||||
for operand_def in all_operand_defs:
|
||||
self.add_operand(operand_def)
|
||||
|
||||
# Add all shape-only tensor index_dim annotations and all tensor uses.
|
||||
for definition in collected_index_defs:
|
||||
self.add_indexed_operand(definition)
|
||||
for use in collected_tensor_uses:
|
||||
self.add_tensor_use(use)
|
||||
# Add all shape-only tensor index_dim annotations and all tensor uses.
|
||||
for definition in collected_index_defs:
|
||||
self.add_indexed_operand(definition)
|
||||
for use in collected_tensor_uses:
|
||||
self.add_tensor_use(use)
|
||||
|
||||
# Normalize all shape and indexing maps now that full count of dims and
|
||||
# symbols are known.
|
||||
for cuse in self.uses.values():
|
||||
cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
|
||||
for definition in collected_index_defs:
|
||||
self.operands[definition].indexing_map = self._normalize_affine_map(
|
||||
self.operands[definition].indexing_map)
|
||||
for operand_config in self.operands.values():
|
||||
if operand_config.shape_map:
|
||||
operand_config.shape_map = self._normalize_affine_map(
|
||||
operand_config.shape_map, with_dims=False)
|
||||
if operand_config.index_attr_map:
|
||||
operand_config.index_attr_map = self._normalize_affine_map(
|
||||
operand_config.index_attr_map, with_dims=False)
|
||||
# Normalize all shape and indexing maps now that full count of dims and
|
||||
# symbols are known.
|
||||
for cuse in self.uses.values():
|
||||
cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
|
||||
for definition in collected_index_defs:
|
||||
self.operands[definition].indexing_map = self._normalize_affine_map(
|
||||
self.operands[definition].indexing_map
|
||||
)
|
||||
for operand_config in self.operands.values():
|
||||
if operand_config.shape_map:
|
||||
operand_config.shape_map = self._normalize_affine_map(
|
||||
operand_config.shape_map, with_dims=False
|
||||
)
|
||||
if operand_config.index_attr_map:
|
||||
operand_config.index_attr_map = self._normalize_affine_map(
|
||||
operand_config.index_attr_map, with_dims=False
|
||||
)
|
||||
|
||||
# Now for each write use, propagate the indexing maps from the use to the
|
||||
# tensor, ensuring that there are not conflicts.
|
||||
for write_use, _ in self.writes:
|
||||
write_tensor_config = self.operands[write_use.operand_def]
|
||||
if write_tensor_config.indexing_map:
|
||||
raise ValueError(
|
||||
f"Unexpected multi-write to a single tensor: {write_tensor_config}")
|
||||
write_tensor_config.indexing_map = self.uses[write_use].indexing_map
|
||||
# Now for each write use, propagate the indexing maps from the use to the
|
||||
# tensor, ensuring that there are not conflicts.
|
||||
for write_use, _ in self.writes:
|
||||
write_tensor_config = self.operands[write_use.operand_def]
|
||||
if write_tensor_config.indexing_map:
|
||||
raise ValueError(
|
||||
f"Unexpected multi-write to a single tensor: {write_tensor_config}"
|
||||
)
|
||||
write_tensor_config.indexing_map = self.uses[write_use].indexing_map
|
||||
|
||||
# For each read use, propagate the indexing maps from the use to the
|
||||
# tensor, ensuring that there are not conflicts.
|
||||
for _, read_expr in self.writes:
|
||||
read_uses = set() # type: Set[TensorUse]
|
||||
read_expr.collect_tensor_uses(read_uses)
|
||||
for read_use in read_uses:
|
||||
read_operand_config = self.operands[read_use.operand_def]
|
||||
if (read_operand_config.indexing_map and
|
||||
read_operand_config.indexing_map !=
|
||||
self.uses[read_use].indexing_map):
|
||||
raise ValueError(
|
||||
f"Unexpected multi-read of a tensor with different accesses:"
|
||||
f"{read_operand_config} vs {read_use}")
|
||||
read_operand_config.indexing_map = self.uses[read_use].indexing_map
|
||||
# For each read use, propagate the indexing maps from the use to the
|
||||
# tensor, ensuring that there are not conflicts.
|
||||
for _, read_expr in self.writes:
|
||||
read_uses = set() # type: Set[TensorUse]
|
||||
read_expr.collect_tensor_uses(read_uses)
|
||||
for read_use in read_uses:
|
||||
read_operand_config = self.operands[read_use.operand_def]
|
||||
if (
|
||||
read_operand_config.indexing_map
|
||||
and read_operand_config.indexing_map
|
||||
!= self.uses[read_use].indexing_map
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unexpected multi-read of a tensor with different accesses:"
|
||||
f"{read_operand_config} vs {read_use}"
|
||||
)
|
||||
read_operand_config.indexing_map = self.uses[read_use].indexing_map
|
||||
|
||||
# Set the indexing map of all scalar uses to the empty map.
|
||||
for operand_config in self.operands.values():
|
||||
if operand_config.operand_def.kind == OperandKind.SCALAR:
|
||||
operand_config.indexing_map = self._get_scalar_map()
|
||||
# Set the indexing map of all scalar uses to the empty map.
|
||||
for operand_config in self.operands.values():
|
||||
if operand_config.operand_def.kind == OperandKind.SCALAR:
|
||||
operand_config.indexing_map = self._get_scalar_map()
|
||||
|
||||
# Check all registered tensor and scalar operands have an indexing map.
|
||||
for operand in registered_operands:
|
||||
if operand.is_attribute():
|
||||
continue
|
||||
if not (operand in self.operands and self.operands[operand].indexing_map):
|
||||
raise ValueError(f"Failed to compute an indexing map for operand "
|
||||
f"{operand.name}")
|
||||
# Check all registered tensor and scalar operands have an indexing map.
|
||||
for operand in registered_operands:
|
||||
if operand.is_attribute():
|
||||
continue
|
||||
if not (operand in self.operands and self.operands[operand].indexing_map):
|
||||
raise ValueError(
|
||||
f"Failed to compute an indexing map for operand " f"{operand.name}"
|
||||
)
|
||||
|
||||
# Collect reduction dims and ensure all the same.
|
||||
all_reduction_dims = set(comprehension.all_reduction_dims)
|
||||
if len(all_reduction_dims) != 1:
|
||||
raise ValueError(
|
||||
f"All writes within a generic must have the same reduction "
|
||||
f"dims. Got: {all_reduction_dims}")
|
||||
self.reduction_dims = next(iter(all_reduction_dims))
|
||||
# Collect reduction dims and ensure all the same.
|
||||
all_reduction_dims = set(comprehension.all_reduction_dims)
|
||||
if len(all_reduction_dims) != 1:
|
||||
raise ValueError(
|
||||
f"All writes within a generic must have the same reduction "
|
||||
f"dims. Got: {all_reduction_dims}"
|
||||
)
|
||||
self.reduction_dims = next(iter(all_reduction_dims))
|
||||
|
||||
# Check the index dimension exists and resolve.
|
||||
for index in collected_indices:
|
||||
if index.dim_def.dimname not in self.affine_state.all_dims:
|
||||
raise ValueError(
|
||||
f"The dimension {index.dim_def.dimname} is not part of the "
|
||||
f"iteration domain {self.affine_state.all_dims}")
|
||||
index.resolve_dimension_name(self.affine_state)
|
||||
# Check the index dimension exists and resolve.
|
||||
for index in collected_indices:
|
||||
if index.dim_def.dimname not in self.affine_state.all_dims:
|
||||
raise ValueError(
|
||||
f"The dimension {index.dim_def.dimname} is not part of the "
|
||||
f"iteration domain {self.affine_state.all_dims}"
|
||||
)
|
||||
index.resolve_dimension_name(self.affine_state)
|
||||
|
||||
# Generate the scalar assignments (used to build a body).
|
||||
self.assignments = [
|
||||
ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
|
||||
for write_use, read_expr in self.writes
|
||||
]
|
||||
# Generate the scalar assignments (used to build a body).
|
||||
self.assignments = [
|
||||
ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
|
||||
for write_use, read_expr in self.writes
|
||||
]
|
||||
|
||||
@property
|
||||
def ordered_operands(self) -> Sequence[OperandDefConfig]:
|
||||
return sorted(
|
||||
self.operands.values(),
|
||||
key=lambda operand: operand.operand_def.registered_index)
|
||||
@property
|
||||
def ordered_operands(self) -> Sequence[OperandDefConfig]:
|
||||
return sorted(
|
||||
self.operands.values(),
|
||||
key=lambda operand: operand.operand_def.registered_index,
|
||||
)
|
||||
|
||||
@property
|
||||
def ordered_dims(self) -> Sequence[Tuple[str, int]]:
|
||||
"""Gets the ordered list of dim bindings (symbolic name, position).
|
||||
@property
|
||||
def ordered_dims(self) -> Sequence[Tuple[str, int]]:
|
||||
"""Gets the ordered list of dim bindings (symbolic name, position).
|
||||
|
||||
TODO: The original parser relies on parse ordering to arrive at the
|
||||
iterator types, but that ordering is not defined on the Python side, so
|
||||
this may be ambiguous.
|
||||
"""
|
||||
return list(self.affine_state.all_dims.items())
|
||||
TODO: The original parser relies on parse ordering to arrive at the
|
||||
iterator types, but that ordering is not defined on the Python side, so
|
||||
this may be ambiguous.
|
||||
"""
|
||||
return list(self.affine_state.all_dims.items())
|
||||
|
||||
@property
|
||||
def indexing_maps(self) -> Sequence[_ir.AffineMap]:
|
||||
return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
|
||||
@property
|
||||
def indexing_maps(self) -> Sequence[_ir.AffineMap]:
|
||||
return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
|
||||
|
||||
@property
|
||||
def iterator_types(self) -> Sequence[str]:
|
||||
@property
|
||||
def iterator_types(self) -> Sequence[str]:
|
||||
def get_type(symbolic_name, position):
|
||||
for reduction_dim_expr in self.reduction_dims:
|
||||
if reduction_dim_expr.dimname == symbolic_name:
|
||||
return "reduction"
|
||||
return "parallel"
|
||||
|
||||
def get_type(symbolic_name, position):
|
||||
for reduction_dim_expr in self.reduction_dims:
|
||||
if reduction_dim_expr.dimname == symbolic_name:
|
||||
return "reduction"
|
||||
return "parallel"
|
||||
return [get_type(*dim) for dim in self.ordered_dims]
|
||||
|
||||
return [get_type(*dim) for dim in self.ordered_dims]
|
||||
def add_operand(self, operand_def: OperandDef):
|
||||
if operand_def in self.operands:
|
||||
return
|
||||
if not (operand_def.is_tensor() or operand_def.kind == OperandKind.INDEX_ATTR):
|
||||
self.operands[operand_def] = OperandDefConfig(operand_def)
|
||||
return
|
||||
with self.context:
|
||||
local_state = AffineBuildState(
|
||||
global_state=self.affine_state, allow_new_dims=False
|
||||
)
|
||||
exprs = []
|
||||
for expr in operand_def.size_exprs:
|
||||
exprs.append(expr.build(state=local_state))
|
||||
assert local_state.local_dim_count == 0
|
||||
affine_map = _ir.AffineMap.get(
|
||||
dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs
|
||||
)
|
||||
if operand_def.kind == OperandKind.INDEX_ATTR:
|
||||
self.operands[operand_def] = OperandDefConfig(
|
||||
operand_def, index_attr_map=affine_map
|
||||
)
|
||||
else:
|
||||
self.operands[operand_def] = OperandDefConfig(
|
||||
operand_def, shape_map=affine_map
|
||||
)
|
||||
|
||||
def add_operand(self, operand_def: OperandDef):
|
||||
if operand_def in self.operands:
|
||||
return
|
||||
if not (operand_def.is_tensor() or
|
||||
operand_def.kind == OperandKind.INDEX_ATTR):
|
||||
self.operands[operand_def] = OperandDefConfig(operand_def)
|
||||
return
|
||||
with self.context:
|
||||
local_state = AffineBuildState(
|
||||
global_state=self.affine_state, allow_new_dims=False)
|
||||
exprs = []
|
||||
for expr in operand_def.size_exprs:
|
||||
exprs.append(expr.build(state=local_state))
|
||||
assert local_state.local_dim_count == 0
|
||||
affine_map = _ir.AffineMap.get(
|
||||
dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
|
||||
if operand_def.kind == OperandKind.INDEX_ATTR:
|
||||
self.operands[operand_def] = OperandDefConfig(
|
||||
operand_def, index_attr_map=affine_map)
|
||||
else:
|
||||
self.operands[operand_def] = OperandDefConfig(
|
||||
operand_def, shape_map=affine_map)
|
||||
def add_indexed_operand(self, operand_def: OperandDef):
|
||||
with self.context:
|
||||
local_state = AffineBuildState(
|
||||
global_state=self.affine_state, allow_new_symbols=False
|
||||
)
|
||||
exprs = []
|
||||
for expr in operand_def.index_dims:
|
||||
exprs.append(expr.build(state=local_state))
|
||||
self.operands[operand_def].indexing_map = _ir.AffineMap.get(
|
||||
dim_count=local_state.dim_count,
|
||||
symbol_count=local_state.symbol_count,
|
||||
exprs=exprs,
|
||||
)
|
||||
|
||||
def add_indexed_operand(self, operand_def: OperandDef):
|
||||
with self.context:
|
||||
local_state = AffineBuildState(
|
||||
global_state=self.affine_state, allow_new_symbols=False)
|
||||
exprs = []
|
||||
for expr in operand_def.index_dims:
|
||||
exprs.append(expr.build(state=local_state))
|
||||
self.operands[operand_def].indexing_map = _ir.AffineMap.get(
|
||||
dim_count=local_state.dim_count,
|
||||
symbol_count=local_state.symbol_count,
|
||||
exprs=exprs)
|
||||
def add_tensor_use(self, tensor_use: TensorUse):
|
||||
if tensor_use in self.uses:
|
||||
return
|
||||
with self.context:
|
||||
local_state = AffineBuildState(
|
||||
global_state=self.affine_state, allow_new_symbols=False
|
||||
)
|
||||
exprs = []
|
||||
for expr in tensor_use.indices:
|
||||
exprs.append(expr.build(state=local_state))
|
||||
indexing_map = _ir.AffineMap.get(
|
||||
dim_count=local_state.dim_count,
|
||||
symbol_count=local_state.symbol_count,
|
||||
exprs=exprs,
|
||||
)
|
||||
|
||||
def add_tensor_use(self, tensor_use: TensorUse):
|
||||
if tensor_use in self.uses:
|
||||
return
|
||||
with self.context:
|
||||
local_state = AffineBuildState(
|
||||
global_state=self.affine_state, allow_new_symbols=False)
|
||||
exprs = []
|
||||
for expr in tensor_use.indices:
|
||||
exprs.append(expr.build(state=local_state))
|
||||
indexing_map = _ir.AffineMap.get(
|
||||
dim_count=local_state.dim_count,
|
||||
symbol_count=local_state.symbol_count,
|
||||
exprs=exprs)
|
||||
use_config = TensorUseConfig(tensor_use, indexing_map)
|
||||
self.uses[tensor_use] = use_config
|
||||
|
||||
use_config = TensorUseConfig(tensor_use, indexing_map)
|
||||
self.uses[tensor_use] = use_config
|
||||
def _get_scalar_map(self) -> _ir.AffineMap:
|
||||
"""Create an empty affine map used to index a scalar."""
|
||||
with self.context:
|
||||
return _ir.AffineMap.get(
|
||||
dim_count=self.affine_state.dim_count,
|
||||
symbol_count=self.affine_state.symbol_count,
|
||||
exprs=list(),
|
||||
)
|
||||
|
||||
def _get_scalar_map(self) -> _ir.AffineMap:
|
||||
"""Create an empty affine map used to index a scalar."""
|
||||
with self.context:
|
||||
return _ir.AffineMap.get(
|
||||
dim_count=self.affine_state.dim_count,
|
||||
symbol_count=self.affine_state.symbol_count,
|
||||
exprs=list())
|
||||
def _normalize_affine_map(
|
||||
self, affine_map: _ir.AffineMap, with_dims: bool = True
|
||||
) -> _ir.AffineMap:
|
||||
"""Normalizes an indexing map to have the max known symbols and dims."""
|
||||
with self.context:
|
||||
return _ir.AffineMap.get(
|
||||
dim_count=self.affine_state.dim_count if with_dims else 0,
|
||||
symbol_count=self.affine_state.symbol_count,
|
||||
exprs=list(affine_map.results),
|
||||
)
|
||||
|
||||
def _normalize_affine_map(self,
|
||||
affine_map: _ir.AffineMap,
|
||||
with_dims: bool = True) -> _ir.AffineMap:
|
||||
"""Normalizes an indexing map to have the max known symbols and dims."""
|
||||
with self.context:
|
||||
return _ir.AffineMap.get(
|
||||
dim_count=self.affine_state.dim_count if with_dims else 0,
|
||||
symbol_count=self.affine_state.symbol_count,
|
||||
exprs=list(affine_map.results))
|
||||
def to_yaml_custom_dict(self):
|
||||
self_dict = dict(args=self.ordered_operands)
|
||||
# TODO: Refactor the hierarchy internally when supporting more
|
||||
# than static (preserving this serialized form).
|
||||
self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
|
||||
static_indexing_maps=self.indexing_maps
|
||||
)
|
||||
self_dict["iterator_types"] = self.iterator_types
|
||||
self_dict["assignments"] = self.assignments
|
||||
return self_dict
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
self_dict = dict(args=self.ordered_operands)
|
||||
# TODO: Refactor the hierarchy internally when supporting more
|
||||
# than static (preserving this serialized form).
|
||||
self_dict["indexing_maps"] = LinalgIndexingMapsConfig(
|
||||
static_indexing_maps=self.indexing_maps)
|
||||
self_dict["iterator_types"] = self.iterator_types
|
||||
self_dict["assignments"] = self.assignments
|
||||
return self_dict
|
||||
|
||||
def __repr__(self):
|
||||
lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
|
||||
lines.append("operands=[")
|
||||
for def_config in self.ordered_operands:
|
||||
lines.append(f" {repr(def_config)}")
|
||||
lines.append("], indexing_maps=[")
|
||||
for m in self.indexing_maps:
|
||||
lines.append(f" {repr(m)}")
|
||||
lines.append(f"], iterator_types=[")
|
||||
for t in self.iterator_types:
|
||||
lines.append(f" {t}")
|
||||
lines.append("])")
|
||||
return "\n".join(lines)
|
||||
def __repr__(self):
|
||||
lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"]
|
||||
lines.append("operands=[")
|
||||
for def_config in self.ordered_operands:
|
||||
lines.append(f" {repr(def_config)}")
|
||||
lines.append("], indexing_maps=[")
|
||||
for m in self.indexing_maps:
|
||||
lines.append(f" {repr(m)}")
|
||||
lines.append(f"], iterator_types=[")
|
||||
for t in self.iterator_types:
|
||||
lines.append(f" {t}")
|
||||
lines.append("])")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class LinalgOpConfig(YAMLObject):
|
||||
"""Container for any supported linalg op type.
|
||||
"""Container for any supported linalg op type.
|
||||
|
||||
This includes the concrete type by name for ease of parsing by systems
|
||||
that ignore tags.
|
||||
"""
|
||||
yaml_tag = "!LinalgOpConfig"
|
||||
This includes the concrete type by name for ease of parsing by systems
|
||||
that ignore tags.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
metadata: OpMetadataDef,
|
||||
*,
|
||||
structured_op: Optional[LinalgStructuredOpConfig] = None):
|
||||
self.metadata = metadata
|
||||
self.structured_op = structured_op
|
||||
yaml_tag = "!LinalgOpConfig"
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
self_dict = dict(metadata=self.metadata,)
|
||||
if self.structured_op:
|
||||
self_dict["structured_op"] = self.structured_op
|
||||
return self_dict
|
||||
def __init__(
|
||||
self,
|
||||
metadata: OpMetadataDef,
|
||||
*,
|
||||
structured_op: Optional[LinalgStructuredOpConfig] = None,
|
||||
):
|
||||
self.metadata = metadata
|
||||
self.structured_op = structured_op
|
||||
|
||||
@staticmethod
|
||||
def from_linalg_op_def(
|
||||
op_def: LinalgOpDef,
|
||||
context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]:
|
||||
"""Expands a LinalgOpDef into corresponding Linalg configured ops."""
|
||||
# TODO: Many LinalgOpDef patterns need to expand to multiple generics.
|
||||
assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
|
||||
return [
|
||||
LinalgOpConfig(
|
||||
op_def.metadata,
|
||||
structured_op=LinalgStructuredOpConfig(
|
||||
op_def.comprehensions[0], op_def.domain,
|
||||
op_def.registered_operands.values(), context)),
|
||||
]
|
||||
def to_yaml_custom_dict(self):
|
||||
self_dict = dict(
|
||||
metadata=self.metadata,
|
||||
)
|
||||
if self.structured_op:
|
||||
self_dict["structured_op"] = self.structured_op
|
||||
return self_dict
|
||||
|
||||
def __repr__(self):
|
||||
return (f"LinalgOpConfig(metadata={self.metadata},\n"
|
||||
f"structured_op={self.structured_op})")
|
||||
@staticmethod
|
||||
def from_linalg_op_def(
|
||||
op_def: LinalgOpDef, context: Optional[_ir.Context] = None
|
||||
) -> Sequence["LinalgOpConfig"]:
|
||||
"""Expands a LinalgOpDef into corresponding Linalg configured ops."""
|
||||
# TODO: Many LinalgOpDef patterns need to expand to multiple generics.
|
||||
assert len(op_def.comprehensions) == 1, "Only one comprehension supported"
|
||||
return [
|
||||
LinalgOpConfig(
|
||||
op_def.metadata,
|
||||
structured_op=LinalgStructuredOpConfig(
|
||||
op_def.comprehensions[0],
|
||||
op_def.domain,
|
||||
op_def.registered_operands.values(),
|
||||
context,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"LinalgOpConfig(metadata={self.metadata},\n"
|
||||
f"structured_op={self.structured_op})"
|
||||
)
|
||||
|
||||
@@ -10,160 +10,192 @@ import inspect
|
||||
import threading
|
||||
|
||||
from ..... import ir
|
||||
from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
|
||||
from ...._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
)
|
||||
from .comprehension import *
|
||||
from .config import *
|
||||
from .emitter import *
|
||||
|
||||
_CONTEXT = threading.local()
|
||||
|
||||
StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList,
|
||||
Sequence[Union[ir.Value, ir.Operation, ir.OpView]]]
|
||||
StructuredOpOuts = Union[
|
||||
ir.Operation,
|
||||
ir.OpView,
|
||||
ir.OpResultList,
|
||||
Sequence[Union[ir.Value, ir.Operation, ir.OpView]],
|
||||
]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_op_def(op_def: LinalgOpDef):
|
||||
if hasattr(_CONTEXT, "current_op_def"):
|
||||
raise ValueError("Cannot recursively define an operation")
|
||||
_CONTEXT.current_op_def = op_def
|
||||
try:
|
||||
yield op_def
|
||||
finally:
|
||||
del _CONTEXT.current_op_def
|
||||
if hasattr(_CONTEXT, "current_op_def"):
|
||||
raise ValueError("Cannot recursively define an operation")
|
||||
_CONTEXT.current_op_def = op_def
|
||||
try:
|
||||
yield op_def
|
||||
finally:
|
||||
del _CONTEXT.current_op_def
|
||||
|
||||
|
||||
def current_op_def() -> LinalgOpDef:
|
||||
try:
|
||||
return _CONTEXT.current_op_def
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"Attempt to access the current op definition being defined "
|
||||
"but none is set. Did you mean to call this in an op definition?")
|
||||
try:
|
||||
return _CONTEXT.current_op_def
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"Attempt to access the current op definition being defined "
|
||||
"but none is set. Did you mean to call this in an op definition?"
|
||||
)
|
||||
|
||||
|
||||
def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList:
|
||||
if isinstance(outs, (ir.Operation, ir.OpView)):
|
||||
return _get_op_results_or_values(outs)
|
||||
elif isinstance(outs, ir.OpResultList):
|
||||
return outs
|
||||
if isinstance(outs, (ir.Operation, ir.OpView)):
|
||||
return _get_op_results_or_values(outs)
|
||||
elif isinstance(outs, ir.OpResultList):
|
||||
return outs
|
||||
|
||||
return [_get_op_result_or_value(o) for o in outs]
|
||||
return [_get_op_result_or_value(o) for o in outs]
|
||||
|
||||
|
||||
class DefinedOpCallable:
|
||||
"""Callable that wraps any defined op function."""
|
||||
"""Callable that wraps any defined op function."""
|
||||
|
||||
def __init__(self, op_name: str, op_def: LinalgOpDef):
|
||||
self.op_name = op_name
|
||||
self.op_def = op_def
|
||||
def __init__(self, op_name: str, op_def: LinalgOpDef):
|
||||
self.op_name = op_name
|
||||
self.op_def = op_def
|
||||
|
||||
def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value],
|
||||
outs: StructuredOpOuts, **kwargs):
|
||||
"""Emits the corresponding op definition as IR.
|
||||
def __call__(
|
||||
self,
|
||||
*ins: Union[ir.Operation, ir.OpView, ir.Value],
|
||||
outs: StructuredOpOuts,
|
||||
**kwargs,
|
||||
):
|
||||
"""Emits the corresponding op definition as IR.
|
||||
|
||||
Most arguments are passed through to the underlying emitter. The following
|
||||
keyword argument is interpreted here:
|
||||
emit_generic: Emits a generic form as appropriate (default True). If
|
||||
False, a named form is emitted (which must have been built in to the
|
||||
compiler).
|
||||
"""
|
||||
emit_generic = kwargs.pop("emit_generic", False)
|
||||
if not isinstance(emit_generic, bool):
|
||||
raise ValueError(f"The named argument 'emit_generic' needs to be "
|
||||
f" of type bool but got {type(emit_generic)}")
|
||||
Most arguments are passed through to the underlying emitter. The following
|
||||
keyword argument is interpreted here:
|
||||
emit_generic: Emits a generic form as appropriate (default True). If
|
||||
False, a named form is emitted (which must have been built in to the
|
||||
compiler).
|
||||
"""
|
||||
emit_generic = kwargs.pop("emit_generic", False)
|
||||
if not isinstance(emit_generic, bool):
|
||||
raise ValueError(
|
||||
f"The named argument 'emit_generic' needs to be "
|
||||
f" of type bool but got {type(emit_generic)}"
|
||||
)
|
||||
|
||||
op_configs = LinalgOpConfig.from_linalg_op_def(
|
||||
self.op_def, context=ir.Context.current)
|
||||
op_configs = LinalgOpConfig.from_linalg_op_def(
|
||||
self.op_def, context=ir.Context.current
|
||||
)
|
||||
|
||||
if len(op_configs) != 1:
|
||||
# TODO: Support composite ops.
|
||||
raise NotImplementedError(
|
||||
f"Emission of composite linalg ops not supported: {op_configs}")
|
||||
if len(op_configs) != 1:
|
||||
# TODO: Support composite ops.
|
||||
raise NotImplementedError(
|
||||
f"Emission of composite linalg ops not supported: {op_configs}"
|
||||
)
|
||||
|
||||
ctx = ir.Context.current
|
||||
linalgDialect = ctx.get_dialect_descriptor("linalg")
|
||||
fully_qualified_name = "linalg." + self.op_name
|
||||
emit_generic = (
|
||||
emit_generic or not ctx.is_registered_operation(fully_qualified_name))
|
||||
ctx = ir.Context.current
|
||||
linalgDialect = ctx.get_dialect_descriptor("linalg")
|
||||
fully_qualified_name = "linalg." + self.op_name
|
||||
emit_generic = emit_generic or not ctx.is_registered_operation(
|
||||
fully_qualified_name
|
||||
)
|
||||
|
||||
op_config = op_configs[0]
|
||||
out_values = _prepare_structured_op_outs(outs)
|
||||
in_values = [_get_op_result_or_value(i) for i in ins]
|
||||
if op_config.structured_op:
|
||||
if emit_generic:
|
||||
return emit_generic_structured_op(
|
||||
op_config.structured_op, *in_values, outs=out_values, **kwargs)
|
||||
else:
|
||||
return emit_named_structured_op(
|
||||
op_config.structured_op,
|
||||
self.op_name,
|
||||
self.op_def.metadata.cpp_class_name,
|
||||
*in_values,
|
||||
outs=out_values,
|
||||
**kwargs)
|
||||
op_config = op_configs[0]
|
||||
out_values = _prepare_structured_op_outs(outs)
|
||||
in_values = [_get_op_result_or_value(i) for i in ins]
|
||||
if op_config.structured_op:
|
||||
if emit_generic:
|
||||
return emit_generic_structured_op(
|
||||
op_config.structured_op, *in_values, outs=out_values, **kwargs
|
||||
)
|
||||
else:
|
||||
return emit_named_structured_op(
|
||||
op_config.structured_op,
|
||||
self.op_name,
|
||||
self.op_def.metadata.cpp_class_name,
|
||||
*in_values,
|
||||
outs=out_values,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Emission of linalg op type not supported: {op_config}")
|
||||
raise NotImplementedError(
|
||||
f"Emission of linalg op type not supported: {op_config}"
|
||||
)
|
||||
|
||||
|
||||
def linalg_structured_op(dsl_func=None,
|
||||
*,
|
||||
op_name=None,
|
||||
op_class_name=None) -> DefinedOpCallable:
|
||||
if dsl_func is None:
|
||||
# Curry the keyword args in for delayed application.
|
||||
return functools.partial(
|
||||
linalg_structured_op, op_name=op_name, op_class_name=op_class_name)
|
||||
# Determine default names by introspecting the function.
|
||||
if op_name is None:
|
||||
op_name = dsl_func.__name__
|
||||
if op_class_name is None:
|
||||
# Camel case it.
|
||||
op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
|
||||
def linalg_structured_op(
|
||||
dsl_func=None, *, op_name=None, op_class_name=None
|
||||
) -> DefinedOpCallable:
|
||||
if dsl_func is None:
|
||||
# Curry the keyword args in for delayed application.
|
||||
return functools.partial(
|
||||
linalg_structured_op, op_name=op_name, op_class_name=op_class_name
|
||||
)
|
||||
# Determine default names by introspecting the function.
|
||||
if op_name is None:
|
||||
op_name = dsl_func.__name__
|
||||
if op_class_name is None:
|
||||
# Camel case it.
|
||||
op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op"
|
||||
|
||||
op_def = LinalgOpDef(
|
||||
name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func))
|
||||
op_def = LinalgOpDef(
|
||||
name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)
|
||||
)
|
||||
|
||||
# Extract arguments and TensorDefs from the signature.
|
||||
dsl_func_args = list()
|
||||
sig = inspect.signature(dsl_func)
|
||||
for param_name, param in sig.parameters.items():
|
||||
param_default = param.default
|
||||
if isinstance(param_default,
|
||||
(TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef,
|
||||
BinaryFnAttrDef, TypeFnAttrDef)):
|
||||
op_def.add_operand(param_name, param_default.operand_def)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"@linalg_structured_op function parameters must be defaulted as "
|
||||
f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): "
|
||||
f"Found {param_name}: {param_default}")
|
||||
dsl_func_args.append(param_default)
|
||||
# Extract arguments and TensorDefs from the signature.
|
||||
dsl_func_args = list()
|
||||
sig = inspect.signature(dsl_func)
|
||||
for param_name, param in sig.parameters.items():
|
||||
param_default = param.default
|
||||
if isinstance(
|
||||
param_default,
|
||||
(
|
||||
TensorDef,
|
||||
ScalarDef,
|
||||
IndexAttrDef,
|
||||
UnaryFnAttrDef,
|
||||
BinaryFnAttrDef,
|
||||
TypeFnAttrDef,
|
||||
),
|
||||
):
|
||||
op_def.add_operand(param_name, param_default.operand_def)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"@linalg_structured_op function parameters must be defaulted as "
|
||||
f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): "
|
||||
f"Found {param_name}: {param_default}"
|
||||
)
|
||||
dsl_func_args.append(param_default)
|
||||
|
||||
# Invoke the DSL func to finish populating the op definition.
|
||||
with bind_op_def(op_def):
|
||||
dsl_func(*dsl_func_args)
|
||||
# Invoke the DSL func to finish populating the op definition.
|
||||
with bind_op_def(op_def):
|
||||
dsl_func(*dsl_func_args)
|
||||
|
||||
# TODO: The returned callable should be an IR emitter but that is not
|
||||
# upstreamed yet.
|
||||
return DefinedOpCallable(op_name, op_def)
|
||||
# TODO: The returned callable should be an IR emitter but that is not
|
||||
# upstreamed yet.
|
||||
return DefinedOpCallable(op_name, op_def)
|
||||
|
||||
|
||||
def domain(*dimensions: DimDef):
|
||||
if any(not isinstance(d, DimDef) for d in dimensions):
|
||||
raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
|
||||
current_op_def().domain.extend(dimensions)
|
||||
if any(not isinstance(d, DimDef) for d in dimensions):
|
||||
raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
|
||||
current_op_def().domain.extend(dimensions)
|
||||
|
||||
|
||||
def implements(*interfaces: OpInterfaceDef):
|
||||
if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
|
||||
raise ValueError(
|
||||
f"Expected interfaces of type OpInterfaceDef but got {interfaces}")
|
||||
current_op_def().metadata.implements.extend(interfaces)
|
||||
if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
|
||||
raise ValueError(
|
||||
f"Expected interfaces of type OpInterfaceDef but got {interfaces}"
|
||||
)
|
||||
current_op_def().metadata.implements.extend(interfaces)
|
||||
|
||||
|
||||
def defines(*definitions: OpDefinitionDef):
|
||||
if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
|
||||
raise ValueError(
|
||||
f"Expected definitions of type OpDefinitionDef but got {definitions}")
|
||||
current_op_def().metadata.defines.extend(definitions)
|
||||
if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
|
||||
raise ValueError(
|
||||
f"Expected definitions of type OpDefinitionDef but got {definitions}"
|
||||
)
|
||||
current_op_def().metadata.defines.extend(definitions)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -30,123 +30,137 @@ __all__ = [
|
||||
|
||||
|
||||
class ScalarFn:
|
||||
"""A type of ScalarExpression that applies a function."""
|
||||
"""A type of ScalarExpression that applies a function."""
|
||||
|
||||
def __init__(self, kind: "FunctionKind", fn_name: Optional[str],
|
||||
attr_name: Optional[str], type_var: Optional["TypeVar"],
|
||||
operands: Sequence["ScalarExpression"]):
|
||||
if bool(fn_name) + bool(attr_name) != 1:
|
||||
raise ValueError("One of 'fn_name', 'attr_name' must be specified")
|
||||
self.kind = kind
|
||||
self.fn_name = fn_name
|
||||
self.attr_name = attr_name
|
||||
self.type_var = type_var
|
||||
self.operands = operands
|
||||
def __init__(
|
||||
self,
|
||||
kind: "FunctionKind",
|
||||
fn_name: Optional[str],
|
||||
attr_name: Optional[str],
|
||||
type_var: Optional["TypeVar"],
|
||||
operands: Sequence["ScalarExpression"],
|
||||
):
|
||||
if bool(fn_name) + bool(attr_name) != 1:
|
||||
raise ValueError("One of 'fn_name', 'attr_name' must be specified")
|
||||
self.kind = kind
|
||||
self.fn_name = fn_name
|
||||
self.attr_name = attr_name
|
||||
self.type_var = type_var
|
||||
self.operands = operands
|
||||
|
||||
def expr(self) -> "ScalarExpression":
|
||||
return ScalarExpression(scalar_fn=self)
|
||||
def expr(self) -> "ScalarExpression":
|
||||
return ScalarExpression(scalar_fn=self)
|
||||
|
||||
def __repr__(self):
|
||||
name = self.fn_name if self.fn_name else self.attr_name
|
||||
return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
|
||||
f"operands=[{', '.join(self.operands)}])")
|
||||
def __repr__(self):
|
||||
name = self.fn_name if self.fn_name else self.attr_name
|
||||
return (
|
||||
f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
|
||||
f"operands=[{', '.join(self.operands)}])"
|
||||
)
|
||||
|
||||
|
||||
class ScalarArg:
|
||||
"""A type of ScalarExpression that references a named argument."""
|
||||
"""A type of ScalarExpression that references a named argument."""
|
||||
|
||||
def __init__(self, arg: str):
|
||||
self.arg = arg
|
||||
def __init__(self, arg: str):
|
||||
self.arg = arg
|
||||
|
||||
def expr(self) -> "ScalarExpression":
|
||||
return ScalarExpression(scalar_arg=self)
|
||||
def expr(self) -> "ScalarExpression":
|
||||
return ScalarExpression(scalar_arg=self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"(ScalarArg({self.arg})"
|
||||
def __repr__(self):
|
||||
return f"(ScalarArg({self.arg})"
|
||||
|
||||
|
||||
class ScalarConst:
|
||||
"""A type of ScalarExpression representing a constant."""
|
||||
"""A type of ScalarExpression representing a constant."""
|
||||
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
def expr(self) -> "ScalarExpression":
|
||||
return ScalarExpression(scalar_const=self)
|
||||
def expr(self) -> "ScalarExpression":
|
||||
return ScalarExpression(scalar_const=self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"(ScalarConst({self.value})"
|
||||
def __repr__(self):
|
||||
return f"(ScalarConst({self.value})"
|
||||
|
||||
|
||||
class ScalarIndex:
|
||||
"""A type of ScalarExpression accessing an iteration index."""
|
||||
"""A type of ScalarExpression accessing an iteration index."""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
self.dim = dim
|
||||
def __init__(self, dim: int):
|
||||
self.dim = dim
|
||||
|
||||
def expr(self) -> "ScalarExpression":
|
||||
return ScalarExpression(scalar_index=self)
|
||||
def expr(self) -> "ScalarExpression":
|
||||
return ScalarExpression(scalar_index=self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"(ScalarIndex({self.dim})"
|
||||
def __repr__(self):
|
||||
return f"(ScalarIndex({self.dim})"
|
||||
|
||||
|
||||
class ScalarExpression(YAMLObject):
|
||||
"""An expression on scalar values.
|
||||
"""An expression on scalar values.
|
||||
|
||||
Can be one of:
|
||||
- ScalarFn
|
||||
- ScalarArg
|
||||
- ScalarConst
|
||||
- ScalarIndex
|
||||
"""
|
||||
yaml_tag = "!ScalarExpression"
|
||||
Can be one of:
|
||||
- ScalarFn
|
||||
- ScalarArg
|
||||
- ScalarConst
|
||||
- ScalarIndex
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
scalar_fn: Optional[ScalarFn] = None,
|
||||
scalar_arg: Optional[ScalarArg] = None,
|
||||
scalar_const: Optional[ScalarConst] = None,
|
||||
scalar_index: Optional[ScalarIndex] = None):
|
||||
if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) +
|
||||
bool(scalar_index)) != 1:
|
||||
raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
|
||||
"'scalar_index' must be specified")
|
||||
self.scalar_fn = scalar_fn
|
||||
self.scalar_arg = scalar_arg
|
||||
self.scalar_const = scalar_const
|
||||
self.scalar_index = scalar_index
|
||||
yaml_tag = "!ScalarExpression"
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
if self.scalar_fn:
|
||||
scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
|
||||
if self.scalar_fn.fn_name:
|
||||
scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
|
||||
if self.scalar_fn.attr_name:
|
||||
scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
|
||||
if self.scalar_fn.type_var:
|
||||
scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
|
||||
scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
|
||||
return dict(scalar_fn=scalar_fn_dict)
|
||||
elif self.scalar_arg:
|
||||
return dict(scalar_arg=self.scalar_arg.arg)
|
||||
elif self.scalar_const:
|
||||
return dict(scalar_const=self.scalar_const.value)
|
||||
elif self.scalar_index:
|
||||
return dict(scalar_index=self.scalar_index.dim)
|
||||
else:
|
||||
raise ValueError(f"Unexpected ScalarExpression type: {self}")
|
||||
def __init__(
|
||||
self,
|
||||
scalar_fn: Optional[ScalarFn] = None,
|
||||
scalar_arg: Optional[ScalarArg] = None,
|
||||
scalar_const: Optional[ScalarConst] = None,
|
||||
scalar_index: Optional[ScalarIndex] = None,
|
||||
):
|
||||
if (
|
||||
bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + bool(scalar_index)
|
||||
) != 1:
|
||||
raise ValueError(
|
||||
"One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
|
||||
"'scalar_index' must be specified"
|
||||
)
|
||||
self.scalar_fn = scalar_fn
|
||||
self.scalar_arg = scalar_arg
|
||||
self.scalar_const = scalar_const
|
||||
self.scalar_index = scalar_index
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
if self.scalar_fn:
|
||||
scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
|
||||
if self.scalar_fn.fn_name:
|
||||
scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
|
||||
if self.scalar_fn.attr_name:
|
||||
scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
|
||||
if self.scalar_fn.type_var:
|
||||
scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
|
||||
scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
|
||||
return dict(scalar_fn=scalar_fn_dict)
|
||||
elif self.scalar_arg:
|
||||
return dict(scalar_arg=self.scalar_arg.arg)
|
||||
elif self.scalar_const:
|
||||
return dict(scalar_const=self.scalar_const.value)
|
||||
elif self.scalar_index:
|
||||
return dict(scalar_index=self.scalar_index.dim)
|
||||
else:
|
||||
raise ValueError(f"Unexpected ScalarExpression type: {self}")
|
||||
|
||||
|
||||
class ScalarAssign(YAMLObject):
|
||||
"""An assignment to a named argument (LHS of a comprehension)."""
|
||||
yaml_tag = "!ScalarAssign"
|
||||
"""An assignment to a named argument (LHS of a comprehension)."""
|
||||
|
||||
def __init__(self, arg: str, value: ScalarExpression):
|
||||
self.arg = arg
|
||||
self.value = value
|
||||
yaml_tag = "!ScalarAssign"
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
return dict(arg=self.arg, value=self.value)
|
||||
def __init__(self, arg: str, value: ScalarExpression):
|
||||
self.arg = arg
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
return f"ScalarAssign({self.arg}, {self.value})"
|
||||
def to_yaml_custom_dict(self):
|
||||
return dict(arg=self.arg, value=self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ScalarAssign({self.arg}, {self.value})"
|
||||
|
||||
@@ -21,13 +21,11 @@ from typing import Dict
|
||||
__all__ = [
|
||||
"TypeVar",
|
||||
"TV",
|
||||
|
||||
# Predefined types.
|
||||
"I32",
|
||||
"I64",
|
||||
"F32",
|
||||
"F64",
|
||||
|
||||
# TypeVar aliases.
|
||||
"T",
|
||||
"U",
|
||||
@@ -36,34 +34,34 @@ __all__ = [
|
||||
|
||||
|
||||
class TypeVar:
|
||||
"""A replaceable type variable.
|
||||
"""A replaceable type variable.
|
||||
|
||||
Type variables are uniqued by name.
|
||||
"""
|
||||
ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"]
|
||||
Type variables are uniqued by name.
|
||||
"""
|
||||
|
||||
def __new__(cls, name: str):
|
||||
existing = cls.ALL_TYPEVARS.get(name)
|
||||
if existing is not None:
|
||||
return existing
|
||||
new = super().__new__(cls)
|
||||
new.name = name
|
||||
cls.ALL_TYPEVARS[name] = new
|
||||
return new
|
||||
ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"]
|
||||
|
||||
def __repr__(self):
|
||||
return f"TypeVar({self.name})"
|
||||
def __new__(cls, name: str):
|
||||
existing = cls.ALL_TYPEVARS.get(name)
|
||||
if existing is not None:
|
||||
return existing
|
||||
new = super().__new__(cls)
|
||||
new.name = name
|
||||
cls.ALL_TYPEVARS[name] = new
|
||||
return new
|
||||
|
||||
@classmethod
|
||||
def create_expando(cls):
|
||||
"""Create an expando class that creates unique type vars on attr access."""
|
||||
def __repr__(self):
|
||||
return f"TypeVar({self.name})"
|
||||
|
||||
class ExpandoTypeVars:
|
||||
@classmethod
|
||||
def create_expando(cls):
|
||||
"""Create an expando class that creates unique type vars on attr access."""
|
||||
|
||||
def __getattr__(self, n):
|
||||
return cls(n)
|
||||
class ExpandoTypeVars:
|
||||
def __getattr__(self, n):
|
||||
return cls(n)
|
||||
|
||||
return ExpandoTypeVars()
|
||||
return ExpandoTypeVars()
|
||||
|
||||
|
||||
# Expando access via TV.foo
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
import sys
|
||||
|
||||
try:
|
||||
import yaml
|
||||
import yaml
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"This tool requires PyYAML but it was not installed. "
|
||||
f"Recommend: {sys.executable} -m pip install PyYAML") from e
|
||||
raise ModuleNotFoundError(
|
||||
f"This tool requires PyYAML but it was not installed. "
|
||||
f"Recommend: {sys.executable} -m pip install PyYAML"
|
||||
) from e
|
||||
|
||||
__all__ = [
|
||||
"yaml_dump",
|
||||
@@ -20,35 +21,33 @@ __all__ = [
|
||||
|
||||
|
||||
class YAMLObject(yaml.YAMLObject):
|
||||
@classmethod
|
||||
def to_yaml(cls, dumper, self):
|
||||
"""Default to a custom dictionary mapping."""
|
||||
return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())
|
||||
|
||||
@classmethod
|
||||
def to_yaml(cls, dumper, self):
|
||||
"""Default to a custom dictionary mapping."""
|
||||
return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict())
|
||||
def to_yaml_custom_dict(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def as_linalg_yaml(self):
|
||||
return yaml_dump(self)
|
||||
def as_linalg_yaml(self):
|
||||
return yaml_dump(self)
|
||||
|
||||
|
||||
def multiline_str_representer(dumper, data):
|
||||
if len(data.splitlines()) > 1:
|
||||
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
|
||||
else:
|
||||
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
|
||||
if len(data.splitlines()) > 1:
|
||||
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
|
||||
else:
|
||||
return dumper.represent_scalar("tag:yaml.org,2002:str", data)
|
||||
|
||||
|
||||
yaml.add_representer(str, multiline_str_representer)
|
||||
|
||||
|
||||
def yaml_dump(data, sort_keys=False, **kwargs):
|
||||
return yaml.dump(data, sort_keys=sort_keys, **kwargs)
|
||||
return yaml.dump(data, sort_keys=sort_keys, **kwargs)
|
||||
|
||||
|
||||
def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs):
|
||||
return yaml.dump_all(data,
|
||||
sort_keys=sort_keys,
|
||||
explicit_start=explicit_start,
|
||||
**kwargs)
|
||||
return yaml.dump_all(
|
||||
data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,8 @@
|
||||
from ._python_test_ops_gen import *
|
||||
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
|
||||
|
||||
|
||||
def register_python_test_dialect(context, load=True):
|
||||
from .._mlir_libs import _mlirPythonTest
|
||||
_mlirPythonTest.register_python_test_dialect(context, load)
|
||||
from .._mlir_libs import _mlirPythonTest
|
||||
|
||||
_mlirPythonTest.register_python_test_dialect(context, load)
|
||||
|
||||
@@ -6,16 +6,18 @@ from enum import Enum
|
||||
|
||||
|
||||
class FailurePropagationMode(Enum):
|
||||
"""Propagation mode for silenceable errors."""
|
||||
PROPAGATE = 1
|
||||
SUPPRESS = 2
|
||||
"""Propagation mode for silenceable errors."""
|
||||
|
||||
def _as_int(self):
|
||||
if self is FailurePropagationMode.PROPAGATE:
|
||||
return 1
|
||||
PROPAGATE = 1
|
||||
SUPPRESS = 2
|
||||
|
||||
def _as_int(self):
|
||||
if self is FailurePropagationMode.PROPAGATE:
|
||||
return 1
|
||||
|
||||
assert self is FailurePropagationMode.SUPPRESS
|
||||
return 2
|
||||
|
||||
assert self is FailurePropagationMode.SUPPRESS
|
||||
return 2
|
||||
|
||||
from .._transform_ops_gen import *
|
||||
from ..._mlir_libs._mlirDialectsTransform import *
|
||||
|
||||
@@ -7,37 +7,37 @@ from ._mlir_libs import _mlirExecutionEngine as _execution_engine
|
||||
import ctypes
|
||||
|
||||
__all__ = [
|
||||
"ExecutionEngine",
|
||||
"ExecutionEngine",
|
||||
]
|
||||
|
||||
|
||||
class ExecutionEngine(_execution_engine.ExecutionEngine):
|
||||
def lookup(self, name):
|
||||
"""Lookup a function emitted with the `llvm.emit_c_interface`
|
||||
attribute and returns a ctype callable.
|
||||
Raise a RuntimeError if the function isn't found.
|
||||
"""
|
||||
func = self.raw_lookup("_mlir_ciface_" + name)
|
||||
if not func:
|
||||
raise RuntimeError("Unknown function " + name)
|
||||
prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||
return prototype(func)
|
||||
|
||||
def lookup(self, name):
|
||||
"""Lookup a function emitted with the `llvm.emit_c_interface`
|
||||
attribute and returns a ctype callable.
|
||||
Raise a RuntimeError if the function isn't found.
|
||||
"""
|
||||
func = self.raw_lookup("_mlir_ciface_" + name)
|
||||
if not func:
|
||||
raise RuntimeError("Unknown function " + name)
|
||||
prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||
return prototype(func)
|
||||
def invoke(self, name, *ctypes_args):
|
||||
"""Invoke a function with the list of ctypes arguments.
|
||||
All arguments must be pointers.
|
||||
Raise a RuntimeError if the function isn't found.
|
||||
"""
|
||||
func = self.lookup(name)
|
||||
packed_args = (ctypes.c_void_p * len(ctypes_args))()
|
||||
for argNum in range(len(ctypes_args)):
|
||||
packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
|
||||
func(packed_args)
|
||||
|
||||
def invoke(self, name, *ctypes_args):
|
||||
"""Invoke a function with the list of ctypes arguments.
|
||||
All arguments must be pointers.
|
||||
Raise a RuntimeError if the function isn't found.
|
||||
"""
|
||||
func = self.lookup(name)
|
||||
packed_args = (ctypes.c_void_p * len(ctypes_args))()
|
||||
for argNum in range(len(ctypes_args)):
|
||||
packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
|
||||
func(packed_args)
|
||||
|
||||
def register_runtime(self, name, ctypes_callback):
|
||||
"""Register a runtime function available to the jitted code
|
||||
under the provided `name`. The `ctypes_callback` must be a
|
||||
`CFuncType` that outlives the execution engine.
|
||||
"""
|
||||
callback = ctypes.cast(ctypes_callback, ctypes.c_void_p)
|
||||
self.raw_register_runtime("_mlir_ciface_" + name, callback)
|
||||
def register_runtime(self, name, ctypes_callback):
|
||||
"""Register a runtime function available to the jitted code
|
||||
under the provided `name`. The `ctypes_callback` must be a
|
||||
`CFuncType` that outlives the execution engine.
|
||||
"""
|
||||
callback = ctypes.cast(ctypes_callback, ctypes.c_void_p)
|
||||
self.raw_register_runtime("_mlir_ciface_" + name, callback)
|
||||
|
||||
@@ -8,124 +8,123 @@ from ._mlir_libs._mlir.ir import _GlobalDebug
|
||||
|
||||
# Convenience decorator for registering user-friendly Attribute builders.
|
||||
def register_attribute_builder(kind):
|
||||
def decorator_builder(func):
|
||||
AttrBuilder.insert(kind, func)
|
||||
return func
|
||||
|
||||
def decorator_builder(func):
|
||||
AttrBuilder.insert(kind, func)
|
||||
return func
|
||||
|
||||
return decorator_builder
|
||||
return decorator_builder
|
||||
|
||||
|
||||
@register_attribute_builder("BoolAttr")
|
||||
def _boolAttr(x, context):
|
||||
return BoolAttr.get(x, context=context)
|
||||
return BoolAttr.get(x, context=context)
|
||||
|
||||
|
||||
@register_attribute_builder("IndexAttr")
|
||||
def _indexAttr(x, context):
|
||||
return IntegerAttr.get(IndexType.get(context=context), x)
|
||||
return IntegerAttr.get(IndexType.get(context=context), x)
|
||||
|
||||
|
||||
@register_attribute_builder("I16Attr")
|
||||
def _i16Attr(x, context):
|
||||
return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
|
||||
return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
|
||||
|
||||
|
||||
@register_attribute_builder("I32Attr")
|
||||
def _i32Attr(x, context):
|
||||
return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
|
||||
return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
|
||||
|
||||
|
||||
@register_attribute_builder("I64Attr")
|
||||
def _i64Attr(x, context):
|
||||
return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
|
||||
return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
|
||||
|
||||
|
||||
@register_attribute_builder("SI16Attr")
|
||||
def _si16Attr(x, context):
|
||||
return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
|
||||
return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
|
||||
|
||||
|
||||
@register_attribute_builder("SI32Attr")
|
||||
def _si32Attr(x, context):
|
||||
return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
|
||||
return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
|
||||
|
||||
|
||||
@register_attribute_builder("F32Attr")
|
||||
def _f32Attr(x, context):
|
||||
return FloatAttr.get_f32(x, context=context)
|
||||
return FloatAttr.get_f32(x, context=context)
|
||||
|
||||
|
||||
@register_attribute_builder("F64Attr")
|
||||
def _f64Attr(x, context):
|
||||
return FloatAttr.get_f64(x, context=context)
|
||||
return FloatAttr.get_f64(x, context=context)
|
||||
|
||||
|
||||
@register_attribute_builder("StrAttr")
|
||||
def _stringAttr(x, context):
|
||||
return StringAttr.get(x, context=context)
|
||||
return StringAttr.get(x, context=context)
|
||||
|
||||
|
||||
@register_attribute_builder("SymbolNameAttr")
|
||||
def _symbolNameAttr(x, context):
|
||||
return StringAttr.get(x, context=context)
|
||||
return StringAttr.get(x, context=context)
|
||||
|
||||
|
||||
@register_attribute_builder("SymbolRefAttr")
|
||||
def _symbolRefAttr(x, context):
|
||||
return FlatSymbolRefAttr.get(x, context=context)
|
||||
return FlatSymbolRefAttr.get(x, context=context)
|
||||
|
||||
|
||||
@register_attribute_builder("ArrayAttr")
|
||||
def _arrayAttr(x, context):
|
||||
return ArrayAttr.get(x, context=context)
|
||||
return ArrayAttr.get(x, context=context)
|
||||
|
||||
|
||||
@register_attribute_builder("I32ArrayAttr")
|
||||
def _i32ArrayAttr(x, context):
|
||||
return ArrayAttr.get([_i32Attr(v, context) for v in x])
|
||||
return ArrayAttr.get([_i32Attr(v, context) for v in x])
|
||||
|
||||
|
||||
@register_attribute_builder("I64ArrayAttr")
|
||||
def _i64ArrayAttr(x, context):
|
||||
return ArrayAttr.get([_i64Attr(v, context) for v in x])
|
||||
return ArrayAttr.get([_i64Attr(v, context) for v in x])
|
||||
|
||||
|
||||
@register_attribute_builder("F32ArrayAttr")
|
||||
def _f32ArrayAttr(x, context):
|
||||
return ArrayAttr.get([_f32Attr(v, context) for v in x])
|
||||
return ArrayAttr.get([_f32Attr(v, context) for v in x])
|
||||
|
||||
|
||||
@register_attribute_builder("F64ArrayAttr")
|
||||
def _f64ArrayAttr(x, context):
|
||||
return ArrayAttr.get([_f64Attr(v, context) for v in x])
|
||||
return ArrayAttr.get([_f64Attr(v, context) for v in x])
|
||||
|
||||
|
||||
@register_attribute_builder("DenseI64ArrayAttr")
|
||||
def _denseI64ArrayAttr(x, context):
|
||||
return DenseI64ArrayAttr.get(x, context=context)
|
||||
return DenseI64ArrayAttr.get(x, context=context)
|
||||
|
||||
|
||||
@register_attribute_builder("TypeAttr")
|
||||
def _typeAttr(x, context):
|
||||
return TypeAttr.get(x, context=context)
|
||||
return TypeAttr.get(x, context=context)
|
||||
|
||||
|
||||
@register_attribute_builder("TypeArrayAttr")
|
||||
def _typeArrayAttr(x, context):
|
||||
return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
|
||||
return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
|
||||
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy as np
|
||||
|
||||
@register_attribute_builder("IndexElementsAttr")
|
||||
def _indexElementsAttr(x, context):
|
||||
return DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.int64),
|
||||
type=IndexType.get(context=context),
|
||||
context=context,
|
||||
)
|
||||
@register_attribute_builder("IndexElementsAttr")
|
||||
def _indexElementsAttr(x, context):
|
||||
return DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.int64),
|
||||
type=IndexType.get(context=context),
|
||||
context=context,
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -9,131 +9,134 @@ import ctypes
|
||||
|
||||
|
||||
class C128(ctypes.Structure):
|
||||
"""A ctype representation for MLIR's Double Complex."""
|
||||
_fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
|
||||
"""A ctype representation for MLIR's Double Complex."""
|
||||
|
||||
_fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
|
||||
|
||||
|
||||
class C64(ctypes.Structure):
|
||||
"""A ctype representation for MLIR's Float Complex."""
|
||||
_fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
|
||||
"""A ctype representation for MLIR's Float Complex."""
|
||||
|
||||
_fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
|
||||
|
||||
|
||||
class F16(ctypes.Structure):
|
||||
"""A ctype representation for MLIR's Float16."""
|
||||
_fields_ = [("f16", ctypes.c_int16)]
|
||||
"""A ctype representation for MLIR's Float16."""
|
||||
|
||||
_fields_ = [("f16", ctypes.c_int16)]
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
|
||||
def as_ctype(dtp):
|
||||
"""Converts dtype to ctype."""
|
||||
if dtp == np.dtype(np.complex128):
|
||||
return C128
|
||||
if dtp == np.dtype(np.complex64):
|
||||
return C64
|
||||
if dtp == np.dtype(np.float16):
|
||||
return F16
|
||||
return np.ctypeslib.as_ctypes_type(dtp)
|
||||
"""Converts dtype to ctype."""
|
||||
if dtp == np.dtype(np.complex128):
|
||||
return C128
|
||||
if dtp == np.dtype(np.complex64):
|
||||
return C64
|
||||
if dtp == np.dtype(np.float16):
|
||||
return F16
|
||||
return np.ctypeslib.as_ctypes_type(dtp)
|
||||
|
||||
|
||||
def to_numpy(array):
|
||||
"""Converts ctypes array back to numpy dtype array."""
|
||||
if array.dtype == C128:
|
||||
return array.view("complex128")
|
||||
if array.dtype == C64:
|
||||
return array.view("complex64")
|
||||
if array.dtype == F16:
|
||||
return array.view("float16")
|
||||
return array
|
||||
"""Converts ctypes array back to numpy dtype array."""
|
||||
if array.dtype == C128:
|
||||
return array.view("complex128")
|
||||
if array.dtype == C64:
|
||||
return array.view("complex64")
|
||||
if array.dtype == F16:
|
||||
return array.view("float16")
|
||||
return array
|
||||
|
||||
|
||||
def make_nd_memref_descriptor(rank, dtype):
|
||||
class MemRefDescriptor(ctypes.Structure):
|
||||
"""Builds an empty descriptor for the given rank/dtype, where rank>0."""
|
||||
|
||||
class MemRefDescriptor(ctypes.Structure):
|
||||
"""Builds an empty descriptor for the given rank/dtype, where rank>0."""
|
||||
_fields_ = [
|
||||
("allocated", ctypes.c_longlong),
|
||||
("aligned", ctypes.POINTER(dtype)),
|
||||
("offset", ctypes.c_longlong),
|
||||
("shape", ctypes.c_longlong * rank),
|
||||
("strides", ctypes.c_longlong * rank),
|
||||
]
|
||||
|
||||
_fields_ = [
|
||||
("allocated", ctypes.c_longlong),
|
||||
("aligned", ctypes.POINTER(dtype)),
|
||||
("offset", ctypes.c_longlong),
|
||||
("shape", ctypes.c_longlong * rank),
|
||||
("strides", ctypes.c_longlong * rank),
|
||||
]
|
||||
|
||||
return MemRefDescriptor
|
||||
return MemRefDescriptor
|
||||
|
||||
|
||||
def make_zero_d_memref_descriptor(dtype):
|
||||
class MemRefDescriptor(ctypes.Structure):
|
||||
"""Builds an empty descriptor for the given dtype, where rank=0."""
|
||||
|
||||
class MemRefDescriptor(ctypes.Structure):
|
||||
"""Builds an empty descriptor for the given dtype, where rank=0."""
|
||||
_fields_ = [
|
||||
("allocated", ctypes.c_longlong),
|
||||
("aligned", ctypes.POINTER(dtype)),
|
||||
("offset", ctypes.c_longlong),
|
||||
]
|
||||
|
||||
_fields_ = [
|
||||
("allocated", ctypes.c_longlong),
|
||||
("aligned", ctypes.POINTER(dtype)),
|
||||
("offset", ctypes.c_longlong),
|
||||
]
|
||||
|
||||
return MemRefDescriptor
|
||||
return MemRefDescriptor
|
||||
|
||||
|
||||
class UnrankedMemRefDescriptor(ctypes.Structure):
|
||||
"""Creates a ctype struct for memref descriptor"""
|
||||
_fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
|
||||
"""Creates a ctype struct for memref descriptor"""
|
||||
|
||||
_fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
|
||||
|
||||
|
||||
def get_ranked_memref_descriptor(nparray):
|
||||
"""Returns a ranked memref descriptor for the given numpy array."""
|
||||
ctp = as_ctype(nparray.dtype)
|
||||
if nparray.ndim == 0:
|
||||
x = make_zero_d_memref_descriptor(ctp)()
|
||||
"""Returns a ranked memref descriptor for the given numpy array."""
|
||||
ctp = as_ctype(nparray.dtype)
|
||||
if nparray.ndim == 0:
|
||||
x = make_zero_d_memref_descriptor(ctp)()
|
||||
x.allocated = nparray.ctypes.data
|
||||
x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
|
||||
x.offset = ctypes.c_longlong(0)
|
||||
return x
|
||||
|
||||
x = make_nd_memref_descriptor(nparray.ndim, ctp)()
|
||||
x.allocated = nparray.ctypes.data
|
||||
x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
|
||||
x.offset = ctypes.c_longlong(0)
|
||||
x.shape = nparray.ctypes.shape
|
||||
|
||||
# Numpy uses byte quantities to express strides, MLIR OTOH uses the
|
||||
# torch abstraction which specifies strides in terms of elements.
|
||||
strides_ctype_t = ctypes.c_longlong * nparray.ndim
|
||||
x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
|
||||
return x
|
||||
|
||||
x = make_nd_memref_descriptor(nparray.ndim, ctp)()
|
||||
x.allocated = nparray.ctypes.data
|
||||
x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
|
||||
x.offset = ctypes.c_longlong(0)
|
||||
x.shape = nparray.ctypes.shape
|
||||
|
||||
# Numpy uses byte quantities to express strides, MLIR OTOH uses the
|
||||
# torch abstraction which specifies strides in terms of elements.
|
||||
strides_ctype_t = ctypes.c_longlong * nparray.ndim
|
||||
x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
|
||||
return x
|
||||
|
||||
|
||||
def get_unranked_memref_descriptor(nparray):
|
||||
"""Returns a generic/unranked memref descriptor for the given numpy array."""
|
||||
d = UnrankedMemRefDescriptor()
|
||||
d.rank = nparray.ndim
|
||||
x = get_ranked_memref_descriptor(nparray)
|
||||
d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
|
||||
return d
|
||||
"""Returns a generic/unranked memref descriptor for the given numpy array."""
|
||||
d = UnrankedMemRefDescriptor()
|
||||
d.rank = nparray.ndim
|
||||
x = get_ranked_memref_descriptor(nparray)
|
||||
d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
|
||||
return d
|
||||
|
||||
|
||||
def unranked_memref_to_numpy(unranked_memref, np_dtype):
|
||||
"""Converts unranked memrefs to numpy arrays."""
|
||||
ctp = as_ctype(np_dtype)
|
||||
descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
|
||||
val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
|
||||
np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
|
||||
strided_arr = np.lib.stride_tricks.as_strided(
|
||||
np_arr,
|
||||
np.ctypeslib.as_array(val[0].shape),
|
||||
np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
|
||||
)
|
||||
return to_numpy(strided_arr)
|
||||
"""Converts unranked memrefs to numpy arrays."""
|
||||
ctp = as_ctype(np_dtype)
|
||||
descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
|
||||
val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
|
||||
np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
|
||||
strided_arr = np.lib.stride_tricks.as_strided(
|
||||
np_arr,
|
||||
np.ctypeslib.as_array(val[0].shape),
|
||||
np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
|
||||
)
|
||||
return to_numpy(strided_arr)
|
||||
|
||||
|
||||
def ranked_memref_to_numpy(ranked_memref):
|
||||
"""Converts ranked memrefs to numpy arrays."""
|
||||
np_arr = np.ctypeslib.as_array(
|
||||
ranked_memref[0].aligned, shape=ranked_memref[0].shape)
|
||||
strided_arr = np.lib.stride_tricks.as_strided(
|
||||
np_arr,
|
||||
np.ctypeslib.as_array(ranked_memref[0].shape),
|
||||
np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
|
||||
)
|
||||
return to_numpy(strided_arr)
|
||||
"""Converts ranked memrefs to numpy arrays."""
|
||||
np_arr = np.ctypeslib.as_array(
|
||||
ranked_memref[0].aligned, shape=ranked_memref[0].shape
|
||||
)
|
||||
strided_arr = np.lib.stride_tricks.as_strided(
|
||||
np_arr,
|
||||
np.ctypeslib.as_array(ranked_memref[0].shape),
|
||||
np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
|
||||
)
|
||||
return to_numpy(strided_arr)
|
||||
|
||||
@@ -1 +1 @@
|
||||
config.suffixes.add('.c')
|
||||
config.suffixes.add(".c")
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
if not config.run_cuda_tests:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
if not config.run_rocm_tests:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
# Requires native execution.
|
||||
if 'host-supports-jit' not in config.available_features:
|
||||
if "host-supports-jit" not in config.available_features:
|
||||
config.unsupported = True
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
# Requires native execution.
|
||||
if 'host-supports-jit' not in config.available_features:
|
||||
if "host-supports-jit" not in config.available_features:
|
||||
config.unsupported = True
|
||||
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
if not config.build_examples:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
# Disable with sanitizers for now, this require some more setup apparently.
|
||||
for san in ['asan', 'msan', 'ubsan']:
|
||||
if (san in config.available_features):
|
||||
config.unsupported = True
|
||||
for san in ["asan", "msan", "ubsan"]:
|
||||
if san in config.available_features:
|
||||
config.unsupported = True
|
||||
|
||||
config.substitutions.append(("%cmake_exe", config.host_cmake))
|
||||
config.substitutions.append(("%cmake_generator", config.host_cmake_generator))
|
||||
config.substitutions.append(("%host_cxx", config.host_cxx))
|
||||
config.substitutions.append(("%host_cc", config.host_cc))
|
||||
config.substitutions.append(("%enable_libcxx", config.enable_libcxx))
|
||||
config.substitutions.append(
|
||||
("%mlir_cmake_dir", config.mlir_cmake_dir))
|
||||
config.substitutions.append(("%mlir_cmake_dir", config.mlir_cmake_dir))
|
||||
config.substitutions.append(("%llvm_use_linker", config.llvm_use_linker))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import sys
|
||||
|
||||
# Windows does not have aligned_alloc
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
config.unsupported = True
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import platform
|
||||
|
||||
if platform.machine() != 'x86_64':
|
||||
if platform.machine() != "x86_64":
|
||||
config.unsupported = True
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
import sys
|
||||
|
||||
lli_cmd = 'lli'
|
||||
lli_cmd = "lli"
|
||||
if config.riscv_emulator_lli_executable:
|
||||
lli_cmd = config.riscv_emulator_lli_executable
|
||||
|
||||
config.substitutions.append(('%mlir_native_utils_lib_dir',
|
||||
config.riscv_emulator_utils_lib_dir or config.mlir_lib_dir))
|
||||
config.substitutions.append(
|
||||
(
|
||||
"%mlir_native_utils_lib_dir",
|
||||
config.riscv_emulator_utils_lib_dir or config.mlir_lib_dir,
|
||||
)
|
||||
)
|
||||
|
||||
if config.riscv_vector_emulator_executable:
|
||||
# Run test in qemu emulator.
|
||||
emulation_cmd = config.riscv_vector_emulator_executable
|
||||
if config.riscv_vector_emulator_options:
|
||||
emulation_cmd = emulation_cmd + ' ' + config.riscv_vector_emulator_options
|
||||
emulation_cmd = emulation_cmd + ' ' + lli_cmd + ' --march=riscv64 -mattr=+v '
|
||||
config.substitutions.append(('%lli', emulation_cmd))
|
||||
emulation_cmd = emulation_cmd + " " + config.riscv_vector_emulator_options
|
||||
emulation_cmd = emulation_cmd + " " + lli_cmd + " --march=riscv64 -mattr=+v "
|
||||
config.substitutions.append(("%lli", emulation_cmd))
|
||||
else:
|
||||
config.substitutions.append(('%lli', lli_cmd))
|
||||
config.substitutions.append(("%lli", lli_cmd))
|
||||
|
||||
@@ -2,14 +2,16 @@ import sys
|
||||
from lit.llvm import llvm_config
|
||||
|
||||
# FIXME: %mlir_native_utils_lib_dir is set incorrectly on Windows
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
config.unsupported = True
|
||||
|
||||
# ArmSVE tests must be enabled via build flag.
|
||||
if config.mlir_run_arm_sve_tests:
|
||||
config.substitutions.append(('%ENABLE_VLA', 'true'))
|
||||
config.substitutions.append(('%VLA_ARCH_ATTR_OPTIONS', '--march=aarch64 --mattr="+sve"'))
|
||||
config.substitutions.append(("%ENABLE_VLA", "true"))
|
||||
config.substitutions.append(
|
||||
("%VLA_ARCH_ATTR_OPTIONS", '--march=aarch64 --mattr="+sve"')
|
||||
)
|
||||
else:
|
||||
config.substitutions.append(('%ENABLE_VLA', 'false'))
|
||||
config.substitutions.append(('%VLA_ARCH_ATTR_OPTIONS', ''))
|
||||
config.substitutions.append(('%mlir_native_utils_lib_dir', config.mlir_lib_dir))
|
||||
config.substitutions.append(("%ENABLE_VLA", "false"))
|
||||
config.substitutions.append(("%VLA_ARCH_ATTR_OPTIONS", ""))
|
||||
config.substitutions.append(("%mlir_native_utils_lib_dir", config.mlir_lib_dir))
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
if not config.enable_cuda_runner or not config.mlir_run_cuda_sm80_tests:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Disable ASAN's leak detection for python OpsDSL tests.
|
||||
config.environment['ASAN_OPTIONS'] = 'detect_leaks=0'
|
||||
config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
|
||||
# Only run when python bindings are enabled.
|
||||
if not config.enable_bindings_python:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
@@ -18,42 +18,45 @@ _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(_SCRIPT_PATH)
|
||||
from tools import sparse_compiler
|
||||
|
||||
|
||||
@dsl.linalg_structured_op
|
||||
def sddmm_dsl(
|
||||
A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
|
||||
B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
|
||||
S=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N),
|
||||
C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)):
|
||||
C[dsl.D.m,
|
||||
dsl.D.n] += S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
|
||||
C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
|
||||
):
|
||||
C[dsl.D.m, dsl.D.n] += (
|
||||
S[dsl.D.m, dsl.D.n] * A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
|
||||
)
|
||||
|
||||
|
||||
def build_SDDMM(attr: st.EncodingAttr):
|
||||
"""Build SDDMM kernel.
|
||||
"""Build SDDMM kernel.
|
||||
|
||||
This method generates a linalg op with for matrix multiplication using
|
||||
just the Python API. Effectively, a generic linalg op is constructed
|
||||
that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
|
||||
"""
|
||||
module = ir.Module.create()
|
||||
f64 = ir.F64Type.get()
|
||||
a = ir.RankedTensorType.get([8, 8], f64)
|
||||
b = ir.RankedTensorType.get([8, 8], f64)
|
||||
c = ir.RankedTensorType.get([8, 8], f64)
|
||||
s = ir.RankedTensorType.get([8, 8], f64, attr)
|
||||
arguments = [a, b, s, c]
|
||||
with ir.InsertionPoint(module.body):
|
||||
This method generates a linalg op with for matrix multiplication using
|
||||
just the Python API. Effectively, a generic linalg op is constructed
|
||||
that computes C(i,j) += S(i,j) SUM_k A(i,k) B(k,j) for sparse S.
|
||||
"""
|
||||
module = ir.Module.create()
|
||||
f64 = ir.F64Type.get()
|
||||
a = ir.RankedTensorType.get([8, 8], f64)
|
||||
b = ir.RankedTensorType.get([8, 8], f64)
|
||||
c = ir.RankedTensorType.get([8, 8], f64)
|
||||
s = ir.RankedTensorType.get([8, 8], f64, attr)
|
||||
arguments = [a, b, s, c]
|
||||
with ir.InsertionPoint(module.body):
|
||||
|
||||
@func.FuncOp.from_py_func(*arguments)
|
||||
def sddmm(*args):
|
||||
return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]])
|
||||
@func.FuncOp.from_py_func(*arguments)
|
||||
def sddmm(*args):
|
||||
return sddmm_dsl(args[0], args[1], args[2], outs=[args[3]])
|
||||
|
||||
return module
|
||||
return module
|
||||
|
||||
|
||||
def boilerplate(attr: st.EncodingAttr):
|
||||
"""Returns boilerplate code for main driver."""
|
||||
return f"""
|
||||
"""Returns boilerplate code for main driver."""
|
||||
return f"""
|
||||
func.func @main(%a: tensor<8x8xf64>,
|
||||
%b: tensor<8x8xf64>,
|
||||
%c: tensor<8x8xf64>) -> tensor<8x8xf64> attributes {{ llvm.emit_c_interface }} {{
|
||||
@@ -69,92 +72,100 @@ func.func @main(%a: tensor<8x8xf64>,
|
||||
|
||||
|
||||
def build_compile_and_run_SDDMMM(attr: st.EncodingAttr, compiler):
|
||||
# Build.
|
||||
module = build_SDDMM(attr)
|
||||
func = str(module.operation.regions[0].blocks[0].operations[0].operation)
|
||||
module = ir.Module.parse(func + boilerplate(attr))
|
||||
# Build.
|
||||
module = build_SDDMM(attr)
|
||||
func = str(module.operation.regions[0].blocks[0].operations[0].operation)
|
||||
module = ir.Module.parse(func + boilerplate(attr))
|
||||
|
||||
# Compile.
|
||||
engine = compiler.compile_and_jit(module)
|
||||
# Compile.
|
||||
engine = compiler.compile_and_jit(module)
|
||||
|
||||
# Set up numpy input and buffer for output.
|
||||
a = np.array([[1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
|
||||
[1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
|
||||
[1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
|
||||
[1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
|
||||
[1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
|
||||
[1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
|
||||
[1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
|
||||
[1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8]], np.float64)
|
||||
b = np.ones((8, 8), np.float64)
|
||||
c = np.zeros((8, 8), np.float64)
|
||||
# Set up numpy input and buffer for output.
|
||||
a = np.array(
|
||||
[
|
||||
[1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
|
||||
[1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2],
|
||||
[1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3],
|
||||
[1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4],
|
||||
[1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5],
|
||||
[1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6],
|
||||
[1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7],
|
||||
[1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8],
|
||||
],
|
||||
np.float64,
|
||||
)
|
||||
b = np.ones((8, 8), np.float64)
|
||||
c = np.zeros((8, 8), np.float64)
|
||||
|
||||
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
|
||||
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
|
||||
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
|
||||
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
|
||||
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
|
||||
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
|
||||
|
||||
# Allocate a MemRefDescriptor to receive the output tensor.
|
||||
# The buffer itself is allocated inside the MLIR code generation.
|
||||
ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
|
||||
mem_out = ctypes.pointer(ctypes.pointer(ref_out))
|
||||
# Allocate a MemRefDescriptor to receive the output tensor.
|
||||
# The buffer itself is allocated inside the MLIR code generation.
|
||||
ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
|
||||
mem_out = ctypes.pointer(ctypes.pointer(ref_out))
|
||||
|
||||
# Invoke the kernel and get numpy output.
|
||||
# Built-in bufferization uses in-out buffers.
|
||||
# TODO: replace with inplace comprehensive bufferization.
|
||||
engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
|
||||
# Invoke the kernel and get numpy output.
|
||||
# Built-in bufferization uses in-out buffers.
|
||||
# TODO: replace with inplace comprehensive bufferization.
|
||||
engine.invoke("main", mem_out, mem_a, mem_b, mem_c)
|
||||
|
||||
# Sanity check on computed result. Only a few elements
|
||||
# are sampled from the full dense matrix multiplication.
|
||||
full_matmul = np.matmul(a, b)
|
||||
expected = np.zeros((8, 8), np.float64)
|
||||
expected[0, 0] = 1.0 * full_matmul[0, 0]
|
||||
expected[0, 2] = 2.0 * full_matmul[0, 2]
|
||||
expected[4, 1] = 3.0 * full_matmul[4, 1]
|
||||
c = rt.ranked_memref_to_numpy(mem_out[0])
|
||||
if np.allclose(c, expected):
|
||||
pass
|
||||
else:
|
||||
quit(f'FAILURE')
|
||||
# Sanity check on computed result. Only a few elements
|
||||
# are sampled from the full dense matrix multiplication.
|
||||
full_matmul = np.matmul(a, b)
|
||||
expected = np.zeros((8, 8), np.float64)
|
||||
expected[0, 0] = 1.0 * full_matmul[0, 0]
|
||||
expected[0, 2] = 2.0 * full_matmul[0, 2]
|
||||
expected[4, 1] = 3.0 * full_matmul[4, 1]
|
||||
c = rt.ranked_memref_to_numpy(mem_out[0])
|
||||
if np.allclose(c, expected):
|
||||
pass
|
||||
else:
|
||||
quit(f"FAILURE")
|
||||
|
||||
|
||||
def main():
|
||||
support_lib = os.getenv('SUPPORT_LIB')
|
||||
assert support_lib is not None, 'SUPPORT_LIB is undefined'
|
||||
if not os.path.exists(support_lib):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
|
||||
support_lib)
|
||||
support_lib = os.getenv("SUPPORT_LIB")
|
||||
assert support_lib is not None, "SUPPORT_LIB is undefined"
|
||||
if not os.path.exists(support_lib):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
|
||||
|
||||
# CHECK-LABEL: TEST: testSDDMMM
|
||||
print('\nTEST: testSDDMMM')
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
count = 0
|
||||
# Loop over various ways to compile and annotate the SDDMM kernel with
|
||||
# a *single* sparse tensor. Note that we deliberate do not exhaustively
|
||||
# search the full state space to reduce runtime of the test. It is
|
||||
# straightforward to adapt the code below to explore more combinations.
|
||||
levels = [[st.DimLevelType.dense, st.DimLevelType.dense],
|
||||
[st.DimLevelType.dense, st.DimLevelType.compressed],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.dense],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.compressed]]
|
||||
orderings = [
|
||||
ir.AffineMap.get_permutation([0, 1]),
|
||||
ir.AffineMap.get_permutation([1, 0])
|
||||
]
|
||||
for level in levels:
|
||||
for ordering in orderings:
|
||||
for pwidth in [32]:
|
||||
for iwidth in [32]:
|
||||
for e in [True]:
|
||||
attr = st.EncodingAttr.get(level, ordering, None, pwidth,
|
||||
iwidth)
|
||||
opt = (f'parallelization-strategy=none')
|
||||
compiler = sparse_compiler.SparseCompiler(
|
||||
options=opt, opt_level=0, shared_libs=[support_lib])
|
||||
build_compile_and_run_SDDMMM(attr, compiler)
|
||||
count = count + 1
|
||||
# CHECK: Passed 8 tests
|
||||
print('Passed ', count, 'tests')
|
||||
# CHECK-LABEL: TEST: testSDDMMM
|
||||
print("\nTEST: testSDDMMM")
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
count = 0
|
||||
# Loop over various ways to compile and annotate the SDDMM kernel with
|
||||
# a *single* sparse tensor. Note that we deliberate do not exhaustively
|
||||
# search the full state space to reduce runtime of the test. It is
|
||||
# straightforward to adapt the code below to explore more combinations.
|
||||
levels = [
|
||||
[st.DimLevelType.dense, st.DimLevelType.dense],
|
||||
[st.DimLevelType.dense, st.DimLevelType.compressed],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.dense],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.compressed],
|
||||
]
|
||||
orderings = [
|
||||
ir.AffineMap.get_permutation([0, 1]),
|
||||
ir.AffineMap.get_permutation([1, 0]),
|
||||
]
|
||||
for level in levels:
|
||||
for ordering in orderings:
|
||||
for pwidth in [32]:
|
||||
for iwidth in [32]:
|
||||
for e in [True]:
|
||||
attr = st.EncodingAttr.get(
|
||||
level, ordering, None, pwidth, iwidth
|
||||
)
|
||||
opt = f"parallelization-strategy=none"
|
||||
compiler = sparse_compiler.SparseCompiler(
|
||||
options=opt, opt_level=0, shared_libs=[support_lib]
|
||||
)
|
||||
build_compile_and_run_SDDMMM(attr, compiler)
|
||||
count = count + 1
|
||||
# CHECK: Passed 8 tests
|
||||
print("Passed ", count, "tests")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -18,45 +18,47 @@ _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(_SCRIPT_PATH)
|
||||
from tools import sparse_compiler
|
||||
|
||||
|
||||
@dsl.linalg_structured_op
|
||||
def matmul_dsl(
|
||||
A=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.K),
|
||||
B=dsl.TensorDef(dsl.T, dsl.S.K, dsl.S.N),
|
||||
C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True)):
|
||||
C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
|
||||
C=dsl.TensorDef(dsl.T, dsl.S.M, dsl.S.N, output=True),
|
||||
):
|
||||
C[dsl.D.m, dsl.D.n] += A[dsl.D.m, dsl.D.k] * B[dsl.D.k, dsl.D.n]
|
||||
|
||||
|
||||
def build_SpMM(attr: st.EncodingAttr):
|
||||
"""Build SpMM kernel.
|
||||
"""Build SpMM kernel.
|
||||
|
||||
This method generates a linalg op with for matrix multiplication using
|
||||
just the Python API. Effectively, a generic linalg op is constructed
|
||||
that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A.
|
||||
"""
|
||||
module = ir.Module.create()
|
||||
f64 = ir.F64Type.get()
|
||||
a = ir.RankedTensorType.get([3, 4], f64, attr)
|
||||
b = ir.RankedTensorType.get([4, 2], f64)
|
||||
c = ir.RankedTensorType.get([3, 2], f64)
|
||||
arguments = [a, b, c]
|
||||
with ir.InsertionPoint(module.body):
|
||||
This method generates a linalg op with for matrix multiplication using
|
||||
just the Python API. Effectively, a generic linalg op is constructed
|
||||
that computes C(i,j) += A(i,k) * B(k,j) for annotated matrix A.
|
||||
"""
|
||||
module = ir.Module.create()
|
||||
f64 = ir.F64Type.get()
|
||||
a = ir.RankedTensorType.get([3, 4], f64, attr)
|
||||
b = ir.RankedTensorType.get([4, 2], f64)
|
||||
c = ir.RankedTensorType.get([3, 2], f64)
|
||||
arguments = [a, b, c]
|
||||
with ir.InsertionPoint(module.body):
|
||||
|
||||
@func.FuncOp.from_py_func(*arguments)
|
||||
def spMxM(*args):
|
||||
return matmul_dsl(args[0], args[1], outs=[args[2]])
|
||||
@func.FuncOp.from_py_func(*arguments)
|
||||
def spMxM(*args):
|
||||
return matmul_dsl(args[0], args[1], outs=[args[2]])
|
||||
|
||||
return module
|
||||
return module
|
||||
|
||||
|
||||
def boilerplate(attr: st.EncodingAttr):
|
||||
"""Returns boilerplate main method.
|
||||
"""Returns boilerplate main method.
|
||||
|
||||
This method sets up a boilerplate main method that takes three tensors
|
||||
(a, b, c), converts the first tensor a into s sparse tensor, and then
|
||||
calls the sparse kernel for matrix multiplication. For convenience,
|
||||
this part is purely done as string input.
|
||||
"""
|
||||
return f"""
|
||||
This method sets up a boilerplate main method that takes three tensors
|
||||
(a, b, c), converts the first tensor a into s sparse tensor, and then
|
||||
calls the sparse kernel for matrix multiplication. For convenience,
|
||||
this part is purely done as string input.
|
||||
"""
|
||||
return f"""
|
||||
func.func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) -> tensor<3x2xf64>
|
||||
attributes {{ llvm.emit_c_interface }} {{
|
||||
%a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, {attr}>
|
||||
@@ -69,82 +71,87 @@ func.func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>)
|
||||
|
||||
|
||||
def build_compile_and_run_SpMM(attr: st.EncodingAttr, compiler):
|
||||
# Build.
|
||||
module = build_SpMM(attr)
|
||||
func = str(module.operation.regions[0].blocks[0].operations[0].operation)
|
||||
module = ir.Module.parse(func + boilerplate(attr))
|
||||
# Build.
|
||||
module = build_SpMM(attr)
|
||||
func = str(module.operation.regions[0].blocks[0].operations[0].operation)
|
||||
module = ir.Module.parse(func + boilerplate(attr))
|
||||
|
||||
# Compile.
|
||||
engine = compiler.compile_and_jit(module)
|
||||
# Compile.
|
||||
engine = compiler.compile_and_jit(module)
|
||||
|
||||
# Set up numpy input and buffer for output.
|
||||
a = np.array(
|
||||
[[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]],
|
||||
np.float64)
|
||||
b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
|
||||
c = np.zeros((3, 2), np.float64)
|
||||
# Set up numpy input and buffer for output.
|
||||
a = np.array(
|
||||
[[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]], np.float64
|
||||
)
|
||||
b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
|
||||
c = np.zeros((3, 2), np.float64)
|
||||
|
||||
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
|
||||
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
|
||||
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
|
||||
# Allocate a MemRefDescriptor to receive the output tensor.
|
||||
# The buffer itself is allocated inside the MLIR code generation.
|
||||
ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
|
||||
mem_out = ctypes.pointer(ctypes.pointer(ref_out))
|
||||
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
|
||||
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
|
||||
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
|
||||
# Allocate a MemRefDescriptor to receive the output tensor.
|
||||
# The buffer itself is allocated inside the MLIR code generation.
|
||||
ref_out = rt.make_nd_memref_descriptor(2, ctypes.c_double)()
|
||||
mem_out = ctypes.pointer(ctypes.pointer(ref_out))
|
||||
|
||||
# Invoke the kernel and get numpy output.
|
||||
# Built-in bufferization uses in-out buffers.
|
||||
# TODO: replace with inplace comprehensive bufferization.
|
||||
engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
|
||||
# Invoke the kernel and get numpy output.
|
||||
# Built-in bufferization uses in-out buffers.
|
||||
# TODO: replace with inplace comprehensive bufferization.
|
||||
engine.invoke("main", mem_out, mem_a, mem_b, mem_c)
|
||||
|
||||
# Sanity check on computed result.
|
||||
expected = np.matmul(a, b);
|
||||
c = rt.ranked_memref_to_numpy(mem_out[0])
|
||||
if np.allclose(c, expected):
|
||||
pass
|
||||
else:
|
||||
quit(f'FAILURE')
|
||||
# Sanity check on computed result.
|
||||
expected = np.matmul(a, b)
|
||||
c = rt.ranked_memref_to_numpy(mem_out[0])
|
||||
if np.allclose(c, expected):
|
||||
pass
|
||||
else:
|
||||
quit(f"FAILURE")
|
||||
|
||||
|
||||
def main():
|
||||
support_lib = os.getenv('SUPPORT_LIB')
|
||||
assert support_lib is not None, 'SUPPORT_LIB is undefined'
|
||||
if not os.path.exists(support_lib):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
|
||||
support_lib = os.getenv("SUPPORT_LIB")
|
||||
assert support_lib is not None, "SUPPORT_LIB is undefined"
|
||||
if not os.path.exists(support_lib):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
|
||||
|
||||
# CHECK-LABEL: TEST: testSpMM
|
||||
print('\nTEST: testSpMM')
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
count = 0
|
||||
# Loop over various ways to compile and annotate the SpMM kernel with
|
||||
# a *single* sparse tensor. Note that we deliberate do not exhaustively
|
||||
# search the full state space to reduce runtime of the test. It is
|
||||
# straightforward to adapt the code below to explore more combinations.
|
||||
# CHECK-LABEL: TEST: testSpMM
|
||||
print("\nTEST: testSpMM")
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
count = 0
|
||||
# Loop over various ways to compile and annotate the SpMM kernel with
|
||||
# a *single* sparse tensor. Note that we deliberate do not exhaustively
|
||||
# search the full state space to reduce runtime of the test. It is
|
||||
# straightforward to adapt the code below to explore more combinations.
|
||||
|
||||
vl = 1
|
||||
e = False
|
||||
opt = (f'parallelization-strategy=none')
|
||||
levels = [[st.DimLevelType.dense, st.DimLevelType.dense],
|
||||
[st.DimLevelType.dense, st.DimLevelType.compressed],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.dense],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.compressed]]
|
||||
orderings = [
|
||||
ir.AffineMap.get_permutation([0, 1]),
|
||||
ir.AffineMap.get_permutation([1, 0])
|
||||
]
|
||||
bitwidths = [0]
|
||||
compiler = sparse_compiler.SparseCompiler(
|
||||
options=opt, opt_level=0, shared_libs=[support_lib])
|
||||
for level in levels:
|
||||
for ordering in orderings:
|
||||
for pwidth in bitwidths:
|
||||
for iwidth in bitwidths:
|
||||
attr = st.EncodingAttr.get(level, ordering, None, pwidth, iwidth)
|
||||
build_compile_and_run_SpMM(attr, compiler)
|
||||
count = count + 1
|
||||
# CHECK: Passed 8 tests
|
||||
print('Passed ', count, 'tests')
|
||||
vl = 1
|
||||
e = False
|
||||
opt = f"parallelization-strategy=none"
|
||||
levels = [
|
||||
[st.DimLevelType.dense, st.DimLevelType.dense],
|
||||
[st.DimLevelType.dense, st.DimLevelType.compressed],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.dense],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.compressed],
|
||||
]
|
||||
orderings = [
|
||||
ir.AffineMap.get_permutation([0, 1]),
|
||||
ir.AffineMap.get_permutation([1, 0]),
|
||||
]
|
||||
bitwidths = [0]
|
||||
compiler = sparse_compiler.SparseCompiler(
|
||||
options=opt, opt_level=0, shared_libs=[support_lib]
|
||||
)
|
||||
for level in levels:
|
||||
for ordering in orderings:
|
||||
for pwidth in bitwidths:
|
||||
for iwidth in bitwidths:
|
||||
attr = st.EncodingAttr.get(
|
||||
level, ordering, None, pwidth, iwidth
|
||||
)
|
||||
build_compile_and_run_SpMM(attr, compiler)
|
||||
count = count + 1
|
||||
# CHECK: Passed 8 tests
|
||||
print("Passed ", count, "tests")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -57,49 +57,52 @@ func.func @main(%ad: tensor<3x4xf64>, %bd: tensor<3x4xf64>) -> tensor<3x4xf64, #
|
||||
|
||||
|
||||
def _run_test(support_lib, kernel):
|
||||
"""Compiles, runs and checks results."""
|
||||
compiler = sparse_compiler.SparseCompiler(
|
||||
options='', opt_level=2, shared_libs=[support_lib])
|
||||
module = ir.Module.parse(kernel)
|
||||
engine = compiler.compile_and_jit(module)
|
||||
"""Compiles, runs and checks results."""
|
||||
compiler = sparse_compiler.SparseCompiler(
|
||||
options="", opt_level=2, shared_libs=[support_lib]
|
||||
)
|
||||
module = ir.Module.parse(kernel)
|
||||
engine = compiler.compile_and_jit(module)
|
||||
|
||||
# Set up numpy inputs and buffer for output.
|
||||
a = np.array(
|
||||
[[1.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 6.6, 0.0]],
|
||||
np.float64)
|
||||
b = np.array(
|
||||
[[1.1, 0.0, 0.0, 2.8], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
|
||||
np.float64)
|
||||
# Set up numpy inputs and buffer for output.
|
||||
a = np.array(
|
||||
[[1.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 6.6, 0.0]], np.float64
|
||||
)
|
||||
b = np.array(
|
||||
[[1.1, 0.0, 0.0, 2.8], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], np.float64
|
||||
)
|
||||
|
||||
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
|
||||
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
|
||||
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
|
||||
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
|
||||
|
||||
# The sparse tensor output is a pointer to pointer of char.
|
||||
out = ctypes.c_char(0)
|
||||
mem_out = ctypes.pointer(ctypes.pointer(out))
|
||||
# The sparse tensor output is a pointer to pointer of char.
|
||||
out = ctypes.c_char(0)
|
||||
mem_out = ctypes.pointer(ctypes.pointer(out))
|
||||
|
||||
# Invoke the kernel.
|
||||
engine.invoke('main', mem_a, mem_b, mem_out)
|
||||
# Invoke the kernel.
|
||||
engine.invoke("main", mem_a, mem_b, mem_out)
|
||||
|
||||
# Retrieve and check the result.
|
||||
rank, nse, shape, values, indices = test_tools.sparse_tensor_to_coo_tensor(
|
||||
support_lib, mem_out[0], np.float64)
|
||||
# Retrieve and check the result.
|
||||
rank, nse, shape, values, indices = test_tools.sparse_tensor_to_coo_tensor(
|
||||
support_lib, mem_out[0], np.float64
|
||||
)
|
||||
|
||||
# CHECK: PASSED
|
||||
if np.allclose(values, [2.2, 2.8, 6.6]) and np.allclose(
|
||||
indices, [[0, 0], [0, 3], [2, 2]]):
|
||||
print('PASSED')
|
||||
else:
|
||||
quit('FAILURE')
|
||||
# CHECK: PASSED
|
||||
if np.allclose(values, [2.2, 2.8, 6.6]) and np.allclose(
|
||||
indices, [[0, 0], [0, 3], [2, 2]]
|
||||
):
|
||||
print("PASSED")
|
||||
else:
|
||||
quit("FAILURE")
|
||||
|
||||
|
||||
def test_elementwise_add():
|
||||
# Obtain path to runtime support library.
|
||||
support_lib = os.getenv('SUPPORT_LIB')
|
||||
assert support_lib is not None, 'SUPPORT_LIB is undefined'
|
||||
assert os.path.exists(support_lib), f'{support_lib} does not exist'
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
_run_test(support_lib, _KERNEL_STR)
|
||||
# Obtain path to runtime support library.
|
||||
support_lib = os.getenv("SUPPORT_LIB")
|
||||
assert support_lib is not None, "SUPPORT_LIB is undefined"
|
||||
assert os.path.exists(support_lib), f"{support_lib} does not exist"
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
_run_test(support_lib, _KERNEL_STR)
|
||||
|
||||
|
||||
test_elementwise_add()
|
||||
|
||||
@@ -18,8 +18,8 @@ from tools import sparse_compiler
|
||||
|
||||
# TODO: move more into actual IR building.
|
||||
def boilerplate(attr: st.EncodingAttr):
|
||||
"""Returns boilerplate main method."""
|
||||
return f"""
|
||||
"""Returns boilerplate main method."""
|
||||
return f"""
|
||||
func.func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }} {{
|
||||
%d = arith.constant sparse<[[0, 0], [1, 1], [0, 9], [9, 0], [4, 4]],
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<10x10xf64>
|
||||
@@ -31,13 +31,13 @@ func.func @main(%p : !llvm.ptr<i8>) -> () attributes {{ llvm.emit_c_interface }}
|
||||
|
||||
|
||||
def expected():
|
||||
"""Returns expected contents of output.
|
||||
"""Returns expected contents of output.
|
||||
|
||||
Regardless of the dimension ordering, compression, and bitwidths that are
|
||||
used in the sparse tensor, the output is always lexicographically sorted
|
||||
by natural index order.
|
||||
"""
|
||||
return f"""; extended FROSTT format
|
||||
Regardless of the dimension ordering, compression, and bitwidths that are
|
||||
used in the sparse tensor, the output is always lexicographically sorted
|
||||
by natural index order.
|
||||
"""
|
||||
return f"""; extended FROSTT format
|
||||
2 5
|
||||
10 10
|
||||
1 1 1
|
||||
@@ -49,53 +49,55 @@ def expected():
|
||||
|
||||
|
||||
def build_compile_and_run_output(attr: st.EncodingAttr, compiler):
|
||||
# Build and Compile.
|
||||
module = ir.Module.parse(boilerplate(attr))
|
||||
engine = compiler.compile_and_jit(module)
|
||||
# Build and Compile.
|
||||
module = ir.Module.parse(boilerplate(attr))
|
||||
engine = compiler.compile_and_jit(module)
|
||||
|
||||
# Invoke the kernel and compare output.
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
out = os.path.join(test_dir, 'out.tns')
|
||||
buf = out.encode('utf-8')
|
||||
mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
|
||||
engine.invoke('main', mem_a)
|
||||
# Invoke the kernel and compare output.
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
out = os.path.join(test_dir, "out.tns")
|
||||
buf = out.encode("utf-8")
|
||||
mem_a = ctypes.pointer(ctypes.pointer(ctypes.create_string_buffer(buf)))
|
||||
engine.invoke("main", mem_a)
|
||||
|
||||
actual = open(out).read()
|
||||
if actual != expected():
|
||||
quit('FAILURE')
|
||||
actual = open(out).read()
|
||||
if actual != expected():
|
||||
quit("FAILURE")
|
||||
|
||||
|
||||
def main():
|
||||
support_lib = os.getenv('SUPPORT_LIB')
|
||||
assert support_lib is not None, 'SUPPORT_LIB is undefined'
|
||||
if not os.path.exists(support_lib):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
|
||||
support_lib)
|
||||
support_lib = os.getenv("SUPPORT_LIB")
|
||||
assert support_lib is not None, "SUPPORT_LIB is undefined"
|
||||
if not os.path.exists(support_lib):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
|
||||
|
||||
# CHECK-LABEL: TEST: test_output
|
||||
print('\nTEST: test_output')
|
||||
count = 0
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
# Loop over various sparse types: CSR, DCSR, CSC, DCSC.
|
||||
levels = [[st.DimLevelType.dense, st.DimLevelType.compressed],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.compressed]]
|
||||
orderings = [
|
||||
ir.AffineMap.get_permutation([0, 1]),
|
||||
ir.AffineMap.get_permutation([1, 0])
|
||||
]
|
||||
bitwidths = [8, 16, 32, 64]
|
||||
compiler = sparse_compiler.SparseCompiler(
|
||||
options='', opt_level=2, shared_libs=[support_lib])
|
||||
for level in levels:
|
||||
for ordering in orderings:
|
||||
for bwidth in bitwidths:
|
||||
attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
|
||||
build_compile_and_run_output(attr, compiler)
|
||||
count = count + 1
|
||||
# CHECK-LABEL: TEST: test_output
|
||||
print("\nTEST: test_output")
|
||||
count = 0
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
# Loop over various sparse types: CSR, DCSR, CSC, DCSC.
|
||||
levels = [
|
||||
[st.DimLevelType.dense, st.DimLevelType.compressed],
|
||||
[st.DimLevelType.compressed, st.DimLevelType.compressed],
|
||||
]
|
||||
orderings = [
|
||||
ir.AffineMap.get_permutation([0, 1]),
|
||||
ir.AffineMap.get_permutation([1, 0]),
|
||||
]
|
||||
bitwidths = [8, 16, 32, 64]
|
||||
compiler = sparse_compiler.SparseCompiler(
|
||||
options="", opt_level=2, shared_libs=[support_lib]
|
||||
)
|
||||
for level in levels:
|
||||
for ordering in orderings:
|
||||
for bwidth in bitwidths:
|
||||
attr = st.EncodingAttr.get(level, ordering, None, bwidth, bwidth)
|
||||
build_compile_and_run_output(attr, compiler)
|
||||
count = count + 1
|
||||
|
||||
# CHECK: Passed 16 tests
|
||||
print('Passed', count, 'tests')
|
||||
# CHECK: Passed 16 tests
|
||||
print("Passed", count, "tests")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -28,216 +28,241 @@ from tools import sparse_compiler
|
||||
# TODO: move this boilerplate to its own module, so it can be used by
|
||||
# other tests and programs.
|
||||
class TypeConverter:
|
||||
"""Converter between NumPy types and MLIR types."""
|
||||
"""Converter between NumPy types and MLIR types."""
|
||||
|
||||
def __init__(self, context: ir.Context):
|
||||
# Note 1: these are numpy "scalar types" (i.e., the values of
|
||||
# np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
|
||||
#
|
||||
# Note 2: we must construct the MLIR types in the same context as the
|
||||
# types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
|
||||
# otherwise, those methods will raise a KeyError.
|
||||
types_list = [
|
||||
(np.float64, ir.F64Type.get(context=context)),
|
||||
(np.float32, ir.F32Type.get(context=context)),
|
||||
(np.int64, ir.IntegerType.get_signless(64, context=context)),
|
||||
(np.int32, ir.IntegerType.get_signless(32, context=context)),
|
||||
(np.int16, ir.IntegerType.get_signless(16, context=context)),
|
||||
(np.int8, ir.IntegerType.get_signless(8, context=context)),
|
||||
]
|
||||
self._sc2ir = dict(types_list)
|
||||
self._ir2sc = dict(( (ir,sc) for sc,ir in types_list ))
|
||||
def __init__(self, context: ir.Context):
|
||||
# Note 1: these are numpy "scalar types" (i.e., the values of
|
||||
# np.sctypeDict) not numpy "dtypes" (i.e., the np.dtype class).
|
||||
#
|
||||
# Note 2: we must construct the MLIR types in the same context as the
|
||||
# types that'll be passed to irtype_to_sctype() or irtype_to_dtype();
|
||||
# otherwise, those methods will raise a KeyError.
|
||||
types_list = [
|
||||
(np.float64, ir.F64Type.get(context=context)),
|
||||
(np.float32, ir.F32Type.get(context=context)),
|
||||
(np.int64, ir.IntegerType.get_signless(64, context=context)),
|
||||
(np.int32, ir.IntegerType.get_signless(32, context=context)),
|
||||
(np.int16, ir.IntegerType.get_signless(16, context=context)),
|
||||
(np.int8, ir.IntegerType.get_signless(8, context=context)),
|
||||
]
|
||||
self._sc2ir = dict(types_list)
|
||||
self._ir2sc = dict(((ir, sc) for sc, ir in types_list))
|
||||
|
||||
def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
|
||||
"""Returns the MLIR equivalent of a NumPy dtype."""
|
||||
try:
|
||||
return self.sctype_to_irtype(dtype.type)
|
||||
except KeyError as e:
|
||||
raise KeyError(f'Unknown dtype: {dtype}') from e
|
||||
def dtype_to_irtype(self, dtype: np.dtype) -> ir.Type:
|
||||
"""Returns the MLIR equivalent of a NumPy dtype."""
|
||||
try:
|
||||
return self.sctype_to_irtype(dtype.type)
|
||||
except KeyError as e:
|
||||
raise KeyError(f"Unknown dtype: {dtype}") from e
|
||||
|
||||
def sctype_to_irtype(self, sctype) -> ir.Type:
|
||||
"""Returns the MLIR equivalent of a NumPy scalar type."""
|
||||
if sctype in self._sc2ir:
|
||||
return self._sc2ir[sctype]
|
||||
else:
|
||||
raise KeyError(f'Unknown sctype: {sctype}')
|
||||
def sctype_to_irtype(self, sctype) -> ir.Type:
|
||||
"""Returns the MLIR equivalent of a NumPy scalar type."""
|
||||
if sctype in self._sc2ir:
|
||||
return self._sc2ir[sctype]
|
||||
else:
|
||||
raise KeyError(f"Unknown sctype: {sctype}")
|
||||
|
||||
def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
|
||||
"""Returns the NumPy dtype equivalent of an MLIR type."""
|
||||
return np.dtype(self.irtype_to_sctype(tp))
|
||||
def irtype_to_dtype(self, tp: ir.Type) -> np.dtype:
|
||||
"""Returns the NumPy dtype equivalent of an MLIR type."""
|
||||
return np.dtype(self.irtype_to_sctype(tp))
|
||||
|
||||
def irtype_to_sctype(self, tp: ir.Type):
|
||||
"""Returns the NumPy scalar-type equivalent of an MLIR type."""
|
||||
if tp in self._ir2sc:
|
||||
return self._ir2sc[tp]
|
||||
else:
|
||||
raise KeyError(f'Unknown ir.Type: {tp}')
|
||||
def irtype_to_sctype(self, tp: ir.Type):
|
||||
"""Returns the NumPy scalar-type equivalent of an MLIR type."""
|
||||
if tp in self._ir2sc:
|
||||
return self._ir2sc[tp]
|
||||
else:
|
||||
raise KeyError(f"Unknown ir.Type: {tp}")
|
||||
|
||||
def get_RankedTensorType_of_nparray(
|
||||
self, nparray: np.ndarray
|
||||
) -> ir.RankedTensorType:
|
||||
"""Returns the ir.RankedTensorType of a NumPy array. Note that NumPy
|
||||
arrays can only be converted to/from dense tensors, not sparse tensors."""
|
||||
# TODO: handle strides as well?
|
||||
return ir.RankedTensorType.get(
|
||||
nparray.shape, self.dtype_to_irtype(nparray.dtype)
|
||||
)
|
||||
|
||||
def get_RankedTensorType_of_nparray(self, nparray: np.ndarray) -> ir.RankedTensorType:
|
||||
"""Returns the ir.RankedTensorType of a NumPy array. Note that NumPy
|
||||
arrays can only be converted to/from dense tensors, not sparse tensors."""
|
||||
# TODO: handle strides as well?
|
||||
return ir.RankedTensorType.get(nparray.shape,
|
||||
self.dtype_to_irtype(nparray.dtype))
|
||||
|
||||
# ===----------------------------------------------------------------------=== #
|
||||
|
||||
|
||||
class StressTest:
|
||||
def __init__(self, tyconv: TypeConverter):
|
||||
self._tyconv = tyconv
|
||||
self._roundtripTp = None
|
||||
self._module = None
|
||||
self._engine = None
|
||||
def __init__(self, tyconv: TypeConverter):
|
||||
self._tyconv = tyconv
|
||||
self._roundtripTp = None
|
||||
self._module = None
|
||||
self._engine = None
|
||||
|
||||
def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
|
||||
assert self._roundtripTp is not None, \
|
||||
'StressTest: uninitialized roundtrip type'
|
||||
if tp != self._roundtripTp:
|
||||
raise AssertionError(
|
||||
f"Type is not equal to the roundtrip type.\n"
|
||||
f"\tExpected: {self._roundtripTp}\n"
|
||||
f"\tFound: {tp}\n")
|
||||
def _assertEqualsRoundtripTp(self, tp: ir.RankedTensorType):
|
||||
assert self._roundtripTp is not None, "StressTest: uninitialized roundtrip type"
|
||||
if tp != self._roundtripTp:
|
||||
raise AssertionError(
|
||||
f"Type is not equal to the roundtrip type.\n"
|
||||
f"\tExpected: {self._roundtripTp}\n"
|
||||
f"\tFound: {tp}\n"
|
||||
)
|
||||
|
||||
def build(self, types: List[ir.Type]):
|
||||
"""Builds the ir.Module. The module has only the @main function,
|
||||
which will convert the input through the list of types and then back
|
||||
to the initial type. The roundtrip type must be a dense tensor."""
|
||||
assert self._module is None, 'StressTest: must not call build() repeatedly'
|
||||
self._module = ir.Module.create()
|
||||
with ir.InsertionPoint(self._module.body):
|
||||
tp0 = types.pop(0)
|
||||
self._roundtripTp = tp0
|
||||
# TODO: assert dense? assert element type is recognised by the TypeConverter?
|
||||
types.append(tp0)
|
||||
funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
|
||||
funcOp = func.FuncOp(name='main', type=funcTp)
|
||||
funcOp.attributes['llvm.emit_c_interface'] = ir.UnitAttr.get()
|
||||
with ir.InsertionPoint(funcOp.add_entry_block()):
|
||||
arg0 = funcOp.entry_block.arguments[0]
|
||||
self._assertEqualsRoundtripTp(arg0.type)
|
||||
v = st.ConvertOp(types.pop(0), arg0)
|
||||
for tp in types:
|
||||
w = st.ConvertOp(tp, v)
|
||||
# Release intermediate tensors before they fall out of scope.
|
||||
bufferization.DeallocTensorOp(v.result)
|
||||
v = w
|
||||
self._assertEqualsRoundtripTp(v.result.type)
|
||||
func.ReturnOp(v)
|
||||
return self
|
||||
def build(self, types: List[ir.Type]):
|
||||
"""Builds the ir.Module. The module has only the @main function,
|
||||
which will convert the input through the list of types and then back
|
||||
to the initial type. The roundtrip type must be a dense tensor."""
|
||||
assert self._module is None, "StressTest: must not call build() repeatedly"
|
||||
self._module = ir.Module.create()
|
||||
with ir.InsertionPoint(self._module.body):
|
||||
tp0 = types.pop(0)
|
||||
self._roundtripTp = tp0
|
||||
# TODO: assert dense? assert element type is recognised by the TypeConverter?
|
||||
types.append(tp0)
|
||||
funcTp = ir.FunctionType.get(inputs=[tp0], results=[tp0])
|
||||
funcOp = func.FuncOp(name="main", type=funcTp)
|
||||
funcOp.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
||||
with ir.InsertionPoint(funcOp.add_entry_block()):
|
||||
arg0 = funcOp.entry_block.arguments[0]
|
||||
self._assertEqualsRoundtripTp(arg0.type)
|
||||
v = st.ConvertOp(types.pop(0), arg0)
|
||||
for tp in types:
|
||||
w = st.ConvertOp(tp, v)
|
||||
# Release intermediate tensors before they fall out of scope.
|
||||
bufferization.DeallocTensorOp(v.result)
|
||||
v = w
|
||||
self._assertEqualsRoundtripTp(v.result.type)
|
||||
func.ReturnOp(v)
|
||||
return self
|
||||
|
||||
def writeTo(self, filename):
|
||||
"""Write the ir.Module to the given file. If the file already exists,
|
||||
then raises an error. If the filename is None, then is a no-op."""
|
||||
assert self._module is not None, \
|
||||
'StressTest: must call build() before writeTo()'
|
||||
if filename is None:
|
||||
# Silent no-op, for convenience.
|
||||
return self
|
||||
if os.path.exists(filename):
|
||||
raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
|
||||
with open(filename, 'w') as f:
|
||||
f.write(str(self._module))
|
||||
return self
|
||||
def writeTo(self, filename):
|
||||
"""Write the ir.Module to the given file. If the file already exists,
|
||||
then raises an error. If the filename is None, then is a no-op."""
|
||||
assert (
|
||||
self._module is not None
|
||||
), "StressTest: must call build() before writeTo()"
|
||||
if filename is None:
|
||||
# Silent no-op, for convenience.
|
||||
return self
|
||||
if os.path.exists(filename):
|
||||
raise FileExistsError(errno.EEXIST, os.strerror(errno.EEXIST), filename)
|
||||
with open(filename, "w") as f:
|
||||
f.write(str(self._module))
|
||||
return self
|
||||
|
||||
def compile(self, compiler):
|
||||
"""Compile the ir.Module."""
|
||||
assert self._module is not None, \
|
||||
'StressTest: must call build() before compile()'
|
||||
assert self._engine is None, \
|
||||
'StressTest: must not call compile() repeatedly'
|
||||
self._engine = compiler.compile_and_jit(self._module)
|
||||
return self
|
||||
def compile(self, compiler):
|
||||
"""Compile the ir.Module."""
|
||||
assert (
|
||||
self._module is not None
|
||||
), "StressTest: must call build() before compile()"
|
||||
assert self._engine is None, "StressTest: must not call compile() repeatedly"
|
||||
self._engine = compiler.compile_and_jit(self._module)
|
||||
return self
|
||||
|
||||
def run(self, np_arg0: np.ndarray) -> np.ndarray:
|
||||
"""Runs the test on the given numpy array, and returns the resulting
|
||||
numpy array."""
|
||||
assert self._engine is not None, "StressTest: must call compile() before run()"
|
||||
self._assertEqualsRoundtripTp(
|
||||
self._tyconv.get_RankedTensorType_of_nparray(np_arg0)
|
||||
)
|
||||
np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
|
||||
self._assertEqualsRoundtripTp(
|
||||
self._tyconv.get_RankedTensorType_of_nparray(np_out)
|
||||
)
|
||||
mem_arg0 = ctypes.pointer(
|
||||
ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0))
|
||||
)
|
||||
mem_out = ctypes.pointer(
|
||||
ctypes.pointer(rt.get_ranked_memref_descriptor(np_out))
|
||||
)
|
||||
self._engine.invoke("main", mem_out, mem_arg0)
|
||||
return rt.ranked_memref_to_numpy(mem_out[0])
|
||||
|
||||
def run(self, np_arg0: np.ndarray) -> np.ndarray:
|
||||
"""Runs the test on the given numpy array, and returns the resulting
|
||||
numpy array."""
|
||||
assert self._engine is not None, \
|
||||
'StressTest: must call compile() before run()'
|
||||
self._assertEqualsRoundtripTp(
|
||||
self._tyconv.get_RankedTensorType_of_nparray(np_arg0))
|
||||
np_out = np.zeros(np_arg0.shape, dtype=np_arg0.dtype)
|
||||
self._assertEqualsRoundtripTp(
|
||||
self._tyconv.get_RankedTensorType_of_nparray(np_out))
|
||||
mem_arg0 = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_arg0)))
|
||||
mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(np_out)))
|
||||
self._engine.invoke('main', mem_out, mem_arg0)
|
||||
return rt.ranked_memref_to_numpy(mem_out[0])
|
||||
|
||||
# ===----------------------------------------------------------------------=== #
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
|
||||
"""
|
||||
USAGE: python3 test_stress.py [raw_module.mlir [compiled_module.mlir]]
|
||||
|
||||
The environment variable SUPPORT_LIB must be set to point to the
|
||||
libmlir_c_runner_utils shared library. There are two optional
|
||||
arguments, for debugging purposes. The first argument specifies where
|
||||
to write out the raw/generated ir.Module. The second argument specifies
|
||||
where to write out the compiled version of that ir.Module.
|
||||
"""
|
||||
support_lib = os.getenv('SUPPORT_LIB')
|
||||
assert support_lib is not None, 'SUPPORT_LIB is undefined'
|
||||
if not os.path.exists(support_lib):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
|
||||
The environment variable SUPPORT_LIB must be set to point to the
|
||||
libmlir_c_runner_utils shared library. There are two optional
|
||||
arguments, for debugging purposes. The first argument specifies where
|
||||
to write out the raw/generated ir.Module. The second argument specifies
|
||||
where to write out the compiled version of that ir.Module.
|
||||
"""
|
||||
support_lib = os.getenv("SUPPORT_LIB")
|
||||
assert support_lib is not None, "SUPPORT_LIB is undefined"
|
||||
if not os.path.exists(support_lib):
|
||||
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
|
||||
|
||||
# CHECK-LABEL: TEST: test_stress
|
||||
print("\nTEST: test_stress")
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
# Disable direct sparse2sparse conversion, because it doubles the time!
|
||||
# TODO: While direct s2s is far too slow for per-commit testing,
|
||||
# we should have some framework ensure that we run this test with
|
||||
# `s2s=0` on a regular basis, to ensure that it does continue to work.
|
||||
# TODO: be sure to test s2s=0 together with singletons.
|
||||
s2s = 1
|
||||
sparsification_options = (
|
||||
f'parallelization-strategy=none '
|
||||
f's2s-strategy={s2s}')
|
||||
compiler = sparse_compiler.SparseCompiler(
|
||||
options=sparsification_options, opt_level=0, shared_libs=[support_lib])
|
||||
f64 = ir.F64Type.get()
|
||||
# Be careful about increasing this because
|
||||
# len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
|
||||
shape = range(2, 3)
|
||||
rank = len(shape)
|
||||
# All combinations.
|
||||
# TODO: add singleton here too; which requires updating how `np_arg0`
|
||||
# is initialized below.
|
||||
levels = list(itertools.product(*itertools.repeat(
|
||||
[st.DimLevelType.dense, st.DimLevelType.compressed], rank)))
|
||||
# All permutations.
|
||||
orderings = list(map(ir.AffineMap.get_permutation,
|
||||
itertools.permutations(range(rank))))
|
||||
bitwidths = [0]
|
||||
# The first type must be a dense tensor for numpy conversion to work.
|
||||
types = [ir.RankedTensorType.get(shape, f64)]
|
||||
for level in levels:
|
||||
for ordering in orderings:
|
||||
for pwidth in bitwidths:
|
||||
for iwidth in bitwidths:
|
||||
attr = st.EncodingAttr.get(level, ordering, None, pwidth, iwidth)
|
||||
types.append(ir.RankedTensorType.get(shape, f64, attr))
|
||||
#
|
||||
# For exhaustiveness we should have one or more StressTest, such
|
||||
# that their paths cover all 2*n*(n-1) directed pairwise combinations
|
||||
# of the `types` set. However, since n is already superexponential,
|
||||
# such exhaustiveness would be prohibitive for a test that runs on
|
||||
# every commit. So for now we'll just pick one particular path that
|
||||
# at least hits all n elements of the `types` set.
|
||||
#
|
||||
tyconv = TypeConverter(ctx)
|
||||
size = 1
|
||||
for d in shape:
|
||||
size *= d
|
||||
np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
|
||||
np_out = (
|
||||
StressTest(tyconv).build(types).writeTo(
|
||||
sys.argv[1] if len(sys.argv) > 1 else None).compile(compiler)
|
||||
.writeTo(sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0))
|
||||
# CHECK: Passed
|
||||
if np.allclose(np_out, np_arg0):
|
||||
print('Passed')
|
||||
else:
|
||||
sys.exit('FAILURE')
|
||||
# CHECK-LABEL: TEST: test_stress
|
||||
print("\nTEST: test_stress")
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
# Disable direct sparse2sparse conversion, because it doubles the time!
|
||||
# TODO: While direct s2s is far too slow for per-commit testing,
|
||||
# we should have some framework ensure that we run this test with
|
||||
# `s2s=0` on a regular basis, to ensure that it does continue to work.
|
||||
# TODO: be sure to test s2s=0 together with singletons.
|
||||
s2s = 1
|
||||
sparsification_options = f"parallelization-strategy=none " f"s2s-strategy={s2s}"
|
||||
compiler = sparse_compiler.SparseCompiler(
|
||||
options=sparsification_options, opt_level=0, shared_libs=[support_lib]
|
||||
)
|
||||
f64 = ir.F64Type.get()
|
||||
# Be careful about increasing this because
|
||||
# len(types) = 1 + len(level_choices)^rank * rank! * len(bitwidths)^2
|
||||
shape = range(2, 3)
|
||||
rank = len(shape)
|
||||
# All combinations.
|
||||
# TODO: add singleton here too; which requires updating how `np_arg0`
|
||||
# is initialized below.
|
||||
levels = list(
|
||||
itertools.product(
|
||||
*itertools.repeat(
|
||||
[st.DimLevelType.dense, st.DimLevelType.compressed], rank
|
||||
)
|
||||
)
|
||||
)
|
||||
# All permutations.
|
||||
orderings = list(
|
||||
map(ir.AffineMap.get_permutation, itertools.permutations(range(rank)))
|
||||
)
|
||||
bitwidths = [0]
|
||||
# The first type must be a dense tensor for numpy conversion to work.
|
||||
types = [ir.RankedTensorType.get(shape, f64)]
|
||||
for level in levels:
|
||||
for ordering in orderings:
|
||||
for pwidth in bitwidths:
|
||||
for iwidth in bitwidths:
|
||||
attr = st.EncodingAttr.get(
|
||||
level, ordering, None, pwidth, iwidth
|
||||
)
|
||||
types.append(ir.RankedTensorType.get(shape, f64, attr))
|
||||
#
|
||||
# For exhaustiveness we should have one or more StressTest, such
|
||||
# that their paths cover all 2*n*(n-1) directed pairwise combinations
|
||||
# of the `types` set. However, since n is already superexponential,
|
||||
# such exhaustiveness would be prohibitive for a test that runs on
|
||||
# every commit. So for now we'll just pick one particular path that
|
||||
# at least hits all n elements of the `types` set.
|
||||
#
|
||||
tyconv = TypeConverter(ctx)
|
||||
size = 1
|
||||
for d in shape:
|
||||
size *= d
|
||||
np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
|
||||
np_out = (
|
||||
StressTest(tyconv)
|
||||
.build(types)
|
||||
.writeTo(sys.argv[1] if len(sys.argv) > 1 else None)
|
||||
.compile(compiler)
|
||||
.writeTo(sys.argv[2] if len(sys.argv) > 2 else None)
|
||||
.run(np_arg0)
|
||||
)
|
||||
# CHECK: Passed
|
||||
if np.allclose(np_out, np_arg0):
|
||||
print("Passed")
|
||||
else:
|
||||
sys.exit("FAILURE")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -11,65 +11,71 @@ import numpy as np
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_c_shared_lib(lib_name: str):
|
||||
"""Loads and returns the requested C shared library.
|
||||
"""Loads and returns the requested C shared library.
|
||||
|
||||
Args:
|
||||
lib_name: A string representing the C shared library.
|
||||
Args:
|
||||
lib_name: A string representing the C shared library.
|
||||
|
||||
Returns:
|
||||
The C shared library.
|
||||
Returns:
|
||||
The C shared library.
|
||||
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routine.
|
||||
"""
|
||||
# This raises OSError exception if there is any problem in loading the shared
|
||||
# library.
|
||||
c_lib = ctypes.CDLL(lib_name)
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routine.
|
||||
"""
|
||||
# This raises OSError exception if there is any problem in loading the shared
|
||||
# library.
|
||||
c_lib = ctypes.CDLL(lib_name)
|
||||
|
||||
try:
|
||||
c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p
|
||||
except Exception as e:
|
||||
raise ValueError('Missing function convertFromMLIRSparseTensorF64 from '
|
||||
f'the C shared library: {e} ') from e
|
||||
try:
|
||||
c_lib.convertFromMLIRSparseTensorF64.restype = ctypes.c_void_p
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Missing function convertFromMLIRSparseTensorF64 from "
|
||||
f"the C shared library: {e} "
|
||||
) from e
|
||||
|
||||
return c_lib
|
||||
return c_lib
|
||||
|
||||
|
||||
def sparse_tensor_to_coo_tensor(support_lib, sparse, dtype):
|
||||
"""Converts a sparse tensor to COO-flavored format.
|
||||
"""Converts a sparse tensor to COO-flavored format.
|
||||
|
||||
Args:
|
||||
support_lib: A string for the supporting C shared library.
|
||||
sparse: A ctypes.pointer to the sparse tensor descriptor.
|
||||
dtype: The numpy data type for the tensor elements.
|
||||
Args:
|
||||
support_lib: A string for the supporting C shared library.
|
||||
sparse: A ctypes.pointer to the sparse tensor descriptor.
|
||||
dtype: The numpy data type for the tensor elements.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the following values:
|
||||
rank: An integer for the rank of the tensor.
|
||||
nse: An integer for the number of non-zero values in the tensor.
|
||||
shape: A 1D numpy array of integers, for the shape of the tensor.
|
||||
values: A 1D numpy array, for the non-zero values in the tensor.
|
||||
indices: A 2D numpy array of integers, representing the indices for the
|
||||
non-zero values in the tensor.
|
||||
Returns:
|
||||
A tuple that contains the following values:
|
||||
rank: An integer for the rank of the tensor.
|
||||
nse: An integer for the number of non-zero values in the tensor.
|
||||
shape: A 1D numpy array of integers, for the shape of the tensor.
|
||||
values: A 1D numpy array, for the non-zero values in the tensor.
|
||||
indices: A 2D numpy array of integers, representing the indices for the
|
||||
non-zero values in the tensor.
|
||||
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routine.
|
||||
"""
|
||||
c_lib = _get_c_shared_lib(support_lib)
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routine.
|
||||
"""
|
||||
c_lib = _get_c_shared_lib(support_lib)
|
||||
|
||||
rank = ctypes.c_ulonglong(0)
|
||||
nse = ctypes.c_ulonglong(0)
|
||||
shape = ctypes.POINTER(ctypes.c_ulonglong)()
|
||||
values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
|
||||
indices = ctypes.POINTER(ctypes.c_ulonglong)()
|
||||
c_lib.convertFromMLIRSparseTensorF64(sparse, ctypes.byref(rank),
|
||||
ctypes.byref(nse), ctypes.byref(shape),
|
||||
ctypes.byref(values),
|
||||
ctypes.byref(indices))
|
||||
# Convert the returned values to the corresponding numpy types.
|
||||
shape = np.ctypeslib.as_array(shape, shape=[rank.value])
|
||||
values = np.ctypeslib.as_array(values, shape=[nse.value])
|
||||
indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
|
||||
return rank, nse, shape, values, indices
|
||||
rank = ctypes.c_ulonglong(0)
|
||||
nse = ctypes.c_ulonglong(0)
|
||||
shape = ctypes.POINTER(ctypes.c_ulonglong)()
|
||||
values = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))()
|
||||
indices = ctypes.POINTER(ctypes.c_ulonglong)()
|
||||
c_lib.convertFromMLIRSparseTensorF64(
|
||||
sparse,
|
||||
ctypes.byref(rank),
|
||||
ctypes.byref(nse),
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(values),
|
||||
ctypes.byref(indices),
|
||||
)
|
||||
# Convert the returned values to the corresponding numpy types.
|
||||
shape = np.ctypeslib.as_array(shape, shape=[rank.value])
|
||||
values = np.ctypeslib.as_array(values, shape=[nse.value])
|
||||
indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
|
||||
return rank, nse, shape, values, indices
|
||||
|
||||
@@ -9,30 +9,31 @@ from mlir import ir
|
||||
from mlir import passmanager
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
class SparseCompiler:
|
||||
"""Sparse compiler class for compiling and building MLIR modules."""
|
||||
"""Sparse compiler class for compiling and building MLIR modules."""
|
||||
|
||||
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
|
||||
pipeline = f'builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})'
|
||||
self.pipeline = pipeline
|
||||
self.opt_level = opt_level
|
||||
self.shared_libs = shared_libs
|
||||
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
|
||||
pipeline = f"builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})"
|
||||
self.pipeline = pipeline
|
||||
self.opt_level = opt_level
|
||||
self.shared_libs = shared_libs
|
||||
|
||||
def __call__(self, module: ir.Module):
|
||||
"""Convenience application method."""
|
||||
self.compile(module)
|
||||
def __call__(self, module: ir.Module):
|
||||
"""Convenience application method."""
|
||||
self.compile(module)
|
||||
|
||||
def compile(self, module: ir.Module):
|
||||
"""Compiles the module by invoking the sparse copmiler pipeline."""
|
||||
passmanager.PassManager.parse(self.pipeline).run(module.operation)
|
||||
def compile(self, module: ir.Module):
|
||||
"""Compiles the module by invoking the sparse copmiler pipeline."""
|
||||
passmanager.PassManager.parse(self.pipeline).run(module.operation)
|
||||
|
||||
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||
"""Wraps the module in a JIT execution engine."""
|
||||
return execution_engine.ExecutionEngine(
|
||||
module, opt_level=self.opt_level, shared_libs=self.shared_libs)
|
||||
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||
"""Wraps the module in a JIT execution engine."""
|
||||
return execution_engine.ExecutionEngine(
|
||||
module, opt_level=self.opt_level, shared_libs=self.shared_libs
|
||||
)
|
||||
|
||||
def compile_and_jit(self,
|
||||
module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||
"""Compiles and jits the module."""
|
||||
self.compile(module)
|
||||
return self.jit(module)
|
||||
def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||
"""Compiles and jits the module."""
|
||||
self.compile(module)
|
||||
return self.jit(module)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Disable ASAN's leak detection for python taco tests.
|
||||
config.environment['ASAN_OPTIONS'] = 'detect_leaks=0'
|
||||
config.environment["ASAN_OPTIONS"] = "detect_leaks=0"
|
||||
# Only run when python bindings are enabled.
|
||||
if not config.enable_bindings_python:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
@@ -46,10 +46,10 @@ A[i, j] = B[i, k, l] * D[l, j] * C[k, j]
|
||||
|
||||
# Perform the MTTKRP computation and write the result to file.
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns")
|
||||
out_file = os.path.join(test_dir, "A.tns")
|
||||
pt.write(out_file, A)
|
||||
#
|
||||
# CHECK: Compare result True
|
||||
#
|
||||
print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
|
||||
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_A.tns")
|
||||
out_file = os.path.join(test_dir, "A.tns")
|
||||
pt.write(out_file, A)
|
||||
#
|
||||
# CHECK: Compare result True
|
||||
#
|
||||
print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
|
||||
|
||||
@@ -46,13 +46,13 @@ expected = """; extended FROSTT format
|
||||
|
||||
# Force evaluation of the kernels by writing out X and Y.
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
x_file = os.path.join(test_dir, "X.tns")
|
||||
y_file = os.path.join(test_dir, "Y.tns")
|
||||
pt.write(x_file, X)
|
||||
pt.write(y_file, Y)
|
||||
#
|
||||
# CHECK: Compare result True True
|
||||
#
|
||||
x_data = utils.file_as_string(x_file)
|
||||
y_data = utils.file_as_string(y_file)
|
||||
print(f"Compare result {x_data == expected} {y_data == expected}")
|
||||
x_file = os.path.join(test_dir, "X.tns")
|
||||
y_file = os.path.join(test_dir, "Y.tns")
|
||||
pt.write(x_file, X)
|
||||
pt.write(y_file, Y)
|
||||
#
|
||||
# CHECK: Compare result True True
|
||||
#
|
||||
x_data = utils.file_as_string(x_file)
|
||||
y_data = utils.file_as_string(y_file)
|
||||
print(f"Compare result {x_data == expected} {y_data == expected}")
|
||||
|
||||
@@ -26,10 +26,10 @@ C[i, j] = A[i, k] * B[k, j]
|
||||
|
||||
# Force evaluation of the kernel by writing out C.
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
|
||||
out_file = os.path.join(test_dir, "C.tns")
|
||||
pt.write(out_file, C)
|
||||
#
|
||||
# CHECK: Compare result True
|
||||
#
|
||||
print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
|
||||
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
|
||||
out_file = os.path.join(test_dir, "C.tns")
|
||||
pt.write(out_file, C)
|
||||
#
|
||||
# CHECK: Compare result True
|
||||
#
|
||||
print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
|
||||
|
||||
@@ -47,10 +47,10 @@ y[i] = A[i, j] * x[j] + z[i]
|
||||
|
||||
# Perform the SpMV computation and write the result to file
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns")
|
||||
out_file = os.path.join(test_dir, "y.tns")
|
||||
pt.write(out_file, y)
|
||||
#
|
||||
# CHECK: Compare result True
|
||||
#
|
||||
print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
|
||||
golden_file = os.path.join(_SCRIPT_PATH, "data/gold_y.tns")
|
||||
out_file = os.path.join(test_dir, "y.tns")
|
||||
pt.write(out_file, y)
|
||||
#
|
||||
# CHECK: Compare result True
|
||||
#
|
||||
print(f"Compare result {utils.compare_sparse_tns(golden_file, out_file)}")
|
||||
|
||||
@@ -18,11 +18,10 @@ i, j, k, l, m = pt.get_index_vars(5)
|
||||
alpha = pt.tensor(42.0)
|
||||
|
||||
# Set up some sparse tensors with different dim annotations and ordering.
|
||||
S = pt.tensor([8, 8, 8],
|
||||
pt.format([pt.compressed, pt.dense, pt.compressed], [1, 0, 2]))
|
||||
X = pt.tensor([8, 8, 8],
|
||||
pt.format([pt.compressed, pt.compressed, pt.compressed],
|
||||
[1, 0, 2]))
|
||||
S = pt.tensor([8, 8, 8], pt.format([pt.compressed, pt.dense, pt.compressed], [1, 0, 2]))
|
||||
X = pt.tensor(
|
||||
[8, 8, 8], pt.format([pt.compressed, pt.compressed, pt.compressed], [1, 0, 2])
|
||||
)
|
||||
S.insert([0, 0, 0], 2.0)
|
||||
S.insert([1, 1, 1], 3.0)
|
||||
S.insert([4, 4, 4], 4.0)
|
||||
@@ -32,16 +31,14 @@ X[i, j, k] = alpha[0] * S[i, j, k]
|
||||
|
||||
# Set up tensors with a dense last dimension. This results in a full
|
||||
# enveloping storage of all last "rows" with one or more nonzeros.
|
||||
T = pt.tensor([1, 2, 3, 4, 5],
|
||||
pt.format([
|
||||
pt.compressed, pt.compressed, pt.compressed, pt.compressed,
|
||||
pt.dense
|
||||
]))
|
||||
Y = pt.tensor([1, 2, 3, 4, 5],
|
||||
pt.format([
|
||||
pt.compressed, pt.compressed, pt.compressed, pt.compressed,
|
||||
pt.dense
|
||||
]))
|
||||
T = pt.tensor(
|
||||
[1, 2, 3, 4, 5],
|
||||
pt.format([pt.compressed, pt.compressed, pt.compressed, pt.compressed, pt.dense]),
|
||||
)
|
||||
Y = pt.tensor(
|
||||
[1, 2, 3, 4, 5],
|
||||
pt.format([pt.compressed, pt.compressed, pt.compressed, pt.compressed, pt.dense]),
|
||||
)
|
||||
T.insert([0, 1, 2, 3, 4], -2.0)
|
||||
|
||||
Y[i, j, k, l, m] = alpha[0] * T[i, j, k, l, m]
|
||||
@@ -85,18 +82,18 @@ z_expected = """; extended FROSTT format
|
||||
|
||||
# Force evaluation of the kernel by writing out X.
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
x_file = os.path.join(test_dir, 'X.tns')
|
||||
pt.write(x_file, X)
|
||||
y_file = os.path.join(test_dir, 'Y.tns')
|
||||
pt.write(y_file, Y)
|
||||
z_file = os.path.join(test_dir, 'Z.tns')
|
||||
pt.write(z_file, Z)
|
||||
#
|
||||
# CHECK: Compare result True True True
|
||||
#
|
||||
x_data = utils.file_as_string(x_file)
|
||||
y_data = utils.file_as_string(y_file)
|
||||
z_data = utils.file_as_string(z_file)
|
||||
print(
|
||||
f'Compare result {x_data == x_expected} {y_data == y_expected} {z_data == z_expected}'
|
||||
)
|
||||
x_file = os.path.join(test_dir, "X.tns")
|
||||
pt.write(x_file, X)
|
||||
y_file = os.path.join(test_dir, "Y.tns")
|
||||
pt.write(y_file, Y)
|
||||
z_file = os.path.join(test_dir, "Z.tns")
|
||||
pt.write(z_file, Z)
|
||||
#
|
||||
# CHECK: Compare result True True True
|
||||
#
|
||||
x_data = utils.file_as_string(x_file)
|
||||
y_data = utils.file_as_string(y_file)
|
||||
z_data = utils.file_as_string(z_file)
|
||||
print(
|
||||
f"Compare result {x_data == x_expected} {y_data == y_expected} {z_data == z_expected}"
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ compressed = pt.compressed
|
||||
|
||||
i, j = pt.get_index_vars(2)
|
||||
A = pt.tensor([2, 3])
|
||||
S = pt.tensor(3) # S is a scalar tensor.
|
||||
S = pt.tensor(3) # S is a scalar tensor.
|
||||
B = pt.tensor([2, 3], compressed)
|
||||
A.insert([0, 1], 10)
|
||||
A.insert([1, 2], 40)
|
||||
@@ -26,11 +26,11 @@ passed += np.array_equal(values, [30.0, 120.0])
|
||||
|
||||
# Sum all the values in A.
|
||||
S[0] = A[i, j]
|
||||
passed += (S.get_scalar_value() == 50.0)
|
||||
passed += S.get_scalar_value() == 50.0
|
||||
|
||||
indices, values = S.get_coordinates_and_values()
|
||||
passed += (len(indices)==0)
|
||||
passed += (values == 50.0)
|
||||
passed += len(indices) == 0
|
||||
passed += values == 50.0
|
||||
|
||||
# CHECK: Number of passed: 5
|
||||
print("Number of passed:", passed)
|
||||
|
||||
@@ -12,20 +12,20 @@ compressed = pt.compressed
|
||||
passed = 0
|
||||
all_types = [pt.complex64, pt.complex128]
|
||||
for t in all_types:
|
||||
i, j = pt.get_index_vars(2)
|
||||
A = pt.tensor([2, 3], dtype=t)
|
||||
B = pt.tensor([2, 3], dtype=t)
|
||||
C = pt.tensor([2, 3], compressed, dtype=t)
|
||||
A.insert([0, 1], 10 + 20j)
|
||||
A.insert([1, 2], 40 + 0.5j)
|
||||
B.insert([0, 0], 20)
|
||||
B.insert([1, 2], 30 + 15j)
|
||||
C[i, j] = A[i, j] + B[i, j]
|
||||
i, j = pt.get_index_vars(2)
|
||||
A = pt.tensor([2, 3], dtype=t)
|
||||
B = pt.tensor([2, 3], dtype=t)
|
||||
C = pt.tensor([2, 3], compressed, dtype=t)
|
||||
A.insert([0, 1], 10 + 20j)
|
||||
A.insert([1, 2], 40 + 0.5j)
|
||||
B.insert([0, 0], 20)
|
||||
B.insert([1, 2], 30 + 15j)
|
||||
C[i, j] = A[i, j] + B[i, j]
|
||||
|
||||
indices, values = C.get_coordinates_and_values()
|
||||
passed += isinstance(values[0], t.value)
|
||||
passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
|
||||
passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j])
|
||||
indices, values = C.get_coordinates_and_values()
|
||||
passed += isinstance(values[0], t.value)
|
||||
passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
|
||||
passed += np.allclose(values, [20, 10 + 20j, 70 + 15.5j])
|
||||
|
||||
# CHECK: Number of passed: 6
|
||||
print("Number of passed:", passed)
|
||||
|
||||
@@ -12,24 +12,22 @@ compressed = pt.compressed
|
||||
dense = pt.dense
|
||||
|
||||
passed = 0
|
||||
all_types = [
|
||||
pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64
|
||||
]
|
||||
all_types = [pt.int8, pt.int16, pt.int32, pt.int64, pt.float16, pt.float32, pt.float64]
|
||||
for t in all_types:
|
||||
i, j = pt.get_index_vars(2)
|
||||
A = pt.tensor([2, 3], dtype=t)
|
||||
B = pt.tensor([2, 3], dtype=t)
|
||||
C = pt.tensor([2, 3], compressed, dtype=t)
|
||||
A.insert([0, 1], 10)
|
||||
A.insert([1, 2], 40)
|
||||
B.insert([0, 0], 20)
|
||||
B.insert([1, 2], 30)
|
||||
C[i, j] = A[i, j] + B[i, j]
|
||||
i, j = pt.get_index_vars(2)
|
||||
A = pt.tensor([2, 3], dtype=t)
|
||||
B = pt.tensor([2, 3], dtype=t)
|
||||
C = pt.tensor([2, 3], compressed, dtype=t)
|
||||
A.insert([0, 1], 10)
|
||||
A.insert([1, 2], 40)
|
||||
B.insert([0, 0], 20)
|
||||
B.insert([1, 2], 30)
|
||||
C[i, j] = A[i, j] + B[i, j]
|
||||
|
||||
indices, values = C.get_coordinates_and_values()
|
||||
passed += isinstance(values[0], t.value)
|
||||
passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
|
||||
passed += np.allclose(values, [20.0, 10.0, 70.0])
|
||||
indices, values = C.get_coordinates_and_values()
|
||||
passed += isinstance(values[0], t.value)
|
||||
passed += np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
|
||||
passed += np.allclose(values, [20.0, 10.0, 70.0])
|
||||
|
||||
# CHECK: Number of passed: 21
|
||||
print("Number of passed:", passed)
|
||||
|
||||
@@ -10,8 +10,8 @@ from tools import mlir_pytaco_api as pt
|
||||
|
||||
i, j = pt.get_index_vars(2)
|
||||
# Both tensors are true dense tensors.
|
||||
A = pt.from_array(np.full([2,3], 1, dtype=np.float64))
|
||||
B = pt.from_array(np.full([2,3], 2, dtype=np.float64))
|
||||
A = pt.from_array(np.full([2, 3], 1, dtype=np.float64))
|
||||
B = pt.from_array(np.full([2, 3], 2, dtype=np.float64))
|
||||
# Define the result tensor as a true dense tensor. The parameter is_dense=True
|
||||
# is an MLIR-PyTACO extension.
|
||||
C = pt.tensor([2, 3], dtype=pt.float64, is_dense=True)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -31,50 +31,52 @@ _MTX_FILENAME_SUFFIX = ".mtx"
|
||||
_TNS_FILENAME_SUFFIX = ".tns"
|
||||
|
||||
|
||||
def read(filename: str, fmt: Format,
|
||||
dtype: DType = DType(Type.FLOAT32)) -> Tensor:
|
||||
"""Inputs a tensor from a given file.
|
||||
def read(filename: str, fmt: Format, dtype: DType = DType(Type.FLOAT32)) -> Tensor:
|
||||
"""Inputs a tensor from a given file.
|
||||
|
||||
The name suffix of the file specifies the format of the input tensor. We
|
||||
currently only support .mtx format for support sparse tensors.
|
||||
The name suffix of the file specifies the format of the input tensor. We
|
||||
currently only support .mtx format for support sparse tensors.
|
||||
|
||||
Args:
|
||||
filename: A string input filename.
|
||||
fmt: The storage format of the tensor.
|
||||
dtype: The data type, default to float32.
|
||||
Args:
|
||||
filename: A string input filename.
|
||||
fmt: The storage format of the tensor.
|
||||
dtype: The data type, default to float32.
|
||||
|
||||
Raises:
|
||||
ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an
|
||||
instance of Format or fmt is not a sparse tensor.
|
||||
"""
|
||||
if (not isinstance(filename, str) or
|
||||
(not filename.endswith(_MTX_FILENAME_SUFFIX) and
|
||||
not filename.endswith(_TNS_FILENAME_SUFFIX))):
|
||||
raise ValueError("Expected string filename ends with "
|
||||
f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: "
|
||||
f"{filename}.")
|
||||
Raises:
|
||||
ValueError: If filename doesn't end with .mtx or .tns, or fmt is not an
|
||||
instance of Format or fmt is not a sparse tensor.
|
||||
"""
|
||||
if not isinstance(filename, str) or (
|
||||
not filename.endswith(_MTX_FILENAME_SUFFIX)
|
||||
and not filename.endswith(_TNS_FILENAME_SUFFIX)
|
||||
):
|
||||
raise ValueError(
|
||||
"Expected string filename ends with "
|
||||
f"{_MTX_FILENAME_SUFFIX} or {_TNS_FILENAME_SUFFIX}: "
|
||||
f"{filename}."
|
||||
)
|
||||
|
||||
return Tensor.from_file(filename, fmt, dtype)
|
||||
return Tensor.from_file(filename, fmt, dtype)
|
||||
|
||||
|
||||
def write(filename: str, tensor: Tensor) -> None:
|
||||
"""Outputs a tensor to a given file.
|
||||
"""Outputs a tensor to a given file.
|
||||
|
||||
The name suffix of the file specifies the format of the output. We currently
|
||||
only support .tns format.
|
||||
The name suffix of the file specifies the format of the output. We currently
|
||||
only support .tns format.
|
||||
|
||||
Args:
|
||||
filename: A string output filename.
|
||||
tensor: The tensor to output.
|
||||
Args:
|
||||
filename: A string output filename.
|
||||
tensor: The tensor to output.
|
||||
|
||||
Raises:
|
||||
ValueError: If filename doesn't end with .tns or tensor is not a Tensor.
|
||||
"""
|
||||
if (not isinstance(filename, str) or
|
||||
not filename.endswith(_TNS_FILENAME_SUFFIX)):
|
||||
raise ValueError("Expected string filename ends with"
|
||||
f" {_TNS_FILENAME_SUFFIX}: {filename}.")
|
||||
if not isinstance(tensor, Tensor):
|
||||
raise ValueError(f"Expected a Tensor object: {tensor}.")
|
||||
Raises:
|
||||
ValueError: If filename doesn't end with .tns or tensor is not a Tensor.
|
||||
"""
|
||||
if not isinstance(filename, str) or not filename.endswith(_TNS_FILENAME_SUFFIX):
|
||||
raise ValueError(
|
||||
"Expected string filename ends with" f" {_TNS_FILENAME_SUFFIX}: {filename}."
|
||||
)
|
||||
if not isinstance(tensor, Tensor):
|
||||
raise ValueError(f"Expected a Tensor object: {tensor}.")
|
||||
|
||||
tensor.to_file(filename)
|
||||
tensor.to_file(filename)
|
||||
|
||||
@@ -36,190 +36,234 @@ _ENTRY_NAME = "main"
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_support_lib_name() -> str:
|
||||
"""Gets the string name for the supporting C shared library."""
|
||||
return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
|
||||
"""Gets the string name for the supporting C shared library."""
|
||||
return os.getenv(_SUPPORTLIB_ENV_VAR, _DEFAULT_SUPPORTLIB)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_sparse_compiler() -> mlir_sparse_compiler.SparseCompiler:
|
||||
"""Gets the MLIR sparse compiler with default setting."""
|
||||
return mlir_sparse_compiler.SparseCompiler(
|
||||
options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()])
|
||||
"""Gets the MLIR sparse compiler with default setting."""
|
||||
return mlir_sparse_compiler.SparseCompiler(
|
||||
options="", opt_level=_OPT_LEVEL, shared_libs=[_get_support_lib_name()]
|
||||
)
|
||||
|
||||
|
||||
def _record_support_funcs(
|
||||
ty: np.dtype, to_func: _SupportFunc, from_func: _SupportFunc,
|
||||
ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]]) -> None:
|
||||
"""Records the two supporting functions for a given data type."""
|
||||
to_func.restype = ctypes.c_void_p
|
||||
from_func.restype = ctypes.c_void_p
|
||||
ty_to_funcs[ty] = (to_func, from_func)
|
||||
ty: np.dtype,
|
||||
to_func: _SupportFunc,
|
||||
from_func: _SupportFunc,
|
||||
ty_to_funcs: Dict[np.dtype, Tuple[_SupportFunc, _SupportFunc]],
|
||||
) -> None:
|
||||
"""Records the two supporting functions for a given data type."""
|
||||
to_func.restype = ctypes.c_void_p
|
||||
from_func.restype = ctypes.c_void_p
|
||||
ty_to_funcs[ty] = (to_func, from_func)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_support_func_locator() -> _SupportFuncLocator:
|
||||
"""Constructs a function to locate the supporting functions for a data type.
|
||||
"""Constructs a function to locate the supporting functions for a data type.
|
||||
|
||||
Loads the supporting C shared library with the needed routines. Constructs a
|
||||
dictionary from the supported data types to the routines for the data types,
|
||||
and then a function to look up the dictionary for a given data type.
|
||||
Loads the supporting C shared library with the needed routines. Constructs a
|
||||
dictionary from the supported data types to the routines for the data types,
|
||||
and then a function to look up the dictionary for a given data type.
|
||||
|
||||
The name of the supporting C shared library is either provided by an
|
||||
an environment variable or a default value.
|
||||
The name of the supporting C shared library is either provided by an
|
||||
an environment variable or a default value.
|
||||
|
||||
Returns:
|
||||
The function to look up the supporting functions for a given data type.
|
||||
Returns:
|
||||
The function to look up the supporting functions for a given data type.
|
||||
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routines.
|
||||
"""
|
||||
# This raises OSError exception if there is any problem in loading the shared
|
||||
# library.
|
||||
c_lib = ctypes.CDLL(_get_support_lib_name())
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routines.
|
||||
"""
|
||||
# This raises OSError exception if there is any problem in loading the shared
|
||||
# library.
|
||||
c_lib = ctypes.CDLL(_get_support_lib_name())
|
||||
|
||||
type_to_funcs = {}
|
||||
try:
|
||||
support_types = [(np.int8, c_lib.convertToMLIRSparseTensorI8,
|
||||
c_lib.convertFromMLIRSparseTensorI8),
|
||||
(np.int16, c_lib.convertToMLIRSparseTensorI16,
|
||||
c_lib.convertFromMLIRSparseTensorI16),
|
||||
(np.int32, c_lib.convertToMLIRSparseTensorI32,
|
||||
c_lib.convertFromMLIRSparseTensorI32),
|
||||
(np.int64, c_lib.convertToMLIRSparseTensorI64,
|
||||
c_lib.convertFromMLIRSparseTensorI64),
|
||||
(np.float16, c_lib.convertToMLIRSparseTensorF16,
|
||||
c_lib.convertFromMLIRSparseTensorF16),
|
||||
(np.float32, c_lib.convertToMLIRSparseTensorF32,
|
||||
c_lib.convertFromMLIRSparseTensorF32),
|
||||
(np.float64, c_lib.convertToMLIRSparseTensorF64,
|
||||
c_lib.convertFromMLIRSparseTensorF64),
|
||||
(np.complex64, c_lib.convertToMLIRSparseTensorC32,
|
||||
c_lib.convertFromMLIRSparseTensorC32),
|
||||
(np.complex128, c_lib.convertToMLIRSparseTensorC64,
|
||||
c_lib.convertFromMLIRSparseTensorC64)]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Missing supporting function: {e}") from e
|
||||
for i, info in enumerate(support_types):
|
||||
_record_support_funcs(info[0], info[1], info[2], type_to_funcs)
|
||||
type_to_funcs = {}
|
||||
try:
|
||||
support_types = [
|
||||
(
|
||||
np.int8,
|
||||
c_lib.convertToMLIRSparseTensorI8,
|
||||
c_lib.convertFromMLIRSparseTensorI8,
|
||||
),
|
||||
(
|
||||
np.int16,
|
||||
c_lib.convertToMLIRSparseTensorI16,
|
||||
c_lib.convertFromMLIRSparseTensorI16,
|
||||
),
|
||||
(
|
||||
np.int32,
|
||||
c_lib.convertToMLIRSparseTensorI32,
|
||||
c_lib.convertFromMLIRSparseTensorI32,
|
||||
),
|
||||
(
|
||||
np.int64,
|
||||
c_lib.convertToMLIRSparseTensorI64,
|
||||
c_lib.convertFromMLIRSparseTensorI64,
|
||||
),
|
||||
(
|
||||
np.float16,
|
||||
c_lib.convertToMLIRSparseTensorF16,
|
||||
c_lib.convertFromMLIRSparseTensorF16,
|
||||
),
|
||||
(
|
||||
np.float32,
|
||||
c_lib.convertToMLIRSparseTensorF32,
|
||||
c_lib.convertFromMLIRSparseTensorF32,
|
||||
),
|
||||
(
|
||||
np.float64,
|
||||
c_lib.convertToMLIRSparseTensorF64,
|
||||
c_lib.convertFromMLIRSparseTensorF64,
|
||||
),
|
||||
(
|
||||
np.complex64,
|
||||
c_lib.convertToMLIRSparseTensorC32,
|
||||
c_lib.convertFromMLIRSparseTensorC32,
|
||||
),
|
||||
(
|
||||
np.complex128,
|
||||
c_lib.convertToMLIRSparseTensorC64,
|
||||
c_lib.convertFromMLIRSparseTensorC64,
|
||||
),
|
||||
]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Missing supporting function: {e}") from e
|
||||
for i, info in enumerate(support_types):
|
||||
_record_support_funcs(info[0], info[1], info[2], type_to_funcs)
|
||||
|
||||
def get_support_funcs(ty: np.dtype):
|
||||
funcs = type_to_funcs[ty]
|
||||
assert funcs is not None
|
||||
return funcs
|
||||
def get_support_funcs(ty: np.dtype):
|
||||
funcs = type_to_funcs[ty]
|
||||
assert funcs is not None
|
||||
return funcs
|
||||
|
||||
return get_support_funcs
|
||||
return get_support_funcs
|
||||
|
||||
|
||||
def sparse_tensor_to_coo_tensor(
|
||||
sparse_tensor: ctypes.c_void_p,
|
||||
dtype: np.dtype,
|
||||
) -> Tuple[int, int, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Converts an MLIR sparse tensor to a COO-flavored format tensor.
|
||||
"""Converts an MLIR sparse tensor to a COO-flavored format tensor.
|
||||
|
||||
Args:
|
||||
sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
|
||||
dtype: The numpy data type for the tensor elements.
|
||||
Args:
|
||||
sparse_tensor: A ctypes.c_void_p to the MLIR sparse tensor descriptor.
|
||||
dtype: The numpy data type for the tensor elements.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the following values for the COO-flavored format
|
||||
tensor:
|
||||
rank: An integer for the rank of the tensor.
|
||||
nse: An integer for the number of non-zero values in the tensor.
|
||||
shape: A 1D numpy array of integers, for the shape of the tensor.
|
||||
values: A 1D numpy array, for the non-zero values in the tensor.
|
||||
indices: A 2D numpy array of integers, representing the indices for the
|
||||
non-zero values in the tensor.
|
||||
Returns:
|
||||
A tuple that contains the following values for the COO-flavored format
|
||||
tensor:
|
||||
rank: An integer for the rank of the tensor.
|
||||
nse: An integer for the number of non-zero values in the tensor.
|
||||
shape: A 1D numpy array of integers, for the shape of the tensor.
|
||||
values: A 1D numpy array, for the non-zero values in the tensor.
|
||||
indices: A 2D numpy array of integers, representing the indices for the
|
||||
non-zero values in the tensor.
|
||||
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routines.
|
||||
"""
|
||||
convert_from = _get_support_func_locator()(dtype)[1]
|
||||
rank = ctypes.c_ulonglong(0)
|
||||
nse = ctypes.c_ulonglong(0)
|
||||
shape = ctypes.POINTER(ctypes.c_ulonglong)()
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routines.
|
||||
"""
|
||||
convert_from = _get_support_func_locator()(dtype)[1]
|
||||
rank = ctypes.c_ulonglong(0)
|
||||
nse = ctypes.c_ulonglong(0)
|
||||
shape = ctypes.POINTER(ctypes.c_ulonglong)()
|
||||
|
||||
values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))()
|
||||
indices = ctypes.POINTER(ctypes.c_ulonglong)()
|
||||
convert_from(sparse_tensor, ctypes.byref(rank), ctypes.byref(nse),
|
||||
ctypes.byref(shape), ctypes.byref(values), ctypes.byref(indices))
|
||||
values = ctypes.POINTER(runtime.as_ctype(np.dtype(dtype)))()
|
||||
indices = ctypes.POINTER(ctypes.c_ulonglong)()
|
||||
convert_from(
|
||||
sparse_tensor,
|
||||
ctypes.byref(rank),
|
||||
ctypes.byref(nse),
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(values),
|
||||
ctypes.byref(indices),
|
||||
)
|
||||
|
||||
# Convert the returned values to the corresponding numpy types.
|
||||
shape = np.ctypeslib.as_array(shape, shape=[rank.value])
|
||||
values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value]))
|
||||
indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
|
||||
return rank.value, nse.value, shape, values, indices
|
||||
# Convert the returned values to the corresponding numpy types.
|
||||
shape = np.ctypeslib.as_array(shape, shape=[rank.value])
|
||||
values = runtime.to_numpy(np.ctypeslib.as_array(values, shape=[nse.value]))
|
||||
indices = np.ctypeslib.as_array(indices, shape=[nse.value, rank.value])
|
||||
return rank.value, nse.value, shape, values, indices
|
||||
|
||||
|
||||
def coo_tensor_to_sparse_tensor(np_shape: np.ndarray, np_values: np.ndarray,
|
||||
np_indices: np.ndarray, np_perm: np.ndarray,
|
||||
np_sparse: np.ndarray) -> int:
|
||||
"""Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
|
||||
def coo_tensor_to_sparse_tensor(
|
||||
np_shape: np.ndarray,
|
||||
np_values: np.ndarray,
|
||||
np_indices: np.ndarray,
|
||||
np_perm: np.ndarray,
|
||||
np_sparse: np.ndarray,
|
||||
) -> int:
|
||||
"""Converts a COO-flavored format sparse tensor to an MLIR sparse tensor.
|
||||
|
||||
Args:
|
||||
np_shape: A 1D numpy array of integers, for the shape of the tensor.
|
||||
np_values: A 1D numpy array, for the non-zero values in the tensor.
|
||||
np_indices: A 2D numpy array of integers, representing the indices for the
|
||||
non-zero values in the tensor.
|
||||
np_perm: A 1D numpy array of integers, representing the storage ordering
|
||||
for the dimensions.
|
||||
np_sparse: A 1D numpy array of uint8, representing the sparsity values
|
||||
for the dimensions.
|
||||
Args:
|
||||
np_shape: A 1D numpy array of integers, for the shape of the tensor.
|
||||
np_values: A 1D numpy array, for the non-zero values in the tensor.
|
||||
np_indices: A 2D numpy array of integers, representing the indices for the
|
||||
non-zero values in the tensor.
|
||||
np_perm: A 1D numpy array of integers, representing the storage ordering
|
||||
for the dimensions.
|
||||
np_sparse: A 1D numpy array of uint8, representing the sparsity values
|
||||
for the dimensions.
|
||||
|
||||
Returns:
|
||||
An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor
|
||||
descriptor.
|
||||
Returns:
|
||||
An integer for the non-null ctypes.c_void_p to the MLIR sparse tensor
|
||||
descriptor.
|
||||
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routines.
|
||||
"""
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routines.
|
||||
"""
|
||||
|
||||
r = len(np_shape)
|
||||
rank = ctypes.c_ulonglong(r)
|
||||
nse = ctypes.c_ulonglong(len(np_values))
|
||||
shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
|
||||
values = np_values.ctypes.data_as(
|
||||
ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype))))
|
||||
indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
|
||||
r = len(np_shape)
|
||||
rank = ctypes.c_ulonglong(r)
|
||||
nse = ctypes.c_ulonglong(len(np_values))
|
||||
shape = np_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
|
||||
values = np_values.ctypes.data_as(
|
||||
ctypes.POINTER(runtime.as_ctype(np.dtype(np_values.dtype)))
|
||||
)
|
||||
indices = np_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
|
||||
|
||||
perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
|
||||
sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
|
||||
perm = np_perm.ctypes.data_as(ctypes.POINTER(ctypes.c_ulonglong))
|
||||
sparse = np_sparse.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
|
||||
|
||||
convert_to = _get_support_func_locator()(np_values.dtype.type)[0]
|
||||
ptr = convert_to(rank, nse, shape, values, indices, perm, sparse)
|
||||
assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
|
||||
return ptr
|
||||
convert_to = _get_support_func_locator()(np_values.dtype.type)[0]
|
||||
ptr = convert_to(rank, nse, shape, values, indices, perm, sparse)
|
||||
assert ptr is not None, "Problem with calling convertToMLIRSparseTensorF64"
|
||||
return ptr
|
||||
|
||||
|
||||
def compile_and_build_engine(
|
||||
module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||
"""Compiles an MLIR module and builds a JIT execution engine.
|
||||
def compile_and_build_engine(module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||
"""Compiles an MLIR module and builds a JIT execution engine.
|
||||
|
||||
Args:
|
||||
module: The MLIR module.
|
||||
Args:
|
||||
module: The MLIR module.
|
||||
|
||||
Returns:
|
||||
A JIT execution engine for the MLIR module.
|
||||
Returns:
|
||||
A JIT execution engine for the MLIR module.
|
||||
|
||||
"""
|
||||
return _get_sparse_compiler().compile_and_jit(module)
|
||||
"""
|
||||
return _get_sparse_compiler().compile_and_jit(module)
|
||||
|
||||
|
||||
class _SparseTensorDescriptor(ctypes.Structure):
|
||||
"""A C structure for an MLIR sparse tensor."""
|
||||
_fields_ = [
|
||||
# A pointer for the MLIR sparse tensor storage.
|
||||
("storage", ctypes.POINTER(ctypes.c_ulonglong)),
|
||||
# An MLIR MemRef descriptor for the shape of the sparse tensor.
|
||||
("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
|
||||
]
|
||||
"""A C structure for an MLIR sparse tensor."""
|
||||
|
||||
_fields_ = [
|
||||
# A pointer for the MLIR sparse tensor storage.
|
||||
("storage", ctypes.POINTER(ctypes.c_ulonglong)),
|
||||
# An MLIR MemRef descriptor for the shape of the sparse tensor.
|
||||
("shape", runtime.make_nd_memref_descriptor(1, ctypes.c_ulonglong)),
|
||||
]
|
||||
|
||||
|
||||
def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str:
|
||||
"""Produces the MLIR text code to output the size for the given dimension."""
|
||||
return f"""
|
||||
"""Produces the MLIR text code to output the size for the given dimension."""
|
||||
return f"""
|
||||
%c{dim} = arith.constant {dim} : index
|
||||
%d{dim} = tensor.dim %t, %c{dim} : tensor<{shape}x{type}, #enc>
|
||||
memref.store %d{dim}, %b[%c{dim}] : memref<{rank}xindex>
|
||||
@@ -233,26 +277,29 @@ def _output_one_dim(dim: int, rank: int, shape: str, type: str) -> str:
|
||||
# (2) Use scf.for instead of an unrolled loop to write out the dimension sizes
|
||||
# when tensor.dim supports non-constant dimension value.
|
||||
def _get_create_sparse_tensor_kernel(
|
||||
sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str) -> str:
|
||||
"""Creates an MLIR text kernel to contruct a sparse tensor from a file.
|
||||
sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str
|
||||
) -> str:
|
||||
"""Creates an MLIR text kernel to contruct a sparse tensor from a file.
|
||||
|
||||
The kernel returns a _SparseTensorDescriptor structure.
|
||||
"""
|
||||
rank = len(sparsity_codes)
|
||||
The kernel returns a _SparseTensorDescriptor structure.
|
||||
"""
|
||||
rank = len(sparsity_codes)
|
||||
|
||||
# Use ? to represent a dimension in the dynamic shape string representation.
|
||||
shape = "x".join(map(lambda d: "?", range(rank)))
|
||||
# Use ? to represent a dimension in the dynamic shape string representation.
|
||||
shape = "x".join(map(lambda d: "?", range(rank)))
|
||||
|
||||
# Convert the encoded sparsity values to a string representation.
|
||||
sparsity = ", ".join(
|
||||
map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes))
|
||||
# Convert the encoded sparsity values to a string representation.
|
||||
sparsity = ", ".join(
|
||||
map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes)
|
||||
)
|
||||
|
||||
# Get the MLIR text code to write the dimension sizes to the output buffer.
|
||||
output_dims = "\n".join(
|
||||
map(lambda d: _output_one_dim(d, rank, shape, type), range(rank)))
|
||||
# Get the MLIR text code to write the dimension sizes to the output buffer.
|
||||
output_dims = "\n".join(
|
||||
map(lambda d: _output_one_dim(d, rank, shape, type), range(rank))
|
||||
)
|
||||
|
||||
# Return the MLIR text kernel.
|
||||
return f"""
|
||||
# Return the MLIR text kernel.
|
||||
return f"""
|
||||
!Ptr = !llvm.ptr<i8>
|
||||
#enc = #sparse_tensor.encoding<{{
|
||||
lvlTypes = [ {sparsity} ]
|
||||
@@ -266,69 +313,69 @@ attributes {{ llvm.emit_c_interface }} {{
|
||||
}}"""
|
||||
|
||||
|
||||
def create_sparse_tensor(filename: str,
|
||||
sparsity: Sequence[sparse_tensor.DimLevelType],
|
||||
type: str) -> Tuple[ctypes.c_void_p, np.ndarray]:
|
||||
"""Creates an MLIR sparse tensor from the input file.
|
||||
def create_sparse_tensor(
|
||||
filename: str, sparsity: Sequence[sparse_tensor.DimLevelType], type: str
|
||||
) -> Tuple[ctypes.c_void_p, np.ndarray]:
|
||||
"""Creates an MLIR sparse tensor from the input file.
|
||||
|
||||
Args:
|
||||
filename: A string for the name of the file that contains the tensor data in
|
||||
a COO-flavored format.
|
||||
sparsity: A sequence of DimLevelType values, one for each dimension of the
|
||||
tensor.
|
||||
Args:
|
||||
filename: A string for the name of the file that contains the tensor data in
|
||||
a COO-flavored format.
|
||||
sparsity: A sequence of DimLevelType values, one for each dimension of the
|
||||
tensor.
|
||||
|
||||
Returns:
|
||||
A Tuple containing the following values:
|
||||
storage: A ctypes.c_void_p for the MLIR sparse tensor storage.
|
||||
shape: A 1D numpy array of integers, for the shape of the tensor.
|
||||
Returns:
|
||||
A Tuple containing the following values:
|
||||
storage: A ctypes.c_void_p for the MLIR sparse tensor storage.
|
||||
shape: A 1D numpy array of integers, for the shape of the tensor.
|
||||
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the supporting C shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routine.
|
||||
"""
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
module = _get_create_sparse_tensor_kernel(sparsity, type)
|
||||
module = ir.Module.parse(module)
|
||||
engine = compile_and_build_engine(module)
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the supporting C shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routine.
|
||||
"""
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
module = _get_create_sparse_tensor_kernel(sparsity, type)
|
||||
module = ir.Module.parse(module)
|
||||
engine = compile_and_build_engine(module)
|
||||
|
||||
# A sparse tensor descriptor to receive the kernel result.
|
||||
c_tensor_desc = _SparseTensorDescriptor()
|
||||
# Convert the filename to a byte stream.
|
||||
c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
|
||||
# A sparse tensor descriptor to receive the kernel result.
|
||||
c_tensor_desc = _SparseTensorDescriptor()
|
||||
# Convert the filename to a byte stream.
|
||||
c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
|
||||
|
||||
arg_pointers = [
|
||||
ctypes.byref(ctypes.pointer(c_tensor_desc)),
|
||||
ctypes.byref(c_filename)
|
||||
]
|
||||
arg_pointers = [
|
||||
ctypes.byref(ctypes.pointer(c_tensor_desc)),
|
||||
ctypes.byref(c_filename),
|
||||
]
|
||||
|
||||
# Invoke the execution engine to run the module and return the result.
|
||||
engine.invoke(_ENTRY_NAME, *arg_pointers)
|
||||
shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
|
||||
return c_tensor_desc.storage, shape
|
||||
# Invoke the execution engine to run the module and return the result.
|
||||
engine.invoke(_ENTRY_NAME, *arg_pointers)
|
||||
shape = runtime.ranked_memref_to_numpy(ctypes.pointer(c_tensor_desc.shape))
|
||||
return c_tensor_desc.storage, shape
|
||||
|
||||
|
||||
# TODO: With better support from MLIR, we may improve the current implementation
|
||||
# by using Python code to generate the kernel instead of doing MLIR text code
|
||||
# stitching.
|
||||
def _get_output_sparse_tensor_kernel(
|
||||
sparsity_codes: Sequence[sparse_tensor.DimLevelType],
|
||||
type: str) -> str:
|
||||
"""Creates an MLIR text kernel to output a sparse tensor to a file.
|
||||
sparsity_codes: Sequence[sparse_tensor.DimLevelType], type: str
|
||||
) -> str:
|
||||
"""Creates an MLIR text kernel to output a sparse tensor to a file.
|
||||
|
||||
The kernel returns void.
|
||||
"""
|
||||
rank = len(sparsity_codes)
|
||||
The kernel returns void.
|
||||
"""
|
||||
rank = len(sparsity_codes)
|
||||
|
||||
# Use ? to represent a dimension in the dynamic shape string representation.
|
||||
shape = "x".join(map(lambda d: "?", range(rank)))
|
||||
# Use ? to represent a dimension in the dynamic shape string representation.
|
||||
shape = "x".join(map(lambda d: "?", range(rank)))
|
||||
|
||||
# Convert the encoded sparsity values to a string representation.
|
||||
sparsity = ", ".join(
|
||||
map(lambda s: '"compressed"'
|
||||
if s.value else '"dense"', sparsity_codes))
|
||||
# Convert the encoded sparsity values to a string representation.
|
||||
sparsity = ", ".join(
|
||||
map(lambda s: '"compressed"' if s.value else '"dense"', sparsity_codes)
|
||||
)
|
||||
|
||||
# Return the MLIR text kernel.
|
||||
return f"""
|
||||
# Return the MLIR text kernel.
|
||||
return f"""
|
||||
!Ptr = !llvm.ptr<i8>
|
||||
#enc = #sparse_tensor.encoding<{{
|
||||
lvlTypes = [ {sparsity} ]
|
||||
@@ -340,35 +387,38 @@ attributes {{ llvm.emit_c_interface }} {{
|
||||
}}"""
|
||||
|
||||
|
||||
def output_sparse_tensor(tensor: ctypes.c_void_p, filename: str,
|
||||
sparsity: Sequence[sparse_tensor.DimLevelType],
|
||||
type: str) -> None:
|
||||
"""Outputs an MLIR sparse tensor to the given file.
|
||||
def output_sparse_tensor(
|
||||
tensor: ctypes.c_void_p,
|
||||
filename: str,
|
||||
sparsity: Sequence[sparse_tensor.DimLevelType],
|
||||
type: str,
|
||||
) -> None:
|
||||
"""Outputs an MLIR sparse tensor to the given file.
|
||||
|
||||
Args:
|
||||
tensor: A C pointer to the MLIR sparse tensor.
|
||||
filename: A string for the name of the file that contains the tensor data in
|
||||
a COO-flavored format.
|
||||
sparsity: A sequence of DimLevelType values, one for each dimension of the
|
||||
tensor.
|
||||
type: The MLIR string for the data type.
|
||||
Args:
|
||||
tensor: A C pointer to the MLIR sparse tensor.
|
||||
filename: A string for the name of the file that contains the tensor data in
|
||||
a COO-flavored format.
|
||||
sparsity: A sequence of DimLevelType values, one for each dimension of the
|
||||
tensor.
|
||||
type: The MLIR string for the data type.
|
||||
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the supporting C shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routine.
|
||||
"""
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
module = _get_output_sparse_tensor_kernel(sparsity, type)
|
||||
module = ir.Module.parse(module)
|
||||
engine = compile_and_build_engine(module)
|
||||
Raises:
|
||||
OSError: If there is any problem in loading the supporting C shared library.
|
||||
ValueError: If the shared library doesn't contain the needed routine.
|
||||
"""
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
module = _get_output_sparse_tensor_kernel(sparsity, type)
|
||||
module = ir.Module.parse(module)
|
||||
engine = compile_and_build_engine(module)
|
||||
|
||||
# Convert the filename to a byte stream.
|
||||
c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
|
||||
# Convert the filename to a byte stream.
|
||||
c_filename = ctypes.c_char_p(bytes(filename, "utf-8"))
|
||||
|
||||
arg_pointers = [
|
||||
ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
|
||||
ctypes.byref(c_filename)
|
||||
]
|
||||
arg_pointers = [
|
||||
ctypes.byref(ctypes.cast(tensor, ctypes.c_void_p)),
|
||||
ctypes.byref(c_filename),
|
||||
]
|
||||
|
||||
# Invoke the execution engine to run the module and return the result.
|
||||
engine.invoke(_ENTRY_NAME, *arg_pointers)
|
||||
# Invoke the execution engine to run the module and return the result.
|
||||
engine.invoke(_ENTRY_NAME, *arg_pointers)
|
||||
|
||||
@@ -13,29 +13,29 @@ from typing import Sequence
|
||||
|
||||
|
||||
class SparseCompiler:
|
||||
"""Sparse compiler class for compiling and building MLIR modules."""
|
||||
"""Sparse compiler class for compiling and building MLIR modules."""
|
||||
|
||||
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
|
||||
pipeline = f'builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})'
|
||||
self.pipeline = pipeline
|
||||
self.opt_level = opt_level
|
||||
self.shared_libs = shared_libs
|
||||
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
|
||||
pipeline = f"builtin.module(sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}})"
|
||||
self.pipeline = pipeline
|
||||
self.opt_level = opt_level
|
||||
self.shared_libs = shared_libs
|
||||
|
||||
def __call__(self, module: ir.Module):
|
||||
"""Convenience application method."""
|
||||
self.compile(module)
|
||||
def __call__(self, module: ir.Module):
|
||||
"""Convenience application method."""
|
||||
self.compile(module)
|
||||
|
||||
def compile(self, module: ir.Module):
|
||||
"""Compiles the module by invoking the sparse copmiler pipeline."""
|
||||
passmanager.PassManager.parse(self.pipeline).run(module.operation)
|
||||
def compile(self, module: ir.Module):
|
||||
"""Compiles the module by invoking the sparse copmiler pipeline."""
|
||||
passmanager.PassManager.parse(self.pipeline).run(module.operation)
|
||||
|
||||
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||
"""Wraps the module in a JIT execution engine."""
|
||||
return execution_engine.ExecutionEngine(
|
||||
module, opt_level=self.opt_level, shared_libs=self.shared_libs)
|
||||
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||
"""Wraps the module in a JIT execution engine."""
|
||||
return execution_engine.ExecutionEngine(
|
||||
module, opt_level=self.opt_level, shared_libs=self.shared_libs
|
||||
)
|
||||
|
||||
def compile_and_jit(self,
|
||||
module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||
"""Compiles and jits the module."""
|
||||
self.compile(module)
|
||||
return self.jit(module)
|
||||
def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
|
||||
"""Compiles and jits the module."""
|
||||
self.compile(module)
|
||||
return self.jit(module)
|
||||
|
||||
@@ -8,38 +8,40 @@ import numpy as np
|
||||
|
||||
|
||||
def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool:
|
||||
"""Compares sparse tensor actual output file with expected output file.
|
||||
"""Compares sparse tensor actual output file with expected output file.
|
||||
|
||||
This routine assumes the input files are in FROSTT format. See
|
||||
http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format.
|
||||
This routine assumes the input files are in FROSTT format. See
|
||||
http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format.
|
||||
|
||||
It also assumes the first line in the output file is a comment line.
|
||||
It also assumes the first line in the output file is a comment line.
|
||||
|
||||
"""
|
||||
with open(actual, "r") as actual_f:
|
||||
with open(expected, "r") as expected_f:
|
||||
# Skip the first comment line.
|
||||
_ = actual_f.readline()
|
||||
_ = expected_f.readline()
|
||||
"""
|
||||
with open(actual, "r") as actual_f:
|
||||
with open(expected, "r") as expected_f:
|
||||
# Skip the first comment line.
|
||||
_ = actual_f.readline()
|
||||
_ = expected_f.readline()
|
||||
|
||||
# Compare the two lines of meta data
|
||||
if (actual_f.readline() != expected_f.readline() or
|
||||
actual_f.readline() != expected_f.readline()):
|
||||
return FALSE
|
||||
# Compare the two lines of meta data
|
||||
if (
|
||||
actual_f.readline() != expected_f.readline()
|
||||
or actual_f.readline() != expected_f.readline()
|
||||
):
|
||||
return FALSE
|
||||
|
||||
actual_data = np.loadtxt(actual, np.float64, skiprows=3)
|
||||
expected_data = np.loadtxt(expected, np.float64, skiprows=3)
|
||||
return np.allclose(actual_data, expected_data, rtol=rtol)
|
||||
actual_data = np.loadtxt(actual, np.float64, skiprows=3)
|
||||
expected_data = np.loadtxt(expected, np.float64, skiprows=3)
|
||||
return np.allclose(actual_data, expected_data, rtol=rtol)
|
||||
|
||||
|
||||
def file_as_string(file: str) -> str:
|
||||
"""Returns contents of file as string."""
|
||||
with open(file, "r") as f:
|
||||
return f.read()
|
||||
"""Returns contents of file as string."""
|
||||
with open(file, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def run_test(f):
|
||||
"""Prints the test name and runs the test."""
|
||||
print(f.__name__)
|
||||
f()
|
||||
return f
|
||||
"""Prints the test name and runs the test."""
|
||||
print(f.__name__)
|
||||
f()
|
||||
return f
|
||||
|
||||
@@ -18,509 +18,630 @@ _DENSE = mlir_pytaco.ModeFormat.DENSE
|
||||
|
||||
|
||||
def _init_3d(T, I, J, K):
|
||||
for i in range(I):
|
||||
for j in range(J):
|
||||
for k in range(K):
|
||||
T.insert([i, j, k], i + j + k + 1)
|
||||
for i in range(I):
|
||||
for j in range(J):
|
||||
for k in range(K):
|
||||
T.insert([i, j, k], i + j + k + 1)
|
||||
|
||||
|
||||
def _init_2d(T, I, J):
|
||||
for i in range(I):
|
||||
for j in range(J):
|
||||
T.insert([i, j], i + j + 1)
|
||||
for i in range(I):
|
||||
for j in range(J):
|
||||
T.insert([i, j], i + j + 1)
|
||||
|
||||
|
||||
def _init_1d_with_value(T, I, v):
|
||||
for i in range(I):
|
||||
T.insert([i], v)
|
||||
for i in range(I):
|
||||
T.insert([i], v)
|
||||
|
||||
|
||||
def test_expect_error(name, code, error):
|
||||
"""Executes the code then verifies the expected error message."""
|
||||
try:
|
||||
exec(code)
|
||||
except ValueError as e:
|
||||
passed = "passed" if (str(e).startswith(error)) else "failed"
|
||||
print(f"test_{name}: {passed}")
|
||||
"""Executes the code then verifies the expected error message."""
|
||||
try:
|
||||
exec(code)
|
||||
except ValueError as e:
|
||||
passed = "passed" if (str(e).startswith(error)) else "failed"
|
||||
print(f"test_{name}: {passed}")
|
||||
|
||||
|
||||
# CHECK-LABEL: test_tensor_dtype
|
||||
@testing_utils.run_test
|
||||
def test_tensor_dtype():
|
||||
passed = mlir_pytaco.DType(mlir_pytaco.Type.INT16).is_int()
|
||||
passed += mlir_pytaco.DType(mlir_pytaco.Type.INT32).is_int()
|
||||
passed += mlir_pytaco.DType(mlir_pytaco.Type.INT64).is_int()
|
||||
passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32).is_float()
|
||||
passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64).is_float()
|
||||
# CHECK: Number of passed: 5
|
||||
print("Number of passed:", passed)
|
||||
passed = mlir_pytaco.DType(mlir_pytaco.Type.INT16).is_int()
|
||||
passed += mlir_pytaco.DType(mlir_pytaco.Type.INT32).is_int()
|
||||
passed += mlir_pytaco.DType(mlir_pytaco.Type.INT64).is_int()
|
||||
passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT32).is_float()
|
||||
passed += mlir_pytaco.DType(mlir_pytaco.Type.FLOAT64).is_float()
|
||||
# CHECK: Number of passed: 5
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK: test_mode_ordering_not_int: passed
|
||||
test_expect_error("mode_ordering_not_int",
|
||||
"m = mlir_pytaco.ModeOrdering(['x'])",
|
||||
"Ordering must be a list of integers")
|
||||
test_expect_error(
|
||||
"mode_ordering_not_int",
|
||||
"m = mlir_pytaco.ModeOrdering(['x'])",
|
||||
"Ordering must be a list of integers",
|
||||
)
|
||||
|
||||
# CHECK: test_mode_ordering_not_permutation: passed
|
||||
test_expect_error("mode_ordering_not_permutation",
|
||||
"m = mlir_pytaco.ModeOrdering([2, 1])", "Invalid ordering")
|
||||
test_expect_error(
|
||||
"mode_ordering_not_permutation",
|
||||
"m = mlir_pytaco.ModeOrdering([2, 1])",
|
||||
"Invalid ordering",
|
||||
)
|
||||
|
||||
# CHECK: test_mode_format_invalid: passed
|
||||
test_expect_error("mode_format_invalid",
|
||||
"m = mlir_pytaco.ModeFormatPack(['y'])",
|
||||
"Formats must be a list of ModeFormat")
|
||||
test_expect_error(
|
||||
"mode_format_invalid",
|
||||
"m = mlir_pytaco.ModeFormatPack(['y'])",
|
||||
"Formats must be a list of ModeFormat",
|
||||
)
|
||||
|
||||
# CHECK: test_expect_mode_format_pack: passed
|
||||
test_expect_error("expect_mode_format_pack", ("""
|
||||
test_expect_error(
|
||||
"expect_mode_format_pack",
|
||||
(
|
||||
"""
|
||||
mode_ordering = mlir_pytaco.ModeOrdering([0, 1, 2])
|
||||
f = mlir_pytaco.Format(["x"], mode_ordering)
|
||||
"""), "Expected a list of ModeFormat")
|
||||
"""
|
||||
),
|
||||
"Expected a list of ModeFormat",
|
||||
)
|
||||
|
||||
# CHECK: test_expect_mode_ordering: passed
|
||||
test_expect_error("expect_mode_ordering", ("""
|
||||
test_expect_error(
|
||||
"expect_mode_ordering",
|
||||
(
|
||||
"""
|
||||
mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
|
||||
f = mlir_pytaco.Format(mode_format_pack, "x")
|
||||
"""), "Expected ModeOrdering")
|
||||
"""
|
||||
),
|
||||
"Expected ModeOrdering",
|
||||
)
|
||||
|
||||
# CHECK: test_inconsistent_mode_format_pack_and_mode_ordering: passed
|
||||
test_expect_error("inconsistent_mode_format_pack_and_mode_ordering", ("""
|
||||
test_expect_error(
|
||||
"inconsistent_mode_format_pack_and_mode_ordering",
|
||||
(
|
||||
"""
|
||||
mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
|
||||
mode_ordering = mlir_pytaco.ModeOrdering([0, 1, 2])
|
||||
f = mlir_pytaco.Format(mode_format_pack, mode_ordering)
|
||||
"""), "Inconsistent ModeFormatPack and ModeOrdering")
|
||||
"""
|
||||
),
|
||||
"Inconsistent ModeFormatPack and ModeOrdering",
|
||||
)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_format_default_ordering
|
||||
@testing_utils.run_test
|
||||
def test_format_default_ordering():
|
||||
f = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED])
|
||||
passed = 0
|
||||
passed += np.array_equal(f.ordering.ordering, [0, 1])
|
||||
# CHECK: Number of passed: 1
|
||||
print("Number of passed:", passed)
|
||||
f = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED])
|
||||
passed = 0
|
||||
passed += np.array_equal(f.ordering.ordering, [0, 1])
|
||||
# CHECK: Number of passed: 1
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_format_explicit_ordering
|
||||
@testing_utils.run_test
|
||||
def test_format_explicit_ordering():
|
||||
f = mlir_pytaco.Format([_COMPRESSED, _DENSE], [1, 0])
|
||||
passed = 0
|
||||
passed += np.array_equal(f.ordering.ordering, [1, 0])
|
||||
# CHECK: Number of passed: 1
|
||||
print("Number of passed:", passed)
|
||||
f = mlir_pytaco.Format([_COMPRESSED, _DENSE], [1, 0])
|
||||
passed = 0
|
||||
passed += np.array_equal(f.ordering.ordering, [1, 0])
|
||||
# CHECK: Number of passed: 1
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_index_var
|
||||
@testing_utils.run_test
|
||||
def test_index_var():
|
||||
i = mlir_pytaco.IndexVar()
|
||||
j = mlir_pytaco.IndexVar()
|
||||
passed = (i.name != j.name)
|
||||
i = mlir_pytaco.IndexVar()
|
||||
j = mlir_pytaco.IndexVar()
|
||||
passed = i.name != j.name
|
||||
|
||||
vars = mlir_pytaco.get_index_vars(10)
|
||||
passed += (len(vars) == 10)
|
||||
passed += (all([isinstance(e, mlir_pytaco.IndexVar) for e in vars]))
|
||||
vars = mlir_pytaco.get_index_vars(10)
|
||||
passed += len(vars) == 10
|
||||
passed += all([isinstance(e, mlir_pytaco.IndexVar) for e in vars])
|
||||
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK: test_tensor_invalid_first_argument: passed
|
||||
test_expect_error("tensor_invalid_first_argument",
|
||||
"t = mlir_pytaco.Tensor('f')", "Invalid first argument")
|
||||
test_expect_error(
|
||||
"tensor_invalid_first_argument",
|
||||
"t = mlir_pytaco.Tensor('f')",
|
||||
"Invalid first argument",
|
||||
)
|
||||
|
||||
# CHECK: test_tensor_inconsistent_shape_and_format: passed
|
||||
test_expect_error("tensor_inconsistent_shape_and_format", ("""
|
||||
test_expect_error(
|
||||
"tensor_inconsistent_shape_and_format",
|
||||
(
|
||||
"""
|
||||
mode_format_pack = mlir_pytaco.ModeFormatPack([_COMPRESSED, _COMPRESSED])
|
||||
mode_ordering = mlir_pytaco.ModeOrdering([0, 1])
|
||||
f = mlir_pytaco.Format(mode_format_pack, mode_ordering)
|
||||
t = mlir_pytaco.Tensor([3], f)
|
||||
"""), "Inconsistent shape and format")
|
||||
"""
|
||||
),
|
||||
"Inconsistent shape and format",
|
||||
)
|
||||
|
||||
# CHECK: test_tensor_invalid_format: passed
|
||||
test_expect_error("tensor_invalid_format", "t = mlir_pytaco.Tensor([3], 'f')",
|
||||
"Invalid format argument")
|
||||
test_expect_error(
|
||||
"tensor_invalid_format",
|
||||
"t = mlir_pytaco.Tensor([3], 'f')",
|
||||
"Invalid format argument",
|
||||
)
|
||||
|
||||
# CHECK: test_tensor_insert_nonlist_coordinate: passed
|
||||
test_expect_error("tensor_insert_nonlist_coordinate", ("""
|
||||
test_expect_error(
|
||||
"tensor_insert_nonlist_coordinate",
|
||||
(
|
||||
"""
|
||||
t = mlir_pytaco.Tensor([3])
|
||||
t.insert(1, 0)
|
||||
"""), "Non list coordinate detected")
|
||||
"""
|
||||
),
|
||||
"Non list coordinate detected",
|
||||
)
|
||||
|
||||
# CHECK: test_tensor_insert_too_much_coordinate: passed
|
||||
test_expect_error("tensor_insert_too_much_coordinate", ("""
|
||||
test_expect_error(
|
||||
"tensor_insert_too_much_coordinate",
|
||||
(
|
||||
"""
|
||||
t = mlir_pytaco.Tensor([3])
|
||||
t.insert([0, 0], 0)
|
||||
"""), "Invalid coordinate")
|
||||
"""
|
||||
),
|
||||
"Invalid coordinate",
|
||||
)
|
||||
|
||||
# CHECK: test_tensor_insert_coordinate_outof_range: passed
|
||||
test_expect_error("tensor_insert_coordinate_outof_range", ("""
|
||||
test_expect_error(
|
||||
"tensor_insert_coordinate_outof_range",
|
||||
(
|
||||
"""
|
||||
t = mlir_pytaco.Tensor([1, 1])
|
||||
t.insert([1, 0], 0)
|
||||
"""), "Invalid coordinate")
|
||||
"""
|
||||
),
|
||||
"Invalid coordinate",
|
||||
)
|
||||
|
||||
# CHECK: test_tensor_insert_coordinate_nonint: passed
|
||||
test_expect_error("tensor_insert_coordinate_nonint", ("""
|
||||
test_expect_error(
|
||||
"tensor_insert_coordinate_nonint",
|
||||
(
|
||||
"""
|
||||
t = mlir_pytaco.Tensor([1, 1])
|
||||
t.insert([0, "xy"], 0)
|
||||
"""), "Non integer coordinate detected")
|
||||
"""
|
||||
),
|
||||
"Non integer coordinate detected",
|
||||
)
|
||||
|
||||
# CHECK: test_tensor_insert_invalid_value: passed
|
||||
test_expect_error("tensor_insert_invalid_value", ("""
|
||||
test_expect_error(
|
||||
"tensor_insert_invalid_value",
|
||||
(
|
||||
"""
|
||||
t = mlir_pytaco.Tensor([1, 1])
|
||||
t.insert([0, 0], "x")
|
||||
"""), "Value is neither int nor float")
|
||||
"""
|
||||
),
|
||||
"Value is neither int nor float",
|
||||
)
|
||||
|
||||
# CHECK: test_access_non_index_var_index: passed
|
||||
test_expect_error("access_non_index_var_index", ("""
|
||||
test_expect_error(
|
||||
"access_non_index_var_index",
|
||||
(
|
||||
"""
|
||||
t = mlir_pytaco.Tensor([5, 6])
|
||||
i = mlir_pytaco.IndexVar()
|
||||
a = mlir_pytaco.Access(t, (i, "j"))
|
||||
"""), "Indices contain non IndexVar")
|
||||
"""
|
||||
),
|
||||
"Indices contain non IndexVar",
|
||||
)
|
||||
|
||||
# CHECK: test_access_inconsistent_rank_indices: passed
|
||||
test_expect_error("access_inconsistent_rank_indices", ("""
|
||||
test_expect_error(
|
||||
"access_inconsistent_rank_indices",
|
||||
(
|
||||
"""
|
||||
t = mlir_pytaco.Tensor([5, 6])
|
||||
i = mlir_pytaco.IndexVar()
|
||||
a = mlir_pytaco.Access(t, (i,))
|
||||
"""), "Invalid indices for rank")
|
||||
"""
|
||||
),
|
||||
"Invalid indices for rank",
|
||||
)
|
||||
|
||||
# CHECK: test_access_invalid_indices_for_rank: passed
|
||||
test_expect_error("access_invalid_indices_for_rank", ("""
|
||||
test_expect_error(
|
||||
"access_invalid_indices_for_rank",
|
||||
(
|
||||
"""
|
||||
t = mlir_pytaco.Tensor([5, 6])
|
||||
i, j, k = mlir_pytaco.get_index_vars(3)
|
||||
a = mlir_pytaco.Access(t, (i,j, k))
|
||||
"""), "Invalid indices for rank")
|
||||
"""
|
||||
),
|
||||
"Invalid indices for rank",
|
||||
)
|
||||
|
||||
# CHECK: test_invalid_indices: passed
|
||||
test_expect_error("invalid_indices", ("""
|
||||
test_expect_error(
|
||||
"invalid_indices",
|
||||
(
|
||||
"""
|
||||
i, j = mlir_pytaco.get_index_vars(2)
|
||||
A = mlir_pytaco.Tensor([2, 3])
|
||||
B = mlir_pytaco.Tensor([2, 3])
|
||||
C = mlir_pytaco.Tensor([2, 3], _DENSE)
|
||||
C[i, j] = A[1, j] + B[i, j]
|
||||
"""), "Expected IndexVars")
|
||||
"""
|
||||
),
|
||||
"Expected IndexVars",
|
||||
)
|
||||
|
||||
# CHECK: test_inconsistent_rank_indices: passed
|
||||
test_expect_error("inconsistent_rank_indices", ("""
|
||||
test_expect_error(
|
||||
"inconsistent_rank_indices",
|
||||
(
|
||||
"""
|
||||
i, j = mlir_pytaco.get_index_vars(2)
|
||||
A = mlir_pytaco.Tensor([2, 3])
|
||||
C = mlir_pytaco.Tensor([2, 3], _DENSE)
|
||||
C[i, j] = A[i]
|
||||
"""), "Invalid indices for rank")
|
||||
"""
|
||||
),
|
||||
"Invalid indices for rank",
|
||||
)
|
||||
|
||||
# CHECK: test_destination_index_not_used_in_source: passed
|
||||
test_expect_error("destination_index_not_used_in_source", ("""
|
||||
test_expect_error(
|
||||
"destination_index_not_used_in_source",
|
||||
(
|
||||
"""
|
||||
i, j = mlir_pytaco.get_index_vars(2)
|
||||
A = mlir_pytaco.Tensor([3])
|
||||
C = mlir_pytaco.Tensor([3], _DENSE)
|
||||
C[j] = A[i]
|
||||
C.evaluate()
|
||||
"""), "Destination IndexVar not used in the source expression")
|
||||
"""
|
||||
),
|
||||
"Destination IndexVar not used in the source expression",
|
||||
)
|
||||
|
||||
# CHECK: test_destination_dim_not_consistent_with_source: passed
|
||||
test_expect_error("destination_dim_not_consistent_with_source", ("""
|
||||
test_expect_error(
|
||||
"destination_dim_not_consistent_with_source",
|
||||
(
|
||||
"""
|
||||
i = mlir_pytaco.IndexVar()
|
||||
A = mlir_pytaco.Tensor([3])
|
||||
C = mlir_pytaco.Tensor([5], _DENSE)
|
||||
C[i] = A[i]
|
||||
C.evaluate()
|
||||
"""), "Inconsistent destination dimension for IndexVar")
|
||||
"""
|
||||
),
|
||||
"Inconsistent destination dimension for IndexVar",
|
||||
)
|
||||
|
||||
# CHECK: test_inconsistent_source_dim: passed
|
||||
test_expect_error("inconsistent_source_dim", ("""
|
||||
test_expect_error(
|
||||
"inconsistent_source_dim",
|
||||
(
|
||||
"""
|
||||
i = mlir_pytaco.IndexVar()
|
||||
A = mlir_pytaco.Tensor([3])
|
||||
B = mlir_pytaco.Tensor([5])
|
||||
C = mlir_pytaco.Tensor([3], _DENSE)
|
||||
C[i] = A[i] + B[i]
|
||||
C.evaluate()
|
||||
"""), "Inconsistent source dimension for IndexVar")
|
||||
"""
|
||||
),
|
||||
"Inconsistent source dimension for IndexVar",
|
||||
)
|
||||
|
||||
# CHECK: test_index_var_outside_domain: passed
|
||||
test_expect_error("index_var_outside_domain", ("""
|
||||
test_expect_error(
|
||||
"index_var_outside_domain",
|
||||
(
|
||||
"""
|
||||
i, j = mlir_pytaco.get_index_vars(2)
|
||||
A = mlir_pytaco.Tensor([3])
|
||||
B = mlir_pytaco.Tensor([3])
|
||||
B[i] = A[i] + j
|
||||
B.evaluate()
|
||||
"""), "IndexVar is not part of the iteration domain")
|
||||
"""
|
||||
),
|
||||
"IndexVar is not part of the iteration domain",
|
||||
)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_tensor_all_dense_sparse
|
||||
@testing_utils.run_test
|
||||
def test_tensor_all_dense_sparse():
|
||||
a = mlir_pytaco.Tensor([4], [_DENSE])
|
||||
passed = (not a.is_dense())
|
||||
passed += (a.order == 1)
|
||||
passed += (a.shape[0] == 4)
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
a = mlir_pytaco.Tensor([4], [_DENSE])
|
||||
passed = not a.is_dense()
|
||||
passed += a.order == 1
|
||||
passed += a.shape[0] == 4
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_tensor_true_dense
|
||||
@testing_utils.run_test
|
||||
def test_tensor_true_dense():
|
||||
a = mlir_pytaco.Tensor.from_array(np.random.uniform(size=5))
|
||||
passed = a.is_dense()
|
||||
passed += (a.order == 1)
|
||||
passed += (a.shape[0] == 5)
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
a = mlir_pytaco.Tensor.from_array(np.random.uniform(size=5))
|
||||
passed = a.is_dense()
|
||||
passed += a.order == 1
|
||||
passed += a.shape[0] == 5
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_tensor_copy
|
||||
@testing_utils.run_test
|
||||
def test_tensor_copy():
|
||||
i, j = mlir_pytaco.get_index_vars(2)
|
||||
I = 2
|
||||
J = 3
|
||||
A = mlir_pytaco.Tensor([I, J])
|
||||
A.insert([0, 1], 5.0)
|
||||
A.insert([1, 2], 6.0)
|
||||
B = mlir_pytaco.Tensor([I, J])
|
||||
B[i, j] = A[i, j]
|
||||
passed = (B._assignment is not None)
|
||||
passed += (B._engine is None)
|
||||
try:
|
||||
i, j = mlir_pytaco.get_index_vars(2)
|
||||
I = 2
|
||||
J = 3
|
||||
A = mlir_pytaco.Tensor([I, J])
|
||||
A.insert([0, 1], 5.0)
|
||||
A.insert([1, 2], 6.0)
|
||||
B = mlir_pytaco.Tensor([I, J])
|
||||
B[i, j] = A[i, j]
|
||||
passed = B._assignment is not None
|
||||
passed += B._engine is None
|
||||
try:
|
||||
B.compute()
|
||||
except ValueError as e:
|
||||
passed += str(e).startswith("Need to invoke compile")
|
||||
B.compile()
|
||||
passed += B._engine is not None
|
||||
B.compute()
|
||||
except ValueError as e:
|
||||
passed += (str(e).startswith("Need to invoke compile"))
|
||||
B.compile()
|
||||
passed += (B._engine is not None)
|
||||
B.compute()
|
||||
passed += (B._assignment is None)
|
||||
passed += (B._engine is None)
|
||||
indices, values = B.get_coordinates_and_values()
|
||||
passed += np.array_equal(indices, [[0, 1], [1, 2]])
|
||||
passed += np.allclose(values, [5.0, 6.0])
|
||||
# No temporary tensor is used.
|
||||
passed += (B._stats.get_total() == 0)
|
||||
# CHECK: Number of passed: 9
|
||||
print("Number of passed:", passed)
|
||||
passed += B._assignment is None
|
||||
passed += B._engine is None
|
||||
indices, values = B.get_coordinates_and_values()
|
||||
passed += np.array_equal(indices, [[0, 1], [1, 2]])
|
||||
passed += np.allclose(values, [5.0, 6.0])
|
||||
# No temporary tensor is used.
|
||||
passed += B._stats.get_total() == 0
|
||||
# CHECK: Number of passed: 9
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_tensor_trivial_reduction
|
||||
@testing_utils.run_test
|
||||
def test_tensor_trivial_reduction():
|
||||
i, j = mlir_pytaco.get_index_vars(2)
|
||||
I = 2
|
||||
J = 3
|
||||
A = mlir_pytaco.Tensor([I, J])
|
||||
A.insert([0, 1], 5.0)
|
||||
A.insert([0, 2], 3.0)
|
||||
A.insert([1, 2], 6.0)
|
||||
B = mlir_pytaco.Tensor([I])
|
||||
B[i] = A[i, j]
|
||||
indices, values = B.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0], [1]])
|
||||
passed += np.allclose(values, [8.0, 6.0])
|
||||
# No temporary tensor is used.
|
||||
passed += (B._stats.get_total() == 0)
|
||||
i, j = mlir_pytaco.get_index_vars(2)
|
||||
I = 2
|
||||
J = 3
|
||||
A = mlir_pytaco.Tensor([I, J])
|
||||
A.insert([0, 1], 5.0)
|
||||
A.insert([0, 2], 3.0)
|
||||
A.insert([1, 2], 6.0)
|
||||
B = mlir_pytaco.Tensor([I])
|
||||
B[i] = A[i, j]
|
||||
indices, values = B.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0], [1]])
|
||||
passed += np.allclose(values, [8.0, 6.0])
|
||||
# No temporary tensor is used.
|
||||
passed += B._stats.get_total() == 0
|
||||
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_binary_add
|
||||
@testing_utils.run_test
|
||||
def test_binary_add():
|
||||
i = mlir_pytaco.IndexVar()
|
||||
A = mlir_pytaco.Tensor([4])
|
||||
B = mlir_pytaco.Tensor([4])
|
||||
C = mlir_pytaco.Tensor([4])
|
||||
A.insert([1], 10)
|
||||
A.insert([2], 1)
|
||||
B.insert([3], 20)
|
||||
B.insert([2], 2)
|
||||
C[i] = A[i] + B[i]
|
||||
indices, values = C.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[1], [2], [3]])
|
||||
passed += np.array_equal(values, [10., 3., 20.])
|
||||
# No temporary tensor is used.
|
||||
passed += (C._stats.get_total() == 0)
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
i = mlir_pytaco.IndexVar()
|
||||
A = mlir_pytaco.Tensor([4])
|
||||
B = mlir_pytaco.Tensor([4])
|
||||
C = mlir_pytaco.Tensor([4])
|
||||
A.insert([1], 10)
|
||||
A.insert([2], 1)
|
||||
B.insert([3], 20)
|
||||
B.insert([2], 2)
|
||||
C[i] = A[i] + B[i]
|
||||
indices, values = C.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[1], [2], [3]])
|
||||
passed += np.array_equal(values, [10.0, 3.0, 20.0])
|
||||
# No temporary tensor is used.
|
||||
passed += C._stats.get_total() == 0
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_binary_add_sub
|
||||
@testing_utils.run_test
|
||||
def test_binary_add_sub():
|
||||
i = mlir_pytaco.IndexVar()
|
||||
j = mlir_pytaco.IndexVar()
|
||||
A = mlir_pytaco.Tensor([2, 3])
|
||||
B = mlir_pytaco.Tensor([2, 3])
|
||||
C = mlir_pytaco.Tensor([2, 3])
|
||||
D = mlir_pytaco.Tensor([2, 3])
|
||||
A.insert([0, 1], 10)
|
||||
A.insert([1, 2], 40)
|
||||
B.insert([0, 0], 20)
|
||||
B.insert([1, 2], 30)
|
||||
C.insert([0, 1], 5)
|
||||
C.insert([1, 2], 7)
|
||||
D[i, j] = A[i, j] + B[i, j] - C[i, j]
|
||||
indices, values = D.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
|
||||
passed += np.array_equal(values, [20., 5., 63.])
|
||||
# No temporary tensor is used.
|
||||
passed += (D._stats.get_total() == 0)
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
i = mlir_pytaco.IndexVar()
|
||||
j = mlir_pytaco.IndexVar()
|
||||
A = mlir_pytaco.Tensor([2, 3])
|
||||
B = mlir_pytaco.Tensor([2, 3])
|
||||
C = mlir_pytaco.Tensor([2, 3])
|
||||
D = mlir_pytaco.Tensor([2, 3])
|
||||
A.insert([0, 1], 10)
|
||||
A.insert([1, 2], 40)
|
||||
B.insert([0, 0], 20)
|
||||
B.insert([1, 2], 30)
|
||||
C.insert([0, 1], 5)
|
||||
C.insert([1, 2], 7)
|
||||
D[i, j] = A[i, j] + B[i, j] - C[i, j]
|
||||
indices, values = D.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
|
||||
passed += np.array_equal(values, [20.0, 5.0, 63.0])
|
||||
# No temporary tensor is used.
|
||||
passed += D._stats.get_total() == 0
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_binary_mul_add
|
||||
@testing_utils.run_test
|
||||
def test_binary_mul_add():
|
||||
i = mlir_pytaco.IndexVar()
|
||||
j = mlir_pytaco.IndexVar()
|
||||
A = mlir_pytaco.Tensor([2, 3])
|
||||
B = mlir_pytaco.Tensor([2, 3])
|
||||
C = mlir_pytaco.Tensor([2, 3])
|
||||
D = mlir_pytaco.Tensor([2, 3])
|
||||
A.insert([0, 1], 10)
|
||||
A.insert([1, 2], 40)
|
||||
B.insert([0, 0], 20)
|
||||
B.insert([1, 2], 30)
|
||||
C.insert([0, 1], 5)
|
||||
C.insert([1, 2], 7)
|
||||
D[i, j] = A[i, j] * C[i, j] + B[i, j]
|
||||
indices, values = D.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
|
||||
passed += np.array_equal(values, [20., 50., 310.])
|
||||
# No temporary tensor is used.
|
||||
passed += (D._stats.get_total() == 0)
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
i = mlir_pytaco.IndexVar()
|
||||
j = mlir_pytaco.IndexVar()
|
||||
A = mlir_pytaco.Tensor([2, 3])
|
||||
B = mlir_pytaco.Tensor([2, 3])
|
||||
C = mlir_pytaco.Tensor([2, 3])
|
||||
D = mlir_pytaco.Tensor([2, 3])
|
||||
A.insert([0, 1], 10)
|
||||
A.insert([1, 2], 40)
|
||||
B.insert([0, 0], 20)
|
||||
B.insert([1, 2], 30)
|
||||
C.insert([0, 1], 5)
|
||||
C.insert([1, 2], 7)
|
||||
D[i, j] = A[i, j] * C[i, j] + B[i, j]
|
||||
indices, values = D.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0, 0], [0, 1], [1, 2]])
|
||||
passed += np.array_equal(values, [20.0, 50.0, 310.0])
|
||||
# No temporary tensor is used.
|
||||
passed += D._stats.get_total() == 0
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_binary_add_reduce_at_root
|
||||
@testing_utils.run_test
|
||||
def test_binary_add_reduce_at_root():
|
||||
i = mlir_pytaco.IndexVar()
|
||||
j = mlir_pytaco.IndexVar()
|
||||
A = mlir_pytaco.Tensor([2, 3])
|
||||
B = mlir_pytaco.Tensor([2, 3])
|
||||
C = mlir_pytaco.Tensor([2], _DENSE)
|
||||
A.insert([0, 1], 10)
|
||||
A.insert([1, 2], 40)
|
||||
B.insert([0, 0], 20)
|
||||
B.insert([1, 2], 30)
|
||||
C[i] = A[i, j] + B[i, j]
|
||||
indices, values = C.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0], [1]])
|
||||
passed += np.array_equal(values, [30., 70.])
|
||||
# No temporary tensor is used.
|
||||
passed += (C._stats.get_total() == 0)
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
i = mlir_pytaco.IndexVar()
|
||||
j = mlir_pytaco.IndexVar()
|
||||
A = mlir_pytaco.Tensor([2, 3])
|
||||
B = mlir_pytaco.Tensor([2, 3])
|
||||
C = mlir_pytaco.Tensor([2], _DENSE)
|
||||
A.insert([0, 1], 10)
|
||||
A.insert([1, 2], 40)
|
||||
B.insert([0, 0], 20)
|
||||
B.insert([1, 2], 30)
|
||||
C[i] = A[i, j] + B[i, j]
|
||||
indices, values = C.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0], [1]])
|
||||
passed += np.array_equal(values, [30.0, 70.0])
|
||||
# No temporary tensor is used.
|
||||
passed += C._stats.get_total() == 0
|
||||
# CHECK: Number of passed: 3
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_binary_add_reduce_at_child
|
||||
@testing_utils.run_test
|
||||
def test_binary_add_reduce_at_child():
|
||||
i = mlir_pytaco.IndexVar()
|
||||
j = mlir_pytaco.IndexVar()
|
||||
I = 2
|
||||
J = 3
|
||||
A = mlir_pytaco.Tensor([I, J])
|
||||
B = mlir_pytaco.Tensor([J])
|
||||
C = mlir_pytaco.Tensor([I])
|
||||
D = mlir_pytaco.Tensor([I], _DENSE)
|
||||
i = mlir_pytaco.IndexVar()
|
||||
j = mlir_pytaco.IndexVar()
|
||||
I = 2
|
||||
J = 3
|
||||
A = mlir_pytaco.Tensor([I, J])
|
||||
B = mlir_pytaco.Tensor([J])
|
||||
C = mlir_pytaco.Tensor([I])
|
||||
D = mlir_pytaco.Tensor([I], _DENSE)
|
||||
|
||||
_init_2d(A, I, J)
|
||||
_init_1d_with_value(C, I, 2)
|
||||
_init_1d_with_value(B, J, 1)
|
||||
_init_2d(A, I, J)
|
||||
_init_1d_with_value(C, I, 2)
|
||||
_init_1d_with_value(B, J, 1)
|
||||
|
||||
D[i] = A[i, j] * B[j] + C[i]
|
||||
indices, values = D.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0], [1]])
|
||||
passed += np.array_equal(values, [8., 11.])
|
||||
D[i] = A[i, j] * B[j] + C[i]
|
||||
indices, values = D.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0], [1]])
|
||||
passed += np.array_equal(values, [8.0, 11.0])
|
||||
|
||||
# The expression is implemented as:
|
||||
# temp0[i] = A[i, j] * B[i]
|
||||
# D[i] = temp0[i] + C[i]
|
||||
# Check the temporary tensor introduced by the implementation.
|
||||
stats = D._stats
|
||||
passed += (stats.get_total() == 1)
|
||||
passed += (stats.get_formats(0) == (_COMPRESSED,))
|
||||
passed += (stats.get_dimensions(0) == (I,))
|
||||
# CHECK: Number of passed: 5
|
||||
print("Number of passed:", passed)
|
||||
# The expression is implemented as:
|
||||
# temp0[i] = A[i, j] * B[i]
|
||||
# D[i] = temp0[i] + C[i]
|
||||
# Check the temporary tensor introduced by the implementation.
|
||||
stats = D._stats
|
||||
passed += stats.get_total() == 1
|
||||
passed += stats.get_formats(0) == (_COMPRESSED,)
|
||||
passed += stats.get_dimensions(0) == (I,)
|
||||
# CHECK: Number of passed: 5
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_binary_add_reduce_3d_1
|
||||
@testing_utils.run_test
|
||||
def test_binary_add_reduce_3d_1():
|
||||
i, j, k, l = mlir_pytaco.get_index_vars(4)
|
||||
I = 2
|
||||
J = 3
|
||||
K = 4
|
||||
L = 5
|
||||
A = mlir_pytaco.Tensor([I, J, K])
|
||||
B = mlir_pytaco.Tensor([I, J, L])
|
||||
C = mlir_pytaco.Tensor([K])
|
||||
D = mlir_pytaco.Tensor([L])
|
||||
E = mlir_pytaco.Tensor([I], _DENSE)
|
||||
i, j, k, l = mlir_pytaco.get_index_vars(4)
|
||||
I = 2
|
||||
J = 3
|
||||
K = 4
|
||||
L = 5
|
||||
A = mlir_pytaco.Tensor([I, J, K])
|
||||
B = mlir_pytaco.Tensor([I, J, L])
|
||||
C = mlir_pytaco.Tensor([K])
|
||||
D = mlir_pytaco.Tensor([L])
|
||||
E = mlir_pytaco.Tensor([I], _DENSE)
|
||||
|
||||
_init_3d(A, I, J, K)
|
||||
_init_3d(B, I, J, L)
|
||||
_init_1d_with_value(C, K, 1)
|
||||
_init_1d_with_value(D, L, 2)
|
||||
_init_3d(A, I, J, K)
|
||||
_init_3d(B, I, J, L)
|
||||
_init_1d_with_value(C, K, 1)
|
||||
_init_1d_with_value(D, L, 2)
|
||||
|
||||
E[i] = A[i, j, k] * C[k] + B[i, j, l] * D[l]
|
||||
indices, values = E.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0], [1]])
|
||||
passed += np.array_equal(values, [162., 204.])
|
||||
E[i] = A[i, j, k] * C[k] + B[i, j, l] * D[l]
|
||||
indices, values = E.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0], [1]])
|
||||
passed += np.array_equal(values, [162.0, 204.0])
|
||||
|
||||
# The expression is implemented as:
|
||||
# temp0[i, j] = A[i, j, k] * C[k]
|
||||
# temp1[i, j] = B[i, j, l] * D[l]
|
||||
# E[i] = temp0[i, j] + temp1[i, j]
|
||||
# Check the two temporary tensors introduced by the implementation.
|
||||
stats = E._stats
|
||||
passed += (stats.get_total() == 2)
|
||||
passed += (stats.get_formats(0) == (_COMPRESSED, _COMPRESSED))
|
||||
passed += (stats.get_dimensions(0) == (I, J))
|
||||
passed += (stats.get_formats(1) == (_COMPRESSED, _COMPRESSED))
|
||||
passed += (stats.get_dimensions(1) == (I, J))
|
||||
# CHECK: Number of passed: 7
|
||||
print("Number of passed:", passed)
|
||||
# The expression is implemented as:
|
||||
# temp0[i, j] = A[i, j, k] * C[k]
|
||||
# temp1[i, j] = B[i, j, l] * D[l]
|
||||
# E[i] = temp0[i, j] + temp1[i, j]
|
||||
# Check the two temporary tensors introduced by the implementation.
|
||||
stats = E._stats
|
||||
passed += stats.get_total() == 2
|
||||
passed += stats.get_formats(0) == (_COMPRESSED, _COMPRESSED)
|
||||
passed += stats.get_dimensions(0) == (I, J)
|
||||
passed += stats.get_formats(1) == (_COMPRESSED, _COMPRESSED)
|
||||
passed += stats.get_dimensions(1) == (I, J)
|
||||
# CHECK: Number of passed: 7
|
||||
print("Number of passed:", passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_binary_add_reduce_3d_2
|
||||
@testing_utils.run_test
|
||||
def test_binary_add_reduce_3d_2():
|
||||
i, j, k, l = mlir_pytaco.get_index_vars(4)
|
||||
I = 2
|
||||
J = 3
|
||||
K = 4
|
||||
L = 5
|
||||
A = mlir_pytaco.Tensor([I, J, K], [_COMPRESSED, _COMPRESSED, _DENSE])
|
||||
B = mlir_pytaco.Tensor([I, L, K], [_DENSE, _COMPRESSED, _COMPRESSED])
|
||||
C = mlir_pytaco.Tensor([J, K], [_COMPRESSED, _COMPRESSED])
|
||||
D = mlir_pytaco.Tensor([L])
|
||||
E = mlir_pytaco.Tensor([I], _DENSE)
|
||||
i, j, k, l = mlir_pytaco.get_index_vars(4)
|
||||
I = 2
|
||||
J = 3
|
||||
K = 4
|
||||
L = 5
|
||||
A = mlir_pytaco.Tensor([I, J, K], [_COMPRESSED, _COMPRESSED, _DENSE])
|
||||
B = mlir_pytaco.Tensor([I, L, K], [_DENSE, _COMPRESSED, _COMPRESSED])
|
||||
C = mlir_pytaco.Tensor([J, K], [_COMPRESSED, _COMPRESSED])
|
||||
D = mlir_pytaco.Tensor([L])
|
||||
E = mlir_pytaco.Tensor([I], _DENSE)
|
||||
|
||||
_init_3d(A, I, J, K)
|
||||
_init_3d(B, I, L, K)
|
||||
_init_2d(C, J, K)
|
||||
_init_1d_with_value(D, L, 2)
|
||||
_init_3d(A, I, J, K)
|
||||
_init_3d(B, I, L, K)
|
||||
_init_2d(C, J, K)
|
||||
_init_1d_with_value(D, L, 2)
|
||||
|
||||
E[i] = A[i, j, k] + C[j, k] + B[i, l, k] * D[l]
|
||||
indices, values = E.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0], [1]])
|
||||
passed += np.array_equal(values, [264., 316.])
|
||||
E[i] = A[i, j, k] + C[j, k] + B[i, l, k] * D[l]
|
||||
indices, values = E.get_coordinates_and_values()
|
||||
passed = np.array_equal(indices, [[0], [1]])
|
||||
passed += np.array_equal(values, [264.0, 316.0])
|
||||
|
||||
# The expression is implemented as:
|
||||
# temp0[i, k] = A[i, j, k] + C[j, k]
|
||||
# temp1[i, k] = B[i, l, k] * D[l]
|
||||
# E[i] = temp0[i, k] + temp1[i, k]
|
||||
# Check the two temporary tensors introduced by the implementation.
|
||||
stats = E._stats
|
||||
passed += (stats.get_total() == 2)
|
||||
passed += (stats.get_formats(0) == (_COMPRESSED, _DENSE))
|
||||
passed += (stats.get_dimensions(0) == (I, K))
|
||||
passed += (stats.get_formats(1) == (_DENSE, _COMPRESSED))
|
||||
passed += (stats.get_dimensions(1) == (I, K))
|
||||
# CHECK: Number of passed: 7
|
||||
print("Number of passed:", passed)
|
||||
# The expression is implemented as:
|
||||
# temp0[i, k] = A[i, j, k] + C[j, k]
|
||||
# temp1[i, k] = B[i, l, k] * D[l]
|
||||
# E[i] = temp0[i, k] + temp1[i, k]
|
||||
# Check the two temporary tensors introduced by the implementation.
|
||||
stats = E._stats
|
||||
passed += stats.get_total() == 2
|
||||
passed += stats.get_formats(0) == (_COMPRESSED, _DENSE)
|
||||
passed += stats.get_dimensions(0) == (I, K)
|
||||
passed += stats.get_formats(1) == (_DENSE, _COMPRESSED)
|
||||
passed += stats.get_dimensions(1) == (I, K)
|
||||
# CHECK: Number of passed: 7
|
||||
print("Number of passed:", passed)
|
||||
|
||||
@@ -32,21 +32,21 @@ _MTX_DATA = """%%MatrixMarket matrix coordinate real general
|
||||
# CHECK-LABEL: test_read_mtx_matrix_general
|
||||
@testing_utils.run_test
|
||||
def test_read_mtx_matrix_general():
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
file_name = os.path.join(test_dir, "data.mtx")
|
||||
with open(file_name, "w") as file:
|
||||
file.write(_MTX_DATA)
|
||||
a = mlir_pytaco_io.read(file_name, _FORMAT)
|
||||
passed = 0
|
||||
# The value of a is stored as an MLIR sparse tensor.
|
||||
passed += (not a.is_unpacked())
|
||||
a.unpack()
|
||||
passed += (a.is_unpacked())
|
||||
coords, values = a.get_coordinates_and_values()
|
||||
passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
|
||||
passed += np.allclose(values, [2.0, 3.0, 4.0])
|
||||
# CHECK: 4
|
||||
print(passed)
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
file_name = os.path.join(test_dir, "data.mtx")
|
||||
with open(file_name, "w") as file:
|
||||
file.write(_MTX_DATA)
|
||||
a = mlir_pytaco_io.read(file_name, _FORMAT)
|
||||
passed = 0
|
||||
# The value of a is stored as an MLIR sparse tensor.
|
||||
passed += not a.is_unpacked()
|
||||
a.unpack()
|
||||
passed += a.is_unpacked()
|
||||
coords, values = a.get_coordinates_and_values()
|
||||
passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
|
||||
passed += np.allclose(values, [2.0, 3.0, 4.0])
|
||||
# CHECK: 4
|
||||
print(passed)
|
||||
|
||||
|
||||
_TNS_DATA = """2 3
|
||||
@@ -60,57 +60,57 @@ _TNS_DATA = """2 3
|
||||
# CHECK-LABEL: test_read_tns
|
||||
@testing_utils.run_test
|
||||
def test_read_tns():
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
file_name = os.path.join(test_dir, "data.tns")
|
||||
with open(file_name, "w") as file:
|
||||
file.write(_TNS_DATA)
|
||||
a = mlir_pytaco_io.read(file_name, _FORMAT)
|
||||
passed = 0
|
||||
# The value of a is stored as an MLIR sparse tensor.
|
||||
passed += (not a.is_unpacked())
|
||||
a.unpack()
|
||||
passed += (a.is_unpacked())
|
||||
coords, values = a.get_coordinates_and_values()
|
||||
passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
|
||||
passed += np.allclose(values, [2.0, 3.0, 4.0])
|
||||
# CHECK: 4
|
||||
print(passed)
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
file_name = os.path.join(test_dir, "data.tns")
|
||||
with open(file_name, "w") as file:
|
||||
file.write(_TNS_DATA)
|
||||
a = mlir_pytaco_io.read(file_name, _FORMAT)
|
||||
passed = 0
|
||||
# The value of a is stored as an MLIR sparse tensor.
|
||||
passed += not a.is_unpacked()
|
||||
a.unpack()
|
||||
passed += a.is_unpacked()
|
||||
coords, values = a.get_coordinates_and_values()
|
||||
passed += np.array_equal(coords, [[0, 1], [2, 0], [2, 1]])
|
||||
passed += np.allclose(values, [2.0, 3.0, 4.0])
|
||||
# CHECK: 4
|
||||
print(passed)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_write_unpacked_tns
|
||||
@testing_utils.run_test
|
||||
def test_write_unpacked_tns():
|
||||
a = mlir_pytaco.Tensor([2, 3])
|
||||
a.insert([0, 1], 10)
|
||||
a.insert([1, 2], 40)
|
||||
a.insert([0, 0], 20)
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
file_name = os.path.join(test_dir, "data.tns")
|
||||
try:
|
||||
mlir_pytaco_io.write(file_name, a)
|
||||
except ValueError as e:
|
||||
# CHECK: Writing unpacked sparse tensors to file is not supported
|
||||
print(e)
|
||||
a = mlir_pytaco.Tensor([2, 3])
|
||||
a.insert([0, 1], 10)
|
||||
a.insert([1, 2], 40)
|
||||
a.insert([0, 0], 20)
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
file_name = os.path.join(test_dir, "data.tns")
|
||||
try:
|
||||
mlir_pytaco_io.write(file_name, a)
|
||||
except ValueError as e:
|
||||
# CHECK: Writing unpacked sparse tensors to file is not supported
|
||||
print(e)
|
||||
|
||||
|
||||
# CHECK-LABEL: test_write_packed_tns
|
||||
@testing_utils.run_test
|
||||
def test_write_packed_tns():
|
||||
a = mlir_pytaco.Tensor([2, 3])
|
||||
a.insert([0, 1], 10)
|
||||
a.insert([1, 2], 40)
|
||||
a.insert([0, 0], 20)
|
||||
b = mlir_pytaco.Tensor([2, 3])
|
||||
i, j = mlir_pytaco.get_index_vars(2)
|
||||
b[i, j] = a[i, j] + a[i, j]
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
file_name = os.path.join(test_dir, "data.tns")
|
||||
mlir_pytaco_io.write(file_name, b)
|
||||
with open(file_name, "r") as file:
|
||||
lines = file.readlines()
|
||||
passed = 0
|
||||
# Skip the comment line in the output.
|
||||
if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]:
|
||||
passed = 1
|
||||
# CHECK: 1
|
||||
print(passed)
|
||||
a = mlir_pytaco.Tensor([2, 3])
|
||||
a.insert([0, 1], 10)
|
||||
a.insert([1, 2], 40)
|
||||
a.insert([0, 0], 20)
|
||||
b = mlir_pytaco.Tensor([2, 3])
|
||||
i, j = mlir_pytaco.get_index_vars(2)
|
||||
b[i, j] = a[i, j] + a[i, j]
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
file_name = os.path.join(test_dir, "data.tns")
|
||||
mlir_pytaco_io.write(file_name, b)
|
||||
with open(file_name, "r") as file:
|
||||
lines = file.readlines()
|
||||
passed = 0
|
||||
# Skip the comment line in the output.
|
||||
if lines[1:] == ["2 3\n", "2 3\n", "1 1 40\n", "1 2 20\n", "2 3 80\n"]:
|
||||
passed = 1
|
||||
# CHECK: 1
|
||||
print(passed)
|
||||
|
||||
@@ -20,79 +20,93 @@ _DENSE = mlir_pytaco.ModeFormat.DENSE
|
||||
|
||||
|
||||
def _to_string(s: Sequence[int]) -> str:
|
||||
"""Converts a sequence of integer to a space separated value string."""
|
||||
return " ".join(map(lambda e: str(e), s))
|
||||
"""Converts a sequence of integer to a space separated value string."""
|
||||
return " ".join(map(lambda e: str(e), s))
|
||||
|
||||
|
||||
def _add_one(s: Sequence[int]) -> Sequence[int]:
|
||||
"""Adds one to each element in the sequence of integer."""
|
||||
return [i + 1 for i in s]
|
||||
"""Adds one to each element in the sequence of integer."""
|
||||
return [i + 1 for i in s]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _SparseTensorCOO:
|
||||
"""Values for a COO-flavored format sparse tensor.
|
||||
"""Values for a COO-flavored format sparse tensor.
|
||||
|
||||
Attributes:
|
||||
rank: An integer rank for the tensor.
|
||||
nse: An integer for the number of non-zero values.
|
||||
shape: A sequence of integer for the dimension size.
|
||||
values: A sequence of float for the non-zero values of the tensor.
|
||||
indices: A sequence of coordinate, each coordinate is a sequence of integer.
|
||||
"""
|
||||
rank: int
|
||||
nse: int
|
||||
shape: Sequence[int]
|
||||
values: Sequence[float]
|
||||
indices: Sequence[Sequence[int]]
|
||||
Attributes:
|
||||
rank: An integer rank for the tensor.
|
||||
nse: An integer for the number of non-zero values.
|
||||
shape: A sequence of integer for the dimension size.
|
||||
values: A sequence of float for the non-zero values of the tensor.
|
||||
indices: A sequence of coordinate, each coordinate is a sequence of integer.
|
||||
"""
|
||||
|
||||
rank: int
|
||||
nse: int
|
||||
shape: Sequence[int]
|
||||
values: Sequence[float]
|
||||
indices: Sequence[Sequence[int]]
|
||||
|
||||
|
||||
def _coo_values_to_tns_format(t: _SparseTensorCOO) -> str:
|
||||
"""Converts a sparse tensor COO-flavored values to TNS text format."""
|
||||
# The coo_value_str contains one line for each (coordinate value) pair.
|
||||
# Indices are 1-based in TNS text format but 0-based in MLIR.
|
||||
coo_value_str = "\n".join(
|
||||
map(lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]),
|
||||
range(t.nse)))
|
||||
"""Converts a sparse tensor COO-flavored values to TNS text format."""
|
||||
# The coo_value_str contains one line for each (coordinate value) pair.
|
||||
# Indices are 1-based in TNS text format but 0-based in MLIR.
|
||||
coo_value_str = "\n".join(
|
||||
map(
|
||||
lambda i: _to_string(_add_one(t.indices[i])) + " " + str(t.values[i]),
|
||||
range(t.nse),
|
||||
)
|
||||
)
|
||||
|
||||
# Returns the TNS text format representation for the tensor.
|
||||
return f"""{t.rank} {t.nse}
|
||||
# Returns the TNS text format representation for the tensor.
|
||||
return f"""{t.rank} {t.nse}
|
||||
{_to_string(t.shape)}
|
||||
{coo_value_str}
|
||||
"""
|
||||
|
||||
|
||||
def _implement_read_tns_test(
|
||||
t: _SparseTensorCOO,
|
||||
sparsity_codes: Sequence[sparse_tensor.DimLevelType]) -> int:
|
||||
tns_data = _coo_values_to_tns_format(t)
|
||||
t: _SparseTensorCOO, sparsity_codes: Sequence[sparse_tensor.DimLevelType]
|
||||
) -> int:
|
||||
tns_data = _coo_values_to_tns_format(t)
|
||||
|
||||
# Write sparse tensor data to a file.
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
file_name = os.path.join(test_dir, "data.tns")
|
||||
with open(file_name, "w") as file:
|
||||
file.write(tns_data)
|
||||
# Write sparse tensor data to a file.
|
||||
with tempfile.TemporaryDirectory() as test_dir:
|
||||
file_name = os.path.join(test_dir, "data.tns")
|
||||
with open(file_name, "w") as file:
|
||||
file.write(tns_data)
|
||||
|
||||
# Read the data from the file and construct an MLIR sparse tensor.
|
||||
sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor(
|
||||
file_name, sparsity_codes, "f64")
|
||||
# Read the data from the file and construct an MLIR sparse tensor.
|
||||
sparse_tensor, o_shape = pytaco_utils.create_sparse_tensor(
|
||||
file_name, sparsity_codes, "f64"
|
||||
)
|
||||
|
||||
passed = 0
|
||||
passed = 0
|
||||
|
||||
# Verify the output shape for the tensor.
|
||||
if np.array_equal(o_shape, t.shape):
|
||||
passed += 1
|
||||
# Verify the output shape for the tensor.
|
||||
if np.array_equal(o_shape, t.shape):
|
||||
passed += 1
|
||||
|
||||
# Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
|
||||
# values and verify the values.
|
||||
o_rank, o_nse, o_shape, o_values, o_indices = (
|
||||
pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64))
|
||||
if o_rank == t.rank and o_nse == t.nse and np.array_equal(
|
||||
o_shape, t.shape) and np.allclose(o_values, t.values) and np.array_equal(
|
||||
o_indices, t.indices):
|
||||
passed += 1
|
||||
# Use the output MLIR sparse tensor pointer to retrieve the COO-flavored
|
||||
# values and verify the values.
|
||||
(
|
||||
o_rank,
|
||||
o_nse,
|
||||
o_shape,
|
||||
o_values,
|
||||
o_indices,
|
||||
) = pytaco_utils.sparse_tensor_to_coo_tensor(sparse_tensor, np.float64)
|
||||
if (
|
||||
o_rank == t.rank
|
||||
and o_nse == t.nse
|
||||
and np.array_equal(o_shape, t.shape)
|
||||
and np.allclose(o_values, t.values)
|
||||
and np.array_equal(o_indices, t.indices)
|
||||
):
|
||||
passed += 1
|
||||
|
||||
return passed
|
||||
return passed
|
||||
|
||||
|
||||
# A 2D sparse tensor data in COO-flavored format.
|
||||
|
||||
@@ -5,11 +5,11 @@ if not config.mlir_run_amx_tests:
|
||||
config.unsupported = True
|
||||
|
||||
# No JIT on win32.
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
config.unsupported = True
|
||||
|
||||
if config.intel_sde_executable:
|
||||
# Run test in emulator (Intel SDE): AMX needs Sapphire Rapids CPU.
|
||||
config.substitutions.append(('%lli', config.intel_sde_executable + ' -spr -- lli'))
|
||||
config.substitutions.append(("%lli", config.intel_sde_executable + " -spr -- lli"))
|
||||
else:
|
||||
config.substitutions.append(('%lli', 'lli'))
|
||||
config.substitutions.append(("%lli", "lli"))
|
||||
|
||||
@@ -5,5 +5,5 @@ if not config.mlir_run_arm_sme_tests:
|
||||
config.unsupported = True
|
||||
|
||||
# No JIT on win32.
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
config.unsupported = True
|
||||
|
||||
@@ -5,5 +5,5 @@ if not config.mlir_run_arm_sve_tests:
|
||||
config.unsupported = True
|
||||
|
||||
# No JIT on win32.
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
config.unsupported = True
|
||||
|
||||
@@ -5,11 +5,11 @@ if not config.mlir_run_x86vector_tests:
|
||||
config.unsupported = True
|
||||
|
||||
# No JIT on win32.
|
||||
if sys.platform == 'win32':
|
||||
if sys.platform == "win32":
|
||||
config.unsupported = True
|
||||
|
||||
if config.intel_sde_executable:
|
||||
# Run test in emulator (Intel SDE).
|
||||
config.substitutions.append(('%lli', config.intel_sde_executable + ' -tgl -- lli'))
|
||||
config.substitutions.append(("%lli", config.intel_sde_executable + " -tgl -- lli"))
|
||||
else:
|
||||
config.substitutions.append(('%lli', 'lli'))
|
||||
config.substitutions.append(("%lli", "lli"))
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
if not config.enable_cuda_runner:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
@@ -2,4 +2,4 @@ import sys
|
||||
|
||||
# TensorCore tests must be enabled via build flag.
|
||||
if not config.mlir_run_cuda_tensor_core_tests:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
if not config.enable_cuda_runner:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
if not config.enable_rocm_runner or not config.rocm_test_chipset:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
config.substitutions.append(('%chip', config.rocm_test_chipset))
|
||||
config.substitutions.append(("%chip", config.rocm_test_chipset))
|
||||
|
||||
@@ -3,8 +3,9 @@ from lit.llvm import llvm_config
|
||||
if not config.mlir_include_integration_tests:
|
||||
config.unsupported = True
|
||||
|
||||
|
||||
def configure_aarch64_lli_cmd():
|
||||
lli_cmd = 'lli'
|
||||
lli_cmd = "lli"
|
||||
|
||||
# NOTE: If the SVE tests are disabled and the SME tests are enabled to run
|
||||
# under emulation, the SVE specific RUN lines in the SparseTensor tests
|
||||
@@ -12,8 +13,12 @@ def configure_aarch64_lli_cmd():
|
||||
if not (config.mlir_run_arm_sve_tests or config.mlir_run_arm_sme_tests):
|
||||
return lli_cmd
|
||||
|
||||
config.substitutions.append(('%mlir_native_utils_lib_dir',
|
||||
config.arm_emulator_utils_lib_dir or config.mlir_lib_dir))
|
||||
config.substitutions.append(
|
||||
(
|
||||
"%mlir_native_utils_lib_dir",
|
||||
config.arm_emulator_utils_lib_dir or config.mlir_lib_dir,
|
||||
)
|
||||
)
|
||||
|
||||
if config.arm_emulator_executable:
|
||||
if config.arm_emulator_lli_executable:
|
||||
@@ -23,16 +28,22 @@ def configure_aarch64_lli_cmd():
|
||||
# when running under an emulator. If the user didn't specify an lli
|
||||
# executable, use absolute path %llvm_tools_dir/lli.
|
||||
lli_cmd = llvm_config.use_llvm_tool(
|
||||
'lli', search_env='LLI', required=True,
|
||||
search_paths=[config.llvm_tools_dir], use_installed=False
|
||||
"lli",
|
||||
search_env="LLI",
|
||||
required=True,
|
||||
search_paths=[config.llvm_tools_dir],
|
||||
use_installed=False,
|
||||
)
|
||||
|
||||
# Run test in emulator (qemu or armie)
|
||||
emulation_cmd = f'{config.arm_emulator_executable} {config.arm_emulator_options}'
|
||||
lli_cmd = f'{emulation_cmd} {lli_cmd}'
|
||||
emulation_cmd = (
|
||||
f"{config.arm_emulator_executable} {config.arm_emulator_options}"
|
||||
)
|
||||
lli_cmd = f"{emulation_cmd} {lli_cmd}"
|
||||
|
||||
return lli_cmd
|
||||
|
||||
|
||||
aarch64_lli_cmd = configure_aarch64_lli_cmd()
|
||||
|
||||
# Configure the following AArch64 substitutions:
|
||||
@@ -52,5 +63,5 @@ aarch64_lli_cmd = configure_aarch64_lli_cmd()
|
||||
# could be used in the SparseTensor tests where necessary, but the meaning
|
||||
# conveyed by the substitution name would be a misnomer if the host target
|
||||
# is not AArch64 and MLIR_RUN_ARM_SVE_TESTS=OFF.
|
||||
config.substitutions.append(('%lli_aarch64_cmd', aarch64_lli_cmd))
|
||||
config.substitutions.append(('%lli_host_or_aarch64_cmd', aarch64_lli_cmd))
|
||||
config.substitutions.append(("%lli_aarch64_cmd", aarch64_lli_cmd))
|
||||
config.substitutions.append(("%lli_host_or_aarch64_cmd", aarch64_lli_cmd))
|
||||
|
||||
@@ -8,43 +8,43 @@ import subprocess
|
||||
import lit.formats
|
||||
|
||||
# name: The name of this test suite.
|
||||
config.name = 'MLIR-Unit'
|
||||
config.name = "MLIR-Unit"
|
||||
|
||||
# suffixes: A list of file extensions to treat as test files.
|
||||
config.suffixes = []
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.mlir_obj_root, 'unittests')
|
||||
config.test_exec_root = os.path.join(config.mlir_obj_root, "unittests")
|
||||
config.test_source_root = config.test_exec_root
|
||||
|
||||
# testFormat: The test format to use to interpret tests.
|
||||
config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, 'Tests')
|
||||
config.test_format = lit.formats.GoogleTest(config.llvm_build_mode, "Tests")
|
||||
|
||||
# Propagate the temp directory. Windows requires this because it uses \Windows\
|
||||
# if none of these are present.
|
||||
if 'TMP' in os.environ:
|
||||
config.environment['TMP'] = os.environ['TMP']
|
||||
if 'TEMP' in os.environ:
|
||||
config.environment['TEMP'] = os.environ['TEMP']
|
||||
if "TMP" in os.environ:
|
||||
config.environment["TMP"] = os.environ["TMP"]
|
||||
if "TEMP" in os.environ:
|
||||
config.environment["TEMP"] = os.environ["TEMP"]
|
||||
|
||||
# Propagate HOME as it can be used to override incorrect homedir in passwd
|
||||
# that causes the tests to fail.
|
||||
if 'HOME' in os.environ:
|
||||
config.environment['HOME'] = os.environ['HOME']
|
||||
if "HOME" in os.environ:
|
||||
config.environment["HOME"] = os.environ["HOME"]
|
||||
|
||||
# Propagate sanitizer options.
|
||||
for var in [
|
||||
'ASAN_SYMBOLIZER_PATH',
|
||||
'HWASAN_SYMBOLIZER_PATH',
|
||||
'MSAN_SYMBOLIZER_PATH',
|
||||
'TSAN_SYMBOLIZER_PATH',
|
||||
'UBSAN_SYMBOLIZER_PATH',
|
||||
'ASAN_OPTIONS',
|
||||
'HWASAN_OPTIONS',
|
||||
'MSAN_OPTIONS',
|
||||
'TSAN_OPTIONS',
|
||||
'UBSAN_OPTIONS',
|
||||
"ASAN_SYMBOLIZER_PATH",
|
||||
"HWASAN_SYMBOLIZER_PATH",
|
||||
"MSAN_SYMBOLIZER_PATH",
|
||||
"TSAN_SYMBOLIZER_PATH",
|
||||
"UBSAN_SYMBOLIZER_PATH",
|
||||
"ASAN_OPTIONS",
|
||||
"HWASAN_OPTIONS",
|
||||
"MSAN_OPTIONS",
|
||||
"TSAN_OPTIONS",
|
||||
"UBSAN_OPTIONS",
|
||||
]:
|
||||
if var in os.environ:
|
||||
config.environment[var] = os.environ[var]
|
||||
|
||||
@@ -1 +1 @@
|
||||
config.suffixes.remove('.td')
|
||||
config.suffixes.remove(".td")
|
||||
|
||||
@@ -1 +1 @@
|
||||
config.suffixes.remove('.td')
|
||||
config.suffixes.remove(".td")
|
||||
|
||||
@@ -1 +1 @@
|
||||
config.suffixes.remove('.pdll')
|
||||
config.suffixes.remove(".pdll")
|
||||
|
||||
@@ -1 +1 @@
|
||||
config.suffixes.remove('.pdll')
|
||||
config.suffixes.remove(".pdll")
|
||||
|
||||
@@ -16,21 +16,32 @@ from lit.llvm.subst import FindTool
|
||||
# Configuration file for the 'lit' test runner.
|
||||
|
||||
# name: The name of this test suite.
|
||||
config.name = 'MLIR'
|
||||
config.name = "MLIR"
|
||||
|
||||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||
|
||||
# suffixes: A list of file extensions to treat as test files.
|
||||
config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc', '.py', '.yaml', '.test', '.pdll', '.c']
|
||||
config.suffixes = [
|
||||
".td",
|
||||
".mlir",
|
||||
".toy",
|
||||
".ll",
|
||||
".tc",
|
||||
".py",
|
||||
".yaml",
|
||||
".test",
|
||||
".pdll",
|
||||
".c",
|
||||
]
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
||||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.mlir_obj_root, 'test')
|
||||
config.test_exec_root = os.path.join(config.mlir_obj_root, "test")
|
||||
|
||||
config.substitutions.append(('%PATH%', config.environment['PATH']))
|
||||
config.substitutions.append(('%shlibext', config.llvm_shlib_ext))
|
||||
config.substitutions.append(("%PATH%", config.environment["PATH"]))
|
||||
config.substitutions.append(("%shlibext", config.llvm_shlib_ext))
|
||||
config.substitutions.append(("%mlir_src_root", config.mlir_src_root))
|
||||
config.substitutions.append(("%host_cxx", config.host_cxx))
|
||||
config.substitutions.append(("%host_cc", config.host_cc))
|
||||
@@ -40,94 +51,109 @@ config.substitutions.append(("%host_cc", config.host_cc))
|
||||
# substitution of the same name and the found path.
|
||||
# Correctly handles the platforms shared library directory and naming conventions.
|
||||
def add_runtime(name):
|
||||
path = ''
|
||||
for prefix in ['', 'lib']:
|
||||
path = os.path.join(config.llvm_shlib_dir, f'{prefix}{name}{config.llvm_shlib_ext}')
|
||||
path = ""
|
||||
for prefix in ["", "lib"]:
|
||||
path = os.path.join(
|
||||
config.llvm_shlib_dir, f"{prefix}{name}{config.llvm_shlib_ext}"
|
||||
)
|
||||
if os.path.isfile(path):
|
||||
break
|
||||
return ToolSubst(f'%{name}', path)
|
||||
return ToolSubst(f"%{name}", path)
|
||||
|
||||
|
||||
llvm_config.with_system_environment(
|
||||
['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP'])
|
||||
llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
|
||||
|
||||
llvm_config.use_default_substitutions()
|
||||
|
||||
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
||||
# subdirectories contain auxiliary inputs for various tests in their parent
|
||||
# directories.
|
||||
config.excludes = ['Inputs', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt',
|
||||
'lit.cfg.py', 'lit.site.cfg.py']
|
||||
config.excludes = [
|
||||
"Inputs",
|
||||
"CMakeLists.txt",
|
||||
"README.txt",
|
||||
"LICENSE.txt",
|
||||
"lit.cfg.py",
|
||||
"lit.site.cfg.py",
|
||||
]
|
||||
|
||||
# Tweak the PATH to include the tools dir.
|
||||
llvm_config.with_environment('PATH', config.mlir_tools_dir, append_path=True)
|
||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||
llvm_config.with_environment("PATH", config.mlir_tools_dir, append_path=True)
|
||||
llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True)
|
||||
|
||||
tool_dirs = [config.mlir_tools_dir, config.llvm_tools_dir]
|
||||
tools = [
|
||||
'mlir-tblgen',
|
||||
'mlir-translate',
|
||||
'mlir-lsp-server',
|
||||
'mlir-capi-execution-engine-test',
|
||||
'mlir-capi-ir-test',
|
||||
'mlir-capi-llvm-test',
|
||||
'mlir-capi-pass-test',
|
||||
'mlir-capi-pdl-test',
|
||||
'mlir-capi-quant-test',
|
||||
'mlir-capi-sparse-tensor-test',
|
||||
'mlir-capi-transform-test',
|
||||
'mlir-cpu-runner',
|
||||
add_runtime('mlir_runner_utils'),
|
||||
add_runtime('mlir_c_runner_utils'),
|
||||
add_runtime('mlir_async_runtime'),
|
||||
'mlir-linalg-ods-yaml-gen',
|
||||
'mlir-reduce',
|
||||
'mlir-pdll',
|
||||
'not',
|
||||
"mlir-tblgen",
|
||||
"mlir-translate",
|
||||
"mlir-lsp-server",
|
||||
"mlir-capi-execution-engine-test",
|
||||
"mlir-capi-ir-test",
|
||||
"mlir-capi-llvm-test",
|
||||
"mlir-capi-pass-test",
|
||||
"mlir-capi-pdl-test",
|
||||
"mlir-capi-quant-test",
|
||||
"mlir-capi-sparse-tensor-test",
|
||||
"mlir-capi-transform-test",
|
||||
"mlir-cpu-runner",
|
||||
add_runtime("mlir_runner_utils"),
|
||||
add_runtime("mlir_c_runner_utils"),
|
||||
add_runtime("mlir_async_runtime"),
|
||||
"mlir-linalg-ods-yaml-gen",
|
||||
"mlir-reduce",
|
||||
"mlir-pdll",
|
||||
"not",
|
||||
]
|
||||
|
||||
if config.enable_spirv_cpu_runner:
|
||||
tools.extend(['mlir-spirv-cpu-runner', add_runtime('mlir_test_spirv_cpu_runner_c_wrappers')])
|
||||
tools.extend(
|
||||
["mlir-spirv-cpu-runner", add_runtime("mlir_test_spirv_cpu_runner_c_wrappers")]
|
||||
)
|
||||
|
||||
if config.enable_vulkan_runner:
|
||||
tools.extend([add_runtime('vulkan-runtime-wrappers')])
|
||||
tools.extend([add_runtime("vulkan-runtime-wrappers")])
|
||||
|
||||
if config.enable_rocm_runner:
|
||||
tools.extend([add_runtime('mlir_rocm_runtime')])
|
||||
tools.extend([add_runtime("mlir_rocm_runtime")])
|
||||
|
||||
if config.enable_cuda_runner:
|
||||
tools.extend([add_runtime('mlir_cuda_runtime')])
|
||||
tools.extend([add_runtime("mlir_cuda_runtime")])
|
||||
|
||||
# The following tools are optional
|
||||
tools.extend([
|
||||
ToolSubst('toyc-ch1', unresolved='ignore'),
|
||||
ToolSubst('toyc-ch2', unresolved='ignore'),
|
||||
ToolSubst('toyc-ch3', unresolved='ignore'),
|
||||
ToolSubst('toyc-ch4', unresolved='ignore'),
|
||||
ToolSubst('toyc-ch5', unresolved='ignore'),
|
||||
ToolSubst('toyc-ch6', unresolved='ignore'),
|
||||
ToolSubst('toyc-ch7', unresolved='ignore'),
|
||||
ToolSubst('%mlir_lib_dir', config.mlir_lib_dir, unresolved='ignore'),
|
||||
ToolSubst('%mlir_src_dir', config.mlir_src_root, unresolved='ignore'),
|
||||
])
|
||||
tools.extend(
|
||||
[
|
||||
ToolSubst("toyc-ch1", unresolved="ignore"),
|
||||
ToolSubst("toyc-ch2", unresolved="ignore"),
|
||||
ToolSubst("toyc-ch3", unresolved="ignore"),
|
||||
ToolSubst("toyc-ch4", unresolved="ignore"),
|
||||
ToolSubst("toyc-ch5", unresolved="ignore"),
|
||||
ToolSubst("toyc-ch6", unresolved="ignore"),
|
||||
ToolSubst("toyc-ch7", unresolved="ignore"),
|
||||
ToolSubst("%mlir_lib_dir", config.mlir_lib_dir, unresolved="ignore"),
|
||||
ToolSubst("%mlir_src_dir", config.mlir_src_root, unresolved="ignore"),
|
||||
]
|
||||
)
|
||||
|
||||
python_executable = config.python_executable
|
||||
# Python configuration with sanitizer requires some magic preloading. This will only work on clang/linux.
|
||||
# TODO: detect Darwin/Windows situation (or mark these tests as unsupported on these platforms).
|
||||
if "asan" in config.available_features and "Linux" in config.host_os:
|
||||
python_executable = f"LD_PRELOAD=$({config.host_cxx} -print-file-name=libclang_rt.asan-{config.host_arch}.so) {config.python_executable}"
|
||||
python_executable = f"LD_PRELOAD=$({config.host_cxx} -print-file-name=libclang_rt.asan-{config.host_arch}.so) {config.python_executable}"
|
||||
# On Windows the path to python could contains spaces in which case it needs to be provided in quotes.
|
||||
# This is the equivalent of how %python is setup in llvm/utils/lit/lit/llvm/config.py.
|
||||
elif "Windows" in config.host_os:
|
||||
python_executable = '"%s"' % (python_executable)
|
||||
tools.extend([
|
||||
ToolSubst('%PYTHON', python_executable, unresolved='ignore'),
|
||||
])
|
||||
python_executable = '"%s"' % (python_executable)
|
||||
tools.extend(
|
||||
[
|
||||
ToolSubst("%PYTHON", python_executable, unresolved="ignore"),
|
||||
]
|
||||
)
|
||||
|
||||
if "MLIR_OPT_CHECK_IR_ROUNDTRIP" in os.environ:
|
||||
tools.extend([
|
||||
ToolSubst('mlir-opt', 'mlir-opt --verify-roundtrip', unresolved='fatal'),
|
||||
])
|
||||
tools.extend(
|
||||
[
|
||||
ToolSubst("mlir-opt", "mlir-opt --verify-roundtrip", unresolved="fatal"),
|
||||
]
|
||||
)
|
||||
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
||||
@@ -135,40 +161,48 @@ llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
# FileCheck -enable-var-scope is enabled by default in MLIR test
|
||||
# This option avoids to accidentally reuse variable across -LABEL match,
|
||||
# it can be explicitly opted-in by prefixing the variable name with $
|
||||
config.environment['FILECHECK_OPTS'] = "-enable-var-scope --allow-unused-prefixes=false"
|
||||
config.environment["FILECHECK_OPTS"] = "-enable-var-scope --allow-unused-prefixes=false"
|
||||
|
||||
# Add the python path for both the source and binary tree.
|
||||
# Note that presently, the python sources come from the source tree and the
|
||||
# binaries come from the build tree. This should be unified to the build tree
|
||||
# by copying/linking sources to build.
|
||||
if config.enable_bindings_python:
|
||||
llvm_config.with_environment('PYTHONPATH', [
|
||||
os.path.join(config.mlir_obj_root, 'python_packages', 'mlir_core'),
|
||||
os.path.join(config.mlir_obj_root, 'python_packages', 'mlir_test'),
|
||||
], append_path=True)
|
||||
llvm_config.with_environment(
|
||||
"PYTHONPATH",
|
||||
[
|
||||
os.path.join(config.mlir_obj_root, "python_packages", "mlir_core"),
|
||||
os.path.join(config.mlir_obj_root, "python_packages", "mlir_test"),
|
||||
],
|
||||
append_path=True,
|
||||
)
|
||||
|
||||
if config.enable_assertions:
|
||||
config.available_features.add('asserts')
|
||||
config.available_features.add("asserts")
|
||||
else:
|
||||
config.available_features.add('noasserts')
|
||||
config.available_features.add("noasserts")
|
||||
|
||||
|
||||
def have_host_jit_feature_support(feature_name):
|
||||
mlir_cpu_runner_exe = lit.util.which('mlir-cpu-runner', config.mlir_tools_dir)
|
||||
mlir_cpu_runner_exe = lit.util.which("mlir-cpu-runner", config.mlir_tools_dir)
|
||||
|
||||
if not mlir_cpu_runner_exe:
|
||||
return False
|
||||
if not mlir_cpu_runner_exe:
|
||||
return False
|
||||
|
||||
try:
|
||||
mlir_cpu_runner_cmd = subprocess.Popen(
|
||||
[mlir_cpu_runner_exe, '--host-supports-' + feature_name], stdout=subprocess.PIPE)
|
||||
except OSError:
|
||||
print('could not exec mlir-cpu-runner')
|
||||
return False
|
||||
try:
|
||||
mlir_cpu_runner_cmd = subprocess.Popen(
|
||||
[mlir_cpu_runner_exe, "--host-supports-" + feature_name],
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
except OSError:
|
||||
print("could not exec mlir-cpu-runner")
|
||||
return False
|
||||
|
||||
mlir_cpu_runner_out = mlir_cpu_runner_cmd.stdout.read().decode('ascii')
|
||||
mlir_cpu_runner_cmd.wait()
|
||||
mlir_cpu_runner_out = mlir_cpu_runner_cmd.stdout.read().decode("ascii")
|
||||
mlir_cpu_runner_cmd.wait()
|
||||
|
||||
return 'true' in mlir_cpu_runner_out
|
||||
return "true" in mlir_cpu_runner_out
|
||||
|
||||
if have_host_jit_feature_support('jit'):
|
||||
config.available_features.add('host-supports-jit')
|
||||
|
||||
if have_host_jit_feature_support("jit"):
|
||||
config.available_features.add("host-supports-jit")
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import sys
|
||||
|
||||
# MSAN does not work with JIT.
|
||||
if 'msan' in config.available_features:
|
||||
config.unsupported = True
|
||||
|
||||
# Requires native execution.
|
||||
if 'host-supports-jit' not in config.available_features:
|
||||
if "msan" in config.available_features:
|
||||
config.unsupported = True
|
||||
|
||||
config.available_features.add(
|
||||
config.root.native_target.lower() + '-native-target')
|
||||
# Requires native execution.
|
||||
if "host-supports-jit" not in config.available_features:
|
||||
config.unsupported = True
|
||||
|
||||
config.available_features.add(config.root.native_target.lower() + "-native-target")
|
||||
|
||||
@@ -1 +1 @@
|
||||
config.excludes = ['include']
|
||||
config.excludes = ["include"]
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
config.suffixes = ['.pdll', '.mlir']
|
||||
config.excludes = ['include']
|
||||
config.suffixes = [".pdll", ".mlir"]
|
||||
config.excludes = ["include"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import sys
|
||||
|
||||
if not config.enable_spirv_cpu_runner:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
if not config.enable_vulkan_runner:
|
||||
config.unsupported = True
|
||||
config.unsupported = True
|
||||
|
||||
@@ -14,5 +14,6 @@ expected_lib_name = "MLIRPythonCAPI"
|
||||
all_libs = os.listdir(get_lib_dirs()[0])
|
||||
found_lib = False
|
||||
for file_name in all_libs:
|
||||
if expected_lib_name in file_name: found_lib = True
|
||||
if expected_lib_name in file_name:
|
||||
found_lib = True
|
||||
assert found_lib, f"Did not find '{expected_lib_name}' lib in {all_libs}"
|
||||
|
||||
@@ -4,16 +4,18 @@ from mlir.ir import *
|
||||
import mlir.dialects.func as func
|
||||
import mlir.dialects.arith as arith
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testConstantOp
|
||||
@run
|
||||
def testConstantOps():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
arith.ConstantOp(value=42.42, result=F32Type.get())
|
||||
# CHECK: %cst = arith.constant 4.242000e+01 : f32
|
||||
print(module)
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
arith.ConstantOp(value=42.42, result=F32Type.get())
|
||||
# CHECK: %cst = arith.constant 4.242000e+01 : f32
|
||||
print(module)
|
||||
|
||||
@@ -5,14 +5,17 @@ import mlir.dialects.async_dialect
|
||||
import mlir.dialects.async_dialect.passes
|
||||
from mlir.passmanager import *
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
|
||||
|
||||
def testAsyncPass():
|
||||
with Context() as context:
|
||||
PassManager.parse('any(async-to-async-runtime)')
|
||||
print('SUCCESS')
|
||||
with Context() as context:
|
||||
PassManager.parse("any(async-to-async-runtime)")
|
||||
print("SUCCESS")
|
||||
|
||||
|
||||
# CHECK-LABEL: testAsyncPass
|
||||
# CHECK: SUCCESS
|
||||
|
||||
@@ -7,232 +7,242 @@ import numpy as np
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
return f
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testFromPyFunc
|
||||
@run
|
||||
def testFromPyFunc():
|
||||
with Context() as ctx, Location.unknown() as loc:
|
||||
ctx.allow_unregistered_dialects = True
|
||||
m = builtin.ModuleOp()
|
||||
f32 = F32Type.get()
|
||||
f64 = F64Type.get()
|
||||
with InsertionPoint(m.body):
|
||||
# CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
|
||||
# CHECK: return %arg0 : f64
|
||||
@func.FuncOp.from_py_func(f64)
|
||||
def unary_return(a):
|
||||
return a
|
||||
with Context() as ctx, Location.unknown() as loc:
|
||||
ctx.allow_unregistered_dialects = True
|
||||
m = builtin.ModuleOp()
|
||||
f32 = F32Type.get()
|
||||
f64 = F64Type.get()
|
||||
with InsertionPoint(m.body):
|
||||
# CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
|
||||
# CHECK: return %arg0 : f64
|
||||
@func.FuncOp.from_py_func(f64)
|
||||
def unary_return(a):
|
||||
return a
|
||||
|
||||
# CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
|
||||
# CHECK: return %arg0, %arg1 : f32, f64
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def binary_return(a, b):
|
||||
return a, b
|
||||
# CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
|
||||
# CHECK: return %arg0, %arg1 : f32, f64
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def binary_return(a, b):
|
||||
return a, b
|
||||
|
||||
# CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
|
||||
# CHECK: return
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def none_return(a, b):
|
||||
pass
|
||||
# CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
|
||||
# CHECK: return
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def none_return(a, b):
|
||||
pass
|
||||
|
||||
# CHECK-LABEL: func @call_unary
|
||||
# CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
|
||||
# CHECK: return %0 : f64
|
||||
@func.FuncOp.from_py_func(f64)
|
||||
def call_unary(a):
|
||||
return unary_return(a)
|
||||
# CHECK-LABEL: func @call_unary
|
||||
# CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
|
||||
# CHECK: return %0 : f64
|
||||
@func.FuncOp.from_py_func(f64)
|
||||
def call_unary(a):
|
||||
return unary_return(a)
|
||||
|
||||
# CHECK-LABEL: func @call_binary
|
||||
# CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
|
||||
# CHECK: return %0#0, %0#1 : f32, f64
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def call_binary(a, b):
|
||||
return binary_return(a, b)
|
||||
# CHECK-LABEL: func @call_binary
|
||||
# CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
|
||||
# CHECK: return %0#0, %0#1 : f32, f64
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def call_binary(a, b):
|
||||
return binary_return(a, b)
|
||||
|
||||
# We expect coercion of a single result operation to a returned value.
|
||||
# CHECK-LABEL: func @single_result_op
|
||||
# CHECK: %0 = "custom.op1"() : () -> f32
|
||||
# CHECK: return %0 : f32
|
||||
@func.FuncOp.from_py_func()
|
||||
def single_result_op():
|
||||
return Operation.create("custom.op1", results=[f32])
|
||||
# We expect coercion of a single result operation to a returned value.
|
||||
# CHECK-LABEL: func @single_result_op
|
||||
# CHECK: %0 = "custom.op1"() : () -> f32
|
||||
# CHECK: return %0 : f32
|
||||
@func.FuncOp.from_py_func()
|
||||
def single_result_op():
|
||||
return Operation.create("custom.op1", results=[f32])
|
||||
|
||||
# CHECK-LABEL: func @call_none
|
||||
# CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
|
||||
# CHECK: return
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def call_none(a, b):
|
||||
return none_return(a, b)
|
||||
# CHECK-LABEL: func @call_none
|
||||
# CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
|
||||
# CHECK: return
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def call_none(a, b):
|
||||
return none_return(a, b)
|
||||
|
||||
## Variants and optional feature tests.
|
||||
# CHECK-LABEL: func @from_name_arg
|
||||
@func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
|
||||
def explicit_name(a, b):
|
||||
return b
|
||||
## Variants and optional feature tests.
|
||||
# CHECK-LABEL: func @from_name_arg
|
||||
@func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
|
||||
def explicit_name(a, b):
|
||||
return b
|
||||
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def positional_func_op(a, b, func_op):
|
||||
assert isinstance(func_op, func.FuncOp)
|
||||
return b
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def positional_func_op(a, b, func_op):
|
||||
assert isinstance(func_op, func.FuncOp)
|
||||
return b
|
||||
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def kw_func_op(a, b=None, func_op=None):
|
||||
assert isinstance(func_op, func.FuncOp)
|
||||
return b
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def kw_func_op(a, b=None, func_op=None):
|
||||
assert isinstance(func_op, func.FuncOp)
|
||||
return b
|
||||
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def kwargs_func_op(a, b=None, **kwargs):
|
||||
assert isinstance(kwargs["func_op"], func.FuncOp)
|
||||
return b
|
||||
@func.FuncOp.from_py_func(f32, f64)
|
||||
def kwargs_func_op(a, b=None, **kwargs):
|
||||
assert isinstance(kwargs["func_op"], func.FuncOp)
|
||||
return b
|
||||
|
||||
# CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
|
||||
# CHECK: return %arg1 : f64
|
||||
@func.FuncOp.from_py_func(f32, f64, results=[f64])
|
||||
def explicit_results(a, b):
|
||||
func.ReturnOp([b])
|
||||
# CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
|
||||
# CHECK: return %arg1 : f64
|
||||
@func.FuncOp.from_py_func(f32, f64, results=[f64])
|
||||
def explicit_results(a, b):
|
||||
func.ReturnOp([b])
|
||||
|
||||
print(m)
|
||||
print(m)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testFromPyFuncErrors
|
||||
@run
|
||||
def testFromPyFuncErrors():
|
||||
with Context() as ctx, Location.unknown() as loc:
|
||||
m = builtin.ModuleOp()
|
||||
f32 = F32Type.get()
|
||||
f64 = F64Type.get()
|
||||
with InsertionPoint(m.body):
|
||||
try:
|
||||
with Context() as ctx, Location.unknown() as loc:
|
||||
m = builtin.ModuleOp()
|
||||
f32 = F32Type.get()
|
||||
f64 = F64Type.get()
|
||||
with InsertionPoint(m.body):
|
||||
try:
|
||||
|
||||
@func.FuncOp.from_py_func(f64, results=[f64])
|
||||
def unary_return(a):
|
||||
return a
|
||||
except AssertionError as e:
|
||||
# CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
|
||||
print(e)
|
||||
@func.FuncOp.from_py_func(f64, results=[f64])
|
||||
def unary_return(a):
|
||||
return a
|
||||
|
||||
except AssertionError as e:
|
||||
# CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
|
||||
print(e)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testBuildFuncOp
|
||||
@run
|
||||
def testBuildFuncOp():
|
||||
ctx = Context()
|
||||
with Location.unknown(ctx) as loc:
|
||||
m = builtin.ModuleOp()
|
||||
ctx = Context()
|
||||
with Location.unknown(ctx) as loc:
|
||||
m = builtin.ModuleOp()
|
||||
|
||||
f32 = F32Type.get()
|
||||
tensor_type = RankedTensorType.get((2, 3, 4), f32)
|
||||
with InsertionPoint.at_block_begin(m.body):
|
||||
f = func.FuncOp(name="some_func",
|
||||
type=FunctionType.get(
|
||||
inputs=[tensor_type, tensor_type],
|
||||
results=[tensor_type]),
|
||||
visibility="nested")
|
||||
# CHECK: Name is: "some_func"
|
||||
print("Name is: ", f.name)
|
||||
f32 = F32Type.get()
|
||||
tensor_type = RankedTensorType.get((2, 3, 4), f32)
|
||||
with InsertionPoint.at_block_begin(m.body):
|
||||
f = func.FuncOp(
|
||||
name="some_func",
|
||||
type=FunctionType.get(
|
||||
inputs=[tensor_type, tensor_type], results=[tensor_type]
|
||||
),
|
||||
visibility="nested",
|
||||
)
|
||||
# CHECK: Name is: "some_func"
|
||||
print("Name is: ", f.name)
|
||||
|
||||
# CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
|
||||
print("Type is: ", f.type)
|
||||
# CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
|
||||
print("Type is: ", f.type)
|
||||
|
||||
# CHECK: Visibility is: "nested"
|
||||
print("Visibility is: ", f.visibility)
|
||||
# CHECK: Visibility is: "nested"
|
||||
print("Visibility is: ", f.visibility)
|
||||
|
||||
try:
|
||||
entry_block = f.entry_block
|
||||
except IndexError as e:
|
||||
# CHECK: External function does not have a body
|
||||
print(e)
|
||||
try:
|
||||
entry_block = f.entry_block
|
||||
except IndexError as e:
|
||||
# CHECK: External function does not have a body
|
||||
print(e)
|
||||
|
||||
with InsertionPoint(f.add_entry_block()):
|
||||
func.ReturnOp([f.entry_block.arguments[0]])
|
||||
pass
|
||||
with InsertionPoint(f.add_entry_block()):
|
||||
func.ReturnOp([f.entry_block.arguments[0]])
|
||||
pass
|
||||
|
||||
try:
|
||||
f.add_entry_block()
|
||||
except IndexError as e:
|
||||
# CHECK: The function already has an entry block!
|
||||
print(e)
|
||||
try:
|
||||
f.add_entry_block()
|
||||
except IndexError as e:
|
||||
# CHECK: The function already has an entry block!
|
||||
print(e)
|
||||
|
||||
# Try the callback builder and passing type as tuple.
|
||||
f = func.FuncOp(name="some_other_func",
|
||||
type=([tensor_type, tensor_type], [tensor_type]),
|
||||
visibility="nested",
|
||||
body_builder=lambda f: func.ReturnOp(
|
||||
[f.entry_block.arguments[0]]))
|
||||
# Try the callback builder and passing type as tuple.
|
||||
f = func.FuncOp(
|
||||
name="some_other_func",
|
||||
type=([tensor_type, tensor_type], [tensor_type]),
|
||||
visibility="nested",
|
||||
body_builder=lambda f: func.ReturnOp([f.entry_block.arguments[0]]),
|
||||
)
|
||||
|
||||
# CHECK: module {
|
||||
# CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||
# CHECK: return %arg0 : tensor<2x3x4xf32>
|
||||
# CHECK: }
|
||||
# CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||
# CHECK: return %arg0 : tensor<2x3x4xf32>
|
||||
# CHECK: }
|
||||
print(m)
|
||||
# CHECK: module {
|
||||
# CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||
# CHECK: return %arg0 : tensor<2x3x4xf32>
|
||||
# CHECK: }
|
||||
# CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||
# CHECK: return %arg0 : tensor<2x3x4xf32>
|
||||
# CHECK: }
|
||||
print(m)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testFuncArgumentAccess
|
||||
@run
|
||||
def testFuncArgumentAccess():
|
||||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
f64 = F64Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
|
||||
with InsertionPoint(f.add_entry_block()):
|
||||
func.ReturnOp(f.arguments)
|
||||
f.arg_attrs = ArrayAttr.get([
|
||||
DictAttr.get({
|
||||
"custom_dialect.foo": StringAttr.get("bar"),
|
||||
"custom_dialect.baz": UnitAttr.get()
|
||||
}),
|
||||
DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])})
|
||||
])
|
||||
f.result_attrs = ArrayAttr.get([
|
||||
DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
|
||||
DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)})
|
||||
])
|
||||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
f64 = F64Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
|
||||
with InsertionPoint(f.add_entry_block()):
|
||||
func.ReturnOp(f.arguments)
|
||||
f.arg_attrs = ArrayAttr.get(
|
||||
[
|
||||
DictAttr.get(
|
||||
{
|
||||
"custom_dialect.foo": StringAttr.get("bar"),
|
||||
"custom_dialect.baz": UnitAttr.get(),
|
||||
}
|
||||
),
|
||||
DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])}),
|
||||
]
|
||||
)
|
||||
f.result_attrs = ArrayAttr.get(
|
||||
[
|
||||
DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
|
||||
DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)}),
|
||||
]
|
||||
)
|
||||
|
||||
other = func.FuncOp("other_func", ([f32, f32], []))
|
||||
with InsertionPoint(other.add_entry_block()):
|
||||
func.ReturnOp([])
|
||||
other.arg_attrs = [
|
||||
DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
|
||||
DictAttr.get()
|
||||
]
|
||||
other = func.FuncOp("other_func", ([f32, f32], []))
|
||||
with InsertionPoint(other.add_entry_block()):
|
||||
func.ReturnOp([])
|
||||
other.arg_attrs = [
|
||||
DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
|
||||
DictAttr.get(),
|
||||
]
|
||||
|
||||
# CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
|
||||
print(f.arg_attrs)
|
||||
# CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
|
||||
print(f.arg_attrs)
|
||||
|
||||
# CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
|
||||
print(f.result_attrs)
|
||||
# CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
|
||||
print(f.result_attrs)
|
||||
|
||||
# CHECK: func @some_func(
|
||||
# CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
|
||||
# CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
|
||||
# CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
|
||||
# CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
|
||||
# CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
|
||||
#
|
||||
# CHECK: func @other_func(
|
||||
# CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
|
||||
# CHECK: %{{.*}}: f32)
|
||||
print(module)
|
||||
# CHECK: func @some_func(
|
||||
# CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
|
||||
# CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
|
||||
# CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
|
||||
# CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
|
||||
# CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
|
||||
#
|
||||
# CHECK: func @other_func(
|
||||
# CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
|
||||
# CHECK: %{{.*}}: f32)
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: testDenseElementsAttr
|
||||
@run
|
||||
def testDenseElementsAttr():
|
||||
with Context(), Location.unknown():
|
||||
values = np.arange(4, dtype=np.int32)
|
||||
i32 = IntegerType.get_signless(32)
|
||||
print(DenseElementsAttr.get(values, type=i32))
|
||||
# CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
|
||||
print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
|
||||
# CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
|
||||
print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
|
||||
# CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
|
||||
with Context(), Location.unknown():
|
||||
values = np.arange(4, dtype=np.int32)
|
||||
i32 = IntegerType.get_signless(32)
|
||||
print(DenseElementsAttr.get(values, type=i32))
|
||||
# CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
|
||||
print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
|
||||
# CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
|
||||
print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
|
||||
# CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
|
||||
|
||||
@@ -9,24 +9,24 @@ import mlir.dialects.complex as mlir_complex
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testComplexOps
|
||||
@run
|
||||
def testComplexOps():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
@func.FuncOp.from_py_func(ComplexType.get(F32Type.get()))
|
||||
def emit_add(arg):
|
||||
return mlir_complex.AddOp(arg, arg)
|
||||
@func.FuncOp.from_py_func(ComplexType.get(F32Type.get()))
|
||||
def emit_add(arg):
|
||||
return mlir_complex.AddOp(arg, arg)
|
||||
|
||||
# CHECK-LABEL: func @emit_add(
|
||||
# CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
|
||||
# CHECK: %[[RES:.*]] = complex.add %[[ARG]], %[[ARG]] : complex<f32>
|
||||
# CHECK: return %[[RES]] : complex<f32>
|
||||
# CHECK: }
|
||||
print(module)
|
||||
# CHECK-LABEL: func @emit_add(
|
||||
# CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
|
||||
# CHECK: %[[RES:.*]] = complex.add %[[ARG]], %[[ARG]] : complex<f32>
|
||||
# CHECK: return %[[RES]] : complex<f32>
|
||||
# CHECK: }
|
||||
print(module)
|
||||
|
||||
@@ -7,13 +7,13 @@ from mlir.dialects import func
|
||||
|
||||
|
||||
def constructAndPrintInModule(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
f()
|
||||
print(module)
|
||||
return f
|
||||
print("\nTEST:", f.__name__)
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
f()
|
||||
print(module)
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testConstantOp
|
||||
@@ -21,21 +21,21 @@ def constructAndPrintInModule(f):
|
||||
|
||||
@constructAndPrintInModule
|
||||
def testConstantOp():
|
||||
c1 = arith.ConstantOp(IntegerType.get_signless(32), 42)
|
||||
c2 = arith.ConstantOp(IntegerType.get_signless(64), 100)
|
||||
c3 = arith.ConstantOp(F32Type.get(), 3.14)
|
||||
c4 = arith.ConstantOp(F64Type.get(), 1.23)
|
||||
# CHECK: 42
|
||||
print(c1.literal_value)
|
||||
c1 = arith.ConstantOp(IntegerType.get_signless(32), 42)
|
||||
c2 = arith.ConstantOp(IntegerType.get_signless(64), 100)
|
||||
c3 = arith.ConstantOp(F32Type.get(), 3.14)
|
||||
c4 = arith.ConstantOp(F64Type.get(), 1.23)
|
||||
# CHECK: 42
|
||||
print(c1.literal_value)
|
||||
|
||||
# CHECK: 100
|
||||
print(c2.literal_value)
|
||||
# CHECK: 100
|
||||
print(c2.literal_value)
|
||||
|
||||
# CHECK: 3.140000104904175
|
||||
print(c3.literal_value)
|
||||
# CHECK: 3.140000104904175
|
||||
print(c3.literal_value)
|
||||
|
||||
# CHECK: 1.23
|
||||
print(c4.literal_value)
|
||||
# CHECK: 1.23
|
||||
print(c4.literal_value)
|
||||
|
||||
|
||||
# CHECK: = arith.constant 42 : i32
|
||||
@@ -47,17 +47,17 @@ def testConstantOp():
|
||||
# CHECK-LABEL: TEST: testVectorConstantOp
|
||||
@constructAndPrintInModule
|
||||
def testVectorConstantOp():
|
||||
int_type = IntegerType.get_signless(32)
|
||||
vec_type = VectorType.get([2, 2], int_type)
|
||||
c1 = arith.ConstantOp(
|
||||
vec_type,
|
||||
DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42)))
|
||||
try:
|
||||
print(c1.literal_value)
|
||||
except ValueError as e:
|
||||
assert "only integer and float constants have literal values" in str(e)
|
||||
else:
|
||||
assert False
|
||||
int_type = IntegerType.get_signless(32)
|
||||
vec_type = VectorType.get([2, 2], int_type)
|
||||
c1 = arith.ConstantOp(
|
||||
vec_type, DenseElementsAttr.get_splat(vec_type, IntegerAttr.get(int_type, 42))
|
||||
)
|
||||
try:
|
||||
print(c1.literal_value)
|
||||
except ValueError as e:
|
||||
assert "only integer and float constants have literal values" in str(e)
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
# CHECK: = arith.constant dense<42> : vector<2x2xi32>
|
||||
@@ -66,9 +66,9 @@ def testVectorConstantOp():
|
||||
# CHECK-LABEL: TEST: testConstantIndexOp
|
||||
@constructAndPrintInModule
|
||||
def testConstantIndexOp():
|
||||
c1 = arith.ConstantOp.create_index(10)
|
||||
# CHECK: 10
|
||||
print(c1.literal_value)
|
||||
c1 = arith.ConstantOp.create_index(10)
|
||||
# CHECK: 10
|
||||
print(c1.literal_value)
|
||||
|
||||
|
||||
# CHECK: = arith.constant 10 : index
|
||||
@@ -77,18 +77,18 @@ def testConstantIndexOp():
|
||||
# CHECK-LABEL: TEST: testFunctionCalls
|
||||
@constructAndPrintInModule
|
||||
def testFunctionCalls():
|
||||
foo = func.FuncOp("foo", ([], []))
|
||||
foo.sym_visibility = StringAttr.get("private")
|
||||
bar = func.FuncOp("bar", ([], [IndexType.get()]))
|
||||
bar.sym_visibility = StringAttr.get("private")
|
||||
qux = func.FuncOp("qux", ([], [F32Type.get()]))
|
||||
qux.sym_visibility = StringAttr.get("private")
|
||||
foo = func.FuncOp("foo", ([], []))
|
||||
foo.sym_visibility = StringAttr.get("private")
|
||||
bar = func.FuncOp("bar", ([], [IndexType.get()]))
|
||||
bar.sym_visibility = StringAttr.get("private")
|
||||
qux = func.FuncOp("qux", ([], [F32Type.get()]))
|
||||
qux.sym_visibility = StringAttr.get("private")
|
||||
|
||||
with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
|
||||
func.CallOp(foo, [])
|
||||
func.CallOp([IndexType.get()], "bar", [])
|
||||
func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
|
||||
func.ReturnOp([])
|
||||
with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
|
||||
func.CallOp(foo, [])
|
||||
func.CallOp([IndexType.get()], "bar", [])
|
||||
func.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
|
||||
func.ReturnOp([])
|
||||
|
||||
|
||||
# CHECK: func private @foo()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user