[mlir] Refactor the representation of function-like argument/result attributes.
authorRiver Riddle <riddleriver@gmail.com>
Sat, 8 May 2021 02:30:25 +0000 (19:30 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Sat, 8 May 2021 02:32:31 +0000 (19:32 -0700)
The current design uses a unique entry for each argument/result attribute, with the name of the entry being something like "arg0". This provides for a somewhat sparse design, but ends up being much more expensive (from a runtime perspective) in-practice. The design requires building a string every time we lookup the dictionary for a specific arg/result, and also requires N attribute lookups when collecting all of the arg/result attribute dictionaries.

This revision restructures the design to instead have an ArrayAttr that contains all of the attribute dictionaries for arguments and another for results. This design reduces the number of attribute name lookups to 1, and allows for O(1) lookup for individual element dictionaries. The major downside is that we can end up with larger memory usage, as the ArrayAttr contains an entry for each element even if that element has no attributes. If the memory usage becomes too problematic, we can experiment with a more sparse structure that still provides a lot of the wins in this revision.

This dropped the compilation time of a somewhat large TensorFlow model from ~650 seconds to ~400 seconds.

Differential Revision: https://reviews.llvm.org/D102035

19 files changed:
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/FunctionImplementation.h
mlir/include/mlir/IR/FunctionSupport.h
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/lib/IR/FunctionImplementation.cpp
mlir/lib/IR/FunctionSupport.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Dialect/LLVMIR/func.mlir
mlir/test/IR/invalid-func-op.mlir
mlir/test/IR/test-func-set-type.mlir

index 5bd6956..a29a22a 100644 (file)
@@ -213,16 +213,6 @@ def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">,
           GPUDialect::getKernelFuncAttrName()) != nullptr;
     }
 
