From 206e55cc1653795166b0aebc25390bcc46f452db Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 2 Jul 2019 10:49:17 -0700 Subject: [PATCH] NFC: Refactor Module to be value typed. As with Functions, Module will soon become an operation, which are value-typed. This eases the transition from Module to ModuleOp. A new class, OwningModuleRef is provided to allow for owning a reference to a Module, and will auto-delete the held module on destruction. PiperOrigin-RevId: 256196193 --- mlir/bindings/python/pybind.cpp | 10 +- .../Linalg/Linalg1/include/linalg1/Common.h | 4 +- .../include/linalg1/ConvertToLLVMDialect.h | 2 +- .../Linalg1/lib/ConvertToLLVMDialect.cpp | 4 +- mlir/examples/Linalg/Linalg2/Example.cpp | 10 +- mlir/examples/Linalg/Linalg3/Conversion.cpp | 10 +- mlir/examples/Linalg/Linalg3/Example.cpp | 12 +- mlir/examples/Linalg/Linalg3/Execution.cpp | 10 +- .../include/linalg3/ConvertToLLVMDialect.h | 2 +- .../Linalg3/lib/ConvertToLLVMDialect.cpp | 2 +- mlir/examples/Linalg/Linalg4/Example.cpp | 14 +- mlir/examples/toy/Ch2/include/toy/MLIRGen.h | 5 +- mlir/examples/toy/Ch2/mlir/MLIRGen.cpp | 14 +- mlir/examples/toy/Ch2/toyc.cpp | 4 +- mlir/examples/toy/Ch3/include/toy/MLIRGen.h | 5 +- mlir/examples/toy/Ch3/mlir/MLIRGen.cpp | 10 +- mlir/examples/toy/Ch3/toyc.cpp | 4 +- mlir/examples/toy/Ch4/include/toy/MLIRGen.h | 5 +- mlir/examples/toy/Ch4/mlir/MLIRGen.cpp | 10 +- .../toy/Ch4/mlir/ShapeInferencePass.cpp | 2 +- mlir/examples/toy/Ch4/toyc.cpp | 8 +- mlir/examples/toy/Ch5/include/toy/MLIRGen.h | 5 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 6 +- mlir/examples/toy/Ch5/mlir/MLIRGen.cpp | 10 +- .../toy/Ch5/mlir/ShapeInferencePass.cpp | 4 +- mlir/examples/toy/Ch5/toyc.cpp | 16 +- mlir/g3doc/WritingAPass.md | 16 +- .../ConvertStandardToLLVMPass.h | 2 +- .../mlir/ExecutionEngine/ExecutionEngine.h | 2 +- mlir/include/mlir/IR/Builders.h | 4 +- mlir/include/mlir/IR/Function.h | 10 +- mlir/include/mlir/IR/Module.h | 145 ++++++++++++++---- mlir/include/mlir/IR/SymbolTable.h | 2 +- mlir/include/mlir/Parser.h | 10 +- mlir/include/mlir/Pass/AnalysisManager.h | 4 +- mlir/include/mlir/Pass/Pass.h | 7 +- mlir/include/mlir/Pass/PassManager.h | 2 +- mlir/include/mlir/Target/LLVMIR.h | 2 +- .../mlir/Target/LLVMIR/ModuleTranslation.h | 14 +- mlir/include/mlir/Target/NVVMIR.h | 3 +- .../mlir/Transforms/DialectConversion.h | 2 +- mlir/include/mlir/Translation.h | 6 +- .../GPUToCUDA/ConvertKernelFuncToCubin.cpp | 5 +- .../ConvertLaunchFuncToCudaCalls.cpp | 8 +- .../GPUToCUDA/GenerateCubinAccessors.cpp | 2 +- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 16 +- mlir/lib/ExecutionEngine/ExecutionEngine.cpp | 4 +- mlir/lib/GPU/IR/GPUDialect.cpp | 4 +- mlir/lib/GPU/Transforms/KernelOutlining.cpp | 2 +- mlir/lib/IR/AsmPrinter.cpp | 16 +- mlir/lib/IR/Builders.cpp | 4 +- mlir/lib/IR/Function.cpp | 16 +- mlir/lib/IR/SymbolTable.cpp | 4 +- .../Linalg/Transforms/LowerToLLVMDialect.cpp | 22 +-- mlir/lib/Parser/Parser.cpp | 27 ++-- mlir/lib/Pass/IRPrinting.cpp | 6 +- mlir/lib/Pass/Pass.cpp | 6 +- mlir/lib/Pass/PassDetail.h | 2 +- .../SPIRV/Serialization/ConvertFromBinary.cpp | 10 +- .../SPIRV/Serialization/ConvertToBinary.cpp | 6 +- mlir/lib/StandardOps/Ops.cpp | 4 +- mlir/lib/Support/MlirOptMain.cpp | 4 +- mlir/lib/Support/TranslateClParser.cpp | 8 +- mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp | 6 +- mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 9 +- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 8 +- mlir/lib/Transforms/DialectConversion.cpp | 2 +- mlir/test/EDSC/builder-api-test.cpp | 4 +- .../mlir-cpu-runner/mlir-cpu-runner-lib.cpp | 16 +- mlir/unittests/Pass/AnalysisManagerTest.cpp | 16 +- 70 files changed, 373 insertions(+), 283 deletions(-) diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index cdf4a7fe89cc..f730f8e48bbf 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -146,8 +146,8 @@ struct PythonFunction { /// Trivial C++ wrappers make use of the EDSC C API. struct PythonMLIRModule { PythonMLIRModule() - : mlirContext(), module(new mlir::Module(&mlirContext)), - moduleManager(module.get()) {} + : mlirContext(), module(mlir::Module::create(&mlirContext)), + moduleManager(*module) {} PythonType makeScalarType(const std::string &mlirElemType, unsigned bitwidth) { @@ -197,12 +197,12 @@ struct PythonMLIRModule { manager.addPass(mlir::createCSEPass()); manager.addPass(mlir::createLowerAffinePass()); manager.addPass(mlir::createConvertToLLVMIRPass()); - if (failed(manager.run(module.get()))) { + if (failed(manager.run(*module))) { llvm::errs() << "conversion to the LLVM IR dialect failed\n"; return; } - auto created = mlir::ExecutionEngine::create(module.get()); + auto created = mlir::ExecutionEngine::create(*module); llvm::handleAllErrors(created.takeError(), [](const llvm::ErrorInfoBase &b) { b.log(llvm::errs()); @@ -235,7 +235,7 @@ struct PythonMLIRModule { private: mlir::MLIRContext mlirContext; // One single module in a python-exposed MLIRContext for now. - std::unique_ptr module; + mlir::OwningModuleRef module; mlir::ModuleManager moduleManager; std::unique_ptr engine; }; diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h index 1f129c6b2830..38b304e31f1d 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h @@ -57,7 +57,7 @@ inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context, } /// A basic function builder -inline mlir::Function makeFunction(mlir::Module &module, llvm::StringRef name, +inline mlir::Function makeFunction(mlir::Module module, llvm::StringRef name, llvm::ArrayRef types, llvm::ArrayRef resultTypes) { auto *context = module.getContext(); @@ -92,7 +92,7 @@ inline void cleanupAndPrintFunction(mlir::Function f) { } }; auto pm = cleanupPassManager(); - check(f.getModule()->verify()); + check(f.getModule().verify()); check(pm->run(f.getModule())); if (printToOuts) f.print(llvm::outs()); diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h b/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h index e341705cb3be..fe77f4eef0fa 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h @@ -51,7 +51,7 @@ void populateLinalg1ToLLVMConversionPatterns( /// Convert the Linalg dialect types and RangeOp, ViewOp and SliceOp operations /// to the LLVM IR dialect types and operations in the given `module`. This is /// the main entry point to the conversion. -void convertToLLVM(mlir::Module &module); +void convertToLLVM(mlir::Module module); } // end namespace linalg #endif // LINALG1_CONVERTTOLLVMDIALECT_H_ diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 0033107db23d..5d20063946a7 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -406,11 +406,11 @@ struct LinalgTypeConverter : public LLVMTypeConverter { }; } // end anonymous namespace -void linalg::convertToLLVM(mlir::Module &module) { +void linalg::convertToLLVM(mlir::Module module) { // Remove affine constructs if any by using an existing pass. PassManager pm; pm.addPass(createLowerAffinePass()); - auto rr = pm.run(&module); + auto rr = pm.run(module); (void)rr; assert(succeeded(rr) && "affine loop lowering failed"); diff --git a/mlir/examples/Linalg/Linalg2/Example.cpp b/mlir/examples/Linalg/Linalg2/Example.cpp index 9534711f1f4f..cb93b96cc589 100644 --- a/mlir/examples/Linalg/Linalg2/Example.cpp +++ b/mlir/examples/Linalg/Linalg2/Example.cpp @@ -34,10 +34,10 @@ using namespace linalg::intrinsics; TEST_FUNC(linalg_ops) { MLIRContext context; - Module module(&context); + OwningModuleRef module = Module::create(&context); auto indexType = mlir::IndexType::get(&context); - mlir::Function f = - makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {}); + mlir::Function f = makeFunction(*module, "linalg_ops", + {indexType, indexType, indexType}, {}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); @@ -73,9 +73,9 @@ TEST_FUNC(linalg_ops) { TEST_FUNC(linalg_ops_folded_slices) { MLIRContext context; - Module module(&context); + OwningModuleRef module = Module::create(&context); auto indexType = mlir::IndexType::get(&context); - mlir::Function f = makeFunction(module, "linalg_ops_folded_slices", + mlir::Function f = makeFunction(*module, "linalg_ops_folded_slices", {indexType, indexType, indexType}, {}); OpBuilder builder(f.getBody()); diff --git a/mlir/examples/Linalg/Linalg3/Conversion.cpp b/mlir/examples/Linalg/Linalg3/Conversion.cpp index 23d1cfef5dc2..6bd428fe8a5e 100644 --- a/mlir/examples/Linalg/Linalg3/Conversion.cpp +++ b/mlir/examples/Linalg/Linalg3/Conversion.cpp @@ -37,7 +37,7 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); mlir::Function f = linalg::common::makeFunction( @@ -66,11 +66,11 @@ Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { TEST_FUNC(foo) { MLIRContext context; - Module module(&context); - mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + OwningModuleRef module = Module::create(&context); + mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops"); lowerToLoops(f); - convertLinalg3ToLLVM(module); + convertLinalg3ToLLVM(*module); // clang-format off // CHECK: {{.*}} = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> @@ -104,7 +104,7 @@ TEST_FUNC(foo) { // CHECK-NEXT: {{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> // CHECK-NEXT: llvm.store {{.*}}, {{.*}} : !llvm<"float*"> // clang-format on - module.print(llvm::outs()); + module->print(llvm::outs()); } int main() { diff --git a/mlir/examples/Linalg/Linalg3/Example.cpp b/mlir/examples/Linalg/Linalg3/Example.cpp index 8b04344b19eb..4ac6a009d622 100644 --- a/mlir/examples/Linalg/Linalg3/Example.cpp +++ b/mlir/examples/Linalg/Linalg3/Example.cpp @@ -34,7 +34,7 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); mlir::Function f = linalg::common::makeFunction( @@ -63,7 +63,7 @@ Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { TEST_FUNC(matmul_as_matvec) { MLIRContext context; - Module module(&context); + Module module = Module::create(&context); mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec"); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); @@ -81,7 +81,7 @@ TEST_FUNC(matmul_as_matvec) { TEST_FUNC(matmul_as_dot) { MLIRContext context; - Module module(&context); + Module module = Module::create(&context); mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_dot"); lowerToFinerGrainedTensorContraction(f); lowerToFinerGrainedTensorContraction(f); @@ -102,7 +102,7 @@ TEST_FUNC(matmul_as_dot) { TEST_FUNC(matmul_as_loops) { MLIRContext context; - Module module(&context); + Module module = Module::create(&context); mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); lowerToLoops(f); composeSliceOps(f); @@ -134,7 +134,7 @@ TEST_FUNC(matmul_as_loops) { TEST_FUNC(matmul_as_matvec_as_loops) { MLIRContext context; - Module module(&context); + Module module = Module::create(&context); mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops"); lowerToFinerGrainedTensorContraction(f); @@ -165,7 +165,7 @@ TEST_FUNC(matmul_as_matvec_as_loops) { TEST_FUNC(matmul_as_matvec_as_affine) { MLIRContext context; - Module module(&context); + Module module = Module::create(&context); mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_affine"); lowerToFinerGrainedTensorContraction(f); diff --git a/mlir/examples/Linalg/Linalg3/Execution.cpp b/mlir/examples/Linalg/Linalg3/Execution.cpp index 94b233a56b01..4b1078791b74 100644 --- a/mlir/examples/Linalg/Linalg3/Execution.cpp +++ b/mlir/examples/Linalg/Linalg3/Execution.cpp @@ -37,7 +37,7 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); mlir::Function f = linalg::common::makeFunction( @@ -109,14 +109,14 @@ TEST_FUNC(execution) { // linalg.matmul operation and lower it all the way down to the LLVM IR // dialect through partial conversions. MLIRContext context; - Module module(&context); - mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + OwningModuleRef module = Module::create(&context); + mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops"); lowerToLoops(f); - convertLinalg3ToLLVM(module); + convertLinalg3ToLLVM(*module); // Create an MLIR execution engine. The execution engine eagerly JIT-compiles // the module. - auto maybeEngine = mlir::ExecutionEngine::create(&module); + auto maybeEngine = mlir::ExecutionEngine::create(*module); assert(maybeEngine && "failed to construct an execution engine"); auto &engine = maybeEngine.get(); diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/ConvertToLLVMDialect.h b/mlir/examples/Linalg/Linalg3/include/linalg3/ConvertToLLVMDialect.h index 8f122e05d2f6..a1854ae77f99 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/ConvertToLLVMDialect.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/ConvertToLLVMDialect.h @@ -23,7 +23,7 @@ class Module; } // end namespace mlir namespace linalg { -void convertLinalg3ToLLVM(mlir::Module &module); +void convertLinalg3ToLLVM(mlir::Module module); } // end namespace linalg #endif // LINALG3_CONVERTTOLLVMDIALECT_H_ diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 96b0f371ef1a..a01a7fd79a39 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -146,7 +146,7 @@ static void populateLinalg3ToLLVMConversionPatterns( context); } -void linalg::convertLinalg3ToLLVM(Module &module) { +void linalg::convertLinalg3ToLLVM(Module module) { // Remove affine constructs. for (auto func : module) { auto rr = lowerAffineConstructs(func); diff --git a/mlir/examples/Linalg/Linalg4/Example.cpp b/mlir/examples/Linalg/Linalg4/Example.cpp index 873e57e78f33..d8ad7c693118 100644 --- a/mlir/examples/Linalg/Linalg4/Example.cpp +++ b/mlir/examples/Linalg/Linalg4/Example.cpp @@ -34,7 +34,7 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); mlir::Function f = linalg::common::makeFunction( @@ -64,8 +64,8 @@ Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { TEST_FUNC(matmul_tiled_loops) { MLIRContext context; - Module module(&context); - mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops"); + OwningModuleRef module = Module::create(&context); + mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_loops"); lowerToTiledLoops(f, {8, 9}); PassManager pm; pm.addPass(createLowerLinalgLoadStorePass()); @@ -95,8 +95,8 @@ TEST_FUNC(matmul_tiled_loops) { TEST_FUNC(matmul_tiled_views) { MLIRContext context; - Module module(&context); - mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views"); + OwningModuleRef module = Module::create(&context); + mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_views"); OpBuilder b(f.getBody()); lowerToTiledViews(f, {b.create(f.getLoc(), 8), b.create(f.getLoc(), 9)}); @@ -124,9 +124,9 @@ TEST_FUNC(matmul_tiled_views) { TEST_FUNC(matmul_tiled_views_as_loops) { MLIRContext context; - Module module(&context); + OwningModuleRef module = Module::create(&context); mlir::Function f = - makeFunctionWithAMatmulOp(module, "matmul_tiled_views_as_loops"); + makeFunctionWithAMatmulOp(*module, "matmul_tiled_views_as_loops"); OpBuilder b(f.getBody()); lowerToTiledViews(f, {b.create(f.getLoc(), 8), b.create(f.getLoc(), 9)}); diff --git a/mlir/examples/toy/Ch2/include/toy/MLIRGen.h b/mlir/examples/toy/Ch2/include/toy/MLIRGen.h index 21637bc19af4..287f432c8479 100644 --- a/mlir/examples/toy/Ch2/include/toy/MLIRGen.h +++ b/mlir/examples/toy/Ch2/include/toy/MLIRGen.h @@ -27,7 +27,7 @@ namespace mlir { class MLIRContext; -class Module; +class OwningModuleRef; } // namespace mlir namespace toy { @@ -35,8 +35,7 @@ class ModuleAST; /// Emit IR for the given Toy moduleAST, returns a newly created MLIR module /// or nullptr on failure. -std::unique_ptr mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST); +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); } // namespace toy #endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index 73789fa41a4f..e3f1b19d25fa 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -66,22 +66,22 @@ public: /// Public API: convert the AST for a Toy module (source file) to an MLIR /// Module. - std::unique_ptr mlirGen(ModuleAST &moduleAST) { + mlir::Module mlirGen(ModuleAST &moduleAST) { // We create an empty MLIR module and codegen functions one at a time and // add them to the module. - theModule = make_unique(&context); + theModule = mlir::Module::create(&context); for (FunctionAST &F : moduleAST) { auto func = mlirGen(F); if (!func) return nullptr; - theModule->push_back(func); + theModule.push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, // this won't do much, but it should at least check some structural // properties. - if (failed(theModule->verify())) { + if (failed(theModule.verify())) { emitError(mlir::UnknownLoc::get(&context), "Module verification error"); return nullptr; } @@ -96,7 +96,7 @@ private: mlir::MLIRContext &context; /// A "module" matches a source file: it contains a list of functions. - std::unique_ptr theModule; + mlir::Module theModule; /// The builder is a helper class to create IR inside a function. It is /// re-initialized every time we enter a function and kept around as a @@ -500,8 +500,8 @@ private: namespace toy { // The public API for codegen. -std::unique_ptr mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST) { +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { return MLIRGenImpl(context).mlirGen(moduleAST); } diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp index 984676452fb4..b541486ab57b 100644 --- a/mlir/examples/toy/Ch2/toyc.cpp +++ b/mlir/examples/toy/Ch2/toyc.cpp @@ -75,7 +75,7 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) { int dumpMLIR() { mlir::MLIRContext context; - std::unique_ptr module; + mlir::OwningModuleRef module; if (inputType == InputType::MLIR || llvm::StringRef(inputFilename).endswith(".mlir")) { llvm::ErrorOr> fileOrErr = @@ -86,7 +86,7 @@ int dumpMLIR() { } llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - module.reset(mlir::parseSourceFile(sourceMgr, &context)); + module = mlir::parseSourceFile(sourceMgr, &context); if (!module) { llvm::errs() << "Error can't load file " << inputFilename << "\n"; return 3; diff --git a/mlir/examples/toy/Ch3/include/toy/MLIRGen.h b/mlir/examples/toy/Ch3/include/toy/MLIRGen.h index 21637bc19af4..287f432c8479 100644 --- a/mlir/examples/toy/Ch3/include/toy/MLIRGen.h +++ b/mlir/examples/toy/Ch3/include/toy/MLIRGen.h @@ -27,7 +27,7 @@ namespace mlir { class MLIRContext; -class Module; +class OwningModuleRef; } // namespace mlir namespace toy { @@ -35,8 +35,7 @@ class ModuleAST; /// Emit IR for the given Toy moduleAST, returns a newly created MLIR module /// or nullptr on failure. -std::unique_ptr mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST); +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); } // namespace toy #endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index 23cb85309c29..5d2e3af5ef24 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -67,10 +67,10 @@ public: /// Public API: convert the AST for a Toy module (source file) to an MLIR /// Module. - std::unique_ptr mlirGen(ModuleAST &moduleAST) { + mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) { // We create an empty MLIR module and codegen functions one at a time and // add them to the module. - theModule = make_unique(&context); + theModule = mlir::Module::create(&context); for (FunctionAST &F : moduleAST) { auto func = mlirGen(F); @@ -97,7 +97,7 @@ private: mlir::MLIRContext &context; /// A "module" matches a source file: it contains a list of functions. - std::unique_ptr theModule; + mlir::OwningModuleRef theModule; /// The builder is a helper class to create IR inside a function. It is /// re-initialized every time we enter a function and kept around as a @@ -469,8 +469,8 @@ private: namespace toy { // The public API for codegen. -std::unique_ptr mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST) { +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { return MLIRGenImpl(context).mlirGen(moduleAST); } diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp index 3d18417dc8aa..864dc420d485 100644 --- a/mlir/examples/toy/Ch3/toyc.cpp +++ b/mlir/examples/toy/Ch3/toyc.cpp @@ -79,7 +79,7 @@ int dumpMLIR() { mlir::registerDialect(); mlir::MLIRContext context; - std::unique_ptr module; + mlir::OwningModuleRef module; if (inputType == InputType::MLIR || llvm::StringRef(inputFilename).endswith(".mlir")) { llvm::ErrorOr> fileOrErr = @@ -90,7 +90,7 @@ int dumpMLIR() { } llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - module.reset(mlir::parseSourceFile(sourceMgr, &context)); + module = mlir::parseSourceFile(sourceMgr, &context); if (!module) { llvm::errs() << "Error can't load file " << inputFilename << "\n"; return 3; diff --git a/mlir/examples/toy/Ch4/include/toy/MLIRGen.h b/mlir/examples/toy/Ch4/include/toy/MLIRGen.h index 21637bc19af4..287f432c8479 100644 --- a/mlir/examples/toy/Ch4/include/toy/MLIRGen.h +++ b/mlir/examples/toy/Ch4/include/toy/MLIRGen.h @@ -27,7 +27,7 @@ namespace mlir { class MLIRContext; -class Module; +class OwningModuleRef; } // namespace mlir namespace toy { @@ -35,8 +35,7 @@ class ModuleAST; /// Emit IR for the given Toy moduleAST, returns a newly created MLIR module /// or nullptr on failure. -std::unique_ptr mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST); +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); } // namespace toy #endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index f2132c29c336..a48212424af0 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -67,10 +67,10 @@ public: /// Public API: convert the AST for a Toy module (source file) to an MLIR /// Module. - std::unique_ptr mlirGen(ModuleAST &moduleAST) { + mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) { // We create an empty MLIR module and codegen functions one at a time and // add them to the module. - theModule = make_unique(&context); + theModule = mlir::Module::create(&context); for (FunctionAST &F : moduleAST) { auto func = mlirGen(F); @@ -97,7 +97,7 @@ private: mlir::MLIRContext &context; /// A "module" matches a source file: it contains a list of functions. - std::unique_ptr theModule; + mlir::OwningModuleRef theModule; /// The builder is a helper class to create IR inside a function. It is /// re-initialized every time we enter a function and kept around as a @@ -469,8 +469,8 @@ private: namespace toy { // The public API for codegen. -std::unique_ptr mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST) { +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { return MLIRGenImpl(context).mlirGen(moduleAST); } diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index f237fd9fb53b..3650e5f022c0 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -119,7 +119,7 @@ public: }; void runOnModule() override { - auto &module = getModule(); + auto module = getModule(); auto main = module.getNamedFunction("main"); if (!main) { emitError(mlir::UnknownLoc::get(module.getContext()), diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp index 77c039d548ef..8dbc6f8c954f 100644 --- a/mlir/examples/toy/Ch4/toyc.cpp +++ b/mlir/examples/toy/Ch4/toyc.cpp @@ -78,7 +78,7 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) { return parser.ParseModule(); } -mlir::LogicalResult optimize(mlir::Module &module) { +mlir::LogicalResult optimize(mlir::Module module) { mlir::PassManager pm; pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(createShapeInferencePass()); @@ -86,7 +86,7 @@ mlir::LogicalResult optimize(mlir::Module &module) { // Apply any generic pass manager command line options. applyPassManagerCLOptions(pm); - return pm.run(&module); + return pm.run(module); } int dumpMLIR() { @@ -97,7 +97,7 @@ int dumpMLIR() { mlir::registerPassManagerCLOptions(); mlir::MLIRContext context; - std::unique_ptr module; + mlir::OwningModuleRef module; if (inputType == InputType::MLIR || llvm::StringRef(inputFilename).endswith(".mlir")) { llvm::ErrorOr> fileOrErr = @@ -108,7 +108,7 @@ int dumpMLIR() { } llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - module.reset(mlir::parseSourceFile(sourceMgr, &context)); + module = mlir::parseSourceFile(sourceMgr, &context); if (!module) { llvm::errs() << "Error can't load file " << inputFilename << "\n"; return 3; diff --git a/mlir/examples/toy/Ch5/include/toy/MLIRGen.h b/mlir/examples/toy/Ch5/include/toy/MLIRGen.h index 21637bc19af4..287f432c8479 100644 --- a/mlir/examples/toy/Ch5/include/toy/MLIRGen.h +++ b/mlir/examples/toy/Ch5/include/toy/MLIRGen.h @@ -27,7 +27,7 @@ namespace mlir { class MLIRContext; -class Module; +class OwningModuleRef; } // namespace mlir namespace toy { @@ -35,8 +35,7 @@ class ModuleAST; /// Emit IR for the given Toy moduleAST, returns a newly created MLIR module /// or nullptr on failure. -std::unique_ptr mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST); +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST); } // namespace toy #endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_ diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 60a8b5a3b9a8..1c2208062668 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -136,7 +136,7 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, PatternRewriter &rewriter) const override { // Get or create the declaration of the printf function in the module. - Function printfFunc = getPrintf(*op->getFunction().getModule()); + Function printfFunc = getPrintf(op->getFunction().getModule()); auto print = cast(op); auto loc = print.getLoc(); @@ -205,13 +205,13 @@ private: /// Return the prototype declaration for printf in the module, create it if /// necessary. - Function getPrintf(Module &module) const { + Function getPrintf(Module module) const { auto printfFunc = module.getNamedFunction("printf"); if (printfFunc) return printfFunc; // Create a function declaration for printf, signature is `i32 (i8*, ...)` - Builder builder(&module); + Builder builder(module); auto *dialect = module.getContext()->getRegisteredDialect(); diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index 9ebfeb438ca2..e0e88836dbfc 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -67,10 +67,10 @@ public: /// Public API: convert the AST for a Toy module (source file) to an MLIR /// Module. - std::unique_ptr mlirGen(ModuleAST &moduleAST) { + mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) { // We create an empty MLIR module and codegen functions one at a time and // add them to the module. - theModule = make_unique(&context); + theModule = mlir::Module::create(&context); for (FunctionAST &F : moduleAST) { auto func = mlirGen(F); @@ -97,7 +97,7 @@ private: mlir::MLIRContext &context; /// A "module" matches a source file: it contains a list of functions. - std::unique_ptr theModule; + mlir::OwningModuleRef theModule; /// The builder is a helper class to create IR inside a function. It is /// re-initialized every time we enter a function and kept around as a @@ -469,8 +469,8 @@ private: namespace toy { // The public API for codegen. -std::unique_ptr mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST) { +mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, + ModuleAST &moduleAST) { return MLIRGenImpl(context).mlirGen(moduleAST); } diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index 0abcb4bb8506..971cf0ac7ab9 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -119,8 +119,8 @@ public: }; void runOnModule() override { - auto &module = getModule(); - mlir::ModuleManager moduleManager(&module); + auto module = getModule(); + mlir::ModuleManager moduleManager(module); auto main = moduleManager.getNamedFunction("main"); if (!main) { emitError(mlir::UnknownLoc::get(module.getContext()), diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp index 9637d7257a26..b5bdde82c742 100644 --- a/mlir/examples/toy/Ch5/toyc.cpp +++ b/mlir/examples/toy/Ch5/toyc.cpp @@ -101,7 +101,7 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) { return parser.ParseModule(); } -mlir::LogicalResult optimize(mlir::Module &module) { +mlir::LogicalResult optimize(mlir::Module module) { mlir::PassManager pm; pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(createShapeInferencePass()); @@ -111,10 +111,10 @@ mlir::LogicalResult optimize(mlir::Module &module) { // Apply any generic pass manager command line options. applyPassManagerCLOptions(pm); - return pm.run(&module); + return pm.run(module); } -mlir::LogicalResult lowerDialect(mlir::Module &module, bool OnlyLinalg) { +mlir::LogicalResult lowerDialect(mlir::Module module, bool OnlyLinalg) { mlir::PassManager pm; pm.addPass(createEarlyLoweringPass()); pm.addPass(mlir::createCanonicalizerPass()); @@ -127,14 +127,14 @@ mlir::LogicalResult lowerDialect(mlir::Module &module, bool OnlyLinalg) { // Apply any generic pass manager command line options. applyPassManagerCLOptions(pm); - return pm.run(&module); + return pm.run(module); } -std::unique_ptr loadFileAndProcessModule( +mlir::OwningModuleRef loadFileAndProcessModule( mlir::MLIRContext &context, bool EnableLinalgLowering = false, bool EnableLLVMLowering = false, bool EnableOpt = false) { - std::unique_ptr module; + mlir::OwningModuleRef module; if (inputType == InputType::MLIR || llvm::StringRef(inputFilename).endswith(".mlir")) { llvm::ErrorOr> fileOrErr = @@ -145,7 +145,7 @@ std::unique_ptr loadFileAndProcessModule( } llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - module.reset(mlir::parseSourceFile(sourceMgr, &context)); + module = mlir::parseSourceFile(sourceMgr, &context); if (!module) { llvm::errs() << "Error can't load file " << inputFilename << "\n"; return nullptr; @@ -252,7 +252,7 @@ int runJit() { // the module. auto optPipeline = mlir::makeOptimizingTransformer( /* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0); - auto maybeEngine = mlir::ExecutionEngine::create(module.get(), optPipeline); + auto maybeEngine = mlir::ExecutionEngine::create(*module, optPipeline); assert(maybeEngine && "failed to construct an execution engine"); auto &engine = maybeEngine.get(); diff --git a/mlir/g3doc/WritingAPass.md b/mlir/g3doc/WritingAPass.md index 5b53a8a13661..691fe7ca09ae 100644 --- a/mlir/g3doc/WritingAPass.md +++ b/mlir/g3doc/WritingAPass.md @@ -54,10 +54,10 @@ namespace { struct MyFunctionPass : public FunctionPass { void runOnFunction() override { // Get the current function being operated on. - Function *f = getFunction(); + Function f = getFunction(); // Operate on the operations within the function. - f->walk([](Operation *inst) { + f.walk([](Operation *inst) { .... }); } @@ -94,10 +94,10 @@ namespace { struct MyModulePass : public ModulePass { void runOnModule() override { // Get the current module being operated on. - Module *m = getModule(); + Module m = getModule(); // Operate on the functions within the module. - for (auto &func : *m) { + for (auto func : m) { .... } } @@ -149,7 +149,7 @@ struct MyFunctionAnalysis { /// An interesting module analysis. struct MyModuleAnalysis { // Compute this analysis with the provided module. - MyModuleAnalysis(Module *module); + MyModuleAnalysis(Module module); }; void MyFunctionPass::runOnFunction() { @@ -181,7 +181,7 @@ void MyModulePass::runOnModule() { // Query MyFunctionAnalysis for a child function of the current module. It // will be computed if it doesn't exist. - auto *fn = &*getModule().begin(); + auto fn = *getModule().begin(); MyFunctionAnalysis &myAnalysis = getFunctionAnalysis(fn); } ``` @@ -255,7 +255,7 @@ pm.addPass(new MyFunctionPass3()); pm.addPass(new MyModulePass2()); // Run the pass manager on a module. -Module *m = ...; +Module m = ...; if (failed(pm.run(m))) ... // One of the passes signaled a failure. ``` @@ -384,7 +384,7 @@ unsigned domInfoCount; pm.addInstrumentation(new DominanceCounterInstrumentation(domInfoCount)); // Run the pass manager on a module. -Module *m = ...; +Module m = ...; if (failed(pm.run(m))) ... diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index f7f4caa35841..a949a877fc5c 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -48,7 +48,7 @@ namespace LLVM { /// support different values coming from the same predecessor. If a block has /// another block as a successor more than once with different values, insert /// a new dummy block for LLVM PHI nodes to tell the sources apart. -void ensureDistinctSuccessors(Module *m); +void ensureDistinctSuccessors(Module m); } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h index 2ee29454fb43..cbbfffd13d61 100644 --- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -62,7 +62,7 @@ public: /// If `sharedLibPaths` are provided, the underlying JIT-compilation will open /// and link the shared libraries for symbol resolution. static llvm::Expected> - create(Module *m, std::function transformer = {}, + create(Module m, std::function transformer = {}, ArrayRef sharedLibPaths = {}); /// Looks up a packed-argument function with the given name and returns a diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index e5c8c035c462..b954d46b152d 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -57,12 +57,12 @@ class UnitAttr; class Builder { public: explicit Builder(MLIRContext *context) : context(context) {} - explicit Builder(Module *module); + explicit Builder(Module module); MLIRContext *getContext() const { return context; } Identifier getIdentifier(StringRef str); - Module *createModule(); + Module createModule(); // Locations. Location getUnknownLoc(); diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 8c66dea7e416..0fee88c92bb2 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -34,10 +34,12 @@ class MLIRContext; class Module; namespace detail { +class ModuleStorage; + /// This class represents all of the internal state of a Function. This allows /// for the Function class to be value typed. class FunctionStorage - : public llvm::ilist_node_with_parent { + : public llvm::ilist_node_with_parent { FunctionStorage(Location location, StringRef name, FunctionType type, ArrayRef attrs = {}); FunctionStorage(Location location, StringRef name, FunctionType type, @@ -47,7 +49,7 @@ class FunctionStorage Identifier name; /// The module this function is embedded into. - Module *module = nullptr; + ModuleStorage *module = nullptr; /// The source location the function was defined or derived from. Location location; @@ -116,7 +118,7 @@ public: } MLIRContext *getContext(); - Module *getModule() { return impl->module; } + Module getModule(); /// Add an entry block to an empty function, and set up the block arguments /// to match the signature of the function. @@ -541,7 +543,7 @@ struct ilist_traits<::mlir::detail::FunctionStorage> function_iterator first, function_iterator last); private: - mlir::Module *getContainingModule(); + mlir::detail::ModuleStorage *getContainingModule(); }; // Functions hash just like pointers. diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h index d8a47891ace5..a77653b925e2 100644 --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -27,12 +27,46 @@ #include "llvm/ADT/ilist.h" namespace mlir { +class Module; + +namespace detail { +class ModuleStorage { + explicit ModuleStorage(MLIRContext *context) : context(context) {} + + /// getSublistAccess() - Returns pointer to member of function list + static llvm::iplist ModuleStorage::* + getSublistAccess(FunctionStorage *) { + return &ModuleStorage::functions; + } + + /// The context attached to this module. + MLIRContext *context; + + /// This is the actual list of functions the module contains. + llvm::iplist functions; + + friend Module; + friend struct llvm::ilist_traits; + friend FunctionStorage; + friend Function; +}; +} // end namespace detail class Module { public: - explicit Module(MLIRContext *context) : context(context) {} + Module(detail::ModuleStorage *impl = nullptr) : impl(impl) {} - MLIRContext *getContext() { return context; } + /// Construct a new module object with the given context. + static Module create(MLIRContext *context) { + return new detail::ModuleStorage(context); + } + + MLIRContext *getContext() { return impl->context; } + + /// Allow converting a Module to bool for null checks. + operator bool() const { return impl; } + bool operator==(Module other) const { return impl == other.impl; } + bool operator!=(Module other) const { return !(*this == other); } /// An iterator class used to iterate over the held functions. class iterator : public llvm::mapped_iterator< @@ -56,14 +90,14 @@ public: llvm::iterator_range getFunctions() { return {begin(), end()}; } // Iteration over the functions in the module. - iterator begin() { return functions.begin(); } - iterator end() { return functions.end(); } - Function front() { return &functions.front(); } - Function back() { return &functions.back(); } + iterator begin() { return impl->functions.begin(); } + iterator end() { return impl->functions.end(); } + Function front() { return &impl->functions.front(); } + Function back() { return &impl->functions.back(); } - void push_back(Function fn) { functions.push_back(fn.impl); } + void push_back(Function fn) { impl->functions.push_back(fn.impl); } void insert(iterator insertPt, Function fn) { - functions.insert(insertPt.getCurrent(), fn.impl); + impl->functions.insert(insertPt.getCurrent(), fn.impl); } // Interfaces for working with the symbol table. @@ -79,6 +113,7 @@ public: /// name exists. Function names never include the @ on them. Note: This /// performs a linear scan of held symbols. Function getNamedFunction(Identifier name) { + auto &functions = impl->functions; auto it = llvm::find_if(functions, [name](detail::FunctionStorage &fn) { return Function(&fn).getName() == name; }); @@ -93,22 +128,27 @@ public: void print(raw_ostream &os); void dump(); + /// Erase the current module. + void erase() { + assert(impl && "expected valid module"); + delete impl; + } + + /// Methods for supporting PointerLikeTypeTraits. + const void *getAsOpaquePointer() const { + return static_cast(impl); + } + static Module getFromOpaquePointer(const void *pointer) { + return reinterpret_cast( + const_cast(pointer)); + } + private: - friend struct llvm::ilist_traits; friend detail::FunctionStorage; friend Function; - /// getSublistAccess() - Returns pointer to member of function list - static llvm::iplist Module::* - getSublistAccess(detail::FunctionStorage *) { - return &Module::functions; - } - - /// The context attached to this module. - MLIRContext *context; - - /// This is the actual list of functions the module contains. - llvm::iplist functions; + /// The internal impl storage object. + detail::ModuleStorage *impl = nullptr; }; /// A class used to manage the symbols held by a module. This class handles @@ -116,7 +156,7 @@ private: /// efficent named lookup to held symbols. class ModuleManager { public: - ModuleManager(Module *module) : module(module), symbolTable(module) {} + ModuleManager(Module module) : module(module), symbolTable(module) {} /// Look up a symbol with the specified name, returning null if no such /// name exists. Names must never include the @ on them. @@ -127,11 +167,11 @@ public: /// Insert a new symbol into the module, auto-renaming it as necessary. void insert(Function function) { symbolTable.insert(function); - module->push_back(function); + module.push_back(function); } void insert(Module::iterator insertPt, Function function) { symbolTable.insert(function); - module->insert(insertPt, function); + module.insert(insertPt, function); } /// Remove the given symbol from the module symbol table and then erase it. @@ -141,16 +181,53 @@ public: } /// Return the internally held module. - Module *getModule() const { return module; } + Module getModule() const { return module; } /// Return the context of the internal module. - MLIRContext *getContext() const { return module->getContext(); } + MLIRContext *getContext() const { return getModule().getContext(); } private: - Module *module; + Module module; SymbolTable symbolTable; }; +/// This class acts as an owning reference to a Module, and will automatically +/// destory the held Module if valid. +class OwningModuleRef { +public: + OwningModuleRef(std::nullptr_t = nullptr) {} + OwningModuleRef(Module module) : module(module) {} + OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {} + ~OwningModuleRef() { + if (module) + module.erase(); + } + + // Assign from another module reference. + OwningModuleRef &operator=(OwningModuleRef &&other) { + if (module) + module.erase(); + module = other.release(); + return *this; + } + + /// Allow accessing the internal module. + Module get() const { return module; } + Module operator*() const { return module; } + Module *operator->() { return &module; } + explicit operator bool() const { return module; } + + /// Release the referenced module. + Module release() { + Module released; + std::swap(released, module); + return released; + } + +private: + Module module; +}; + //===--------------------------------------------------------------------===// // Module Operation. //===--------------------------------------------------------------------===// @@ -196,4 +273,20 @@ public: } // end namespace mlir +namespace llvm { + +/// Allow stealing the low bits of ModuleStorage. +template <> struct PointerLikeTypeTraits { +public: + static inline void *getAsVoidPointer(mlir::Module I) { + return const_cast(I.getAsOpaquePointer()); + } + static inline mlir::Module getFromVoidPointer(void *P) { + return mlir::Module::getFromOpaquePointer(P); + } + enum { NumLowBitsAvailable = 3 }; +}; + +} // end namespace llvm + #endif // MLIR_IR_MODULE_H diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h index a351f66eb2e7..10d3a5e8acaa 100644 --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -31,7 +31,7 @@ class MLIRContext; class SymbolTable { public: /// Build a symbol table with the symbols within the given module. - SymbolTable(Module *module); + SymbolTable(Module module); /// Look up a symbol with the specified name, returning null if no such /// name exists. Names never include the @ on them. diff --git a/mlir/include/mlir/Parser.h b/mlir/include/mlir/Parser.h index a2673caac3f5..f347ff1360fe 100644 --- a/mlir/include/mlir/Parser.h +++ b/mlir/include/mlir/Parser.h @@ -37,24 +37,24 @@ class Type; /// This parses the file specified by the indicated SourceMgr and returns an /// MLIR module if it was valid. If not, the error message is emitted through /// the error handler registered in the context, and a null pointer is returned. -Module *parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context); +Module parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context); /// This parses the file specified by the indicated filename and returns an /// MLIR module if it was valid. If not, the error message is emitted through /// the error handler registered in the context, and a null pointer is returned. -Module *parseSourceFile(llvm::StringRef filename, MLIRContext *context); +Module parseSourceFile(llvm::StringRef filename, MLIRContext *context); /// This parses the file specified by the indicated filename using the provided /// SourceMgr and returns an MLIR module if it was valid. If not, the error /// message is emitted through the error handler registered in the context, and /// a null pointer is returned. -Module *parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr, - MLIRContext *context); +Module parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr, + MLIRContext *context); /// This parses the module string to a MLIR module if it was valid. If not, the /// error message is emitted through the error handler registered in the /// context, and a null pointer is returned. -Module *parseSourceString(llvm::StringRef moduleStr, MLIRContext *context); +Module parseSourceString(llvm::StringRef moduleStr, MLIRContext *context); /// This parses a single MLIR type to an MLIR context if it was valid. If not, /// an error message is emitted through a new SourceMgrDiagnosticHandler diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h index c44f88f6763a..18ba7a826ccc 100644 --- a/mlir/include/mlir/Pass/AnalysisManager.h +++ b/mlir/include/mlir/Pass/AnalysisManager.h @@ -223,7 +223,7 @@ private: /// An analysis manager for a specific module instance. class ModuleAnalysisManager { public: - ModuleAnalysisManager(Module *module, PassInstrumentor *passInstrumentor) + ModuleAnalysisManager(Module module, PassInstrumentor *passInstrumentor) : moduleAnalyses(module), passInstrumentor(passInstrumentor) {} ModuleAnalysisManager(const ModuleAnalysisManager &) = delete; ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete; @@ -273,7 +273,7 @@ private: functionAnalyses; /// The analyses for the owning module. - detail::AnalysisMap moduleAnalyses; + detail::AnalysisMap moduleAnalyses; /// An optional instrumentation object. PassInstrumentor *passInstrumentor; diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h index 41d20ccdd63d..6ee78c52a25b 100644 --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -138,8 +138,7 @@ private: /// Pass to transform a module. Derived passes should not inherit from this /// class directly, and instead should use the CRTP ModulePass class. class ModulePassBase : public Pass { - using PassStateT = - detail::PassExecutionState; + using PassStateT = detail::PassExecutionState; public: static bool classof(const Pass *pass) { @@ -153,7 +152,7 @@ protected: virtual void runOnModule() = 0; /// Return the current module being transformed. - Module &getModule() { return *getPassState().irAndPassFailed.getPointer(); } + Module getModule() { return getPassState().irAndPassFailed.getPointer(); } /// Return the MLIR context for the current module being transformed. MLIRContext &getContext() { return *getModule().getContext(); } @@ -172,7 +171,7 @@ protected: private: /// Forwarding function to execute this pass. LLVM_NODISCARD - LogicalResult run(Module *module, ModuleAnalysisManager &mam); + LogicalResult run(Module module, ModuleAnalysisManager &mam); /// The current execution state for the pass. llvm::Optional passState; diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index 32d5fd66a633..3277d36c58dd 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -60,7 +60,7 @@ public: /// Run the passes within this manager on the provided module. LLVM_NODISCARD - LogicalResult run(Module *module); + LogicalResult run(Module module); //===--------------------------------------------------------------------===// // Pipeline Building diff --git a/mlir/include/mlir/Target/LLVMIR.h b/mlir/include/mlir/Target/LLVMIR.h index a4f45ff23983..e227c3bd783b 100644 --- a/mlir/include/mlir/Target/LLVMIR.h +++ b/mlir/include/mlir/Target/LLVMIR.h @@ -38,7 +38,7 @@ class Module; /// from the registered LLVM IR dialect. In case of error, report it /// to the error handler registered with the MLIR context, if any (obtained from /// the MLIR module), and return `nullptr`. -std::unique_ptr translateModuleToLLVMIR(Module &m); +std::unique_ptr translateModuleToLLVMIR(Module m); } // namespace mlir diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 493b0caaa289..d76776e70c61 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -24,7 +24,7 @@ #define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H #include "mlir/IR/Block.h" -#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" #include "mlir/IR/Value.h" #include "llvm/IR/BasicBlock.h" @@ -48,7 +48,7 @@ namespace LLVM { class ModuleTranslation { public: template - static std::unique_ptr translateModule(Module &m) { + static std::unique_ptr translateModule(Module m) { auto llvmModule = prepareLLVMModule(m); T translator(m); @@ -63,17 +63,17 @@ protected: // Translate the given MLIR module expressed in MLIR LLVM IR dialect into an // LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an // LLVMContext, the LLVM IR module will be created in that context. - explicit ModuleTranslation(Module &module) : mlirModule(module) {} + explicit ModuleTranslation(Module module) : mlirModule(module) {} virtual ~ModuleTranslation() {} virtual bool convertOperation(Operation &op, llvm::IRBuilder<> &builder); - static std::unique_ptr prepareLLVMModule(Module &m); + static std::unique_ptr prepareLLVMModule(Module m); private: bool convertFunctions(); - bool convertOneFunction(Function &func); - void connectPHINodes(Function &func); + bool convertOneFunction(Function func); + void connectPHINodes(Function func); bool convertBlock(Block &bb, bool ignoreArguments); template @@ -83,7 +83,7 @@ private: Location loc); // Original and translated module. - Module &mlirModule; + Module mlirModule; std::unique_ptr llvmModule; protected: diff --git a/mlir/include/mlir/Target/NVVMIR.h b/mlir/include/mlir/Target/NVVMIR.h index 27f964e131cb..ba46947e7b94 100644 --- a/mlir/include/mlir/Target/NVVMIR.h +++ b/mlir/include/mlir/Target/NVVMIR.h @@ -30,7 +30,6 @@ class Module; } // namespace llvm namespace mlir { - class Module; /// Convert the given MLIR module into NVVM IR. This conversion requires the @@ -38,7 +37,7 @@ class Module; /// from the registered LLVM IR dialect. In case of error, report it /// to the error handler registered with the MLIR context, if any (obtained from /// the MLIR module), and return `nullptr`. -std::unique_ptr translateModuleToNVVMIR(Module &m); +std::unique_ptr translateModuleToNVVMIR(Module m); } // namespace mlir diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index c8ede78ec204..0101673b500a 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -339,7 +339,7 @@ private: /// conversion object. This function returns failure if a type conversion /// failed. LLVM_NODISCARD LogicalResult applyConversionPatterns( - Module &module, ConversionTarget &target, TypeConverter &converter, + Module module, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns); /// Convert the given functions with the provided conversion patterns. This diff --git a/mlir/include/mlir/Translation.h b/mlir/include/mlir/Translation.h index 78518e9f39b1..63360830c547 100644 --- a/mlir/include/mlir/Translation.h +++ b/mlir/include/mlir/Translation.h @@ -27,17 +27,17 @@ namespace mlir { class MLIRContext; class Module; +class OwningModuleRef; /// Interface of the function that translates a file to MLIR. The /// implementation should create a new MLIR Module in the given context and /// return a pointer to it, or a nullptr in case of any error. using TranslateToMLIRFunction = - std::function(llvm::StringRef, MLIRContext *)>; + std::function; /// Interface of the function that translates MLIR to a different format and /// outputs the result to a file. The implementation should return "true" on /// error and "false" otherwise. It is allowed to modify the module. -using TranslateFromMLIRFunction = - std::function; +using TranslateFromMLIRFunction = std::function; /// Use Translate[To|From]MLIRRegistration as a global initialiser that /// registers a function and associates it with name. This requires that a diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index 022d8c70cc61..246bd549f43f 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -139,7 +139,7 @@ LogicalResult GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) { Builder builder(function.getContext()); - std::unique_ptr module(builder.createModule()); + OwningModuleRef module = builder.createModule(); // TODO(herhut): Also handle called functions. module->push_back(function.clone()); @@ -147,8 +147,9 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) { auto llvmModule = translateModuleToNVVMIR(*module); auto cubin = convertModuleToCubin(*llvmModule, function); - if (!cubin) + if (!cubin) { return function.emitError("Translation to CUDA binary failed."); + } function.setAttr(kCubinAnnotation, builder.getStringAttr({cubin->data(), cubin->size()})); diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index f9d5899456a1..0759324c77c4 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -152,8 +152,8 @@ private: // The types in comments give the actual types expected/returned but the API // uses void pointers. This is fine as they have the same linkage in C. void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { - Module &module = getModule(); - Builder builder(&module); + Module module = getModule(); + Builder builder(module); if (!module.getNamedFunction(cuModuleLoadName)) { module.push_back( Function::create(loc, cuModuleLoadName, @@ -343,7 +343,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( ArrayRef{cuModule, data.getResult(0)}); // Get the function from the module. The name corresponds to the name of // the kernel function. - auto cuModuleRef = + auto cuOwningModuleRef = builder.create(loc, getPointerType(), cuModule); auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder); auto cuFunction = allocatePointer(builder, loc); @@ -352,7 +352,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( builder.create( loc, ArrayRef{getCUResultType()}, builder.getFunctionAttr(cuModuleGetFunction), - ArrayRef{cuFunction, cuModuleRef, kernelName}); + ArrayRef{cuFunction, cuOwningModuleRef, kernelName}); // Grab the global stream needed for execution. Function cuGetStreamHelper = getModule().getNamedFunction(cuGetStreamHelperName); diff --git a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp index 97790a5afce1..550491b3cd0e 100644 --- a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp @@ -115,7 +115,7 @@ public: void runOnModule() override { llvmDialect = getModule().getContext()->getRegisteredDialect(); - auto &module = getModule(); + auto module = getModule(); Builder builder(&getContext()); auto functions = module.getFunctions(); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index e849f6fd023b..a0b911e8b79f 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -442,13 +442,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern { // Insert the `malloc` declaration if it is not already present. Function mallocFunc = - op->getFunction().getModule()->getNamedFunction("malloc"); + op->getFunction().getModule().getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(getIndexType(), getVoidPtrType()); mallocFunc = Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); - op->getFunction().getModule()->push_back(mallocFunc); + op->getFunction().getModule().push_back(mallocFunc); } // Allocate the underlying buffer and store a pointer to it in the MemRef @@ -503,11 +503,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { OperandAdaptor transformed(operands); // Insert the `free` declaration if it is not already present. - Function freeFunc = op->getFunction().getModule()->getNamedFunction("free"); + Function freeFunc = op->getFunction().getModule().getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); - op->getFunction().getModule()->push_back(freeFunc); + op->getFunction().getModule().push_back(freeFunc); } auto type = transformed.memref()->getType().cast(); @@ -936,8 +936,8 @@ static void ensureDistinctSuccessors(Block &bb) { } } -void mlir::LLVM::ensureDistinctSuccessors(Module *m) { - for (auto f : *m) { +void mlir::LLVM::ensureDistinctSuccessors(Module m) { + for (auto f : m) { for (auto &bb : f.getBlocks()) { ::ensureDistinctSuccessors(bb); } @@ -1010,8 +1010,8 @@ namespace { struct LLVMLoweringPass : public ModulePass { // Run the dialect converter on the module. void runOnModule() override { - Module &m = getModule(); - LLVM::ensureDistinctSuccessors(&m); + Module m = getModule(); + LLVM::ensureDistinctSuccessors(m); LLVMTypeConverter converter(&getContext()); OwningRewritePatternList patterns; diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 5a1151203d07..ea9788ca4755 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -322,7 +322,7 @@ void packFunctionArguments(llvm::Module *module) { ExecutionEngine::~ExecutionEngine() = default; Expected> -ExecutionEngine::create(Module *m, +ExecutionEngine::create(Module m, std::function transformer, ArrayRef sharedLibPaths) { auto engine = llvm::make_unique(); @@ -330,7 +330,7 @@ ExecutionEngine::create(Module *m, if (!expectedJIT) return expectedJIT.takeError(); - auto llvmModule = translateModuleToLLVMIR(*m); + auto llvmModule = translateModuleToLLVMIR(m); if (!llvmModule) return make_string_error("could not convert to LLVM IR"); // FIXME: the triple should be passed to the translation or dialect conversion diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp index eadd5d6bfb9d..6cf57b42f457 100644 --- a/mlir/lib/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/GPU/IR/GPUDialect.cpp @@ -426,8 +426,8 @@ LogicalResult LaunchFuncOp::verify() { return emitOpError("attribute 'kernel' must be a function"); } - auto *module = getOperation()->getFunction().getModule(); - Function kernelFunc = module->getNamedFunction(kernel()); + auto module = getOperation()->getFunction().getModule(); + Function kernelFunc = module.getNamedFunction(kernel()); if (!kernelFunc) return emitError() << "kernel function '" << kernelAttr << "' is undefined"; diff --git a/mlir/lib/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/GPU/Transforms/KernelOutlining.cpp index f93febcf5da8..6cb920a10b3c 100644 --- a/mlir/lib/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/GPU/Transforms/KernelOutlining.cpp @@ -97,7 +97,7 @@ namespace { class GpuKernelOutliningPass : public ModulePass { public: void runOnModule() override { - ModuleManager moduleManager(&getModule()); + ModuleManager moduleManager(getModule()); for (auto func : getModule()) { func.walk([&](mlir::gpu::LaunchOp op) { Function outlinedFunc = outlineKernelFunc(op); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 52f54fe1a4a8..12e24df7f140 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -91,7 +91,7 @@ public: explicit ModuleState(MLIRContext *context) : context(context) {} // Initializes module state, populating affine map state. - void initialize(Module *module); + void initialize(Module module); Twine getAttributeAlias(Attribute attr) const { auto alias = attrToAlias.find(attr); @@ -301,12 +301,12 @@ void ModuleState::initializeSymbolAliases() { } // Initializes module state, populating affine map and integer set state. -void ModuleState::initialize(Module *module) { +void ModuleState::initialize(Module module) { // Initialize the symbol aliases. initializeSymbolAliases(); // Walk the module and visit each operation. - for (auto fn : *module) { + for (auto fn : module) { visitType(fn.getType()); for (auto attr : fn.getAttrs()) ModuleState::visitAttribute(attr.second); @@ -331,7 +331,7 @@ public: interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; }); } - void print(Module *module); + void print(Module module); /// Print the given attribute. If 'mayElideType' is true, some attributes are /// printed without the type when the type matches the default used in the @@ -451,13 +451,13 @@ void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) { } } -void ModulePrinter::print(Module *module) { +void ModulePrinter::print(Module module) { // Output the aliases at the top level. state.printAttributeAliases(os); state.printTypeAliases(os); // Print the module. - for (auto fn : *module) + for (auto fn : module) print(fn); } @@ -1784,8 +1784,8 @@ void Function::dump() { print(llvm::errs()); } void Module::print(raw_ostream &os) { ModuleState state(getContext()); - state.initialize(this); - ModulePrinter(os, state).print(this); + state.initialize(*this); + ModulePrinter(os, state).print(*this); } void Module::dump() { print(llvm::errs()); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 89df64260d39..6d0df6ded8ed 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -26,13 +26,13 @@ #include "mlir/Support/Functional.h" using namespace mlir; -Builder::Builder(Module *module) : context(module->getContext()) {} +Builder::Builder(Module module) : context(module.getContext()) {} Identifier Builder::getIdentifier(StringRef str) { return Identifier::get(str, context); } -Module *Builder::createModule() { return new Module(context); } +Module Builder::createModule() { return Module::create(context); } //===----------------------------------------------------------------------===// // Locations. diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 9c665e7b8363..77425c7b7e39 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -43,12 +43,14 @@ FunctionStorage::FunctionStorage(Location location, StringRef name, type(type), attrs(attrs), argAttrs(argAttrs), body(this) {} MLIRContext *Function::getContext() { return getType().getContext(); } +Module Function::getModule() { return impl->module; } -Module *llvm::ilist_traits::getContainingModule() { - size_t Offset( - size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr)))); +ModuleStorage *llvm::ilist_traits::getContainingModule() { + size_t Offset(size_t( + &((ModuleStorage *)nullptr->*ModuleStorage::getSublistAccess(nullptr)))); iplist *Anchor(static_cast *>(this)); - return reinterpret_cast(reinterpret_cast(Anchor) - Offset); + return reinterpret_cast(reinterpret_cast(Anchor) - + Offset); } /// This is a trait method invoked when a Function is added to a Module. We @@ -74,7 +76,7 @@ void llvm::ilist_traits::transferNodesFromList( function_iterator last) { // If we are transferring functions within the same module, the Module // pointer doesn't need to be updated. - Module *curParent = getContainingModule(); + ModuleStorage *curParent = getContainingModule(); if (curParent == otherList.getContainingModule()) return; @@ -87,8 +89,8 @@ void llvm::ilist_traits::transferNodesFromList( /// Unlink this function from its Module and delete it. void Function::erase() { - if (auto *module = getModule()) - getModule()->functions.erase(impl); + if (auto module = getModule()) + module.impl->functions.erase(impl); else delete impl; } diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index dafbd48f513e..02721b512eb2 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -21,8 +21,8 @@ using namespace mlir; /// Build a symbol table with the symbols within the given module. -SymbolTable::SymbolTable(Module *module) : context(module->getContext()) { - for (auto func : *module) { +SymbolTable::SymbolTable(Module module) : context(module.getContext()) { + for (auto func : module) { auto inserted = symbolTable.insert({func.getName(), func}); (void)inserted; assert(inserted.second && diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 5fe4f07613ab..7c5157012b8d 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -170,13 +170,13 @@ public: LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); // Insert the `malloc` declaration if it is not already present. - auto *module = op->getFunction().getModule(); - Function mallocFunc = module->getNamedFunction("malloc"); + auto module = op->getFunction().getModule(); + Function mallocFunc = module.getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy); mallocFunc = Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); - module->push_back(mallocFunc); + module.push_back(mallocFunc); } // Get MLIR types for injecting element pointer. @@ -231,12 +231,12 @@ public: auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); // Insert the `free` declaration if it is not already present. - auto *module = op->getFunction().getModule(); - Function freeFunc = module->getNamedFunction("free"); + auto module = op->getFunction().getModule(); + Function freeFunc = module.getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(voidPtrTy, {}); freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); - module->push_back(freeFunc); + module.push_back(freeFunc); } // Get MLIR types for extracting element pointer. @@ -576,7 +576,7 @@ public: static Function getLLVMLibraryCallImplDefinition(Function libFn) { auto implFnName = (libFn.getName().str() + "_impl"); auto module = libFn.getModule(); - if (auto f = module->getNamedFunction(implFnName)) { + if (auto f = module.getNamedFunction(implFnName)) { return f; } SmallVector fnArgTypes; @@ -590,7 +590,7 @@ static Function getLLVMLibraryCallImplDefinition(Function libFn) { // Insert the implementation function definition. auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType); - module->push_back(implFnDefn); + module.push_back(implFnDefn); return implFnDefn; } @@ -603,7 +603,7 @@ static Function getLLVMLibraryCallDeclaration(Operation *op, assert(isa(op)); auto fnName = LinalgOp::getLibraryCallName(); auto module = op->getFunction().getModule(); - if (auto f = module->getNamedFunction(fnName)) { + if (auto f = module.getNamedFunction(fnName)) { return f; } @@ -620,7 +620,7 @@ static Function getLLVMLibraryCallDeclaration(Operation *op, "have void return types"); auto libFnType = FunctionType::get(inputTypes, {}, op->getContext()); auto libFn = Function::create(op->getLoc(), fnName, libFnType); - module->push_back(libFn); + module.push_back(libFn); // Return after creating the function definition. The body will be created // later. return libFn; @@ -802,7 +802,7 @@ static void lowerLinalgForToCFG(Function &f) { } void LowerLinalgToLLVMPass::runOnModule() { - auto &module = getModule(); + auto module = getModule(); for (auto f : module.getFunctions()) { lowerLinalgSubViewOps(f); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 677b6eb3500c..5422b8f0f925 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3857,7 +3857,7 @@ class ModuleParser : public Parser { public: explicit ModuleParser(ParserState &state) : Parser(state) {} - ParseResult parseModule(Module *module); + ParseResult parseModule(Module module); private: /// Parse an attribute alias declaration. @@ -3875,7 +3875,7 @@ private: StringRef &name, FunctionType &type, SmallVectorImpl> &argNames, SmallVectorImpl> &argAttrs); - ParseResult parseFunc(Module *module); + ParseResult parseFunc(Module module); }; } // end anonymous namespace @@ -4039,7 +4039,7 @@ ParseResult ModuleParser::parseFunctionSignature( /// function-body ::= `{` block+ `}` /// function-attributes ::= `attributes` attribute-dict /// -ParseResult ModuleParser::parseFunc(Module *module) { +ParseResult ModuleParser::parseFunc(Module module) { consumeToken(); StringRef name; @@ -4061,7 +4061,7 @@ ParseResult ModuleParser::parseFunc(Module *module) { // Okay, the function signature was parsed correctly, create the function now. auto function = Function::create(getEncodedSourceLocation(loc), name, type, attrs); - module->push_back(function); + module.push_back(function); // Parse an optional trailing location. if (parseOptionalTrailingLocation(function)) @@ -4097,7 +4097,7 @@ ParseResult ModuleParser::parseFunc(Module *module) { } /// This is the top-level module parser. -ParseResult ModuleParser::parseModule(Module *module) { +ParseResult ModuleParser::parseModule(Module module) { while (1) { switch (getToken().getKind()) { default: @@ -4139,16 +4139,15 @@ ParseResult ModuleParser::parseModule(Module *module) { /// This parses the file specified by the indicated SourceMgr and returns an /// MLIR module if it was valid. If not, it emits diagnostics and returns /// null. -Module *mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, - MLIRContext *context) { +Module mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, + MLIRContext *context) { // This is the result module we are parsing into. - std::unique_ptr module(new Module(context)); + OwningModuleRef module(Module::create(context)); ParserState state(sourceMgr, context); - if (ModuleParser(state).parseModule(module.get())) { + if (ModuleParser(state).parseModule(*module)) return nullptr; - } // Make sure the parse module has no other structural problems detected by // the verifier. @@ -4161,7 +4160,7 @@ Module *mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, /// This parses the file specified by the indicated filename and returns an /// MLIR module if it was valid. If not, the error message is emitted through /// the error handler registered in the context, and a null pointer is returned. -Module *mlir::parseSourceFile(StringRef filename, MLIRContext *context) { +Module mlir::parseSourceFile(StringRef filename, MLIRContext *context) { llvm::SourceMgr sourceMgr; return parseSourceFile(filename, sourceMgr, context); } @@ -4170,8 +4169,8 @@ Module *mlir::parseSourceFile(StringRef filename, MLIRContext *context) { /// SourceMgr and returns an MLIR module if it was valid. If not, the error /// message is emitted through the error handler registered in the context, and /// a null pointer is returned. -Module *mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr, - MLIRContext *context) { +Module mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr, + MLIRContext *context) { if (sourceMgr.getNumBuffers() != 0) { // TODO(b/136086478): Extend to support multiple buffers. emitError(mlir::UnknownLoc::get(context), @@ -4192,7 +4191,7 @@ Module *mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr, /// This parses the program string to a MLIR module if it was valid. If not, /// it emits diagnostics and returns null. -Module *mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) { +Module mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) { auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); if (!memBuffer) return nullptr; diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 057f26552074..aef16ff231a5 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -66,7 +66,7 @@ static void printIR(const llvm::Any &ir, bool printModuleScope, // Print the function name and a newline before the Module. out << " (function: " << function.getName() << ")\n"; - function.getModule()->print(out); + function.getModule().print(out); return; } @@ -80,8 +80,8 @@ static void printIR(const llvm::Any &ir, bool printModuleScope, } // Print the given module. - assert(llvm::any_isa(ir) && "unexpected IR unit"); - llvm::any_cast(ir)->print(out); + assert(llvm::any_isa(ir) && "unexpected IR unit"); + llvm::any_cast(ir).print(out); } /// Instrumentation hooks. diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 27ec74c23c2a..feaf2bba4686 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -75,7 +75,7 @@ LogicalResult FunctionPassBase::run(Function fn, FunctionAnalysisManager &fam) { } /// Forwarding function to execute this pass. -LogicalResult ModulePassBase::run(Module *module, ModuleAnalysisManager &mam) { +LogicalResult ModulePassBase::run(Module module, ModuleAnalysisManager &mam) { // Initialize the pass state. passState.emplace(module, mam); @@ -124,7 +124,7 @@ LogicalResult detail::FunctionPassExecutor::run(Function function, } /// Run all of the passes in this manager over the current module. -LogicalResult detail::ModulePassExecutor::run(Module *module, +LogicalResult detail::ModulePassExecutor::run(Module module, ModuleAnalysisManager &mam) { // Run each of the held passes. for (auto &pass : passes) @@ -261,7 +261,7 @@ PassManager::PassManager(bool verifyPasses) PassManager::~PassManager() {} /// Run the passes within this manager on the provided module. -LogicalResult PassManager::run(Module *module) { +LogicalResult PassManager::run(Module module) { ModuleAnalysisManager mam(module, instrumentor.get()); return mpe->run(module, mam); } diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h index d2563fb62cd9..b0cd22820a3b 100644 --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -76,7 +76,7 @@ public: ModulePassExecutor &operator=(const ModulePassExecutor &) = delete; /// Run the executor on the given module. - LogicalResult run(Module *module, ModuleAnalysisManager &mam); + LogicalResult run(Module module, ModuleAnalysisManager &mam); /// Add a pass to the current executor. This takes ownership over the provided /// pass pointer. diff --git a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp index 543b7300af09..688d1f952643 100644 --- a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp +++ b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp @@ -34,10 +34,10 @@ using namespace mlir; // Adds a one-block function named as `spirv_module` to `module` and returns the // block. The created block will be terminated by `std.return`. -Block *createOneBlockFunction(Builder builder, Module *module) { +Block *createOneBlockFunction(Builder builder, Module module) { auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{}); auto fn = Function::create(builder.getUnknownLoc(), "spirv_module", fnType); - module->push_back(fn); + module.push_back(fn); auto *block = new Block(); fn.push_back(block); @@ -51,8 +51,8 @@ Block *createOneBlockFunction(Builder builder, Module *module) { // Deserializes the SPIR-V binary module stored in the file named as // `inputFilename` and returns a module containing the SPIR-V module. -std::unique_ptr deserializeModule(llvm::StringRef inputFilename, - MLIRContext *context) { +OwningModuleRef deserializeModule(llvm::StringRef inputFilename, + MLIRContext *context) { Builder builder(context); std::string errorMessage; @@ -83,7 +83,7 @@ std::unique_ptr deserializeModule(llvm::StringRef inputFilename, // converted SPIR-V ModuleOp inside a MLIR module. This should be changed to // return the SPIR-V ModuleOp directly after module and function are migrated // to be general ops. - std::unique_ptr module(builder.createModule()); + OwningModuleRef module(builder.createModule()); Block *block = createOneBlockFunction(builder, module.get()); block->push_front(spirvModule->getOperation()); diff --git a/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp index 33572d5adbe5..7cf9bdb1caf7 100644 --- a/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp +++ b/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp @@ -31,7 +31,7 @@ using namespace mlir; -LogicalResult serializeModule(Module *module, StringRef outputFilename) { +LogicalResult serializeModule(Module module, StringRef outputFilename) { if (!module) return failure(); @@ -45,7 +45,7 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) { // wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed // to take in the SPIR-V ModuleOp directly after module and function are // migrated to be general ops. - for (auto fn : *module) { + for (auto fn : module) { fn.walk([&](spirv::ModuleOp spirvModule) { if (done) { spirvModule.emitError("found more than one 'spv.module' op"); @@ -73,6 +73,6 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) { static TranslateFromMLIRRegistration registration("serialize-spirv", - [](Module *module, StringRef outputFilename) { + [](Module module, StringRef outputFilename) { return failed(serializeModule(module, outputFilename)); }); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 92a2389eea47..a527e8526871 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -440,7 +440,7 @@ static LogicalResult verify(CallOp op) { auto fnAttr = op.getAttrOfType("callee"); if (!fnAttr) return op.emitOpError("requires a 'callee' function attribute"); - auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction( + auto fn = op.getOperation()->getFunction().getModule().getNamedFunction( fnAttr.getValue()); if (!fn) return op.emitOpError() << "'" << fnAttr.getValue() @@ -1107,7 +1107,7 @@ static LogicalResult verify(ConstantOp &op) { return op.emitOpError("requires 'value' to be a function reference"); // Try to find the referenced function. - auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction( + auto fn = op.getOperation()->getFunction().getModule().getNamedFunction( fnAttr.getValue()); if (!fn) return op.emitOpError("reference to undefined function 'bar'"); diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp index 15b148a31ca9..d2957c17480a 100644 --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -50,7 +50,7 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses, SourceMgr &sourceMgr, MLIRContext *context, const std::vector &passList) { - std::unique_ptr module(parseSourceFile(sourceMgr, context)); + OwningModuleRef module(parseSourceFile(sourceMgr, context)); if (!module) return failure(); @@ -63,7 +63,7 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses, applyPassManagerCLOptions(pm); // Run the pipeline. - if (failed(pm.run(module.get()))) + if (failed(pm.run(*module))) return failure(); // Print the output. diff --git a/mlir/lib/Support/TranslateClParser.cpp b/mlir/lib/Support/TranslateClParser.cpp index eb18acb56de2..dcb55d6a15f4 100644 --- a/mlir/lib/Support/TranslateClParser.cpp +++ b/mlir/lib/Support/TranslateClParser.cpp @@ -37,7 +37,7 @@ using namespace mlir; // Storage for the translation function wrappers that survive the parser. static llvm::SmallVector wrapperStorage; -static LogicalResult printMLIROutput(Module &module, +static LogicalResult printMLIROutput(Module module, llvm::StringRef outputFilename) { if (failed(module.verify())) return failure(); @@ -62,7 +62,7 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt) TranslateFunction wrapper = [function](StringRef inputFilename, StringRef outputFilename, MLIRContext *context) { - std::unique_ptr module = function(inputFilename, context); + OwningModuleRef module = function(inputFilename, context); if (!module) return failure(); return printMLIROutput(*module, outputFilename); @@ -79,8 +79,8 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt) MLIRContext *context) { llvm::SourceMgr sourceMgr; SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context); - auto module = std::unique_ptr( - parseSourceFile(inputFilename, sourceMgr, context)); + auto module = + OwningModuleRef(parseSourceFile(inputFilename, sourceMgr, context)); if (!module) return failure(); return failure(function(module.get(), outputFilename)); diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index 34dd374e9d7b..49431d463ba1 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -31,16 +31,16 @@ using namespace mlir; -std::unique_ptr mlir::translateModuleToLLVMIR(Module &m) { +std::unique_ptr mlir::translateModuleToLLVMIR(Module m) { return LLVM::ModuleTranslation::translateModule<>(m); } static TranslateFromMLIRRegistration registration( - "mlir-to-llvmir", [](Module *module, llvm::StringRef outputFilename) { + "mlir-to-llvmir", [](Module module, llvm::StringRef outputFilename) { if (!module) return true; - auto llvmModule = LLVM::ModuleTranslation::translateModule<>(*module); + auto llvmModule = LLVM::ModuleTranslation::translateModule<>(module); if (!llvmModule) return true; diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 1e8409246efc..bcff6e473285 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -47,8 +47,7 @@ static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder, class ModuleTranslation : public LLVM::ModuleTranslation { public: - explicit ModuleTranslation(Module &module) - : LLVM::ModuleTranslation(module) {} + explicit ModuleTranslation(Module module) : LLVM::ModuleTranslation(module) {} ~ModuleTranslation() override {} protected: @@ -62,7 +61,7 @@ protected: }; } // namespace -std::unique_ptr mlir::translateModuleToNVVMIR(Module &m) { +std::unique_ptr mlir::translateModuleToNVVMIR(Module m) { ModuleTranslation translation(m); auto llvmModule = LLVM::ModuleTranslation::translateModule(m); @@ -91,11 +90,11 @@ std::unique_ptr mlir::translateModuleToNVVMIR(Module &m) { static TranslateFromMLIRRegistration registration("mlir-to-nvvmir", - [](Module *module, llvm::StringRef outputFilename) { + [](Module module, llvm::StringRef outputFilename) { if (!module) return true; - auto llvmModule = mlir::translateModuleToNVVMIR(*module); + auto llvmModule = mlir::translateModuleToNVVMIR(module); if (!llvmModule) return true; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 4a68ac71ee0b..df2bd57d5f0f 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -275,7 +275,7 @@ static Value *getPHISourceValue(Block *current, Block *pred, : terminator.getSuccessorOperand(1, index); } -void ModuleTranslation::connectPHINodes(Function &func) { +void ModuleTranslation::connectPHINodes(Function func) { // Skip the first block, it cannot be branched to and its arguments correspond // to the arguments of the LLVM function. for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) { @@ -306,7 +306,7 @@ static void topologicalSortImpl(llvm::SetVector &blocks, Block *b) { } // Sort function blocks topologically. -static llvm::SetVector topologicalSort(Function &f) { +static llvm::SetVector topologicalSort(Function f) { // For each blocks that has not been visited yet (i.e. that has no // predecessors), add it to the list and traverse its successors in DFS // preorder. @@ -320,7 +320,7 @@ static llvm::SetVector topologicalSort(Function &f) { return blocks; } -bool ModuleTranslation::convertOneFunction(Function &func) { +bool ModuleTranslation::convertOneFunction(Function func) { // Clear the block and value mappings, they are only relevant within one // function. blockMapping.clear(); @@ -404,7 +404,7 @@ bool ModuleTranslation::convertFunctions() { return false; } -std::unique_ptr ModuleTranslation::prepareLLVMModule(Module &m) { +std::unique_ptr ModuleTranslation::prepareLLVMModule(Module m) { auto *dialect = m.getContext()->getRegisteredDialect(); assert(dialect && "LLVM dialect must be registered"); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 9916c9ee593a..9375a7b8445e 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1128,7 +1128,7 @@ auto ConversionTarget::getOpAction(OperationName op) const /// conversion object. If conversion fails for specific functions, those /// functions remains unmodified. LogicalResult -mlir::applyConversionPatterns(Module &module, ConversionTarget &target, +mlir::applyConversionPatterns(Module module, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns) { SmallVector allFunctions(module.getFunctions()); diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index a88312dba9bd..788857c1a10e 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -555,8 +555,8 @@ TEST_FUNC(vectorize_2d) { makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType}); mlir::Function f = owningF; - mlir::Module module(&globalContext()); - module.push_back(f); + mlir::OwningModuleRef module = Module::create(&globalContext()); + module->push_back(f); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp index 1ac6c4026302..1b1fbccdf629 100644 --- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp +++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp @@ -89,8 +89,8 @@ static llvm::cl::list llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)); -static std::unique_ptr parseMLIRInput(StringRef inputFilename, - MLIRContext *context) { +static OwningModuleRef parseMLIRInput(StringRef inputFilename, + MLIRContext *context) { // Set up the input file. std::string errorMessage; auto file = openInputFile(inputFilename, &errorMessage); @@ -101,7 +101,7 @@ static std::unique_ptr parseMLIRInput(StringRef inputFilename, llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); - return std::unique_ptr(parseSourceFile(sourceMgr, context)); + return OwningModuleRef(parseSourceFile(sourceMgr, context)); } // Initialize the relevant subsystems of LLVM. @@ -151,7 +151,7 @@ static void printMemRefArguments(ArrayRef argTypes, // - canonicalization // - affine to standard lowering // - standard to llvm lowering -static LogicalResult convertAffineStandardToLLVMIR(Module *module) { +static LogicalResult convertAffineStandardToLLVMIR(Module module) { PassManager manager; manager.addPass(mlir::createCanonicalizerPass()); manager.addPass(mlir::createCSEPass()); @@ -161,9 +161,9 @@ static LogicalResult convertAffineStandardToLLVMIR(Module *module) { } static Error compileAndExecuteFunctionWithMemRefs( - Module *module, StringRef entryPoint, + Module module, StringRef entryPoint, std::function transformer) { - Function mainFunction = module->getNamedFunction(entryPoint); + Function mainFunction = module.getNamedFunction(entryPoint); if (!mainFunction || mainFunction.getBlocks().empty()) { return make_string_error("entry point not found"); } @@ -204,9 +204,9 @@ static Error compileAndExecuteFunctionWithMemRefs( } static Error compileAndExecuteSingleFloatReturnFunction( - Module *module, StringRef entryPoint, + Module module, StringRef entryPoint, std::function transformer) { - Function mainFunction = module->getNamedFunction(entryPoint); + Function mainFunction = module.getNamedFunction(entryPoint); if (!mainFunction || mainFunction.isExternal()) { return make_string_error("entry point not found"); } diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp index d2a82374124a..0464498b3614 100644 --- a/mlir/unittests/Pass/AnalysisManagerTest.cpp +++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp @@ -26,19 +26,19 @@ namespace { /// Minimal class definitions for two analyses. struct MyAnalysis { MyAnalysis(Function) {} - MyAnalysis(Module *) {} + MyAnalysis(Module) {} }; struct OtherAnalysis { OtherAnalysis(Function) {} - OtherAnalysis(Module *) {} + OtherAnalysis(Module) {} }; TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) { MLIRContext context; // Test fine grain invalidation of the module analysis manager. - std::unique_ptr module(new Module(&context)); - ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr); + OwningModuleRef module(Module::create(&context)); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); // Query two different analyses, but only preserve one before invalidating. mam.getAnalysis(); @@ -58,14 +58,14 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) { Builder builder(&context); // Create a function and a module. - std::unique_ptr module(new Module(&context)); + OwningModuleRef module(Module::create(&context)); Function func1 = Function::create(builder.getUnknownLoc(), "foo", builder.getFunctionType(llvm::None, llvm::None)); module->push_back(func1); // Test fine grain invalidation of the function analysis manager. - ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); FunctionAnalysisManager fam = mam.slice(func1); // Query two different analyses, but only preserve one before invalidating. @@ -86,7 +86,7 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) { Builder builder(&context); // Create a function and a module. - std::unique_ptr module(new Module(&context)); + OwningModuleRef module(Module::create(&context)); Function func1 = Function::create(builder.getUnknownLoc(), "foo", builder.getFunctionType(llvm::None, llvm::None)); @@ -94,7 +94,7 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) { // Test fine grain invalidation of a function analysis from within a module // analysis manager. - ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); // Query two different analyses, but only preserve one before invalidating. mam.getFunctionAnalysis(func1);