[TableGen] Add the `TCopVTEtAreSameAt` PredOpTrait
authorLei Zhang <antiagainst@google.com>
Thu, 2 May 2019 19:08:17 +0000 (12:08 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:25:18 +0000 (08:25 -0700)
    This trait is used for specifying operands at the indices from a given
    list are with the same element type.

--

PiperOrigin-RevId: 246364735

mlir/include/mlir/IR/OpBase.td
mlir/test/mlir-tblgen/predicate.td

index f45f748..fd84701 100644 (file)
 #define OP_BASE
 
 //===----------------------------------------------------------------------===//
+// Common utilities for defining TableGen mechanisms
+//===----------------------------------------------------------------------===//
+
+// Concatenates a list of integers into a string separated with comma.
+class Stringify<list<int> integers> {
+  string result = !foldl(/*init*/!cast<string>(!head(integers)),
+                         /*list*/!tail(integers), prev, cur, prev # ", " # cur);
+}
+
+//===----------------------------------------------------------------------===//
 // Predicate definitions
 //===----------------------------------------------------------------------===//
 
@@ -320,8 +330,7 @@ class Vector<Type t, list<int> dims> : ContainerType<t, AllOf<[
     // Match dims. Construct an ArrayRef with the elements of `dims` by folding
     // over the list.
     CPred<"$_self.cast<VectorType>().getShape() == ArrayRef{{" #
-      !foldl("", dims, sum, element, sum #
-       !if(!empty(sum), "", ",") # !cast<string>(element)) # "}">]>,
+      Stringify<dims>.result # "}">]>,
     "$_self.cast<VectorType>().getElementType()",
     "vector"> {
   list<int> dimensions = dims;
@@ -952,6 +961,23 @@ class TCresVTEtIsSameAsOp<int i, int j> : AllOf<[
           "getElementType() == $_op.getOperand(" # j # ")->getType()."
           "cast<VectorOrTensorType>().getElementType()">]>;
 
+// Predicate OpTrait to verify that all the operands at the given `indices`
+// have the same element type.
+// Type Constraint operands' Vector or Tensor Element type are all Same At
+// the given `indices`.
+//
+// Precondition:
+// 1) all operands involved are of vector or tensor type and
+// 2) the indices are not out of range.
+class TCopVTEtAreSameAt<list<int> indices = []> : PredOpTrait<
+    "operands indexed at " # Stringify<indices>.result #
+      " should all have the same type",
+    // We query the operands' types into a list and check they are all the same.
+    CPred<"llvm::is_splat(mlir::functional::map("
+            "[this](unsigned i) { return this->getOperand(i)->getType()"
+              ".cast<VectorOrTensorType>().getElementType(); }, "
+            "llvm::ArrayRef<unsigned>({" # Stringify<indices>.result # "})))">>;
+
 //===----------------------------------------------------------------------===//
 // Pattern definitions
 //===----------------------------------------------------------------------===//
index 0c57c85..8869e5d 100644 (file)
@@ -92,3 +92,20 @@ def OpI: NS_Op<"op_for_arr_min_value_at_index", []> {
 // CHECK-LABEL: OpI::verify()
 // CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>().getValue()[0].cast<IntegerAttr>().getInt() >= 8)))))
 // CHECK-SAME:    return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8");
+
+def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt",
+               [TCopVTEtAreSameAt<[0, 2, 3]>]> {
+  let arguments = (ins
+    Tensor:$a,
+    Tensor:$b,
+    Tensor:$c,
+    Tensor:$d,
+    Tensor:$e
+  );
+}
+
+// CHECK-LABEL: OpJ::verify()
+// CHECK:      llvm::is_splat(mlir::functional::map(
+// CHECK-SAME:   [this](unsigned i) { return this->getOperand(i)->getType().cast<VectorOrTensorType>().getElementType(); },
+// CHECK-SAME:   llvm::ArrayRef<unsigned>({0, 2, 3})))
+// CHECK:   return emitOpError("failed to verify that operands indexed at 0, 2, 3 should all have the same type");