mirror of
https://github.com/intel/llvm.git
synced 2026-01-15 12:25:46 +08:00
[mlir] Python: Parse ModuleOp from file path (#126572)
For extremely large models, it may be inefficient to load the model into memory in Python prior to passing it to the MLIR C APIs for deserialization. This change adds an API to parse a ModuleOp directly from a file path. Re-lands [4e14b8a](4e14b8afb4).
This commit is contained in:
@@ -309,6 +309,10 @@ MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location);
|
||||
MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context,
|
||||
MlirStringRef module);
|
||||
|
||||
/// Parses a module from file and transfers ownership to the caller.
|
||||
MLIR_CAPI_EXPORTED MlirModule
|
||||
mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName);
|
||||
|
||||
/// Gets the context that a module was created with.
|
||||
MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module);
|
||||
|
||||
|
||||
@@ -299,7 +299,7 @@ struct PyAttrBuilderMap {
|
||||
return *builder;
|
||||
}
|
||||
static void dunderSetItemNamed(const std::string &attributeKind,
|
||||
nb::callable func, bool replace) {
|
||||
nb::callable func, bool replace) {
|
||||
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
|
||||
replace);
|
||||
}
|
||||
@@ -3049,6 +3049,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
||||
},
|
||||
nb::arg("asm"), nb::arg("context").none() = nb::none(),
|
||||
kModuleParseDocstring)
|
||||
.def_static(
|
||||
"parseFile",
|
||||
[](const std::string &path, DefaultingPyMlirContext context) {
|
||||
PyMlirContext::ErrorCapture errors(context->getRef());
|
||||
MlirModule module = mlirModuleCreateParseFromFile(
|
||||
context->get(), toMlirStringRef(path));
|
||||
if (mlirModuleIsNull(module))
|
||||
throw MLIRError("Unable to parse module assembly", errors.take());
|
||||
return PyModule::forModule(module).releaseObject();
|
||||
},
|
||||
nb::arg("path"), nb::arg("context").none() = nb::none(),
|
||||
kModuleParseDocstring)
|
||||
.def_static(
|
||||
"create",
|
||||
[](DefaultingPyLocation loc) {
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/OwningOpRef.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
@@ -328,6 +329,15 @@ MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
|
||||
return MlirModule{owning.release().getOperation()};
|
||||
}
|
||||
|
||||
MlirModule mlirModuleCreateParseFromFile(MlirContext context,
|
||||
MlirStringRef fileName) {
|
||||
OwningOpRef<ModuleOp> owning =
|
||||
parseSourceFile<ModuleOp>(unwrap(fileName), unwrap(context));
|
||||
if (!owning)
|
||||
return MlirModule{nullptr};
|
||||
return MlirModule{owning.release().getOperation()};
|
||||
}
|
||||
|
||||
MlirContext mlirModuleGetContext(MlirModule module) {
|
||||
return wrap(unwrap(module).getContext());
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ import abc
|
||||
import collections
|
||||
from collections.abc import Callable, Sequence
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, TypeVar, overload
|
||||
|
||||
__all__ = [
|
||||
@@ -2129,6 +2130,15 @@ class Module:
|
||||
|
||||
Returns a new MlirModule or raises an MLIRError if the parsing fails.
|
||||
|
||||
See also: https://mlir.llvm.org/docs/LangRef/
|
||||
"""
|
||||
@staticmethod
|
||||
def parseFile(path: str, context: Context | None = None) -> Module:
|
||||
"""
|
||||
Parses a module's assembly format from file.
|
||||
|
||||
Returns a new MlirModule or raises an MLIRError if the parsing fails.
|
||||
|
||||
See also: https://mlir.llvm.org/docs/LangRef/
|
||||
"""
|
||||
def _CAPICreate(self) -> Any: ...
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import gc
|
||||
from tempfile import NamedTemporaryFile
|
||||
from mlir.ir import *
|
||||
|
||||
|
||||
@@ -27,6 +28,24 @@ def testParseSuccess():
|
||||
print(str(module))
|
||||
|
||||
|
||||
# Verify successful parse from file.
|
||||
# CHECK-LABEL: TEST: testParseFromFileSuccess
|
||||
# CHECK: module @successfulParse
|
||||
@run
|
||||
def testParseFromFileSuccess():
|
||||
ctx = Context()
|
||||
with NamedTemporaryFile(mode="w") as tmp_file:
|
||||
tmp_file.write(r"""module @successfulParse {}""")
|
||||
tmp_file.flush()
|
||||
module = Module.parseFile(tmp_file.name, ctx)
|
||||
assert module.context is ctx
|
||||
print("CLEAR CONTEXT")
|
||||
ctx = None # Ensure that module captures the context.
|
||||
gc.collect()
|
||||
module.operation.verify()
|
||||
print(str(module))
|
||||
|
||||
|
||||
# Verify parse error.
|
||||
# CHECK-LABEL: TEST: testParseError
|
||||
# CHECK: testParseError: <
|
||||
|
||||
Reference in New Issue
Block a user