Expose MlirOperationClone in Python bindings.

Expose MlirOperationClone in Python bindings.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D122526
This commit is contained in:
Dominik Grewe
2022-03-28 15:45:40 +02:00
committed by Alex Zinenko
parent 58d0da885e
commit 774818c09c
3 changed files with 49 additions and 12 deletions

View File

@@ -1075,6 +1075,21 @@ py::object PyOperation::createFromCapsule(py::object capsule) {
.releaseObject();
}
static void maybeInsertOperation(PyOperationRef &op,
const py::object &maybeIp) {
// InsertPoint active?
if (!maybeIp.is(py::cast(false))) {
PyInsertionPoint *ip;
if (maybeIp.is_none()) {
ip = PyThreadContextEntry::getDefaultInsertionPoint();
} else {
ip = py::cast<PyInsertionPoint *>(maybeIp);
}
if (ip)
ip->insert(*op.get());
}
}
py::object PyOperation::create(
const std::string &name, llvm::Optional<std::vector<PyType *>> results,
llvm::Optional<std::vector<PyValue *>> operands,
@@ -1192,22 +1207,20 @@ py::object PyOperation::create(
MlirOperation operation = mlirOperationCreate(&state);
PyOperationRef created =
PyOperation::createDetached(location->getContext(), operation);
// InsertPoint active?
if (!maybeIp.is(py::cast(false))) {
PyInsertionPoint *ip;
if (maybeIp.is_none()) {
ip = PyThreadContextEntry::getDefaultInsertionPoint();
} else {
ip = py::cast<PyInsertionPoint *>(maybeIp);
}
if (ip)
ip->insert(*created.get());
}
maybeInsertOperation(created, maybeIp);
return created->createOpView();
}
py::object PyOperation::clone(const py::object &maybeIp) {
MlirOperation clonedOperation = mlirOperationClone(operation);
PyOperationRef cloned =
PyOperation::createDetached(getContext(), clonedOperation);
maybeInsertOperation(cloned, maybeIp);
return cloned->createOpView();
}
py::object PyOperation::createOpView() {
checkValid();
MlirIdentifier ident = mlirOperationGetName(get());
@@ -2616,6 +2629,7 @@ void mlir::python::populateIRCore(py::module &m) {
return py::none();
})
.def("erase", &PyOperation::erase)
.def("clone", &PyOperation::clone, py::arg("ip") = py::none())
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyOperation::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)

View File

@@ -575,6 +575,9 @@ public:
/// parent context's live operations map, and sets the valid bit false.
void erase();
/// Clones this operation.
pybind11::object clone(const pybind11::object &ip);
private:
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
static PyOperationRef createInstance(PyMlirContextRef contextRef,

View File

@@ -767,6 +767,26 @@ def testOperationErase():
Operation.create("custom.op2")
# CHECK-LABEL: TEST: testOperationClone
@run
def testOperationClone():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Module.create()
with InsertionPoint(m.body):
op = Operation.create("custom.op1")
# CHECK: "custom.op1"
print(m)
clone = op.operation.clone()
op.operation.erase()
# CHECK: "custom.op1"
print(m)
# CHECK-LABEL: TEST: testOperationLoc
@run
def testOperationLoc():