mirror of
https://github.com/intel/llvm.git
synced 2026-02-02 18:18:09 +08:00
365 lines
14 KiB
C++
365 lines
14 KiB
C++
//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
//
|
|
// This file implements a generic pass for converting between MLIR dialects.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Function.h"
|
|
#include "mlir/IR/Module.h"
|
|
#include "mlir/Transforms/Utils.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace mlir {
|
|
namespace impl {
|
|
// Implementation detail class of the DialectConversion pass. Performs
|
|
// function-by-function conversions by creating new functions, filling them in
|
|
// with converted blocks, updating the function attributes, and replacing the
|
|
// old functions with the new ones in the module.
|
|
class FunctionConversion {
|
|
public:
|
|
// Entry point. Uses hooks defined in `conversion` to obtain the list of
|
|
// conversion patterns and to convert function and block argument types.
|
|
// Converts the `module` in-place by replacing all existing functions with the
|
|
// converted ones.
|
|
static LogicalResult convert(DialectConversion *conversion, Module *module);
|
|
|
|
private:
|
|
// Constructs a FunctionConversion by storing the hooks.
|
|
explicit FunctionConversion(DialectConversion *conversion)
|
|
: dialectConversion(conversion) {}
|
|
|
|
// Utility that looks up a list of value in the value remapping table. Returns
|
|
// an empty vector if one of the values is not mapped yet.
|
|
SmallVector<Value *, 4> lookupValues(Operation::operand_range operands);
|
|
|
|
// Converts the given function to the dialect using hooks defined in
|
|
// `dialectConversion`. Returns the converted function or `nullptr` on error.
|
|
Function *convertFunction(Function *f);
|
|
|
|
// Converts the given region starting from the entry block and following the
|
|
// block successors. Returns the converted region or `nullptr` on error.
|
|
template <typename RegionParent>
|
|
std::unique_ptr<Region> convertRegion(MLIRContext *context, Region *region,
|
|
RegionParent *parent);
|
|
|
|
// Converts an operation with successors. Extracts the converted operands
|
|
// from `valueRemapping` and the converted blocks from `blockRemapping`, and
|
|
// passes them to `converter->rewriteTerminator` function defined in the
|
|
// pattern, together with `builder`.
|
|
LogicalResult convertOpWithSuccessors(DialectOpConversion *converter,
|
|
Operation *op, FuncBuilder &builder);
|
|
|
|
// Converts an operation without successors. Extracts the converted operands
|
|
// from `valueRemapping` and passes them to the `converter->rewrite` function
|
|
// defined in the pattern, together with `builder`.
|
|
LogicalResult convertOp(DialectOpConversion *converter, Operation *op,
|
|
FuncBuilder &builder);
|
|
|
|
// Converts a block by traversing its operations sequentially, looking for
|
|
// the first pattern match and dispatching the operation conversion to
|
|
// either `convertOp` or `convertOpWithSuccessors` depending on the presence
|
|
// of successors. If there is no match, clones the operation.
|
|
//
|
|
// After converting operations, traverses the successor blocks unless they
|
|
// have been visited already as indicated in `visitedBlocks`.
|
|
LogicalResult convertBlock(Block *block, FuncBuilder &builder,
|
|
llvm::DenseSet<Block *> &visitedBlocks);
|
|
|
|
// Converts the module as follows.
|
|
// 1. Call `convertFunction` on each function of the module and collect the
|
|
// mapping between old and new functions.
|
|
// 2. Remap all function attributes in the new functions to point to the new
|
|
// functions instead of the old ones.
|
|
// 3. Replace old functions with the new in the module.
|
|
LogicalResult run(Module *m);
|
|
|
|
// Pointer to a specific dialect pass.
|
|
DialectConversion *dialectConversion;
|
|
|
|
// Set of known conversion patterns.
|
|
llvm::DenseSet<DialectOpConversion *> conversions;
|
|
|
|
// Mapping between values(blocks) in the original function and in the new
|
|
// function.
|
|
BlockAndValueMapping mapping;
|
|
};
|
|
} // end namespace impl
|
|
} // end namespace mlir
|
|
|
|
SmallVector<Value *, 4>
|
|
impl::FunctionConversion::lookupValues(Operation::operand_range operands) {
|
|
SmallVector<Value *, 4> remapped;
|
|
remapped.reserve(llvm::size(operands));
|
|
for (Value *operand : operands) {
|
|
Value *value = mapping.lookupOrNull(operand);
|
|
if (!value)
|
|
return {};
|
|
remapped.push_back(value);
|
|
}
|
|
return remapped;
|
|
}
|
|
|
|
LogicalResult impl::FunctionConversion::convertOpWithSuccessors(
|
|
DialectOpConversion *converter, Operation *op, FuncBuilder &builder) {
|
|
SmallVector<Block *, 2> destinations;
|
|
destinations.reserve(op->getNumSuccessors());
|
|
SmallVector<Value *, 4> operands = lookupValues(op->getOperands());
|
|
assert((!operands.empty() || op->getNumOperands() == 0) &&
|
|
"converting op before ops defining its operands");
|
|
|
|
SmallVector<ArrayRef<Value *>, 2> operandsPerDestination;
|
|
unsigned numSuccessorOperands = 0;
|
|
for (unsigned i = 0, e = op->getNumSuccessors(); i < e; ++i)
|
|
numSuccessorOperands += op->getNumSuccessorOperands(i);
|
|
unsigned seen = 0;
|
|
unsigned firstSuccessorOperand = op->getNumOperands() - numSuccessorOperands;
|
|
for (unsigned i = 0, e = op->getNumSuccessors(); i < e; ++i) {
|
|
Block *successor = mapping.lookupOrNull(op->getSuccessor(i));
|
|
assert(successor && "block was not remapped");
|
|
destinations.push_back(successor);
|
|
unsigned n = op->getNumSuccessorOperands(i);
|
|
operandsPerDestination.push_back(
|
|
llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n));
|
|
seen += n;
|
|
}
|
|
converter->rewriteTerminator(
|
|
op,
|
|
llvm::makeArrayRef(operands.data(),
|
|
operands.data() + firstSuccessorOperand),
|
|
destinations, operandsPerDestination, builder);
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
impl::FunctionConversion::convertOp(DialectOpConversion *converter,
|
|
Operation *op, FuncBuilder &builder) {
|
|
auto operands = lookupValues(op->getOperands());
|
|
assert((!operands.empty() || op->getNumOperands() == 0) &&
|
|
"converting op before ops defining its operands");
|
|
|
|
auto results = converter->rewrite(op, operands, builder);
|
|
if (results.size() != op->getNumResults())
|
|
return (op->emitError("rewriting produced a different number of results"),
|
|
failure());
|
|
|
|
for (unsigned i = 0, e = results.size(); i < e; ++i)
|
|
mapping.map(op->getResult(i), results[i]);
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder,
|
|
llvm::DenseSet<Block *> &visitedBlocks) {
|
|
// First, add the current block to the list of visited blocks.
|
|
visitedBlocks.insert(block);
|
|
// Setup the builder to the insert to the converted block.
|
|
builder.setInsertionPointToStart(mapping.lookupOrNull(block));
|
|
|
|
// Iterate over ops and convert them.
|
|
for (Operation &op : *block) {
|
|
// Find the first matching conversion and apply it.
|
|
bool converted = false;
|
|
for (auto *conversion : conversions) {
|
|
if (!conversion->match(&op))
|
|
continue;
|
|
|
|
if (op.getNumSuccessors() != 0) {
|
|
if (failed(convertOpWithSuccessors(conversion, &op, builder)))
|
|
return failure();
|
|
} else if (failed(convertOp(conversion, &op, builder))) {
|
|
return failure();
|
|
}
|
|
converted = true;
|
|
break;
|
|
}
|
|
// If there is no conversion provided for the op, clone the op and convert
|
|
// its regions, if any.
|
|
if (!converted) {
|
|
auto *newOp = builder.cloneWithoutRegions(op, mapping);
|
|
for (int i = 0, e = op.getNumRegions(); i < e; ++i) {
|
|
auto newRegion = convertRegion(op.getContext(), &op.getRegion(i), &op);
|
|
newOp->getRegion(i).takeBody(*newRegion);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Recurse to children unless they have been already visited.
|
|
for (Block *succ : block->getSuccessors()) {
|
|
if (visitedBlocks.count(succ) != 0)
|
|
continue;
|
|
if (failed(convertBlock(succ, builder, visitedBlocks)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
template <typename RegionParent>
|
|
std::unique_ptr<Region>
|
|
impl::FunctionConversion::convertRegion(MLIRContext *context, Region *region,
|
|
RegionParent *parent) {
|
|
assert(region && "expected a region");
|
|
auto newRegion = llvm::make_unique<Region>(parent);
|
|
if (region->empty())
|
|
return newRegion;
|
|
|
|
auto emitError = [context](llvm::Twine f) -> std::unique_ptr<Region> {
|
|
context->emitError(UnknownLoc::get(context), f.str());
|
|
return nullptr;
|
|
};
|
|
|
|
// Create new blocks and convert their arguments.
|
|
for (Block &block : *region) {
|
|
auto *newBlock = new Block;
|
|
newRegion->push_back(newBlock);
|
|
mapping.map(&block, newBlock);
|
|
for (auto *arg : block.getArguments()) {
|
|
auto convertedType = dialectConversion->convertType(arg->getType());
|
|
if (!convertedType)
|
|
return emitError("could not convert block argument type");
|
|
newBlock->addArgument(convertedType);
|
|
mapping.map(arg, *newBlock->args_rbegin());
|
|
}
|
|
}
|
|
|
|
// Start a DFS-order traversal of the CFG to make sure defs are converted
|
|
// before uses in dominated blocks.
|
|
llvm::DenseSet<Block *> visitedBlocks;
|
|
FuncBuilder builder(&newRegion->front());
|
|
if (failed(convertBlock(®ion->front(), builder, visitedBlocks)))
|
|
return nullptr;
|
|
|
|
// If some blocks are not reachable through successor chains, they should have
|
|
// been removed by the DCE before this.
|
|
|
|
if (visitedBlocks.size() != std::distance(region->begin(), region->end()))
|
|
return emitError("unreachable blocks were not converted");
|
|
return newRegion;
|
|
}
|
|
|
|
Function *impl::FunctionConversion::convertFunction(Function *f) {
|
|
assert(f && "expected function");
|
|
MLIRContext *context = f->getContext();
|
|
auto emitError = [context](llvm::Twine f) -> Function * {
|
|
context->emitError(UnknownLoc::get(context), f.str());
|
|
return nullptr;
|
|
};
|
|
|
|
// Create a new function with argument types and result types converted. Wrap
|
|
// it into a unique_ptr to make sure it is cleaned up in case of error.
|
|
SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
|
|
Type newFunctionType = dialectConversion->convertFunctionSignatureType(
|
|
f->getType(), f->getAllArgAttrs(), newFunctionArgAttrs);
|
|
if (!newFunctionType)
|
|
return emitError("could not convert function type");
|
|
auto newFunction = llvm::make_unique<Function>(
|
|
f->getLoc(), f->getName().strref(), newFunctionType.cast<FunctionType>(),
|
|
f->getAttrs(), newFunctionArgAttrs);
|
|
|
|
// Return early if the function has no blocks.
|
|
if (f->getBlocks().empty())
|
|
return newFunction.release();
|
|
|
|
auto newBody = convertRegion(context, &f->getBody(), f);
|
|
if (!newBody)
|
|
return emitError("could not convert function body");
|
|
newFunction->getBody().takeBody(*newBody);
|
|
|
|
return newFunction.release();
|
|
}
|
|
|
|
LogicalResult impl::FunctionConversion::convert(DialectConversion *conversion,
|
|
Module *module) {
|
|
return impl::FunctionConversion(conversion).run(module);
|
|
}
|
|
|
|
LogicalResult impl::FunctionConversion::run(Module *module) {
|
|
if (!module)
|
|
return failure();
|
|
|
|
MLIRContext *context = module->getContext();
|
|
conversions = dialectConversion->initConverters(context);
|
|
|
|
// Convert the functions but don't add them to the module yet to avoid
|
|
// converted functions to be converted again.
|
|
SmallVector<Function *, 0> originalFuncs, convertedFuncs;
|
|
DenseMap<Attribute, FunctionAttr> functionAttrRemapping;
|
|
originalFuncs.reserve(module->getFunctions().size());
|
|
for (auto &func : *module)
|
|
originalFuncs.push_back(&func);
|
|
convertedFuncs.reserve(module->getFunctions().size());
|
|
for (auto *func : originalFuncs) {
|
|
Function *converted = convertFunction(func);
|
|
if (!converted)
|
|
return failure();
|
|
|
|
auto origFuncAttr = FunctionAttr::get(func, context);
|
|
auto convertedFuncAttr = FunctionAttr::get(converted, context);
|
|
convertedFuncs.push_back(converted);
|
|
functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr});
|
|
}
|
|
|
|
// Remap function attributes in the converted functions (they are not yet in
|
|
// the module). Original functions will disappear anyway so there is no
|
|
// need to remap attributes in them.
|
|
for (const auto &funcPair : functionAttrRemapping) {
|
|
remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping);
|
|
}
|
|
|
|
// Remove original functions from the module, then insert converted
|
|
// functions. The order is important to avoid name collisions.
|
|
for (auto &func : originalFuncs)
|
|
func->erase();
|
|
for (auto *func : convertedFuncs)
|
|
module->getFunctions().push_back(func);
|
|
|
|
return success();
|
|
}
|
|
|
|
// Create a function type with arguments and results converted, and argument
|
|
// attributes passed through.
|
|
FunctionType DialectConversion::convertFunctionSignatureType(
|
|
FunctionType type, ArrayRef<NamedAttributeList> argAttrs,
|
|
SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) {
|
|
SmallVector<Type, 8> arguments;
|
|
SmallVector<Type, 4> results;
|
|
|
|
arguments.reserve(type.getNumInputs());
|
|
for (auto t : type.getInputs())
|
|
arguments.push_back(convertType(t));
|
|
|
|
results.reserve(type.getNumResults());
|
|
for (auto t : type.getResults())
|
|
results.push_back(convertType(t));
|
|
|
|
// Note this will cause an extra allocation only if we need
|
|
// to grow the caller-provided resulting attribute vector.
|
|
convertedArgAttrs.reserve(arguments.size());
|
|
for (auto attr : argAttrs)
|
|
convertedArgAttrs.push_back(attr);
|
|
|
|
return FunctionType::get(arguments, results, type.getContext());
|
|
}
|
|
|
|
LogicalResult DialectConversion::convert(Module *m) {
|
|
return impl::FunctionConversion::convert(this, m);
|
|
}
|