mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 01:07:04 +08:00
JitRunner: support entry functions returning void
JitRunner can use as entry points functions that produce either a single '!llvm.f32' value or a list of memrefs. Memref support is legacy and was introduced before MLIR could lower memref allocation and deallocation to malloc/free calls so as to allocate the memory externally, and is likely to be dropped in the future since it unconditionally runs affine+standard-to-llvm lowering on the module instead of accepting the LLVM dialect. CUDA runner relies on memref-based flow in the runner without actually returning anything. Introduce a runner flow to use functions that return void as entry points. PiperOrigin-RevId: 264381686
This commit is contained in:
committed by
A. Unique TensorFlower
parent
0f974817b5
commit
0d82a292b0
@@ -70,7 +70,7 @@ static llvm::cl::opt<std::string>
|
||||
static llvm::cl::opt<std::string> mainFuncType(
|
||||
"entry-point-result",
|
||||
llvm::cl::desc("Textual description of the function type to be called"),
|
||||
llvm::cl::value_desc("f32 or memrefs"), llvm::cl::init("memrefs"));
|
||||
llvm::cl::value_desc("f32 | memrefs | void"), llvm::cl::init("memrefs"));
|
||||
|
||||
static llvm::cl::OptionCategory optFlags("opt-like flags");
|
||||
|
||||
@@ -166,6 +166,37 @@ static LogicalResult convertAffineStandardToLLVMIR(ModuleOp module) {
|
||||
return manager.run(module);
|
||||
}
|
||||
|
||||
// JIT-compile the given module and run "entryPoint" with "args" as arguments.
|
||||
static Error
|
||||
compileAndExecute(ModuleOp module, StringRef entryPoint,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer,
|
||||
void **args) {
|
||||
SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
|
||||
auto expectedEngine =
|
||||
mlir::ExecutionEngine::create(module, transformer, libs);
|
||||
if (!expectedEngine)
|
||||
return expectedEngine.takeError();
|
||||
|
||||
auto engine = std::move(*expectedEngine);
|
||||
auto expectedFPtr = engine->lookup(entryPoint);
|
||||
if (!expectedFPtr)
|
||||
return expectedFPtr.takeError();
|
||||
void (*fptr)(void **) = *expectedFPtr;
|
||||
(*fptr)(args);
|
||||
|
||||
return Error::success();
|
||||
}
|
||||
|
||||
static Error compileAndExecuteVoidFunction(
|
||||
ModuleOp module, StringRef entryPoint,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
|
||||
if (!mainFunction || mainFunction.getBlocks().empty())
|
||||
return make_string_error("entry point not found");
|
||||
void *empty = nullptr;
|
||||
return compileAndExecute(module, entryPoint, transformer, &empty);
|
||||
}
|
||||
|
||||
static Error compileAndExecuteFunctionWithMemRefs(
|
||||
ModuleOp module, StringRef entryPoint,
|
||||
std::function<llvm::Error(llvm::Module *)> transformer) {
|
||||
@@ -191,21 +222,12 @@ static Error compileAndExecuteFunctionWithMemRefs(
|
||||
if (failed(convertAffineStandardToLLVMIR(module)))
|
||||
return make_string_error("conversion to the LLVM IR dialect failed");
|
||||
|
||||
SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
|
||||
auto expectedEngine =
|
||||
mlir::ExecutionEngine::create(module, transformer, libs);
|
||||
if (!expectedEngine)
|
||||
return expectedEngine.takeError();
|
||||
if (auto error = compileAndExecute(module, entryPoint, transformer,
|
||||
expectedArguments->data()))
|
||||
return error;
|
||||
|
||||
auto engine = std::move(*expectedEngine);
|
||||
auto expectedFPtr = engine->lookup(entryPoint);
|
||||
if (!expectedFPtr)
|
||||
return expectedFPtr.takeError();
|
||||
void (*fptr)(void **) = *expectedFPtr;
|
||||
(*fptr)(expectedArguments->data());
|
||||
printMemRefArguments(argTypes, resTypes, *expectedArguments);
|
||||
freeMemRefArguments(*expectedArguments);
|
||||
|
||||
return Error::success();
|
||||
}
|
||||
|
||||
@@ -230,24 +252,14 @@ static Error compileAndExecuteSingleFloatReturnFunction(
|
||||
if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
|
||||
return make_string_error("only single llvm.f32 function result supported");
|
||||
|
||||
SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
|
||||
auto expectedEngine =
|
||||
mlir::ExecutionEngine::create(module, transformer, libs);
|
||||
if (!expectedEngine)
|
||||
return expectedEngine.takeError();
|
||||
|
||||
auto engine = std::move(*expectedEngine);
|
||||
auto expectedFPtr = engine->lookup(entryPoint);
|
||||
if (!expectedFPtr)
|
||||
return expectedFPtr.takeError();
|
||||
void (*fptr)(void **) = *expectedFPtr;
|
||||
|
||||
float res;
|
||||
struct {
|
||||
void *data;
|
||||
} data;
|
||||
data.data = &res;
|
||||
(*fptr)((void **)&data);
|
||||
if (auto error =
|
||||
compileAndExecute(module, entryPoint, transformer, (void **)&data))
|
||||
return error;
|
||||
|
||||
// Intentional printing of the output so we can test.
|
||||
llvm::outs() << res;
|
||||
@@ -320,11 +332,18 @@ int mlir::JitRunnerMain(
|
||||
|
||||
auto transformer = mlir::makeLLVMPassesTransformer(
|
||||
passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
|
||||
auto error = mainFuncType.getValue() == "f32"
|
||||
? compileAndExecuteSingleFloatReturnFunction(
|
||||
m.get(), mainFuncName.getValue(), transformer)
|
||||
: compileAndExecuteFunctionWithMemRefs(
|
||||
m.get(), mainFuncName.getValue(), transformer);
|
||||
|
||||
Error error = make_string_error("unsupported function type");
|
||||
if (mainFuncType.getValue() == "f32")
|
||||
error = compileAndExecuteSingleFloatReturnFunction(
|
||||
m.get(), mainFuncName.getValue(), transformer);
|
||||
else if (mainFuncType.getValue() == "memrefs")
|
||||
error = compileAndExecuteFunctionWithMemRefs(
|
||||
m.get(), mainFuncName.getValue(), transformer);
|
||||
else if (mainFuncType.getValue() == "void")
|
||||
error = compileAndExecuteVoidFunction(m.get(), mainFuncName.getValue(),
|
||||
transformer);
|
||||
|
||||
int exitCode = EXIT_SUCCESS;
|
||||
llvm::handleAllErrors(std::move(error),
|
||||
[&exitCode](const llvm::ErrorInfoBase &info) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext | FileCheck %s
|
||||
// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext --entry-point-result=void | FileCheck %s
|
||||
|
||||
func @other_func(%arg0 : f32, %arg1 : memref<?xf32>) {
|
||||
%cst = constant 1 : index
|
||||
|
||||
Reference in New Issue
Block a user