[mlir][py] Fix infer return type invocation for variadics
authorJacques Pienaar <jpienaar@google.com>
Thu, 2 Feb 2023 20:23:46 +0000 (12:23 -0800)
committerJacques Pienaar <jpienaar@google.com>
Tue, 7 Feb 2023 01:01:53 +0000 (17:01 -0800)
Previously we only allowed the flattened list passed in, but the same
input provided here as to buildGeneric so flatten accordingly. We have
less info here than in buildGeneric so the error is more generic if
unpacking fails.

Differential Revision: https://reviews.llvm.org/D143240

mlir/lib/Bindings/Python/IRInterfaces.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
mlir/test/python/dialects/tensor.py

index fed8a50..b917bf0 100644 (file)
@@ -12,6 +12,7 @@
 #include "IRModule.h"
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Interfaces.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace py = pybind11;
 
@@ -183,9 +184,9 @@ public:
   }
 
   /// Given the arguments required to build an operation, attempts to infer its
-  /// return types. Throws value_error on faliure.
+  /// return types. Throws value_error on failure.
   std::vector<PyType>
-  inferReturnTypes(std::optional<std::vector<PyValue>> operands,
+  inferReturnTypes(std::optional<py::list> operandList,
                    std::optional<PyAttribute> attributes,
                    std::optional<std::vector<PyRegion>> regions,
                    DefaultingPyMlirContext context,
@@ -193,10 +194,45 @@ public:
     llvm::SmallVector<MlirValue> mlirOperands;
     llvm::SmallVector<MlirRegion> mlirRegions;
 
-    if (operands) {
-      mlirOperands.reserve(operands->size());
-      for (PyValue &value : *operands) {
-        mlirOperands.push_back(value);
+    if (operandList && !operandList->empty()) {
+      // Note: as the list may contain other lists this may not be final size.
+      mlirOperands.reserve(operandList->size());
+      for (const auto& it : llvm::enumerate(*operandList)) {
+        PyValue* val;
+        try {
+          val = py::cast<PyValue *>(it.value());
+          if (!val)
+            throw py::cast_error();
+          mlirOperands.push_back(val->get());
+          continue;
+        } catch (py::cast_error &err) {
+        }
+
+        try {
+          auto vals = py::cast<py::sequence>(it.value());
+          for (py::object v : vals) {
+            try {
+              val = py::cast<PyValue *>(v);
+              if (!val)
+                throw py::cast_error();
+              mlirOperands.push_back(val->get());
+            } catch (py::cast_error &err) {
+              throw py::value_error(
+                  (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+                   " must be a Value or Sequence of Values (" + err.what() +
+                   ")")
+                      .str());
+            }
+          }
+          continue;
+        } catch (py::cast_error &err) {
+          throw py::value_error(
+              (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+               " must be a Value or Sequence of Values (" + err.what() + ")")
+                  .str());
+        }
+
+        throw py::cast_error();
       }
     }
 
index d3c8fde..9d30f27 100644 (file)
@@ -24,8 +24,8 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
                                                    ValueRange values) {
   // Check static and dynamic offsets/sizes/strides does not overflow type.
   if (staticVals.size() != numElements)
-    return op->emitError("expected ")
-           << numElements << " " << name << " values";
+    return op->emitError("expected ") << numElements << " " << name
+                                      << " values, got " << staticVals.size();
   unsigned expectedNumDynamicEntries =
       llvm::count_if(staticVals, [&](int64_t staticVal) {
         return ShapedType::isDynamic(staticVal);
index 505946c..63a3125 100644 (file)
@@ -667,7 +667,7 @@ class IndexType(Type):
 
 class InferTypeOpInterface:
     def __init__(self, object: object, context: Optional[Context] = None) -> None: ...
-    def inferReturnTypes(self, operands: Optional[List[Value]] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ...
+    def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ...
     @property
     def operation(self) -> Operation: ...
     @property
index f7f73a1..d8ea426 100644 (file)
@@ -74,3 +74,30 @@ def testEmptyOp():
         return tensor.EmptyOp([], f32)
 
   print(module)
+
+
+# CHECK-LABEL: TEST: testInferTypesInsertSlice
+@run
+def testInferTypesInsertSlice():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32Type = F32Type.get()
+    indexType = IndexType.get()
+    with InsertionPoint(module.body):
+
+      @func.FuncOp.from_py_func(
+          RankedTensorType.get((1, 1), f32Type),
+          RankedTensorType.get((1, 1), f32Type))
+      # CHECK: func @f
+      # CHECK:      tensor.insert_slice %arg0 into %arg1[0, 0] [1, 1] [0, 0] :
+      # CHECK-SAME:   tensor<1x1xf32> into tensor<1x1xf32>
+      def f(source, dest):
+        c0 = arith.ConstantOp(indexType, 0)
+        c1 = arith.ConstantOp(indexType, 1)
+        d0 = tensor.InsertSliceOp(source, dest, [], [], [],
+                                  DenseI64ArrayAttr.get([0, 0]),
+                                  DenseI64ArrayAttr.get([1, 1]),
+                                  DenseI64ArrayAttr.get([0, 0]))
+        return [d0.result]
+
+  print(module)