[MLIR][TableGen] Fix ambiguous build methods when inferring result types.
authorRahul Joshi <jurahul@google.com>
Fri, 7 Aug 2020 21:02:19 +0000 (14:02 -0700)
committerRahul Joshi <jurahul@google.com>
Mon, 10 Aug 2020 17:05:06 +0000 (10:05 -0700)
- Fix ODS framework to suppress build methods that infer result types and are
  ambiguous with collective variants. This applies to operations with a single variadic
  inputs whose result types can be inferred.
- Extended OpBuildGenTest to test these kinds of ops.

Differential Revision: https://reviews.llvm.org/D85060

mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/Operator.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-result.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/unittests/TableGen/OpBuildGen.cpp

index 29d4caa..d7fac87 100644 (file)
@@ -151,6 +151,17 @@ public:
   // Returns the total number of arguments.
   int getNumArgs() const { return arguments.size(); }
 
+  // Returns true of the operation has a single variadic arg.
+  bool hasSingleVariadicArg() const;
+
+  // Returns true if the operation has a single variadic result.
+  bool hasSingleVariadicResult() const {
+    return getNumResults() == 1 && getResult(0).isVariadic();
+  }
+
+  // Returns true of the operation has no variadic regions.
+  bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; }
+
   using arg_iterator = const Argument *;
   using arg_range = llvm::iterator_range<arg_iterator>;
 
index 3dd9245..9d39956 100644 (file)
@@ -134,6 +134,11 @@ unsigned tblgen::Operator::getNumVariableLengthOperands() const {
   });
 }
 
+bool tblgen::Operator::hasSingleVariadicArg() const {
+  return getNumArgs() == 1 && getArg(0).is<tblgen::NamedTypeConstraint *>() &&
+         getOperand(0).isVariadic();
+}
+
 tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {
   return arguments.begin();
 }
index 742033b..c1bc754 100644 (file)
@@ -1526,4 +1526,31 @@ def TableGenBuildOp3 : TEST_Op<"tblgen_build_3", [SameVariadicResultSize]> {
   let results = (outs Variadic<AnyType>:$resultA, Variadic<AnyType>:$resultB);
 }
 
+// Single variadic arg, non variadic results, with SameOperandsAndResultType.
+// Tests suppression of ambiguious build methods for operations with
+// SameOperandsAndResultType trait.
+def TableGenBuildOp4 : TEST_Op<"tblgen_build_4", [SameOperandsAndResultType]> {
+  let arguments = (ins Variadic<AnyType>:$inputs);
+  let results = (outs AnyType:$result);
+}
+
+// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
+// Tests suppression of ambiguious build methods for operations with
+// SameOperandsAndResultType and InferTypeOpInterface.
+def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
+      [SameOperandsAndResultType, InferTypeOpInterface]> {
+  let arguments = (ins Variadic<AnyType>:$inputs);
+  let results = (outs AnyType:$result);
+
+  let extraClassDeclaration = [{
+    static LogicalResult inferReturnTypes(MLIRContext *, 
+          Optional<Location> location, ValueRange operands,
+          DictionaryAttr attributes, RegionRange regions,
+          SmallVectorImpl<Type> &inferredReturnTypes) {
+      inferredReturnTypes.assign({operands[0].getType()});
+      return success();
+    }
+   }];
+}
+
 #endif // TEST_OPS
index 4b091e4..bdb0765 100644 (file)
@@ -110,8 +110,8 @@ def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameOperandsA
   let results = (outs AnyTensor:$result);
 }
 
-// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input)
-// CHECK: odsState.addTypes({input.front().getType()});
+// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes )
+// CHECK: odsState.addTypes({operands[0].getType()});
 
 // Test with inferred shapes and interleaved with operands/attributes.
 //
index 989008d..9f00b80 100644 (file)
@@ -232,6 +232,10 @@ private:
   // operand's type as all results' types.
   void genUseOperandAsResultTypeCollectiveParamBuilder();
 
+  // Returns true if the inferred collective param build method should be
+  // generated.
+  bool shouldGenerateInferredTypeCollectiveParamBuilder();
+
   // Generates the build() method that takes aggregate operands/attributes
   // parameters. This build() method uses inferred types as result types.
   // Requires: The type needs to be inferable via InferTypeOpInterface.
@@ -984,40 +988,37 @@ void OpEmitter::genSeparateArgParamBuilder() {
   //          result
   //
   // In that case, skip generating such ambiguous build methods here.
-  bool hasSingleVariadicResult =
-      op.getNumResults() == 1 && op.getResult(0).isVariadic();
-
-  bool hasSingleVariadicArg =
-      op.getNumArgs() == 1 &&
-      op.getArg(0).is<tblgen::NamedTypeConstraint *>() &&
-      op.getOperand(0).isVariadic();
-  bool hasNoVariadicRegions = op.getNumVariadicRegions() == 0;
-
   for (auto attrType : attrBuilderType) {
     // Case 3b above.
-    if (!(hasNoVariadicRegions && hasSingleVariadicArg &&
-          hasSingleVariadicResult))
+    if (!(op.hasNoVariadicRegions() && op.hasSingleVariadicArg() &&
+          op.hasSingleVariadicResult()))
       emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
-    if (canInferType(op))
-      emit(attrType, TypeParamKind::None, /*inferType=*/true);
+    if (canInferType(op)) {
+      // When inferType = true, the generated build method does not have
+      // result types. If the op has a single variadic arg, then this build
+      // method will be ambiguious with the collective inferred build method
+      // generated in `genInferredTypeCollectiveParamBuilder`. If we are going
+      // to generate that collective inferred method, suppress generating the
+      // ambiguious build method here.
+      bool buildMethodAmbiguious =
+          op.hasSingleVariadicArg() &&
+          shouldGenerateInferredTypeCollectiveParamBuilder();
+      if (!buildMethodAmbiguious)
+        emit(attrType, TypeParamKind::None, /*inferType=*/true);
+    }
     // The separate arg + collective param kind method will be:
     // (a) Same as the separate arg + separate param kind method if there is
     //     only one variadic result.
     // (b) Ambiguous with the collective params method under conditions in (3a)
     //     above.
     // In either case, skip generating such build method.
-    if (!hasSingleVariadicResult &&
-        !(hasNoVariadicRegions && hasSingleVariadicArg))
+    if (!op.hasSingleVariadicResult() &&
+        !(op.hasNoVariadicRegions() && op.hasSingleVariadicArg()))
       emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
   }
 }
 
 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
