Support TF Variant type in the tf/mlir roundtrip pass.
authorFeng Liu <fengliuai@google.com>
Thu, 20 Sep 2018 04:15:43 +0000 (21:15 -0700)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 20:16:18 +0000 (13:16 -0700)
PiperOrigin-RevId: 213748573

mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/Types.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/Types.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/TokenKinds.def

index b71808476fef4c0ab9ff0b53b3532cc9a92d3668..82a3bbf3be7896a18e5fa56e71d47dec9045b208 100644 (file)
@@ -77,6 +77,7 @@ public:
   OtherType *getTFControlType();
   OtherType *getTFStringType();
   OtherType *getTFResourceType();
+  OtherType *getTFVariantType();
   IntegerType *getIntegerType(unsigned width);
   FunctionType *getFunctionType(ArrayRef<Type *> inputs,
                                 ArrayRef<Type *> results);
index 04a1e26f7ef53049637b7c3dc0d8afdded1b1dff..53f03e5944247cbaf2c2c6e6fb00c2677d46adb9 100644 (file)
@@ -41,6 +41,7 @@ public:
     // TensorFlow types.
     TFControl,
     TFResource,
+    TFVariant,
     TFString,
 
     /// These are marker for the first and last 'other' type.
@@ -75,6 +76,7 @@ public:
   bool isAffineInt() const { return getKind() == Kind::AffineInt; }
   bool isTFControl() const { return getKind() == Kind::TFControl; }
   bool isTFResource() const { return getKind() == Kind::TFResource; }
+  bool isTFVariant() const { return getKind() == Kind::TFVariant; }
   bool isTFString() const { return getKind() == Kind::TFString; }
   bool isBF16() const { return getKind() == Kind::BF16; }
   bool isF16() const { return getKind() == Kind::F16; }
@@ -94,6 +96,7 @@ public:
   static OtherType *getTFControl(MLIRContext *ctx);
   static OtherType *getTFString(MLIRContext *ctx);
   static OtherType *getTFResource(MLIRContext *ctx);
+  static OtherType *getTFVariant(MLIRContext *ctx);
 
   /// Print the current type.
   void print(raw_ostream &os) const;
@@ -224,6 +227,9 @@ inline OtherType *Type::getTFResource(MLIRContext *ctx) {
 inline OtherType *Type::getTFString(MLIRContext *ctx) {
   return OtherType::get(Kind::TFString, ctx);
 }
+inline OtherType *Type::getTFVariant(MLIRContext *ctx) {
+  return OtherType::get(Kind::TFVariant, ctx);
+}
 
 /// Function types map from a list of inputs to a list of results.
 class FunctionType : public Type {
@@ -432,6 +438,12 @@ private:
   ~MemRefType() = delete;
 };
 
+/// Return true if the specified element type is ok in a tensor.
+static bool isValidTensorElementType(Type *type) {
+  return isa<FloatType>(type) || isa<VectorType>(type) ||
+         isa<IntegerType>(type) || isa<OtherType>(type);
+}
+
 } // end namespace mlir
 
 #endif  // MLIR_IR_TYPES_H
index e9104472ed661406a88018f8a121ca98a1b99832..a64fd06dc70f934f17c0741715e8889db9a7f66e 100644 (file)
@@ -480,6 +480,9 @@ void ModulePrinter::printType(const Type *type) {
   case Type::Kind::TFResource:
     os << "tf_resource";
     return;
+  case Type::Kind::TFVariant:
+    os << "tf_variant";
+    return;
   case Type::Kind::TFString:
     os << "tf_string";
     return;
index bc8f82fff140b8f46d59811787d35ddc2fabb87f..c4f70934f8b0faff383cb23b8094ce4e285b23db 100644 (file)
@@ -66,6 +66,8 @@ OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
 
 OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
 
+OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
+
 OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
 
 IntegerType *Builder::getIntegerType(unsigned width) {
index 400ef0037228374485286ec932d14df354a7e695..2eedf8a4ace9e56540aa03cb6be663285fa7bbe2 100644 (file)
@@ -60,16 +60,9 @@ VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
     : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
       shapeElements(shape.data()) {}
 
-/// Return true if the specified element type is ok in a tensor.
-static bool isValidTensorElementType(Type *type, MLIRContext *context) {
-  return isa<FloatType>(type) || isa<VectorType>(type) ||
-         isa<IntegerType>(type) || type == Type::getTFString(context) ||
-         type == Type::getTFResource(context);
-}
-
 TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
     : VectorOrTensorType(kind, context, elementType) {
-  assert(isValidTensorElementType(elementType, context));
+  assert(isValidTensorElementType(elementType));
 }
 
 RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
index e68e86cae0f6844b5672865c2b83e723f897aec4..ed0dd69bbc44725fec688a3ebb09b5b47bf432f3 100644 (file)
@@ -336,6 +336,9 @@ Type *Parser::parseType() {
   case Token::kw_tf_resource:
     consumeToken(Token::kw_tf_resource);
     return builder.getTFResourceType();
+  case Token::kw_tf_variant:
+    consumeToken(Token::kw_tf_variant);
+    return builder.getTFVariantType();
   case Token::kw_tf_string:
     consumeToken(Token::kw_tf_string);
     return builder.getTFStringType();
@@ -468,8 +471,7 @@ Type *Parser::parseTensorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
     return nullptr;
 
-  if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) &&
-      !isa<VectorType>(elementType))
+  if (!isValidTensorElementType(elementType))
     return (emitError(typeLoc, "invalid tensor element type"), nullptr);
 
   if (isUnranked)
index 1a98eed90c3ddb80239bab10d7ad1bcd655e1e95..3e88d764729137ba4b8d80e3645391fa853452f4 100644 (file)
@@ -114,6 +114,7 @@ TOK_KEYWORD(step)
 TOK_KEYWORD(tensor)
 TOK_KEYWORD(tf_control)
 TOK_KEYWORD(tf_resource)
+TOK_KEYWORD(tf_variant)
 TOK_KEYWORD(tf_string)
 TOK_KEYWORD(to)
 TOK_KEYWORD(true)