Allow argument and result names replacement in predicates.
authorJacques Pienaar <jpienaar@google.com>
Thu, 30 May 2019 17:57:23 +0000 (10:57 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:11:01 +0000 (20:11 -0700)
    This allow specifying $x to refer to an operand's named argument (operand or attribute) or result. Skip variadic operands/results for now pending autogenerated discussion of their accessors.

    This adds a new predicate, following feedback on the naming but does not remove the old one. Post feedback I'll do that, potentially in follow up.

--

PiperOrigin-RevId: 250720003

mlir/include/mlir/IR/OpBase.td
mlir/test/TestDialect/TestOps.td
mlir/test/mlir-tblgen/types.mlir
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

index edbd273..ce1a87b 100644 (file)
@@ -996,6 +996,14 @@ class TCopVTEtIs<int idx, Type type> : And<[
      ")->getType().cast<ShapedType>().getElementType()",
      type.predicate>]>;
 
+// Predicate to verify that a named argument or result's element type matches a
+// given type.
+class ArgOrResultElementTypeIs<string name, Type type> : And<[
+   SubstLeaves<"$_self",  "$" # name # "->getType()", IsShapedTypePred>,
+   SubstLeaves<"$_self",  "$" # name #
+     "->getType().cast<ShapedType>().getElementType()",
+     type.predicate>]>;
+
 // Predicate to verify that the i'th operand and the j'th operand have the same
 // elemental type.
 // Type Constraint operand `i`'s Element type is Same As operand `j`'s Element
index 915318d..5ffbcbc 100644 (file)
@@ -77,6 +77,17 @@ def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape",
   let results = (outs AnyVectorOrTensor:$res);
 }
 
+def ArgAndResHaveFixedElementTypesOp :
+    TEST_Op<"arg_and_res_have_fixed_element_types",
+      [PredOpTrait<"fixed type combination",
+         Or<[And<[ArgOrResultElementTypeIs<"x", I32>,
+                  ArgOrResultElementTypeIs<"y", F32>,
+                  ArgOrResultElementTypeIs<"res", I16>]>,
+             ArgOrResultElementTypeIs<"attr", I8>]>>]> {
+  let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y, AnyAttr:$attr);
+  let results = (outs AnyVectorOrTensor:$res);
+}
+
 //===----------------------------------------------------------------------===//
 // Test Patterns
 //===----------------------------------------------------------------------===//
index 1472487..e84adcf 100644 (file)
@@ -79,3 +79,34 @@ func @nested_tuple_multi_level_wrong_type() {
   return
 }
 
+// -----
+
+// CHECK-LABEL: @fixed_element_types
+func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
+  %0 = "test.arg_and_res_have_fixed_element_types"(%arg0, %arg1) {attr: ""} : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x i16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @fixed_element_types
+func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
+  %0 = "test.arg_and_res_have_fixed_element_types"(%arg0, %arg1) {attr: splat<tensor<2xi8>, 1>}: (tensor<* x i32>, tensor<* x f32>) -> tensor<* x i32>
+  return
+}
+
+// -----
+
+func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
+  // expected-error@+1 {{fixed type combination}}
+  %0 = "test.arg_and_res_have_fixed_element_types"(%arg0, %arg1) {attr: ""}: (tensor<* x i32>, tensor<* x f32>) -> tensor<* x i32>
+  return
+}
+
+// -----
+
+func @fixed_element_types(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
+  // expected-error@+1 {{fixed type combination}}
+  %0 = "test.arg_and_res_have_fixed_element_types"(%arg1, %arg0) {attr: ""}: (tensor<* x f32>, tensor<* x i32>) -> tensor<* x i16>
+  return
+}
index 0d21d73..1dc9a95 100644 (file)
@@ -903,10 +903,34 @@ void OpEmitter::genVerifier() {
   FmtContext fctx;
   fctx.withOp("(*this->getOperation())");
 
+  // Populate substitutions for attributes and named operands and results.
+  for (const auto &namedAttr : op.getAttributes())
+    fctx.addSubst(namedAttr.name,
+                  formatv("(&this->getAttr(\"{0}\"))", namedAttr.name));
+  for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
+    auto &value = op.getOperand(i);
+    // Skip from from first variadic operands for now. Else getOperand index
+    // used below doesn't match.
+    if (value.isVariadic())
+      break;
+    if (!value.name.empty())
+      fctx.addSubst(value.name,
+                    formatv("this->getOperation()->getOperand({0})", i));
+  }
+  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+    auto &value = op.getResult(i);
+    // Skip from from first variadic results for now. Else getResult index used
+    // below doesn't match.
+    if (value.isVariadic())
+      break;
+    if (!value.name.empty())
+      fctx.addSubst(value.name,
+                    formatv("this->getOperation()->getResult({0})", i));
+  }
+
   // Verify the attributes have the correct type.
   for (const auto &namedAttr : op.getAttributes()) {
     const auto &attr = namedAttr.attr;
-
     if (attr.isDerivedAttr())
       continue;