-    /// Change the type of this function in place. This is an extremely
-    /// dangerous operation and it is up to the caller to ensure that this is
-    /// legal for this function, and to restore invariants:
-    ///  - the entry block args must be updated to match the function params.
-    ///  - the argument/result attributes may need an update: if the new type
-    ///  has less parameters we drop the extra attributes, if there are more
-    ///  parameters they won't have any attributes.
-    // TODO: consider removing this function thanks to rewrite patterns.
-    void setType(FunctionType newType);
-
     /// Returns the number of buffers located in the workgroup memory.
     unsigned getNumWorkgroupAttributions() {
       return (*this)->getAttrOfType<IntegerAttr>(
index c248ad5..05dbc6b 100644 (file)
@@ -300,7 +300,7 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> {
   }];
   let parameters = (ins ArrayRefParameter<"NamedAttribute", "">:$value);
   let builders = [
-    AttrBuilder<(ins "ArrayRef<NamedAttribute>":$value)>
+    AttrBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "llvm::None">:$value)>
   ];
   let extraClassDeclaration = [{
     using ValueType = ArrayRef<NamedAttribute>;
index c19100c..cb7776f 100644 (file)
@@ -20,7 +20,7 @@
 
 namespace mlir {
 
-namespace impl {
+namespace function_like_impl {
 
 /// A named class for passing around the variadic flag.
 class VariadicFlag {
@@ -38,6 +38,9 @@ private:
 /// Internally, argument and result attributes are stored as dict attributes
 /// with special names given by getResultAttrName, getArgumentAttrName.
 void addArgAndResultAttrs(Builder &builder, OperationState &result,
+                          ArrayRef<DictionaryAttr> argAttrs,
+                          ArrayRef<DictionaryAttr> resultAttrs);
+void addArgAndResultAttrs(Builder &builder, OperationState &result,
                           ArrayRef<NamedAttrList> argAttrs,
                           ArrayRef<NamedAttrList> resultAttrs);
 
@@ -103,7 +106,7 @@ void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs,
                              unsigned numResults,
                              ArrayRef<StringRef> elided = {});
 
-} // namespace impl
+} // namespace function_like_impl
 
 } // namespace mlir
 
index a3be1a7..21d6e37 100644 (file)
 
 namespace mlir {
 
-namespace impl {
+namespace function_like_impl {
 
 /// Return the name of the attribute used for function types.
 inline StringRef getTypeAttrName() { return "type"; }
 
-/// Return the name of the attribute used for function arguments.
-inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
-  out.clear();
-  return ("arg" + Twine(arg)).toStringRef(out);
-}
-
-/// Returns true if the given name is a valid argument attribute name.
-inline bool isArgAttrName(StringRef name) {
-  APInt unused;
-  return name.startswith("arg") &&
-         !name.drop_front(3).getAsInteger(/*Radix=*/10, unused);
-}
+/// Return the name of the attribute used for function argument attributes.
+inline StringRef getArgDictAttrName() { return "arg_attrs"; }
 
-/// Return the name of the attribute used for function results.
-inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl<char> &out) {
-  out.clear();
-  return ("result" + Twine(arg)).toStringRef(out);
-}
+/// Return the name of the attribute used for function argument attributes.
+inline StringRef getResultDictAttrName() { return "res_attrs"; }
 
 /// Returns the dictionary attribute corresponding to the argument at 'index'.
 /// If there are no argument attributes at 'index', a null attribute is
 /// returned.
-inline DictionaryAttr getArgAttrDict(Operation *op, unsigned index) {
-  SmallString<8> nameOut;
-  return op->getAttrOfType<DictionaryAttr>(getArgAttrName(index, nameOut));
-}
+DictionaryAttr getArgAttrDict(Operation *op, unsigned index);
 
 /// Returns the dictionary attribute corresponding to the result at 'index'.
 /// If there are no result attributes at 'index', a null attribute is
 /// returned.
-inline DictionaryAttr getResultAttrDict(Operation *op, unsigned index) {
-  SmallString<8> nameOut;
-  return op->getAttrOfType<DictionaryAttr>(getResultAttrName(index, nameOut));
-}
+DictionaryAttr getResultAttrDict(Operation *op, unsigned index);
+
+namespace detail {
+/// Update the given index into an argument or result attribute dictionary.
+void setArgResAttrDict(Operation *op, StringRef attrName,
+                       unsigned numTotalIndices, unsigned index,
+                       DictionaryAttr attrs);
+} // namespace detail
+
+/// Set all of the argument or result attribute dictionaries for a function. The
+/// size of `attrs` is expected to match the number of arguments/results of the
+/// given `op`.
+void setAllArgAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
+void setAllArgAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
+void setAllResultAttrDicts(Operation *op, ArrayRef<DictionaryAttr> attrs);
+void setAllResultAttrDicts(Operation *op, ArrayRef<Attribute> attrs);
 
 /// Return all of the attributes for the argument at 'index'.
 inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
@@ -87,7 +83,7 @@ void setFunctionType(Operation *op, FunctionType newType);
 /// Get a FunctionLike operation's body.
 Region &getFunctionBody(Operation *op);
 
-} // namespace impl
+} // namespace function_like_impl
 
 namespace OpTrait {
 
@@ -142,7 +138,7 @@ public:
   bool isExternal() { return empty(); }
 
   Region &getBody() {
-    return ::mlir::impl::getFunctionBody(this->getOperation());
+    return function_like_impl::getFunctionBody(this->getOperation());
   }
 
   /// Delete all blocks from this function.
@@ -194,7 +190,9 @@ public:
   //===--------------------------------------------------------------------===//
 
   /// Return the name of the attribute used for function types.
-  static StringRef getTypeAttrName() { return ::mlir::impl::getTypeAttrName(); }
+  static StringRef getTypeAttrName() {
+    return function_like_impl::getTypeAttrName();
+  }
 
   TypeAttr getTypeAttr() {
     return this->getOperation()->template getAttrOfType<TypeAttr>(
@@ -207,7 +205,7 @@ public:
   /// hide this one if the concrete class does not use FunctionType for the
   /// function type under the hood.
   FunctionType getType() {
-    return ::mlir::impl::getFunctionType(this->getOperation());
+    return function_like_impl::getFunctionType(this->getOperation());
   }
 
   /// Return the type of this function without the specified arguments and
@@ -277,8 +275,8 @@ public:
   void eraseArguments(ArrayRef<unsigned> argIndices) {
     unsigned originalNumArgs = getNumArguments();
     Type newType = getTypeWithoutArgsAndResults(argIndices, {});
-    ::mlir::impl::eraseFunctionArguments(this->getOperation(), argIndices,
-                                         originalNumArgs, newType);
+    function_like_impl::eraseFunctionArguments(this->getOperation(), argIndices,
+                                               originalNumArgs, newType);
   }
 
   /// Erase a single result at `resultIndex`.
@@ -289,8 +287,8 @@ public:
   void eraseResults(ArrayRef<unsigned> resultIndices) {
     unsigned originalNumResults = getNumResults();
     Type newType = getTypeWithoutArgsAndResults({}, resultIndices);
-    ::mlir::impl::eraseFunctionResults(this->getOperation(), resultIndices,
-                                       originalNumResults, newType);
+    function_like_impl::eraseFunctionResults(
+        this->getOperation(), resultIndices, originalNumResults, newType);
   }
 
   //===--------------------------------------------------------------------===//
@@ -306,14 +304,23 @@ public:
 
   /// Return all of the attributes for the argument at 'index'.
   ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
-    return ::mlir::impl::getArgAttrs(this->getOperation(), index);
+    return function_like_impl::getArgAttrs(this->getOperation(), index);
   }
 
-  /// Return all argument attributes of this function. If an argument does not
-  /// have any attributes, the corresponding entry in `result` is nullptr.
+  /// Return an ArrayAttr containing all argument attribute dictionaries of this
+  /// function, or nullptr if no arguments have attributes.
+  ArrayAttr getAllArgAttrs() {
+    return this->getOperation()->template getAttrOfType<ArrayAttr>(
+        function_like_impl::getArgDictAttrName());
+  }
+  /// Return all argument attributes of this function.
   void getAllArgAttrs(SmallVectorImpl<DictionaryAttr> &result) {
-    for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
-      result.emplace_back(getArgAttrDict(i));
+    if (ArrayAttr argAttrs = getAllArgAttrs()) {
+      auto argAttrRange = argAttrs.template getAsRange<DictionaryAttr>();
+      result.append(argAttrRange.begin(), argAttrRange.end());
+    } else {
+      result.resize(getNumArguments());
+    }
   }
 
   /// Return the specified attribute, if present, for the argument at 'index',
@@ -342,7 +349,19 @@ public:
   /// Set the attributes held by the argument at 'index'. `attributes` may be
   /// null, in which case any existing argument attributes are removed.
   void setArgAttrs(unsigned index, DictionaryAttr attributes);
-  void setAllArgAttrs(ArrayRef<DictionaryAttr> attributes);
+  void setAllArgAttrs(ArrayRef<DictionaryAttr> attributes) {
+    assert(attributes.size() == getNumArguments());
+    function_like_impl::setAllArgAttrDicts(this->getOperation(), attributes);
+  }
+  void setAllArgAttrs(ArrayRef<Attribute> attributes) {
+    assert(attributes.size() == getNumArguments());
+    function_like_impl::setAllArgAttrDicts(this->getOperation(), attributes);
+  }
+  void setAllArgAttrs(ArrayAttr attributes) {
+    assert(attributes.size() == getNumArguments());
+    this->getOperation()->setAttr(function_like_impl::getArgDictAttrName(),
+                                  attributes);
+  }
 
   /// If the an attribute exists with the specified name, change it to the new
   /// value. Otherwise, add a new attribute with the specified name/value.
@@ -370,14 +389,23 @@ public:
 
   /// Return all of the attributes for the result at 'index'.
   ArrayRef<NamedAttribute> getResultAttrs(unsigned index) {
-    return ::mlir::impl::getResultAttrs(this->getOperation(), index);
+    return function_like_impl::getResultAttrs(this->getOperation(), index);
   }
 
-  /// Return all result attributes of this function. If a result does not have
-  /// any attributes, the corresponding entry in `result` is nullptr.
+  /// Return an ArrayAttr containing all result attribute dictionaries of this
+  /// function, or nullptr if no result have attributes.
+  ArrayAttr getAllResultAttrs() {
+    return this->getOperation()->template getAttrOfType<ArrayAttr>(
+        function_like_impl::getResultDictAttrName());
+  }
+  /// Return all result attributes of this function.
   void getAllResultAttrs(SmallVectorImpl<DictionaryAttr> &result) {
-    for (unsigned i = 0, e = getNumResults(); i != e; ++i)
-      result.emplace_back(getResultAttrDict(i));
+    if (ArrayAttr argAttrs = getAllResultAttrs()) {
+      auto argAttrRange = argAttrs.template getAsRange<DictionaryAttr>();
+      result.append(argAttrRange.begin(), argAttrRange.end());
+    } else {
+      result.resize(getNumResults());
+    }
   }
 
   /// Return the specified attribute, if present, for the result at 'index',
@@ -402,10 +430,23 @@ public:
 
   /// Set the attributes held by the result at 'index'.
   void setResultAttrs(unsigned index, ArrayRef<NamedAttribute> attributes);
+
   /// Set the attributes held by the result at 'index'. `attributes` may be
   /// null, in which case any existing argument attributes are removed.
   void setResultAttrs(unsigned index, DictionaryAttr attributes);
-  void setAllResultAttrs(ArrayRef<DictionaryAttr> attributes);
+  void setAllResultAttrs(ArrayRef<DictionaryAttr> attributes) {
+    assert(attributes.size() == getNumResults());
+    function_like_impl::setAllResultAttrDicts(this->getOperation(), attributes);
+  }
+  void setAllResultAttrs(ArrayRef<Attribute> attributes) {
+    assert(attributes.size() == getNumResults());
+    function_like_impl::setAllResultAttrDicts(this->getOperation(), attributes);
+  }
+  void setAllResultAttrs(ArrayAttr attributes) {
+    assert(attributes.size() == getNumResults());
+    this->getOperation()->setAttr(function_like_impl::getResultDictAttrName(),
+                                  attributes);
+  }
 
   /// If the an attribute exists with the specified name, change it to the new
   /// value. Otherwise, add a new attribute with the specified name/value.
@@ -422,25 +463,12 @@ public:
   Attribute removeResultAttr(unsigned index, Identifier name);
 
 protected:
-  /// Returns the attribute entry name for the set of argument attributes at
-  /// 'index'.
-  static StringRef getArgAttrName(unsigned index, SmallVectorImpl<char> &out) {
-    return ::mlir::impl::getArgAttrName(index, out);
-  }
-
   /// Returns the dictionary attribute corresponding to the argument at 'index'.
   /// If there are no argument attributes at 'index', a null attribute is
   /// returned.
   DictionaryAttr getArgAttrDict(unsigned index) {
     assert(index < getNumArguments() && "invalid argument number");
-    return ::mlir::impl::getArgAttrDict(this->getOperation(), index);
-  }
-
-  /// Returns the attribute entry name for the set of result attributes at
-  /// 'index'.
-  static StringRef getResultAttrName(unsigned index,
-                                     SmallVectorImpl<char> &out) {
-    return ::mlir::impl::getResultAttrName(index, out);
+    return function_like_impl::getArgAttrDict(this->getOperation(), index);
   }
 
   /// Returns the dictionary attribute corresponding to the result at 'index'.
@@ -448,7 +476,7 @@ protected:
   /// returned.
   DictionaryAttr getResultAttrDict(unsigned index) {
     assert(index < getNumResults() && "invalid result number");
-    return ::mlir::impl::getResultAttrDict(this->getOperation(), index);
+    return function_like_impl::getResultAttrDict(this->getOperation(), index);
   }
 
   /// Hook for concrete classes to verify that the type attribute respects
@@ -475,9 +503,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyBody() {
 
 template <typename ConcreteType>
 LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
-  MLIRContext *ctx = op->getContext();
   auto funcOp = cast<ConcreteType>(op);
-
   if (!funcOp.isTypeAttrValid())
     return funcOp.emitOpError("requires a type attribute '")
            << getTypeAttrName() << '\'';
@@ -485,35 +511,69 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
   if (failed(funcOp.verifyType()))
     return failure();
 
-  for (unsigned i = 0, e = funcOp.getNumArguments(); i != e; ++i) {
-    // Verify that all of the argument attributes are dialect attributes, i.e.
-    // that they contain a dialect prefix in their name.  Call the dialect, if
-    // registered, to verify the attributes themselves.
-    for (auto attr : funcOp.getArgAttrs(i)) {
-      if (!attr.first.strref().contains('.'))
-        return funcOp.emitOpError("arguments may only have dialect attributes");
-      auto dialectNamePair = attr.first.strref().split('.');
-      if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
-        if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
-                                                     /*argIndex=*/i, attr)))
-          return failure();
+  if (ArrayAttr allArgAttrs = funcOp.getAllArgAttrs()) {
+    unsigned numArgs = funcOp.getNumArguments();
+    if (allArgAttrs.size() != numArgs) {
+      return funcOp.emitOpError()
+             << "expects argument attribute array `"
+             << function_like_impl::getArgDictAttrName()
+             << "` to have the same number of elements as the number of "
+                "function arguments, got "
+             << allArgAttrs.size() << ", but expected " << numArgs;
+    }
+    for (unsigned i = 0; i != numArgs; ++i) {
+      DictionaryAttr argAttrs = allArgAttrs[i].dyn_cast<DictionaryAttr>();
+      if (!argAttrs) {
+        return funcOp.emitOpError() << "expects argument attribute dictionary "
+                                       "to be a DictionaryAttr, but got `"
+                                    << allArgAttrs[i] << "`";
+      }
+
+      // Verify that all of the argument attributes are dialect attributes, i.e.
+      // that they contain a dialect prefix in their name.  Call the dialect, if
+      // registered, to verify the attributes themselves.
+      for (auto attr : argAttrs) {
+        if (!attr.first.strref().contains('.'))
+          return funcOp.emitOpError(
+              "arguments may only have dialect attributes");
+        if (Dialect *dialect = attr.first.getDialect()) {
+          if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
+                                                       /*argIndex=*/i, attr)))
+            return failure();
+        }
       }
     }
   }
+  if (ArrayAttr allResultAttrs = funcOp.getAllResultAttrs()) {
+    unsigned numResults = funcOp.getNumResults();
+    if (allResultAttrs.size() != numResults) {
+      return funcOp.emitOpError()
+             << "expects result attribute array `"
+             << function_like_impl::getResultDictAttrName()
+             << "` to have the same number of elements as the number of "
+                "function results, got "
+             << allResultAttrs.size() << ", but expected " << numResults;
+    }
+    for (unsigned i = 0; i != numResults; ++i) {
+      DictionaryAttr resultAttrs = allResultAttrs[i].dyn_cast<DictionaryAttr>();
+      if (!resultAttrs) {
+        return funcOp.emitOpError() << "expects result attribute dictionary "
+                                       "to be a DictionaryAttr, but got `"
+                                    << allResultAttrs[i] << "`";
+      }
 
-  for (unsigned i = 0, e = funcOp.getNumResults(); i != e; ++i) {
-    // Verify that all of the result attributes are dialect attributes, i.e.
-    // that they contain a dialect prefix in their name.  Call the dialect, if
-    // registered, to verify the attributes themselves.
-    for (auto attr : funcOp.getResultAttrs(i)) {
-      if (!attr.first.strref().contains('.'))
-        return funcOp.emitOpError("results may only have dialect attributes");
-      auto dialectNamePair = attr.first.strref().split('.');
-      if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
-        if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
-                                                        /*resultIndex=*/i,
-                                                        attr)))
-          return failure();
+      // Verify that all of the result attributes are dialect attributes, i.e.
+      // that they contain a dialect prefix in their name.  Call the dialect, if
+      // registered, to verify the attributes themselves.
+      for (auto attr : resultAttrs) {
+        if (!attr.first.strref().contains('.'))
+          return funcOp.emitOpError("results may only have dialect attributes");
+        if (Dialect *dialect = attr.first.getDialect()) {
+          if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
+                                                          /*resultIndex=*/i,
+                                                          attr)))
+            return failure();
+        }
       }
     }
   }
@@ -551,7 +611,7 @@ Block *FunctionLike<ConcreteType>::addBlock() {
 
 template <typename ConcreteType>
 void FunctionLike<ConcreteType>::setType(FunctionType newType) {
-  ::mlir::impl::setFunctionType(this->getOperation(), newType);
+  function_like_impl::setFunctionType(this->getOperation(), newType);
 }
 
 //===----------------------------------------------------------------------===//
@@ -563,45 +623,19 @@ template <typename ConcreteType>
 void FunctionLike<ConcreteType>::setArgAttrs(
     unsigned index, ArrayRef<NamedAttribute> attributes) {
   assert(index < getNumArguments() && "invalid argument number");
-  SmallString<8> nameOut;
-  getArgAttrName(index, nameOut);
-
   Operation *op = this->getOperation();
-  if (attributes.empty())
-    return (void)op->removeAttr(nameOut);
-  op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes));
+  return function_like_impl::detail::setArgResAttrDict(
+      op, function_like_impl::getArgDictAttrName(), getNumArguments(), index,
+      DictionaryAttr::get(op->getContext(), attributes));
 }
 
 template <typename ConcreteType>
 void FunctionLike<ConcreteType>::setArgAttrs(unsigned index,
                                              DictionaryAttr attributes) {
-  assert(index < getNumArguments() && "invalid argument number");
-  SmallString<8> nameOut;
-  if (!attributes || attributes.empty())
-    this->getOperation()->removeAttr(getArgAttrName(index, nameOut));
-  else
-    return this->getOperation()->setAttr(getArgAttrName(index, nameOut),
-                                         attributes);
-}
-
-template <typename ConcreteType>
-void FunctionLike<ConcreteType>::setAllArgAttrs(
-    ArrayRef<DictionaryAttr> attributes) {
-  assert(attributes.size() == getNumArguments());
-  NamedAttrList attrs = this->getOperation()->getAttrs();
-
-  // Instead of calling setArgAttrs() multiple times, which rebuild the
-  // attribute dictionary every time, build a new list of attributes for the
-  // operation so that we rebuild the attribute dictionary in one shot.
-  SmallString<8> argAttrName;
-  for (unsigned i = 0, e = attributes.size(); i != e; ++i) {
-    StringRef attrName = getArgAttrName(i, argAttrName);
-    if (!attributes[i] || attributes[i].empty())
-      attrs.erase(attrName);
-    else
-      attrs.set(attrName, attributes[i]);
-  }
-  this->getOperation()->setAttrs(attrs);
+  Operation *op = this->getOperation();
+  return function_like_impl::detail::setArgResAttrDict(
+      op, function_like_impl::getArgDictAttrName(), getNumArguments(), index,
+      attributes ? attributes : DictionaryAttr::get(op->getContext()));
 }
 
 /// If the an attribute exists with the specified name, change it to the new
