Clean up container type names in OpBase
authorGeoffrey Martin-Noble <gcmn@google.com>
Tue, 21 May 2019 17:45:30 +0000 (10:45 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 02:55:03 +0000 (19:55 -0700)
    Establish the following convention:
    1. Container class types end in "Of" (e.g. TensorOf) and take a list of allowed types.
    2. An X container where only a single type is allowed is called TypeX (e.g. I32Tensor).
    3. An X container where any type is allowed is called AnyX (e.g. AnyTensor).

--

PiperOrigin-RevId: 249281018

mlir/g3doc/OpDefinitions.md
mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/StandardOps/Ops.td
mlir/test/mlir-tblgen/predicate.td

index 90fbfaf..4befbde 100644 (file)
@@ -450,7 +450,7 @@ the entity's declaration place as described in
 
 To help modelling constraints of common types, a set of `TypeConstraint`s are
 created; they are the `Type` subclass hierarchy. It includes `F32` for the
-constraints of being a float, `TypedTensor<F32>` for the constraints of being
+constraints of being a float, `TensorOf<[F32]>` for the constraints of being
 a float tensor, and so on.
 
 Similarly, a set of `AttrConstraint`s are created for helping modelling
index 9db5af4..d000946 100644 (file)
@@ -28,8 +28,8 @@
 
 class quant_TypedPrimitiveOrContainer<Type etype> :
     Type<Or<[etype.predicate,
-                Tensor<etype>.predicate,
-                Vector<etype>.predicate]>,
+                TensorOf<[etype]>.predicate,
+                VectorOf<[etype]>.predicate]>,
          "primitive/tensor/vector of " # etype.description>;
 
 // An implementation of QuantizedType.
index 429032d..a7b69b6 100644 (file)
@@ -344,49 +344,72 @@ class ShapedContainerType<Type etype, Pred containerPred, string descr> :
 
 // Vector types.
 
-class Vector<Type t> : ShapedContainerType<t, IsVectorTypePred, "vector">;
+class VectorOf<list<Type> allowedTypes, string elementDescription = ""> :
+  ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
+                      IsVectorTypePred, "vector">;
 
-def AnyVector : Vector<AnyType>;
+def AnyVector : VectorOf<[AnyType]>;
 
 // Tensor types.
 
-class Tensor<Type t> : ShapedContainerType<t, IsTensorTypePred, "tensor">;
-
-def AnyTensor : Tensor<AnyType>;
-
 // Any tensor type whose element type is from the given `allowedTypes` list
-class AnyTensorOf<list<Type> allowedTypes, string elementDescription = ""> :
-  Tensor<AnyTypeOf<allowedTypes, elementDescription>>;
+class TensorOf<list<Type> allowedTypes, string elementDescription = ""> :
+  ShapedContainerType<AnyTypeOf<allowedTypes, elementDescription>,
+                      IsTensorTypePred, "tensor">;
+
+def AnyTensor : TensorOf<[AnyType]>;
 
 // TODO(b/130807343) Fix description to contain element information.
 class StaticShapeTensor<Type t>
-    : Type<And<[ Tensor<t>.predicate, HasStaticShapePred ]>,
+    : Type<And<[ TensorOf<[t]>.predicate, HasStaticShapePred ]>,
            "statically shaped tensor">;
 
 def AnyStaticShapeTensor : StaticShapeTensor<AnyType>;
 
-def I1Tensor   : Tensor<I1>;
-def I8Tensor   : Tensor<I8>;
-def I16Tensor  : Tensor<I16>;
-def I32Tensor  : Tensor<I32>;
-def I64Tensor  : Tensor<I64>;
+def I1Tensor   : TensorOf<[I1]>;
+def I8Tensor   : TensorOf<[I8]>;
+def I16Tensor  : TensorOf<[I16]>;
+def I32Tensor  : TensorOf<[I32]>;
+def I64Tensor  : TensorOf<[I64]>;
+
+def BF16Tensor : TensorOf<[BF16]>;
+def F16Tensor  : TensorOf<[F16]>;
+def F32Tensor  : TensorOf<[F32]>;
+def F64Tensor  : TensorOf<[F64]>;
+
+// Memref type.
+
+// TODO(b/132735995) Use ShapedContainerType when MemRef subclasses ShapedType.
+// Memrefs are blocks of data with fixed type and rank.
+class MemRefOf<list<Type> allowedTypes, string elementDescription = ""> :
+    ContainerType<AnyTypeOf<allowedTypes, elementDescription>, IsMemRefTypePred,
+                    "$_self.cast<MemRefType>().getElementType()", "memref">;
+
+def AnyMemRef : MemRefOf<[AnyType]>;
+
+// Memref declarations handle any memref, independent of rank, size, (static or
+// dynamic), layout, or memory space.
+def I1MemRef  : MemRefOf<[I1]>;
+def I8MemRef  : MemRefOf<[I8]>;
+def I16MemRef : MemRefOf<[I16]>;
+def I32MemRef : MemRefOf<[I32]>;
+def I64MemRef : MemRefOf<[I64]>;
 
