[mlir][ods] Handle DeclareOpInterfaceMethods in formatgen
authorJacques Pienaar <jpienaar@google.com>
Tue, 4 Jan 2022 16:28:59 +0000 (08:28 -0800)
committerJacques Pienaar <jpienaar@google.com>
Tue, 4 Jan 2022 16:28:59 +0000 (08:28 -0800)
Previously it would not consider ops with
DeclareOpInterfaceMethods<InferTypeOpInterface> as having the
InferTypeOpInterface interfaces added. The OpInterface nested inside
DeclareOpInterfaceMethods is not retained so that one could query it, so
check for the the C++ class directly (a bit raw/low level - will be
addressed in follow up).

Differential Revision: https://reviews.llvm.org/D116572

mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp

index aee0bdb..4418178 100644 (file)
@@ -264,6 +264,15 @@ Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
   return builder.create<TestOpConstant>(loc, type, value);
 }
 
+::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
+    ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
+    ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+    ::mlir::RegionRange regions,
+    ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+  inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
+  return ::mlir::success();
+}
+
 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
                                                OperationName opName) {
   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
index 39f0b0b..6fad11b 100644 (file)
@@ -2139,6 +2139,12 @@ def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
    }];
 }
 
+// Check that formatget supports DeclareOpInterfaceMethods.
+def FormatInferType2Op : TEST_Op<"format_infer_type2", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let results = (outs AnyType);
+  let assemblyFormat = "attr-dict";
+}
+
 // Base class for testing mixing allOperandTypes, allOperands, and
 // inferResultTypes.
 class FormatInferAllTypesBaseOp<string mnemonic, list<OpTrait> traits = []>
index 152cd0a..77afc41 100644 (file)
@@ -409,7 +409,10 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
 //===----------------------------------------------------------------------===//
 
 // CHECK: test.format_infer_type
-%ignored_res7 = test.format_infer_type
+%ignored_res7a = test.format_infer_type
+
+// CHECK: test.format_infer_type2
+%ignored_res7b = test.format_infer_type2
 
 // CHECK: test.format_infer_type_all_operands_and_types(%[[I64]], %[[I32]]) : i64, i32
 %ignored_res8:2 = test.format_infer_type_all_operands_and_types(%i64, %i32) : i64, i32
index 02d0e81..b521803 100644 (file)
@@ -2345,9 +2345,16 @@ LogicalResult FormatParser::parse() {
       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
     } else if (def.isSubClassOf("TypesMatchWith")) {
       handleTypesMatchConstraint(variableTyResolver, def);
-    } else if (def.getName() == "InferTypeOpInterface" &&
-               !op.allResultTypesKnown()) {
-      canInferResultTypes = true;
+    } else if (!op.allResultTypesKnown()) {
+      // This doesn't check the name directly to handle
+      //    DeclareOpInterfaceMethods<InferTypeOpInterface>
+      // and the like.
+      // TODO: Add hasCppInterface check.
+      if (auto name = def.getValueAsOptionalString("cppClassName")) {
+        if (*name == "InferTypeOpInterface" &&
+            def.getValueAsString("cppNamespace") == "::mlir")
+          canInferResultTypes = true;
+      }
     }
   }