From d0d26ee78cde3402fbc1fe445bcbcfc7606fbcd1 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Wed, 24 May 2023 22:05:06 -0400 Subject: [PATCH] [mlir][python] Hook up PyRegionList.__iter__ to PyRegionIterator This fixes a -Wunused-member-function warning, at the moment `PyRegionIterator` is never constructed by anything (the only use was removed in D111697), and iterating over region lists is just falling back to a generic python iterator object. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D150244 --- mlir/lib/Bindings/Python/IRCore.cpp | 6 ++++++ mlir/test/python/ir/operation.py | 13 +++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7013cca..a6bd4d8 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -295,6 +295,11 @@ class PyRegionList { public: PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} + PyRegionIterator dunderIter() { + operation->checkValid(); + return PyRegionIterator(operation); + } + intptr_t dunderLen() { operation->checkValid(); return mlirOperationGetNumRegions(operation->get()); @@ -312,6 +317,7 @@ public: static void bind(py::module &m) { py::class_(m, "RegionSequence", py::module_local()) .def("__len__", &PyRegionList::dunderLen) + .def("__iter__", &PyRegionList::dunderIter) .def("__getitem__", &PyRegionList::dunderGetItem); } diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index ea84d11..22a8089 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -48,11 +48,9 @@ def testTraverseOpRegionBlockIterators(): # CHECK: .verify = True print(f".verify = {module.operation.verify()}") - # Get the regions and blocks from the default collections. - default_regions = list(op.regions) - default_blocks = list(default_regions[0]) + # Get the blocks from the default collection. + default_blocks = list(regions[0]) # They should compare equal regardless of how obtained. - assert default_regions == regions assert default_blocks == blocks # Should be able to get the operations from either the named collection @@ -79,6 +77,13 @@ def testTraverseOpRegionBlockIterators(): # CHECK: OP 1: func.return walk_operations("", op) + # CHECK: Region iter: