Use llvm.func to define functions with wrapped LLVM IR function type

This function-like operation allows one to define functions that have wrapped
LLVM IR function type, in particular variadic functions. The operation was
added in parallel to the existing lowering flow, this commit only switches the
flow to use it.

Using a custom function type makes the LLVM IR dialect type system more
consistent and avoids complex conversion rules for functions that previously
had to use the built-in function type instead of a wrapped LLVM IR dialect type
and perform conversions during the analysis.

PiperOrigin-RevId: 273910855
This commit is contained in:
Alex Zinenko
2019-10-10 01:33:33 -07:00
committed by A. Unique TensorFlower
parent 309b4556d0
commit 5e7959a353
29 changed files with 324 additions and 307 deletions

View File

@@ -39,38 +39,6 @@
namespace mlir {
namespace LLVM {
// Convert an MLIR function type to LLVM IR. Arguments of the function must of
// MLIR LLVM IR dialect types. Use `loc` as a location when reporting errors.
// Return nullptr on errors.
static llvm::FunctionType *convertFunctionType(llvm::LLVMContext &llvmContext,
FunctionType type, Location loc,
bool isVarArgs) {
assert(type && "expected non-null type");
if (type.getNumResults() > 1)
return emitError(loc, "LLVM functions can only have 0 or 1 result"),
nullptr;
SmallVector<llvm::Type *, 8> argTypes;
argTypes.reserve(type.getNumInputs());
for (auto t : type.getInputs()) {
auto wrappedLLVMType = t.dyn_cast<LLVM::LLVMType>();
if (!wrappedLLVMType)
return emitError(loc, "non-LLVM function argument type"), nullptr;
argTypes.push_back(wrappedLLVMType.getUnderlyingType());
}
if (type.getNumResults() == 0)
return llvm::FunctionType::get(llvm::Type::getVoidTy(llvmContext), argTypes,
isVarArgs);
auto wrappedResultType = type.getResult(0).dyn_cast<LLVM::LLVMType>();
if (!wrappedResultType)
return emitError(loc, "non-LLVM function result"), nullptr;
return llvm::FunctionType::get(wrappedResultType.getUnderlyingType(),
argTypes, isVarArgs);
}
// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
// This currently supports integer, floating point, splat and dense element
// attributes and combinations thereof. In case of error, report it to `loc`
@@ -362,7 +330,7 @@ static Value *getPHISourceValue(Block *current, Block *pred,
: terminator.getSuccessorOperand(1, index);
}
void ModuleTranslation::connectPHINodes(FuncOp func) {
void ModuleTranslation::connectPHINodes(LLVMFuncOp func) {
// Skip the first block, it cannot be branched to and its arguments correspond
// to the arguments of the LLVM function.
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
@@ -393,7 +361,7 @@ static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
}
// Sort function blocks topologically.
static llvm::SetVector<Block *> topologicalSort(FuncOp f) {
static llvm::SetVector<Block *> topologicalSort(LLVMFuncOp f) {
// For each blocks that has not been visited yet (i.e. that has no
// predecessors), add it to the list and traverse its successors in DFS
// preorder.
@@ -407,7 +375,7 @@ static llvm::SetVector<Block *> topologicalSort(FuncOp f) {
return blocks;
}
LogicalResult ModuleTranslation::convertOneFunction(FuncOp func) {
LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
// Clear the block and value mappings, they are only relevant within one
// function.
blockMapping.clear();
@@ -460,24 +428,17 @@ LogicalResult ModuleTranslation::convertOneFunction(FuncOp func) {
LogicalResult ModuleTranslation::convertFunctions() {
// Declare all functions first because there may be function calls that form a
// call graph with cycles.
for (FuncOp function : mlirModule.getOps<FuncOp>()) {
mlir::BoolAttr isVarArgsAttr =
function.getAttrOfType<BoolAttr>("std.varargs");
bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
llvm::FunctionType *functionType =
convertFunctionType(llvmModule->getContext(), function.getType(),
function.getLoc(), isVarArgs);
if (!functionType)
return failure();
llvm::FunctionCallee llvmFuncCst =
llvmModule->getOrInsertFunction(function.getName(), functionType);
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
function.getName(),
llvm::cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
assert(isa<llvm::Function>(llvmFuncCst.getCallee()));
functionMapping[function.getName()] =
cast<llvm::Function>(llvmFuncCst.getCallee());
}
// Convert functions.
for (FuncOp function : mlirModule.getOps<FuncOp>()) {
for (auto function : mlirModule.getOps<LLVMFuncOp>()) {
// Ignore external functions.
if (function.isExternal())
continue;