[mlir][python] Hook up PyRegionList.__iter__ to PyRegionIterator

This fixes a -Wunused-member-function warning, at the moment
`PyRegionIterator` is never constructed by anything (the only use was
removed in D111697), and iterating over region lists is just falling
back to a generic python iterator object.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D150244
This commit is contained in:
Rahul Kayaith
2023-05-24 22:05:06 -04:00
parent 693a1b7024
commit d0d26ee78c
2 changed files with 15 additions and 4 deletions

View File

@@ -295,6 +295,11 @@ class PyRegionList {
public:
PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
PyRegionIterator dunderIter() {
operation->checkValid();
return PyRegionIterator(operation);
}
intptr_t dunderLen() {
operation->checkValid();
return mlirOperationGetNumRegions(operation->get());
@@ -312,6 +317,7 @@ public:
static void bind(py::module &m) {
py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
.def("__len__", &PyRegionList::dunderLen)
.def("__iter__", &PyRegionList::dunderIter)
.def("__getitem__", &PyRegionList::dunderGetItem);
}

View File

@@ -48,11 +48,9 @@ def testTraverseOpRegionBlockIterators():
# CHECK: .verify = True
print(f".verify = {module.operation.verify()}")
# Get the regions and blocks from the default collections.
default_regions = list(op.regions)
default_blocks = list(default_regions[0])
# Get the blocks from the default collection.
default_blocks = list(regions[0])
# They should compare equal regardless of how obtained.
assert default_regions == regions
assert default_blocks == blocks
# Should be able to get the operations from either the named collection
@@ -79,6 +77,13 @@ def testTraverseOpRegionBlockIterators():
# CHECK: OP 1: func.return
walk_operations("", op)
# CHECK: Region iter: <mlir.{{.+}}.RegionIterator
# CHECK: Block iter: <mlir.{{.+}}.BlockIterator
# CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
print(" Region iter:", iter(op.regions))
print(" Block iter:", iter(op.regions[0]))
print("Operation iter:", iter(op.regions[0].blocks[0]))
# Verify index based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices