mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 19:44:38 +08:00
[mlir][python] enable memref.subview (#79393)
This commit is contained in:
@@ -408,6 +408,12 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
|
||||
/// Returns the memory space of the given MemRef type.
|
||||
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
|
||||
|
||||
/// Returns the strides of the MemRef if the layout map is in strided form.
|
||||
/// Both strides and offset are out params. strides must point to pre-allocated
|
||||
/// memory of length equal to the rank of the memref.
|
||||
MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(
|
||||
MlirType type, int64_t *strides, int64_t *offset);
|
||||
|
||||
/// Returns the memory spcae of the given Unranked MemRef type.
|
||||
MLIR_CAPI_EXPORTED MlirAttribute
|
||||
mlirUnrankedMemrefGetMemorySpace(MlirType type);
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace py = pybind11;
|
||||
@@ -618,6 +620,18 @@ public:
|
||||
return mlirMemRefTypeGetLayout(self);
|
||||
},
|
||||
"The layout of the MemRef type.")
|
||||
.def(
|
||||
"get_strides_and_offset",
|
||||
[](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
|
||||
std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
|
||||
int64_t offset;
|
||||
if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
|
||||
self, strides.data(), &offset)))
|
||||
throw std::runtime_error(
|
||||
"Failed to extract strides and offset from memref.");
|
||||
return {strides, offset};
|
||||
},
|
||||
"The strides and offset of the MemRef type.")
|
||||
.def_property_readonly(
|
||||
"affine_map",
|
||||
[](PyMemRefType &self) -> PyAffineMap {
|
||||
|
||||
@@ -9,12 +9,16 @@
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/AffineMap.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/CAPI/AffineMap.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -426,6 +430,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
|
||||
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
|
||||
}
|
||||
|
||||
MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
|
||||
int64_t *strides,
|
||||
int64_t *offset) {
|
||||
MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
|
||||
SmallVector<int64_t> strides_;
|
||||
if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
|
||||
return mlirLogicalResultFailure();
|
||||
|
||||
(void)std::copy(strides_.begin(), strides_.end(), strides);
|
||||
return mlirLogicalResultSuccess();
|
||||
}
|
||||
|
||||
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
|
||||
return wrap(UnrankedMemRefType::getTypeID());
|
||||
}
|
||||
|
||||
@@ -2,16 +2,30 @@
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# Provide a convenient name for sub-packages to resolve the main C-extension
|
||||
# with a relative import.
|
||||
from .._mlir_libs import _mlir as _cext
|
||||
from typing import (
|
||||
List as _List,
|
||||
Optional as _Optional,
|
||||
Sequence as _Sequence,
|
||||
Tuple as _Tuple,
|
||||
Type as _Type,
|
||||
TypeVar as _TypeVar,
|
||||
Union as _Union,
|
||||
)
|
||||
|
||||
from .._mlir_libs import _mlir as _cext
|
||||
from ..ir import (
|
||||
ArrayAttr,
|
||||
Attribute,
|
||||
BoolAttr,
|
||||
DenseI64ArrayAttr,
|
||||
IntegerAttr,
|
||||
IntegerType,
|
||||
OpView,
|
||||
Operation,
|
||||
ShapedType,
|
||||
Value,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"equally_sized_accessor",
|
||||
"get_default_loc_context",
|
||||
@@ -138,3 +152,157 @@ SubClassValueT = _Type[_U]
|
||||
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
|
||||
ResultValueT = _Union[ResultValueTypeTuple]
|
||||
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
|
||||
|
||||
StaticIntLike = _Union[int, IntegerAttr]
|
||||
ValueLike = _Union[Operation, OpView, Value]
|
||||
MixedInt = _Union[StaticIntLike, ValueLike]
|
||||
|
||||
IntOrAttrList = _Sequence[_Union[IntegerAttr, int]]
|
||||
OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]]
|
||||
|
||||
BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]]
|
||||
OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]]
|
||||
|
||||
MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
|
||||
|
||||
DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]]
|
||||
|
||||
|
||||
def _dispatch_dynamic_index_list(
|
||||
indices: _Union[DynamicIndexList, ArrayAttr],
|
||||
) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]:
|
||||
"""Dispatches a list of indices to the appropriate form.
|
||||
|
||||
This is similar to the custom `DynamicIndexList` directive upstream:
|
||||
provided indices may be in the form of dynamic SSA values or static values,
|
||||
and they may be scalable (i.e., as a singleton list) or not. This function
|
||||
dispatches each index into its respective form. It also extracts the SSA
|
||||
values and static indices from various similar structures, respectively.
|
||||
"""
|
||||
dynamic_indices = []
|
||||
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
|
||||
scalable_indices = [False] * len(indices)
|
||||
|
||||
# ArrayAttr: Extract index values.
|
||||
if isinstance(indices, ArrayAttr):
|
||||
indices = [idx for idx in indices]
|
||||
|
||||
def process_nonscalable_index(i, index):
|
||||
"""Processes any form of non-scalable index.
|
||||
|
||||
Returns False if the given index was scalable and thus remains
|
||||
unprocessed; True otherwise.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
static_indices[i] = index
|
||||
elif isinstance(index, IntegerAttr):
|
||||
static_indices[i] = index.value # pytype: disable=attribute-error
|
||||
elif isinstance(index, (Operation, Value, OpView)):
|
||||
dynamic_indices.append(index)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Process each index at a time.
|
||||
for i, index in enumerate(indices):
|
||||
if not process_nonscalable_index(i, index):
|
||||
# If it wasn't processed, it must be a scalable index, which is
|
||||
# provided as a _Sequence of one value, so extract and process that.
|
||||
scalable_indices[i] = True
|
||||
assert len(index) == 1
|
||||
ret = process_nonscalable_index(i, index[0])
|
||||
assert ret
|
||||
|
||||
return dynamic_indices, static_indices, scalable_indices
|
||||
|
||||
|
||||
# Dispatches `MixedValues` that all represents integers in various forms into
|
||||
# the following three categories:
|
||||
# - `dynamic_values`: a list of `Value`s, potentially from op results;
|
||||
# - `packed_values`: a value handle, potentially from an op result, associated
|
||||
# to one or more payload operations of integer type;
|
||||
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
|
||||
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
|
||||
# The input is in the form for `packed_values`, only that result is set and the
|
||||
# other two are empty. Otherwise, the input can be a mix of the other two forms,
|
||||
# and for each dynamic value, a special value is added to the `static_values`.
|
||||
def _dispatch_mixed_values(
|
||||
values: MixedValues,
|
||||
) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]:
|
||||
dynamic_values = []
|
||||
packed_values = None
|
||||
static_values = None
|
||||
if isinstance(values, ArrayAttr):
|
||||
static_values = values
|
||||
elif isinstance(values, (Operation, Value, OpView)):
|
||||
packed_values = values
|
||||
else:
|
||||
static_values = []
|
||||
for size in values or []:
|
||||
if isinstance(size, int):
|
||||
static_values.append(size)
|
||||
else:
|
||||
static_values.append(ShapedType.get_dynamic_size())
|
||||
dynamic_values.append(size)
|
||||
static_values = DenseI64ArrayAttr.get(static_values)
|
||||
|
||||
return (dynamic_values, packed_values, static_values)
|
||||
|
||||
|
||||
def _get_value_or_attribute_value(
|
||||
value_or_attr: _Union[any, Attribute, ArrayAttr]
|
||||
) -> any:
|
||||
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
|
||||
return value_or_attr.value
|
||||
if isinstance(value_or_attr, ArrayAttr):
|
||||
return _get_value_list(value_or_attr)
|
||||
return value_or_attr
|
||||
|
||||
|
||||
def _get_value_list(
|
||||
sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr]
|
||||
) -> _Sequence[any]:
|
||||
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
|
||||
|
||||
|
||||
def _get_int_array_attr(
|
||||
values: _Optional[_Union[ArrayAttr, IntOrAttrList]]
|
||||
) -> ArrayAttr:
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Turn into a Python list of Python ints.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# Make an ArrayAttr of IntegerAttrs out of it.
|
||||
return ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
|
||||
)
|
||||
|
||||
|
||||
def _get_int_array_array_attr(
|
||||
values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]]
|
||||
) -> ArrayAttr:
|
||||
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
|
||||
|
||||
The input has to be a collection of a collection of integers, where any
|
||||
Python _Sequence and ArrayAttr are admissible collections and Python ints and
|
||||
any IntegerAttr are admissible integers. Both levels of collections are
|
||||
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
|
||||
If the input is None, an empty ArrayAttr is returned.
|
||||
"""
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Make sure the outer level is a list.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
|
||||
# Sequences. Make sure the nested values are all lists.
|
||||
values = [_get_value_list(nested) for nested in values]
|
||||
|
||||
# Turn each nested list into an ArrayAttr.
|
||||
values = [_get_int_array_attr(nested) for nested in values]
|
||||
|
||||
# Turn the outer list into an ArrayAttr.
|
||||
return ArrayAttr.get(values)
|
||||
|
||||
@@ -1,5 +1,135 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
import operator
|
||||
from itertools import accumulate
|
||||
from typing import Optional
|
||||
|
||||
from ._memref_ops_gen import *
|
||||
from ._ods_common import _dispatch_mixed_values, MixedValues
|
||||
from .arith import ConstantOp, _is_integer_like_type
|
||||
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
|
||||
|
||||
|
||||
def _is_constant_int_like(i):
|
||||
return (
|
||||
isinstance(i, Value)
|
||||
and isinstance(i.owner.opview, ConstantOp)
|
||||
and _is_integer_like_type(i.type)
|
||||
)
|
||||
|
||||
|
||||
def _is_static_int_like(i):
|
||||
return (
|
||||
isinstance(i, int) and not ShapedType.is_dynamic_size(i)
|
||||
) or _is_constant_int_like(i)
|
||||
|
||||
|
||||
def _infer_memref_subview_result_type(
|
||||
source_memref_type, offsets, static_sizes, static_strides
|
||||
):
|
||||
source_strides, source_offset = source_memref_type.get_strides_and_offset()
|
||||
# "canonicalize" from tuple|list -> list
|
||||
offsets, static_sizes, static_strides, source_strides = map(
|
||||
list, (offsets, static_sizes, static_strides, source_strides)
|
||||
)
|
||||
|
||||
if not all(
|
||||
all(_is_static_int_like(i) for i in s)
|
||||
for s in [
|
||||
static_sizes,
|
||||
static_strides,
|
||||
source_strides,
|
||||
]
|
||||
):
|
||||
raise ValueError(
|
||||
"Only inferring from python or mlir integer constant is supported."
|
||||
)
|
||||
|
||||
for s in [offsets, static_sizes, static_strides]:
|
||||
for idx, i in enumerate(s):
|
||||
if _is_constant_int_like(i):
|
||||
s[idx] = i.owner.opview.literal_value
|
||||
|
||||
if any(not _is_static_int_like(i) for i in offsets + [source_offset]):
|
||||
target_offset = ShapedType.get_dynamic_size()
|
||||
else:
|
||||
target_offset = source_offset
|
||||
for offset, target_stride in zip(offsets, source_strides):
|
||||
target_offset += offset * target_stride
|
||||
|
||||
target_strides = []
|
||||
for source_stride, static_stride in zip(source_strides, static_strides):
|
||||
target_strides.append(source_stride * static_stride)
|
||||
|
||||
# If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
|
||||
default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
|
||||
if target_strides == default_strides and target_offset == 0:
|
||||
layout = None
|
||||
else:
|
||||
layout = StridedLayoutAttr.get(target_offset, target_strides)
|
||||
return (
|
||||
offsets,
|
||||
static_sizes,
|
||||
static_strides,
|
||||
MemRefType.get(
|
||||
static_sizes,
|
||||
source_memref_type.element_type,
|
||||
layout,
|
||||
source_memref_type.memory_space,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_generated_subview = subview
|
||||
|
||||
|
||||
def subview(
|
||||
source: Value,
|
||||
offsets: MixedValues,
|
||||
sizes: MixedValues,
|
||||
strides: MixedValues,
|
||||
*,
|
||||
result_type: Optional[MemRefType] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if offsets is None:
|
||||
offsets = []
|
||||
if sizes is None:
|
||||
sizes = []
|
||||
if strides is None:
|
||||
strides = []
|
||||
source_strides, source_offset = source.type.get_strides_and_offset()
|
||||
if result_type is None and all(
|
||||
all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides]
|
||||
):
|
||||
# If any are arith.constant results then this will canonicalize to python int
|
||||
# (which can then be used to fully specify the subview).
|
||||
(
|
||||
offsets,
|
||||
sizes,
|
||||
strides,
|
||||
result_type,
|
||||
) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
|
||||
elif result_type is None:
|
||||
raise ValueError(
|
||||
"mixed static/dynamic offset/sizes/strides requires explicit result type."
|
||||
)
|
||||
|
||||
offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
|
||||
sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
|
||||
strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
|
||||
|
||||
return _generated_subview(
|
||||
result_type,
|
||||
source,
|
||||
offsets,
|
||||
sizes,
|
||||
strides,
|
||||
static_offsets,
|
||||
static_sizes,
|
||||
static_strides,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -9,163 +9,24 @@ from .._structured_transform_enum_gen import *
|
||||
try:
|
||||
from ...ir import *
|
||||
from ...dialects import transform
|
||||
from .._ods_common import _cext as _ods_cext
|
||||
from .._ods_common import (
|
||||
DynamicIndexList,
|
||||
IntOrAttrList,
|
||||
MixedValues,
|
||||
OptionalBoolList,
|
||||
OptionalIntList,
|
||||
_cext as _ods_cext,
|
||||
_dispatch_dynamic_index_list,
|
||||
_dispatch_mixed_values,
|
||||
_get_int_array_array_attr,
|
||||
_get_int_array_attr,
|
||||
_get_value_list,
|
||||
_get_value_or_attribute_value,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple, Union, overload
|
||||
|
||||
StaticIntLike = Union[int, IntegerAttr]
|
||||
ValueLike = Union[Operation, OpView, Value]
|
||||
MixedInt = Union[StaticIntLike, ValueLike]
|
||||
|
||||
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
|
||||
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
|
||||
|
||||
BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
|
||||
OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
|
||||
|
||||
MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
|
||||
|
||||
DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
|
||||
|
||||
|
||||
def _dispatch_dynamic_index_list(
|
||||
indices: Union[DynamicIndexList, ArrayAttr],
|
||||
) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
|
||||
"""Dispatches a list of indices to the appropriate form.
|
||||
|
||||
This is similar to the custom `DynamicIndexList` directive upstream:
|
||||
provided indices may be in the form of dynamic SSA values or static values,
|
||||
and they may be scalable (i.e., as a singleton list) or not. This function
|
||||
dispatches each index into its respective form. It also extracts the SSA
|
||||
values and static indices from various similar structures, respectively.
|
||||
"""
|
||||
dynamic_indices = []
|
||||
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
|
||||
scalable_indices = [False] * len(indices)
|
||||
|
||||
# ArrayAttr: Extract index values.
|
||||
if isinstance(indices, ArrayAttr):
|
||||
indices = [idx for idx in indices]
|
||||
|
||||
def process_nonscalable_index(i, index):
|
||||
"""Processes any form of non-scalable index.
|
||||
|
||||
Returns False if the given index was scalable and thus remains
|
||||
unprocessed; True otherwise.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
static_indices[i] = index
|
||||
elif isinstance(index, IntegerAttr):
|
||||
static_indices[i] = index.value # pytype: disable=attribute-error
|
||||
elif isinstance(index, (Operation, Value, OpView)):
|
||||
dynamic_indices.append(index)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Process each index at a time.
|
||||
for i, index in enumerate(indices):
|
||||
if not process_nonscalable_index(i, index):
|
||||
# If it wasn't processed, it must be a scalable index, which is
|
||||
# provided as a Sequence of one value, so extract and process that.
|
||||
scalable_indices[i] = True
|
||||
assert len(index) == 1
|
||||
ret = process_nonscalable_index(i, index[0])
|
||||
assert ret
|
||||
|
||||
return dynamic_indices, static_indices, scalable_indices
|
||||
|
||||
|
||||
# Dispatches `MixedValues` that all represents integers in various forms into
|
||||
# the following three categories:
|
||||
# - `dynamic_values`: a list of `Value`s, potentially from op results;
|
||||
# - `packed_values`: a value handle, potentially from an op result, associated
|
||||
# to one or more payload operations of integer type;
|
||||
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
|
||||
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
|
||||
# The input is in the form for `packed_values`, only that result is set and the
|
||||
# other two are empty. Otherwise, the input can be a mix of the other two forms,
|
||||
# and for each dynamic value, a special value is added to the `static_values`.
|
||||
def _dispatch_mixed_values(
|
||||
values: MixedValues,
|
||||
) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
|
||||
dynamic_values = []
|
||||
packed_values = None
|
||||
static_values = None
|
||||
if isinstance(values, ArrayAttr):
|
||||
static_values = values
|
||||
elif isinstance(values, (Operation, Value, OpView)):
|
||||
packed_values = values
|
||||
else:
|
||||
static_values = []
|
||||
for size in values or []:
|
||||
if isinstance(size, int):
|
||||
static_values.append(size)
|
||||
else:
|
||||
static_values.append(ShapedType.get_dynamic_size())
|
||||
dynamic_values.append(size)
|
||||
static_values = DenseI64ArrayAttr.get(static_values)
|
||||
|
||||
return (dynamic_values, packed_values, static_values)
|
||||
|
||||
|
||||
def _get_value_or_attribute_value(
|
||||
value_or_attr: Union[any, Attribute, ArrayAttr]
|
||||
) -> any:
|
||||
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
|
||||
return value_or_attr.value
|
||||
if isinstance(value_or_attr, ArrayAttr):
|
||||
return _get_value_list(value_or_attr)
|
||||
return value_or_attr
|
||||
|
||||
|
||||
def _get_value_list(
|
||||
sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
|
||||
) -> Sequence[any]:
|
||||
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
|
||||
|
||||
|
||||
def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Turn into a Python list of Python ints.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# Make an ArrayAttr of IntegerAttrs out of it.
|
||||
return ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
|
||||
)
|
||||
|
||||
|
||||
def _get_int_array_array_attr(
|
||||
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
|
||||
) -> ArrayAttr:
|
||||
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
|
||||
|
||||
The input has to be a collection of collection of integers, where any
|
||||
Python Sequence and ArrayAttr are admissible collections and Python ints and
|
||||
any IntegerAttr are admissible integers. Both levels of collections are
|
||||
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
|
||||
If the input is None, an empty ArrayAttr is returned.
|
||||
"""
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Make sure the outer level is a list.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
|
||||
# Sequences. Make sure the nested values are all lists.
|
||||
values = [_get_value_list(nested) for nested in values]
|
||||
|
||||
# Turn each nested list into an ArrayAttr.
|
||||
values = [_get_int_array_attr(nested) for nested in values]
|
||||
|
||||
# Turn the outer list into an ArrayAttr.
|
||||
return ArrayAttr.get(values)
|
||||
from typing import List, Optional, Sequence, Union, overload
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir.ir import *
|
||||
import mlir.dialects.func as func
|
||||
import mlir.dialects.arith as arith
|
||||
import mlir.dialects.memref as memref
|
||||
import mlir.extras.types as T
|
||||
from mlir.dialects.memref import _infer_memref_subview_result_type
|
||||
from mlir.ir import *
|
||||
|
||||
|
||||
def run(f):
|
||||
@@ -88,3 +89,164 @@ def testMemRefAttr():
|
||||
memref.global_("objFifo_in0", T.memref(16, T.i32()))
|
||||
# CHECK: memref.global @objFifo_in0 : memref<16xi32>
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics
|
||||
@run
|
||||
def testSubViewOpInferReturnTypeSemantics():
|
||||
with Context() as ctx, Location.unknown(ctx):
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
|
||||
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
|
||||
print(x.owner)
|
||||
|
||||
y = memref.subview(x, [1, 1], [3, 3], [1, 1])
|
||||
assert y.owner.verify()
|
||||
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
|
||||
print(y.owner)
|
||||
|
||||
z = memref.subview(
|
||||
x,
|
||||
[arith.constant(T.index(), 1), 1],
|
||||
[3, 3],
|
||||
[1, 1],
|
||||
)
|
||||
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
|
||||
print(z.owner)
|
||||
|
||||
z = memref.subview(
|
||||
x,
|
||||
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
|
||||
[3, 3],
|
||||
[1, 1],
|
||||
)
|
||||
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>>
|
||||
print(z.owner)
|
||||
|
||||
s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4))
|
||||
z = memref.subview(
|
||||
x,
|
||||
[s, 0],
|
||||
[3, 3],
|
||||
[1, 1],
|
||||
)
|
||||
# CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>>
|
||||
print(z)
|
||||
|
||||
try:
|
||||
_infer_memref_subview_result_type(
|
||||
x.type,
|
||||
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
|
||||
[ShapedType.get_dynamic_size(), 3],
|
||||
[1, 1],
|
||||
)
|
||||
except ValueError as e:
|
||||
# CHECK: Only inferring from python or mlir integer constant is supported
|
||||
print(e)
|
||||
|
||||
try:
|
||||
memref.subview(
|
||||
x,
|
||||
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
|
||||
[ShapedType.get_dynamic_size(), 3],
|
||||
[1, 1],
|
||||
)
|
||||
except ValueError as e:
|
||||
# CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
|
||||
print(e)
|
||||
|
||||
layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1])
|
||||
x = memref.alloc(
|
||||
T.memref(
|
||||
10,
|
||||
10,
|
||||
T.i32(),
|
||||
layout=layout,
|
||||
),
|
||||
[],
|
||||
[arith.constant(T.index(), 42)],
|
||||
)
|
||||
# CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>>
|
||||
print(x.owner)
|
||||
y = memref.subview(
|
||||
x,
|
||||
[1, 1],
|
||||
[3, 3],
|
||||
[1, 1],
|
||||
result_type=T.memref(3, 3, T.i32(), layout=layout),
|
||||
)
|
||||
# CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>>
|
||||
print(y.owner)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing
|
||||
@run
|
||||
def testSubViewOpInferReturnTypeExtensiveSlicing():
|
||||
def check_strides_offset(memref, np_view):
|
||||
layout = memref.type.layout
|
||||
dtype_size_in_bytes = np_view.dtype.itemsize
|
||||
golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist()
|
||||
golden_offset = (
|
||||
np_view.ctypes.data - np_view.base.ctypes.data
|
||||
) // dtype_size_in_bytes
|
||||
|
||||
assert (layout.strides, layout.offset) == (golden_strides, golden_offset)
|
||||
|
||||
with Context() as ctx, Location.unknown(ctx):
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
shape = (10, 22, 333, 4444)
|
||||
golden_mem = np.zeros(shape, dtype=np.int32)
|
||||
mem1 = memref.alloc(T.memref(*shape, T.i32()), [], [])
|
||||
|
||||
# fmt: off
|
||||
check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 333, 4444), (1, 1, 1, 1)), golden_mem[1:2, ...])
|
||||
check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 333, 4444), (1, 1, 1, 1)), golden_mem[:, 1:2])
|
||||
check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 4444), (1, 1, 1, 1)), golden_mem[:, :, 1:2])
|
||||
check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 333, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2])
|
||||
check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 333, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2])
|
||||
check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 333, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2])
|
||||
check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 333, 4444), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :])
|
||||
check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2])
|
||||
check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 4444), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :])
|
||||
check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 4444), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :])
|
||||
check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 333, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2])
|
||||
check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2])
|
||||
check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2])
|
||||
check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 4444), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :])
|
||||
# fmt: on
|
||||
|
||||
# default strides and offset means no stridedlayout attribute means affinemap layout
|
||||
assert memref.subview(
|
||||
mem1, (0, 0, 0, 0), (10, 22, 333, 4444), (1, 1, 1, 1)
|
||||
).type.layout == AffineMapAttr.get(
|
||||
AffineMap.get(
|
||||
4,
|
||||
0,
|
||||
[
|
||||
AffineDimExpr.get(0),
|
||||
AffineDimExpr.get(1),
|
||||
AffineDimExpr.get(2),
|
||||
AffineDimExpr.get(3),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
shape = (7, 22, 333, 4444)
|
||||
golden_mem = np.zeros(shape, dtype=np.int32)
|
||||
mem2 = memref.alloc(T.memref(*shape, T.i32()), [], [])
|
||||
# fmt: off
|
||||
check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 333, 4444), (1, 2, 1, 1)), golden_mem[:, 0:22:2])
|
||||
check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 4444), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30])
|
||||
check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400])
|
||||
check_strides_offset(memref.subview(mem2, (0, 0, 100, 1000), (7, 22, 20, 20), (1, 1, 5, 50)), golden_mem[:, :, 100:200:5, 1000:2000:50])
|
||||
# fmt: on
|
||||
|
||||
shape = (8, 8)
|
||||
golden_mem = np.zeros(shape, dtype=np.int32)
|
||||
# fmt: off
|
||||
mem3 = memref.alloc(T.memref(*shape, T.i32()), [], [])
|
||||
check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4])
|
||||
check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8])
|
||||
# fmt: on
|
||||
|
||||
Reference in New Issue
Block a user