From 17d3acf40c764b2c06ffcce3cab3151a4b5d09b4 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 30 Apr 2019 14:26:04 -0700 Subject: [PATCH] Ensure that every Attribute contains a Type. If an Attribute does not provide a type explicitly, the type is defaulted to NoneType. -- PiperOrigin-RevId: 246021088 --- mlir/include/mlir/IR/Attributes.h | 13 +--- mlir/lib/IR/AttributeDetail.h | 122 +++++++++++++++++++++----------------- mlir/lib/IR/Attributes.cpp | 36 ++++------- mlir/lib/StandardOps/Ops.cpp | 19 +----- 4 files changed, 83 insertions(+), 107 deletions(-) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 5deb2bb..9d09c38 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -47,7 +47,6 @@ struct AffineMapAttributeStorage; struct IntegerSetAttributeStorage; struct TypeAttributeStorage; struct FunctionAttributeStorage; -struct ElementsAttributeStorage; struct SplatElementsAttributeStorage; struct DenseElementsAttributeStorage; struct DenseIntElementsAttributeStorage; @@ -125,6 +124,9 @@ public: /// Return the classification for this attribute. Kind getKind() const; + /// Return the type of this attribute. + Type getType() const; + /// Return true if this field is, or contains, a function attribute. bool isOrContainsFunction() const; @@ -177,8 +179,6 @@ class NumericAttr : public Attribute { public: using Attribute::Attribute; - Type getType() const; - static bool kindof(Kind kind); }; @@ -192,8 +192,6 @@ public: bool getValue() const; - Type getType() const; - /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool kindof(Kind kind) { return kind == Kind::Bool; } }; @@ -211,8 +209,6 @@ public: // TODO(jpienaar): Change callers to use getValue instead. int64_t getInt() const; - Type getType() const; - /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool kindof(Kind kind) { return kind == Kind::Integer; } }; @@ -238,8 +234,6 @@ public: /// precision. double getValueAsDouble() const; - Type getType() const; - /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool kindof(Kind kind) { return kind == Kind::Float; } }; @@ -353,7 +347,6 @@ public: class ElementsAttr : public NumericAttr { public: using NumericAttr::NumericAttr; - using ImplType = detail::ElementsAttributeStorage; VectorOrTensorType getType() const; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 89ac240..3db5ec8 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -24,23 +24,43 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/Function.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/StorageUniquer.h" #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/PointerIntPair.h" #include "llvm/Support/TrailingObjects.h" namespace mlir { namespace detail { -/// Base storage class appearing in an attribute. +/// Base storage class appearing in an attribute. Derived storage classes should +/// only be constructed within the context of the AttributeUniquer. struct AttributeStorage : public StorageUniquer::BaseStorage { - AttributeStorage(bool isOrContainsFunctionCache = false) - : isOrContainsFunctionCache(isOrContainsFunctionCache) {} - - /// This field is true if this is, or contains, a function attribute. - bool isOrContainsFunctionCache : 1; + /// Construct a new attribute storage instance with the given type and a + /// boolean that signals if the derived attribute is or contains a function + /// pointer. + /// Note: All attributes require a valid type. If a null type is provided + /// here, the type of the attribute will automatically default to + /// NoneType upon initialization in the uniquer. + AttributeStorage(Type type = {}, bool isOrContainsFunctionCache = false) + : typeAndContainsFunctionAttrPair(type, isOrContainsFunctionCache) {} + AttributeStorage(bool isOrContainsFunctionCache) + : AttributeStorage(/*type=*/{}, isOrContainsFunctionCache) {} + + bool isOrContainsFunctionCache() const { + return typeAndContainsFunctionAttrPair.getInt(); + } + + Type getType() const { return typeAndContainsFunctionAttrPair.getPointer(); } + void setType(Type type) { typeAndContainsFunctionAttrPair.setPointer(type); } + + /// This field is a pair of: + /// - The type of the attribute value. + /// - A boolean that is true if this is, or contains, a function attribute. + llvm::PointerIntPair typeAndContainsFunctionAttrPair; }; // A utility class to get, or create, unique instances of attributes within an @@ -54,7 +74,7 @@ public: !std::is_same::value, T>::type get(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) { return ctx->getAttributeUniquer().getComplex( - /*initFn=*/{}, static_cast(kind), + getInitFn(ctx), static_cast(kind), std::forward(args)...); } @@ -66,7 +86,7 @@ public: std::is_same::value, T>::type get(MLIRContext *ctx, Attribute::Kind kind) { return ctx->getAttributeUniquer().getSimple( - /*initFn=*/{}, static_cast(kind)); + getInitFn(ctx), static_cast(kind)); } /// Erase a uniqued instance of attribute T. This overload is used for @@ -78,6 +98,15 @@ public: return ctx->getAttributeUniquer().eraseComplex( static_cast(kind), std::forward(args)...); } + + /// Generate a functor to initialize a new attribute storage instance. + static std::function getInitFn(MLIRContext *ctx) { + return [ctx](AttributeStorage *storage) { + // If the attribute did not provide a type, then default to NoneType. + if (!storage->getType()) + storage->setType(NoneType::get(ctx)); + }; + } }; using AttributeStorageAllocator = StorageUniquer::StorageAllocator; @@ -86,7 +115,8 @@ using AttributeStorageAllocator = StorageUniquer::StorageAllocator; struct BoolAttributeStorage : public AttributeStorage { using KeyTy = std::pair; - BoolAttributeStorage(Type type, bool value) : type(type), value(value) {} + BoolAttributeStorage(Type type, bool value) + : AttributeStorage(type), value(value) {} /// We only check equality for and hash with the boolean key parameter. bool operator==(const KeyTy &key) const { return key.second == value; } @@ -100,7 +130,6 @@ struct BoolAttributeStorage : public AttributeStorage { BoolAttributeStorage(IntegerType::get(1, key.first), key.second); } - Type type; bool value; }; @@ -111,13 +140,13 @@ struct IntegerAttributeStorage final using KeyTy = std::pair; IntegerAttributeStorage(Type type, size_t numObjects) - : type(type), numObjects(numObjects) { + : AttributeStorage(type), numObjects(numObjects) { assert((type.isIndex() || type.isa()) && "invalid type"); } /// Key equality and hash functions. bool operator==(const KeyTy &key) const { - return key == KeyTy(type, getValue()); + return key == KeyTy(getType(), getValue()); } static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(key.first, llvm::hash_value(key.second)); @@ -142,13 +171,12 @@ struct IntegerAttributeStorage final /// Returns an APInt representing the stored value. APInt getValue() const { - if (type.isIndex()) + if (getType().isIndex()) return APInt(64, {getTrailingObjects(), numObjects}); - return APInt(type.getIntOrFloatBitWidth(), + return APInt(getType().getIntOrFloatBitWidth(), {getTrailingObjects(), numObjects}); } - Type type; size_t numObjects; }; @@ -160,12 +188,11 @@ struct FloatAttributeStorage final FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type, size_t numObjects) - : semantics(semantics), type(type.cast()), - numObjects(numObjects) {} + : AttributeStorage(type), semantics(semantics), numObjects(numObjects) {} /// Key equality and hash functions. bool operator==(const KeyTy &key) const { - return key.first == type && key.second.bitwiseIsEqual(getValue()); + return key.first == getType() && key.second.bitwiseIsEqual(getValue()); } static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(key.first, llvm::hash_value(key.second)); @@ -197,7 +224,6 @@ struct FloatAttributeStorage final } const llvm::fltSemantics &semantics; - FloatType type; size_t numObjects; }; @@ -307,7 +333,8 @@ struct FunctionAttributeStorage : public AttributeStorage { using KeyTy = Function *; FunctionAttributeStorage(Function *value) - : AttributeStorage(/*isOrContainsFunctionCache=*/true), value(value) {} + : AttributeStorage(value->getType(), /*isOrContainsFunctionCache=*/true), + value(value) {} /// Key equality function. bool operator==(const KeyTy &key) const { return key == value; } @@ -329,26 +356,17 @@ struct FunctionAttributeStorage : public AttributeStorage { Function *value; }; -/// A base attribute representing a reference to a vector or tensor constant. -struct ElementsAttributeStorage : public AttributeStorage { - ElementsAttributeStorage(VectorOrTensorType type) : type(type) {} - VectorOrTensorType type; -}; - /// An attribute representing a reference to a vector or tensor constant, /// inwhich all elements have the same value. -struct SplatElementsAttributeStorage : public ElementsAttributeStorage { - using KeyTy = std::pair; +struct SplatElementsAttributeStorage : public AttributeStorage { + using KeyTy = std::pair; - SplatElementsAttributeStorage(VectorOrTensorType type, Attribute elt) - : ElementsAttributeStorage(type), elt(elt) {} + SplatElementsAttributeStorage(Type type, Attribute elt) + : AttributeStorage(type), elt(elt) {} /// Key equality and hash functions. bool operator==(const KeyTy &key) const { - return key == std::make_pair(type, elt); - } - static unsigned hashKey(const KeyTy &key) { - return llvm::hash_combine(key.first, key.second); + return key == std::make_pair(getType(), elt); } /// Construct a new storage instance. @@ -362,16 +380,15 @@ struct SplatElementsAttributeStorage : public ElementsAttributeStorage { }; /// An attribute representing a reference to a dense vector or tensor object. -struct DenseElementsAttributeStorage : public ElementsAttributeStorage { - using KeyTy = std::pair>; +struct DenseElementsAttributeStorage : public AttributeStorage { + using KeyTy = std::pair>; - DenseElementsAttributeStorage(VectorOrTensorType ty, ArrayRef data) - : ElementsAttributeStorage(ty), data(data) {} + DenseElementsAttributeStorage(Type ty, ArrayRef data) + : AttributeStorage(ty), data(data) {} /// Key equality and hash functions. - bool operator==(const KeyTy &key) const { return key == KeyTy(type, data); } - static unsigned hashKey(const KeyTy &key) { - return llvm::hash_combine(key.first, key.second); + bool operator==(const KeyTy &key) const { + return key == KeyTy(getType(), data); } /// Construct a new storage instance. @@ -398,16 +415,15 @@ struct DenseElementsAttributeStorage : public ElementsAttributeStorage { /// An attribute representing a reference to a tensor constant with opaque /// content. -struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage { - using KeyTy = std::tuple; +struct OpaqueElementsAttributeStorage : public AttributeStorage { + using KeyTy = std::tuple; - OpaqueElementsAttributeStorage(VectorOrTensorType type, Dialect *dialect, - StringRef bytes) - : ElementsAttributeStorage(type), dialect(dialect), bytes(bytes) {} + OpaqueElementsAttributeStorage(Type type, Dialect *dialect, StringRef bytes) + : AttributeStorage(type), dialect(dialect), bytes(bytes) {} /// Key equality and hash functions. bool operator==(const KeyTy &key) const { - return key == std::make_tuple(type, dialect, bytes); + return key == std::make_tuple(getType(), dialect, bytes); } static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(std::get<0>(key), std::get<1>(key), @@ -429,18 +445,16 @@ struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage { }; /// An attribute representing a reference to a sparse vector or tensor object. -struct SparseElementsAttributeStorage : public ElementsAttributeStorage { - using KeyTy = - std::tuple; +struct SparseElementsAttributeStorage : public AttributeStorage { + using KeyTy = std::tuple; - SparseElementsAttributeStorage(VectorOrTensorType type, - DenseIntElementsAttr indices, + SparseElementsAttributeStorage(Type type, DenseIntElementsAttr indices, DenseElementsAttr values) - : ElementsAttributeStorage(type), indices(indices), values(values) {} + : AttributeStorage(type), indices(indices), values(values) {} /// Key equality and hash functions. bool operator==(const KeyTy &key) const { - return key == std::make_tuple(type, indices, values); + return key == std::make_tuple(getType(), indices, values); } static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(std::get<0>(key), std::get<1>(key), diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 4df46df..aa1127c 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -30,8 +30,11 @@ Attribute::Kind Attribute::getKind() const { return static_cast(attr->getKind()); } +/// Return the type of this attribute. +Type Attribute::getType() const { return attr->getType(); } + bool Attribute::isOrContainsFunction() const { - return attr->isOrContainsFunctionCache; + return attr->isOrContainsFunctionCache(); } // Given an attribute that could refer to a function attribute in the remapping @@ -79,19 +82,6 @@ UnitAttr UnitAttr::get(MLIRContext *context) { // NumericAttr //===----------------------------------------------------------------------===// -Type NumericAttr::getType() const { - if (auto boolAttr = dyn_cast()) - return boolAttr.getType(); - if (auto intAttr = dyn_cast()) - return intAttr.getType(); - if (auto floatAttr = dyn_cast()) - return floatAttr.getType(); - if (auto elemAttr = dyn_cast()) - return elemAttr.getType(); - - llvm_unreachable("unhandled NumericAttr subclass"); -} - bool NumericAttr::kindof(Kind kind) { return BoolAttr::kindof(kind) || IntegerAttr::kindof(kind) || FloatAttr::kindof(kind) || ElementsAttr::kindof(kind); @@ -109,8 +99,6 @@ BoolAttr BoolAttr::get(bool value, MLIRContext *context) { bool BoolAttr::getValue() const { return static_cast(attr)->value; } -Type BoolAttr::getType() const { return static_cast(attr)->type; } - //===----------------------------------------------------------------------===// // IntegerAttr //===----------------------------------------------------------------------===// @@ -135,10 +123,6 @@ APInt IntegerAttr::getValue() const { int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } -Type IntegerAttr::getType() const { - return static_cast(attr)->type; -} - //===----------------------------------------------------------------------===// // FloatAttr //===----------------------------------------------------------------------===// @@ -185,8 +169,6 @@ APFloat FloatAttr::getValue() const { return static_cast(attr)->getValue(); } -Type FloatAttr::getType() const { return static_cast(attr)->type; } - double FloatAttr::getValueAsDouble() const { const auto &semantics = getType().cast().getFloatSemantics(); auto value = getValue(); @@ -281,14 +263,16 @@ Function *FunctionAttr::getValue() const { return static_cast(attr)->value; } -FunctionType FunctionAttr::getType() const { return getValue()->getType(); } +FunctionType FunctionAttr::getType() const { + return Attribute::getType().cast(); +} //===----------------------------------------------------------------------===// // ElementsAttr //===----------------------------------------------------------------------===// VectorOrTensorType ElementsAttr::getType() const { - return static_cast(attr)->type; + return Attribute::getType().cast(); } /// Return the value at the given index. If index does not refer to a valid @@ -315,8 +299,8 @@ Attribute ElementsAttr::getValue(ArrayRef index) const { SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type, Attribute elt) { - assert(elt.cast().getType() == type.getElementType() && - "value should be of the given type"); + assert(elt.getType() == type.getElementType() && + "value should be of the given element type"); return AttributeUniquer::get( type.getContext(), Attribute::Kind::SplatElements, type, elt); } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 3aa35cd..a05d83d 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -920,24 +920,11 @@ void ConstantOp::build(Builder *builder, OperationState *result, Type type, result->types.push_back(type); } -// Extracts and returns a type of an attribute if it has one. Returns a null -// type otherwise. Currently, NumericAttrs and FunctionAttrs have types. -static Type getAttributeType(Attribute attr) { - assert(attr && "expected non-null attribute"); - if (auto numericAttr = attr.dyn_cast()) - return numericAttr.getType(); - if (auto functionAttr = attr.dyn_cast()) - return functionAttr.getType(); - return {}; -} - /// Builds a constant with the specified attribute value and type extracted /// from the attribute. The attribute must have a type. void ConstantOp::build(Builder *builder, OperationState *result, Attribute value) { - Type t = getAttributeType(value); - assert(t && "expected an attribute with a type"); - return build(builder, result, t, value); + return build(builder, result, value.getType(), value); } void ConstantOp::print(OpAsmPrinter *p) { @@ -1018,9 +1005,7 @@ LogicalResult ConstantOp::verify() { return success(); } - auto attrType = getAttributeType(value); - if (!attrType) - return emitOpError("requires 'value' attribute to have a type"); + auto attrType = value.getType(); if (attrType != type) return emitOpError("requires the type of the 'value' attribute to match " "that of the operation result"); -- 2.7.4