OtherType *getTFControlType();
OtherType *getTFStringType();
OtherType *getTFResourceType();
+ OtherType *getTFVariantType();
IntegerType *getIntegerType(unsigned width);
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
ArrayRef<Type *> results);
// TensorFlow types.
TFControl,
TFResource,
+ TFVariant,
TFString,
/// These are marker for the first and last 'other' type.
bool isAffineInt() const { return getKind() == Kind::AffineInt; }
bool isTFControl() const { return getKind() == Kind::TFControl; }
bool isTFResource() const { return getKind() == Kind::TFResource; }
+ bool isTFVariant() const { return getKind() == Kind::TFVariant; }
bool isTFString() const { return getKind() == Kind::TFString; }
bool isBF16() const { return getKind() == Kind::BF16; }
bool isF16() const { return getKind() == Kind::F16; }
static OtherType *getTFControl(MLIRContext *ctx);
static OtherType *getTFString(MLIRContext *ctx);
static OtherType *getTFResource(MLIRContext *ctx);
+ static OtherType *getTFVariant(MLIRContext *ctx);
/// Print the current type.
void print(raw_ostream &os) const;
inline OtherType *Type::getTFString(MLIRContext *ctx) {
return OtherType::get(Kind::TFString, ctx);
}
+inline OtherType *Type::getTFVariant(MLIRContext *ctx) {
+ return OtherType::get(Kind::TFVariant, ctx);
+}
/// Function types map from a list of inputs to a list of results.
class FunctionType : public Type {
~MemRefType() = delete;
};
+/// Return true if the specified element type is ok in a tensor.
+static bool isValidTensorElementType(Type *type) {
+ return isa<FloatType>(type) || isa<VectorType>(type) ||
+ isa<IntegerType>(type) || isa<OtherType>(type);
+}
+
} // end namespace mlir
#endif // MLIR_IR_TYPES_H
case Type::Kind::TFResource:
os << "tf_resource";
return;
+ case Type::Kind::TFVariant:
+ os << "tf_variant";
+ return;
case Type::Kind::TFString:
os << "tf_string";
return;
OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
+OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
+
OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
IntegerType *Builder::getIntegerType(unsigned width) {
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
shapeElements(shape.data()) {}
-/// Return true if the specified element type is ok in a tensor.
-static bool isValidTensorElementType(Type *type, MLIRContext *context) {
- return isa<FloatType>(type) || isa<VectorType>(type) ||
- isa<IntegerType>(type) || type == Type::getTFString(context) ||
- type == Type::getTFResource(context);
-}
-
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
: VectorOrTensorType(kind, context, elementType) {
- assert(isValidTensorElementType(elementType, context));
+ assert(isValidTensorElementType(elementType));
}
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
case Token::kw_tf_resource:
consumeToken(Token::kw_tf_resource);
return builder.getTFResourceType();
+ case Token::kw_tf_variant:
+ consumeToken(Token::kw_tf_variant);
+ return builder.getTFVariantType();
case Token::kw_tf_string:
consumeToken(Token::kw_tf_string);
return builder.getTFStringType();
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
return nullptr;
- if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) &&
- !isa<VectorType>(elementType))
+ if (!isValidTensorElementType(elementType))
return (emitError(typeLoc, "invalid tensor element type"), nullptr);
if (isUnranked)
TOK_KEYWORD(tensor)
TOK_KEYWORD(tf_control)
TOK_KEYWORD(tf_resource)
+TOK_KEYWORD(tf_variant)
TOK_KEYWORD(tf_string)
TOK_KEYWORD(to)
TOK_KEYWORD(true)