@@ -640,45 +674,20 @@ template <typename ConcreteType>
 void FunctionLike<ConcreteType>::setResultAttrs(
     unsigned index, ArrayRef<NamedAttribute> attributes) {
   assert(index < getNumResults() && "invalid result number");
-  SmallString<8> nameOut;
-  getResultAttrName(index, nameOut);
-
-  if (attributes.empty())
-    return (void)this->getOperation()->removeAttr(nameOut);
   Operation *op = this->getOperation();
-  op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes));
+  return function_like_impl::detail::setArgResAttrDict(
+      op, function_like_impl::getResultDictAttrName(), getNumResults(), index,
+      DictionaryAttr::get(op->getContext(), attributes));
 }
 
 template <typename ConcreteType>
 void FunctionLike<ConcreteType>::setResultAttrs(unsigned index,
                                                 DictionaryAttr attributes) {
   assert(index < getNumResults() && "invalid result number");
-  SmallString<8> nameOut;
-  if (!attributes || attributes.empty())
-    this->getOperation()->removeAttr(getResultAttrName(index, nameOut));
-  else
-    this->getOperation()->setAttr(getResultAttrName(index, nameOut),
-                                  attributes);
-}
-
-template <typename ConcreteType>
-void FunctionLike<ConcreteType>::setAllResultAttrs(
-    ArrayRef<DictionaryAttr> attributes) {
-  assert(attributes.size() == getNumResults());
-  NamedAttrList attrs = this->getOperation()->getAttrs();
-
-  // Instead of calling setResultAttrs() multiple times, which rebuild the
-  // attribute dictionary every time, build a new list of attributes for the
-  // operation so that we rebuild the attribute dictionary in one shot.
-  SmallString<8> resultAttrName;
-  for (unsigned i = 0, e = attributes.size(); i != e; ++i) {
-    StringRef attrName = getResultAttrName(i, resultAttrName);
-    if (!attributes[i] || attributes[i].empty())
-      attrs.erase(attrName);
-    else
-      attrs.set(attrName, attributes[i]);
-  }
-  this->getOperation()->setAttrs(attrs);
+  Operation *op = this->getOperation();
+  return function_like_impl::detail::setArgResAttrDict(
+      op, function_like_impl::getResultDictAttrName(), getNumResults(), index,
+      attributes ? attributes : DictionaryAttr::get(op->getContext()));
 }
 
 /// If the an attribute exists with the specified name, change it to the new
index 0833953..67f699a 100644 (file)
@@ -58,7 +58,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
   SmallVector<NamedAttribute, 4> attributes;
   for (const auto &attr : gpuFuncOp->getAttrs()) {
     if (attr.first == SymbolTable::getSymbolAttrName() ||
-        attr.first == impl::getTypeAttrName() ||
+        attr.first == function_like_impl::getTypeAttrName() ||
         attr.first == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
       continue;
     attributes.push_back(attr);
index 2066deb..fa4bbff 100644 (file)
@@ -195,7 +195,7 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
                                llvm::None));
   for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.first == impl::getTypeAttrName() ||
+    if (namedAttr.first == function_like_impl::getTypeAttrName() ||
         namedAttr.first == SymbolTable::getSymbolAttrName())
       continue;
     newFuncOp->setAttr(namedAttr.first, namedAttr.second);
index 5f94804..3949cd4 100644 (file)
@@ -1211,8 +1211,10 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
                                  SmallVectorImpl<NamedAttribute> &result) {
   for (const auto &attr : attrs) {
     if (attr.first == SymbolTable::getSymbolAttrName() ||
-        attr.first == impl::getTypeAttrName() || attr.first == "std.varargs" ||
-        (filterArgAttrs && impl::isArgAttrName(attr.first.strref())))
+        attr.first == function_like_impl::getTypeAttrName() ||
+        attr.first == "std.varargs" ||
+        (filterArgAttrs &&
+         attr.first == function_like_impl::getArgDictAttrName()))
       continue;
     result.push_back(attr);
   }
@@ -1395,19 +1397,19 @@ protected:
     SmallVector<NamedAttribute, 4> attributes;
     filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
                          attributes);
-    for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
-      auto attr = impl::getArgAttrDict(funcOp, i);
-      if (!attr)
-        continue;
-
-      auto mapping = result.getInputMapping(i);
-      assert(mapping.hasValue() && "unexpected deletion of function argument");
-
-      SmallString<8> name;
-      for (size_t j = 0; j < mapping->size; ++j) {
-        impl::getArgAttrName(mapping->inputNo + j, name);
-        attributes.push_back(rewriter.getNamedAttr(name, attr));
+    if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
+      SmallVector<Attribute, 4> newArgAttrs(
+          llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
+      for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
+        auto mapping = result.getInputMapping(i);
+        assert(mapping.hasValue() &&
+               "unexpected deletion of function argument");
+        for (size_t j = 0; j < mapping->size; ++j)
+          newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
       }
+      attributes.push_back(
+          rewriter.getNamedAttr(function_like_impl::getArgDictAttrName(),
+                                rewriter.getArrayAttr(newArgAttrs)));
     }
 
     // Create an LLVM function, use external linkage by default until MLIR
index 1fa687f..1f081d8 100644 (file)
@@ -599,9 +599,9 @@ parseLaunchFuncOperands(OpAsmParser &parser,
     return success();
   SmallVector<NamedAttrList, 4> argAttrs;
   bool isVariadic = false;
-  return impl::parseFunctionArgumentList(parser, /*allowAttributes=*/false,
-                                         /*allowVariadic=*/false, argNames,
-                                         argTypes, argAttrs, isVariadic);
+  return function_like_impl::parseFunctionArgumentList(
+      parser, /*allowAttributes=*/false,
+      /*allowVariadic=*/false, argNames, argTypes, argAttrs, isVariadic);
 }
 
 static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
@@ -717,7 +717,7 @@ static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   auto signatureLocation = parser.getCurrentLocation();
-  if (failed(impl::parseFunctionSignature(
+  if (failed(function_like_impl::parseFunctionSignature(
           parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
           isVariadic, resultTypes, resultAttrs)))
     return failure();
@@ -756,7 +756,8 @@ static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) {
   // Parse attributes.
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
     return failure();
-  mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
+  function_like_impl::addArgAndResultAttrs(builder, result, argAttrs,
+                                           resultAttrs);
 
   // Parse the region. If no argument names were provided, take all names
   // (including those of attributions) from the entry block.
@@ -781,33 +782,22 @@ static void printGPUFuncOp(OpAsmPrinter &p, GPUFuncOp op) {
   p.printSymbolName(op.getName());
 
   FunctionType type = op.getType();
-  impl::printFunctionSignature(p, op.getOperation(), type.getInputs(),
-                               /*isVariadic=*/false, type.getResults());
+  function_like_impl::printFunctionSignature(
+      p, op.getOperation(), type.getInputs(),
+      /*isVariadic=*/false, type.getResults());
 
   printAttributions(p, op.getWorkgroupKeyword(), op.getWorkgroupAttributions());
   printAttributions(p, op.getPrivateKeyword(), op.getPrivateAttributions());
   if (op.isKernel())
     p << ' ' << op.getKernelKeyword();
 
-  impl::printFunctionAttributes(p, op.getOperation(), type.getNumInputs(),
-                                type.getNumResults(),
-                                {op.getNumWorkgroupAttributionsAttrName(),
-                                 GPUDialect::getKernelFuncAttrName()});
+  function_like_impl::printFunctionAttributes(
+      p, op.getOperation(), type.getNumInputs(), type.getNumResults(),
+      {op.getNumWorkgroupAttributionsAttrName(),
+       GPUDialect::getKernelFuncAttrName()});
   p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false);
 }
 
-void GPUFuncOp::setType(FunctionType newType) {
-  auto oldType = getType();
-  assert(newType.getNumResults() == oldType.getNumResults() &&
-         "unimplemented: changes to the number of results");
-
-  SmallVector<char, 16> nameBuf;
-  for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++)
-    (*this)->removeAttr(getArgAttrName(i, nameBuf));
-
-  (*this)->setAttr(getTypeAttrName(), TypeAttr::get(newType));
-}
-
 /// Hook for FunctionLike verifier.
 LogicalResult GPUFuncOp::verifyType() {
   Type type = getTypeAttr().getValue();
index e1ad37e..12e6ccc 100644 (file)
@@ -1732,21 +1732,19 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
   if (argAttrs.empty())
     return;
 
-  unsigned numInputs = type.cast<LLVMFunctionType>().getNumParams();
-  assert(numInputs == argAttrs.size() &&
+  assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
          "expected as many argument attribute lists as arguments");
-  SmallString<8> argAttrName;
-  for (unsigned i = 0; i < numInputs; ++i)
-    if (DictionaryAttr argDict = argAttrs[i])
-      result.addAttribute(getArgAttrName(i, argAttrName), argDict);
+  function_like_impl::addArgAndResultAttrs(builder, result, argAttrs,
+                                           /*resultAttrs=*/llvm::None);
 }
 
 // Builds an LLVM function type from the given lists of input and output types.
 // Returns a null type if any of the types provided are non-LLVM types, or if
 // there is more than one output type.
-static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
-                                  ArrayRef<Type> inputs, ArrayRef<Type> outputs,
-                                  impl::VariadicFlag variadicFlag) {
+static Type
+buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
+                      ArrayRef<Type> inputs, ArrayRef<Type> outputs,
+                      function_like_impl::VariadicFlag variadicFlag) {
   Builder &b = parser.getBuilder();
   if (outputs.size() > 1) {
     parser.emitError(loc, "failed to construct function type: expected zero or "
@@ -1803,22 +1801,23 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
   auto signatureLocation = parser.getCurrentLocation();
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes) ||
-      impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs,
-                                   argTypes, argAttrs, isVariadic, resultTypes,
-                                   resultAttrs))
+      function_like_impl::parseFunctionSignature(
+          parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs,
+          isVariadic, resultTypes, resultAttrs))
     return failure();
 
   auto type =
       buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
