mirror of
https://github.com/intel/llvm.git
synced 2026-01-27 06:06:34 +08:00
623 lines
25 KiB
C++
623 lines
25 KiB
C++
//===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
|
|
#include "mlir/EDSC/Builders.h"
|
|
#include "mlir/EDSC/Intrinsics.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Module.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/StandardTypes.h"
|
|
#include "mlir/IR/Types.h"
|
|
#include "mlir/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/LLVMIR/LLVMLowering.h"
|
|
#include "mlir/LLVMIR/Transforms.h"
|
|
#include "mlir/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/Linalg/IR/LinalgTypes.h"
|
|
#include "mlir/Linalg/Passes.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
#include "llvm/IR/DerivedTypes.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/Support/Allocator.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::edsc;
|
|
using namespace mlir::edsc::intrinsics;
|
|
using namespace mlir::LLVM;
|
|
using namespace mlir::linalg;
|
|
|
|
using undef = ValueBuilder<mlir::LLVM::UndefOp>;
|
|
using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>;
|
|
using extractvalue = ValueBuilder<mlir::LLVM::ExtractValueOp>;
|
|
using constant = ValueBuilder<mlir::LLVM::ConstantOp>;
|
|
using add = ValueBuilder<mlir::LLVM::AddOp>;
|
|
using sub = ValueBuilder<mlir::LLVM::SubOp>;
|
|
using mul = ValueBuilder<mlir::LLVM::MulOp>;
|
|
using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>;
|
|
using call = OperationBuilder<mlir::LLVM::CallOp>;
|
|
using gep = ValueBuilder<mlir::LLVM::GEPOp>;
|
|
using llvm_load = ValueBuilder<LLVM::LoadOp>;
|
|
using llvm_store = OperationBuilder<LLVM::StoreOp>;
|
|
using llvm_select = ValueBuilder<LLVM::SelectOp>;
|
|
using icmp = ValueBuilder<LLVM::ICmpOp>;
|
|
|
|
template <typename T>
|
|
static LLVMType getPtrToElementType(T containerType, LLVMLowering &lowering) {
|
|
return lowering.convertType(containerType.getElementType())
|
|
.template cast<LLVMType>()
|
|
.getPointerTo();
|
|
}
|
|
|
|
// Convert the given type to the LLVM IR Dialect type. The following
|
|
// conversions are supported:
|
|
// - an Index type is converted into an LLVM integer type with pointer
|
|
// bitwidth (analogous to intptr_t in C);
|
|
// - an Integer type is converted into an LLVM integer type of the same width;
|
|
// - an F32 type is converted into an LLVM float type
|
|
// - a Buffer, Range or View is converted into an LLVM structure type
|
|
// containing the respective dynamic values.
|
|
static Type convertLinalgType(Type t, LLVMLowering &lowering) {
|
|
auto *context = t.getContext();
|
|
auto int64Ty = lowering.convertType(IntegerType::get(64, context))
|
|
.cast<LLVM::LLVMType>();
|
|
|
|
// A buffer descriptor contains the pointer to a flat region of storage and
|
|
// the size of the region.
|
|
//
|
|
// template <typename Elem, size_t Rank>
|
|
// struct {
|
|
// Elem *ptr;
|
|
// int64_t size;
|
|
// };
|
|
if (auto bufferType = t.dyn_cast<BufferType>()) {
|
|
auto ptrTy = getPtrToElementType(bufferType, lowering);
|
|
return LLVMType::getStructTy(ptrTy, int64Ty);
|
|
}
|
|
|
|
// Range descriptor contains the range bounds and the step as 64-bit integers.
|
|
//
|
|
// struct {
|
|
// int64_t min;
|
|
// int64_t max;
|
|
// int64_t step;
|
|
// };
|
|
if (t.isa<RangeType>())
|
|
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
|
|
|
|
// View descriptor contains the pointer to the data buffer, followed by a
|
|
// 64-bit integer containing the distance between the beginning of the buffer
|
|
// and the first element to be accessed through the view, followed by two
|
|
// arrays, each containing as many 64-bit integers as the rank of the View.
|
|
// The first array represents the size, in number of original elements, of the
|
|
// view along the given dimension. When taking the view, the size is the
|
|
// difference between the upper and the lower bound of the range. The second
|
|
// array represents the "stride" (in tensor abstraction sense), i.e. the
|
|
// number of consecutive elements of the underlying buffer that separate two
|
|
// consecutive elements addressable through the view along the given
|
|
// dimension. When taking the view, the strides are constructed as products
|
|
// of the original sizes along the trailing dimensions, multiplied by the view
|
|
// step. For example, a view of a MxN memref with ranges {0:M:1}, {0:N:1},
|
|
// i.e. the view of a complete memref, will have strides N and 1. A view with
|
|
// ranges {0:M:2}, {0:N:3} will have strides 2*N and 3.
|
|
//
|
|
// template <typename Elem, size_t Rank>
|
|
// struct {
|
|
// Elem *ptr;
|
|
// int64_t offset;
|
|
// int64_t sizes[Rank];
|
|
// int64_t strides[Rank];
|
|
// };
|
|
if (auto viewType = t.dyn_cast<ViewType>()) {
|
|
auto ptrTy = getPtrToElementType(viewType, lowering);
|
|
auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank());
|
|
return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy);
|
|
}
|
|
|
|
return Type();
|
|
}
|
|
|
|
// Create an array attribute containing integer attributes with values provided
|
|
// in `position`.
|
|
static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) {
|
|
SmallVector<Attribute, 4> attrs;
|
|
attrs.reserve(position.size());
|
|
for (auto p : position)
|
|
attrs.push_back(builder.getI64IntegerAttr(p));
|
|
return builder.getArrayAttr(attrs);
|
|
}
|
|
|
|
// BufferAllocOp creates a new `!linalg.buffer` value.
|
|
class BufferAllocOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit BufferAllocOpConversion(MLIRContext *context,
|
|
LLVMLowering &lowering_)
|
|
: LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
|
|
|
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
auto indexType = IndexType::get(op->getContext());
|
|
auto voidPtrTy =
|
|
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
|
auto int64Ty = lowering.convertType(operands[0]->getType());
|
|
// Insert the `malloc` declaration if it is not already present.
|
|
auto *module = op->getFunction()->getModule();
|
|
Function *mallocFunc = module->getNamedFunction("malloc");
|
|
if (!mallocFunc) {
|
|
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
|
|
mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
|
|
module->getFunctions().push_back(mallocFunc);
|
|
}
|
|
|
|
// Get MLIR types for injecting element pointer.
|
|
auto allocOp = cast<BufferAllocOp>(op);
|
|
auto elementType = allocOp.getElementType();
|
|
uint64_t elementSize = 0;
|
|
if (auto vectorType = elementType.dyn_cast<VectorType>())
|
|
elementSize = vectorType.getNumElements() *
|
|
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
|
|
else
|
|
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
|
|
auto elementPtrType = getPtrToElementType(
|
|
allocOp.getResult()->getType().cast<BufferType>(), lowering);
|
|
auto bufferDescriptorType =
|
|
convertLinalgType(allocOp.getResult()->getType(), lowering);
|
|
|
|
// Emit IR for creating a new buffer descriptor with an underlying malloc.
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
Value *size = operands[0];
|
|
Value *allocSize =
|
|
mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
|
|
Value *allocated =
|
|
call(voidPtrTy, rewriter.getFunctionAttr(mallocFunc), allocSize)
|
|
.getOperation()
|
|
->getResult(0);
|
|
allocated = bitcast(elementPtrType, allocated);
|
|
Value *desc = undef(bufferDescriptorType);
|
|
desc = insertvalue(bufferDescriptorType, desc, allocated,
|
|
positionAttr(rewriter, 0));
|
|
desc = insertvalue(bufferDescriptorType, desc, size,
|
|
positionAttr(rewriter, 1));
|
|
rewriter.replaceOp(op, desc);
|
|
}
|
|
};
|
|
|
|
// BufferDeallocOp creates no value.
|
|
class BufferDeallocOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit BufferDeallocOpConversion(MLIRContext *context,
|
|
LLVMLowering &lowering_)
|
|
: LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
|
|
lowering_) {}
|
|
|
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
auto voidPtrTy =
|
|
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
|
|
// Insert the `free` declaration if it is not already present.
|
|
auto *module = op->getFunction()->getModule();
|
|
Function *freeFunc = module->getNamedFunction("free");
|
|
if (!freeFunc) {
|
|
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
|
|
freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
|
|
module->getFunctions().push_back(freeFunc);
|
|
}
|
|
|
|
// Get MLIR types for extracting element pointer.
|
|
auto deallocOp = cast<BufferDeallocOp>(op);
|
|
auto elementPtrTy = getPtrToElementType(
|
|
deallocOp.getOperand()->getType().cast<BufferType>(), lowering);
|
|
|
|
// Emit MLIR for buffer_dealloc.
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
|
|
positionAttr(rewriter, 0)));
|
|
call(ArrayRef<Type>(), rewriter.getFunctionAttr(freeFunc), casted);
|
|
rewriter.replaceOp(op, llvm::None);
|
|
}
|
|
};
|
|
|
|
// BufferSizeOp creates a new `index` value.
|
|
class BufferSizeOpConversion : public LLVMOpLowering {
|
|
public:
|
|
BufferSizeOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
|
: LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
|
|
|
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
auto int64Ty = lowering.convertType(operands[0]->getType());
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
rewriter.replaceOp(
|
|
op, {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))});
|
|
}
|
|
};
|
|
|
|
// DimOp creates a new `index` value.
|
|
class DimOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit DimOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
|
: LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
|
|
|
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
auto dimOp = cast<linalg::DimOp>(op);
|
|
auto indexTy = lowering.convertType(rewriter.getIndexType());
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
rewriter.replaceOp(
|
|
op,
|
|
{extractvalue(
|
|
indexTy, operands[0],
|
|
positionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))});
|
|
}
|
|
};
|
|
|
|
namespace {
|
|
// Common functionality for Linalg LoadOp and StoreOp conversion to the
|
|
// LLVM IR Dialect.
|
|
template <typename Op> class LoadStoreOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit LoadStoreOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
|
: LLVMOpLowering(Op::getOperationName(), context, lowering_) {}
|
|
using Base = LoadStoreOpConversion<Op>;
|
|
|
|
// Compute the pointer to an element of the buffer underlying the view given
|
|
// current view indices. Use the base offset and strides stored in the view
|
|
// descriptor to emit IR iteratively computing the actual offset, followed by
|
|
// a getelementptr. This must be called under an edsc::ScopedContext.
|
|
Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
|
|
ArrayRef<Value *> indices,
|
|
PatternRewriter &rewriter) const {
|
|
auto loadOp = cast<Op>(op);
|
|
auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering);
|
|
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
|
auto pos = [&rewriter](ArrayRef<int> values) {
|
|
return positionAttr(rewriter, values);
|
|
};
|
|
|
|
// Linearize subscripts as:
|
|
// base_offset + SUM_i index_i * stride_i.
|
|
Value *base = extractvalue(elementTy, viewDescriptor, pos(0));
|
|
Value *offset = extractvalue(int64Ty, viewDescriptor, pos(1));
|
|
for (int i = 0, e = loadOp.getRank(); i < e; ++i) {
|
|
Value *stride = extractvalue(int64Ty, viewDescriptor, pos({3, i}));
|
|
Value *additionalOffset = mul(indices[i], stride);
|
|
offset = add(offset, additionalOffset);
|
|
}
|
|
return gep(elementTy, base, offset);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
// A load is converted into the actual address computation, getelementptr and
|
|
// an LLVM IR load.
|
|
class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
|
|
using Base::Base;
|
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
edsc::ScopedContext edscContext(rewriter, op->getLoc());
|
|
auto elementTy = lowering.convertType(*op->result_type_begin());
|
|
Value *viewDescriptor = operands[0];
|
|
ArrayRef<Value *> indices = operands.drop_front();
|
|
auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
|
|
rewriter.replaceOp(op, {llvm_load(elementTy, ptr)});
|
|
}
|
|
};
|
|
|
|
// RangeOp creates a new range descriptor.
|
|
class RangeOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit RangeOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
|
: LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
|
|
|
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
auto rangeOp = cast<RangeOp>(op);
|
|
auto rangeDescriptorTy =
|
|
convertLinalgType(rangeOp.getResult()->getType(), lowering);
|
|
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
|
|
// Fill in an aggregate value of the descriptor.
|
|
Value *desc = undef(rangeDescriptorTy);
|
|
desc = insertvalue(rangeDescriptorTy, desc, operands[0],
|
|
positionAttr(rewriter, 0));
|
|
desc = insertvalue(rangeDescriptorTy, desc, operands[1],
|
|
positionAttr(rewriter, 1));
|
|
desc = insertvalue(rangeDescriptorTy, desc, operands[2],
|
|
positionAttr(rewriter, 2));
|
|
rewriter.replaceOp(op, desc);
|
|
}
|
|
};
|
|
|
|
// RangeIntersectOp creates a new range descriptor.
|
|
class RangeIntersectOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit RangeIntersectOpConversion(MLIRContext *context,
|
|
LLVMLowering &lowering_)
|
|
: LLVMOpLowering(RangeIntersectOp::getOperationName(), context,
|
|
lowering_) {}
|
|
|
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
auto rangeIntersectOp = cast<RangeIntersectOp>(op);
|
|
auto rangeDescriptorTy =
|
|
convertLinalgType(rangeIntersectOp.getResult()->getType(), lowering);
|
|
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
|
auto int1Ty = lowering.convertType(rewriter.getIntegerType(1));
|
|
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
auto min1 = extractvalue(int64Ty, operands[0], positionAttr(rewriter, 0));
|
|
auto min2 = extractvalue(int64Ty, operands[1], positionAttr(rewriter, 0));
|
|
auto max1 = extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1));
|
|
auto max2 = extractvalue(int64Ty, operands[1], positionAttr(rewriter, 1));
|
|
auto step1 = extractvalue(int64Ty, operands[0], positionAttr(rewriter, 2));
|
|
auto step2 = extractvalue(int64Ty, operands[1], positionAttr(rewriter, 2));
|
|
|
|
// Fill in an aggregate value of the descriptor.
|
|
auto SLE =
|
|
rewriter.getI64IntegerAttr(static_cast<int64_t>(CmpIPredicate::SLE));
|
|
auto SGE =
|
|
rewriter.getI64IntegerAttr(static_cast<int64_t>(CmpIPredicate::SGE));
|
|
Value *desc = undef(rangeDescriptorTy);
|
|
desc = insertvalue(
|
|
rangeDescriptorTy, desc,
|
|
llvm_select(int64Ty, icmp(int1Ty, SGE, min1, min2), min1, min2),
|
|
positionAttr(rewriter, 0));
|
|
desc = insertvalue(
|
|
rangeDescriptorTy, desc,
|
|
llvm_select(int64Ty, icmp(int1Ty, SLE, max1, max2), max1, max2),
|
|
positionAttr(rewriter, 1));
|
|
// TODO(ntv): this assumes both steps are one for now. Enforce and extend.
|
|
desc = insertvalue(rangeDescriptorTy, desc, mul(step1, step2),
|
|
positionAttr(rewriter, 2));
|
|
rewriter.replaceOp(op, desc);
|
|
}
|
|
};
|
|
|
|
class SliceOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit SliceOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
|
: LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
|
|
|
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
auto sliceOp = cast<SliceOp>(op);
|
|
auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
|
|
auto viewType = sliceOp.getBaseViewType();
|
|
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
|
|
|
// Helper function to create an integer array attribute out of a list of
|
|
// values.
|
|
auto pos = [&rewriter](ArrayRef<int> values) {
|
|
return positionAttr(rewriter, values);
|
|
};
|
|
// Helper function to obtain the ptr of the given `view`.
|
|
auto getViewPtr = [pos, &rewriter, this](ViewType type,
|
|
Value *view) -> Value * {
|
|
auto elementPtrTy = getPtrToElementType(type, lowering);
|
|
return extractvalue(elementPtrTy, view, pos(0));
|
|
};
|
|
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
// Declare the view descriptor and insert data ptr.
|
|
Value *desc = undef(viewDescriptorTy);
|
|
desc = insertvalue(viewDescriptorTy, desc,
|
|
getViewPtr(viewType, operands[0]), pos(0));
|
|
|
|
// TODO(ntv): extract sizes and emit asserts.
|
|
SmallVector<Value *, 4> strides(viewType.getRank());
|
|
for (int dim = 0, e = viewType.getRank(); dim < e; ++dim) {
|
|
strides[dim] = extractvalue(int64Ty, operands[0], pos({3, dim}));
|
|
}
|
|
|
|
// Compute and insert base offset.
|
|
Value *baseOffset = extractvalue(int64Ty, operands[0], pos(1));
|
|
for (int j = 0, e = viewType.getRank(); j < e; ++j) {
|
|
Value *indexing = operands[1 + j];
|
|
Value *min =
|
|
sliceOp.getIndexing(j)->getType().isa<RangeType>()
|
|
? static_cast<Value *>(extractvalue(int64Ty, indexing, pos(0)))
|
|
: indexing;
|
|
Value *product = mul(min, strides[j]);
|
|
baseOffset = add(baseOffset, product);
|
|
}
|
|
desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
|
|
|
|
// Compute and insert view sizes (max - min along the range). Skip the
|
|
// non-range operands as they will be projected away from the view.
|
|
int i = 0;
|
|
for (Value *index : sliceOp.getIndexings()) {
|
|
if (!index->getType().isa<RangeType>())
|
|
continue;
|
|
|
|
Value *rangeDescriptor = operands[1 + i];
|
|
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
|
|
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
|
|
Value *size = sub(max, min);
|
|
|
|
desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
|
|
++i;
|
|
}
|
|
|
|
// Compute and insert view strides. Step over the strides that correspond
|
|
// to non-range operands as they are projected away from the view.
|
|
i = 0;
|
|
for (int j = 0, e = strides.size(); j < e; ++j) {
|
|
if (!sliceOp.getIndexing(j)->getType().isa<RangeType>())
|
|
continue;
|
|
Value *step = extractvalue(int64Ty, operands[1 + j], pos(2));
|
|
Value *stride = mul(strides[j], step);
|
|
desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
|
|
++i;
|
|
}
|
|
|
|
rewriter.replaceOp(op, desc);
|
|
}
|
|
};
|
|
|
|
// A store is converted into the actual address computation, getelementptr and
|
|
// an LLVM IR store.
|
|
class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
|
|
using Base::Base;
|
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
edsc::ScopedContext edscContext(rewriter, op->getLoc());
|
|
Value *data = operands[0];
|
|
Value *viewDescriptor = operands[1];
|
|
ArrayRef<Value *> indices = operands.drop_front(2);
|
|
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
|
|
llvm_store(data, ptr);
|
|
rewriter.replaceOp(op, llvm::None);
|
|
}
|
|
};
|
|
|
|
class ViewOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit ViewOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
|
: LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
|
|
|
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
auto viewOp = cast<ViewOp>(op);
|
|
auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
|
|
auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
|
|
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
|
|
|
|
auto pos = [&rewriter](ArrayRef<int> values) {
|
|
return positionAttr(rewriter, values);
|
|
};
|
|
|
|
// First operand to `view` is the buffer descriptor.
|
|
Value *bufferDescriptor = operands[0];
|
|
|
|
// Declare the descriptor of the view.
|
|
edsc::ScopedContext context(rewriter, op->getLoc());
|
|
Value *desc = undef(viewDescriptorTy);
|
|
|
|
// Copy the buffer pointer from the old descriptor to the new one.
|
|
Value *buffer = extractvalue(elementTy, bufferDescriptor, pos(0));
|
|
desc = insertvalue(viewDescriptorTy, desc, buffer, pos(0));
|
|
|
|
// Zero base offset.
|
|
auto indexTy = rewriter.getIndexType();
|
|
Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0));
|
|
desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
|
|
|
|
// Compute and insert view sizes (max - min along the range).
|
|
int numIndexings = llvm::size(viewOp.getIndexings());
|
|
Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1));
|
|
for (int i = numIndexings - 1; i >= 0; --i) {
|
|
// Update stride.
|
|
Value *rangeDescriptor = operands[1 + i];
|
|
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
|
|
Value *stride = mul(runningStride, step);
|
|
desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
|
|
// Update size.
|
|
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
|
|
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
|
|
Value *size = sub(max, min);
|
|
desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
|
|
// Update stride for the next dimension.
|
|
if (i > 0)
|
|
runningStride = mul(runningStride, max);
|
|
}
|
|
|
|
rewriter.replaceOp(op, desc);
|
|
}
|
|
};
|
|
|
|
// DotOp creates a new range descriptor.
|
|
class DotOpConversion : public LLVMOpLowering {
|
|
public:
|
|
explicit DotOpConversion(MLIRContext *context, LLVMLowering &lowering_)
|
|
: LLVMOpLowering(DotOp::getOperationName(), context, lowering_) {}
|
|
|
|
static StringRef libraryFunctionName() { return "linalg_dot"; }
|
|
|
|
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
|
PatternRewriter &rewriter) const override {
|
|
auto *f =
|
|
op->getFunction()->getModule()->getNamedFunction(libraryFunctionName());
|
|
if (!f) {
|
|
op->emitError("Could not find function: " + libraryFunctionName() +
|
|
"in lowering to LLVM ");
|
|
return;
|
|
}
|
|
|
|
auto fAttr = rewriter.getFunctionAttr(f);
|
|
auto named = rewriter.getNamedAttr("callee", fAttr);
|
|
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
|
|
ArrayRef<NamedAttribute>{named});
|
|
}
|
|
};
|
|
|
|
namespace {
|
|
// The conversion class from Linalg to LLVMIR.
|
|
class Lowering : public LLVMLowering {
|
|
protected:
|
|
void initAdditionalConverters(OwningRewritePatternList &patterns) override {
|
|
RewriteListBuilder<BufferAllocOpConversion, BufferDeallocOpConversion,
|
|
BufferSizeOpConversion, DimOpConversion, DotOpConversion,
|
|
LoadOpConversion, RangeOpConversion,
|
|
RangeIntersectOpConversion, SliceOpConversion,
|
|
StoreOpConversion,
|
|
ViewOpConversion>::build(patterns,
|
|
llvmDialect->getContext(),
|
|
*this);
|
|
}
|
|
|
|
Type convertAdditionalType(Type t) override {
|
|
return convertLinalgType(t, *this);
|
|
}
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
namespace {
|
|
struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
|
|
void runOnModule();
|
|
};
|
|
} // namespace
|
|
|
|
void LowerLinalgToLLVMPass::runOnModule() {
|
|
auto &module = getModule();
|
|
|
|
PassManager pm;
|
|
pm.addPass(createLowerAffinePass());
|
|
if (failed(pm.run(&module)))
|
|
signalPassFailure();
|
|
|
|
// Convert to the LLVM IR dialect using the converter defined above.
|
|
Lowering lowering;
|
|
if (failed(applyConverter(module, lowering)))
|
|
signalPassFailure();
|
|
}
|
|
|
|
ModulePassBase *mlir::linalg::createLowerLinalgToLLVMPass() {
|
|
return new LowerLinalgToLLVMPass();
|
|
}
|
|
|
|
static PassRegistration<LowerLinalgToLLVMPass>
|
|
pass("linalg-lower-to-llvm-dialect",
|
|
"Lower the operations from the linalg dialect into the LLVM dialect");
|