[mlir][python] Fix issues with block argument slices
authorAlex Zinenko <zinenko@google.com>
Thu, 21 Jul 2022 14:00:37 +0000 (14:00 +0000)
committerAlex Zinenko <zinenko@google.com>
Thu, 21 Jul 2022 14:41:12 +0000 (14:41 +0000)
The type extraction helper function for block argument and op result
list objects was ignoring the slice entirely. So was the slice addition.
Both are caused by a misleading naming convention to implement slices
via CRTP. Make the convention more explicit and hide the helper
functions so users have harder time calling them directly.

Closes #56540.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D130271

mlir/lib/Bindings/Python/IRAffine.cpp
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/PybindUtils.h
mlir/test/python/ir/operation.py

index 0da936e..fc7133b 100644 (file)
@@ -385,9 +385,13 @@ public:
                   step),
         affineMap(map) {}
 
-  intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyAffineMapExprList, PyAffineExpr>;
+
+  intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); }
 
-  PyAffineExpr getElement(intptr_t pos) {
+  PyAffineExpr getRawElement(intptr_t pos) {
     return PyAffineExpr(affineMap.getContext(),
                         mlirAffineMapGetResult(affineMap, pos));
   }
@@ -397,7 +401,6 @@ public:
     return PyAffineMapExprList(affineMap, startIndex, length, step);
   }
 
-private:
   PyAffineMap affineMap;
 };
 } // namespace
@@ -460,9 +463,13 @@ public:
                   step),
         set(set) {}
 
-  intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>;
+
+  intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); }
 
-  PyIntegerSetConstraint getElement(intptr_t pos) {
+  PyIntegerSetConstraint getRawElement(intptr_t pos) {
     return PyIntegerSetConstraint(set, pos);
   }
 
@@ -471,7 +478,6 @@ public:
     return PyIntegerSetConstraintList(set, startIndex, length, step);
   }
 
-private:
   PyIntegerSet set;
 };
 } // namespace
index 9738351..fea26ec 100644 (file)
@@ -1968,8 +1968,8 @@ template <typename Container>
 static std::vector<PyType> getValueTypes(Container &container,
                                          PyMlirContextRef &context) {
   std::vector<PyType> result;
-  result.reserve(container.getNumElements());
-  for (int i = 0, e = container.getNumElements(); i < e; ++i) {
+  result.reserve(container.size());
+  for (int i = 0, e = container.size(); i < e; ++i) {
     result.push_back(
         PyType(context, mlirValueGetType(container.getElement(i).get())));
   }
@@ -1993,14 +1993,24 @@ public:
                   step),
         operation(std::move(operation)), block(block) {}
 
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly("types", [](PyBlockArgumentList &self) {
+      return getValueTypes(self, self.operation->getContext());
+    });
+  }
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
+
   /// Returns the number of arguments in the list.
-  intptr_t getNumElements() {
+  intptr_t getRawNumElements() {
     operation->checkValid();
     return mlirBlockGetNumArguments(block);
   }
 
-  /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
-  PyBlockArgument getElement(intptr_t pos) {
+  /// Returns `pos`-the element in the list.
+  PyBlockArgument getRawElement(intptr_t pos) {
     MlirValue argument = mlirBlockGetArgument(block, pos);
     return PyBlockArgument(operation, argument);
   }
@@ -2011,13 +2021,6 @@ public:
     return PyBlockArgumentList(operation, block, startIndex, length, step);
   }
 
-  static void bindDerived(ClassTy &c) {
-    c.def_property_readonly("types", [](PyBlockArgumentList &self) {
-      return getValueTypes(self, self.operation->getContext());
-    });
-  }
-
-private:
   PyOperationRef operation;
   MlirBlock block;
 };
@@ -2038,12 +2041,25 @@ public:
                   step),
         operation(operation) {}
 
-  intptr_t getNumElements() {
+  void dunderSetItem(intptr_t index, PyValue value) {
+    index = wrapIndex(index);
+    mlirOperationSetOperand(operation->get(), index, value.get());
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def("__setitem__", &PyOpOperandList::dunderSetItem);
+  }
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyOpOperandList, PyValue>;
+
+  intptr_t getRawNumElements() {
     operation->checkValid();
     return mlirOperationGetNumOperands(operation->get());
   }
 
-  PyValue getElement(intptr_t pos) {
+  PyValue getRawElement(intptr_t pos) {
     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
     MlirOperation owner;
     if (mlirValueIsAOpResult(operand))
@@ -2061,16 +2077,6 @@ public:
     return PyOpOperandList(operation, startIndex, length, step);
   }
 
-  void dunderSetItem(intptr_t index, PyValue value) {
-    index = wrapIndex(index);
-    mlirOperationSetOperand(operation->get(), index, value.get());
-  }
-
-  static void bindDerived(ClassTy &c) {
-    c.def("__setitem__", &PyOpOperandList::dunderSetItem);
-  }
-
-private:
   PyOperationRef operation;
 };
 
@@ -2090,12 +2096,22 @@ public:
                   step),
         operation(operation) {}
 