-                            impl::VariadicFlag(isVariadic));
+                            function_like_impl::VariadicFlag(isVariadic));
   if (!type)
     return failure();
-  result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type));
+  result.addAttribute(function_like_impl::getTypeAttrName(),
+                      TypeAttr::get(type));
 
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
     return failure();
-  impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs,
-                             resultAttrs);
+  function_like_impl::addArgAndResultAttrs(parser.getBuilder(), result,
+                                           argAttrs, resultAttrs);
 
   auto *body = result.addRegion();
   OptionalParseResult parseResult = parser.parseOptionalRegion(
@@ -1846,9 +1845,10 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
   if (!returnType.isa<LLVMVoidType>())
     resTypes.push_back(returnType);
 
-  impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes);
-  impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(),
-                                {getLinkageAttrName()});
+  function_like_impl::printFunctionSignature(p, op, argTypes, op.isVarArg(),
+                                             resTypes);
+  function_like_impl::printFunctionAttributes(
+      p, op, argTypes.size(), resTypes.size(), {getLinkageAttrName()});
 
   // Print the body if this is not an external function.
   Region &body = op.body();
index 4ca7da6..04b6353 100644 (file)
@@ -99,7 +99,7 @@ struct FunctionNonEntryBlockConversion : public ConversionPattern {
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.startRootUpdate(op);
-    Region &region = mlir::impl::getFunctionBody(op);
+    Region &region = function_like_impl::getFunctionBody(op);
     SmallVector<TypeConverter::SignatureConversion, 2> conversions;
 
     for (Block &block : llvm::drop_begin(region, 1)) {
index 31a92c3..c74528c 100644 (file)
@@ -1783,13 +1783,14 @@ static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
 
   // Parse the function signature.
   bool isVariadic = false;
-  if (impl::parseFunctionSignature(parser, /*allowVariadic=*/false, entryArgs,
-                                   argTypes, argAttrs, isVariadic, resultTypes,
-                                   resultAttrs))
+  if (function_like_impl::parseFunctionSignature(
+          parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
+          isVariadic, resultTypes, resultAttrs))
     return failure();
 
   auto fnType = builder.getFunctionType(argTypes, resultTypes);
-  state.addAttribute(impl::getTypeAttrName(), TypeAttr::get(fnType));
+  state.addAttribute(function_like_impl::getTypeAttrName(),
+                     TypeAttr::get(fnType));
 
   // Parse the optional function control keyword.
   spirv::FunctionControl fnControl;
@@ -1803,7 +1804,8 @@ static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
   // Add the attributes to the function arguments.
   assert(argAttrs.size() == argTypes.size());
   assert(resultAttrs.size() == resultTypes.size());
-  impl::addArgAndResultAttrs(builder, state, argAttrs, resultAttrs);
+  function_like_impl::addArgAndResultAttrs(builder, state, argAttrs,
+                                           resultAttrs);
 
   // Parse the optional function body.
   auto *body = state.addRegion();
@@ -1817,11 +1819,12 @@ static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) {
   printer << spirv::FuncOp::getOperationName() << " ";
   printer.printSymbolName(fnOp.sym_name());
   auto fnType = fnOp.getType();
-  impl::printFunctionSignature(printer, fnOp, fnType.getInputs(),
-                               /*isVariadic=*/false, fnType.getResults());
+  function_like_impl::printFunctionSignature(printer, fnOp, fnType.getInputs(),
+                                             /*isVariadic=*/false,
+                                             fnType.getResults());
   printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control())
           << "\"";
-  impl::printFunctionAttributes(
+  function_like_impl::printFunctionAttributes(
       printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(),
       {spirv::attributeName<spirv::FunctionControl>()});
 
index 5a59021..6e807a7 100644 (file)
@@ -582,7 +582,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
 
   // Copy over all attributes other than the function name and type.
   for (const auto &namedAttr : funcOp->getAttrs()) {
-    if (namedAttr.first != impl::getTypeAttrName() &&
+    if (namedAttr.first != function_like_impl::getTypeAttrName() &&
         namedAttr.first != SymbolTable::getSymbolAttrName())
       newFuncOp->setAttr(namedAttr.first, namedAttr.second);
   }
index e1706f2..728443e 100644 (file)
@@ -106,27 +106,25 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
   if (argAttrs.empty())
     return;
   assert(type.getNumInputs() == argAttrs.size());
-  SmallString<8> argAttrName;
-  for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
-    if (DictionaryAttr argDict = argAttrs[i])
-      state.addAttribute(getArgAttrName(i, argAttrName), argDict);
+  function_like_impl::addArgAndResultAttrs(builder, state, argAttrs,
+                                           /*resultAttrs=*/llvm::None);
 }
 
 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
   auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
-                          ArrayRef<Type> results, impl::VariadicFlag,
-                          std::string &) {
+                          ArrayRef<Type> results,
+                          function_like_impl::VariadicFlag, std::string &) {
     return builder.getFunctionType(argTypes, results);
   };
 
-  return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/false,
-                                   buildFuncType);
+  return function_like_impl::parseFunctionLikeOp(
+      parser, result, /*allowVariadic=*/false, buildFuncType);
 }
 
 static void print(FuncOp op, OpAsmPrinter &p) {
   FunctionType fnType = op.getType();
-  impl::printFunctionLikeOp(p, op, fnType.getInputs(), /*isVariadic=*/false,
-                            fnType.getResults());
+  function_like_impl::printFunctionLikeOp(
+      p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
 }
 
 static LogicalResult verify(FuncOp op) {
@@ -170,30 +168,39 @@ void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
 /// to cloned sub-values with the corresponding value that is copied, and adds
 /// those mappings to the mapper.
 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
-  FunctionType newType = getType();
+  // Create the new function.
+  FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
 
   // If the function has a body, then the user might be deleting arguments to
   // the function by specifying them in the mapper. If so, we don't add the
   // argument to the input type vector.
-  bool isExternalFn = isExternal();
-  if (!isExternalFn) {
-    SmallVector<Type, 4> inputTypes;
-    inputTypes.reserve(newType.getNumInputs());
-    for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
+  if (!isExternal()) {
+    FunctionType oldType = getType();
+
+    unsigned oldNumArgs = oldType.getNumInputs();
+    SmallVector<Type, 4> newInputs;
+    newInputs.reserve(oldNumArgs);
+    for (unsigned i = 0; i != oldNumArgs; ++i)
       if (!mapper.contains(getArgument(i)))
-        inputTypes.push_back(newType.getInput(i));
-    newType = FunctionType::get(getContext(), inputTypes, newType.getResults());
+        newInputs.push_back(oldType.getInput(i));
+
+    /// If any of the arguments were dropped, update the type and drop any
+    /// necessary argument attributes.
+    if (newInputs.size() != oldNumArgs) {
+      newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
+                                        oldType.getResults()));
+
+      if (ArrayAttr argAttrs = getAllArgAttrs()) {
+        SmallVector<Attribute> newArgAttrs;
+        newArgAttrs.reserve(newInputs.size());
+        for (unsigned i = 0; i != oldNumArgs; ++i)
+          if (!mapper.contains(getArgument(i)))
+            newArgAttrs.push_back(argAttrs[i]);
+        newFunc.setAllArgAttrs(newArgAttrs);
+      }
+    }
   }
 
-  // Create the new function.
-  FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
-  newFunc.setType(newType);
-
-  /// Set the argument attributes for arguments that aren't being replaced.
-  for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
-    if (isExternalFn || !mapper.contains(getArgument(i)))
-      newFunc.setArgAttrs(destI++, getArgAttrs(i));
-
   /// Clone the current function into the new one and return it.
   cloneInto(newFunc, mapper);
   return newFunc;
index 4bec168..aadf545 100644 (file)
@@ -13,7 +13,7 @@
 
 using namespace mlir;
 
-ParseResult mlir::impl::parseFunctionArgumentList(
+ParseResult mlir::function_like_impl::parseFunctionArgumentList(
     OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
@@ -125,7 +125,7 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
 /// indicates whether functions with variadic arguments are supported. The
 /// trailing arguments are populated by this function with names, types and
 /// attributes of the arguments and those of the results.
-ParseResult mlir::impl::parseFunctionSignature(
+ParseResult mlir::function_like_impl::parseFunctionSignature(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
@@ -140,29 +140,53 @@ ParseResult mlir::impl::parseFunctionSignature(
   return success();
 }
 
-void mlir::impl::addArgAndResultAttrs(Builder &builder, OperationState &result,
-                                      ArrayRef<NamedAttrList> argAttrs,
-                                      ArrayRef<NamedAttrList> resultAttrs) {
-  // Add the attributes to the function arguments.
-  SmallString<8> attrNameBuf;
-  for (unsigned i = 0, e = argAttrs.size(); i != e; ++i)
-    if (!argAttrs[i].empty())
-      result.addAttribute(getArgAttrName(i, attrNameBuf),
-                          builder.getDictionaryAttr(argAttrs[i]));
+/// Implementation of `addArgAndResultAttrs` that is attribute list type
+/// agnostic.
+template <typename AttrListT, typename AttrArrayBuildFnT>
+static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result,
+                                     ArrayRef<AttrListT> argAttrs,
+                                     ArrayRef<AttrListT> resultAttrs,
+                                     AttrArrayBuildFnT &&buildAttrArrayFn) {
+  auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); };
 
+  // Add the attributes to the function arguments.
+  if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) {
+    ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs));
+    result.addAttribute(function_like_impl::getArgDictAttrName(), attrDicts);
+  }
   // Add the attributes to the function results.
-  for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i)
-    if (!resultAttrs[i].empty())
-      result.addAttribute(getResultAttrName(i, attrNameBuf),
-                          builder.getDictionaryAttr(resultAttrs[i]));
+  if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) {
+    ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs));
+    result.addAttribute(function_like_impl::getResultDictAttrName(), attrDicts);
+  }
+}
+
+void mlir::function_like_impl::addArgAndResultAttrs(
+    Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
+    ArrayRef<DictionaryAttr> resultAttrs) {
+  auto buildFn = [](ArrayRef<DictionaryAttr> attrs) {
+    return ArrayRef<Attribute>(attrs.data(), attrs.size());
+  };
+  addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
+}
+void mlir::function_like_impl::addArgAndResultAttrs(
+    Builder &builder, OperationState &result, ArrayRef<NamedAttrList> argAttrs,
+    ArrayRef<NamedAttrList> resultAttrs) {
+  MLIRContext *context = builder.getContext();
+  auto buildFn = [=](ArrayRef<NamedAttrList> attrs) {
+    return llvm::to_vector<8>(
+        llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute {
+          return attrList.getDictionary(context);
+        }));
+  };
+  addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
 }
 
 /// Parser implementation for function-like operations.  Uses `funcTypeBuilder`
 /// to construct the custom function type given lists of input and output types.
