mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[mlir][memref] Fold extract_strided_metadata(cast(x)) into extract_strided_metadata(x) (#164585)
This commit is contained in:
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
|
||||
atLeastOneReplacement |= replaceConstantUsesOf(
|
||||
builder, getLoc(), getStrides(), getConstifiedMixedStrides());
|
||||
|
||||
// extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
|
||||
if (auto prev = getSource().getDefiningOp<CastOp>())
|
||||
if (isa<MemRefType>(prev.getSource().getType())) {
|
||||
getSourceMutable().assign(prev.getSource());
|
||||
atLeastOneReplacement = true;
|
||||
}
|
||||
|
||||
return success(atLeastOneReplacement);
|
||||
}
|
||||
|
||||
|
||||
@@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
|
||||
}
|
||||
};
|
||||
|
||||
/// Replace `base, offset, sizes, strides =
|
||||
/// extract_strided_metadata(
|
||||
/// cast(src) to dstTy)`
|
||||
/// With
|
||||
/// ```
|
||||
/// base, ... = extract_strided_metadata(src)
|
||||
/// offset = !dstTy.srcOffset.isDynamic()
|
||||
/// ? dstTy.srcOffset
|
||||
/// : extract_strided_metadata(src).offset
|
||||
/// sizes = for each srcSize in dstTy.srcSizes:
|
||||
/// !srcSize.isDynamic()
|
||||
/// ? srcSize
|
||||
// : extract_strided_metadata(src).sizes[i]
|
||||
/// strides = for each srcStride in dstTy.srcStrides:
|
||||
/// !srcStrides.isDynamic()
|
||||
/// ? srcStrides
|
||||
/// : extract_strided_metadata(src).strides[i]
|
||||
/// ```
|
||||
///
|
||||
/// In other words, consume the `cast` and apply its effects
|
||||
/// on the offset, sizes, and strides or compute them directly from `src`.
|
||||
class ExtractStridedMetadataOpCastFolder
|
||||
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value source = extractStridedMetadataOp.getSource();
|
||||
auto castOp = source.getDefiningOp<memref::CastOp>();
|
||||
if (!castOp)
|
||||
return failure();
|
||||
|
||||
Location loc = extractStridedMetadataOp.getLoc();
|
||||
// Check if the source is suitable for extract_strided_metadata.
|
||||
SmallVector<Type> inferredReturnTypes;
|
||||
if (failed(extractStridedMetadataOp.inferReturnTypes(
|
||||
rewriter.getContext(), loc, {castOp.getSource()},
|
||||
/*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
|
||||
inferredReturnTypes)))
|
||||
return rewriter.notifyMatchFailure(castOp,
|
||||
"cast source's type is incompatible");
|
||||
|
||||
auto memrefType = cast<MemRefType>(source.getType());
|
||||
unsigned rank = memrefType.getRank();
|
||||
SmallVector<OpFoldResult> results;
|
||||
results.resize_for_overwrite(rank * 2 + 2);
|
||||
|
||||
auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
|
||||
rewriter, loc, castOp.getSource());
|
||||
|
||||
// Register the base_buffer.
|
||||
results[0] = newExtractStridedMetadata.getBaseBuffer();
|
||||
|
||||
auto getConstantOrValue = [&rewriter](int64_t constant,
|
||||
OpFoldResult ofr) -> OpFoldResult {
|
||||
return ShapedType::isStatic(constant)
|
||||
? OpFoldResult(rewriter.getIndexAttr(constant))
|
||||
: ofr;
|
||||
};
|
||||
|
||||
auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
|
||||
assert(sourceStrides.size() == rank && "unexpected number of strides");
|
||||
|
||||
// Register the new offset.
|
||||
results[1] =
|
||||
getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
|
||||
|
||||
const unsigned sizeStartIdx = 2;
|
||||
const unsigned strideStartIdx = sizeStartIdx + rank;
|
||||
ArrayRef<int64_t> sourceSizes = memrefType.getShape();
|
||||
|
||||
SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
|
||||
SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
|
||||
for (unsigned i = 0; i < rank; ++i) {
|
||||
results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
|
||||
results[strideStartIdx + i] =
|
||||
getConstantOrValue(sourceStrides[i], strides[i]);
|
||||
}
|
||||
rewriter.replaceOp(extractStridedMetadataOp,
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, results));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Replace `base, offset, sizes, strides = extract_strided_metadata(
|
||||
/// memory_space_cast(src) to dstTy)`
|
||||
/// with
|
||||
@@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
|
||||
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
|
||||
ExtractStridedMetadataOpReinterpretCastFolder,
|
||||
ExtractStridedMetadataOpSubviewFolder,
|
||||
ExtractStridedMetadataOpCastFolder,
|
||||
ExtractStridedMetadataOpMemorySpaceCastFolder,
|
||||
ExtractStridedMetadataOpAssumeAlignmentFolder,
|
||||
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
|
||||
@@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
|
||||
ExtractStridedMetadataOpSubviewFolder,
|
||||
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
|
||||
ExtractStridedMetadataOpReinterpretCastFolder,
|
||||
ExtractStridedMetadataOpCastFolder,
|
||||
ExtractStridedMetadataOpMemorySpaceCastFolder,
|
||||
ExtractStridedMetadataOpAssumeAlignmentFolder,
|
||||
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
|
||||
|
||||
@@ -901,6 +901,132 @@ func.func @scope_merge_without_terminator() {
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we simplify extract_strided_metadata of cast
|
||||
// when the source of the cast is compatible with what
|
||||
// `extract_strided_metadata`s accept.
|
||||
//
|
||||
// When we apply the transformation the resulting offset, sizes and strides
|
||||
// should come straight from the inputs of the cast.
|
||||
// Additionally the folder on extract_strided_metadata should propagate the
|
||||
// static information.
|
||||
//
|
||||
// CHECK-LABEL: func @extract_strided_metadata_of_cast
|
||||
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
|
||||
//
|
||||
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
|
||||
//
|
||||
// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
|
||||
func.func @extract_strided_metadata_of_cast(
|
||||
%arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
|
||||
-> (memref<i32>, index,
|
||||
index, index,
|
||||
index, index) {
|
||||
|
||||
%cast =
|
||||
memref.cast %arg :
|
||||
memref<3x?xi32, strided<[4, ?], offset: ?>> to
|
||||
memref<?x?xi32, strided<[?, ?], offset: ?>>
|
||||
|
||||
%base, %base_offset, %sizes:2, %strides:2 =
|
||||
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
|
||||
-> memref<i32>, index,
|
||||
index, index,
|
||||
index, index
|
||||
|
||||
return %base, %base_offset,
|
||||
%sizes#0, %sizes#1,
|
||||
%strides#0, %strides#1 :
|
||||
memref<i32>, index,
|
||||
index, index,
|
||||
index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we simplify extract_strided_metadata of cast
|
||||
// when the source of the cast is compatible with what
|
||||
// `extract_strided_metadata`s accept.
|
||||
//
|
||||
// Same as extract_strided_metadata_of_cast but with constant sizes and strides
|
||||
// in the destination type.
|
||||
//
|
||||
// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
|
||||
// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
|
||||
//
|
||||
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
|
||||
// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
|
||||
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
|
||||
//
|
||||
// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
|
||||
func.func @extract_strided_metadata_of_cast_w_csts(
|
||||
%arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
|
||||
-> (memref<i32>, index,
|
||||
index, index,
|
||||
index, index) {
|
||||
|
||||
%cast =
|
||||
memref.cast %arg :
|
||||
memref<?x?xi32, strided<[?, ?], offset: ?>> to
|
||||
memref<4x?xi32, strided<[?, 18], offset: 25>>
|
||||
|
||||
%base, %base_offset, %sizes:2, %strides:2 =
|
||||
memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
|
||||
-> memref<i32>, index,
|
||||
index, index,
|
||||
index, index
|
||||
|
||||
return %base, %base_offset,
|
||||
%sizes#0, %sizes#1,
|
||||
%strides#0, %strides#1 :
|
||||
memref<i32>, index,
|
||||
index, index,
|
||||
index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we don't simplify extract_strided_metadata of
|
||||
// cast when the source of the cast is unranked.
|
||||
// Unranked memrefs cannot feed into extract_strided_metadata operations.
|
||||
// Note: Technically we could still fold the sizes and strides.
|
||||
//
|
||||
// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
|
||||
// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
|
||||
//
|
||||
// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
|
||||
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
|
||||
//
|
||||
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
|
||||
func.func @extract_strided_metadata_of_cast_unranked(
|
||||
%arg : memref<*xi32>)
|
||||
-> (memref<i32>, index,
|
||||
index, index,
|
||||
index, index) {
|
||||
|
||||
%cast =
|
||||
memref.cast %arg :
|
||||
memref<*xi32> to
|
||||
memref<?x?xi32, strided<[?, ?], offset: ?>>
|
||||
|
||||
%base, %base_offset, %sizes:2, %strides:2 =
|
||||
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
|
||||
-> memref<i32>, index,
|
||||
index, index,
|
||||
index, index
|
||||
|
||||
return %base, %base_offset,
|
||||
%sizes#0, %sizes#1,
|
||||
%strides#0, %strides#1 :
|
||||
memref<i32>, index,
|
||||
index, index,
|
||||
index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reinterpret_noop
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
|
||||
// CHECK-NEXT: return %[[ARG]]
|
||||
|
||||
@@ -1376,133 +1376,6 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
|
||||
memref<i32>, index, index, index, index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we simplify extract_strided_metadata of cast
|
||||
// when the source of the cast is compatible with what
|
||||
// `extract_strided_metadata`s accept.
|
||||
//
|
||||
// When we apply the transformation the resulting offset, sizes and strides
|
||||
// should come straight from the inputs of the cast.
|
||||
// Additionally the folder on extract_strided_metadata should propagate the
|
||||
// static information.
|
||||
//
|
||||
// CHECK-LABEL: func @extract_strided_metadata_of_cast
|
||||
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
|
||||
//
|
||||
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
|
||||
//
|
||||
// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
|
||||
func.func @extract_strided_metadata_of_cast(
|
||||
%arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
|
||||
-> (memref<i32>, index,
|
||||
index, index,
|
||||
index, index) {
|
||||
|
||||
%cast =
|
||||
memref.cast %arg :
|
||||
memref<3x?xi32, strided<[4, ?], offset: ?>> to
|
||||
memref<?x?xi32, strided<[?, ?], offset: ?>>
|
||||
|
||||
%base, %base_offset, %sizes:2, %strides:2 =
|
||||
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
|
||||
-> memref<i32>, index,
|
||||
index, index,
|
||||
index, index
|
||||
|
||||
return %base, %base_offset,
|
||||
%sizes#0, %sizes#1,
|
||||
%strides#0, %strides#1 :
|
||||
memref<i32>, index,
|
||||
index, index,
|
||||
index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we simplify extract_strided_metadata of cast
|
||||
// when the source of the cast is compatible with what
|
||||
// `extract_strided_metadata`s accept.
|
||||
//
|
||||
// Same as extract_strided_metadata_of_cast but with constant sizes and strides
|
||||
// in the destination type.
|
||||
//
|
||||
// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
|
||||
// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
|
||||
//
|
||||
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
|
||||
// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
|
||||
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
|
||||
//
|
||||
// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
|
||||
func.func @extract_strided_metadata_of_cast_w_csts(
|
||||
%arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
|
||||
-> (memref<i32>, index,
|
||||
index, index,
|
||||
index, index) {
|
||||
|
||||
%cast =
|
||||
memref.cast %arg :
|
||||
memref<?x?xi32, strided<[?, ?], offset: ?>> to
|
||||
memref<4x?xi32, strided<[?, 18], offset: 25>>
|
||||
|
||||
%base, %base_offset, %sizes:2, %strides:2 =
|
||||
memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
|
||||
-> memref<i32>, index,
|
||||
index, index,
|
||||
index, index
|
||||
|
||||
return %base, %base_offset,
|
||||
%sizes#0, %sizes#1,
|
||||
%strides#0, %strides#1 :
|
||||
memref<i32>, index,
|
||||
index, index,
|
||||
index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we don't simplify extract_strided_metadata of
|
||||
// cast when the source of the cast is unranked.
|
||||
// Unranked memrefs cannot feed into extract_strided_metadata operations.
|
||||
// Note: Technically we could still fold the sizes and strides.
|
||||
//
|
||||
// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
|
||||
// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
|
||||
//
|
||||
// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
|
||||
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
|
||||
//
|
||||
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
|
||||
func.func @extract_strided_metadata_of_cast_unranked(
|
||||
%arg : memref<*xi32>)
|
||||
-> (memref<i32>, index,
|
||||
index, index,
|
||||
index, index) {
|
||||
|
||||
%cast =
|
||||
memref.cast %arg :
|
||||
memref<*xi32> to
|
||||
memref<?x?xi32, strided<[?, ?], offset: ?>>
|
||||
|
||||
%base, %base_offset, %sizes:2, %strides:2 =
|
||||
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
|
||||
-> memref<i32>, index,
|
||||
index, index,
|
||||
index, index
|
||||
|
||||
return %base, %base_offset,
|
||||
%sizes#0, %sizes#1,
|
||||
%strides#0, %strides#1 :
|
||||
memref<i32>, index,
|
||||
index, index,
|
||||
index, index
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
memref.global "private" @dynamicShmem : memref<0xf16,3>
|
||||
|
||||
Reference in New Issue
Block a user