From e7c3ca92f846fd757fd755b2fa0c908bb1775b1b Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Fri, 27 Sep 2019 17:34:56 -0700 Subject: [PATCH] Tablegen helpers for accessing properties of shaped types Tablegen's lack of functions continues to be annoying PiperOrigin-RevId: 271680947 --- mlir/include/mlir/IR/OpBase.td | 49 +++++++++++++++++++++++++++--------- mlir/test/lib/TestDialect/TestOps.td | 5 ++-- 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 43e9b5e..89609ed 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -27,13 +27,32 @@ // Common utilities for defining TableGen mechanisms //===----------------------------------------------------------------------===// -// Concatenates a list of strings with a separator (default ", ") -class StrJoin strings, string sep = ", "> { - string result = - !if(!empty(strings), "", - !foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur)); +// A workaround for the inability to define functions in Tablegen. +// +// The template parameter defines a string that can be extracted from an +// instance of this class by accessing the "result" member. Subclasses can take +// their own template parameters as function "arguments" and use them to +// populate result. +// For example, if it didn't already exist, a concat function could be defined +// like: +// +// class StrConcat strings> : +// StrFunc +// +// and then called like +// +// StrConcat<["a", "b", "c"]>.result +// +// to get the string "abc" +class StrFunc { + string result = r; } +// Concatenates a list of strings with a separator (default ", ") +class StrJoin strings, string sep = ", "> : + StrFunc; + // Concatenates a list of integers into a string with a separator (default ", ") class StrJoinInt integers, string sep = ", "> : StrJoin(i)), sep>; @@ -1437,6 +1456,14 @@ def HasNoUseOf: Constraint< // TODO(b/135033717): Improve the autogenerated error messages. +class Rank : + StrFunc<"$" # name # ".getType().cast().getRank()">; + +class ElementCount : + StrFunc<"$" # name # ".getType().cast().getNumElements()">; + +class ElementType : StrFunc<"getElementTypeOrSelf($" # name # ")">; + class AllMatchPred values> : CPred<"llvm::is_splat(llvm::makeArrayRef({"# StrJoin.result #"}))">; @@ -1454,17 +1481,15 @@ class AllMatchSameOperatorTrait names, string operator, AllMatchSameOperatorPred>; class AllElementCountsMatch names> : - AllMatchSameOperatorTrait< - names, "$_self.getType().cast().getNumElements()", - "element count">; + AllMatchSameOperatorTrait.result, + "element count">; class AllElementTypesMatch names> : - AllMatchSameOperatorTrait; + AllMatchSameOperatorTrait.result, + "element type">; class AllRanksMatch names> : - AllMatchSameOperatorTrait< - names, "$_self.getType().cast().getRank()", "rank">; + AllMatchSameOperatorTrait.result, "rank">; class AllTypesMatch names> : AllMatchSameOperatorTrait; diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 862861a..0c2e53b 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -293,9 +293,8 @@ def FourEqualsFive : def OperandRankEqualsResultSize : TEST_Op<"operand_rank_equals_result_size", - [AllMatch<["$operand.getType().cast().getRank()", - "$result.getType().cast().getNumElements()" - ], "operand rank equals result size">]> { + [AllMatch<[Rank<"operand">.result, ElementCount<"result">.result], + "operand rank equals result size">]> { let arguments = (ins AnyTensor:$operand); let results = (outs AnyTensor:$result); } -- 2.7.4