-ParseResult
-mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
-                                bool allowVariadic,
-                                mlir::impl::FuncTypeBuilder funcTypeBuilder) {
+ParseResult mlir::function_like_impl::parseFunctionLikeOp(
+    OpAsmParser &parser, OperationState &result, bool allowVariadic,
+    FuncTypeBuilder funcTypeBuilder) {
   SmallVector<OpAsmParser::OperandType, 4> entryArgs;
   SmallVector<NamedAttrList, 4> argAttrs;
   SmallVector<NamedAttrList, 4> resultAttrs;
@@ -187,13 +211,14 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
     return failure();
 
   std::string errorMessage;
-  if (auto type = funcTypeBuilder(builder, argTypes, resultTypes,
-                                  impl::VariadicFlag(isVariadic), errorMessage))
-    result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
-  else
+  Type type = funcTypeBuilder(builder, argTypes, resultTypes,
+                              VariadicFlag(isVariadic), errorMessage);
+  if (!type) {
     return parser.emitError(signatureLocation)
            << "failed to construct function type"
            << (errorMessage.empty() ? "" : ": ") << errorMessage;
+  }
+  result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
 
   // If function attributes are present, parse them.
   NamedAttrList parsedAttributes;
@@ -236,35 +261,38 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
   return success();
 }
 
-// Print a function result list.
+/// Print a function result list. The provided `attrs` must either be null, or
+/// contain a set of DictionaryAttrs of the same arity as `types`.
 static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
-                                    ArrayRef<ArrayRef<NamedAttribute>> attrs) {
+                                    ArrayAttr attrs) {
   assert(!types.empty() && "Should not be called for empty result list.");
+  assert((!attrs || attrs.size() == types.size()) &&
+         "Invalid number of attributes.");
+
   auto &os = p.getStream();
-  bool needsParens =
-      types.size() > 1 || types[0].isa<FunctionType>() || !attrs[0].empty();
+  bool needsParens = types.size() > 1 || types[0].isa<FunctionType>() ||
+                     (attrs && !attrs[0].cast<DictionaryAttr>().empty());
   if (needsParens)
     os << '(';
-  llvm::interleaveComma(
-      llvm::zip(types, attrs), os,
-      [&](const std::tuple<Type, ArrayRef<NamedAttribute>> &t) {
-        p.printType(std::get<0>(t));
-        p.printOptionalAttrDict(std::get<1>(t));
-      });
+  llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
+    p.printType(types[i]);
+    if (attrs)
+      p.printOptionalAttrDict(attrs[i].cast<DictionaryAttr>().getValue());
+  });
   if (needsParens)
     os << ')';
 }
 
 /// Print the signature of the function-like operation `op`.  Assumes `op` has
 /// the FunctionLike trait and passed the verification.
