mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 21:53:12 +08:00
Now that a view op has graduated to the std dialect, we can update Linalg to use it and remove ops that have become obsolete. As a byproduct, the linalg buffer and associated ops can also disappear. PiperOrigin-RevId: 279073591
564 lines
23 KiB
C++
564 lines
23 KiB
C++
//===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
|
|
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
|
#include "mlir/Conversion/VectorToLLVM/VectorToLLVM.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
|
|
#include "mlir/EDSC/Builders.h"
|
|
#include "mlir/EDSC/Intrinsics.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Module.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/StandardTypes.h"
|
|
#include "mlir/IR/Types.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/LowerAffine.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/IR/DerivedTypes.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/Support/Allocator.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::edsc;
|
|
using namespace mlir::edsc::intrinsics;
|
|
using namespace mlir::LLVM;
|
|
using namespace mlir::linalg;
|
|
using namespace mlir::linalg::intrinsics;
|
|
|
|
using add = ValueBuilder<mlir::LLVM::AddOp>;
|
|
using addi = ValueBuilder<mlir::AddIOp>;
|
|
using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>;
|
|
using cmpi = ValueBuilder<mlir::CmpIOp>;
|
|
using constant = ValueBuilder<mlir::LLVM::ConstantOp>;
|
|
using extractvalue = ValueBuilder<mlir::LLVM::ExtractValueOp>;
|
|
using gep = ValueBuilder<mlir::LLVM::GEPOp>;
|
|
using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>;
|
|
using llvm_call = OperationBuilder<mlir::LLVM::CallOp>;
|
|
using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
|
|
using llvm_load = ValueBuilder<LLVM::LoadOp>;
|
|
using llvm_store = OperationBuilder<LLVM::StoreOp>;
|
|
using llvm_select = ValueBuilder<LLVM::SelectOp>;
|
|
using mul = ValueBuilder<mlir::LLVM::MulOp>;
|
|
using ptrtoint = ValueBuilder<mlir::LLVM::PtrToIntOp>;
|
|
using sub = ValueBuilder<mlir::LLVM::SubOp>;
|
|
using llvm_undef = ValueBuilder<mlir::LLVM::UndefOp>;
|
|
using urem = ValueBuilder<mlir::LLVM::URemOp>;
|
|
using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
|
|
using llvm_return = OperationBuilder<LLVM::ReturnOp>;
|
|
|
|
template <typename T>
|
|
static LLVMType getPtrToElementType(T containerType,
|
|
LLVMTypeConverter &lowering) {
|
|
return lowering.convertType(containerType.getElementType())
|
|
.template cast<LLVMType>()
|
|
.getPointerTo();
|
|
}
|
|
|
|
// Convert the given type to the LLVM IR Dialect type. The following
|
|
// conversions are supported:
|
|
// - an Index type is converted into an LLVM integer type with pointer
|
|
// bitwidth (analogous to intptr_t in C);
|
|
// - an Integer type is converted into an LLVM integer type of the same width;
|
|
// - an F32 type is converted into an LLVM float type
|
|
// - a Buffer, Range or View is converted into an LLVM structure type
|
|
// containing the respective dynamic values.
|
|
static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
|
|
auto *context = t.getContext();
|
|
auto int64Ty = lowering.convertType(IntegerType::get(64, context))
|
|
.cast<LLVM::LLVMType>();
|
|
|
|
// A buffer descriptor contains the pointer to a flat region of storage and
|
|
// the size of the region.
|
|
//
|
|
// template <typename Elem, size_t Rank>
|
|
// struct {
|
|
// void *baseAlloc;
|
|
// Elem *ptr;
|
|
// int64_t size;
|
|
// };
|
|
if (auto bufferType = t.dyn_cast<BufferType>()) {
|
|
auto voidPtrTy = LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
|
auto ptrTy = getPtrToElementType(bufferType, lowering);
|
|
return LLVMType::getStructTy(voidPtrTy, ptrTy, int64Ty);
|
|
}
|
|
|
|
// Range descriptor contains the range bounds and the step as 64-bit integers.
|
|
//
|
|
// struct {
|
|
// int64_t min;
|
|
// int64_t max;
|
|
// int64_t step;
|
|
// };
|
|
if (t.isa<RangeType>())
|
|
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
|
|
|
|
return Type();
|
|
}
|
|
|
|
static constexpr int kBasePtrPosInBuffer = 0;
|
|
static constexpr int kPtrPosInBuffer = 1;
|
|
static constexpr int kSizePosInBuffer = 2;
|
|
static constexpr int kPtrPosInView = 0;
|
|
static constexpr int kOffsetPosInView = 1;
|
|
static constexpr int kSizePosInView = 2;
|
|
static constexpr int kStridePosInView = 3;
|
|
|
|
namespace {
|
|
/// Factor out the common information for all view conversions:
|
|
/// 1. common types in (standard and LLVM dialects)
|
|
/// 2. `pos` method
|
|
/// 3. view descriptor construction `desc`.
|
|
class BaseViewConversionHelper {
|
|
public:
|
|
BaseViewConversionHelper(Location loc, MemRefType memRefType,
|
|
ConversionPatternRewriter &rewriter,
|
|
LLVMTypeConverter &lowering)
|
|
: zeroDMemRef(memRefType.getRank() == 0),
|
|
elementTy(getPtrToElementType(memRefType, lowering)),
|
|
int64Ty(
|
|
lowering.convertType(rewriter.getIntegerType(64)).cast<LLVMType>()),
|
|
desc(nullptr), rewriter(rewriter) {
|
|
assert(isStrided(memRefType) && "expected strided memref type");
|
|
viewDescriptorTy = lowering.convertType(memRefType).cast<LLVMType>();
|
|
desc = rewriter.create<LLVM::UndefOp>(loc, viewDescriptorTy);
|
|
}
|
|
|
|
ArrayAttr pos(ArrayRef<int64_t> values) const {
|
|
return rewriter.getI64ArrayAttr(values);
|
|
};
|
|
|
|
bool zeroDMemRef;
|
|
LLVMType elementTy, int64Ty, viewDescriptorTy;
|
|
Value *desc;
|
|
ConversionPatternRewriter &rewriter;
|
|
};
|
|
} // namespace
|
|
|
|
// RangeOp creates a new range descriptor.
|
|
class RangeOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
|
: LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
|
|
|
|
PatternMatchResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto rangeOp = cast<RangeOp>(op);
|
|
auto rangeDescriptorTy =
|
|
convertLinalgType(rangeOp.getResult()->getType(), lowering);
|
|
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
|
|
// Fill in an aggregate value of the descriptor.
|
|
RangeOpOperandAdaptor adaptor(operands);
|
|
Value *desc = llvm_undef(rangeDescriptorTy);
|
|
desc = insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0));
|
|
desc = insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
|
|
desc = insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
|
|
rewriter.replaceOp(op, desc);
|
|
return matchSuccess();
|
|
}
|
|
};
|
|
|
|
/// Conversion pattern that transforms a linalg.slice op into:
|
|
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
|
|
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
|
|
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
|
|
/// and stride corresponding to the region of memory within the bounds of
|
|
/// the parent view.
|
|
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
|
|
/// The linalg.slice op is replaced by the alloca'ed pointer.
|
|
class SliceOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
|
: LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
|
|
|
|
PatternMatchResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
SliceOpOperandAdaptor adaptor(operands);
|
|
Value *baseDesc = adaptor.view();
|
|
|
|
auto sliceOp = cast<SliceOp>(op);
|
|
auto memRefType = sliceOp.getBaseViewType();
|
|
|
|
BaseViewConversionHelper helper(op->getLoc(), sliceOp.getViewType(),
|
|
rewriter, lowering);
|
|
LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
|
|
Value *desc = helper.desc;
|
|
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
|
|
// TODO(ntv): extract sizes and emit asserts.
|
|
SmallVector<Value *, 4> strides(memRefType.getRank());
|
|
for (int i = 0, e = memRefType.getRank(); i < e; ++i)
|
|
strides[i] =
|
|
extractvalue(int64Ty, baseDesc, helper.pos({kStridePosInView, i}));
|
|
|
|
// Compute base offset.
|
|
Value *baseOffset =
|
|
extractvalue(int64Ty, baseDesc, helper.pos(kOffsetPosInView));
|
|
for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
|
|
Value *indexing = adaptor.indexings()[i];
|
|
Value *min = indexing;
|
|
if (sliceOp.indexing(i)->getType().isa<RangeType>())
|
|
min = extractvalue(int64Ty, indexing, helper.pos(0));
|
|
baseOffset = add(baseOffset, mul(min, strides[i]));
|
|
}
|
|
|
|
// Insert base pointer.
|
|
auto ptrPos = helper.pos(kPtrPosInView);
|
|
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
|
|
|
|
// Insert base offset.
|
|
desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView));
|
|
|
|
// Corner case, no sizes or strides: early return the descriptor.
|
|
if (helper.zeroDMemRef)
|
|
return rewriter.replaceOp(op, desc), matchSuccess();
|
|
|
|
Value *zero =
|
|
constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
|
// Compute and insert view sizes (max - min along the range) and strides.
|
|
// Skip the non-range operands as they will be projected away from the view.
|
|
int numNewDims = 0;
|
|
for (auto en : llvm::enumerate(sliceOp.indexings())) {
|
|
Value *indexing = en.value();
|
|
if (indexing->getType().isa<RangeType>()) {
|
|
int rank = en.index();
|
|
Value *rangeDescriptor = adaptor.indexings()[rank];
|
|
Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0));
|
|
Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1));
|
|
Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2));
|
|
Value *baseSize =
|
|
extractvalue(int64Ty, baseDesc, helper.pos({kSizePosInView, rank}));
|
|
// Bound upper by base view upper bound.
|
|
max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
|
|
baseSize);
|
|
Value *size = sub(max, min);
|
|
// Bound lower by zero.
|
|
size =
|
|
llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
|
|
Value *stride = mul(strides[rank], step);
|
|
desc =
|
|
insertvalue(desc, size, helper.pos({kSizePosInView, numNewDims}));
|
|
desc = insertvalue(desc, stride,
|
|
helper.pos({kStridePosInView, numNewDims}));
|
|
++numNewDims;
|
|
}
|
|
}
|
|
|
|
rewriter.replaceOp(op, desc);
|
|
return matchSuccess();
|
|
}
|
|
};
|
|
|
|
/// Conversion pattern that transforms a linalg.transpose op into:
|
|
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
|
|
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
|
|
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
|
|
/// and stride. Size and stride are permutations of the original values.
|
|
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
|
|
/// The linalg.transpose op is replaced by the alloca'ed pointer.
|
|
class TransposeOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit TransposeOpConversion(MLIRContext *context,
|
|
LLVMTypeConverter &lowering_)
|
|
: LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {}
|
|
|
|
PatternMatchResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Initialize the common boilerplate and alloca at the top of the FuncOp.
|
|
TransposeOpOperandAdaptor adaptor(operands);
|
|
Value *baseDesc = adaptor.view();
|
|
|
|
auto transposeOp = cast<TransposeOp>(op);
|
|
// No permutation, early exit.
|
|
if (transposeOp.permutation().isIdentity())
|
|
return rewriter.replaceOp(op, baseDesc), matchSuccess();
|
|
|
|
BaseViewConversionHelper helper(op->getLoc(), transposeOp.getViewType(),
|
|
rewriter, lowering);
|
|
LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
|
|
Value *desc = helper.desc;
|
|
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
// Copy the base pointer from the old descriptor to the new one.
|
|
ArrayAttr ptrPos = helper.pos(kPtrPosInView);
|
|
desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
|
|
|
|
// Copy the offset pointer from the old descriptor to the new one.
|
|
ArrayAttr offPos = helper.pos(kOffsetPosInView);
|
|
desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos);
|
|
|
|
// Iterate over the dimensions and apply size/stride permutation.
|
|
for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) {
|
|
int sourcePos = en.index();
|
|
int targetPos = en.value().cast<AffineDimExpr>().getPosition();
|
|
Value *size = extractvalue(int64Ty, baseDesc,
|
|
helper.pos({kSizePosInView, sourcePos}));
|
|
desc = insertvalue(desc, size, helper.pos({kSizePosInView, targetPos}));
|
|
Value *stride = extractvalue(int64Ty, baseDesc,
|
|
helper.pos({kStridePosInView, sourcePos}));
|
|
desc =
|
|
insertvalue(desc, stride, helper.pos({kStridePosInView, targetPos}));
|
|
}
|
|
|
|
rewriter.replaceOp(op, desc);
|
|
return matchSuccess();
|
|
}
|
|
};
|
|
|
|
// YieldOp produces and LLVM::ReturnOp.
|
|
class YieldOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
|
: LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {}
|
|
|
|
PatternMatchResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
|
|
return matchSuccess();
|
|
}
|
|
};
|
|
|
|
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
|
|
// If the library function does not exist, insert a declaration.
|
|
template <typename LinalgOp>
|
|
static SymbolRefAttr getLibraryCallSymbolRef(Operation *op,
|
|
PatternRewriter &rewriter) {
|
|
auto linalgOp = cast<LinalgOp>(op);
|
|
auto fnName = linalgOp.getLibraryCallName();
|
|
if (fnName.empty()) {
|
|
op->emitWarning("No library call defined for: ") << *op;
|
|
return {};
|
|
}
|
|
|
|
// fnName is a dynamic std::String, unique it via a SymbolRefAttr.
|
|
SymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
|
|
auto module = op->getParentOfType<ModuleOp>();
|
|
if (module.lookupSymbol(fnName)) {
|
|
return fnNameAttr;
|
|
}
|
|
|
|
SmallVector<Type, 4> inputTypes(op->getOperandTypes());
|
|
assert(op->getNumResults() == 0 &&
|
|
"Library call for linalg operation can be generated only for ops that "
|
|
"have void return types");
|
|
auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
// Insert before module terminator.
|
|
rewriter.setInsertionPoint(module.getBody(),
|
|
std::prev(module.getBody()->end()));
|
|
rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType,
|
|
ArrayRef<NamedAttribute>{});
|
|
return fnNameAttr;
|
|
}
|
|
|
|
namespace {
|
|
// The conversion class from Linalg to LLVMIR.
|
|
class LinalgTypeConverter : public LLVMTypeConverter {
|
|
using LLVMTypeConverter::LLVMTypeConverter;
|
|
|
|
public:
|
|
Type convertType(Type t) override {
|
|
if (auto result = LLVMTypeConverter::convertType(t))
|
|
return result;
|
|
return convertLinalgType(t, *this);
|
|
}
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
// LinalgOpConversion<LinalgOp> creates a new call to the
|
|
// `LinalgOp::getLibraryCallName()` function.
|
|
// The implementation of the function can be either in the same module or in an
|
|
// externally linked library.
|
|
template <typename LinalgOp>
|
|
class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
|
|
public:
|
|
using OpRewritePattern<LinalgOp>::OpRewritePattern;
|
|
|
|
PatternMatchResult matchAndRewrite(LinalgOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
|
|
if (!libraryCallName)
|
|
return this->matchFailure();
|
|
|
|
SmallVector<Value *, 4> operands(op.getOperands().begin(),
|
|
op.getOperands().end());
|
|
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
|
|
ArrayRef<Type>{}, operands);
|
|
return this->matchSuccess();
|
|
}
|
|
};
|
|
|
|
/// Conversion pattern specialization for CopyOp. This kicks in when both input
|
|
/// and output permutations are left unspecified or are the identity.
|
|
template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
|
|
public:
|
|
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
|
|
|
PatternMatchResult matchAndRewrite(CopyOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto inputPerm = op.inputPermutation();
|
|
if (inputPerm.hasValue() && !inputPerm->isIdentity())
|
|
return matchFailure();
|
|
auto outputPerm = op.outputPermutation();
|
|
if (outputPerm.hasValue() && !outputPerm->isIdentity())
|
|
return matchFailure();
|
|
|
|
auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
|
|
if (!libraryCallName)
|
|
return matchFailure();
|
|
|
|
SmallVector<Value *, 4> operands(op.getOperands().begin(),
|
|
op.getOperands().end());
|
|
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
|
|
ArrayRef<Type>{}, operands);
|
|
return matchSuccess();
|
|
}
|
|
};
|
|
|
|
/// A non-conversion rewrite pattern kicks in to convert CopyOp with
|
|
/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
|
|
/// This interplays together with TransposeOpConversion and
|
|
/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
|
|
class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
|
|
public:
|
|
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
|
|
|
PatternMatchResult matchAndRewrite(CopyOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value *in = op.input(), *out = op.output();
|
|
|
|
// If either inputPerm or outputPerm are non-identities, insert transposes.
|
|
auto inputPerm = op.inputPermutation();
|
|
if (inputPerm.hasValue() && !inputPerm->isIdentity())
|
|
in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
|
|
AffineMapAttr::get(*inputPerm));
|
|
auto outputPerm = op.outputPermutation();
|
|
if (outputPerm.hasValue() && !outputPerm->isIdentity())
|
|
out = rewriter.create<linalg::TransposeOp>(
|
|
op.getLoc(), out, AffineMapAttr::get(*outputPerm));
|
|
|
|
// If nothing was transposed, fail and let the conversion kick in.
|
|
if (in == op.input() && out == op.output())
|
|
return matchFailure();
|
|
|
|
rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
|
|
return matchSuccess();
|
|
}
|
|
};
|
|
|
|
/// A non-conversion rewrite pattern kicks in to convert SubViewOp into RangeOps
|
|
/// and SliceOps.
|
|
class SubViewOpConversion : public OpRewritePattern<SubViewOp> {
|
|
public:
|
|
using OpRewritePattern<SubViewOp>::OpRewritePattern;
|
|
|
|
PatternMatchResult matchAndRewrite(SubViewOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto *view = op.getView();
|
|
SmallVector<Value *, 8> ranges;
|
|
for (auto sliceRange : op.getRanges())
|
|
ranges.push_back(rewriter.create<RangeOp>(
|
|
op.getLoc(), sliceRange.min, sliceRange.max, sliceRange.step));
|
|
rewriter.replaceOpWithNewOp<SliceOp>(op, view, ranges);
|
|
return matchSuccess();
|
|
}
|
|
};
|
|
|
|
/// Populate the given list with patterns that convert from Linalg to Standard.
|
|
static void
|
|
populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
|
|
MLIRContext *ctx) {
|
|
// TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
|
|
// attribute values such as kernel striding and dilation.
|
|
patterns.insert<CopyTransposeConversion, LinalgOpConversion<CopyOp>,
|
|
LinalgOpConversion<DotOp>, LinalgOpConversion<FillOp>,
|
|
LinalgOpConversion<MatvecOp>, LinalgOpConversion<MatmulOp>,
|
|
LinalgOpConversion<ConvOp>, LinalgOpConversion<GenericOp>,
|
|
SubViewOpConversion>(ctx);
|
|
}
|
|
|
|
/// Populate the given list with patterns that convert from Linalg to LLVM.
|
|
static void
|
|
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
|
|
OwningRewritePatternList &patterns,
|
|
MLIRContext *ctx) {
|
|
patterns.insert<RangeOpConversion, SliceOpConversion, TransposeOpConversion,
|
|
YieldOpConversion>(ctx, converter);
|
|
}
|
|
|
|
namespace {
|
|
struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
|
|
void runOnModule() override;
|
|
};
|
|
} // namespace
|
|
|
|
void LowerLinalgToLLVMPass::runOnModule() {
|
|
auto module = getModule();
|
|
|
|
// Convert to the LLVM IR dialect using the converter defined above.
|
|
OwningRewritePatternList patterns;
|
|
LinalgTypeConverter converter(&getContext());
|
|
populateAffineToStdConversionPatterns(patterns, &getContext());
|
|
populateLoopToStdConversionPatterns(patterns, &getContext());
|
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
|
populateVectorToLLVMConversionPatterns(converter, patterns);
|
|
populateLinalgToStandardConversionPatterns(patterns, &getContext());
|
|
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
|
|
|
|
ConversionTarget target(getContext());
|
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
|
target.addDynamicallyLegalOp<FuncOp>(
|
|
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
|
if (failed(applyFullConversion(module, target, patterns, &converter)))
|
|
signalPassFailure();
|
|
}
|
|
|
|
std::unique_ptr<OpPassBase<ModuleOp>>
|
|
mlir::linalg::createLowerLinalgToLLVMPass() {
|
|
return std::make_unique<LowerLinalgToLLVMPass>();
|
|
}
|
|
|
|
static PassRegistration<LowerLinalgToLLVMPass>
|
|
pass("convert-linalg-to-llvm",
|
|
"Lower the operations from the linalg dialect into the LLVM dialect");
|