Simplify container type definitions
authorGeoffrey Martin-Noble <gcmn@google.com>
Tue, 21 May 2019 22:44:17 +0000 (15:44 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 02:55:33 +0000 (19:55 -0700)
    The passed element type description is usually unnecessary, and it's just as valid to want to pass a description for the entire container. In either case there's an alternative (Separate element type def or a TypeAlias) and we don't need to pollute the main API.

    To allow for this, I cleaned up the TF op definitions and added some additional utilities.

--

PiperOrigin-RevId: 249340979

mlir/include/mlir/IR/OpBase.td

index a7b69b6..e49e7dc 100644 (file)
@@ -34,9 +34,9 @@ class StrJoin<list<string> strings, string sep = ", "> {
           !foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur));
 }
 
-// Concatenates a list of integers into a string separated with comma.
-class Stringify<list<int> integers> :
-    StrJoin<!foreach(i, integers, !cast<string>(i))>;
+// Concatenates a list of integers into a string with a separator (default ", ")
+class StrJoinInt<list<int> integers, string sep = ", "> :
+    StrJoin<!foreach(i, integers, !cast<string>(i)), sep>;
 
 //===----------------------------------------------------------------------===//
 // Predicate definitions
@@ -241,6 +241,10 @@ class Dialect {
 class Type<Pred condition, string descr = ""> :
     TypeConstraint<condition, descr>;
 
+// Allows providing an alternative name and description to an existing type def.
+class TypeAlias<Type t, string description = t.description> :
+    Type<t.predicate, description>;
+
 // A variadic type constraint. It expands to zero or more of the base type. This
 // class is used for supporting variadic operands/results. An op can declare no
 // more than one variadic operand/result, and that operand/result must be the
@@ -291,6 +295,11 @@ class I<int width>
       BuildableType<"getIntegerType(" # width # ")"> {
   int bitwidth = width;
 }
+
+class IntOfWidths<list<int> widths> :
+    AnyTypeOf<!foreach(w, widths, I<w>),
+              StrJoinInt<widths, "/">.result # "-bit integer">;
+
 def I1  : I<1>;
 def I8  : I<8>;
 def I16 : I<16>;
@@ -310,6 +319,10 @@ class F<int width>
   int bitwidth = width;
 }
 
+class FloatOfWidths<list<int> widths> :
+    AnyTypeOf<!foreach(w, widths, F<w>),
+              StrJoinInt<widths, "/">.result # "-bit float">;
+
 def F16 : F<16>;
 def F32 : F<32>;
 def F64 : F<64>;
@@ -338,24 +351,22 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
   code getElementTypeCall = elementTypeCall;
 }
 
-class ShapedContainerType<Type etype, Pred containerPred, string descr> :
-    ContainerType<etype, containerPred,
+class ShapedContainerType<list<Type> allowedTypes, Pred containerPred, string descr> :
+    ContainerType<AnyTypeOf<allowedTypes>, containerPred,
                   "$_self.cast<ShapedType>().getElementType()", descr>;
 
 // Vector types.
 
-class VectorOf<list<Type> allowedTypes, string elementDescription = ""> :
-  ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
-                      IsVectorTypePred, "vector">;
+class VectorOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">;
 
 def AnyVector : VectorOf<[AnyType]>;
 
 // Tensor types.
 
 // Any tensor type whose element type is from the given `allowedTypes` list
-class TensorOf<list<Type> allowedTypes, string elementDescription = ""> :
-  ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
-                      IsTensorTypePred, "tensor">;
+class TensorOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor">;
 
 def AnyTensor : TensorOf<[AnyType]>;
 
@@ -381,8 +392,8 @@ def F64Tensor  : TensorOf<[F64]>;
 
 // TODO(b/132735995) Use ShapedContainerType when MemRef subclasses ShapedType.
 // Memrefs are blocks of data with fixed type and rank.
-class MemRefOf<list<Type> allowedTypes, string elementDescription = ""> :
-    ContainerType<AnyTypeOf<allowedTypes, elementDescription>, IsMemRefTypePred,
+class MemRefOf<list<Type> allowedTypes> :
+    ContainerType<AnyTypeOf<allowedTypes>, IsMemRefTypePred,
                     "$_self.cast<MemRefType>().getElementType()", "memref">;
 
 def AnyMemRef : MemRefOf<[AnyType]>;
@@ -992,7 +1003,7 @@ class TCopVTEtAreSameAt<list<int> indices> :
     CPred<"llvm::is_splat(mlir::functional::map("
             "[this](unsigned i) { return this->getOperand(i)->getType()"
               ".cast<ShapedType>().getElementType(); }, "
-            "llvm::ArrayRef<unsigned>({" # Stringify<indices>.result # "})))">;
+            "llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result # "})))">;
 
 //===----------------------------------------------------------------------===//
 // Pattern definitions