[mlir][python] enable memref.subview (#79393)

This commit is contained in:
Maksim Levental
2024-01-30 16:21:56 -06:00
committed by GitHub
parent a356e6ccad
commit 404af14f92
7 changed files with 516 additions and 159 deletions

View File

@@ -408,6 +408,12 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
/// Returns the memory space of the given MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
/// Returns the strides of the MemRef if the layout map is in strided form.
/// Both strides and offset are out params. strides must point to pre-allocated
/// memory of length equal to the rank of the memref.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(
MlirType type, int64_t *strides, int64_t *offset);
/// Returns the memory spcae of the given Unranked MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute
mlirUnrankedMemrefGetMemorySpace(MlirType type);

View File

@@ -12,6 +12,8 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
#include <optional>
namespace py = pybind11;
@@ -618,6 +620,18 @@ public:
return mlirMemRefTypeGetLayout(self);
},
"The layout of the MemRef type.")
.def(
"get_strides_and_offset",
[](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
int64_t offset;
if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
self, strides.data(), &offset)))
throw std::runtime_error(
"Failed to extract strides and offset from memref.");
return {strides, offset};
},
"The strides and offset of the MemRef type.")
.def_property_readonly(
"affine_map",
[](PyMemRefType &self) -> PyAffineMap {

View File

@@ -9,12 +9,16 @@
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
#include <algorithm>
using namespace mlir;
@@ -426,6 +430,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
}
MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
int64_t *strides,
int64_t *offset) {
MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
SmallVector<int64_t> strides_;
if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
return mlirLogicalResultFailure();
(void)std::copy(strides_.begin(), strides_.end(), strides);
return mlirLogicalResultSuccess();
}
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
return wrap(UnrankedMemRefType::getTypeID());
}

View File