-  intptr_t getNumElements() {
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly("types", [](PyOpResultList &self) {
+      return getValueTypes(self, self.operation->getContext());
+    });
+  }
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyOpResultList, PyOpResult>;
+
+  intptr_t getRawNumElements() {
     operation->checkValid();
     return mlirOperationGetNumResults(operation->get());
   }
 
-  PyOpResult getElement(intptr_t index) {
+  PyOpResult getRawElement(intptr_t index) {
     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
     return PyOpResult(value);
   }
@@ -2104,13 +2120,6 @@ public:
     return PyOpResultList(operation, startIndex, length, step);
   }
 
-  static void bindDerived(ClassTy &c) {
-    c.def_property_readonly("types", [](PyOpResultList &self) {
-      return getValueTypes(self, self.operation->getContext());
-    });
-  }
-
-private:
   PyOperationRef operation;
 };
 
index e791ba8..5356cbd 100644 (file)
@@ -199,15 +199,17 @@ private:
 /// A derived class must provide the following:
 ///   - a `static const char *pyClassName ` field containing the name of the
 ///     Python class to bind;
-///   - an instance method `intptr_t getNumElements()` that returns the number
+///   - an instance method `intptr_t getRawNumElements()` that returns the
+///   number
 ///     of elements in the backing container (NOT that of the slice);
-///   - an instance method `ElementTy getElement(intptr_t)` that returns a
-///     single element at the given index.
+///   - an instance method `ElementTy getRawElement(intptr_t)` that returns a
+///     single element at the given linear index (NOT slice index);
 ///   - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
 ///     constructs a new instance of the derived pseudo-container with the
 ///     given slice parameters (to be forwarded to the Sliceable constructor).
 ///
-/// The getNumElements() and getElement(intptr_t) callbacks must not throw.
+/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
+/// throw.
 ///
 /// A derived class may additionally define:
 ///   - a `static void bindDerived(ClassTy &)` method to bind additional methods
@@ -217,8 +219,8 @@ class Sliceable {
 protected:
   using ClassTy = pybind11::class_<Derived>;
 
-  // Transforms `index` into a legal value to access the underlying sequence.
-  // Returns <0 on failure.
+  /// Transforms `index` into a legal value to access the underlying sequence.
+  /// Returns <0 on failure.
   intptr_t wrapIndex(intptr_t index) {
     if (index < 0)
       index = length + index;
@@ -227,6 +229,15 @@ protected:
     return index;
   }
 
+  /// Computes the linear index given the current slice properties.
+  intptr_t linearizeIndex(intptr_t index) {
+    intptr_t linearIndex = index * step + startIndex;
+    assert(linearIndex >= 0 &&
+           linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
+           "linear index out of bounds, the slice is ill-formed");
+    return linearIndex;
+  }
+
   /// Returns the element at the given slice index. Supports negative indices
   /// by taking elements in inverse order. Returns a nullptr object if out
   /// of bounds.
@@ -238,13 +249,8 @@ protected:
       return {};
     }
 
-    // Compute the linear index given the current slice properties.
-    int linearIndex = index * step + startIndex;
-    assert(linearIndex >= 0 &&
-           linearIndex < static_cast<Derived *>(this)->getNumElements() &&
-           "linear index out of bounds, the slice is ill-formed");
     return pybind11::cast(
-        static_cast<Derived *>(this)->getElement(linearIndex));
+        static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
   }
 
   /// Returns a new instance of the pseudo-container restricted to the given
@@ -266,6 +272,21 @@ public:
     assert(length >= 0 && "expected non-negative slice length");
   }
 
+  /// Returns the `index`-th element in the slice, supports negative indices.
+  /// Throws if the index is out of bounds.
+  ElementTy getElement(intptr_t index) {
+    // Negative indices mean we count from the end.
+    index = wrapIndex(index);
+    if (index < 0) {
+      throw pybind11::index_error("index out of range");
+    }
+
+    return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
+  }
+
+  /// Returns the size of slice.
+  intptr_t size() { return length; }
+
   /// Returns a new vector (mapped to Python list) containing elements from two
   /// slices. The new vector is necessary because slices may not be contiguous
   /// or even come from the same original sequence.
@@ -276,7 +297,7 @@ public:
       elements.push_back(static_cast<Derived *>(this)->getElement(i));
     }
     for (intptr_t i = 0; i < other.length; ++i) {
-      elements.push_back(static_cast<Derived *>(this)->getElement(i));
+      elements.push_back(static_cast<Derived *>(&other)->getElement(i));
     }
     return elements;
   }
index b7b47f8..2d70b88 100644 (file)
@@ -185,6 +185,19 @@ def testBlockArgumentList():
     for t in entry_block.arguments.types:
       print("Type: ", t)
 
+    # Check that slicing and type access compose.
+    # CHECK: Sliced type: i16
+    # CHECK: Sliced type: i24
+    for t in entry_block.arguments[1:].types:
+      print("Sliced type: ", t)
+
+    # Check that slice addition works as expected.
+    # CHECK: Argument 2, type i24
+    # CHECK: Argument 0, type i8
+    restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
+    for arg in restructured:
+      print(f"Argument {arg.arg_number}, type {arg.type}")
+
 
 # CHECK-LABEL: TEST: testOperationOperands
 @run