[mlir][python] enable memref.subview (#79393)

This commit is contained in:
Maksim Levental
2024-01-30 16:21:56 -06:00
committed by GitHub
parent a356e6ccad
commit 404af14f92
7 changed files with 516 additions and 159 deletions

View File

@@ -2,16 +2,30 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Provide a convenient name for sub-packages to resolve the main C-extension
# with a relative import.
from .._mlir_libs import _mlir as _cext
from typing import (
List as _List,
Optional as _Optional,
Sequence as _Sequence,
Tuple as _Tuple,
Type as _Type,
TypeVar as _TypeVar,
Union as _Union,
)
from .._mlir_libs import _mlir as _cext
from ..ir import (
ArrayAttr,
Attribute,
BoolAttr,
DenseI64ArrayAttr,
IntegerAttr,
IntegerType,
OpView,
Operation,
ShapedType,
Value,
)
__all__ = [
"equally_sized_accessor",
"get_default_loc_context",
@@ -138,3 +152,157 @@ SubClassValueT = _Type[_U]
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
ResultValueT = _Union[ResultValueTypeTuple]
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
StaticIntLike = _Union[int, IntegerAttr]
ValueLike = _Union[Operation, OpView, Value]
MixedInt = _Union[StaticIntLike, ValueLike]
IntOrAttrList = _Sequence[_Union[IntegerAttr, int]]
OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]]
BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]]
OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]]
MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]]
def _dispatch_dynamic_index_list(
indices: _Union[DynamicIndexList, ArrayAttr],
) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]:
"""Dispatches a list of indices to the appropriate form.
This is similar to the custom `DynamicIndexList` directive upstream:
provided indices may be in the form of dynamic SSA values or static values,
and they may be scalable (i.e., as a singleton list) or not. This function
dispatches each index into its respective form. It also extracts the SSA
values and static indices from various similar structures, respectively.
"""
dynamic_indices = []
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
scalable_indices = [False] * len(indices)
# ArrayAttr: Extract index values.
if isinstance(indices, ArrayAttr):
indices = [idx for idx in indices]
def process_nonscalable_index(i, index):
"""Processes any form of non-scalable index.
Returns False if the given index was scalable and thus remains
unprocessed; True otherwise.
"""
if isinstance(index, int):
static_indices[i] = index
elif isinstance(index, IntegerAttr):
static_indices[i] = index.value # pytype: disable=attribute-error
elif isinstance(index, (Operation, Value, OpView)):
dynamic_indices.append(index)
else:
return False
return True
# Process each index at a time.
for i, index in enumerate(indices):
if not process_nonscalable_index(i, index):
# If it wasn't processed, it must be a scalable index, which is
# provided as a _Sequence of one value, so extract and process that.
scalable_indices[i] = True
assert len(index) == 1
ret = process_nonscalable_index(i, index[0])
assert ret
return dynamic_indices, static_indices, scalable_indices
# Dispatches `MixedValues` that all represents integers in various forms into
# the following three categories:
# - `dynamic_values`: a list of `Value`s, potentially from op results;
# - `packed_values`: a value handle, potentially from an op result, associated
# to one or more payload operations of integer type;
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
# The input is in the form for `packed_values`, only that result is set and the
# other two are empty. Otherwise, the input can be a mix of the other two forms,
# and for each dynamic value, a special value is added to the `static_values`.
def _dispatch_mixed_values(
values: MixedValues,
) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]:
dynamic_values = []
packed_values = None
static_values = None
if isinstance(values, ArrayAttr):
static_values = values
elif isinstance(values, (Operation, Value, OpView)):
packed_values = values
else:
static_values = []
for size in values or []:
if isinstance(size, int):
static_values.append(size)
else:
static_values.append(ShapedType.get_dynamic_size())
dynamic_values.append(size)
static_values = DenseI64ArrayAttr.get(static_values)
return (dynamic_values, packed_values, static_values)
def _get_value_or_attribute_value(
value_or_attr: _Union[any, Attribute, ArrayAttr]
) -> any:
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
return value_or_attr.value
if isinstance(value_or_attr, ArrayAttr):
return _get_value_list(value_or_attr)
return value_or_attr
def _get_value_list(
sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr]
) -> _Sequence[any]:
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
def _get_int_array_attr(
values: _Optional[_Union[ArrayAttr, IntOrAttrList]]
) -> ArrayAttr:
if values is None:
return None
# Turn into a Python list of Python ints.
values = _get_value_list(values)
# Make an ArrayAttr of IntegerAttrs out of it.
return ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
)
def _get_int_array_array_attr(
values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]]
) -> ArrayAttr:
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
The input has to be a collection of a collection of integers, where any
Python _Sequence and ArrayAttr are admissible collections and Python ints and
any IntegerAttr are admissible integers. Both levels of collections are
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
If the input is None, an empty ArrayAttr is returned.
"""
if values is None:
return None
# Make sure the outer level is a list.
values = _get_value_list(values)
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
# Sequences. Make sure the nested values are all lists.
values = [_get_value_list(nested) for nested in values]
# Turn each nested list into an ArrayAttr.
values = [_get_int_array_attr(nested) for nested in values]
# Turn the outer list into an ArrayAttr.
return ArrayAttr.get(values)

View File

@@ -1,5 +1,135 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import operator
from itertools import accumulate
from typing import Optional
from ._memref_ops_gen import *
from ._ods_common import _dispatch_mixed_values, MixedValues
from .arith import ConstantOp, _is_integer_like_type
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
def _is_constant_int_like(i):
return (
isinstance(i, Value)
and isinstance(i.owner.opview, ConstantOp)
and _is_integer_like_type(i.type)
)
def _is_static_int_like(i):
return (
isinstance(i, int) and not ShapedType.is_dynamic_size(i)
) or _is_constant_int_like(i)
def _infer_memref_subview_result_type(
source_memref_type, offsets, static_sizes, static_strides
):
source_strides, source_offset = source_memref_type.get_strides_and_offset()
# "canonicalize" from tuple|list -> list
offsets, static_sizes, static_strides, source_strides = map(
list, (offsets, static_sizes, static_strides, source_strides)
)
if not all(
all(_is_static_int_like(i) for i in s)
for s in [
static_sizes,
static_strides,
source_strides,
]
):
raise ValueError(
"Only inferring from python or mlir integer constant is supported."
)
for s in [offsets, static_sizes, static_strides]:
for idx, i in enumerate(s):
if _is_constant_int_like(i):
s[idx] = i.owner.opview.literal_value
if any(not _is_static_int_like(i) for i in offsets + [source_offset]):
target_offset = ShapedType.get_dynamic_size()
else:
target_offset = source_offset
for offset, target_stride in zip(offsets, source_strides):
target_offset += offset * target_stride
target_strides = []
for source_stride, static_stride in zip(source_strides, static_strides):
target_strides.append(source_stride * static_stride)
# If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
if target_strides == default_strides and target_offset == 0:
layout = None
else:
layout = StridedLayoutAttr.get(target_offset, target_strides)
return (
offsets,
static_sizes,
static_strides,
MemRefType.get(
static_sizes,
source_memref_type.element_type,
layout,
source_memref_type.memory_space,
),
)
_generated_subview = subview
def subview(
source: Value,
offsets: MixedValues,
sizes: MixedValues,
strides: MixedValues,
*,
result_type: Optional[MemRefType] = None,
loc=None,
ip=None,
):
if offsets is None:
offsets = []
if sizes is None:
sizes = []
if strides is None:
strides = []
source_strides, source_offset = source.type.get_strides_and_offset()
if result_type is None and all(
all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides]
):
# If any are arith.constant results then this will canonicalize to python int
# (which can then be used to fully specify the subview).
(
offsets,
sizes,
strides,
result_type,
) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
elif result_type is None:
raise ValueError(
"mixed static/dynamic offset/sizes/strides requires explicit result type."
)
offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
return _generated_subview(
result_type,
source,
offsets,
sizes,
strides,
static_offsets,
static_sizes,
static_strides,
loc=loc,
ip=ip,
)

View File

@@ -9,163 +9,24 @@ from .._structured_transform_enum_gen import *
try:
from ...ir import *
from ...dialects import transform
from .._ods_common import _cext as _ods_cext
from .._ods_common import (
DynamicIndexList,
IntOrAttrList,
MixedValues,
OptionalBoolList,
OptionalIntList,
_cext as _ods_cext,
_dispatch_dynamic_index_list,
_dispatch_mixed_values,
_get_int_array_array_attr,
_get_int_array_attr,
_get_value_list,
_get_value_or_attribute_value,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import List, Optional, Sequence, Tuple, Union, overload
StaticIntLike = Union[int, IntegerAttr]
ValueLike = Union[Operation, OpView, Value]
MixedInt = Union[StaticIntLike, ValueLike]
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
def _dispatch_dynamic_index_list(
indices: Union[DynamicIndexList, ArrayAttr],
) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
"""Dispatches a list of indices to the appropriate form.
This is similar to the custom `DynamicIndexList` directive upstream:
provided indices may be in the form of dynamic SSA values or static values,
and they may be scalable (i.e., as a singleton list) or not. This function
dispatches each index into its respective form. It also extracts the SSA
values and static indices from various similar structures, respectively.
"""
dynamic_indices = []
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
scalable_indices = [False] * len(indices)
# ArrayAttr: Extract index values.
if isinstance(indices, ArrayAttr):
indices = [idx for idx in indices]
def process_nonscalable_index(i, index):
"""Processes any form of non-scalable index.
Returns False if the given index was scalable and thus remains
unprocessed; True otherwise.
"""
if isinstance(index, int):
static_indices[i] = index
elif isinstance(index, IntegerAttr):
static_indices[i] = index.value # pytype: disable=attribute-error
elif isinstance(index, (Operation, Value, OpView)):
dynamic_indices.append(index)
else:
return False
return True
# Process each index at a time.
for i, index in enumerate(indices):
if not process_nonscalable_index(i, index):
# If it wasn't processed, it must be a scalable index, which is
# provided as a Sequence of one value, so extract and process that.
scalable_indices[i] = True
assert len(index) == 1
ret = process_nonscalable_index(i, index[0])
assert ret
return dynamic_indices, static_indices, scalable_indices
# Dispatches `MixedValues` that all represents integers in various forms into
# the following three categories:
# - `dynamic_values`: a list of `Value`s, potentially from op results;
# - `packed_values`: a value handle, potentially from an op result, associated
# to one or more payload operations of integer type;
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
# The input is in the form for `packed_values`, only that result is set and the
# other two are empty. Otherwise, the input can be a mix of the other two forms,
# and for each dynamic value, a special value is added to the `static_values`.
def _dispatch_mixed_values(
values: MixedValues,
) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
dynamic_values = []
packed_values = None
static_values = None
if isinstance(values, ArrayAttr):
static_values = values
elif isinstance(values, (Operation, Value, OpView)):
packed_values = values
else:
static_values = []
for size in values or []:
if isinstance(size, int):
static_values.append(size)
else:
static_values.append(ShapedType.get_dynamic_size())
dynamic_values.append(size)
static_values = DenseI64ArrayAttr.get(static_values)
return (dynamic_values, packed_values, static_values)
def _get_value_or_attribute_value(
value_or_attr: Union[any, Attribute, ArrayAttr]
) -> any:
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
return value_or_attr.value
if isinstance(value_or_attr, ArrayAttr):
return _get_value_list(value_or_attr)
return value_or_attr
def _get_value_list(
sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
) -> Sequence[any]:
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
if values is None:
return None
# Turn into a Python list of Python ints.
values = _get_value_list(values)
# Make an ArrayAttr of IntegerAttrs out of it.
return ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
)
def _get_int_array_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
) -> ArrayAttr:
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
The input has to be a collection of collection of integers, where any
Python Sequence and ArrayAttr are admissible collections and Python ints and
any IntegerAttr are admissible integers. Both levels of collections are
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
If the input is None, an empty ArrayAttr is returned.
"""
if values is None:
return None
# Make sure the outer level is a list.
values = _get_value_list(values)
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
# Sequences. Make sure the nested values are all lists.
values = [_get_value_list(nested) for nested in values]
# Turn each nested list into an ArrayAttr.
values = [_get_int_array_attr(nested) for nested in values]
# Turn the outer list into an ArrayAttr.
return ArrayAttr.get(values)
from typing import List, Optional, Sequence, Union, overload
@_ods_cext.register_operation(_Dialect, replace=True)