-void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
-                                        ArrayRef<Type> argTypes,
-                                        bool isVariadic,
-                                        ArrayRef<Type> resultTypes) {
+void mlir::function_like_impl::printFunctionSignature(
+    OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
+    ArrayRef<Type> resultTypes) {
   Region &body = op->getRegion(0);
   bool isExternal = body.empty();
 
   p << '(';
+  ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
   for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
     if (i > 0)
       p << ", ";
@@ -275,7 +303,8 @@ void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
     }
 
     p.printType(argTypes[i]);
-    p.printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i));
+    if (argAttrs)
+      p.printOptionalAttrDict(argAttrs[i].cast<DictionaryAttr>().getValue());
   }
 
   if (isVariadic) {
@@ -288,9 +317,7 @@ void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
 
   if (!resultTypes.empty()) {
     p.getStream() << " -> ";
-    SmallVector<ArrayRef<NamedAttribute>, 4> resultAttrs;
-    for (int i = 0, e = resultTypes.size(); i < e; ++i)
-      resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i));
+    auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
     printFunctionResultList(p, resultTypes, resultAttrs);
   }
 }
@@ -300,39 +327,25 @@ void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
 /// function-like operation internally are not printed. Nothing is printed
 /// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and
 /// passed the verification.
-void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op,
-                                         unsigned numInputs,
-                                         unsigned numResults,
-                                         ArrayRef<StringRef> elided) {
+void mlir::function_like_impl::printFunctionAttributes(
+    OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
+    ArrayRef<StringRef> elided) {
   // Print out function attributes, if present.
   SmallVector<StringRef, 2> ignoredAttrs = {
-      ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()};
+      ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(),
+      getArgDictAttrName(), getResultDictAttrName()};
   ignoredAttrs.append(elided.begin(), elided.end());
 
-  SmallString<8> attrNameBuf;
-
-  // Ignore any argument attributes.
-  std::vector<SmallString<8>> argAttrStorage;
-  for (unsigned i = 0; i != numInputs; ++i)
-    if (op->getAttr(getArgAttrName(i, attrNameBuf)))
-      argAttrStorage.emplace_back(attrNameBuf);
-  ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end());
-
-  // Ignore any result attributes.
-  std::vector<SmallString<8>> resultAttrStorage;
-  for (unsigned i = 0; i != numResults; ++i)
-    if (op->getAttr(getResultAttrName(i, attrNameBuf)))
-      resultAttrStorage.emplace_back(attrNameBuf);
-  ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end());
-
   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
 }
 
 /// Printer implementation for function-like operations.  Accepts lists of
 /// argument and result types to use while printing.
-void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
-                                     ArrayRef<Type> argTypes, bool isVariadic,
-                                     ArrayRef<Type> resultTypes) {
+void mlir::function_like_impl::printFunctionLikeOp(OpAsmPrinter &p,
+                                                   Operation *op,
+                                                   ArrayRef<Type> argTypes,
+                                                   bool isVariadic,
+                                                   ArrayRef<Type> resultTypes) {
   // Print the operation and the function name.
   auto funcName =
       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
index 347ea15..2538271 100644 (file)
@@ -31,103 +31,199 @@ inline void iterateIndicesExcept(unsigned totalIndices,
 // Function Arguments and Results.
 //===----------------------------------------------------------------------===//
 
-void mlir::impl::eraseFunctionArguments(Operation *op,
-                                        ArrayRef<unsigned> argIndices,
-                                        unsigned originalNumArgs,
-                                        Type newType) {
+static bool isEmptyAttrDict(Attribute attr) {
+  return attr.cast<DictionaryAttr>().empty();
+}
+
+DictionaryAttr mlir::function_like_impl::getArgAttrDict(Operation *op,
+                                                        unsigned index) {
+  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+  DictionaryAttr argAttrs =
+      attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
+  return (argAttrs && !argAttrs.empty()) ? argAttrs : DictionaryAttr();
+}
+
+DictionaryAttr mlir::function_like_impl::getResultAttrDict(Operation *op,
+                                                           unsigned index) {
+  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+  DictionaryAttr resAttrs =
+      attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
+  return (resAttrs && !resAttrs.empty()) ? resAttrs : DictionaryAttr();
+}
+
+void mlir::function_like_impl::detail::setArgResAttrDict(
+    Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
+    DictionaryAttr attrs) {
+  ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
+  if (!allAttrs) {
+    if (attrs.empty())
+      return;
+
+    // If this attribute is not empty, we need to create a new attribute array.
+    SmallVector<Attribute, 8> newAttrs(numTotalIndices,
+                                       DictionaryAttr::get(op->getContext()));
+    newAttrs[index] = attrs;
+    op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
+    return;
+  }
+  // Check to see if the attribute is different from what we already have.
+  if (allAttrs[index] == attrs)
+    return;
+
+  // If it is, check to see if the attribute array would now contain only empty
+  // dictionaries.
+  ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
+  if (attrs.empty() &&
+      llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
+      llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
+    op->removeAttr(attrName);
+    return;
+  }
+
+  // Otherwise, create a new attribute array with the updated dictionary.
+  SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
+  newAttrs[index] = attrs;
+  op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
+}
+
+/// Set all of the argument or result attribute dictionaries for a function.
+static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
+                                  ArrayRef<Attribute> attrs) {
+  if (llvm::all_of(attrs, isEmptyAttrDict))
+    op->removeAttr(attrName);
+  else
+    op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
+}
+
+void mlir::function_like_impl::setAllArgAttrDicts(
+    Operation *op, ArrayRef<DictionaryAttr> attrs) {
+  setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+}
+void mlir::function_like_impl::setAllArgAttrDicts(Operation *op,
+                                                  ArrayRef<Attribute> attrs) {
+  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
+    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
+  });
+  setAllArgResAttrDicts(op, getArgDictAttrName(),
+                        llvm::to_vector<8>(wrappedAttrs));
+}
+
+void mlir::function_like_impl::setAllResultAttrDicts(
+    Operation *op, ArrayRef<DictionaryAttr> attrs) {
+  setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
+}
+void mlir::function_like_impl::setAllResultAttrDicts(
+    Operation *op, ArrayRef<Attribute> attrs) {
+  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
+    return !attr ? DictionaryAttr::get(op->getContext()) : attr;
+  });
+  setAllArgResAttrDicts(op, getResultDictAttrName(),
+                        llvm::to_vector<8>(wrappedAttrs));
+}
+
+void mlir::function_like_impl::eraseFunctionArguments(
+    Operation *op, ArrayRef<unsigned> argIndices, unsigned originalNumArgs,
+    Type newType) {
   // There are 3 things that need to be updated:
   // - Function type.
   // - Arg attrs.
   // - Block arguments of entry block.
   Block &entry = op->getRegion(0).front();
-  SmallString<8> nameBuf;
-
-  // Collect arg attrs to set.
-  SmallVector<DictionaryAttr, 4> newArgAttrs;
-  iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
-    newArgAttrs.emplace_back(getArgAttrDict(op, i));
-  });
-
-  // Remove any arg attrs that are no longer needed.
-  for (unsigned i = newArgAttrs.size(), e = originalNumArgs; i < e; ++i)
-    op->removeAttr(getArgAttrName(i, nameBuf));
-
-  // Set the function type.
-  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
 
-  // Set the new arg attrs, or remove them if empty.
-  for (unsigned i = 0, e = newArgAttrs.size(); i != e; ++i) {
-    auto nameAttr = getArgAttrName(i, nameBuf);
-    if (newArgAttrs[i] && !newArgAttrs[i].empty())
-      op->setAttr(nameAttr, newArgAttrs[i]);
-    else
-      op->removeAttr(nameAttr);
+  // Update the argument attributes of the function.
+  if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
+    SmallVector<DictionaryAttr, 4> newArgAttrs;
+    newArgAttrs.reserve(argAttrs.size());
+    iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
+      newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
+    });
+    setAllArgAttrDicts(op, newArgAttrs);
   }
 
