Add support to FuncOp for managing argument attributes. The syntax for argument attri...
authorRiver Riddle <riverriddle@google.com>
Tue, 4 Jun 2019 18:01:32 +0000 (11:01 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:16:36 +0000 (16:16 -0700)
  func @foo(i1 {dialect.attr: 10 : i64})

  func @foo(%arg0: i1 {dialect.attr: 10 : i64}) {
    return
  }

PiperOrigin-RevId: 251473338

mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Function.h
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/IR/Function.cpp
mlir/test/IR/func-op.mlir

index ddc05d1..b7b300b 100644 (file)
@@ -350,6 +350,7 @@ public:
   using iterator = llvm::ArrayRef<NamedAttribute>::iterator;
   iterator begin() const;
   iterator end() const;
+  bool empty() const { return size() == 0; }
   size_t size() const;
 
   /// Methods for supporting type inquiry through isa, cast, and dyn_cast.
@@ -802,9 +803,14 @@ inline ::llvm::hash_code hash_value(Attribute arg) {
 /// searches for everything.
 class NamedAttributeList {
 public:
-  NamedAttributeList(DictionaryAttr attrs = nullptr) : attrs(attrs) {}
+  NamedAttributeList(DictionaryAttr attrs = nullptr)
+      : attrs((attrs && !attrs.empty()) ? attrs : nullptr) {}
   NamedAttributeList(ArrayRef<NamedAttribute> attributes);
 
+  /// Return the underlying dictionary attribute. This may be null, if this list
+  /// has no attributes.
+  DictionaryAttr getDictionary() const { return attrs; }
+
   /// Return all of the attributes on this operation.
   ArrayRef<NamedAttribute> getAttrs() const;
 
index ed7083e..9ffcda4 100644 (file)
@@ -24,6 +24,7 @@
 
 #include "mlir/IR/Block.h"
 #include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/SmallString.h"
 
 namespace mlir {
 class BlockAndValueMapping;
@@ -395,6 +396,83 @@ public:
   llvm::iterator_range<args_iterator> getArguments() {
     return {args_begin(), args_end()};
   }
+
+  //===--------------------------------------------------------------------===//
+  // Argument Attributes
+  //===--------------------------------------------------------------------===//
+
+  /// FuncOp allows for attaching attributes to each of the respective function
+  /// arguments. These argument attributes are stored as DictionaryAttrs in the
+  /// main operation attribute dictionary. The name of these entries is `arg`
+  /// followed by the index of the argument. These argument attribute
+  /// dictionaries are optional, and will generally only exist if they are
+  /// non-empty.
+
+  /// Return all of the attributes for the argument at 'index'.
+  ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
+    auto argDict = getArgAttrDict(index);
+    return argDict ? argDict.getValue() : llvm::None;
+  }
+
+  /// Return all argument attributes of this function.
+  void getAllArgAttrs(SmallVectorImpl<DictionaryAttr> &result) {
+    for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
+      result.emplace_back(getArgAttrDict(i));
+  }
+
+  /// Return the specified attribute, if present, for the argument at 'index',
+  /// null otherwise.
+  Attribute getArgAttr(unsigned index, Identifier name) {
+    auto argDict = getArgAttrDict(index);
+    return argDict ? argDict.get(name) : nullptr;
+  }
+  Attribute getArgAttr(unsigned index, StringRef name) {
+    auto argDict = getArgAttrDict(index);
+    return argDict ? argDict.get(name) : nullptr;
+  }
+
+  template <typename AttrClass>
+  AttrClass getArgAttrOfType(unsigned index, Identifier name) {
+    return getArgAttr(index, name).dyn_cast_or_null<AttrClass>();
+  }
+  template <typename AttrClass>
+  AttrClass getArgAttrOfType(unsigned index, StringRef name) {
+    return getArgAttr(index, name).dyn_cast_or_null<AttrClass>();
+  }
+
+  /// Set the attributes held by the argument at 'index'.
+  void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes);
+  void setArgAttrs(unsigned index, NamedAttributeList attributes);
+  void setAllArgAttrs(ArrayRef<NamedAttributeList> attributes) {
+    assert(attributes.size() == getNumArguments());
+    for (unsigned i = 0, e = attributes.size(); i != e; ++i)
+      setArgAttrs(i, attributes[i]);
+  }
+
+  /// If the an attribute exists with the specified name, change it to the new
+  /// value. Otherwise, add a new attribute with the specified name/value.
+  void setArgAttr(unsigned index, Identifier name, Attribute value);
+  void setArgAttr(unsigned index, StringRef name, Attribute value) {
+    setArgAttr(index, Identifier::get(name, getContext()), value);
+  }
+
+  /// Remove the attribute 'name' from the argument at 'index'.
+  NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
+                                                 Identifier name);
+
+private:
+  /// Returns the attribute entry name for the set of argument attributes at
+  /// index 'arg'.
+  static StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out);
+
+  /// Returns the dictionary attribute corresponding to the argument at 'index'.
+  /// If there are no argument attributes at 'index', a null attribute is
+  /// returned.
+  DictionaryAttr getArgAttrDict(unsigned index) {
+    assert(index < getNumArguments() && "invalid argument number");
+    SmallString<8> nameOut;
+    return getAttrOfType<DictionaryAttr>(getArgAttrName(index, nameOut));
+  }
 };
 
 } // end namespace mlir
