[ODS] Support variadic operand/result verification
authorLei Zhang <antiagainst@google.com>
Sun, 9 Jun 2019 14:00:09 +0000 (07:00 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:24:29 +0000 (16:24 -0700)
This CL enables verification code generation for variadic operands and results.
In verify(), we use fallback getter methods to access all the dynamic values
belonging to one static variadic operand/result to reuse the value range
calculation there.

PiperOrigin-RevId: 252288219

13 files changed:
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/LLVMIR/LLVMOps.td
mlir/include/mlir/TableGen/Operator.h
mlir/lib/StandardOps/Ops.cpp
mlir/lib/TableGen/Operator.cpp
mlir/test/IR/invalid-ops.mlir
mlir/test/IR/operand.mlir [new file with mode: 0644]
mlir/test/IR/result.mlir [new file with mode: 0644]
mlir/test/TestDialect/TestOps.td
mlir/test/mlir-tblgen/op-operand.td
mlir/test/mlir-tblgen/op-result.td
mlir/test/mlir-tblgen/predicate.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index ac8c652..0d32213 100644 (file)
@@ -255,9 +255,7 @@ class TypeAlias<Type t, string description = t.description> :
 // class is used for supporting variadic operands/results. An op can declare no
 // more than one variadic operand/result, and that operand/result must be the
 // last one in the operand/result list.
-class Variadic<Type type, string descr = "">
-    // TODO(b/132908002): support variadic type conditions
-    : TypeConstraint<CPred<"true">, descr> {
+class Variadic<Type type> : TypeConstraint<type.predicate, type.description> {
   Type baseType = type;
 }
 
@@ -907,6 +905,9 @@ def Terminator       : NativeOpTrait<"IsTerminator">;
 def FirstAttrDerivedResultType :
   GenInternalOpTrait<"FirstAttrDerivedResultType">;
 
+// TODO(antiagainst): Turn the following into normal traits and generate
+// verification for them.
+
 // All variadic operands of the op have the same number of values.
 // A variadic operand contains an array of values whose array size is only
 // known at runtime. This trait requires all variadic operands of an op
index a207e94..e9f235a 100644 (file)
@@ -203,7 +203,9 @@ def LLVM_PtrToIntOp
 // Call-related operations.
 def LLVM_CallOp : LLVM_Op<"call">,
                   Arguments<(ins OptionalAttr<FunctionAttr>:$callee,
-                             Variadic<LLVM_Type>)>,
+                             // TODO(b/133216756): fix test failure and
+                             // change to LLVM_Type
+                             Variadic<AnyType>)>,
                   Results<(outs Variadic<LLVM_Type>)>,
                   LLVM_TwoBuilders<LLVM_OneResultOpBuilder,
                                    LLVM_ZeroResultOpBuilder> {
index 6cc6bbc..4551790 100644 (file)
@@ -69,11 +69,12 @@ public:
   std::string getQualCppClassName() const;
 
   using value_iterator = NamedTypeConstraint *;
+  using value_range = llvm::iterator_range<value_iterator>;
 
   // Op result iterators.
   value_iterator result_begin();
   value_iterator result_end();
-  llvm::iterator_range<value_iterator> getResults();
+  value_range getResults();
 
   // Returns the number of results this op produces.
   int getNumResults() const;
@@ -110,7 +111,7 @@ public:
   // Op operand iterators.
   value_iterator operand_begin();
   value_iterator operand_end();
-  llvm::iterator_range<value_iterator> getOperands();
+  value_range getOperands();
 
   int getNumOperands() const { return operands.size(); }
   NamedTypeConstraint &getOperand(int index) { return operands[index]; }
index 4b2940a..1ef3fcd 100644 (file)
@@ -1595,12 +1595,6 @@ static LogicalResult verify(ExtractElementOp op) {
   if (op.getType() != aggregateType.getElementType())
     return op.emitOpError("result type must match element type of aggregate");
 
-  // TODO(b/132908002) This should be covered by the op specification in
-  // tablegen, but for some reason it's not.
-  for (auto *idx : op.getIndices())
-    if (!idx->getType().isIndex())
-      return op.emitOpError("index to extract_element must have 'index' type");
-
   // Verify the # indices match if we have a ranked type.
   if (aggregateType.hasRank() &&
       aggregateType.getRank() != op.getNumOperands() - 1)
index 3c269ba..a5dd9c2 100644 (file)
@@ -95,7 +95,7 @@ auto tblgen::Operator::result_begin() -> value_iterator {
 
 auto tblgen::Operator::result_end() -> value_iterator { return results.end(); }
 
-auto tblgen::Operator::getResults() -> llvm::iterator_range<value_iterator> {
+auto tblgen::Operator::getResults() -> value_range {
   return {result_begin(), result_end()};
 }
 
@@ -205,7 +205,7 @@ auto tblgen::Operator::operand_begin() -> value_iterator {
 auto tblgen::Operator::operand_end() -> value_iterator {
   return operands.end();
 }
-auto tblgen::Operator::getOperands() -> llvm::iterator_range<value_iterator> {
+auto tblgen::Operator::getOperands() -> value_range {
   return {operand_begin(), operand_end()};
 }
 
index 562c2ce..6baa104 100644 (file)
@@ -639,7 +639,7 @@ func @extract_element_no_indices(%v : vector<3xf32>) {
 // -----
 
 func @extract_element_invalid_index_type(%v : vector<3xf32>, %i : i32) {
-  // expected-error@+1 {{index to extract_element must have 'index' type}}
+  // expected-error@+1 {{operand #1 must be index}}
   %0 = "std.extract_element"(%v, %i) : (vector<3xf32>, i32) -> f32
   return
 }
diff --git a/mlir/test/IR/operand.mlir b/mlir/test/IR/operand.mlir
new file mode 100644 (file)
index 0000000..0d7939f
--- /dev/null
@@ -0,0 +1,35 @@
+// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test mixed normal and variadic operands
+//===----------------------------------------------------------------------===//
+
+func @correct_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
+  // CHECK: mixed_normal_variadic_operand
+  "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg0, %arg0) : (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+  return
+}
+
+// -----
+
+func @error_in_first_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
+  // expected-error @+1 {{operand #0 must be tensor of any type}}
+  "test.mixed_normal_variadic_operand"(%arg0, %arg1, %arg0, %arg0, %arg0) : (tensor<f32>, f32, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
+  return
+}
+
+// -----
+
+func @error_in_normal_operand(%arg0: tensor<f32>, %arg1: f32) {
+  // expected-error @+1 {{operand #1 must be tensor of any type}}
+  "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg1, %arg0, %arg0) : (tensor<f32>, tensor<f32>, f32, tensor<f32>, tensor<f32>) -> ()
+  return
+}
+
+// -----
+
+func @error_in_second_variadic_operand(%arg0: tensor<f32>, %arg1: f32) {
+  // expected-error @+1 {{operand #2 must be tensor of any type}}
+  "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg1, %arg0) : (tensor<f32>, tensor<f32>, tensor<f32>, f32, tensor<f32>) -> ()
+  return
+}
diff --git a/mlir/test/IR/result.mlir b/mlir/test/IR/result.mlir
new file mode 100644 (file)
index 0000000..fc5d597
--- /dev/null
@@ -0,0 +1,36 @@
+// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Test mixed normal and variadic results
+//===----------------------------------------------------------------------===//
+
+func @correct_variadic_result() -> tensor<f32> {
+  // CHECK: mixed_normal_variadic_result
+  %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>)
+  return %0#4 : tensor<f32>
+}
+
+// -----
+
+func @error_in_first_variadic_result() -> tensor<f32> {
+  // expected-error @+1 {{result #0 must be tensor of any type}}
+  %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, f32, tensor<f32>, tensor<f32>, tensor<f32>)
+  return %0#4 : tensor<f32>
+}
+
+// -----
+
+func @error_in_normal_result() -> tensor<f32> {
+  // expected-error @+1 {{result #1 must be tensor of any type}}
+  %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, f32, tensor<f32>, tensor<f32>)
+  return %0#4 : tensor<f32>
+}
+
+// -----
+
+func @error_in_second_variadic_result() -> tensor<f32> {
+  // expected-error @+1 {{result #2 must be tensor of any type}}
+  %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor<f32>, tensor<f32>, tensor<f32>, f32, tensor<f32>)
+  return %0#4 : tensor<f32>
+}
+
index 845b08d..10c144f 100644 (file)
@@ -60,6 +60,31 @@ def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> {
   let results = (outs NestedTupleOf<[I32, F32]>);
 }
 
+//===----------------------------------------------------------------------===//
+// Test Operands
+//===----------------------------------------------------------------------===//
+
+def MixedNormalVariadicOperandOp : TEST_Op<
+    "mixed_normal_variadic_operand", [SameVariadicOperandSize]> {
+  let arguments = (ins
+    Variadic<AnyTensor>:$input1,
+    AnyTensor:$input2,
+    Variadic<AnyTensor>:$input3
+  );
+}
+
+//===----------------------------------------------------------------------===//
+// Test Results
+//===----------------------------------------------------------------------===//
+
+def MixedNormalVariadicResults : TEST_Op<
+    "mixed_normal_variadic_result", [SameVariadicResultSize]> {
+  let results = (outs
+    Variadic<AnyTensor>:$output1,
+    AnyTensor:$output2,
+    Variadic<AnyTensor>:$output3
+  );
+}
 
 //===----------------------------------------------------------------------===//
 // Test Attributes
index 6055081..ea567e4 100644 (file)
@@ -26,10 +26,6 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
 // CHECK:        assert(operands.size() == 1u && "mismatched number of parameters");
 // CHECK:        tblgen_state->addOperands(operands);
 
-// CHECK:      LogicalResult OpA::verify() {
-// CHECK:        if (!((this->getOperation()->getOperand(0)->getType().isInteger(32))))
-// CHECK-NEXT:     return emitOpError("operand #0 must be 32-bit integer");
-
 def OpB : NS_Op<"one_variadic_operand_op", []> {
   let arguments = (ins Variadic<I32>:$input);
 }
@@ -52,20 +48,6 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
 // CHECK-LABEL: ArrayRef<Value *> OpDOperandAdaptor::input3
 // CHECK-NEXT:    return getODSOperands(2);
 
-// TODO(b/134305899): Move to use TestDialect after fixing verification.
-
-// CHECK-LABEL: Operation::operand_range OpD::getODSOperands(unsigned index)
-// CHECK-NEXT:    bool isVariadic[] = {true, false, true};
-// CHECK-NEXT:    int prevVariadicCount = 0;
-// CHECK-NEXT:    for (int i = 0; i < index; ++i)
-// CHECK-NEXT:      if (isVariadic[i]) ++prevVariadicCount;
-
-// CHECK:         int variadicSize = (getOperation()->getNumOperands() - 1) / 2;
-// CHECK:         int offset = index + (variadicSize - 1) * prevVariadicCount;
-// CHECK-NEXT:    int size = isVariadic[index] ? variadicSize : 1;
-
-// CHECK:         return {std::next(getOperation()->operand_begin(), offset), std::next(getOperation()->operand_begin(), offset + size)};
-
 // CHECK-LABEL: Operation::operand_range OpD::input1
 // CHECK-NEXT: return getODSOperands(0);
 
index e0f14e4..83f804a 100644 (file)
@@ -17,10 +17,6 @@ def OpA : NS_Op<"one_normal_result_op", []> {
 // CHECK:         assert(resultTypes.size() == 1u && "mismatched number of return types");
 // CHECK-NEXT:    tblgen_state->addTypes(resultTypes);
 
-// CHECK-LABEL: LogicalResult OpA::verify()
-// CHECK:         if (!((this->getOperation()->getResult(0)->getType().isInteger(32))))
-// CHECK-NEXT:      return emitOpError("result #0 must be 32-bit integer");
-
 def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
   let arguments = (ins I32:$x);
   let results = (outs I32:$y);
@@ -90,20 +86,6 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]>
   let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Variadic<AnyTensor>:$output3);
 }
 
-// TODO(b/134305899): Move to use TestDialect after fixing verification.
-
-// CHECK-LABEL: Operation::result_range OpI::getODSResults(unsigned index)
-// CHECK-NEXT:   bool isVariadic[] = {true, false, true};
-// CHECK-NEXT:   int prevVariadicCount = 0;
-// CHECK-NEXT:   for (int i = 0; i < index; ++i)
-// CHECK-NEXT:     if (isVariadic[i]) ++prevVariadicCount;
-
-// CHECK:        int variadicSize = (getOperation()->getNumResults() - 1) / 2;
-// CHECK:        int offset = index + (variadicSize - 1) * prevVariadicCount;
-// CHECK-NEXT:   int size = isVariadic[index] ? variadicSize : 1;
-
-// CHECK:        return {std::next(getOperation()->result_begin(), offset), std::next(getOperation()->result_begin(), offset + size)};
-
 // CHECK-LABEL: Operation::result_range OpI::output1
 // CHECK-NEXT:    return getODSResults(0);
 
index 454a01b..7cf5a8d 100644 (file)
@@ -16,7 +16,8 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
 }
 
 // CHECK-LABEL: OpA::verify
-// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32) || this->getOperation()->getOperand(0)->getType().isF32())))
+// CHECK: for (Value *v : getODSOperands(0)) {
+// CHECK:   if (!((v->getType().isInteger(32) || v->getType().isF32())))
 
 def OpB : NS_Op<"op_for_And_PredOpTrait", [
     PredOpTrait<"both first and second holds",
@@ -103,4 +104,5 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> {
 }
 
 // CHECK-LABEL: OpK::verify
-// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa<TensorType>())) && (((this->getOperation()->getOperand(0)->getType().cast<ShapedType>().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast<ShapedType>().getElementType().isInteger(32))))))
+// CHECK: for (Value *v : getODSOperands(0)) {
+// CHECK: if (!(((v->getType().isa<TensorType>())) && (((v->getType().cast<ShapedType>().getElementType().isF32())) || ((v->getType().cast<ShapedType>().getElementType().isInteger(32))))))
index 7183f34..7718a0d 100644 (file)
@@ -448,6 +448,12 @@ private:
   // Generates verify method for the operation.
   void genVerifier();
 
+  // Generates verify statements for operands and results in the operation.
+  // The generated code will be attached to `body`.
+  void genOperandResultVerifier(OpMethodBody &body,
+                                Operator::value_range values,
+                                StringRef valueKind);
+
   // Generates verify statements for regions in the operation.
   // The generated code will be attached to `body`.
   void genRegionVerifier(OpMethodBody &body);
@@ -1022,39 +1028,8 @@ void OpEmitter::genVerifier() {
     body << "  }\n";
   }
 
-  // Emits verification code for an operand or result.
-  auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index,
-                         bool isOperand) -> void {
-    // TODO: Handle variadic operand/result verification.
-    if (value.isVariadic())
-      return;
-
-    // TODO: Commonality between matchers could be extracted to have a more
-    // concise code.
-    if (value.hasPredicate()) {
-      auto description = value.constraint.getDescription();
-      body << "  if (!("
-           << tgfmt(
-                  value.constraint.getConditionTemplate(),
-                  &verifyCtx.withSelf("this->getOperation()->get" +
-                                      Twine(isOperand ? "Operand" : "Result") +
-                                      "(" + Twine(index) + ")->getType()"))
-           << ")) {\n";
-      body << "    return emitOpError(\"" << (isOperand ? "operand" : "result")
-           << " #" << index
-           << (description.empty() ? " type precondition failed"
-                                   : " must be " + Twine(description))
-           << "\");\n  }\n";
-    }
-  };
-
-  for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
-    verifyValue(op.getOperand(i), i, /*isOperand=*/true);
-  }
-
-  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
-    verifyValue(op.getResult(i), i, /*isOperand=*/false);
-  }
+  genOperandResultVerifier(body, op.getOperands(), "operand");
+  genOperandResultVerifier(body, op.getResults(), "result");
 
   for (auto &trait : op.getTraits()) {
     if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
@@ -1073,6 +1048,37 @@ void OpEmitter::genVerifier() {
     body << "  return mlir::success();\n";
 }
 
+void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
+                                         Operator::value_range values,
+                                         StringRef valueKind) {
+  FmtContext fctx;
+  unsigned i = 0;
+  for (auto &staticValue : values) {
+    if (!staticValue.hasPredicate())
+      continue;
+
+    // Emit a loop to check all the dynamic values in the pack.
+    body << formatv("  for (Value *v : getODS{0}{1}s({2})) {{\n",
+                    // Capitalize the first letter to match the function name
+                    valueKind.substr(0, 1).upper(), valueKind.substr(1), i);
+
+    auto description = staticValue.constraint.getDescription();
+    body << "    (void)v;\n";
+    body << "    if (!("
+         << tgfmt(staticValue.constraint.getConditionTemplate(),
+                  &fctx.withSelf("v->getType()"))
+         << "))\n";
+    body << "      return emitOpError(\""
+         // TODO(b/129706806): Use the name of the operand/result here
+         << valueKind << " #" << i
+         << (description.empty() ? " type precondition failed"
+                                 : " must be " + Twine(description))
+         << "\");\n";
+    body << "  }\n";
+    ++i;
+  }
+}
+
 void OpEmitter::genRegionVerifier(OpMethodBody &body) {
   unsigned numRegions = op.getNumRegions();