@@ -2,16 +2,30 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Provide a convenient name for sub-packages to resolve the main C-extension
# with a relative import.
from .._mlir_libs import _mlir as _cext
from typing import (
List as _List,
Optional as _Optional,
Sequence as _Sequence,
Tuple as _Tuple,
Type as _Type,
TypeVar as _TypeVar,
Union as _Union,
)
from .._mlir_libs import _mlir as _cext
from ..ir import (
ArrayAttr,
Attribute,
BoolAttr,
DenseI64ArrayAttr,
IntegerAttr,
IntegerType,
OpView,
Operation,
ShapedType,
Value,
)
__all__ = [
"equally_sized_accessor",
"get_default_loc_context",
@@ -138,3 +152,157 @@ SubClassValueT = _Type[_U]
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
ResultValueT = _Union[ResultValueTypeTuple]
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
StaticIntLike = _Union[int, IntegerAttr]
ValueLike = _Union[Operation, OpView, Value]
MixedInt = _Union[StaticIntLike, ValueLike]
IntOrAttrList = _Sequence[_Union[IntegerAttr, int]]
OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]]
BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]]
OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]]
MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]]
def _dispatch_dynamic_index_list(
indices: _Union[DynamicIndexList, ArrayAttr],
) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]:
"""Dispatches a list of indices to the appropriate form.
This is similar to the custom `DynamicIndexList` directive upstream:
provided indices may be in the form of dynamic SSA values or static values,
and they may be scalable (i.e., as a singleton list) or not. This function
dispatches each index into its respective form. It also extracts the SSA
values and static indices from various similar structures, respectively.
"""
dynamic_indices = []
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
scalable_indices = [False] * len(indices)
# ArrayAttr: Extract index values.
if isinstance(indices, ArrayAttr):
indices = [idx for idx in indices]
def process_nonscalable_index(i, index):
"""Processes any form of non-scalable index.
Returns False if the given index was scalable and thus remains
unprocessed; True otherwise.
"""
if isinstance(index, int):
static_indices[i] = index
elif isinstance(index, IntegerAttr):
static_indices[i] = index.value # pytype: disable=attribute-error
elif isinstance(index, (Operation, Value, OpView)):
dynamic_indices.append(index)
else:
return False
return True
# Process each index at a time.
for i, index in enumerate(indices):
if not process_nonscalable_index(i, index):
# If it wasn't processed, it must be a scalable index, which is
# provided as a _Sequence of one value, so extract and process that.
scalable_indices[i] = True
assert len(index) == 1
ret = process_nonscalable_index(i, index[0])
assert ret
return dynamic_indices, static_indices, scalable_indices
# Dispatches `MixedValues` that all represents integers in various forms into
# the following three categories:
# - `dynamic_values`: a list of `Value`s, potentially from op results;
# - `packed_values`: a value handle, potentially from an op result, associated
# to one or more payload operations of integer type;
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
# The input is in the form for `packed_values`, only that result is set and the
# other two are empty. Otherwise, the input can be a mix of the other two forms,
# and for each dynamic value, a special value is added to the `static_values`.
def _dispatch_mixed_values(
values: MixedValues,
) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]:
dynamic_values = []
packed_values = None
static_values = None
if isinstance(values, ArrayAttr):
static_values = values
elif isinstance(values, (Operation, Value, OpView)):
packed_values = values
else:
static_values = []
for size in values or []:
if isinstance(size, int):
static_values.append(size)
else:
static_values.append(ShapedType.get_dynamic_size())
dynamic_values.append(size)
static_values = DenseI64ArrayAttr.get(static_values)
return (dynamic_values, packed_values, static_values)
def _get_value_or_attribute_value(
value_or_attr: _Union[any, Attribute, ArrayAttr]
) -> any:
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
return value_or_attr.value
if isinstance(value_or_attr, ArrayAttr):
return _get_value_list(value_or_attr)
return value_or_attr
def _get_value_list(
sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr]
) -> _Sequence[any]:
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
def _get_int_array_attr(
values: _Optional[_Union[ArrayAttr, IntOrAttrList]]
) -> ArrayAttr:
if values is None:
return None
# Turn into a Python list of Python ints.
values = _get_value_list(values)
# Make an ArrayAttr of IntegerAttrs out of it.
return ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
)
def _get_int_array_array_attr(
values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]]
) -> ArrayAttr:
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
The input has to be a collection of a collection of integers, where any
Python _Sequence and ArrayAttr are admissible collections and Python ints and
any IntegerAttr are admissible integers. Both levels of collections are
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
If the input is None, an empty ArrayAttr is returned.
"""
if values is None:
return None
# Make sure the outer level is a list.
values = _get_value_list(values)
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
# Sequences. Make sure the nested values are all lists.
values = [_get_value_list(nested) for nested in values]
# Turn each nested list into an ArrayAttr.
values = [_get_int_array_attr(nested) for nested in values]
# Turn the outer list into an ArrayAttr.
return ArrayAttr.get(values)

View File

@@ -1,5 +1,135 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import operator
from itertools import accumulate
from typing import Optional
from ._memref_ops_gen import *
from ._ods_common import _dispatch_mixed_values, MixedValues
from .arith import ConstantOp, _is_integer_like_type
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
def _is_constant_int_like(i):
return (
isinstance(i, Value)
and isinstance(i.owner.opview, ConstantOp)
and _is_integer_like_type(i.type)
)
def _is_static_int_like(i):
return (
isinstance(i, int) and not ShapedType.is_dynamic_size(i)
) or _is_constant_int_like(i)
def _infer_memref_subview_result_type(
source_memref_type, offsets, static_sizes, static_strides
):
source_strides, source_offset = source_memref_type.get_strides_and_offset()
# "canonicalize" from tuple|list -> list
offsets, static_sizes, static_strides, source_strides = map(
list, (offsets, static_sizes, static_strides, source_strides)
)
if not all(
all(_is_static_int_like(i) for i in s)
for s in [
static_sizes,
static_strides,
source_strides,
]
):
raise ValueError(
"Only inferring from python or mlir integer constant is supported."
)
for s in [offsets, static_sizes, static_strides]:
for idx, i in enumerate(s):
if _is_constant_int_like(i):
s[idx] = i.owner.opview.literal_value
if any(not _is_static_int_like(i) for i in offsets + [source_offset]):
target_offset = ShapedType.get_dynamic_size()
else:
target_offset = source_offset
for offset, target_stride in zip(offsets, source_strides):
target_offset += offset * target_stride
target_strides = []
for source_stride, static_stride in zip(source_strides, static_strides):
target_strides.append(source_stride * static_stride)
# If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
if target_strides == default_strides and target_offset == 0:
layout = None
else:
layout = StridedLayoutAttr.get(target_offset, target_strides)
return (
offsets,
static_sizes,
static_strides,
MemRefType.get(
static_sizes,
source_memref_type.element_type,
layout,
source_memref_type.memory_space,
),
)
_generated_subview = subview
def subview(
source: Value,
offsets: MixedValues,
sizes: MixedValues,
strides: MixedValues,
*,
result_type: Optional[MemRefType] = None,
loc=None,
ip=None,
):
if offsets is None:
offsets = []
if sizes is None:
sizes = []
if strides is None:
strides = []
source_strides, source_offset = source.type.get_strides_and_offset()
if result_type is None and all(
all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides]
):
# If any are arith.constant results then this will canonicalize to python int
# (which can then be used to fully specify the subview).
(
offsets,
sizes,
strides,
result_type,
) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
elif result_type is None:
raise ValueError(
"mixed static/dynamic offset/sizes/strides requires explicit result type."
)
offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
return _generated_subview(
result_type,
source,
offsets,
sizes,
strides,
static_offsets,
static_sizes,
static_strides,
loc=loc,
ip=ip,
)

View File

@@ -9,163 +9,24 @@ from .._structured_transform_enum_gen import *
try:
from ...ir import *
from ...dialects import transform
from .._ods_common import _cext as _ods_cext
from .._ods_common import (
DynamicIndexList,
IntOrAttrList,
MixedValues,
OptionalBoolList,
OptionalIntList,
_cext as _ods_cext,
_dispatch_dynamic_index_list,
_dispatch_mixed_values,
_get_int_array_array_attr,
_get_int_array_attr,
_get_value_list,
_get_value_or_attribute_value,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import List, Optional, Sequence, Tuple, Union, overload
StaticIntLike = Union[int, IntegerAttr]
ValueLike = Union[Operation, OpView, Value]
MixedInt = Union[StaticIntLike, ValueLike]
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
def _dispatch_dynamic_index_list(
indices: Union[DynamicIndexList, ArrayAttr],
) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
"""Dispatches a list of indices to the appropriate form.
This is similar to the custom `DynamicIndexList` directive upstream:
provided indices may be in the form of dynamic SSA values or static values,
and they may be scalable (i.e., as a singleton list) or not. This function
dispatches each index into its respective form. It also extracts the SSA
values and static indices from various similar structures, respectively.
"""
dynamic_indices = []
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
scalable_indices = [False] * len(indices)
# ArrayAttr: Extract index values.
if isinstance(indices, ArrayAttr):
indices = [idx for idx in indices]
def process_nonscalable_index(i, index):
"""Processes any form of non-scalable index.
Returns False if the given index was scalable and thus remains
unprocessed; True otherwise.
"""
if isinstance(index, int):
static_indices[i] = index
elif isinstance(index, IntegerAttr):
static_indices[i] = index.value # pytype: disable=attribute-error
elif isinstance(index, (Operation, Value, OpView)):
dynamic_indices.append(index)
else:
return False
return True
# Process each index at a time.
for i, index in enumerate(indices):
if not process_nonscalable_index(i, index):
# If it wasn't processed, it must be a scalable index, which is
# provided as a Sequence of one value, so extract and process that.
scalable_indices[i] = True
assert len(index) == 1
ret = process_nonscalable_index(i, index[0])
assert ret
return dynamic_indices, static_indices, scalable_indices
# Dispatches `MixedValues` that all represents integers in various forms into
# the following three categories:
# - `dynamic_values`: a list of `Value`s, potentially from op results;
# - `packed_values`: a value handle, potentially from an op result, associated
# to one or more payload operations of integer type;
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
# The input is in the form for `packed_values`, only that result is set and the
# other two are empty. Otherwise, the input can be a mix of the other two forms,
# and for each dynamic value, a special value is added to the `static_values`.
def _dispatch_mixed_values(
values: MixedValues,
) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
dynamic_values = []
packed_values = None
static_values = None
if isinstance(values, ArrayAttr):
static_values = values
elif isinstance(values, (Operation, Value, OpView)):
packed_values = values
else:
static_values = []
for size in values or []:
if isinstance(size, int):
static_values.append(size)
else:
static_values.append(ShapedType.get_dynamic_size())
dynamic_values.append(size)
static_values = DenseI64ArrayAttr.get(static_values)
return (dynamic_values, packed_values, static_values)
def _get_value_or_attribute_value(
value_or_attr: Union[any, Attribute, ArrayAttr]
) -> any:
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
return value_or_attr.value
if isinstance(value_or_attr, ArrayAttr):
return _get_value_list(value_or_attr)
return value_or_attr
def _get_value_list(
sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
) -> Sequence[any]:
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
if values is None:
return None
# Turn into a Python list of Python ints.
values = _get_value_list(values)
# Make an ArrayAttr of IntegerAttrs out of it.
return ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
)
def _get_int_array_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
) -> ArrayAttr:
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
The input has to be a collection of collection of integers, where any
Python Sequence and ArrayAttr are admissible collections and Python ints and
any IntegerAttr are admissible integers. Both levels of collections are
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
If the input is None, an empty ArrayAttr is returned.
"""
if values is None:
return None
# Make sure the outer level is a list.
values = _get_value_list(values)
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
# Sequences. Make sure the nested values are all lists.
values = [_get_value_list(nested) for nested in values]
# Turn each nested list into an ArrayAttr.
values = [_get_int_array_attr(nested) for nested in values]
# Turn the outer list into an ArrayAttr.
return ArrayAttr.get(values)
from typing import List, Optional, Sequence, Union, overload
@_ods_cext.register_operation(_Dialect, replace=True)

