[mlir][python] Hook up PyRegionList.__iter__ to PyRegionIterator
authorRahul Kayaith <rkayaith@gmail.com>
Thu, 25 May 2023 02:05:06 +0000 (22:05 -0400)
committerRahul Kayaith <rkayaith@gmail.com>
Thu, 25 May 2023 02:16:58 +0000 (22:16 -0400)
This fixes a -Wunused-member-function warning, at the moment
`PyRegionIterator` is never constructed by anything (the only use was
removed in D111697), and iterating over region lists is just falling
back to a generic python iterator object.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D150244

mlir/lib/Bindings/Python/IRCore.cpp
mlir/test/python/ir/operation.py

index 7013cca..a6bd4d8 100644 (file)
@@ -295,6 +295,11 @@ class PyRegionList {
 public:
   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
 
+  PyRegionIterator dunderIter() {
+    operation->checkValid();
+    return PyRegionIterator(operation);
+  }
+
   intptr_t dunderLen() {
     operation->checkValid();
     return mlirOperationGetNumRegions(operation->get());
@@ -312,6 +317,7 @@ public:
   static void bind(py::module &m) {
     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
         .def("__len__", &PyRegionList::dunderLen)
+        .def("__iter__", &PyRegionList::dunderIter)
         .def("__getitem__", &PyRegionList::dunderGetItem);
   }
 
index ea84d11..22a8089 100644 (file)
@@ -48,11 +48,9 @@ def testTraverseOpRegionBlockIterators():
   # CHECK: .verify = True
   print(f".verify = {module.operation.verify()}")
 
-  # Get the regions and blocks from the default collections.
-  default_regions = list(op.regions)
-  default_blocks = list(default_regions[0])
+  # Get the blocks from the default collection.
+  default_blocks = list(regions[0])
   # They should compare equal regardless of how obtained.
-  assert default_regions == regions
   assert default_blocks == blocks
 
   # Should be able to get the operations from either the named collection
@@ -79,6 +77,13 @@ def testTraverseOpRegionBlockIterators():
   # CHECK:           OP 1: func.return
   walk_operations("", op)
 
+  # CHECK:    Region iter: <mlir.{{.+}}.RegionIterator
+  # CHECK:     Block iter: <mlir.{{.+}}.BlockIterator
+  # CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
+  print("   Region iter:", iter(op.regions))
+  print("    Block iter:", iter(op.regions[0]))
+  print("Operation iter:", iter(op.regions[0].blocks[0]))
+
 
 # Verify index based traversal of the op/region/block hierarchy.
 # CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices