[mlir][ods] Look through OpVariable for type constraint
authorJacques Pienaar <jpienaar@google.com>
Thu, 18 Jun 2020 19:51:51 +0000 (12:51 -0700)
committerJacques Pienaar <jpienaar@google.com>
Thu, 18 Jun 2020 19:51:51 +0000 (12:51 -0700)
If one uses an OpVariable (such as via Res) then the result type constraint
should be returned.

Differential Revision: https://reviews.llvm.org/D82119

mlir/lib/TableGen/Constraint.cpp
mlir/test/mlir-tblgen/op-decl.td

index 98bb7d6..b8e1b9f 100644 (file)
@@ -17,21 +17,28 @@ using namespace mlir::tblgen;
 
 Constraint::Constraint(const llvm::Record *record)
     : def(record), kind(CK_Uncategorized) {
-  if (record->isSubClassOf("TypeConstraint")) {
+  // Look through OpVariable's to their constraint.
+  if (def->isSubClassOf("OpVariable"))
+    def = def->getValueAsDef("constraint");
+  if (def->isSubClassOf("TypeConstraint")) {
     kind = CK_Type;
-  } else if (record->isSubClassOf("AttrConstraint")) {
+  } else if (def->isSubClassOf("AttrConstraint")) {
     kind = CK_Attr;
-  } else if (record->isSubClassOf("RegionConstraint")) {
+  } else if (def->isSubClassOf("RegionConstraint")) {
     kind = CK_Region;
-  } else if (record->isSubClassOf("SuccessorConstraint")) {
+  } else if (def->isSubClassOf("SuccessorConstraint")) {
     kind = CK_Successor;
   } else {
-    assert(record->isSubClassOf("Constraint"));
+    assert(def->isSubClassOf("Constraint"));
   }
 }
 
 Constraint::Constraint(Kind kind, const llvm::Record *record)
-    : def(record), kind(kind) {}
+    : def(record), kind(kind) {
+  // Look through OpVariable's to their constraint.
+  if (def->isSubClassOf("OpVariable"))
+    def = def->getValueAsDef("constraint");
+}
 
 Pred Constraint::getPredicate() const {
   auto *val = def->getValue("predicate");
index c0297da..f5bf03e 100644 (file)
@@ -155,7 +155,7 @@ def NS_EOp : NS_Op<"op_with_optionals", []> {
 def NS_FOp : NS_Op<"op_with_all_types_constraint",
     [AllTypesMatch<["a", "b"]>]> {
   let arguments = (ins AnyType:$a);
-  let results = (outs AnyType:$b);
+  let results = (outs Res<AnyType, "output b", []>:$b);
 }
 
 // CHECK-LABEL: class FOp :