mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 03:56:16 +08:00
[mlir][python] enable memref.subview (#79393)
This commit is contained in:
@@ -2,16 +2,30 @@
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# Provide a convenient name for sub-packages to resolve the main C-extension
|
||||
# with a relative import.
|
||||
from .._mlir_libs import _mlir as _cext
|
||||
from typing import (
|
||||
List as _List,
|
||||
Optional as _Optional,
|
||||
Sequence as _Sequence,
|
||||
Tuple as _Tuple,
|
||||
Type as _Type,
|
||||
TypeVar as _TypeVar,
|
||||
Union as _Union,
|
||||
)
|
||||
|
||||
from .._mlir_libs import _mlir as _cext
|
||||
from ..ir import (
|
||||
ArrayAttr,
|
||||
Attribute,
|
||||
BoolAttr,
|
||||
DenseI64ArrayAttr,
|
||||
IntegerAttr,
|
||||
IntegerType,
|
||||
OpView,
|
||||
Operation,
|
||||
ShapedType,
|
||||
Value,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"equally_sized_accessor",
|
||||
"get_default_loc_context",
|
||||
@@ -138,3 +152,157 @@ SubClassValueT = _Type[_U]
|
||||
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
|
||||
ResultValueT = _Union[ResultValueTypeTuple]
|
||||
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
|
||||
|
||||
StaticIntLike = _Union[int, IntegerAttr]
|
||||
ValueLike = _Union[Operation, OpView, Value]
|
||||
MixedInt = _Union[StaticIntLike, ValueLike]
|
||||
|
||||
IntOrAttrList = _Sequence[_Union[IntegerAttr, int]]
|
||||
OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]]
|
||||
|
||||
BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]]
|
||||
OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]]
|
||||
|
||||
MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
|
||||
|
||||
DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]]
|
||||
|
||||
|
||||
def _dispatch_dynamic_index_list(
|
||||
indices: _Union[DynamicIndexList, ArrayAttr],
|
||||
) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]:
|
||||
"""Dispatches a list of indices to the appropriate form.
|
||||
|
||||
This is similar to the custom `DynamicIndexList` directive upstream:
|
||||
provided indices may be in the form of dynamic SSA values or static values,
|
||||
and they may be scalable (i.e., as a singleton list) or not. This function
|
||||
dispatches each index into its respective form. It also extracts the SSA
|
||||
values and static indices from various similar structures, respectively.
|
||||
"""
|
||||
dynamic_indices = []
|
||||
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
|
||||
scalable_indices = [False] * len(indices)
|
||||
|
||||
# ArrayAttr: Extract index values.
|
||||
if isinstance(indices, ArrayAttr):
|
||||
indices = [idx for idx in indices]
|
||||
|
||||
def process_nonscalable_index(i, index):
|
||||
"""Processes any form of non-scalable index.
|
||||
|
||||
Returns False if the given index was scalable and thus remains
|
||||
unprocessed; True otherwise.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
static_indices[i] = index
|
||||
elif isinstance(index, IntegerAttr):
|
||||
static_indices[i] = index.value # pytype: disable=attribute-error
|
||||
elif isinstance(index, (Operation, Value, OpView)):
|
||||
dynamic_indices.append(index)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Process each index at a time.
|
||||
for i, index in enumerate(indices):
|
||||
if not process_nonscalable_index(i, index):
|
||||
# If it wasn't processed, it must be a scalable index, which is
|
||||
# provided as a _Sequence of one value, so extract and process that.
|
||||
scalable_indices[i] = True
|
||||
assert len(index) == 1
|
||||
ret = process_nonscalable_index(i, index[0])
|
||||
assert ret
|
||||
|
||||
return dynamic_indices, static_indices, scalable_indices
|
||||
|
||||
|
||||
# Dispatches `MixedValues` that all represents integers in various forms into
|
||||
# the following three categories:
|
||||
# - `dynamic_values`: a list of `Value`s, potentially from op results;
|
||||
# - `packed_values`: a value handle, potentially from an op result, associated
|
||||
# to one or more payload operations of integer type;
|
||||
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
|
||||
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
|
||||
# The input is in the form for `packed_values`, only that result is set and the
|
||||
# other two are empty. Otherwise, the input can be a mix of the other two forms,
|
||||
# and for each dynamic value, a special value is added to the `static_values`.
|
||||
def _dispatch_mixed_values(
|
||||
values: MixedValues,
|
||||
) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]:
|
||||
dynamic_values = []
|
||||
packed_values = None
|
||||
static_values = None
|
||||
if isinstance(values, ArrayAttr):
|
||||
static_values = values
|
||||
elif isinstance(values, (Operation, Value, OpView)):
|
||||
packed_values = values
|
||||
else:
|
||||
static_values = []
|
||||
for size in values or []:
|
||||
if isinstance(size, int):
|
||||
static_values.append(size)
|
||||
else:
|
||||
static_values.append(ShapedType.get_dynamic_size())
|
||||
dynamic_values.append(size)
|
||||
static_values = DenseI64ArrayAttr.get(static_values)
|
||||
|
||||
return (dynamic_values, packed_values, static_values)
|
||||
|
||||
|
||||
def _get_value_or_attribute_value(
|
||||
value_or_attr: _Union[any, Attribute, ArrayAttr]
|
||||
) -> any:
|
||||
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
|
||||
return value_or_attr.value
|
||||
if isinstance(value_or_attr, ArrayAttr):
|
||||
return _get_value_list(value_or_attr)
|
||||
return value_or_attr
|
||||
|
||||
|
||||
def _get_value_list(
|
||||
sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr]
|
||||
) -> _Sequence[any]:
|
||||
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
|
||||
|
||||
|
||||
def _get_int_array_attr(
|
||||
values: _Optional[_Union[ArrayAttr, IntOrAttrList]]
|
||||
) -> ArrayAttr:
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Turn into a Python list of Python ints.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# Make an ArrayAttr of IntegerAttrs out of it.
|
||||
return ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
|
||||
)
|
||||
|
||||
|
||||
def _get_int_array_array_attr(
|
||||
values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]]
|
||||
) -> ArrayAttr:
|
||||
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
|
||||
|
||||
The input has to be a collection of a collection of integers, where any
|
||||
Python _Sequence and ArrayAttr are admissible collections and Python ints and
|
||||
any IntegerAttr are admissible integers. Both levels of collections are
|
||||
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
|
||||
If the input is None, an empty ArrayAttr is returned.
|
||||
"""
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Make sure the outer level is a list.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
|
||||
# Sequences. Make sure the nested values are all lists.
|
||||
values = [_get_value_list(nested) for nested in values]
|
||||
|
||||
# Turn each nested list into an ArrayAttr.
|
||||
values = [_get_int_array_attr(nested) for nested in values]
|
||||
|
||||
# Turn the outer list into an ArrayAttr.
|
||||
return ArrayAttr.get(values)
|
||||
|
||||
@@ -1,5 +1,135 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
import operator
|
||||
from itertools import accumulate
|
||||
from typing import Optional
|
||||
|
||||
from ._memref_ops_gen import *
|
||||
from ._ods_common import _dispatch_mixed_values, MixedValues
|
||||
from .arith import ConstantOp, _is_integer_like_type
|
||||
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
|
||||
|
||||
|
||||
def _is_constant_int_like(i):
|
||||
return (
|
||||
isinstance(i, Value)
|
||||
and isinstance(i.owner.opview, ConstantOp)
|
||||
and _is_integer_like_type(i.type)
|
||||
)
|
||||
|
||||
|
||||
def _is_static_int_like(i):
|
||||
return (
|
||||
isinstance(i, int) and not ShapedType.is_dynamic_size(i)
|
||||
) or _is_constant_int_like(i)
|
||||
|
||||
|
||||
def _infer_memref_subview_result_type(
|
||||
source_memref_type, offsets, static_sizes, static_strides
|
||||
):
|
||||
source_strides, source_offset = source_memref_type.get_strides_and_offset()
|
||||
# "canonicalize" from tuple|list -> list
|
||||
offsets, static_sizes, static_strides, source_strides = map(
|
||||
list, (offsets, static_sizes, static_strides, source_strides)
|
||||
)
|
||||
|
||||
if not all(
|
||||
all(_is_static_int_like(i) for i in s)
|
||||
for s in [
|
||||
static_sizes,
|
||||
static_strides,
|
||||
source_strides,
|
||||
]
|
||||
):
|
||||
raise ValueError(
|
||||
"Only inferring from python or mlir integer constant is supported."
|
||||
)
|
||||
|
||||
for s in [offsets, static_sizes, static_strides]:
|
||||
for idx, i in enumerate(s):
|
||||
if _is_constant_int_like(i):
|
||||
s[idx] = i.owner.opview.literal_value
|
||||
|
||||
if any(not _is_static_int_like(i) for i in offsets + [source_offset]):
|
||||
target_offset = ShapedType.get_dynamic_size()
|
||||
else:
|
||||
target_offset = source_offset
|
||||
for offset, target_stride in zip(offsets, source_strides):
|
||||
target_offset += offset * target_stride
|
||||
|
||||
target_strides = []
|
||||
for source_stride, static_stride in zip(source_strides, static_strides):
|
||||
target_strides.append(source_stride * static_stride)
|
||||
|
||||
# If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
|
||||
default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
|
||||
if target_strides == default_strides and target_offset == 0:
|
||||
layout = None
|
||||
else:
|
||||
layout = StridedLayoutAttr.get(target_offset, target_strides)
|
||||
return (
|
||||
offsets,
|
||||
static_sizes,
|
||||
static_strides,
|
||||
MemRefType.get(
|
||||
static_sizes,
|
||||
source_memref_type.element_type,
|
||||
layout,
|
||||
source_memref_type.memory_space,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_generated_subview = subview
|
||||
|
||||
|
||||
def subview(
|
||||
source: Value,
|
||||
offsets: MixedValues,
|
||||
sizes: MixedValues,
|
||||
strides: MixedValues,
|
||||
*,
|
||||
result_type: Optional[MemRefType] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if offsets is None:
|
||||
offsets = []
|
||||
if sizes is None:
|
||||
sizes = []
|
||||
if strides is None:
|
||||
strides = []
|
||||
source_strides, source_offset = source.type.get_strides_and_offset()
|
||||
if result_type is None and all(
|
||||
all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides]
|
||||
):
|
||||
# If any are arith.constant results then this will canonicalize to python int
|
||||
# (which can then be used to fully specify the subview).
|
||||
(
|
||||
offsets,
|
||||
sizes,
|
||||
strides,
|
||||
result_type,
|
||||
) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
|
||||
elif result_type is None:
|
||||
raise ValueError(
|
||||
"mixed static/dynamic offset/sizes/strides requires explicit result type."
|
||||
)
|
||||
|
||||
offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
|
||||
sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
|
||||
strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
|
||||
|
||||
return _generated_subview(
|
||||
result_type,
|
||||
source,
|
||||
offsets,
|
||||
sizes,
|
||||
strides,
|
||||
static_offsets,
|
||||
static_sizes,
|
||||
static_strides,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -9,163 +9,24 @@ from .._structured_transform_enum_gen import *
|
||||
try:
|
||||
from ...ir import *
|
||||
from ...dialects import transform
|
||||
from .._ods_common import _cext as _ods_cext
|
||||
from .._ods_common import (
|
||||
DynamicIndexList,
|
||||
IntOrAttrList,
|
||||
MixedValues,
|
||||
OptionalBoolList,
|
||||
OptionalIntList,
|
||||
_cext as _ods_cext,
|
||||
_dispatch_dynamic_index_list,
|
||||
_dispatch_mixed_values,
|
||||
_get_int_array_array_attr,
|
||||
_get_int_array_attr,
|
||||
_get_value_list,
|
||||
_get_value_or_attribute_value,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple, Union, overload
|
||||
|
||||
StaticIntLike = Union[int, IntegerAttr]
|
||||
ValueLike = Union[Operation, OpView, Value]
|
||||
MixedInt = Union[StaticIntLike, ValueLike]
|
||||
|
||||
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
|
||||
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
|
||||
|
||||
BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
|
||||
OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
|
||||
|
||||
MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
|
||||
|
||||
DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
|
||||
|
||||
|
||||
def _dispatch_dynamic_index_list(
|
||||
indices: Union[DynamicIndexList, ArrayAttr],
|
||||
) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
|
||||
"""Dispatches a list of indices to the appropriate form.
|
||||
|
||||
This is similar to the custom `DynamicIndexList` directive upstream:
|
||||
provided indices may be in the form of dynamic SSA values or static values,
|
||||
and they may be scalable (i.e., as a singleton list) or not. This function
|
||||
dispatches each index into its respective form. It also extracts the SSA
|
||||
values and static indices from various similar structures, respectively.
|
||||
"""
|
||||
dynamic_indices = []
|
||||
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
|
||||
scalable_indices = [False] * len(indices)
|
||||
|
||||
# ArrayAttr: Extract index values.
|
||||
if isinstance(indices, ArrayAttr):
|
||||
indices = [idx for idx in indices]
|
||||
|
||||
def process_nonscalable_index(i, index):
|
||||
"""Processes any form of non-scalable index.
|
||||
|
||||
Returns False if the given index was scalable and thus remains
|
||||
unprocessed; True otherwise.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
static_indices[i] = index
|
||||
elif isinstance(index, IntegerAttr):
|
||||
static_indices[i] = index.value # pytype: disable=attribute-error
|
||||
elif isinstance(index, (Operation, Value, OpView)):
|
||||
dynamic_indices.append(index)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Process each index at a time.
|
||||
for i, index in enumerate(indices):
|
||||
if not process_nonscalable_index(i, index):
|
||||
# If it wasn't processed, it must be a scalable index, which is
|
||||
# provided as a Sequence of one value, so extract and process that.
|
||||
scalable_indices[i] = True
|
||||
assert len(index) == 1
|
||||
ret = process_nonscalable_index(i, index[0])
|
||||
assert ret
|
||||
|
||||
return dynamic_indices, static_indices, scalable_indices
|
||||
|
||||
|
||||
# Dispatches `MixedValues` that all represents integers in various forms into
|
||||
# the following three categories:
|
||||
# - `dynamic_values`: a list of `Value`s, potentially from op results;
|
||||
# - `packed_values`: a value handle, potentially from an op result, associated
|
||||
# to one or more payload operations of integer type;
|
||||
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
|
||||
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
|
||||
# The input is in the form for `packed_values`, only that result is set and the
|
||||
# other two are empty. Otherwise, the input can be a mix of the other two forms,
|
||||
# and for each dynamic value, a special value is added to the `static_values`.
|
||||
def _dispatch_mixed_values(
|
||||
values: MixedValues,
|
||||
) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
|
||||
dynamic_values = []
|
||||
packed_values = None
|
||||
static_values = None
|
||||
if isinstance(values, ArrayAttr):
|
||||
static_values = values
|
||||
elif isinstance(values, (Operation, Value, OpView)):
|
||||
packed_values = values
|
||||
else:
|
||||
static_values = []
|
||||
for size in values or []:
|
||||
if isinstance(size, int):
|
||||
static_values.append(size)
|
||||
else:
|
||||
static_values.append(ShapedType.get_dynamic_size())
|
||||
dynamic_values.append(size)
|
||||
static_values = DenseI64ArrayAttr.get(static_values)
|
||||
|
||||
return (dynamic_values, packed_values, static_values)
|
||||
|
||||
|
||||
def _get_value_or_attribute_value(
|
||||
value_or_attr: Union[any, Attribute, ArrayAttr]
|
||||
) -> any:
|
||||
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
|
||||
return value_or_attr.value
|
||||
if isinstance(value_or_attr, ArrayAttr):
|
||||
return _get_value_list(value_or_attr)
|
||||
return value_or_attr
|
||||
|
||||
|
||||
def _get_value_list(
|
||||
sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
|
||||
) -> Sequence[any]:
|
||||
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
|
||||
|
||||
|
||||
def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Turn into a Python list of Python ints.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# Make an ArrayAttr of IntegerAttrs out of it.
|
||||
return ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
|
||||
)
|
||||
|
||||
|
||||
def _get_int_array_array_attr(
|
||||
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
|
||||
) -> ArrayAttr:
|
||||
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
|
||||
|
||||
The input has to be a collection of collection of integers, where any
|
||||
Python Sequence and ArrayAttr are admissible collections and Python ints and
|
||||
any IntegerAttr are admissible integers. Both levels of collections are
|
||||
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
|
||||
If the input is None, an empty ArrayAttr is returned.
|
||||
"""
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Make sure the outer level is a list.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
|
||||
# Sequences. Make sure the nested values are all lists.
|
||||
values = [_get_value_list(nested) for nested in values]
|
||||
|
||||
# Turn each nested list into an ArrayAttr.
|
||||
values = [_get_int_array_attr(nested) for nested in values]
|
||||
|
||||
# Turn the outer list into an ArrayAttr.
|
||||
return ArrayAttr.get(values)
|
||||
from typing import List, Optional, Sequence, Union, overload
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
|
||||
Reference in New Issue
Block a user