[mlir][MemRef] Move transform related functions in Transforms.h

NFC
This commit is contained in:
Quentin Colombet
2023-03-28 15:18:09 +02:00
parent 1fa6fc3e0f
commit faafd26c4d
10 changed files with 89 additions and 82 deletions

View File

@@ -19,11 +19,6 @@ namespace mlir {
class AffineDialect;
class ModuleOp;
class RewriterBase;
namespace arith {
class WideIntEmulationConverter;
} // namespace arith
namespace func {
class FuncDialect;
@@ -36,82 +31,6 @@ class VectorDialect;
} // namespace vector
namespace memref {
class AllocOp;
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
/// Collects a set of patterns to rewrite ops within the memref dialect.
void populateExpandOpsPatterns(RewritePatternSet &patterns);
/// Appends patterns for folding memref aliasing ops into consumer load/store
/// ops into `patterns`.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the
/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
/// operands.
void populateResolveRankedShapeTypeResultDimsPatterns(
RewritePatternSet &patterns);
/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the `InferShapedTypeOpInterface`, in
/// terms of shapes of its input operands.
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
/// Appends patterns for expanding memref operations that modify the metadata
/// (sizes, offset, strides) of a memref into easier to analyze constructs.
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
/// Appends patterns for emulating wide integer memref operations with ops over
/// narrower integer types.
void populateMemRefWideIntEmulationPatterns(
arith::WideIntEmulationConverter &typeConverter,
RewritePatternSet &patterns);
/// Appends type converions for emulating wide integer memref operations with
/// ops over narrowe integer types.
void populateMemRefWideIntEmulationConversions(
arith::WideIntEmulationConverter &typeConverter);
/// Transformation to do multi-buffering/array expansion to remove dependencies
/// on the temporary allocation between consecutive loop iterations.
/// It returns the new allocation if the original allocation was multi-buffered
/// and returns failure() otherwise.
/// When `skipOverrideAnalysis`, the pass will apply the transformation
/// without checking thwt the buffer is overrided at the beginning of each
/// iteration. This implies that user knows that there is no data carried across
/// loop iterations. Example:
/// ```
/// %0 = memref.alloc() : memref<4x128xf32>
/// scf.for %iv = %c1 to %c1024 step %c3 {
/// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
/// "some_use"(%0) : (memref<4x128xf32>) -> ()
/// }
/// ```
/// into:
/// ```
/// %0 = memref.alloc() : memref<5x4x128xf32>
/// scf.for %iv = %c1 to %c1024 step %c3 {
/// %s = arith.subi %iv, %c1 : index
/// %d = arith.divsi %s, %c3 : index
/// %i = arith.remsi %d, %c5 : index
/// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
/// memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
/// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
/// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
/// }
/// ```
FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
memref::AllocOp allocOp,
unsigned multiplier,
bool skipOverrideAnalysis = false);
/// Call into `multiBuffer` with locally constructed IRRewriter.
FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
unsigned multiplier,
bool skipOverrideAnalysis = false);
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//

View File

@@ -16,8 +16,89 @@
namespace mlir {
class RewritePatternSet;
class RewriterBase;
namespace arith {
class WideIntEmulationConverter;
} // namespace arith
namespace memref {
class AllocOp;
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
/// Collects a set of patterns to rewrite ops within the memref dialect.
void populateExpandOpsPatterns(RewritePatternSet &patterns);
/// Appends patterns for folding memref aliasing ops into consumer load/store
/// ops into `patterns`.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the
/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
/// operands.
void populateResolveRankedShapeTypeResultDimsPatterns(
RewritePatternSet &patterns);
/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the `InferShapedTypeOpInterface`, in
/// terms of shapes of its input operands.
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
/// Appends patterns for expanding memref operations that modify the metadata
/// (sizes, offset, strides) of a memref into easier to analyze constructs.
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
/// Appends patterns for emulating wide integer memref operations with ops over
/// narrower integer types.
void populateMemRefWideIntEmulationPatterns(
arith::WideIntEmulationConverter &typeConverter,
RewritePatternSet &patterns);
/// Appends type converions for emulating wide integer memref operations with
/// ops over narrowe integer types.
void populateMemRefWideIntEmulationConversions(
arith::WideIntEmulationConverter &typeConverter);
/// Transformation to do multi-buffering/array expansion to remove dependencies
/// on the temporary allocation between consecutive loop iterations.
/// It returns the new allocation if the original allocation was multi-buffered
/// and returns failure() otherwise.
/// When `skipOverrideAnalysis`, the pass will apply the transformation
/// without checking thwt the buffer is overrided at the beginning of each
/// iteration. This implies that user knows that there is no data carried across
/// loop iterations. Example:
/// ```
/// %0 = memref.alloc() : memref<4x128xf32>
/// scf.for %iv = %c1 to %c1024 step %c3 {
/// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
/// "some_use"(%0) : (memref<4x128xf32>) -> ()
/// }
/// ```
/// into:
/// ```
/// %0 = memref.alloc() : memref<5x4x128xf32>
/// scf.for %iv = %c1 to %c1024 step %c3 {
/// %s = arith.subi %iv, %c1 : index
/// %d = arith.divsi %s, %c3 : index
/// %i = arith.remsi %d, %c5 : index
/// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
/// memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
/// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
/// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
/// }
/// ```
FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
memref::AllocOp allocOp,
unsigned multiplier,
bool skipOverrideAnalysis = false);
/// Call into `multiBuffer` with locally constructed IRRewriter.
FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
unsigned multiplier,
bool skipOverrideAnalysis = false);
/// Appends patterns for extracting address computations from the instructions
/// with memory accesses such that these memory accesses use only a base
/// pointer.

View File

@@ -19,7 +19,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"

View File

@@ -11,6 +11,7 @@
#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"

View File

@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"

View File

@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

View File

@@ -18,6 +18,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"

View File

@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"

View File

@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

View File

@@ -8,6 +8,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"