mirror of
https://github.com/intel/llvm.git
synced 2026-01-20 10:18:14 +08:00
[mlir][TilingInterface] Add pattern to tile using TilingInterface and implement TilingInterface for Linalg ops.
This patch adds support for tiling operations that implement the TilingInterface. - It separates the loop constructs that are used to iterate over tile from the implementation of the tiling itself. For example, the use of destructive updates is more related to use of scf.for for iterating over tiles that are tensors. - To test the transformation, TilingInterface is implemented for LinalgOps. The separation of the looping constructs used from the implementation of tile code generation greatly simplifies the latter. - The implementation of TilingInterface for LinalgOp is kept as an external model for now till this approach can be fully flushed out to replace the existing tiling + fusion approaches in Linalg. Differential Revision: https://reviews.llvm.org/D127133
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
//===- TilingInterfaceImpl.h - Implementation of TilingInterface ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
|
||||
#define MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace linalg {
|
||||
void registerTilingInterfaceExternalModels(DialectRegistry ®istry);
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_TILINGINTERFACEIMPL_H
|
||||
@@ -164,11 +164,11 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
|
||||
SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
|
||||
ValueRange ivs, ValueRange tileSizes);
|
||||
|
||||
/// Compute tile sizes, given a list of loop `ivs`, `tileSizes` and dimension
|
||||
/// Compute tile sizes, given a list of `tileSizes` and dimension
|
||||
/// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the
|
||||
/// corresponding result size is the corresponding value from `sizeBounds`.
|
||||
/// Note: The returned tile sizes are closed intervals.
|
||||
SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
|
||||
SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
|
||||
ValueRange tileSizes,
|
||||
ArrayRef<Value> sizeBounds);
|
||||
|
||||
|
||||
87
mlir/include/mlir/Dialect/SCF/TileUsingInterface.h
Normal file
87
mlir/include/mlir/Dialect/SCF/TileUsingInterface.h
Normal file
@@ -0,0 +1,87 @@
|
||||
//===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
|
||||
#define MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
class Operation;
|
||||
class PatternRewriter;
|
||||
class TilingInterface;
|
||||
} // namespace mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace scf {
|
||||
|
||||
using SCFTileSizeComputationFunction =
|
||||
std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
|
||||
|
||||
/// Options to use to control tiling.
|
||||
struct SCFTilingOptions {
|
||||
/// Computation function that returns the tile sizes for each operation.
|
||||
/// Delayed construction of constant tile sizes should occur to interoperate
|
||||
/// with folding.
|
||||
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
|
||||
|
||||
SCFTilingOptions &
|
||||
setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
|
||||
tileSizeComputationFunction = std::move(fun);
|
||||
return *this;
|
||||
}
|
||||
/// Set the `tileSizeComputationFunction` to return the values `ts`. The
|
||||
/// values must not fold away when tiling. Otherwise, use a more robust
|
||||
/// `tileSizeComputationFunction`.
|
||||
SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
|
||||
tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
|
||||
return *this;
|
||||
}
|
||||
/// Convenience function to set the `tileSizeComputationFunction` to a
|
||||
/// function that computes tile sizes at the point they are needed. Allows
|
||||
/// proper interaction with folding.
|
||||
SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
|
||||
};
|
||||
|
||||
struct SCFTilingResult {
|
||||
Operation *tiledOp;
|
||||
SmallVector<scf::ForOp> loops;
|
||||
};
|
||||
|
||||
/// Pattern to tile an op that implementas the `TilingInterface` using
|
||||
/// `scf.for` for iterating over the tiles.
|
||||
struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
|
||||
/// Construct a generic pattern applied to all TilingInterface ops.
|
||||
TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Construct a generic pattern applied to `opName`.
|
||||
TileUsingSCFForOp(StringRef opName, MLIRContext *context,
|
||||
SCFTilingOptions options, PatternBenefit benefit = 1);
|
||||
|
||||
/// `matchAndRewrite` implementation that returns the significant transformed
|
||||
/// pieces of IR.
|
||||
FailureOr<SCFTilingResult>
|
||||
returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
|
||||
|
||||
LogicalResult matchAndRewrite(TilingInterface op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
return returningMatchAndRewrite(op, rewriter);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Options to control tiling;
|
||||
SCFTilingOptions options;
|
||||
};
|
||||
|
||||
} // namespace scf
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SCF_TILEUSINGINTERFACE_H
|
||||
@@ -13,6 +13,7 @@
|
||||
#ifndef MLIR_DIALECT_SCF_UTILS_UTILS_H_
|
||||
#define MLIR_DIALECT_SCF_UTILS_UTILS_H_
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
@@ -32,12 +33,6 @@ class CallOp;
|
||||
class FuncOp;
|
||||
} // namespace func
|
||||
|
||||
namespace scf {
|
||||
class IfOp;
|
||||
class ForOp;
|
||||
class ParallelOp;
|
||||
} // namespace scf
|
||||
|
||||
/// Replace the `loop` with `newIterOperands` added as new initialization
|
||||
/// values. `newYieldValuesFn` is a callback that can be used to specify
|
||||
/// the additional values to be yielded by the loop. The number of
|
||||
@@ -57,6 +52,25 @@ scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
|
||||
ValueRange newIterOperands,
|
||||
const NewYieldValueFn &newYieldValuesFn);
|
||||
|
||||
/// Update a perfectly nested loop nest to yield new values from the innermost
|
||||
/// loop and propagating it up through the loop nest. This function
|
||||
/// - Expects `loopNest` to be a perfectly nested loop with outer most loop
|
||||
/// first and innermost loop last.
|
||||
/// - `newIterOperands` are the initialization values to be used for the
|
||||
/// outermost loop
|
||||
/// - `newYielValueFn` is the callback that generates the new values to be
|
||||
/// yielded from within the innermost loop.
|
||||
/// - The original loops are not erased, but are left in a "no-op" state where
|
||||
/// the body of the loop just yields the basic block arguments that correspond
|
||||
/// to the initialization values of a loop. The original loops are dead after
|
||||
/// this method.
|
||||
/// - All uses of the `newIterOperands` within the generated new loop
|
||||
/// are replaced with the corresponding `BlockArgument` in the loop body.
|
||||
SmallVector<scf::ForOp>
|
||||
replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
|
||||
ValueRange newIterOperands,
|
||||
NewYieldValueFn newYieldValueFn);
|
||||
|
||||
/// Outline a region with a single block into a new FuncOp.
|
||||
/// Assumes the FuncOp result types is the type of the yielded operands of the
|
||||
/// single block. This constraint makes it easy to determine the result.
|
||||
|
||||
@@ -98,6 +98,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
|
||||
/*defaultImplementation=*/[{
|
||||
return {};
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Method to return the position of the result tile computed by the tiled operation.
|
||||
|
||||
Specifies what tile of the result of the original tensor is computed
|
||||
by the tiled implementation. Expects the same `offsets` and `sizes` as
|
||||
used to obtain the tiled implementation of the operation.
|
||||
}],
|
||||
/*retType=*/"LogicalResult",
|
||||
/*methodName=*/"getResultTilePosition",
|
||||
/*args=*/(ins
|
||||
"OpBuilder &":$b,
|
||||
"unsigned":$resultNumber,
|
||||
"ArrayRef<OpFoldResult> ":$offsets,
|
||||
"ArrayRef<OpFoldResult> ":$sizes,
|
||||
"SmallVector<OpFoldResult> &":$resultOffsets,
|
||||
"SmallVector<OpFoldResult> &":$resultSizes),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return failure();
|
||||
}]
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
||||
SparseTensorRewriting.cpp
|
||||
SplitReduction.cpp
|
||||
Tiling.cpp
|
||||
TilingInterfaceImpl.cpp
|
||||
Transforms.cpp
|
||||
Vectorization.cpp
|
||||
|
||||
|
||||
@@ -320,8 +320,7 @@ static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op,
|
||||
// Compute offsets and sizes of ExtractSliceOp.
|
||||
SmallVector<Value> offsets =
|
||||
computeTileOffsets(b, loc, localIvs, tileSizes);
|
||||
SmallVector<Value> sizes =
|
||||
computeTileSizes(b, loc, localIvs, tileSizes, allDims);
|
||||
SmallVector<Value> sizes = computeTileSizes(b, loc, tileSizes, allDims);
|
||||
// Create ExtractSliceOp: Extract a tile from the tensor::PadOp.
|
||||
// Note: The tensor::PadOp is located outside of the loop nest. It is
|
||||
// later moved inside by ExtractSliceOfPadTensorSwapPattern.
|
||||
|
||||
156
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Normal file
156
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Normal file
@@ -0,0 +1,156 @@
|
||||
//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
namespace {
|
||||
|
||||
/// External model implementation of TilingInterface for LinalgOps. An external
|
||||
/// model implementation is used for now till the use of `TilingInterface` is
|
||||
/// on-par with the current Linalg tiling + fusion patterns. Once it is
|
||||
/// maybe possible to move this into the op-definition (though there are
|
||||
/// advantages to leaving it as an external model)
|
||||
template <typename LinalgOpTy>
|
||||
struct LinalgOpTilingInterface
|
||||
: public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
|
||||
LinalgOpTy> {
|
||||
|
||||
/// Return the destination operands.
|
||||
SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
|
||||
return llvm::cast<LinalgOp>(op).getOutputOperands();
|
||||
}
|
||||
|
||||
/// Return the loop iterator type.
|
||||
SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
|
||||
LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
|
||||
return llvm::to_vector(
|
||||
llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
|
||||
return strAttr.cast<StringAttr>().getValue();
|
||||
}));
|
||||
}
|
||||
|
||||
/// Return the iteration domain range.
|
||||
SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
|
||||
Location loc = op->getLoc();
|
||||
LinalgOp linalgOp = cast<LinalgOp>(op);
|
||||
auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
|
||||
AffineMap map = linalgOp.getShapesToLoopsMap();
|
||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||
return llvm::to_vector(llvm::map_range(
|
||||
applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
|
||||
return Range{zero, v, one};
|
||||
}));
|
||||
}
|
||||
|
||||
// Instantiate the tiled implementation of the operation.
|
||||
SmallVector<Operation *>
|
||||
getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
bool tileDestOperands) const {
|
||||
// Leave the `sizeBounds` value empty. That is only needed when the `sizes`
|
||||
// specified could lead to out of bounds accesses.
|
||||
Location loc = op->getLoc();
|
||||
LinalgOp linalgOp = cast<LinalgOp>(op);
|
||||
SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
|
||||
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
|
||||
b, loc, linalgOp, valuesToTile,
|
||||
getValueOrCreateConstantIndexOp(b, loc, offsets),
|
||||
getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
|
||||
|
||||
SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
|
||||
linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
|
||||
return tiledOperands[opOperand->getOperandNumber()].getType();
|
||||
}));
|
||||
|
||||
Operation *tiledOp =
|
||||
linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
|
||||
|
||||
return {tiledOp};
|
||||
}
|
||||
|
||||
// Return the details of the output tile generated by the tiled
|
||||
// implementation.
|
||||
LogicalResult
|
||||
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
SmallVector<OpFoldResult> &resultOffsets,
|
||||
SmallVector<OpFoldResult> &resultSizes) const {
|
||||
Location loc = op->getLoc();
|
||||
LinalgOp linalgOp = cast<LinalgOp>(op);
|
||||
|
||||
AffineExpr d0;
|
||||
bindDims(b.getContext(), d0);
|
||||
|
||||
auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
|
||||
AffineExpr expr,
|
||||
ValueRange operands) -> Value {
|
||||
AffineMap map = AffineMap::inferFromExprList({expr}).front();
|
||||
SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
|
||||
mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
|
||||
canonicalizeMapAndOperands(&map, &normalizedOperands);
|
||||
return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
|
||||
};
|
||||
|
||||
SmallVector<Value> sizeVals =
|
||||
getValueOrCreateConstantIndexOp(b, loc, sizes);
|
||||
SmallVector<Value> subShapeSizes =
|
||||
llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
|
||||
return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
|
||||
}));
|
||||
OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
|
||||
Value sliceOpResult =
|
||||
makeTiledShape(b, loc, outOperand->get(), sizeVals,
|
||||
linalgOp.getTiedIndexingMap(outOperand),
|
||||
getValueOrCreateConstantIndexOp(b, loc, offsets),
|
||||
/*ubs*/ {}, subShapeSizes, true);
|
||||
auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!sliceOp)
|
||||
return failure();
|
||||
resultOffsets = sliceOp.getMixedOffsets();
|
||||
resultSizes = sliceOp.getMixedSizes();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename OpType> static void registerOne(MLIRContext *ctx) {
|
||||
OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
|
||||
}
|
||||
|
||||
/// Variadic helper function.
|
||||
template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
|
||||
}
|
||||
|
||||
#define GET_OP_LIST
|
||||
|
||||
void mlir::linalg::registerTilingInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
|
||||
registerOne<linalg::GenericOp>(ctx);
|
||||
registerAll<
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
||||
>(ctx);
|
||||
});
|
||||
}
|
||||
@@ -893,7 +893,7 @@ SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
|
||||
return offsets;
|
||||
}
|
||||
|
||||
SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
|
||||
SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
|
||||
ValueRange tileSizes,
|
||||
ArrayRef<Value> sizeBounds) {
|
||||
SmallVector<Value> sizes;
|
||||
@@ -923,7 +923,7 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
|
||||
// that define tile subshapes.
|
||||
SmallVector<Value> lbs = computeTileOffsets(b, loc, ivs, tileSizes);
|
||||
SmallVector<Value> subShapeSizes =
|
||||
computeTileSizes(b, loc, ivs, tileSizes, sizeBounds);
|
||||
computeTileSizes(b, loc, tileSizes, sizeBounds);
|
||||
|
||||
assert(static_cast<int64_t>(valuesToTile.size()) ==
|
||||
linalgOp.getNumInputsAndOutputs() &&
|
||||
|
||||
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
|
||||
ParallelLoopFusion.cpp
|
||||
ParallelLoopTiling.cpp
|
||||
StructuralTypeConversions.cpp
|
||||
TileUsingInterface.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
|
||||
|
||||
249
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Normal file
249
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Normal file
@@ -0,0 +1,249 @@
|
||||
//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements the tiling using TilingInterface.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SCF/TileUsingInterface.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "tile-using-interface"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
scf::SCFTilingOptions &
|
||||
scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
|
||||
assert(!tileSizeComputationFunction && "tile sizes already set");
|
||||
SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
|
||||
tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPointToStart(
|
||||
&op->getParentOfType<func::FuncOp>().getBody().front());
|
||||
return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
|
||||
Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
|
||||
return v;
|
||||
}));
|
||||
};
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Generate an empty loop nest that represents the tiled loop nest shell.
|
||||
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
|
||||
/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
|
||||
/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
|
||||
/// the
|
||||
/// tile processed within the inner most loop.
|
||||
static SmallVector<scf::ForOp>
|
||||
generateTileLoopNest(OpBuilder &builder, Location loc,
|
||||
ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
|
||||
SmallVector<OpFoldResult> &offsets,
|
||||
SmallVector<OpFoldResult> &sizes) {
|
||||
assert(!loopRanges.empty() && "expected at least one loop range");
|
||||
assert(loopRanges.size() == tileSizeVals.size() &&
|
||||
"expected as many tile sizes as loop ranges");
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
SmallVector<scf::ForOp> loops;
|
||||
offsets.resize(loopRanges.size());
|
||||
sizes.resize(loopRanges.size());
|
||||
|
||||
// The tile size to use (to avoid out of bounds access) is minimum of
|
||||
// `tileSize` and `ub - iv`, where `iv` is the induction variable
|
||||
// of the tiled loop.
|
||||
AffineExpr s0, s1, d0;
|
||||
bindDims(builder.getContext(), d0);
|
||||
bindSymbols(builder.getContext(), s0, s1);
|
||||
AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
|
||||
|
||||
for (auto loopRange : llvm::enumerate(loopRanges)) {
|
||||
// No loops if tile size is zero. Set offset and size to the loop
|
||||
// offset and size.
|
||||
if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) {
|
||||
offsets[loopRange.index()] = loopRange.value().offset;
|
||||
sizes[loopRange.index()] = loopRange.value().size;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto loop = builder.create<scf::ForOp>(
|
||||
loc, loopRange.value().offset, loopRange.value().size,
|
||||
tileSizeVals[loopRange.index()], ValueRange{},
|
||||
[&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
|
||||
ValueRange /*iterArgs*/) {
|
||||
Value boundedTileSize = builder.create<AffineMinOp>(
|
||||
bodyLoc, minMap,
|
||||
ValueRange{iv, tileSizeVals[loopRange.index()],
|
||||
loopRange.value().size});
|
||||
sizes[loopRange.index()] = boundedTileSize;
|
||||
builder.create<scf::YieldOp>(loc);
|
||||
});
|
||||
offsets[loopRange.index()] = loop.getInductionVar();
|
||||
loops.push_back(loop);
|
||||
builder.setInsertionPoint(loop.getBody()->getTerminator());
|
||||
}
|
||||
return loops;
|
||||
}
|
||||
|
||||
scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
|
||||
scf::SCFTilingOptions options,
|
||||
PatternBenefit benefit)
|
||||
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
|
||||
options(std::move(options)) {}
|
||||
|
||||
scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
|
||||
MLIRContext *context,
|
||||
scf::SCFTilingOptions options,
|
||||
PatternBenefit benefit)
|
||||
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
|
||||
options(std::move(options)) {}
|
||||
|
||||
FailureOr<scf::SCFTilingResult>
|
||||
scf::TileUsingSCFForOp::returningMatchAndRewrite(
|
||||
TilingInterface op, PatternRewriter &rewriter) const {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
|
||||
if (!options.tileSizeComputationFunction) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "missing tile size computation function");
|
||||
}
|
||||
|
||||
// 1. Get the range of the loops that are represented by the operation.
|
||||
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
|
||||
size_t numLoops = iterationDomain.size();
|
||||
if (numLoops == 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to tile op with no iteration domain");
|
||||
}
|
||||
|
||||
// 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
|
||||
// skips tiling a particular dimension. This convention is significantly
|
||||
// simpler to handle instead of adjusting affine maps to account for missing
|
||||
// dimensions.
|
||||
SmallVector<Value, 4> tileSizeVector =
|
||||
options.tileSizeComputationFunction(rewriter, op);
|
||||
if (tileSizeVector.size() < iterationDomain.size()) {
|
||||
auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
|
||||
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
|
||||
}
|
||||
|
||||
scf::SCFTilingResult tilingResult;
|
||||
SmallVector<OpFoldResult> offsets, sizes;
|
||||
{
|
||||
// 3. Materialize an empty loop nest that iterates over the tiles. These
|
||||
// loops for now do not return any values even if the original operation has
|
||||
// results.
|
||||
tilingResult.loops = generateTileLoopNest(
|
||||
rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
|
||||
|
||||
LLVM_DEBUG({
|
||||
if (!tilingResult.loops.empty()) {
|
||||
llvm::errs() << "LoopNest shell :\n";
|
||||
tilingResult.loops.front().dump();
|
||||
llvm::errs() << "\n";
|
||||
}
|
||||
});
|
||||
|
||||
// 4. Generate the tiled implementation within the inner most loop.
|
||||
if (!tilingResult.loops.empty())
|
||||
rewriter.setInsertionPoint(
|
||||
tilingResult.loops.back().getBody()->getTerminator());
|
||||
SmallVector<Operation *> tiledImplementation = op.getTiledImplementation(
|
||||
rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true);
|
||||
if (tiledImplementation.size() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected tiled implementation to return a single op");
|
||||
}
|
||||
tilingResult.tiledOp = tiledImplementation[0];
|
||||
|
||||
LLVM_DEBUG({
|
||||
if (!tilingResult.loops.empty()) {
|
||||
llvm::errs() << "After tiled implementation :\n";
|
||||
tilingResult.loops.front().dump();
|
||||
llvm::errs() << "\n";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (op->getNumResults() == 0) {
|
||||
rewriter.eraseOp(op);
|
||||
return tilingResult;
|
||||
}
|
||||
|
||||
// 5. If the original operations has results, modify the loop nest to yield
|
||||
// the replacement values.
|
||||
SmallVector<Value> replacements;
|
||||
if (tilingResult.loops.empty()) {
|
||||
// 5a. If there were no loops, the tiled implementation results are the
|
||||
// replacements.
|
||||
rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
|
||||
return tilingResult;
|
||||
}
|
||||
|
||||
// 5b. `scf.for` with tensor semantics requires the loop nest to yield the
|
||||
// replacement values using destructive updates. Use the `TilingInterface`
|
||||
// to get the position of the result tiles and use that to generate the
|
||||
// destructive update pattern, i.e.,
|
||||
//
|
||||
// ```mlir
|
||||
// scf.for %iv0 = ... {
|
||||
// %0 = tiled_op
|
||||
// }
|
||||
// ```
|
||||
//
|
||||
// is transformed to
|
||||
//
|
||||
// ```mlir
|
||||
// %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
|
||||
// %0 = tiled_op
|
||||
// %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
|
||||
// scf.yield %1
|
||||
// }
|
||||
// ```
|
||||
NewYieldValueFn yieldValueFn =
|
||||
[&](OpBuilder &b, Location loc,
|
||||
ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
|
||||
SmallVector<Value> yieldedValues;
|
||||
Attribute one = b.getIndexAttr(1);
|
||||
for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
|
||||
SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
|
||||
if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
|
||||
resultTileOffsets,
|
||||
resultTileSizes))) {
|
||||
op.emitOpError("unable to get position of result ")
|
||||
<< resultNum << " of the tiled implementation";
|
||||
return {};
|
||||
}
|
||||
SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
|
||||
one);
|
||||
Value yieldedValue = b.create<tensor::InsertSliceOp>(
|
||||
op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
|
||||
newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
|
||||
resultTileStrides);
|
||||
yieldedValues.push_back(yieldedValue);
|
||||
}
|
||||
return yieldedValues;
|
||||
};
|
||||
SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
|
||||
rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
|
||||
yieldValueFn);
|
||||
for (auto loop : llvm::enumerate(tilingResult.loops)) {
|
||||
rewriter.eraseOp(loop.value());
|
||||
tilingResult.loops[loop.index()] = newLoops[loop.index()];
|
||||
}
|
||||
rewriter.replaceOp(op, tilingResult.loops.front().getResults());
|
||||
return tilingResult;
|
||||
}
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -101,6 +102,31 @@ mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
|
||||
return newLoop;
|
||||
}
|
||||
|
||||
SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
|
||||
OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
|
||||
ValueRange newIterOperands, NewYieldValueFn newYieldValueFn) {
|
||||
if (loopNest.empty())
|
||||
return {};
|
||||
SmallVector<scf::ForOp> newLoopNest(loopNest.size());
|
||||
|
||||
newLoopNest.back() = replaceLoopWithNewYields(
|
||||
builder, loopNest.back(), newIterOperands, newYieldValueFn);
|
||||
|
||||
for (unsigned loopDepth :
|
||||
llvm::reverse(llvm::seq<unsigned>(0, loopNest.size() - 1))) {
|
||||
NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc,
|
||||
ArrayRef<BlockArgument> innerNewBBArgs) {
|
||||
SmallVector<Value> newYields(
|
||||
newLoopNest[loopDepth + 1]->getResults().take_back(
|
||||
newIterOperands.size()));
|
||||
return newYields;
|
||||
};
|
||||
newLoopNest[loopDepth] = replaceLoopWithNewYields(
|
||||
builder, loopNest[loopDepth], newIterOperands, fn);
|
||||
}
|
||||
return newLoopNest;
|
||||
}
|
||||
|
||||
/// Outline a region with a single block into a new FuncOp.
|
||||
/// Assumes the FuncOp result types is the type of the yielded operands of the
|
||||
/// single block. This constraint makes it easy to determine the result.
|
||||
|
||||
194
mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
Normal file
194
mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
Normal file
@@ -0,0 +1,194 @@
|
||||
// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s
|
||||
|
||||
func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
|
||||
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"}
|
||||
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
|
||||
// CHECK: func.func @simple_matmul(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
|
||||
// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
|
||||
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
|
||||
// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[ARG2]])
|
||||
// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]]
|
||||
// CHECK: %[[INNER:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
|
||||
// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]])
|
||||
// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]]
|
||||
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
|
||||
// CHECK-SAME: [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
|
||||
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
|
||||
// CHECK-SAME: [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1]
|
||||
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT1]]
|
||||
// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
|
||||
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
|
||||
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
|
||||
// CHECK-SAME: outs(%[[INIT_TILE]] :
|
||||
// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INIT1]]
|
||||
// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
|
||||
// CHECK: scf.yield %[[UPDATE]]
|
||||
// CHECK: scf.yield %[[INNER]]
|
||||
// CHECK: return %[[OUTER]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @simple_matmul_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
|
||||
%arg2 : memref<?x?xf32>) {
|
||||
linalg.matmul {__internal_linalg_transform__ = "simple_gemm_memref"}
|
||||
ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
|
||||
outs(%arg2 : memref<?x?xf32>)
|
||||
return
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
|
||||
// CHECK: func.func @simple_matmul_memref(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
|
||||
// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
|
||||
// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
|
||||
// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
|
||||
// CHECK: %[[TS_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[M]]]
|
||||
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
|
||||
// CHECK: %[[TS_N:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[N]]]
|
||||
// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
|
||||
// CHECK: %[[TS_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[K]]]
|
||||
// CHECK-DAG: %[[LHS_TILE:.+]] = memref.subview %[[ARG0]]
|
||||
// CHECK-SAME: [%[[IV0]], %[[IV2]]] [%[[TS_M]], %[[TS_K]]] [1, 1]
|
||||
// CHECK-DAG: %[[RHS_TILE:.+]] = memref.subview %[[ARG1]]
|
||||
// CHECK-SAME: [%[[IV2]], %[[IV1]]] [%[[TS_K]], %[[TS_N]]] [1, 1]
|
||||
// CHECK-DAG: %[[OUT_TILE:.+]] = memref.subview %[[ARG2]]
|
||||
// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
|
||||
// CHECK-SAME: outs(%[[OUT_TILE]] :
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
|
||||
func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
|
||||
%init0 = linalg.init_tensor [128, 300, 200] : tensor<128x300x200xf32>
|
||||
%init1 = linalg.init_tensor [300, 128, 200] : tensor<300x128x200xf32>
|
||||
%0:2 = linalg.generic {
|
||||
indexing_maps = [#map0, #map1, #map2],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
{__internal_linalg_transform__ = "parallel_generic_transpose"}
|
||||
ins(%arg0 : tensor<128x200x300xf32>)
|
||||
outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
|
||||
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
|
||||
linalg.yield %b0, %b0 : f32, f32
|
||||
} -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>)
|
||||
return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
|
||||
// CHECK: func.func @multi_result(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
|
||||
// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
|
||||
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
|
||||
// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index
|
||||
// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [128, 300, 200]
|
||||
// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [300, 128, 200]
|
||||
// CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
|
||||
// CHECK-SAME: iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
|
||||
// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[C128]]]
|
||||
// CHECK: %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
|
||||
// CHECK-SAME: iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
|
||||
// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[C300]]]
|
||||
// CHECK-DAG: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
|
||||
// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, %[[TS_X]]] [1, 1, 1]
|
||||
// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
|
||||
// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1]
|
||||
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]]
|
||||
// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1]
|
||||
// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic
|
||||
// CHECK-SAME: ins(%[[ARG_TILE]] :
|
||||
// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
|
||||
// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]]
|
||||
// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], %[[TS_X]], 200] [1, 1, 1]
|
||||
// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]]
|
||||
// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [%[[TS_X]], %[[TS_Y]], 200] [1, 1, 1]
|
||||
// CHECK: scf.yield %[[UPDATE0]], %[[UPDATE1]]
|
||||
// CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1
|
||||
// CHECK: return %[[OUTER]]#0, %[[OUTER]]#1
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
|
||||
%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||
%0 = linalg.conv_2d_nhwc_hwcf {
|
||||
strides = dense<[2, 3]> : tensor<2xi64>,
|
||||
dilation = dense<[4, 5]> : tensor<2xi64>,
|
||||
__internal_linalg_transform__ = "simple_conv"}
|
||||
ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||
outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
|
||||
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 2 - 2)>
|
||||
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 3 - 3)>
|
||||
// CHECK: func.func @conv2D(
|
||||
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
|
||||
// CHECK-SAME: %[[FILTER:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
|
||||
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
|
||||
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
|
||||
// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
|
||||
// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
|
||||
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]]
|
||||
// CHECK-DAG: %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]]
|
||||
// CHECK-DAG: %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]]
|
||||
// CHECK-DAG: %[[Q:.+]] = tensor.dim %[[FILTER]], %[[C1]]
|
||||
// CHECK-DAG: %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]]
|
||||
// CHECK-DAG: %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]]
|
||||
// CHECK-DAG: %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
|
||||
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C10]]
|
||||
// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[INIT]])
|
||||
// CHECK: %[[TS_P:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[P]]]
|
||||
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C20]]
|
||||
// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]])
|
||||
// CHECK: %[[TS_Q:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[Q]]]
|
||||
// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C30]]
|
||||
// CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT1]])
|
||||
// CHECK-DAG: %[[TS_C:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C30]], %[[C]]]
|
||||
// CHECK-DAG: %[[TS_H:.+]] = affine.apply #[[MAP3]](%[[TS_P]])[%[[R]]]
|
||||
// CHECK-DAG: %[[TS_W:.+]] = affine.apply #[[MAP4]](%[[TS_Q]])[%[[S]]]
|
||||
// CHECK-DAG: %[[INPUT_TILE:.+]] = tensor.extract_slice %[[INPUT]]
|
||||
// CHECK-SAME: [0, %[[IV0]], %[[IV1]], %[[IV2]]] [%[[N]], %[[TS_H]], %[[TS_W]], %[[TS_C]]]
|
||||
// CHECK-DAG: %[[FILTER_TILE:.+]] = tensor.extract_slice %[[FILTER]]
|
||||
// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], 0] [%[[TS_P]], %[[TS_Q]], %[[TS_C]], %[[F]]]
|
||||
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT2]]
|
||||
// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
|
||||
// CHECK: %[[CONV_TILE:.+]] = linalg.conv_2d_nhwc_hwcf
|
||||
// CHECK-SAME: dilation = dense<[4, 5]> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>
|
||||
// CHECK-SAME: ins(%[[INPUT_TILE]], %[[FILTER_TILE]] :
|
||||
// CHECK-SAME: outs(%[[INIT_TILE]] :
|
||||
// CHECK: tensor.insert_slice %[[CONV_TILE]] into %[[INIT2]]
|
||||
// CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
|
||||
@@ -1,6 +1,7 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Interfaces)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Pass)
|
||||
add_subdirectory(Reducer)
|
||||
|
||||
1
mlir/test/lib/Interfaces/CMakeLists.txt
Normal file
1
mlir/test/lib/Interfaces/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(TilingInterface)
|
||||
15
mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt
Normal file
15
mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
add_mlir_library(MLIRTilingInterfaceTestPasses
|
||||
TestTilingInterface.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAffine
|
||||
MLIRArithmetic
|
||||
MLIRLinalg
|
||||
MLIRLinalgTransforms
|
||||
MLIRMemRef
|
||||
MLIRSCF
|
||||
MLIRSCFTransforms
|
||||
MLIRTensor
|
||||
)
|
||||
126
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Normal file
126
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Normal file
@@ -0,0 +1,126 @@
|
||||
//===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a pass for testing tiling operations using
|
||||
// `TilingInterface`.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SCF/TileUsingInterface.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Construct a generic pattern applied to all TilingInterface ops that verify
|
||||
/// `filter`.
|
||||
struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
|
||||
TestTileUsingSCFForOpWithFilter(MLIRContext *context,
|
||||
scf::SCFTilingOptions options,
|
||||
linalg::LinalgTransformationFilter filter =
|
||||
linalg::LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
|
||||
|
||||
/// Construct a generic pattern applied to `opName`.
|
||||
TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context,
|
||||
scf::SCFTilingOptions options,
|
||||
linalg::LinalgTransformationFilter filter =
|
||||
linalg::LinalgTransformationFilter(),
|
||||
PatternBenefit benefit = 1)
|
||||
: scf::TileUsingSCFForOp(context, options, benefit), filter(filter) {}
|
||||
|
||||
LogicalResult matchAndRewrite(TilingInterface op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (failed(filter.checkAndNotify(rewriter, op)))
|
||||
return failure();
|
||||
|
||||
FailureOr<scf::SCFTilingResult> tilingResult =
|
||||
returningMatchAndRewrite(op, rewriter);
|
||||
if (failed(tilingResult)) {
|
||||
return failure();
|
||||
}
|
||||
filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
linalg::LinalgTransformationFilter filter;
|
||||
};
|
||||
|
||||
struct TestTilingInterfacePass
|
||||
: public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
|
||||
|
||||
TestTilingInterfacePass() = default;
|
||||
TestTilingInterfacePass(const TestTilingInterfacePass &pass)
|
||||
: PassWrapper(pass) {}
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
|
||||
tensor::TensorDialect>();
|
||||
linalg::registerTilingInterfaceExternalModels(registry);
|
||||
}
|
||||
StringRef getArgument() const final { return "test-tiling-interface"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test tiling using TilingInterface";
|
||||
}
|
||||
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) {
|
||||
auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes,
|
||||
StringRef filterName) {
|
||||
scf::SCFTilingOptions tilingOptions;
|
||||
tilingOptions.setTileSizes(tileSizes);
|
||||
linalg::LinalgTransformationFilter filter(
|
||||
StringAttr::get(context, filterName),
|
||||
StringAttr::get(context, "tiled"));
|
||||
patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions,
|
||||
filter);
|
||||
};
|
||||
// 1. Tiling M and N dims of `linalg.matmul` on tensors.
|
||||
addPatternForTiling({10, 20}, "simple_gemm");
|
||||
// 2. Tiling M, N and K of `linalg.matmul` on buffers.
|
||||
addPatternForTiling({10, 20, 30}, "simple_gemm_memref");
|
||||
// 3. Tiling 3D parallel generic op which implements a transpose
|
||||
addPatternForTiling({10, 0, 20}, "parallel_generic_transpose");
|
||||
// 4. Tiling 2D conv op.
|
||||
addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv");
|
||||
}
|
||||
|
||||
void TestTilingInterfacePass::runOnOperation() {
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
RewritePatternSet tilingPatterns(context);
|
||||
addTestPatterns(context, tilingPatterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(tilingPatterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestTilingInterface() {
|
||||
PassRegistration<TestTilingInterfacePass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
@@ -33,6 +33,7 @@ if(MLIR_INCLUDE_TESTS)
|
||||
MLIRTestRewrite
|
||||
MLIRTestTransformDialect
|
||||
MLIRTestTransforms
|
||||
MLIRTilingInterfaceTestPasses
|
||||
MLIRVectorTestPasses
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -111,6 +111,7 @@ void registerTestRecursiveTypesPass();
|
||||
void registerTestSCFUtilsPass();
|
||||
void registerTestSliceAnalysisPass();
|
||||
void registerTestTensorTransforms();
|
||||
void registerTestTilingInterface();
|
||||
void registerTestTransformDialectInterpreterPass();
|
||||
void registerTestVectorLowerings();
|
||||
} // namespace test
|
||||
@@ -206,6 +207,7 @@ void registerTestPasses() {
|
||||
mlir::test::registerTestSCFUtilsPass();
|
||||
mlir::test::registerTestSliceAnalysisPass();
|
||||
mlir::test::registerTestTensorTransforms();
|
||||
mlir::test::registerTestTilingInterface();
|
||||
mlir::test::registerTestTransformDialectInterpreterPass();
|
||||
mlir::test::registerTestVectorLowerings();
|
||||
}
|
||||
|
||||
@@ -1864,6 +1864,7 @@ cc_library(
|
||||
"include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
|
||||
"include/mlir/Dialect/SCF/Passes.h",
|
||||
"include/mlir/Dialect/SCF/Patterns.h",
|
||||
"include/mlir/Dialect/SCF/TileUsingInterface.h",
|
||||
"include/mlir/Dialect/SCF/Transforms.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
@@ -1883,6 +1884,7 @@ cc_library(
|
||||
":SCFUtils",
|
||||
":Support",
|
||||
":TensorDialect",
|
||||
":TilingInterface",
|
||||
":Transforms",
|
||||
"//llvm:Support",
|
||||
],
|
||||
@@ -2645,6 +2647,7 @@ cc_library(
|
||||
exclude = [
|
||||
"include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
|
||||
"include/mlir/Dialect/SCF/Patterns.h",
|
||||
"include/mlir/Dialect/SCF/TileUsingInterface.h",
|
||||
"include/mlir/Dialect/SCF/Transforms.h",
|
||||
],
|
||||
),
|
||||
@@ -6313,6 +6316,7 @@ cc_binary(
|
||||
"//mlir/test:TestSPIRV",
|
||||
"//mlir/test:TestShapeDialect",
|
||||
"//mlir/test:TestTensor",
|
||||
"//mlir/test:TestTilingInterface",
|
||||
"//mlir/test:TestTosaDialect",
|
||||
"//mlir/test:TestTransformDialect",
|
||||
"//mlir/test:TestTransforms",
|
||||
@@ -7492,6 +7496,7 @@ cc_library(
|
||||
":TensorTilingInterfaceImpl",
|
||||
":TensorTransforms",
|
||||
":TensorUtils",
|
||||
":TilingInterface",
|
||||
":TransformUtils",
|
||||
":Transforms",
|
||||
":VectorDialect",
|
||||
|
||||
@@ -293,6 +293,28 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "TestTilingInterface",
|
||||
srcs = glob(["lib/Interfaces/TilingInterface/*.cpp"]),
|
||||
includes = ["lib/Dialect/Test"],
|
||||
deps = [
|
||||
"//llvm:Support",
|
||||
"//mlir:Affine",
|
||||
"//mlir:ArithmeticDialect",
|
||||
"//mlir:FuncDialect",
|
||||
"//mlir:IR",
|
||||
"//mlir:LinalgDialect",
|
||||
"//mlir:LinalgTransforms",
|
||||
"//mlir:MemRefDialect",
|
||||
"//mlir:Pass",
|
||||
"//mlir:SCFDialect",
|
||||
"//mlir:SCFTransforms",
|
||||
"//mlir:TensorDialect",
|
||||
"//mlir:TilingInterface",
|
||||
"//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "TestPass",
|
||||
srcs = glob(["lib/Pass/*.cpp"]),
|
||||
|
||||
Reference in New Issue
Block a user