Simplify several usages of attributes now that they always have a type and, trans...
authorRiver Riddle <riverriddle@google.com>
Mon, 6 May 2019 19:40:43 +0000 (12:40 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sat, 11 May 2019 02:22:41 +0000 (19:22 -0700)
    This also fixes a bug where FunctionAttrs were not being remapped for function and function argument attributes.

--

PiperOrigin-RevId: 246876924

13 files changed:
mlir/bindings/python/pybind.cpp
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Function.h
mlir/include/mlir/IR/Operation.h
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/Function.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/lib/Transforms/Utils/Utils.cpp
mlir/test/IR/parser.mlir

index 5f90a71..720f381 100644 (file)
@@ -571,7 +571,7 @@ PythonMLIRModule::declareFunction(const std::string &name,
       inAttrs.emplace_back(Identifier::get(named.name, &mlirContext),
                            mlir::Attribute::getFromOpaquePointer(
                                reinterpret_cast<const void *>(named.value)));
-    inputAttrs.emplace_back(&mlirContext, inAttrs);
+    inputAttrs.emplace_back(inAttrs);
   }
 
   // Create the function itself.
@@ -634,7 +634,7 @@ PYBIND11_MODULE(pybind, m) {
   });
   m.def("constant_function", [](PythonFunction func) -> PythonValueHandle {
     auto *function = reinterpret_cast<Function *>(func.function);
-    auto attr = FunctionAttr::get(function, function->getContext());
+    auto attr = FunctionAttr::get(function);
     return ValueHandle::create<ConstantOp>(function->getType(), attr);
   });
   m.def("appendTo", [](const PythonBlockHandle &handle) {
index 2ff4937..56b8c7b 100644 (file)
@@ -127,6 +127,9 @@ public:
   /// Return the type of this attribute.
   Type getType() const;
 
+  /// Return the context this attribute belongs to.
+  MLIRContext *getContext() const;
+
   /// Return true if this field is, or contains, a function attribute.
   bool isOrContainsFunction() const;
 
@@ -135,8 +138,7 @@ public:
   /// remapping table.  Return the original attribute if it (or any of nested
   /// attributes) is not present in the table.
   Attribute remapFunctionAttrs(
-      const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable,
-      MLIRContext *context) const;
+      const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable) const;
 
   /// Print the attribute.
   void print(raw_ostream &os) const;
@@ -299,7 +301,7 @@ public:
   using ImplType = detail::TypeAttributeStorage;
   using ValueType = Type;
 
-  static TypeAttr get(Type type, MLIRContext *context);
+  static TypeAttr get(Type value);
 
   Type getValue() const;
 
@@ -320,7 +322,7 @@ public:
   using ImplType = detail::FunctionAttributeStorage;
   using ValueType = Function *;
 
-  static FunctionAttr get(Function *value, MLIRContext *context);
+  static FunctionAttr get(Function *value);
 
   Function *getValue() const;
 
@@ -642,13 +644,13 @@ using NamedAttribute = std::pair<Identifier, Attribute>;
 class NamedAttributeList {
 public:
   NamedAttributeList() : attrs(nullptr) {}
-  NamedAttributeList(MLIRContext *context, ArrayRef<NamedAttribute> attributes);
+  NamedAttributeList(ArrayRef<NamedAttribute> attributes);
 
   /// Return all of the attributes on this operation.
   ArrayRef<NamedAttribute> getAttrs() const;
 
   /// Replace the held attributes with ones provided in 'newAttrs'.
-  void setAttrs(MLIRContext *context, ArrayRef<NamedAttribute> attributes);
+  void setAttrs(ArrayRef<NamedAttribute> attributes);
 
   /// Return the specified attribute if present, null otherwise.
   Attribute get(StringRef name) const;
@@ -656,13 +658,13 @@ public:
 
   /// If the an attribute exists with the specified name, change it to the new
   /// value.  Otherwise, add a new attribute with the specified name/value.
-  void set(MLIRContext *context, Identifier name, Attribute value);
+  void set(Identifier name, Attribute value);
 
   enum class RemoveResult { Removed, NotFound };
 
   /// Remove the attribute with the specified name if it exists.  The return
   /// value indicates whether the attribute was present or not.
-  RemoveResult remove(MLIRContext *context, Identifier name);
+  RemoveResult remove(Identifier name);
 
 private:
   detail::AttributeListStorage *attrs;
index 8a5b28b..6860e1c 100644 (file)
@@ -157,6 +157,9 @@ public:
   /// Return all of the attributes on this function.
   ArrayRef<NamedAttribute> getAttrs() { return attrs.getAttrs(); }
 
+  /// Return the internal attribute list on this function.
+  NamedAttributeList &getAttrList() { return attrs; }
+
   /// Return all of the attributes for the argument at 'index'.
   ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
     assert(index < getNumArguments() && "invalid argument number");
@@ -165,13 +168,13 @@ public:
 
   /// Set the attributes held by this function.
   void setAttrs(ArrayRef<NamedAttribute> attributes) {
-    attrs.setAttrs(getContext(), attributes);
+    attrs.setAttrs(attributes);
   }
 
   /// Set the attributes held by the argument at 'index'.
   void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) {
     assert(index < getNumArguments() && "invalid argument number");
-    argAttrs[index].setAttrs(getContext(), attributes);
+    argAttrs[index].setAttrs(attributes);
   }
 
   /// Return all argument attributes of this function.
@@ -212,15 +215,13 @@ public:
 
   /// If the an attribute exists with the specified name, change it to the new
   /// value.  Otherwise, add a new attribute with the specified name/value.
-  void setAttr(Identifier name, Attribute value) {
-    attrs.set(getContext(), name, value);
-  }
+  void setAttr(Identifier name, Attribute value) { attrs.set(name, value); }
   void setAttr(StringRef name, Attribute value) {
     setAttr(Identifier::get(name, getContext()), value);
   }
   void setArgAttr(unsigned index, Identifier name, Attribute value) {
     assert(index < getNumArguments() && "invalid argument number");
-    argAttrs[index].set(getContext(), name, value);
+    argAttrs[index].set(name, value);
   }
   void setArgAttr(unsigned index, StringRef name, Attribute value) {
     setArgAttr(index, Identifier::get(name, getContext()), value);
@@ -229,12 +230,12 @@ public:
   /// Remove the attribute with the specified name if it exists.  The return
   /// value indicates whether the attribute was present or not.
   NamedAttributeList::RemoveResult removeAttr(Identifier name) {
-    return attrs.remove(getContext(), name);
+    return attrs.remove(name);
   }
   NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
                                                  Identifier name) {
     assert(index < getNumArguments() && "invalid argument number");
-    return attrs.remove(getContext(), name);
+    return attrs.remove(name);
   }
 
   //===--------------------------------------------------------------------===//
index 0f1f3e9..9f605aa 100644 (file)
@@ -239,6 +239,9 @@ public:
   /// Return all of the attributes on this operation.
   ArrayRef<NamedAttribute> getAttrs() { return attrs.getAttrs(); }
 
+  /// Return the internal attribute list on this operation.
+  NamedAttributeList &getAttrList() { return attrs; }
+
   /// Return the specified attribute if present, null otherwise.
   Attribute getAttr(Identifier name) { return attrs.get(name); }
   Attribute getAttr(StringRef name) { return attrs.get(name); }
@@ -253,9 +256,7 @@ public:
 
   /// If the an attribute exists with the specified name, change it to the new
   /// value.  Otherwise, add a new attribute with the specified name/value.
-  void setAttr(Identifier name, Attribute value) {
-    attrs.set(getContext(), name, value);
-  }
+  void setAttr(Identifier name, Attribute value) { attrs.set(name, value); }
   void setAttr(StringRef name, Attribute value) {
     setAttr(Identifier::get(name, getContext()), value);
   }
@@ -263,7 +264,7 @@ public:
   /// Remove the attribute with the specified name if it exists.  The return
   /// value indicates whether the attribute was present or not.
   NamedAttributeList::RemoveResult removeAttr(Identifier name) {
-    return attrs.remove(getContext(), name);
+    return attrs.remove(name);
   }
 
   //===--------------------------------------------------------------------===//
index d1802f4..aab4445 100644 (file)
@@ -408,8 +408,7 @@ public:
   /// Given a list of NamedAttribute's, canonicalize the list (sorting
   /// by name) and return the unique'd result.  Note that the empty list is
   /// represented with a null pointer.
-  static AttributeListStorage *get(ArrayRef<NamedAttribute> attrs,
-                                   MLIRContext *context);
+  static AttributeListStorage *get(ArrayRef<NamedAttribute> attrs);
 
   /// Return the element constants for this aggregate constant.  These are
   /// known to all be constants.
index ab15699..62c3c93 100644 (file)
@@ -67,6 +67,9 @@ Attribute::Kind Attribute::getKind() const {
 /// Return the type of this attribute.
 Type Attribute::getType() const { return attr->getType(); }
 
+/// Return the context this attribute belongs to.
+MLIRContext *Attribute::getContext() const { return getType().getContext(); }
+
 bool Attribute::isOrContainsFunction() const {
   return attr->isOrContainsFunctionCache();
 }
@@ -75,8 +78,7 @@ bool Attribute::isOrContainsFunction() const {
 // table, walk it and rewrite it to use the mapped function.  If it doesn't
 // refer to anything in the table, then it is returned unmodified.
 Attribute Attribute::remapFunctionAttrs(
-    const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable,
-    MLIRContext *context) const {
+    const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable) const {
   // Most attributes are trivially unrelated to function attributes, skip them
   // rapidly.
   if (!isOrContainsFunction())
@@ -93,7 +95,7 @@ Attribute Attribute::remapFunctionAttrs(
   SmallVector<Attribute, 8> remappedElts;
   bool anyChange = false;
   for (auto elt : arrayAttr.getValue()) {
-    auto newElt = elt.remapFunctionAttrs(remappingTable, context);
+    auto newElt = elt.remapFunctionAttrs(remappingTable);
     remappedElts.push_back(newElt);
     anyChange |= (elt != newElt);
   }
@@ -101,7 +103,7 @@ Attribute Attribute::remapFunctionAttrs(
   if (!anyChange)
     return *this;
 
-  return ArrayAttr::get(remappedElts, context);
+  return ArrayAttr::get(remappedElts, getContext());
 }
 
 //===----------------------------------------------------------------------===//
@@ -262,8 +264,9 @@ IntegerSet IntegerSetAttr::getValue() const {
 // TypeAttr
 //===----------------------------------------------------------------------===//
 
-TypeAttr TypeAttr::get(Type value, MLIRContext *context) {
-  return AttributeUniquer::get<TypeAttr>(context, Attribute::Kind::Type, value);
+TypeAttr TypeAttr::get(Type value) {
+  return AttributeUniquer::get<TypeAttr>(value.getContext(),
+                                         Attribute::Kind::Type, value);
 }
 
 Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
@@ -272,10 +275,10 @@ Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
 // FunctionAttr
 //===----------------------------------------------------------------------===//
 
-FunctionAttr FunctionAttr::get(Function *value, MLIRContext *context) {
+FunctionAttr FunctionAttr::get(Function *value) {
   assert(value && "Cannot get FunctionAttr for a null function");
-  return AttributeUniquer::get<FunctionAttr>(context, Attribute::Kind::Function,
-                                             value);
+  return AttributeUniquer::get<FunctionAttr>(value->getContext(),
+                                             Attribute::Kind::Function, value);
 }
 
 /// This function is used by the internals of the Function class to null out
@@ -737,9 +740,8 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
 // NamedAttributeList
 //===----------------------------------------------------------------------===//
 
-NamedAttributeList::NamedAttributeList(MLIRContext *context,
-                                       ArrayRef<NamedAttribute> attributes) {
-  setAttrs(context, attributes);
+NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) {
+  setAttrs(attributes);
 }
 
 /// Return all of the attributes on this operation.
@@ -748,8 +750,7 @@ ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const {
 }
 
 /// Replace the held attributes with ones provided in 'newAttrs'.
-void NamedAttributeList::setAttrs(MLIRContext *context,
-                                  ArrayRef<NamedAttribute> attributes) {
+void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) {
   // Don't create an attribute list if there are no attributes.
   if (attributes.empty()) {
     attrs = nullptr;
@@ -759,7 +760,7 @@ void NamedAttributeList::setAttrs(MLIRContext *context,
   assert(llvm::all_of(attributes,
                       [](const NamedAttribute &attr) { return attr.second; }) &&
          "attributes cannot have null entries");
-  attrs = AttributeListStorage::get(attributes, context);
+  attrs = AttributeListStorage::get(attributes);
 }
 
 /// Return the specified attribute if present, null otherwise.
@@ -778,8 +779,7 @@ Attribute NamedAttributeList::get(Identifier name) const {
 
 /// If the an attribute exists with the specified name, change it to the new
 /// value.  Otherwise, add a new attribute with the specified name/value.
-void NamedAttributeList::set(MLIRContext *context, Identifier name,
-                             Attribute value) {
+void NamedAttributeList::set(Identifier name, Attribute value) {
   assert(value && "attributes may never be null");
 
   // If we already have this attribute, replace it.
@@ -788,27 +788,32 @@ void NamedAttributeList::set(MLIRContext *context, Identifier name,
   for (auto &elt : newAttrs)
     if (elt.first == name) {
       elt.second = value;
-      attrs = AttributeListStorage::get(newAttrs, context);
+      attrs = AttributeListStorage::get(newAttrs);
       return;
     }
 
   // Otherwise, add it.
   newAttrs.push_back({name, value});
-  attrs = AttributeListStorage::get(newAttrs, context);
+  attrs = AttributeListStorage::get(newAttrs);
 }
 
 /// Remove the attribute with the specified name if it exists.  The return
 /// value indicates whether the attribute was present or not.
-auto NamedAttributeList::remove(MLIRContext *context, Identifier name)
-    -> RemoveResult {
+auto NamedAttributeList::remove(Identifier name) -> RemoveResult {
   auto origAttrs = getAttrs();
   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
     if (origAttrs[i].first == name) {
+      // Handle the simple case of removing the only attribute in the list.
+      if (e == 1) {
+        attrs = nullptr;
+        return RemoveResult::Removed;
+      }
+
       SmallVector<NamedAttribute, 8> newAttrs;
       newAttrs.reserve(origAttrs.size() - 1);
       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
-      attrs = AttributeListStorage::get(newAttrs, context);
+      attrs = AttributeListStorage::get(newAttrs);
       return RemoveResult::Removed;
     }
   }
index af066ba..a6036a9 100644 (file)
@@ -167,12 +167,10 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
   return IntegerSetAttr::get(set);
 }
 
-TypeAttr Builder::getTypeAttr(Type type) {
-  return TypeAttr::get(type, context);
-}
+TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); }
 
 FunctionAttr Builder::getFunctionAttr(Function *value) {
-  return FunctionAttr::get(value, context);
+  return FunctionAttr::get(value);
 }
 
 ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type,
index 7651abf..1f9a1f5 100644 (file)
@@ -30,15 +30,13 @@ using namespace mlir;
 Function::Function(Location location, StringRef name, FunctionType type,
                    ArrayRef<NamedAttribute> attrs)
     : name(Identifier::get(name, type.getContext())), location(location),
-      type(type), attrs(type.getContext(), attrs),
-      argAttrs(type.getNumInputs()), body(this) {}
+      type(type), attrs(attrs), argAttrs(type.getNumInputs()), body(this) {}
 
 Function::Function(Location location, StringRef name, FunctionType type,
                    ArrayRef<NamedAttribute> attrs,
                    ArrayRef<NamedAttributeList> argAttrs)
     : name(Identifier::get(name, type.getContext())), location(location),
-      type(type), attrs(type.getContext(), attrs), argAttrs(argAttrs),
-      body(this) {}
+      type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
 
 Function::~Function() {
   // Clean up function attributes referring to this function.
index ac041ae..249a1e1 100644 (file)
@@ -849,8 +849,8 @@ static int compareNamedAttributes(const NamedAttribute *lhs,
 /// Given a list of NamedAttribute's, canonicalize the list (sorting
 /// by name) and return the unique'd result.  Note that the empty list is
 /// represented with a null pointer.
-AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
-                                                MLIRContext *context) {
+AttributeListStorage *
+AttributeListStorage::get(ArrayRef<NamedAttribute> attrs) {
   // We need to sort the element list to canonicalize it, but we also don't want
   // to do a ton of work in the super common case where the element list is
   // already sorted.
@@ -888,7 +888,7 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
     }
   }
 
-  auto &impl = context->getImpl();
+  auto &impl = attrs[0].second.getContext()->getImpl();
 
   // Safely get or create an attribute instance.
   return safeGetOrCreate(impl.attributeLists, attrs, impl.attributeMutex, [&] {
index 2c97988..91b32de 100644 (file)
@@ -102,7 +102,7 @@ Operation *Operation::create(Location location, OperationName name,
                              ArrayRef<Block *> successors, unsigned numRegions,
                              bool resizableOperandList, MLIRContext *context) {
   return create(location, name, operands, resultTypes,
-                NamedAttributeList(context, attributes), successors, numRegions,
+                NamedAttributeList(attributes), successors, numRegions,
                 resizableOperandList, context);
 }
 
index 66c3b2d..831a68a 100644 (file)
@@ -316,8 +316,8 @@ LogicalResult impl::FunctionConversion::run(Module *module) {
     if (!converted)
       return failure();
 
-    auto origFuncAttr = FunctionAttr::get(func, context);
-    auto convertedFuncAttr = FunctionAttr::get(converted, context);
+    auto origFuncAttr = FunctionAttr::get(func);
+    auto convertedFuncAttr = FunctionAttr::get(converted);
     convertedFuncs.push_back(converted);
     functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr});
   }
index 422d6b1..1ab821a 100644 (file)
@@ -290,33 +290,44 @@ void mlir::createAffineComputationSlice(
   }
 }
 
-void mlir::remapFunctionAttrs(
-    Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
-  for (auto attr : op.getAttrs()) {
+static void
+remapFunctionAttrs(NamedAttributeList &attrs,
+                   const DenseMap<Attribute, FunctionAttr> &remappingTable) {
+  for (auto attr : attrs.getAttrs()) {
     // Do the remapping, if we got the same thing back, then it must contain
     // functions that aren't getting remapped.
-    auto newVal =
-        attr.second.remapFunctionAttrs(remappingTable, op.getContext());
+    auto newVal = attr.second.remapFunctionAttrs(remappingTable);
     if (newVal == attr.second)
       continue;
 
     // Otherwise, replace the existing attribute with the new one.  It is safe
     // to mutate the attribute list while we walk it because underlying
     // attribute lists are uniqued and immortal.
-    op.setAttr(attr.first, newVal);
+    attrs.set(attr.first, newVal);
   }
 }
 
 void mlir::remapFunctionAttrs(
+    Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
+  ::remapFunctionAttrs(op.getAttrList(), remappingTable);
+}
+
+void mlir::remapFunctionAttrs(
     Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
 
+  // Remap the attributes of the function.
+  ::remapFunctionAttrs(fn.getAttrList(), remappingTable);
+
+  // Remap the attributes of the arguments of this function.
+  for (auto &attrList : fn.getAllArgAttrs())
+    ::remapFunctionAttrs(attrList, remappingTable);
+
   // Look at all operations in a Function.
   fn.walk([&](Operation *op) { remapFunctionAttrs(*op, remappingTable); });
 }
 
 void mlir::remapFunctionAttrs(
     Module &module, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
-  for (auto &fn : module) {
+  for (auto &fn : module)
     remapFunctionAttrs(fn, remappingTable);
-  }
 }
index a565c3b..2b28f80 100644 (file)
@@ -901,3 +901,15 @@ func @none_type() {
   %none_val = "foo.unknown_op"() : () -> none
   return
 }
+
+// CHECK-LABEL: func @fn_attr_remap
+// CHECK: {some_dialect.arg_attr: @fn_attr_ref : () -> ()}
+func @fn_attr_remap(%arg0: i1 {some_dialect.arg_attr: @fn_attr_ref : () -> ()}) -> ()
+  // CHECK-NEXT: {some_dialect.fn_attr: @fn_attr_ref : () -> ()}
+  attributes {some_dialect.fn_attr: @fn_attr_ref : () -> ()} {
+  return
+}
+
+// CHECK-LABEL: func @fn_attr_ref
+func @fn_attr_ref() -> ()
+