View File

@@ -1,9 +1,10 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
import mlir.dialects.func as func
import mlir.dialects.arith as arith
import mlir.dialects.memref as memref
import mlir.extras.types as T
from mlir.dialects.memref import _infer_memref_subview_result_type
from mlir.ir import *
def run(f):
@@ -88,3 +89,164 @@ def testMemRefAttr():
memref.global_("objFifo_in0", T.memref(16, T.i32()))
# CHECK: memref.global @objFifo_in0 : memref<16xi32>
print(module)
# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics
@run
def testSubViewOpInferReturnTypeSemantics():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
print(x.owner)
y = memref.subview(x, [1, 1], [3, 3], [1, 1])
assert y.owner.verify()
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
print(y.owner)
z = memref.subview(
x,
[arith.constant(T.index(), 1), 1],
[3, 3],
[1, 1],
)
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
print(z.owner)
z = memref.subview(
x,
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
[3, 3],
[1, 1],
)
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>>
print(z.owner)
s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4))
z = memref.subview(
x,
[s, 0],
[3, 3],
[1, 1],
)
# CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>>
print(z)
try:
_infer_memref_subview_result_type(
x.type,
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
[ShapedType.get_dynamic_size(), 3],
[1, 1],
)
except ValueError as e:
# CHECK: Only inferring from python or mlir integer constant is supported
print(e)
try:
memref.subview(
x,
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
[ShapedType.get_dynamic_size(), 3],
[1, 1],
)
except ValueError as e:
# CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
print(e)
layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1])
x = memref.alloc(
T.memref(
10,
10,
T.i32(),
layout=layout,
),
[],
[arith.constant(T.index(), 42)],
)
# CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>>
print(x.owner)
y = memref.subview(
x,
[1, 1],
[3, 3],
[1, 1],
result_type=T.memref(3, 3, T.i32(), layout=layout),
)
# CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>>
print(y.owner)
# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing
@run
def testSubViewOpInferReturnTypeExtensiveSlicing():
def check_strides_offset(memref, np_view):
layout = memref.type.layout
dtype_size_in_bytes = np_view.dtype.itemsize
golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist()
golden_offset = (
np_view.ctypes.data - np_view.base.ctypes.data
) // dtype_size_in_bytes
assert (layout.strides, layout.offset) == (golden_strides, golden_offset)
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
shape = (10, 22, 333, 4444)
golden_mem = np.zeros(shape, dtype=np.int32)
mem1 = memref.alloc(T.memref(*shape, T.i32()), [], [])
# fmt: off
check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 333, 4444), (1, 1, 1, 1)), golden_mem[1:2, ...])
check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 333, 4444), (1, 1, 1, 1)), golden_mem[:, 1:2])
check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 4444), (1, 1, 1, 1)), golden_mem[:, :, 1:2])
check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 333, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2])
check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 333, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2])
check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 333, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2])
check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 333, 4444), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :])
check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2])
check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 4444), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :])
check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 4444), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :])
check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 333, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2])
check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2])
check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2])
check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 4444), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :])
# fmt: on
# default strides and offset means no stridedlayout attribute means affinemap layout
assert memref.subview(
mem1, (0, 0, 0, 0), (10, 22, 333, 4444), (1, 1, 1, 1)
).type.layout == AffineMapAttr.get(
AffineMap.get(
4,
0,
[
AffineDimExpr.get(0),
AffineDimExpr.get(1),
AffineDimExpr.get(2),
AffineDimExpr.get(3),
],
)
)
shape = (7, 22, 333, 4444)
golden_mem = np.zeros(shape, dtype=np.int32)
mem2 = memref.alloc(T.memref(*shape, T.i32()), [], [])
# fmt: off
check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 333, 4444), (1, 2, 1, 1)), golden_mem[:, 0:22:2])
check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 4444), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30])
check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400])
check_strides_offset(memref.subview(mem2, (0, 0, 100, 1000), (7, 22, 20, 20), (1, 1, 5, 50)), golden_mem[:, :, 100:200:5, 1000:2000:50])
# fmt: on
shape = (8, 8)
golden_mem = np.zeros(shape, dtype=np.int32)
# fmt: off
mem3 = memref.alloc(T.memref(*shape, T.i32()), [], [])
check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4])
check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8])
# fmt: on