Tablegen helpers for accessing properties of shaped types
authorGeoffrey Martin-Noble <gcmn@google.com>
Sat, 28 Sep 2019 00:34:56 +0000 (17:34 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 28 Sep 2019 00:35:34 +0000 (17:35 -0700)
Tablegen's lack of functions continues to be annoying

PiperOrigin-RevId: 271680947

mlir/include/mlir/IR/OpBase.td
mlir/test/lib/TestDialect/TestOps.td

index 43e9b5e..89609ed 100644 (file)
 // Common utilities for defining TableGen mechanisms
 //===----------------------------------------------------------------------===//
 
-// Concatenates a list of strings with a separator (default ", ")
-class StrJoin<list<string> strings, string sep = ", "> {
-  string result =
-      !if(!empty(strings), "",
-          !foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur));
+// A workaround for the inability to define functions in Tablegen.
+//
+// The template parameter defines a string that can be extracted from an
+// instance of this class by accessing the "result" member. Subclasses can take
+// their own template parameters as function "arguments" and use them to
+// populate result.
+// For example, if it didn't already exist, a concat function could be defined
+// like:
+//
+// class StrConcat<list<string> strings> :
+//     StrFunc<!foldl("", strings, prev, cur, prev # cur)>
+//
+// and then called like
+//
+// StrConcat<["a", "b", "c"]>.result
+//
+// to get the string "abc"
+class StrFunc<string r> {
+  string result = r;
 }
 
+// Concatenates a list of strings with a separator (default ", ")
+class StrJoin<list<string> strings, string sep = ", "> :
+    StrFunc<!if(!empty(strings), "",
+         !foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur))>;
+
 // Concatenates a list of integers into a string with a separator (default ", ")
 class StrJoinInt<list<int> integers, string sep = ", "> :
     StrJoin<!foreach(i, integers, !cast<string>(i)), sep>;
@@ -1437,6 +1456,14 @@ def HasNoUseOf: Constraint<
 
 // TODO(b/135033717): Improve the autogenerated error messages.
 
+class Rank<string name> :
+    StrFunc<"$" # name # ".getType().cast<ShapedType>().getRank()">;
+
+class ElementCount<string name> :
+  StrFunc<"$" # name # ".getType().cast<ShapedType>().getNumElements()">;
+
+class ElementType<string name> : StrFunc<"getElementTypeOrSelf($" # name # ")">;
+
 class AllMatchPred<list<string> values> :
     CPred<"llvm::is_splat(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
 
@@ -1454,17 +1481,15 @@ class AllMatchSameOperatorTrait<list<string> names, string operator,
         AllMatchSameOperatorPred<names, operator>>;
 
 class AllElementCountsMatch<list<string> names> :
-    AllMatchSameOperatorTrait<
-      names, "$_self.getType().cast<ShapedType>().getNumElements()",
-      "element count">;
+    AllMatchSameOperatorTrait<names, ElementCount<"_self">.result,
+                              "element count">;
 
 class AllElementTypesMatch<list<string> names> :
-    AllMatchSameOperatorTrait<names,
-                              "getElementTypeOrSelf($_self)", "element type">;
+    AllMatchSameOperatorTrait<names, ElementType<"_self">.result,
+                              "element type">;
 
 class AllRanksMatch<list<string> names> :
-    AllMatchSameOperatorTrait<
-        names, "$_self.getType().cast<ShapedType>().getRank()", "rank">;
+    AllMatchSameOperatorTrait<names, Rank<"_self">.result, "rank">;
 
 class AllTypesMatch<list<string> names> :
     AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
index 862861a..0c2e53b 100644 (file)
@@ -293,9 +293,8 @@ def FourEqualsFive :
 
 def OperandRankEqualsResultSize :
     TEST_Op<"operand_rank_equals_result_size",
-            [AllMatch<["$operand.getType().cast<ShapedType>().getRank()",
-                       "$result.getType().cast<ShapedType>().getNumElements()"
-                      ], "operand rank equals result size">]> {
+            [AllMatch<[Rank<"operand">.result, ElementCount<"result">.result],
+                      "operand rank equals result size">]> {
   let arguments = (ins AnyTensor:$operand);
   let results = (outs AnyTensor:$result);
 }