[TableGen] Emit verification code for op results
authorLei Zhang <antiagainst@google.com>
Wed, 6 Mar 2019 20:50:01 +0000 (12:50 -0800)
committerjpienaar <jpienaar@google.com>
Sat, 30 Mar 2019 00:02:26 +0000 (17:02 -0700)
They can be verified using the same logic as operands.

PiperOrigin-RevId: 237101461

mlir/include/mlir/TableGen/Operator.h
mlir/test/mlir-tblgen/op-operand.td [new file with mode: 0644]
mlir/test/mlir-tblgen/op-result.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index 366e83d..6c66b76 100644 (file)
@@ -68,6 +68,10 @@ public:
   // Returns the number of results this op produces.
   int getNumResults() const;
 
+  // Returns the op result at the given `index`.
+  Value &getResult(int index) { return results[index]; }
+  const Value &getResult(int index) const { return results[index]; }
+
   // Returns the `index`-th result's type.
   Type getResultType(int index) const;
   // Returns the `index`-th result's name.
diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td
new file mode 100644 (file)
index 0000000..e3709a8
--- /dev/null
@@ -0,0 +1,12 @@
+// RUN: mlir-tblgen -gen-op-definitions -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def OneOperandOp : Op<"one_operand_op", []> {
+  let arguments = (ins I32:$input);
+}
+
+// CHECK-LABEL: class OneOperandOp
+// CHECK:      bool verify() const {
+// CHECK:        if (!((this->getInstruction()->getOperand(0)->getType().isInteger(32)))) {
+// CHECK-NEXT:     return emitOpError("operand #0 must be 32-bit integer");
index 2236320..8ca903c 100644 (file)
@@ -2,6 +2,16 @@
 
 include "mlir/IR/OpBase.td"
 
+def OneResultOp : Op<"one_result_op", []> {
+  let results = (outs I32:$result);
+}
+
+// CHECK-LABEL: class OneResultOp
+// CHECK:      bool verify() const {
+// CHECK:        if (!((this->getInstruction()->getResult(0)->getType().isInteger(32)))) {
+// CHECK-NEXT:     return emitOpError("result #0 must be 32-bit integer");
+
+
 def SameTypeOp : Op<"same_type_op", [SameValueType]> {
   let arguments = (ins I32:$x);
   let results = (outs I32:$y);
index 30773dc..fa7dc96 100644 (file)
@@ -470,7 +470,7 @@ void OpEmitter::emitVerifier() {
   auto valueInit = def.getValueInit("verifier");
   CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
   bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
-  if (!hasCustomVerify && op.getNumArgs() == 0)
+  if (!hasCustomVerify && op.getNumArgs() == 0 && op.getNumResults() == 0)
     return;
 
   OUT(2) << "bool verify() const {\n";
@@ -515,28 +515,38 @@ void OpEmitter::emitVerifier() {
       OUT(4) << "}\n";
   }
 
-  int opIndex = 0;
-  for (const auto &operand : op.getOperands()) {
-    // TODO: Handle variadic operand verification.
-    if (operand.type.isVariadic())
-      continue;
+  // Emits verification code for an operand or result.
+  auto verifyValue = [this](const tblgen::Value &value, int index,
+                            bool isOperand) -> void {
+    // TODO: Handle variadic operand/result verification.
+    if (value.type.isVariadic())
+      return;
 
     // TODO: Commonality between matchers could be extracted to have a more
     // concise code.
-    if (operand.hasPredicate()) {
-      auto description = operand.type.getDescription();
+    if (value.hasPredicate()) {
+      auto description = value.type.getDescription();
       OUT(4) << "if (!("
-             << formatv(operand.type.getConditionTemplate(),
-                        "this->getInstruction()->getOperand(" + Twine(opIndex) +
-                            ")->getType()")
+             << formatv(value.type.getConditionTemplate(),
+                        "this->getInstruction()->get" +
+                            Twine(isOperand ? "Operand" : "Result") + "(" +
+                            Twine(index) + ")->getType()")
              << ")) {\n";
-      OUT(6) << "return emitOpError(\"operand #" + Twine(opIndex)
+      OUT(6) << "return emitOpError(\"" << (isOperand ? "operand" : "result")
+             << " #" << index
              << (description.empty() ? " type precondition failed"
                                      : " must be " + Twine(description))
              << "\");";
       OUT(4) << "}\n";
     }
-    ++opIndex;
+  };
+
+  for (unsigned i = 0, e = op.getNumOperands(); i < e; ++i) {
+    verifyValue(op.getOperand(i), i, /*isOperand=*/true);
+  }
+
+  for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
+    verifyValue(op.getResult(i), i, /*isOperand=*/false);
   }
 
   for (auto &trait : op.getTraits()) {