diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 621c095021c7..1225c26486a3 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1075,6 +1075,21 @@ py::object PyOperation::createFromCapsule(py::object capsule) { .releaseObject(); } +static void maybeInsertOperation(PyOperationRef &op, + const py::object &maybeIp) { + // InsertPoint active? + if (!maybeIp.is(py::cast(false))) { + PyInsertionPoint *ip; + if (maybeIp.is_none()) { + ip = PyThreadContextEntry::getDefaultInsertionPoint(); + } else { + ip = py::cast(maybeIp); + } + if (ip) + ip->insert(*op.get()); + } +} + py::object PyOperation::create( const std::string &name, llvm::Optional> results, llvm::Optional> operands, @@ -1192,22 +1207,20 @@ py::object PyOperation::create( MlirOperation operation = mlirOperationCreate(&state); PyOperationRef created = PyOperation::createDetached(location->getContext(), operation); - - // InsertPoint active? - if (!maybeIp.is(py::cast(false))) { - PyInsertionPoint *ip; - if (maybeIp.is_none()) { - ip = PyThreadContextEntry::getDefaultInsertionPoint(); - } else { - ip = py::cast(maybeIp); - } - if (ip) - ip->insert(*created.get()); - } + maybeInsertOperation(created, maybeIp); return created->createOpView(); } +py::object PyOperation::clone(const py::object &maybeIp) { + MlirOperation clonedOperation = mlirOperationClone(operation); + PyOperationRef cloned = + PyOperation::createDetached(getContext(), clonedOperation); + maybeInsertOperation(cloned, maybeIp); + + return cloned->createOpView(); +} + py::object PyOperation::createOpView() { checkValid(); MlirIdentifier ident = mlirOperationGetName(get()); @@ -2616,6 +2629,7 @@ void mlir::python::populateIRCore(py::module &m) { return py::none(); }) .def("erase", &PyOperation::erase) + .def("clone", &PyOperation::clone, py::arg("ip") = py::none()) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index b1424a994d85..2046ce0c1655 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -575,6 +575,9 @@ public: /// parent context's live operations map, and sets the valid bit false. void erase(); + /// Clones this operation. + pybind11::object clone(const pybind11::object &ip); + private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 8dca68385947..7e23268c2d8a 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -767,6 +767,26 @@ def testOperationErase(): Operation.create("custom.op2") +# CHECK-LABEL: TEST: testOperationClone +@run +def testOperationClone(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + m = Module.create() + with InsertionPoint(m.body): + op = Operation.create("custom.op1") + + # CHECK: "custom.op1" + print(m) + + clone = op.operation.clone() + op.operation.erase() + + # CHECK: "custom.op1" + print(m) + + # CHECK-LABEL: TEST: testOperationLoc @run def testOperationLoc():