[MLIR][Python] restore APIs in terms of Mlir* types (#160203)

https://github.com/llvm/llvm-project/pull/157930 changed a few APIs from
`Mlir*` to `Py*` and broke users that were using them (see
https://github.com/llvm/llvm-project/pull/160183#issuecomment-3321383969).
This PR restores those APIs.
This commit is contained in:
Maksim Levental
2025-09-22 18:00:57 -04:00
committed by GitHub
parent 42b195e1bf
commit 4a9df48cf8
3 changed files with 94 additions and 0 deletions

View File

@@ -4283,6 +4283,33 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
},
kValueReplaceAllUsesWithDocstring)
.def(
"replace_all_uses_except",
[](MlirValue self, MlirValue with, PyOperation &exception) {
MlirOperation exceptedUser = exception.get();
mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
},
nb::arg("with_"), nb::arg("exceptions"),
nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
"Operation) -> None"),
kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
[](MlirValue self, MlirValue with, nb::list exceptions) {
// Convert Python list to a SmallVector of MlirOperations
llvm::SmallVector<MlirOperation> exceptionOps;
for (nb::handle exception : exceptions) {
exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
}
mlirValueReplaceAllUsesExcept(
self, with, static_cast<intptr_t>(exceptionOps.size()),
exceptionOps.data());
},
nb::arg("with_"), nb::arg("exceptions"),
nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
"Sequence[Operation]) -> None"),
kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
[](PyValue &self, PyValue &with, PyOperation &exception) {

View File

@@ -898,6 +898,18 @@ public:
},
nb::arg("elements"), nb::arg("context") = nb::none(),
"Create a tuple type");
c.def_static(
"get_tuple",
[](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
elements.data());
return PyTupleType(context->getRef(), t);
},
nb::arg("elements"), nb::arg("context") = nb::none(),
// clang-format off
nb::sig("def get_tuple(elements: Sequence[Type], context: mlir.ir.Context | None = None) -> TupleType"),
// clang-format on
"Create a tuple type");
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) {
@@ -944,6 +956,20 @@ public:
},
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
"Gets a FunctionType from a list of input and result types");
c.def_static(
"get",
[](std::vector<MlirType> inputs, std::vector<MlirType> results,
DefaultingPyMlirContext context) {
MlirType t =
mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
results.size(), results.data());
return PyFunctionType(context->getRef(), t);
},
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
// clang-format off
nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"),
// clang-format on
"Gets a FunctionType from a list of input and result types");
c.def_prop_ro(
"inputs",
[](PyFunctionType &self) {

View File

@@ -83,6 +83,16 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
[](PyPDLPatternModule &self, MlirModule module) {
new (&self)
PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
},
// clang-format off
nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
// clang-format on
"module"_a, "Create a PDL module from the given module.")
.def(
"__init__",
[](PyPDLPatternModule &self, PyModule &module) {
@@ -117,6 +127,22 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
// clang-format on
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyModule &module, MlirFrozenRewritePatternSet set) {
auto status =
mlirApplyPatternsAndFoldGreedily(module.get(), set, {});
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
},
"module"_a, "set"_a,
// clang-format off
nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given module greedily while "
"folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
@@ -131,5 +157,20 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given op greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, MlirFrozenRewritePatternSet set) {
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
op.getOperation(), set, {});
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
},
"op"_a, "set"_a,
// clang-format off
nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given op greedily while folding "
"results.");
}