index 0ede2f8..35c248f 100644 (file)
@@ -106,6 +106,15 @@ public:
     setAttr(Identifier::get(name, getContext()), value);
   }
 
+  /// Remove the attribute with the specified name if it exists.  The return
+  /// value indicates whether the attribute was present or not.
+  NamedAttributeList::RemoveResult removeAttr(Identifier name) {
+    return state->removeAttr(name);
+  }
+  NamedAttributeList::RemoveResult removeAttr(StringRef name) {
+    return state->removeAttr(Identifier::get(name, getContext()));
+  }
+
   /// Return true if there are no users of any results of this operation.
   bool use_empty() { return state->use_empty(); }
 
index f4ee155..f53c715 100644 (file)
@@ -269,12 +269,11 @@ parseArgumentList(OpAsmParser *parser, SmallVectorImpl<Type> &argTypes,
     // Add the argument type.
     argTypes.push_back(argumentType);
 
-    // TODO(riverriddle) Parse argument attributes.
-    // Parse the attribute dict.
-    // SmallVector<NamedAttribute, 2> attrs;
-    // if (parser->parseOptionalAttributeDict(attrs))
-    //  return failure();
-    // argAttrs.push_back(attrs);
+    // Parse any argument attributes.
+    SmallVector<NamedAttribute, 2> attrs;
+    if (parser->parseOptionalAttributeDict(attrs))
+      return failure();
+    argAttrs.push_back(attrs);
     return success();
   };
 
@@ -331,12 +330,12 @@ ParseResult FuncOp::parse(OpAsmParser *parser, OperationState *result) {
     if (parser->parseOptionalAttributeDict(result->attributes))
       return failure();
 
-  // TODO(riverriddle) Parse argument attributes.
   // Add the attributes to the function arguments.
-  // for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
-  //   if (!argAttrs[i].empty())
-  //     result->addAttribute(("arg" + Twine(i)).str(),
-  //                          builder.getDictionaryAttr(argAttrs[i]));
+  SmallString<8> argAttrName;
+  for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
+    if (!argAttrs[i].empty())
+      result->addAttribute(getArgAttrName(i, argAttrName),
+                           builder.getDictionaryAttr(argAttrs[i]));
 
   // Parse the optional function body.
   auto *body = result->addRegion();
@@ -362,11 +361,9 @@ static void printFunctionSignature(OpAsmPrinter *p, FuncOp op) {
       *p << ": ";
     }
 
+    // Print the type followed by any argument attributes.
     p->printType(fnType.getInput(i));
-
-    // TODO(riverriddle) Print argument attributes.
-    // Print the attributes for this argument.
-    // p->printOptionalAttrDict(op.getArgAttrs(i));
+    p->printOptionalAttrDict(op.getArgAttrs(i));
   }
   *p << ')';
 