-  // If this op has a variadic result, we cannot generate this builder because
-  // we don't know how many results to create.
-  if (op.getNumVariableLengthResults() != 0)
-    return;
-
   int numResults = op.getNumResults();
 
   // Signature
@@ -1055,6 +1056,10 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
        << llvm::join(resultTypes, ", ") << "});\n\n";
 }
 
+bool OpEmitter::shouldGenerateInferredTypeCollectiveParamBuilder() {
+  return canInferType(op) && op.getNumSuccessors() == 0;
+}
+
 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
   // TODO: Expand to support regions.
   std::string params =
@@ -1209,8 +1214,21 @@ void OpEmitter::genBuilder() {
   //    to facilitate different call patterns.
   if (op.getNumVariableLengthResults() == 0) {
     if (op.getTrait("OpTrait::SameOperandsAndResultType")) {
-      genUseOperandAsResultTypeSeparateParamBuilder();
-      genUseOperandAsResultTypeCollectiveParamBuilder();
+      // If the operation has a single variadic input, then the build method
+      // generated by `genUseOperandAsResultTypeSeparateParamBuilder` will be
+      // ambiguious with the one generated by
+      // `genUseOperandAsResultTypeCollectiveParamBuilder` (they both will have
+      // a single `ValueRange` argument for operands, and the collective one
+      // will have a `ArrayRef<NamedAttribute>` argument initalized to empty).
+      // Suppress such ambiguious build method.
+      if (!op.hasSingleVariadicArg())
+        genUseOperandAsResultTypeSeparateParamBuilder();
+
+      // The build method generated by the inferred type collective param
+      // builder and one generated here have the same arguments and hence
+      // generating both will be ambiguious. Enable just one of them.
+      if (!shouldGenerateInferredTypeCollectiveParamBuilder())
+        genUseOperandAsResultTypeCollectiveParamBuilder();
     }
     if (op.getTrait("OpTrait::FirstAttrDerivedResultType"))
       genUseAttrAsResultTypeBuilder();
@@ -1269,7 +1287,7 @@ void OpEmitter::genCollectiveParamBuilder() {
 
   // Generate builder that infers type too.
   // TODO: Expand to handle regions and successors.
-  if (canInferType(op) && op.getNumSuccessors() == 0)
+  if (shouldGenerateInferredTypeCollectiveParamBuilder())
     genInferredTypeCollectiveParamBuilder();
 }
 
index e90f96b..3e3256e 100644 (file)
@@ -63,6 +63,28 @@ protected:
     concreteOp.erase();
   }
 
+  // Helper method to test ops with inferred result types and single variadic
+  // input.
+  template <typename OpTy>
+  void testSingleVariadicInputInferredType() {
+    // Test separate arg, separate param build method.
+    auto op = builder.create<OpTy>(loc, i32Ty, ArrayRef<Value>{cstI32, cstI32});
+    verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
+
+    // Test collective params build method.
+    op = builder.create<OpTy>(loc, ArrayRef<Type>{i32Ty},
+                              ArrayRef<Value>{cstI32, cstI32});
+    verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
+
+    // Test build method with no result types, default value of attributes.
+    op = builder.create<OpTy>(loc, ArrayRef<Value>{cstI32, cstI32});
+    verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
+
+    // Test build method with no result types and supplied attributes.
+    op = builder.create<OpTy>(loc, ArrayRef<Value>{cstI32, cstI32}, attrs);
+    verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, attrs);
+  }
+
 protected:
   MLIRContext ctx;
   OpBuilder builder;
@@ -178,4 +200,19 @@ TEST_F(OpBuildGenTest,
   verifyOp(std::move(op), {i32Ty, f32Ty}, {cstI32}, attrs);
 }
 
+// The next 2 tests test supression of ambiguious build methods for ops that
+// have a single variadic input, and single non-variadic result, and which
+// support the SameOperandsAndResultType trait and and optionally the
+// InferOpTypeInterface interface. For such ops, the ODS framework generates
+// build methods with no result types as they are inferred from the input types.
+TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
+  testSingleVariadicInputInferredType<TableGenBuildOp4>();
+}
+
+TEST_F(
+    OpBuildGenTest,
+    BuildMethodsSameOperandsAndResultTypeAndInferOpTypeInterfaceSuppression) {
+  testSingleVariadicInputInferredType<TableGenBuildOp5>();
+}
+
 } // namespace mlir