mirror of
https://github.com/intel/llvm.git
synced 2026-02-08 17:07:06 +08:00
[mlir][MemRef] Move transform related functions in Transforms.h
NFC
This commit is contained in:
@@ -19,11 +19,6 @@ namespace mlir {
|
||||
|
||||
class AffineDialect;
|
||||
class ModuleOp;
|
||||
class RewriterBase;
|
||||
|
||||
namespace arith {
|
||||
class WideIntEmulationConverter;
|
||||
} // namespace arith
|
||||
|
||||
namespace func {
|
||||
class FuncDialect;
|
||||
@@ -36,82 +31,6 @@ class VectorDialect;
|
||||
} // namespace vector
|
||||
|
||||
namespace memref {
|
||||
class AllocOp;
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Collects a set of patterns to rewrite ops within the memref dialect.
|
||||
void populateExpandOpsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns for folding memref aliasing ops into consumer load/store
|
||||
/// ops into `patterns`.
|
||||
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns that resolve `memref.dim` operations with values that are
|
||||
/// defined by operations that implement the
|
||||
/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
|
||||
/// operands.
|
||||
void populateResolveRankedShapeTypeResultDimsPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns that resolve `memref.dim` operations with values that are
|
||||
/// defined by operations that implement the `InferShapedTypeOpInterface`, in
|
||||
/// terms of shapes of its input operands.
|
||||
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns for expanding memref operations that modify the metadata
|
||||
/// (sizes, offset, strides) of a memref into easier to analyze constructs.
|
||||
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns for emulating wide integer memref operations with ops over
|
||||
/// narrower integer types.
|
||||
void populateMemRefWideIntEmulationPatterns(
|
||||
arith::WideIntEmulationConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Appends type converions for emulating wide integer memref operations with
|
||||
/// ops over narrowe integer types.
|
||||
void populateMemRefWideIntEmulationConversions(
|
||||
arith::WideIntEmulationConverter &typeConverter);
|
||||
|
||||
/// Transformation to do multi-buffering/array expansion to remove dependencies
|
||||
/// on the temporary allocation between consecutive loop iterations.
|
||||
/// It returns the new allocation if the original allocation was multi-buffered
|
||||
/// and returns failure() otherwise.
|
||||
/// When `skipOverrideAnalysis`, the pass will apply the transformation
|
||||
/// without checking thwt the buffer is overrided at the beginning of each
|
||||
/// iteration. This implies that user knows that there is no data carried across
|
||||
/// loop iterations. Example:
|
||||
/// ```
|
||||
/// %0 = memref.alloc() : memref<4x128xf32>
|
||||
/// scf.for %iv = %c1 to %c1024 step %c3 {
|
||||
/// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
|
||||
/// "some_use"(%0) : (memref<4x128xf32>) -> ()
|
||||
/// }
|
||||
/// ```
|
||||
/// into:
|
||||
/// ```
|
||||
/// %0 = memref.alloc() : memref<5x4x128xf32>
|
||||
/// scf.for %iv = %c1 to %c1024 step %c3 {
|
||||
/// %s = arith.subi %iv, %c1 : index
|
||||
/// %d = arith.divsi %s, %c3 : index
|
||||
/// %i = arith.remsi %d, %c5 : index
|
||||
/// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
|
||||
/// memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
|
||||
/// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
|
||||
/// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
|
||||
/// }
|
||||
/// ```
|
||||
FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
|
||||
memref::AllocOp allocOp,
|
||||
unsigned multiplier,
|
||||
bool skipOverrideAnalysis = false);
|
||||
/// Call into `multiBuffer` with locally constructed IRRewriter.
|
||||
FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
|
||||
unsigned multiplier,
|
||||
bool skipOverrideAnalysis = false);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Passes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -16,8 +16,89 @@
|
||||
|
||||
namespace mlir {
|
||||
class RewritePatternSet;
|
||||
class RewriterBase;
|
||||
|
||||
namespace arith {
|
||||
class WideIntEmulationConverter;
|
||||
} // namespace arith
|
||||
|
||||
namespace memref {
|
||||
class AllocOp;
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Collects a set of patterns to rewrite ops within the memref dialect.
|
||||
void populateExpandOpsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns for folding memref aliasing ops into consumer load/store
|
||||
/// ops into `patterns`.
|
||||
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns that resolve `memref.dim` operations with values that are
|
||||
/// defined by operations that implement the
|
||||
/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
|
||||
/// operands.
|
||||
void populateResolveRankedShapeTypeResultDimsPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns that resolve `memref.dim` operations with values that are
|
||||
/// defined by operations that implement the `InferShapedTypeOpInterface`, in
|
||||
/// terms of shapes of its input operands.
|
||||
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns for expanding memref operations that modify the metadata
|
||||
/// (sizes, offset, strides) of a memref into easier to analyze constructs.
|
||||
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns for emulating wide integer memref operations with ops over
|
||||
/// narrower integer types.
|
||||
void populateMemRefWideIntEmulationPatterns(
|
||||
arith::WideIntEmulationConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Appends type converions for emulating wide integer memref operations with
|
||||
/// ops over narrowe integer types.
|
||||
void populateMemRefWideIntEmulationConversions(
|
||||
arith::WideIntEmulationConverter &typeConverter);
|
||||
|
||||
/// Transformation to do multi-buffering/array expansion to remove dependencies
|
||||
/// on the temporary allocation between consecutive loop iterations.
|
||||
/// It returns the new allocation if the original allocation was multi-buffered
|
||||
/// and returns failure() otherwise.
|
||||
/// When `skipOverrideAnalysis`, the pass will apply the transformation
|
||||
/// without checking thwt the buffer is overrided at the beginning of each
|
||||
/// iteration. This implies that user knows that there is no data carried across
|
||||
/// loop iterations. Example:
|
||||
/// ```
|
||||
/// %0 = memref.alloc() : memref<4x128xf32>
|
||||
/// scf.for %iv = %c1 to %c1024 step %c3 {
|
||||
/// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
|
||||
/// "some_use"(%0) : (memref<4x128xf32>) -> ()
|
||||
/// }
|
||||
/// ```
|
||||
/// into:
|
||||
/// ```
|
||||
/// %0 = memref.alloc() : memref<5x4x128xf32>
|
||||
/// scf.for %iv = %c1 to %c1024 step %c3 {
|
||||
/// %s = arith.subi %iv, %c1 : index
|
||||
/// %d = arith.divsi %s, %c3 : index
|
||||
/// %i = arith.remsi %d, %c5 : index
|
||||
/// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
|
||||
/// memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
|
||||
/// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
|
||||
/// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
|
||||
/// }
|
||||
/// ```
|
||||
FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
|
||||
memref::AllocOp allocOp,
|
||||
unsigned multiplier,
|
||||
bool skipOverrideAnalysis = false);
|
||||
/// Call into `multiBuffer` with locally constructed IRRewriter.
|
||||
FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
|
||||
unsigned multiplier,
|
||||
bool skipOverrideAnalysis = false);
|
||||
|
||||
/// Appends patterns for extracting address computations from the instructions
|
||||
/// with memory accesses such that these memory accesses use only a base
|
||||
/// pointer.
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Dominance.h"
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
Reference in New Issue
Block a user