From 983e0eea9532bdb54a945d522da8d5c1c55a6256 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 6 May 2019 12:40:43 -0700 Subject: [PATCH] Simplify several usages of attributes now that they always have a type and, transitively, access to the context. This also fixes a bug where FunctionAttrs were not being remapped for function and function argument attributes. -- PiperOrigin-RevId: 246876924 --- mlir/bindings/python/pybind.cpp | 4 +-- mlir/include/mlir/IR/Attributes.h | 18 +++++++----- mlir/include/mlir/IR/Function.h | 17 ++++++----- mlir/include/mlir/IR/Operation.h | 9 +++--- mlir/lib/IR/AttributeDetail.h | 3 +- mlir/lib/IR/Attributes.cpp | 49 +++++++++++++++++-------------- mlir/lib/IR/Builders.cpp | 6 ++-- mlir/lib/IR/Function.cpp | 6 ++-- mlir/lib/IR/MLIRContext.cpp | 6 ++-- mlir/lib/IR/Operation.cpp | 2 +- mlir/lib/Transforms/DialectConversion.cpp | 4 +-- mlir/lib/Transforms/Utils/Utils.cpp | 27 ++++++++++++----- mlir/test/IR/parser.mlir | 12 ++++++++ 13 files changed, 95 insertions(+), 68 deletions(-) diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index 5f90a71..720f381e 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -571,7 +571,7 @@ PythonMLIRModule::declareFunction(const std::string &name, inAttrs.emplace_back(Identifier::get(named.name, &mlirContext), mlir::Attribute::getFromOpaquePointer( reinterpret_cast(named.value))); - inputAttrs.emplace_back(&mlirContext, inAttrs); + inputAttrs.emplace_back(inAttrs); } // Create the function itself. @@ -634,7 +634,7 @@ PYBIND11_MODULE(pybind, m) { }); m.def("constant_function", [](PythonFunction func) -> PythonValueHandle { auto *function = reinterpret_cast(func.function); - auto attr = FunctionAttr::get(function, function->getContext()); + auto attr = FunctionAttr::get(function); return ValueHandle::create(function->getType(), attr); }); m.def("appendTo", [](const PythonBlockHandle &handle) { diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 2ff4937..56b8c7b 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -127,6 +127,9 @@ public: /// Return the type of this attribute. Type getType() const; + /// Return the context this attribute belongs to. + MLIRContext *getContext() const; + /// Return true if this field is, or contains, a function attribute. bool isOrContainsFunction() const; @@ -135,8 +138,7 @@ public: /// remapping table. Return the original attribute if it (or any of nested /// attributes) is not present in the table. Attribute remapFunctionAttrs( - const llvm::DenseMap &remappingTable, - MLIRContext *context) const; + const llvm::DenseMap &remappingTable) const; /// Print the attribute. void print(raw_ostream &os) const; @@ -299,7 +301,7 @@ public: using ImplType = detail::TypeAttributeStorage; using ValueType = Type; - static TypeAttr get(Type type, MLIRContext *context); + static TypeAttr get(Type value); Type getValue() const; @@ -320,7 +322,7 @@ public: using ImplType = detail::FunctionAttributeStorage; using ValueType = Function *; - static FunctionAttr get(Function *value, MLIRContext *context); + static FunctionAttr get(Function *value); Function *getValue() const; @@ -642,13 +644,13 @@ using NamedAttribute = std::pair; class NamedAttributeList { public: NamedAttributeList() : attrs(nullptr) {} - NamedAttributeList(MLIRContext *context, ArrayRef attributes); + NamedAttributeList(ArrayRef attributes); /// Return all of the attributes on this operation. ArrayRef getAttrs() const; /// Replace the held attributes with ones provided in 'newAttrs'. - void setAttrs(MLIRContext *context, ArrayRef attributes); + void setAttrs(ArrayRef attributes); /// Return the specified attribute if present, null otherwise. Attribute get(StringRef name) const; @@ -656,13 +658,13 @@ public: /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. - void set(MLIRContext *context, Identifier name, Attribute value); + void set(Identifier name, Attribute value); enum class RemoveResult { Removed, NotFound }; /// Remove the attribute with the specified name if it exists. The return /// value indicates whether the attribute was present or not. - RemoveResult remove(MLIRContext *context, Identifier name); + RemoveResult remove(Identifier name); private: detail::AttributeListStorage *attrs; diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 8a5b28b..6860e1c 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -157,6 +157,9 @@ public: /// Return all of the attributes on this function. ArrayRef getAttrs() { return attrs.getAttrs(); } + /// Return the internal attribute list on this function. + NamedAttributeList &getAttrList() { return attrs; } + /// Return all of the attributes for the argument at 'index'. ArrayRef getArgAttrs(unsigned index) { assert(index < getNumArguments() && "invalid argument number"); @@ -165,13 +168,13 @@ public: /// Set the attributes held by this function. void setAttrs(ArrayRef attributes) { - attrs.setAttrs(getContext(), attributes); + attrs.setAttrs(attributes); } /// Set the attributes held by the argument at 'index'. void setArgAttrs(unsigned index, ArrayRef attributes) { assert(index < getNumArguments() && "invalid argument number"); - argAttrs[index].setAttrs(getContext(), attributes); + argAttrs[index].setAttrs(attributes); } /// Return all argument attributes of this function. @@ -212,15 +215,13 @@ public: /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. - void setAttr(Identifier name, Attribute value) { - attrs.set(getContext(), name, value); - } + void setAttr(Identifier name, Attribute value) { attrs.set(name, value); } void setAttr(StringRef name, Attribute value) { setAttr(Identifier::get(name, getContext()), value); } void setArgAttr(unsigned index, Identifier name, Attribute value) { assert(index < getNumArguments() && "invalid argument number"); - argAttrs[index].set(getContext(), name, value); + argAttrs[index].set(name, value); } void setArgAttr(unsigned index, StringRef name, Attribute value) { setArgAttr(index, Identifier::get(name, getContext()), value); @@ -229,12 +230,12 @@ public: /// Remove the attribute with the specified name if it exists. The return /// value indicates whether the attribute was present or not. NamedAttributeList::RemoveResult removeAttr(Identifier name) { - return attrs.remove(getContext(), name); + return attrs.remove(name); } NamedAttributeList::RemoveResult removeArgAttr(unsigned index, Identifier name) { assert(index < getNumArguments() && "invalid argument number"); - return attrs.remove(getContext(), name); + return attrs.remove(name); } //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 0f1f3e9..9f605aa 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -239,6 +239,9 @@ public: /// Return all of the attributes on this operation. ArrayRef getAttrs() { return attrs.getAttrs(); } + /// Return the internal attribute list on this operation. + NamedAttributeList &getAttrList() { return attrs; } + /// Return the specified attribute if present, null otherwise. Attribute getAttr(Identifier name) { return attrs.get(name); } Attribute getAttr(StringRef name) { return attrs.get(name); } @@ -253,9 +256,7 @@ public: /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. - void setAttr(Identifier name, Attribute value) { - attrs.set(getContext(), name, value); - } + void setAttr(Identifier name, Attribute value) { attrs.set(name, value); } void setAttr(StringRef name, Attribute value) { setAttr(Identifier::get(name, getContext()), value); } @@ -263,7 +264,7 @@ public: /// Remove the attribute with the specified name if it exists. The return /// value indicates whether the attribute was present or not. NamedAttributeList::RemoveResult removeAttr(Identifier name) { - return attrs.remove(getContext(), name); + return attrs.remove(name); } //===--------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index d1802f4..aab4445 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -408,8 +408,7 @@ public: /// Given a list of NamedAttribute's, canonicalize the list (sorting /// by name) and return the unique'd result. Note that the empty list is /// represented with a null pointer. - static AttributeListStorage *get(ArrayRef attrs, - MLIRContext *context); + static AttributeListStorage *get(ArrayRef attrs); /// Return the element constants for this aggregate constant. These are /// known to all be constants. diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index ab15699..62c3c93 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -67,6 +67,9 @@ Attribute::Kind Attribute::getKind() const { /// Return the type of this attribute. Type Attribute::getType() const { return attr->getType(); } +/// Return the context this attribute belongs to. +MLIRContext *Attribute::getContext() const { return getType().getContext(); } + bool Attribute::isOrContainsFunction() const { return attr->isOrContainsFunctionCache(); } @@ -75,8 +78,7 @@ bool Attribute::isOrContainsFunction() const { // table, walk it and rewrite it to use the mapped function. If it doesn't // refer to anything in the table, then it is returned unmodified. Attribute Attribute::remapFunctionAttrs( - const llvm::DenseMap &remappingTable, - MLIRContext *context) const { + const llvm::DenseMap &remappingTable) const { // Most attributes are trivially unrelated to function attributes, skip them // rapidly. if (!isOrContainsFunction()) @@ -93,7 +95,7 @@ Attribute Attribute::remapFunctionAttrs( SmallVector remappedElts; bool anyChange = false; for (auto elt : arrayAttr.getValue()) { - auto newElt = elt.remapFunctionAttrs(remappingTable, context); + auto newElt = elt.remapFunctionAttrs(remappingTable); remappedElts.push_back(newElt); anyChange |= (elt != newElt); } @@ -101,7 +103,7 @@ Attribute Attribute::remapFunctionAttrs( if (!anyChange) return *this; - return ArrayAttr::get(remappedElts, context); + return ArrayAttr::get(remappedElts, getContext()); } //===----------------------------------------------------------------------===// @@ -262,8 +264,9 @@ IntegerSet IntegerSetAttr::getValue() const { // TypeAttr //===----------------------------------------------------------------------===// -TypeAttr TypeAttr::get(Type value, MLIRContext *context) { - return AttributeUniquer::get(context, Attribute::Kind::Type, value); +TypeAttr TypeAttr::get(Type value) { + return AttributeUniquer::get(value.getContext(), + Attribute::Kind::Type, value); } Type TypeAttr::getValue() const { return static_cast(attr)->value; } @@ -272,10 +275,10 @@ Type TypeAttr::getValue() const { return static_cast(attr)->value; } // FunctionAttr //===----------------------------------------------------------------------===// -FunctionAttr FunctionAttr::get(Function *value, MLIRContext *context) { +FunctionAttr FunctionAttr::get(Function *value) { assert(value && "Cannot get FunctionAttr for a null function"); - return AttributeUniquer::get(context, Attribute::Kind::Function, - value); + return AttributeUniquer::get(value->getContext(), + Attribute::Kind::Function, value); } /// This function is used by the internals of the Function class to null out @@ -737,9 +740,8 @@ Attribute SparseElementsAttr::getValue(ArrayRef index) const { // NamedAttributeList //===----------------------------------------------------------------------===// -NamedAttributeList::NamedAttributeList(MLIRContext *context, - ArrayRef attributes) { - setAttrs(context, attributes); +NamedAttributeList::NamedAttributeList(ArrayRef attributes) { + setAttrs(attributes); } /// Return all of the attributes on this operation. @@ -748,8 +750,7 @@ ArrayRef NamedAttributeList::getAttrs() const { } /// Replace the held attributes with ones provided in 'newAttrs'. -void NamedAttributeList::setAttrs(MLIRContext *context, - ArrayRef attributes) { +void NamedAttributeList::setAttrs(ArrayRef attributes) { // Don't create an attribute list if there are no attributes. if (attributes.empty()) { attrs = nullptr; @@ -759,7 +760,7 @@ void NamedAttributeList::setAttrs(MLIRContext *context, assert(llvm::all_of(attributes, [](const NamedAttribute &attr) { return attr.second; }) && "attributes cannot have null entries"); - attrs = AttributeListStorage::get(attributes, context); + attrs = AttributeListStorage::get(attributes); } /// Return the specified attribute if present, null otherwise. @@ -778,8 +779,7 @@ Attribute NamedAttributeList::get(Identifier name) const { /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. -void NamedAttributeList::set(MLIRContext *context, Identifier name, - Attribute value) { +void NamedAttributeList::set(Identifier name, Attribute value) { assert(value && "attributes may never be null"); // If we already have this attribute, replace it. @@ -788,27 +788,32 @@ void NamedAttributeList::set(MLIRContext *context, Identifier name, for (auto &elt : newAttrs) if (elt.first == name) { elt.second = value; - attrs = AttributeListStorage::get(newAttrs, context); + attrs = AttributeListStorage::get(newAttrs); return; } // Otherwise, add it. newAttrs.push_back({name, value}); - attrs = AttributeListStorage::get(newAttrs, context); + attrs = AttributeListStorage::get(newAttrs); } /// Remove the attribute with the specified name if it exists. The return /// value indicates whether the attribute was present or not. -auto NamedAttributeList::remove(MLIRContext *context, Identifier name) - -> RemoveResult { +auto NamedAttributeList::remove(Identifier name) -> RemoveResult { auto origAttrs = getAttrs(); for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { if (origAttrs[i].first == name) { + // Handle the simple case of removing the only attribute in the list. + if (e == 1) { + attrs = nullptr; + return RemoveResult::Removed; + } + SmallVector newAttrs; newAttrs.reserve(origAttrs.size() - 1); newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); - attrs = AttributeListStorage::get(newAttrs, context); + attrs = AttributeListStorage::get(newAttrs); return RemoveResult::Removed; } } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index af066ba..a6036a9 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -167,12 +167,10 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { return IntegerSetAttr::get(set); } -TypeAttr Builder::getTypeAttr(Type type) { - return TypeAttr::get(type, context); -} +TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); } FunctionAttr Builder::getFunctionAttr(Function *value) { - return FunctionAttr::get(value, context); + return FunctionAttr::get(value); } ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type, diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 7651abf..1f9a1f5 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -30,15 +30,13 @@ using namespace mlir; Function::Function(Location location, StringRef name, FunctionType type, ArrayRef attrs) : name(Identifier::get(name, type.getContext())), location(location), - type(type), attrs(type.getContext(), attrs), - argAttrs(type.getNumInputs()), body(this) {} + type(type), attrs(attrs), argAttrs(type.getNumInputs()), body(this) {} Function::Function(Location location, StringRef name, FunctionType type, ArrayRef attrs, ArrayRef argAttrs) : name(Identifier::get(name, type.getContext())), location(location), - type(type), attrs(type.getContext(), attrs), argAttrs(argAttrs), - body(this) {} + type(type), attrs(attrs), argAttrs(argAttrs), body(this) {} Function::~Function() { // Clean up function attributes referring to this function. diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index ac041ae..249a1e1 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -849,8 +849,8 @@ static int compareNamedAttributes(const NamedAttribute *lhs, /// Given a list of NamedAttribute's, canonicalize the list (sorting /// by name) and return the unique'd result. Note that the empty list is /// represented with a null pointer. -AttributeListStorage *AttributeListStorage::get(ArrayRef attrs, - MLIRContext *context) { +AttributeListStorage * +AttributeListStorage::get(ArrayRef attrs) { // We need to sort the element list to canonicalize it, but we also don't want // to do a ton of work in the super common case where the element list is // already sorted. @@ -888,7 +888,7 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef attrs, } } - auto &impl = context->getImpl(); + auto &impl = attrs[0].second.getContext()->getImpl(); // Safely get or create an attribute instance. return safeGetOrCreate(impl.attributeLists, attrs, impl.attributeMutex, [&] { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 2c97988..91b32de 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -102,7 +102,7 @@ Operation *Operation::create(Location location, OperationName name, ArrayRef successors, unsigned numRegions, bool resizableOperandList, MLIRContext *context) { return create(location, name, operands, resultTypes, - NamedAttributeList(context, attributes), successors, numRegions, + NamedAttributeList(attributes), successors, numRegions, resizableOperandList, context); } diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 66c3b2d..831a68a 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -316,8 +316,8 @@ LogicalResult impl::FunctionConversion::run(Module *module) { if (!converted) return failure(); - auto origFuncAttr = FunctionAttr::get(func, context); - auto convertedFuncAttr = FunctionAttr::get(converted, context); + auto origFuncAttr = FunctionAttr::get(func); + auto convertedFuncAttr = FunctionAttr::get(converted); convertedFuncs.push_back(converted); functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr}); } diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 422d6b1..1ab821a 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -290,33 +290,44 @@ void mlir::createAffineComputationSlice( } } -void mlir::remapFunctionAttrs( - Operation &op, const DenseMap &remappingTable) { - for (auto attr : op.getAttrs()) { +static void +remapFunctionAttrs(NamedAttributeList &attrs, + const DenseMap &remappingTable) { + for (auto attr : attrs.getAttrs()) { // Do the remapping, if we got the same thing back, then it must contain // functions that aren't getting remapped. - auto newVal = - attr.second.remapFunctionAttrs(remappingTable, op.getContext()); + auto newVal = attr.second.remapFunctionAttrs(remappingTable); if (newVal == attr.second) continue; // Otherwise, replace the existing attribute with the new one. It is safe // to mutate the attribute list while we walk it because underlying // attribute lists are uniqued and immortal. - op.setAttr(attr.first, newVal); + attrs.set(attr.first, newVal); } } void mlir::remapFunctionAttrs( + Operation &op, const DenseMap &remappingTable) { + ::remapFunctionAttrs(op.getAttrList(), remappingTable); +} + +void mlir::remapFunctionAttrs( Function &fn, const DenseMap &remappingTable) { + // Remap the attributes of the function. + ::remapFunctionAttrs(fn.getAttrList(), remappingTable); + + // Remap the attributes of the arguments of this function. + for (auto &attrList : fn.getAllArgAttrs()) + ::remapFunctionAttrs(attrList, remappingTable); + // Look at all operations in a Function. fn.walk([&](Operation *op) { remapFunctionAttrs(*op, remappingTable); }); } void mlir::remapFunctionAttrs( Module &module, const DenseMap &remappingTable) { - for (auto &fn : module) { + for (auto &fn : module) remapFunctionAttrs(fn, remappingTable); - } } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index a565c3b..2b28f80 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -901,3 +901,15 @@ func @none_type() { %none_val = "foo.unknown_op"() : () -> none return } + +// CHECK-LABEL: func @fn_attr_remap +// CHECK: {some_dialect.arg_attr: @fn_attr_ref : () -> ()} +func @fn_attr_remap(%arg0: i1 {some_dialect.arg_attr: @fn_attr_ref : () -> ()}) -> () + // CHECK-NEXT: {some_dialect.fn_attr: @fn_attr_ref : () -> ()} + attributes {some_dialect.fn_attr: @fn_attr_ref : () -> ()} { + return +} + +// CHECK-LABEL: func @fn_attr_ref +func @fn_attr_ref() -> () + -- 2.7.4