@@ -399,13 +396,20 @@ void FuncOp::print(OpAsmPrinter *p) {
   printFunctionSignature(p, *this);
 
   // Print out function attributes, if present.
-  auto attrs = getAttrs();
+  SmallVector<StringRef, 2> ignoredAttrs = {"name", "type"};
 
-  // We must have more attributes than <name, type>.
-  constexpr unsigned kNumHiddenAttrs = 2;
-  if (attrs.size() > kNumHiddenAttrs) {
+  // Ignore any argument attributes.
+  std::vector<SmallString<8>> argAttrStorage;
+  SmallString<8> argAttrName;
+  for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
+    if (getAttr(getArgAttrName(i, argAttrName)))
+      argAttrStorage.emplace_back(argAttrName);
+  ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end());
+
+  auto attrs = getAttrs();
+  if (attrs.size() > ignoredAttrs.size()) {
     *p << "\n  attributes ";
-    p->printOptionalAttrDict(attrs, {"name", "type"});
+    p->printOptionalAttrDict(attrs, ignoredAttrs);
   }
 
   // Print the body if this is not an external function.
@@ -416,3 +420,58 @@ void FuncOp::print(OpAsmPrinter *p) {
   }
   *p << '\n';
 }
+
+//===----------------------------------------------------------------------===//
+// Function Argument Attribute.
+//===----------------------------------------------------------------------===//
+
+/// Set the attributes held by the argument at 'index'.
+void FuncOp::setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) {
+  assert(index < getNumArguments() && "invalid argument number");
+  SmallString<8> nameOut;
+  getArgAttrName(index, nameOut);
+
+  if (attributes.empty())
+    return (void)removeAttr(nameOut);
+  setAttr(nameOut, DictionaryAttr::get(attributes, getContext()));
+}
+
+void FuncOp::setArgAttrs(unsigned index, NamedAttributeList attributes) {
+  assert(index < getNumArguments() && "invalid argument number");
+  SmallString<8> nameOut;
+  if (auto newAttr = attributes.getDictionary())
+    return setAttr(getArgAttrName(index, nameOut), newAttr);
+  removeAttr(getArgAttrName(index, nameOut));
+}
+
+/// If the an attribute exists with the specified name, change it to the new
+/// value. Otherwise, add a new attribute with the specified name/value.
+void FuncOp::setArgAttr(unsigned index, Identifier name, Attribute value) {
+  auto curAttr = getArgAttrDict(index);
+  NamedAttributeList attrList(curAttr);
+  attrList.set(name, value);
+
+  // If the attribute changed, then set the new arg attribute list.
+  if (curAttr != attrList.getDictionary())
+    setArgAttrs(index, attrList);
+}
+
+/// Remove the attribute 'name' from the argument at 'index'.
+NamedAttributeList::RemoveResult FuncOp::removeArgAttr(unsigned index,
+                                                       Identifier name) {
+  // Build an attribute list and remove the attribute at 'name'.
+  NamedAttributeList attrList(getArgAttrDict(index));
+  auto result = attrList.remove(name);
+
+  // If the attribute was removed, then update the argument dictionary.
+  if (result == NamedAttributeList::RemoveResult::Removed)
+    setArgAttrs(index, attrList);
+  return result;
+}
+
+/// Returns the attribute entry name for the set of argument attributes at index
+/// 'arg'.
+StringRef FuncOp::getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
+  out.clear();
+  return ("arg" + Twine(arg)).toStringRef(out);
+}
index f264dc2..c9819f2 100644 (file)
@@ -27,3 +27,17 @@ func @func_attributes() {
   func @foo() attributes {foo: true}
   return
 }
+
+
+// CHECK-LABEL: func @func_arg_attributes
+func @func_arg_attributes() {
+  // CHECK-NEXT: func @external_func_arg_attrs(i32, i1 {dialect.attr: 10 : i64}, i32)
+  func @external_func_arg_attrs(i32, i1 {dialect.attr: 10 : i64}, i32)
+
+  // CHECK: func @func_arg_attrs(%i0: i1 {dialect.attr: 10 : i64})
+  func @func_arg_attrs(%i0: i1 {dialect.attr: 10 : i64}) {
+    return
+  }
+
+  return
+}