-  // Update the entry block's arguments.
+  // Update the function type and any entry block arguments.
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
   entry.eraseArguments(argIndices);
 }
 
-void mlir::impl::eraseFunctionResults(Operation *op,
-                                      ArrayRef<unsigned> resultIndices,
-                                      unsigned originalNumResults,
-                                      Type newType) {
+void mlir::function_like_impl::eraseFunctionResults(
+    Operation *op, ArrayRef<unsigned> resultIndices,
+    unsigned originalNumResults, Type newType) {
   // There are 2 things that need to be updated:
   // - Function type.
   // - Result attrs.
-  SmallString<8> nameBuf;
-
-  // Collect result attrs to set.
-  SmallVector<DictionaryAttr, 4> newResultAttrs;
-  iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
-    newResultAttrs.emplace_back(getResultAttrDict(op, i));
-  });
 
-  // Remove any result attrs that are no longer needed.
-  for (unsigned i = newResultAttrs.size(), e = originalNumResults; i < e; ++i)
-    op->removeAttr(getResultAttrName(i, nameBuf));
+  // Update the result attributes of the function.
+  if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
+    SmallVector<DictionaryAttr, 4> newResultAttrs;
+    newResultAttrs.reserve(resAttrs.size());
+    iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
+      newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
+    });
+    setAllResultAttrDicts(op, newResultAttrs);
+  }
 
-  // Set the function type.
+  // Update the function type.
   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
-
-  // Set the new result attrs, or remove them if empty.
-  for (unsigned i = 0, e = newResultAttrs.size(); i != e; ++i) {
-    auto nameAttr = getResultAttrName(i, nameBuf);
-    if (newResultAttrs[i] && !newResultAttrs[i].empty())
-      op->setAttr(nameAttr, newResultAttrs[i]);
-    else
-      op->removeAttr(nameAttr);
-  }
 }
 
 //===----------------------------------------------------------------------===//
 // Function type signature.
 //===----------------------------------------------------------------------===//
 
-FunctionType mlir::impl::getFunctionType(Operation *op) {
+FunctionType mlir::function_like_impl::getFunctionType(Operation *op) {
   assert(op->hasTrait<OpTrait::FunctionLike>());
-  return op->getAttrOfType<TypeAttr>(mlir::impl::getTypeAttrName())
+  return op->getAttrOfType<TypeAttr>(getTypeAttrName())
       .getValue()
       .cast<FunctionType>();
 }
 
-void mlir::impl::setFunctionType(Operation *op, FunctionType newType) {
+void mlir::function_like_impl::setFunctionType(Operation *op,
+                                               FunctionType newType) {
   assert(op->hasTrait<OpTrait::FunctionLike>());
-  SmallVector<char, 16> nameBuf;
   FunctionType oldType = getFunctionType(op);
-
-  for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++)
-    op->removeAttr(getArgAttrName(i, nameBuf));
-  for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; i++)
-    op->removeAttr(getResultAttrName(i, nameBuf));
   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+
+  // Functor used to update the argument and result attributes of the function.
+  auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
+                          unsigned newCount, auto setAttrFn) {
+    if (oldCount == newCount)
+      return;
+    // The new type has no arguments/results, just drop the attribute.
+    if (newCount == 0) {
+      op->removeAttr(attrName);
+      return;
+    }
+    ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
+    if (!attrs)
+      return;
+
+    // The new type has less arguments/results, take the first N attributes.
+    if (newCount < oldCount)
+      return setAttrFn(op, attrs.getValue().take_front(newCount));
+
+    // Otherwise, the new type has more arguments/results. Initialize the new
+    // arguments/results with empty attributes.
+    SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
+    newAttrs.resize(newCount);
+    setAttrFn(op, newAttrs);
+  };
+
+  // Update the argument and result attributes.
+  updateAttrFn(function_like_impl::getArgDictAttrName(), oldType.getNumInputs(),
+               newType.getNumInputs(), [&](Operation *op, auto &&attrs) {
+                 setAllArgAttrDicts(op, attrs);
+               });
+  updateAttrFn(
+      function_like_impl::getResultDictAttrName(), oldType.getNumResults(),
+      newType.getNumResults(),
+      [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
 }
 
 //===----------------------------------------------------------------------===//
 // Function body.
 //===----------------------------------------------------------------------===//
 
-Region &mlir::impl::getFunctionBody(Operation *op) {
+Region &mlir::function_like_impl::getFunctionBody(Operation *op) {
   assert(op->hasTrait<OpTrait::FunctionLike>());
   return op->getRegion(0);
 }
index c8bb22e..00b006c 100644 (file)
@@ -2628,15 +2628,15 @@ struct FunctionLikeSignatureConversion : public ConversionPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    FunctionType type = mlir::impl::getFunctionType(op);
+    FunctionType type = function_like_impl::getFunctionType(op);
 
     // Convert the original function types.
     TypeConverter::SignatureConversion result(type.getNumInputs());
     SmallVector<Type, 1> newResults;
     if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
         failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
-        failed(rewriter.convertRegionTypes(&mlir::impl::getFunctionBody(op),
-                                           *typeConverter, &result)))
+        failed(rewriter.convertRegionTypes(
+            &function_like_impl::getFunctionBody(op), *typeConverter, &result)))
       return failure();
 
     // Update the function signature in-place.
@@ -2644,7 +2644,7 @@ struct FunctionLikeSignatureConversion : public ConversionPattern {
                                      result.getConvertedTypes(), newResults);
 
     rewriter.updateRootInPlace(
-        op, [&] { mlir::impl::setFunctionType(op, newType); });
+        op, [&] { function_like_impl::setFunctionType(op, newType); });
 
     return success();
   }
index ab32af2..e52acf6 100644 (file)
@@ -35,7 +35,7 @@ module {
   // CHECK: attributes {xxx = {yyy = 42 : i64}}
   "llvm.func"() ({
   }) {sym_name = "qux", type = !llvm.func<void (ptr<i64>, i64)>,
-      arg0 = {llvm.noalias = true}, xxx = {yyy = 42}} : () -> ()
+      arg_attrs = [{llvm.noalias = true}, {}], xxx = {yyy = 42}} : () -> ()
 
   // CHECK: llvm.func @roundtrip1()
   llvm.func @roundtrip1()
index c2ceefe..6b15a07 100644 (file)
@@ -94,3 +94,22 @@ func private @invalid_symbol_name_attr() attributes { sym_name = "x" }
 // expected-error@+1 {{'type' is an inferred attribute and should not be specified in the explicit attribute dictionary}}
 func private @invalid_symbol_type_attr() attributes { type = "x" }
 
+// -----
+
+// expected-error@+1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}}
+func private @invalid_arg_attrs() attributes { arg_attrs = [{}] }
+
+// -----
+
+// expected-error@+1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
+func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] }
+
+// -----
+
+// expected-error@+1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}}
+func private @invalid_res_attrs() attributes { res_attrs = [{}] }
+
+// -----
+
+// expected-error@+1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}}
+func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] }
index 05a1393..42f56ae 100644 (file)
@@ -9,7 +9,6 @@
 // Test case: The setType call needs to erase some arg attrs.
 
 // CHECK: func private @erase_arg(f32 {test.A})
-// CHECK-NOT: attributes{{.*arg[0-9]}}
 func private @t(f32)
 func private @erase_arg(%arg0: f32 {test.A}, %arg1: f32 {test.B})
 attributes {test.set_type_from = @t}
@@ -19,7 +18,6 @@ attributes {test.set_type_from = @t}
 // Test case: The setType call needs to erase some result attrs.
 
 // CHECK: func private @erase_result() -> (f32 {test.A})
-// CHECK-NOT: attributes{{.*result[0-9]}}
 func private @t() -> (f32)
 func private @erase_result() -> (f32 {test.A}, f32 {test.B})
 attributes {test.set_type_from = @t}