-def BF16Tensor : Tensor<BF16>;
-def F16Tensor  : Tensor<F16>;
-def F32Tensor  : Tensor<F32>;
-def F64Tensor  : Tensor<F64>;
+def BF16MemRef : MemRefOf<[BF16]>;
+def F16MemRef  : MemRefOf<[F16]>;
+def F32MemRef  : MemRefOf<[F32]>;
+def F64MemRef  : MemRefOf<[F64]>;
 
-// This represents a generic tuple without any constraints on elemental type,
-// ranks, or size. As Tuples can contain tensors, vectors, or scalar values
-// there is not only a single elemental type.
-def Tuple : Type<IsTupleTypePred, "tuple">;
+// This represents a generic tuple without any constraints on element type.
+def AnyTuple : Type<IsTupleTypePred, "tuple">;
 
+// TODO(b/132952417) Make this accept a list of types like the classes above.
 // A Tuple that only holds elements of a certain type. This cannot inherit from
 // ContainerType because tuples do not always have a single element type that
 // could be retrieved with elementTypeCall.
-class TypedTuple<Type t> :
+class TupleOf<Type t> :
     Type<And<[
-        Tuple.predicate,
+        IsTupleTypePred,
         Concat<
             [{
               llvm::all_of(
@@ -398,26 +421,6 @@ class TypedTuple<Type t> :
             "; })">
     ]>, "tuple">;
 
-// Memref type.
-
-// Memrefs are blocks of data with fixed type and rank.
-class MemRef<Type t>
-    : ContainerType<t, IsMemRefTypePred,
-                    "$_self.cast<MemRefType>().getElementType()", "memref">;
-
-// Memref declarations handle any memref, independent of rank, size, (static or
-// dynamic), layout, or memory space.
-def I1MemRef  : MemRef<I1>;
-def I8MemRef  : MemRef<I8>;
-def I16MemRef : MemRef<I16>;
-def I32MemRef : MemRef<I32>;
-def I64MemRef : MemRef<I64>;
-
-def BF16MemRef : MemRef<BF16>;
-def F16MemRef  : MemRef<F16>;
-def F32MemRef  : MemRef<F32>;
-def F64MemRef  : MemRef<F64>;
-
 //===----------------------------------------------------------------------===//
 // Common type constraints
 //===----------------------------------------------------------------------===//
@@ -425,12 +428,12 @@ def F64MemRef  : MemRef<F64>;
 // Type constraint for integer-like types: integers, indices, vectors of
 // integers, tensors of integers.
 def IntegerLike : TypeConstraint<Or<[Integer.predicate, Index.predicate,
-        Vector<Integer>.predicate, Tensor<Integer>.predicate]>,
+        VectorOf<[Integer]>.predicate, TensorOf<[Integer]>.predicate]>,
     "integer-like">;
 
 // Type constraint for float-like types: floats, vectors or tensors thereof.
 def FloatLike : TypeConstraint<Or<[Float.predicate,
-        Vector<Float>.predicate, Tensor<Float>.predicate]>,
+        VectorOf<[Float]>.predicate, TensorOf<[Float]>.predicate]>,
     "floating-point-like">;
 
 
index 37b8f84..7cee626 100644 (file)
@@ -146,7 +146,7 @@ def AllocOp : Std_Op<"alloc"> {
   }];
 
   let arguments = (ins Variadic<Index>:$value);
-  let results = (outs MemRef<AnyType>);
+  let results = (outs AnyMemRef);
 
   let builders = [OpBuilder<
     "Builder *builder, OperationState *result, MemRefType memrefType", [{
@@ -303,7 +303,7 @@ def DeallocOp : Std_Op<"dealloc"> {
       dealloc %0 : memref<8x64xf32, (d0, d1) -> (d0, d1), 1>
   }];
 
-  let arguments = (ins MemRef<AnyType>:$memref);
+  let arguments = (ins AnyMemRef:$memref);
 
   let hasCanonicalizer = 1;
 }
@@ -318,7 +318,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
       %1 = dim %0, 2 : tensor<?x?x?xf32>
   }];
 
-  let arguments = (ins AnyTypeOf<[MemRef<AnyType>, AnyTensor],
+  let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor],
                                  "any tensor or memref type">:$memrefOrTensor,
                        APIntAttr:$index);
   let results = (outs Index);
@@ -410,8 +410,8 @@ def MemRefCastOp : CastOp<"memref_cast"> {
        %3 = memref_cast %1 : memref<4xf32> to memref<?xf32>
   }];
 
-  let arguments = (ins MemRef<AnyType>);
-  let results = (outs MemRef<AnyType>);
+  let arguments = (ins AnyMemRef);
+  let results = (outs AnyMemRef);
 
   let extraClassDeclaration = [{
     /// Return true if `a` and `b` are valid operand and result pairs for
index c1dab2a..97a3689 100644 (file)
@@ -112,7 +112,7 @@ def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [
 // CHECK:   return emitOpError("failed to verify that operands indexed at 0, 2, 3 should all have the same type");
 
 def OpK : NS_Op<"op_for_AnyTensorOf", []> {
-  let arguments = (ins AnyTensorOf<[F32, I32]>:$x);
+  let arguments = (ins TensorOf<[F32, I32]>:$x);
 }
 
 // CHECK-LABEL: OpK::verify