diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 94947b760251..c06a48ee4b87 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor, atLeastOneReplacement |= replaceConstantUsesOf( builder, getLoc(), getStrides(), getConstifiedMixedStrides()); + // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x). + if (auto prev = getSource().getDefiningOp()) + if (isa(prev.getSource().getType())) { + getSourceMutable().assign(prev.getSource()); + atLeastOneReplacement = true; + } + return success(atLeastOneReplacement); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index d35566a9c0d2..bd02516d5b52 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder } }; -/// Replace `base, offset, sizes, strides = -/// extract_strided_metadata( -/// cast(src) to dstTy)` -/// With -/// ``` -/// base, ... = extract_strided_metadata(src) -/// offset = !dstTy.srcOffset.isDynamic() -/// ? dstTy.srcOffset -/// : extract_strided_metadata(src).offset -/// sizes = for each srcSize in dstTy.srcSizes: -/// !srcSize.isDynamic() -/// ? srcSize -// : extract_strided_metadata(src).sizes[i] -/// strides = for each srcStride in dstTy.srcStrides: -/// !srcStrides.isDynamic() -/// ? srcStrides -/// : extract_strided_metadata(src).strides[i] -/// ``` -/// -/// In other words, consume the `cast` and apply its effects -/// on the offset, sizes, and strides or compute them directly from `src`. -class ExtractStridedMetadataOpCastFolder - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult - matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, - PatternRewriter &rewriter) const override { - Value source = extractStridedMetadataOp.getSource(); - auto castOp = source.getDefiningOp(); - if (!castOp) - return failure(); - - Location loc = extractStridedMetadataOp.getLoc(); - // Check if the source is suitable for extract_strided_metadata. - SmallVector inferredReturnTypes; - if (failed(extractStridedMetadataOp.inferReturnTypes( - rewriter.getContext(), loc, {castOp.getSource()}, - /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, - inferredReturnTypes))) - return rewriter.notifyMatchFailure(castOp, - "cast source's type is incompatible"); - - auto memrefType = cast(source.getType()); - unsigned rank = memrefType.getRank(); - SmallVector results; - results.resize_for_overwrite(rank * 2 + 2); - - auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create( - rewriter, loc, castOp.getSource()); - - // Register the base_buffer. - results[0] = newExtractStridedMetadata.getBaseBuffer(); - - auto getConstantOrValue = [&rewriter](int64_t constant, - OpFoldResult ofr) -> OpFoldResult { - return ShapedType::isStatic(constant) - ? OpFoldResult(rewriter.getIndexAttr(constant)) - : ofr; - }; - - auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset(); - assert(sourceStrides.size() == rank && "unexpected number of strides"); - - // Register the new offset. - results[1] = - getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset()); - - const unsigned sizeStartIdx = 2; - const unsigned strideStartIdx = sizeStartIdx + rank; - ArrayRef sourceSizes = memrefType.getShape(); - - SmallVector sizes = newExtractStridedMetadata.getSizes(); - SmallVector strides = newExtractStridedMetadata.getStrides(); - for (unsigned i = 0; i < rank; ++i) { - results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]); - results[strideStartIdx + i] = - getConstantOrValue(sourceStrides[i], strides[i]); - } - rewriter.replaceOp(extractStridedMetadataOp, - getValueOrCreateConstantIndexOp(rewriter, loc, results)); - return success(); - } -}; - /// Replace `base, offset, sizes, strides = extract_strided_metadata( /// memory_space_cast(src) to dstTy)` /// with @@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns( RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, ExtractStridedMetadataOpSubviewFolder, - ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpMemorySpaceCastFolder, ExtractStridedMetadataOpAssumeAlignmentFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( @@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns( ExtractStridedMetadataOpSubviewFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, - ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpMemorySpaceCastFolder, ExtractStridedMetadataOpAssumeAlignmentFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 7160b52af635..313090272ef9 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -901,6 +901,132 @@ func.func @scope_merge_without_terminator() { // ----- +// Check that we simplify extract_strided_metadata of cast +// when the source of the cast is compatible with what +// `extract_strided_metadata`s accept. +// +// When we apply the transformation the resulting offset, sizes and strides +// should come straight from the inputs of the cast. +// Additionally the folder on extract_strided_metadata should propagate the +// static information. +// +// CHECK-LABEL: func @extract_strided_metadata_of_cast +// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>) +// +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1 +func.func @extract_strided_metadata_of_cast( + %arg : memref<3x?xi32, strided<[4, ?], offset:?>>) + -> (memref, index, + index, index, + index, index) { + + %cast = + memref.cast %arg : + memref<3x?xi32, strided<[4, ?], offset: ?>> to + memref> + + %base, %base_offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %cast:memref> + -> memref, index, + index, index, + index, index + + return %base, %base_offset, + %sizes#0, %sizes#1, + %strides#0, %strides#1 : + memref, index, + index, index, + index, index +} + +// ----- + +// Check that we simplify extract_strided_metadata of cast +// when the source of the cast is compatible with what +// `extract_strided_metadata`s accept. +// +// Same as extract_strided_metadata_of_cast but with constant sizes and strides +// in the destination type. +// +// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts +// CHECK-SAME: %[[ARG:.*]]: memref>) +// +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index +// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index +// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]] +func.func @extract_strided_metadata_of_cast_w_csts( + %arg : memref>) + -> (memref, index, + index, index, + index, index) { + + %cast = + memref.cast %arg : + memref> to + memref<4x?xi32, strided<[?, 18], offset: 25>> + + %base, %base_offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>> + -> memref, index, + index, index, + index, index + + return %base, %base_offset, + %sizes#0, %sizes#1, + %strides#0, %strides#1 : + memref, index, + index, index, + index, index +} + +// ----- + +// Check that we don't simplify extract_strided_metadata of +// cast when the source of the cast is unranked. +// Unranked memrefs cannot feed into extract_strided_metadata operations. +// Note: Technically we could still fold the sizes and strides. +// +// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked +// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>) +// +// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]] +// +// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1 +func.func @extract_strided_metadata_of_cast_unranked( + %arg : memref<*xi32>) + -> (memref, index, + index, index, + index, index) { + + %cast = + memref.cast %arg : + memref<*xi32> to + memref> + + %base, %base_offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %cast:memref> + -> memref, index, + index, index, + index, index + + return %base, %base_offset, + %sizes#0, %sizes#1, + %strides#0, %strides#1 : + memref, index, + index, index, + index, index +} + +// ----- + // CHECK-LABEL: func @reinterpret_noop // CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>) // CHECK-NEXT: return %[[ARG]] diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index 1e6b0111fa4c..18cdfb73f6ba 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1376,133 +1376,6 @@ func.func @extract_strided_metadata_of_get_global_with_offset() memref, index, index, index, index, index } -// ----- - -// Check that we simplify extract_strided_metadata of cast -// when the source of the cast is compatible with what -// `extract_strided_metadata`s accept. -// -// When we apply the transformation the resulting offset, sizes and strides -// should come straight from the inputs of the cast. -// Additionally the folder on extract_strided_metadata should propagate the -// static information. -// -// CHECK-LABEL: func @extract_strided_metadata_of_cast -// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>) -// -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] -// -// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1 -func.func @extract_strided_metadata_of_cast( - %arg : memref<3x?xi32, strided<[4, ?], offset:?>>) - -> (memref, index, - index, index, - index, index) { - - %cast = - memref.cast %arg : - memref<3x?xi32, strided<[4, ?], offset: ?>> to - memref> - - %base, %base_offset, %sizes:2, %strides:2 = - memref.extract_strided_metadata %cast:memref> - -> memref, index, - index, index, - index, index - - return %base, %base_offset, - %sizes#0, %sizes#1, - %strides#0, %strides#1 : - memref, index, - index, index, - index, index -} - -// ----- - -// Check that we simplify extract_strided_metadata of cast -// when the source of the cast is compatible with what -// `extract_strided_metadata`s accept. -// -// Same as extract_strided_metadata_of_cast but with constant sizes and strides -// in the destination type. -// -// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts -// CHECK-SAME: %[[ARG:.*]]: memref>) -// -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index -// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index -// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] -// -// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]] -func.func @extract_strided_metadata_of_cast_w_csts( - %arg : memref>) - -> (memref, index, - index, index, - index, index) { - - %cast = - memref.cast %arg : - memref> to - memref<4x?xi32, strided<[?, 18], offset: 25>> - - %base, %base_offset, %sizes:2, %strides:2 = - memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>> - -> memref, index, - index, index, - index, index - - return %base, %base_offset, - %sizes#0, %sizes#1, - %strides#0, %strides#1 : - memref, index, - index, index, - index, index -} - -// ----- - -// Check that we don't simplify extract_strided_metadata of -// cast when the source of the cast is unranked. -// Unranked memrefs cannot feed into extract_strided_metadata operations. -// Note: Technically we could still fold the sizes and strides. -// -// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked -// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>) -// -// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : -// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]] -// -// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1 -func.func @extract_strided_metadata_of_cast_unranked( - %arg : memref<*xi32>) - -> (memref, index, - index, index, - index, index) { - - %cast = - memref.cast %arg : - memref<*xi32> to - memref> - - %base, %base_offset, %sizes:2, %strides:2 = - memref.extract_strided_metadata %cast:memref> - -> memref, index, - index, index, - index, index - - return %base, %base_offset, - %sizes#0, %sizes#1, - %strides#0, %strides#1 : - memref, index, - index, index, - index, index -} - - // ----- memref.global "private" @dynamicShmem : memref<0xf16,3>