From d747a170a47dce64aefb906211e8aaed0c6bd6f6 Mon Sep 17 00:00:00 2001 From: John Demme Date: Tue, 9 Aug 2022 19:37:04 -0700 Subject: [PATCH] [MLIR] [Python] Fix `Value.owner` to handle BlockArgs Previously, calling `Value.owner()` would C++ assert in debug builds if `Value` was a block argument. Additionally, the behavior was just wrong in release builds. This patch adds support for BlockArg Values. --- mlir/lib/Bindings/Python/IRCore.cpp | 21 ++++++++++++++++----- mlir/test/python/ir/value.py | 15 +++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index beb0c6c..db199b3 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3118,11 +3118,22 @@ void mlir::python::populateIRCore(py::module &m) { .def_property_readonly( "owner", [](PyValue &self) { - assert(mlirOperationEqual(self.getParentOperation()->get(), - mlirOpResultGetOwner(self.get())) && - "expected the owner of the value in Python to match that in " - "the IR"); - return self.getParentOperation().getObject(); + MlirValue v = self.get(); + if (mlirValueIsAOpResult(v)) { + assert( + mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match that in " + "the IR"); + return self.getParentOperation().getObject(); + } + + if (mlirValueIsABlockArgument(v)) { + MlirBlock block = mlirBlockArgumentGetOwner(self.get()); + return py::cast(PyBlock(self.getParentOperation(), block)); + } + + assert(false && "Value must be a block argument or an op result"); }) .def("__eq__", [](PyValue &self, PyValue &other) { diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index 4eebc53..262896e 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -37,6 +37,21 @@ def testOpResultOwner(): assert op.result.owner == op +# CHECK-LABEL: TEST: testBlockArgOwner +@run +def testBlockArgOwner(): + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" + func.func @foo(%arg0: f32) { + return + }""", ctx) + func = module.body.operations[0] + block = func.regions[0].blocks[0] + assert block.arguments[0].owner == block + + # CHECK-LABEL: TEST: testValueIsInstance @run def testValueIsInstance(): -- 2.7.4