[MLIR][python bindings] Fix inferReturnTypes + AttrSizedOperandSegments for optional...
authormax <maksim.levental@gmail.com>
Fri, 26 May 2023 19:39:03 +0000 (14:39 -0500)
committermax <maksim.levental@gmail.com>
Fri, 26 May 2023 19:50:51 +0000 (14:50 -0500)
Right now `inferTypeOpInterface.inferReturnTypes` fails because there's a cast in there to `py::sequence` which throws a `TypeError` when it tries to cast the `None`s. Note `None`s are inserted into `operands` for omitted operands passed to the generated builder:

```
    operands.append(_get_op_result_or_value(start) if start is not None else None)
    operands.append(_get_op_result_or_value(stop) if stop is not None else None)
    operands.append(_get_op_result_or_value(step) if step is not None else None)
```

Note also that skipping appending to the list operands doesn't work either because [[ https://github.com/llvm/llvm-project/blob/27c37327da67020f938aabf0f6405f57d688441e/mlir/lib/Bindings/Python/IRCore.cpp#L1585 | build generic ]] checks against the number of operand segments expected.

Currently the only way around is to handroll through `ir.Operation.create`.

Reviewed By: rkayaith

Differential Revision: https://reviews.llvm.org/D151409

mlir/lib/Bindings/Python/IRInterfaces.cpp
mlir/test/python/dialects/python_test.py
mlir/test/python/python_test_ops.td

index 25fcacc..dd41900 100644 (file)
@@ -53,6 +53,9 @@ llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
   // Note: as the list may contain other lists this may not be final size.
   mlirOperands.reserve(operandList->size());
   for (const auto &&it : llvm::enumerate(*operandList)) {
+    if (it.value().is_none())
+      continue;
+
     PyValue *val;
     try {
       val = py::cast<PyValue *>(it.value());
index 5346955..37e508f 100644 (file)
@@ -4,6 +4,7 @@ from mlir.ir import *
 import mlir.dialects.func as func
 import mlir.dialects.python_test as test
 import mlir.dialects.tensor as tensor
+import mlir.dialects.arith as arith
 
 
 def run(f):
@@ -467,3 +468,22 @@ def testCustomTypeTypeCaster():
         print(d.type)
         # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
         print(repr(d.type))
+
+
+# CHECK-LABEL: TEST: testInferTypeOpInterface
+@run
+def testInferTypeOpInterface():
+    with Context() as ctx, Location.unknown(ctx):
+        test.register_python_test_dialect(ctx)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            i64 = IntegerType.get_signless(64)
+            zero = arith.ConstantOp(i64, 0)
+
+            one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None)
+            # CHECK: i32
+            print(one_operand.result.type)
+
+            two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
+            # CHECK: f32
+            print(two_operands.result.type)
index 21bb95d..2fc78cb 100644 (file)
@@ -101,6 +101,31 @@ def InferResultsOp : TestOp<"infer_results_op", [InferTypeOpInterface]> {
   }];
 }
 
+def I32OrF32 : TypeConstraint<Or<[I32.predicate, F32.predicate]>,
+                                 "i32 or f32">;
+
+def InferResultsVariadicInputsOp : TestOp<"infer_results_variadic_inputs_op",
+    [InferTypeOpInterface, AttrSizedOperandSegments]> {
+  let arguments = (ins Optional<I64>:$single, Optional<I64>:$doubled);
+  let results = (outs I32OrF32:$res);
+
+  let extraClassDeclaration = [{
+    static ::mlir::LogicalResult inferReturnTypes(
+      ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
+      ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+      ::mlir::OpaqueProperties,
+      ::mlir::RegionRange regions,
+      ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+      ::mlir::Builder b(context);
+      if (operands.size() == 1)
+          inferredReturnTypes.push_back(b.getI32Type());
+      else if (operands.size() == 2)
+          inferredReturnTypes.push_back(b.getF32Type());
+      return ::mlir::success();
+    }
+  }];
+}
+
 // If all result types are buildable, the InferTypeOpInterface is implied and is
 // autogenerated by C++ ODS.
 def InferResultsImpliedOp : TestOp<"infer_results_implied_op"> {