JitRunner: support entry functions returning void

JitRunner can use as entry points functions that produce either a single
'!llvm.f32' value or a list of memrefs.  Memref support is legacy and was
introduced before MLIR could lower memref allocation and deallocation to
malloc/free calls so as to allocate the memory externally, and is likely to be
dropped in the future since it unconditionally runs affine+standard-to-llvm
lowering on the module instead of accepting the LLVM dialect.  CUDA runner
relies on memref-based flow in the runner without actually returning anything.
Introduce a runner flow to use functions that return void as entry points.

PiperOrigin-RevId: 264381686
This commit is contained in:
Alex Zinenko
2019-08-20 07:45:47 -07:00
committed by A. Unique TensorFlower
parent 0f974817b5
commit 0d82a292b0
2 changed files with 51 additions and 32 deletions

View File

@@ -70,7 +70,7 @@ static llvm::cl::opt<std::string>
static llvm::cl::opt<std::string> mainFuncType(
"entry-point-result",
llvm::cl::desc("Textual description of the function type to be called"),
llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
llvm::cl::value_desc("f32 | memrefs | void"), llvm::cl::init("memrefs"));
static llvm::cl::OptionCategory optFlags("opt-like flags");
@@ -166,6 +166,37 @@ static LogicalResult convertAffineStandardToLLVMIR(ModuleOp module) {
return manager.run(module);
}
// JIT-compile the given module and run "entryPoint" with "args" as arguments.
static Error
compileAndExecute(ModuleOp module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer,
void **args) {
SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
auto expectedEngine =
mlir::ExecutionEngine::create(module, transformer, libs);
if (!expectedEngine)
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
void (*fptr)(void **) = *expectedFPtr;
(*fptr)(args);
return Error::success();
}
static Error compileAndExecuteVoidFunction(
ModuleOp module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
if (!mainFunction || mainFunction.getBlocks().empty())
return make_string_error("entry point not found");
void *empty = nullptr;
return compileAndExecute(module, entryPoint, transformer, &empty);
}
static Error compileAndExecuteFunctionWithMemRefs(
ModuleOp module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
@@ -191,21 +222,12 @@ static Error compileAndExecuteFunctionWithMemRefs(
if (failed(convertAffineStandardToLLVMIR(module)))
return make_string_error("conversion to the LLVM IR dialect failed");
SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
auto expectedEngine =
mlir::ExecutionEngine::create(module, transformer, libs);
if (!expectedEngine)
return expectedEngine.takeError();
if (auto error = compileAndExecute(module, entryPoint, transformer,
expectedArguments->data()))
return error;
auto engine = std::move(*expectedEngine);
auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
void (*fptr)(void **) = *expectedFPtr;
(*fptr)(expectedArguments->data());
printMemRefArguments(argTypes, resTypes, *expectedArguments);
freeMemRefArguments(*expectedArguments);
return Error::success();
}
@@ -230,24 +252,14 @@ static Error compileAndExecuteSingleFloatReturnFunction(
if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
return make_string_error("only single llvm.f32 function result supported");
SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
auto expectedEngine =
mlir::ExecutionEngine::create(module, transformer, libs);
if (!expectedEngine)
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
void (*fptr)(void **) = *expectedFPtr;
float res;
struct {
void *data;
} data;
data.data = &res;
(*fptr)((void **)&data);
if (auto error =
compileAndExecute(module, entryPoint, transformer, (void **)&data))
return error;
// Intentional printing of the output so we can test.
llvm::outs() << res;
@@ -320,11 +332,18 @@ int mlir::JitRunnerMain(
auto transformer = mlir::makeLLVMPassesTransformer(
passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
auto error = mainFuncType.getValue() == "f32"
? compileAndExecuteSingleFloatReturnFunction(
m.get(), mainFuncName.getValue(), transformer)
: compileAndExecuteFunctionWithMemRefs(
m.get(), mainFuncName.getValue(), transformer);
Error error = make_string_error("unsupported function type");
if (mainFuncType.getValue() == "f32")
error = compileAndExecuteSingleFloatReturnFunction(
m.get(), mainFuncName.getValue(), transformer);
else if (mainFuncType.getValue() == "memrefs")
error = compileAndExecuteFunctionWithMemRefs(
m.get(), mainFuncName.getValue(), transformer);
else if (mainFuncType.getValue() == "void")
error = compileAndExecuteVoidFunction(m.get(), mainFuncName.getValue(),
transformer);
int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error),
[&exitCode](const llvm::ErrorInfoBase &info) {

View File

@@ -1,4 +1,4 @@
// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext | FileCheck %s
// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext --entry-point-result=void | FileCheck %s
func @other_func(%arg0 : f32, %arg1 : memref<?xf32>) {
%cst = constant 1 : index