[mlir] Use PyValue instead of PyOpResult in Python operand container
authorAlex Zinenko <zinenko@google.com>
Fri, 6 Nov 2020 10:00:13 +0000 (11:00 +0100)
committerAlex Zinenko <zinenko@google.com>
Fri, 6 Nov 2020 18:02:35 +0000 (19:02 +0100)
The PyOpOperands container was erroneously constructing objects for
individual operands as PyOpResult. Operands in fact are just values,
which may or may not be results of another operation. The code would
eventually crash if the operand was a block argument. Add a test that
exercises the behavior that previously led to crashes.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D90917

mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_operation.py

index 5a3e966..cf71cc3 100644 (file)
@@ -1228,13 +1228,12 @@ public:
   }
 
   /// Returns `index`-th element in the result list.
-  PyOpResult dunderGetItem(intptr_t index) {
+  PyValue dunderGetItem(intptr_t index) {
     if (index < 0 || index >= dunderLen()) {
       throw SetPyError(PyExc_IndexError,
                        "attempt to access out of bounds region");
     }
-    PyValue value(operation, mlirOperationGetOperand(operation->get(), index));
-    return PyOpResult(value);
+    return PyValue(operation, mlirOperationGetOperand(operation->get(), index));
   }
 
   /// Defines a Python class in the bindings.
index d5c5b3f..0ce7cee 100644 (file)
@@ -132,6 +132,29 @@ def testBlockArgumentList():
 run(testBlockArgumentList)
 
 
+# CHECK-LABEL: TEST: testOperationOperands
+def testOperationOperands():
+  with Context() as ctx:
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(r"""
+      func @f1(%arg0: i32) {
+        %0 = "test.producer"() : () -> i64
+        "test.consumer"(%arg0, %0) : (i32, i64) -> ()
+        return
+      }""")
+    func = module.body.operations[0]
+    entry_block = func.regions[0].blocks[0]
+    consumer = entry_block.operations[1]
+    assert len(consumer.operands) == 2
+    # CHECK: Operand 0, type i32
+    # CHECK: Operand 1, type i64
+    for i, operand in enumerate(consumer.operands):
+      print(f"Operand {i}, type {operand.type}")
+
+
+run(testOperationOperands)
+
+
 # CHECK-LABEL: TEST: testDetachedOperation
 def testDetachedOperation():
   ctx = Context()