Add support for function result attributes.
authorSean Silva <silvasean@google.com>
Fri, 18 Oct 2019 23:02:56 +0000 (16:02 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 18 Oct 2019 23:03:28 +0000 (16:03 -0700)
This allows dialect-specific attributes to be attached to func results. (or more specifically, FunctionLike ops).

For example:

```
func @f() -> (i32 {my_dialect.some_attr = 3})
```

This attaches my_dialect.some_attr with value 3 to the first result of func @f.

Another more complex example:

```
func @g() -> (i32, f32 {my_dialect.some_attr = "foo", other_dialect.some_other_attr = [1,2,3]}, i1)
```

Here, the second result has two attributes attached.

PiperOrigin-RevId: 275564165

13 files changed:
mlir/g3doc/LangRef.md
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/Function.h
mlir/include/mlir/IR/FunctionSupport.h
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/FunctionSupport.cpp
mlir/test/IR/invalid-func-op.mlir
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestDialect.h

index 493ade1..391f773 100644 (file)
@@ -321,11 +321,19 @@ name via a string attribute like [SymbolRefAttr](#symbol-reference-attribute)):
 function ::= `func` function-signature function-attributes? function-body?
 
 function-signature ::= symbol-ref-id `(` argument-list `)`
-                       (`->` function-result-type)?
+                       (`->` function-result-list)?
+
 argument-list ::= (named-argument (`,` named-argument)*) | /*empty*/
 argument-list ::= (type attribute-dict? (`,` type attribute-dict?)*) | /*empty*/
 named-argument ::= ssa-id `:` type attribute-dict?
 
+function-result-list ::= function-result-list-parens
+                       | non-function-type
+function-result-list-parens ::= `(` `)`
+                              | `(` function-result-list-no-parens `)`
+function-result-list-no-parens ::= function-result (`,` function-result)*
+function-result ::= type attribute-dict?
+
 function-attributes ::= `attributes` attribute-dict
 function-body ::= region
 ```
index a8004a4..35fe515 100644 (file)
@@ -546,6 +546,10 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",
     // Depends on the type attribute being correct as checked by verifyType.
     unsigned getNumFuncArguments();
 
+    // Hook for OpTrait::FunctionLike, returns the number of function results.
+    // Depends on the type attribute being correct as checked by verifyType.
+    unsigned getNumFuncResults();
+
     // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
     // attribute is present.  This can check for preconditions of the
     // getNumArguments hook not failing.
index 810d11c..bf7db91 100644 (file)
@@ -146,6 +146,15 @@ public:
                                                  unsigned argIndex,
                                                  NamedAttribute);
 
+  /// Verify an attribute from this dialect on the result at 'resultIndex' for
+  /// the region at 'regionIndex' on the given operation. Returns failure if
+  /// the verification failed, success otherwise. This hook may optionally be
+  /// invoked from any operation containing a region.
+  virtual LogicalResult verifyRegionResultAttribute(Operation *,
+                                                    unsigned regionIndex,
+                                                    unsigned resultIndex,
+                                                    NamedAttribute);
+
   /// Verify an attribute from this dialect on the given operation. Returns
   /// failure if the verification failed, success otherwise.
   virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) {
index 95920b3..a7777c6 100644 (file)
@@ -136,16 +136,18 @@ public:
   }
 
 private:
-  // This trait needs access to `getNumFuncArguments` and `verifyType` hooks
-  // defined below.
+  // This trait needs access to the hooks defined below.
   friend class OpTrait::FunctionLike<FuncOp>;
 
   /// Returns the number of arguments. This is a hook for OpTrait::FunctionLike.
   unsigned getNumFuncArguments() { return getType().getInputs().size(); }
 
+  /// Returns the number of results. This is a hook for OpTrait::FunctionLike.
+  unsigned getNumFuncResults() { return getType().getResults().size(); }
+
   /// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
   /// attribute is present and checks if it holds a function type.  Ensures
-  /// getType and getNumFuncArguments can be called safely.
+  /// getType, getNumFuncArguments, and getNumFuncResults can be called safely.
   LogicalResult verifyType() {
     auto type = getTypeAttr().getValue();
     if (!type.isa<FunctionType>())
index 7fa27ff..ccac4d3 100644 (file)
@@ -39,6 +39,12 @@ inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
   return ("arg" + Twine(arg)).toStringRef(out);
 }
 
+/// Return the name of the attribute used for function results.
+inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl<char> &out) {
+  out.clear();
+  return ("result" + Twine(arg)).toStringRef(out);
+}
+
 /// Returns the dictionary attribute corresponding to the argument at 'index'.
 /// If there are no argument attributes at 'index', a null attribute is
 /// returned.
@@ -47,12 +53,26 @@ inline DictionaryAttr getArgAttrDict(Operation *op, unsigned index) {
   return op->getAttrOfType<DictionaryAttr>(getArgAttrName(index, nameOut));
 }
 
+/// Returns the dictionary attribute corresponding to the result at 'index'.
+/// If there are no result attributes at 'index', a null attribute is
+/// returned.
+inline DictionaryAttr getResultAttrDict(Operation *op, unsigned index) {
+  SmallString<8> nameOut;
+  return op->getAttrOfType<DictionaryAttr>(getResultAttrName(index, nameOut));
+}
+
 /// Return all of the attributes for the argument at 'index'.
 inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
   auto argDict = getArgAttrDict(op, index);
   return argDict ? argDict.getValue() : llvm::None;
 }
 
+/// Return all of the attributes for the result at 'index'.
+inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
+  auto resultDict = getResultAttrDict(op, index);
+  return resultDict ? resultDict.getValue() : llvm::None;
+}
+
 /// A named class for passing around the variadic flag.
 class VariadicFlag {
 public:
@@ -87,7 +107,7 @@ ParseResult parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
 /// argument and result types to use while printing.
 void printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
                          ArrayRef<Type> argTypes, bool isVariadic,
-                         ArrayRef<Type> results);
+                         ArrayRef<Type> resultTypes);
 
 } // namespace impl
 
@@ -111,10 +131,13 @@ namespace OpTrait {
 /// - Concrete ops *must* define a member function `getNumFuncArguments()` that
 /// returns the number of function arguments based exclusively on type (so that
 /// it can be called on function declarations).
+/// - Concrete ops *must* define a member function `getNumFuncResults()` that
+/// returns the number of function results based exclusively on type (so that
+/// it can be called on function declarations).
 /// - To verify that the type respects op-specific invariants, concrete ops may
 /// redefine the `verifyType()` hook that will be called after verifying the
 /// presence of the `type` attribute and before any call to
-/// `getNumFuncArguments` from the verifier.
+/// `getNumFuncArguments`/`getNumFuncResults` from the verifier.
 template <typename ConcreteType>
 class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
 public:
@@ -202,6 +225,10 @@ public:
     return static_cast<ConcreteType *>(this)->getNumFuncArguments();
   }
 
+  unsigned getNumResults() {
+    return static_cast<ConcreteType *>(this)->getNumFuncResults();
+  }
+
   /// Gets argument.
   BlockArgument *getArgument(unsigned idx) {
     return getBlocks().front().getArgument(idx);
@@ -278,11 +305,75 @@ public:
   NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
                                                  Identifier name);
 
+  //===--------------------------------------------------------------------===//
+  // Result Attributes
+  //===--------------------------------------------------------------------===//
+
+  /// FunctionLike operations allow for attaching attributes to each of the
+  /// respective function results. These result attributes are stored as
+  /// DictionaryAttrs in the main operation attribute dictionary. The name of
+  /// these entries is `result` followed by the index of the result. These
+  /// result attribute dictionaries are optional, and will generally only
+  /// exist if they are non-empty.
+
+  /// Return all of the attributes for the result at 'index'.
+  ArrayRef<NamedAttribute> getResultAttrs(unsigned index) {
+    return ::mlir::impl::getResultAttrs(this->getOperation(), index);
+  }
+
+  /// Return all result attributes of this function.
+  void getAllResultAttrs(SmallVectorImpl<NamedAttributeList> &result) {
+    for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+      result.emplace_back(getResultAttrDict(i));
+  }
+
+  /// Return the specified attribute, if present, for the result at 'index',
+  /// null otherwise.
+  Attribute getResultAttr(unsigned index, Identifier name) {
+    auto argDict = getResultAttrDict(index);
+    return argDict ? argDict.get(name) : nullptr;
+  }
+  Attribute getResultAttr(unsigned index, StringRef name) {
+    auto argDict = getResultAttrDict(index);
+    return argDict ? argDict.get(name) : nullptr;
+  }
+
+  template <typename AttrClass>
+  AttrClass getResultAttrOfType(unsigned index, Identifier name) {
+    return getResultAttr(index, name).template dyn_cast_or_null<AttrClass>();
+  }
+  template <typename AttrClass>
+  AttrClass getResultAttrOfType(unsigned index, StringRef name) {
+    return getResultAttr(index, name).template dyn_cast_or_null<AttrClass>();
+  }
+
+  /// Set the attributes held by the result at 'index'.
+  void setResultAttrs(unsigned index, ArrayRef<NamedAttribute> attributes);
+  void setResultAttrs(unsigned index, NamedAttributeList attributes);
+  void setAllResultAttrs(ArrayRef<NamedAttributeList> attributes) {
+    assert(attributes.size() == getNumResults());
+    for (unsigned i = 0, e = attributes.size(); i != e; ++i)
+      setResultAttrs(i, attributes[i]);
+  }
+
+  /// If the an attribute exists with the specified name, change it to the new
+  /// value. Otherwise, add a new attribute with the specified name/value.
+  void setResultAttr(unsigned index, Identifier name, Attribute value);
+  void setResultAttr(unsigned index, StringRef name, Attribute value) {
+    setResultAttr(index,
+                  Identifier::get(name, this->getOperation()->getContext()),
+                  value);
+  }
+
+  /// Remove the attribute 'name' from the result at 'index'.
+  NamedAttributeList::RemoveResult removeResultAttr(unsigned index,
+                                                    Identifier name);
+
 protected:
   /// Returns the attribute entry name for the set of argument attributes at
-  /// index 'arg'.
-  static StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
-    return ::mlir::impl::getArgAttrName(arg, out);
+  /// 'index'.
+  static StringRef getArgAttrName(unsigned index, SmallVectorImpl<char> &out) {
+    return ::mlir::impl::getArgAttrName(index, out);
   }
 
   /// Returns the dictionary attribute corresponding to the argument at 'index'.
@@ -293,6 +384,21 @@ protected:
     return ::mlir::impl::getArgAttrDict(this->getOperation(), index);
   }
 
+  /// Returns the attribute entry name for the set of result attributes at
+  /// 'index'.
+  static StringRef getResultAttrName(unsigned index,
+                                     SmallVectorImpl<char> &out) {
+    return ::mlir::impl::getResultAttrName(index, out);
+  }
+
+  /// Returns the dictionary attribute corresponding to the result at 'index'.
+  /// If there are no result attributes at 'index', a null attribute is
+  /// returned.
+  DictionaryAttr getResultAttrDict(unsigned index) {
+    assert(index < getNumResults() && "invalid result number");
+    return ::mlir::impl::getResultAttrDict(this->getOperation(), index);
+  }
+
   /// Hook for concrete classes to verify that the type attribute respects
   /// op-specific invariants.  Default implementation always succeeds.
   LogicalResult verifyType() { return success(); }
@@ -326,6 +432,23 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
     }
   }
 
+  for (unsigned i = 0, e = funcOp.getNumResults(); i != e; ++i) {
+    // Verify that all of the result attributes are dialect attributes, i.e.
+    // that they contain a dialect prefix in their name.  Call the dialect, if
+    // registered, to verify the attributes themselves.
+    for (auto attr : funcOp.getResultAttrs(i)) {
+      if (!attr.first.strref().contains('.'))
+        return funcOp.emitOpError("results may only have dialect attributes");
+      auto dialectNamePair = attr.first.strref().split('.');
+      if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
+        if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
+                                                        /*resultIndex=*/i,
+                                                        attr)))
+          return failure();
+      }
+    }
+  }
+
   // Check that the op has exactly one region for the body.
   if (op->getNumRegions() != 1)
     return funcOp.emitOpError("expects one region");
@@ -354,10 +477,10 @@ void FunctionLike<ConcreteType>::setArgAttrs(
   assert(index < getNumArguments() && "invalid argument number");
   SmallString<8> nameOut;
   getArgAttrName(index, nameOut);
-  Operation *op = this->getOperation();
 
   if (attributes.empty())
     return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
+  Operation *op = this->getOperation();
   op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
 }
 
@@ -400,6 +523,64 @@ FunctionLike<ConcreteType>::removeArgAttr(unsigned index, Identifier name) {
   return result;
 }
 
+//===----------------------------------------------------------------------===//
+// Function Result Attribute.
+//===----------------------------------------------------------------------===//
+
+/// Set the attributes held by the result at 'index'.
+template <typename ConcreteType>
+void FunctionLike<ConcreteType>::setResultAttrs(
+    unsigned index, ArrayRef<NamedAttribute> attributes) {
+  assert(index < getNumResults() && "invalid result number");
+  SmallString<8> nameOut;
+  getResultAttrName(index, nameOut);
+
+  if (attributes.empty())
+    return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
+  Operation *op = this->getOperation();
+  op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
+}
+
+template <typename ConcreteType>
+void FunctionLike<ConcreteType>::setResultAttrs(unsigned index,
+                                                NamedAttributeList attributes) {
+  assert(index < getNumResults() && "invalid result number");
+  SmallString<8> nameOut;
+  if (auto newAttr = attributes.getDictionary())
+    return this->getOperation()->setAttr(getResultAttrName(index, nameOut),
+                                         newAttr);
+  static_cast<ConcreteType *>(this)->removeAttr(
+      getResultAttrName(index, nameOut));
+}
+
+/// If the an attribute exists with the specified name, change it to the new
+/// value. Otherwise, add a new attribute with the specified name/value.
+template <typename ConcreteType>
+void FunctionLike<ConcreteType>::setResultAttr(unsigned index, Identifier name,
+                                               Attribute value) {
+  auto curAttr = getResultAttrDict(index);
+  NamedAttributeList attrList(curAttr);
+  attrList.set(name, value);
+
+  // If the attribute changed, then set the new arg attribute list.
+  if (curAttr != attrList.getDictionary())
+    setResultAttrs(index, attrList);
+}
+
+/// Remove the attribute 'name' from the result at 'index'.
+template <typename ConcreteType>
+NamedAttributeList::RemoveResult
+FunctionLike<ConcreteType>::removeResultAttr(unsigned index, Identifier name) {
+  // Build an attribute list and remove the attribute at 'name'.
+  NamedAttributeList attrList(getResultAttrDict(index));
+  auto result = attrList.remove(name);
+
+  // If the attribute was removed, then update the result dictionary.
+  if (result == NamedAttributeList::RemoveResult::Removed)
+    setResultAttrs(index, attrList);
+  return result;
+}
+
 } // end namespace OpTrait
 
 } // end namespace mlir
index 23e3889..618ee23 100644 (file)
@@ -1115,6 +1115,21 @@ unsigned LLVMFuncOp::getNumFuncArguments() {
   return getType().getUnderlyingType()->getFunctionNumParams();
 }
 
+// Hook for OpTrait::FunctionLike, returns the number of function results.
+// Depends on the type attribute being correct as checked by verifyType
+unsigned LLVMFuncOp::getNumFuncResults() {
+  llvm::FunctionType *funcType =
+      cast<llvm::FunctionType>(getType().getUnderlyingType());
+  // We model LLVM functions that return void as having zero results,
+  // and all others as having one result.
+  // If we modeled a void return as one result, then it would be possible to
+  // attach an MLIR result attribute to it, and it isn't clear what semantics we
+  // would assign to that.
+  if (funcType->getReturnType()->isVoidTy())
+    return 0;
+  return 1;
+}
+
 static LogicalResult verify(LLVMFuncOp op) {
   if (op.isExternal())
     return success();
index 6a7dcae..f8539c0 100644 (file)
@@ -89,6 +89,15 @@ LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
   return success();
 }
 
+/// Verify an attribute from this dialect on the result at 'resultIndex' for
+/// the region at 'regionIndex' on the given operation. Returns failure if
+/// the verification failed, success otherwise. This hook may optionally be
+/// invoked from any operation containing a region.
+LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
+                                                   unsigned, NamedAttribute) {
+  return success();
+}
+
 /// Parse an attribute registered to this dialect.
 Attribute Dialect::parseAttribute(StringRef attrData, Type type,
                                   Location loc) const {
index 468301e..22f207e 100644 (file)
@@ -88,6 +88,45 @@ parseArgumentList(OpAsmParser &parser, bool allowVariadic,
   return success();
 }
 
+/// Parse a function result list.
+///
+///   function-result-list ::= function-result-list-parens
+///                          | non-function-type
+///   function-result-list-parens ::= `(` `)`
+///                                 | `(` function-result-list-no-parens `)`
+///   function-result-list-no-parens ::= function-result (`,` function-result)*
+///   function-result ::= type attribute-dict?
+///
+static ParseResult parseFunctionResultList(
+    OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) {
+  if (failed(parser.parseOptionalLParen())) {
+    // We already know that there is no `(`, so parse a type.
+    // Because there is no `(`, it cannot be a function type.
+    Type ty;
+    if (parser.parseType(ty))
+      return failure();
+    resultTypes.push_back(ty);
+    resultAttrs.emplace_back();
+    return success();
+  }
+
+  // Special case for an empty set of parens.
+  if (succeeded(parser.parseOptionalRParen()))
+    return success();
+
+  // Parse individual function results.
+  do {
+    resultTypes.emplace_back();
+    resultAttrs.emplace_back();
+    if (parser.parseType(resultTypes.back()) ||
+        parser.parseOptionalAttributeDict(resultAttrs.back())) {
+      return failure();
+    }
+  } while (succeeded(parser.parseOptionalComma()));
+  return parser.parseRParen();
+}
+
 /// Parse a function signature, starting with a name and including the
 /// parameter list.
 static ParseResult parseFunctionSignature(
@@ -95,12 +134,14 @@ static ParseResult parseFunctionSignature(
     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
     SmallVectorImpl<Type> &argTypes,
     SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, bool &isVariadic,
-    SmallVectorImpl<Type> &results) {
+    SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) {
   if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs,
                         isVariadic))
     return failure();
-  // Parse the return types if present.
-  return parser.parseOptionalArrowTypeList(results);
+  if (succeeded(parser.parseOptionalArrow()))
+    return parseFunctionResultList(parser, resultTypes, resultAttrs);
+  return success();
 }
 
 /// Parser implementation for function-like operations.  Uses `funcTypeBuilder`
@@ -111,8 +152,9 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
                                 mlir::impl::FuncTypeBuilder funcTypeBuilder) {
   SmallVector<OpAsmParser::OperandType, 4> entryArgs;
   SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
+  SmallVector<SmallVector<NamedAttribute, 2>, 4> resultAttrs;
   SmallVector<Type, 4> argTypes;
-  SmallVector<Type, 4> results;
+  SmallVector<Type, 4> resultTypes;
   auto &builder = parser.getBuilder();
 
   // Parse the name as a symbol reference attribute.
@@ -127,11 +169,11 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
   auto signatureLocation = parser.getCurrentLocation();
   bool isVariadic = false;
   if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
-                             argAttrs, isVariadic, results))
+                             argAttrs, isVariadic, resultTypes, resultAttrs))
     return failure();
 
   std::string errorMessage;
-  if (auto type = funcTypeBuilder(builder, argTypes, results,
+  if (auto type = funcTypeBuilder(builder, argTypes, resultTypes,
                                   impl::VariadicFlag(isVariadic), errorMessage))
     result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
   else
@@ -145,12 +187,18 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
       return failure();
 
   // Add the attributes to the function arguments.
-  SmallString<8> argAttrName;
+  SmallString<8> attrNameBuf;
   for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
     if (!argAttrs[i].empty())
-      result.addAttribute(getArgAttrName(i, argAttrName),
+      result.addAttribute(getArgAttrName(i, attrNameBuf),
                           builder.getDictionaryAttr(argAttrs[i]));
 
+  // Add the attributes to the function results.
+  for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
+    if (!resultAttrs[i].empty())
+      result.addAttribute(getResultAttrName(i, attrNameBuf),
+                          builder.getDictionaryAttr(resultAttrs[i]));
+
   // Parse the optional function body.
   auto *body = result.addRegion();
   if (parser.parseOptionalRegion(*body, entryArgs,
@@ -161,11 +209,29 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
   return success();
 }
 
+// Print a function result list.
+static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
+                                    ArrayRef<ArrayRef<NamedAttribute>> attrs) {
+  assert(!types.empty() && "Should not be called for empty result list.");
+  auto &os = p.getStream();
+  bool needsParens =
+      types.size() > 1 || types[0].isa<FunctionType>() || !attrs[0].empty();
+  if (needsParens)
+    os << '(';
+  interleaveComma(llvm::zip(types, attrs), os,
+                  [&](const std::tuple<Type, ArrayRef<NamedAttribute>> &t) {
+                    p.printType(std::get<0>(t));
+                    p.printOptionalAttrDict(std::get<1>(t));
+                  });
+  if (needsParens)
+    os << ')';
+}
+
 /// Print the signature of the function-like operation `op`.  Assumes `op` has
 /// the FunctionLike trait and passed the verification.
 static void printSignature(OpAsmPrinter &p, Operation *op,
                            ArrayRef<Type> argTypes, bool isVariadic,
-                           ArrayRef<Type> results) {
+                           ArrayRef<Type> resultTypes) {
   Region &body = op->getRegion(0);
   bool isExternal = body.empty();
 
@@ -190,14 +256,21 @@ static void printSignature(OpAsmPrinter &p, Operation *op,
   }
 
   p << ')';
-  p.printOptionalArrowTypeList(results);
+
+  if (!resultTypes.empty()) {
+    p.getStream() << " -> ";
+    SmallVector<ArrayRef<NamedAttribute>, 4> resultAttrs;
+    for (int i = 0, e = resultTypes.size(); i < e; ++i)
+      resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i));
+    printFunctionResultList(p, resultTypes, resultAttrs);
+  }
 }
 
 /// Printer implementation for function-like operations.  Accepts lists of
 /// argument and result types to use while printing.
 void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
                                      ArrayRef<Type> argTypes, bool isVariadic,
-                                     ArrayRef<Type> results) {
+                                     ArrayRef<Type> resultTypes) {
   // Print the operation and the function name.
   auto funcName =
       op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName())
@@ -206,20 +279,28 @@ void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
   p.printSymbolName(funcName);
 
   // Print the signature.
-  printSignature(p, op, argTypes, isVariadic, results);
+  printSignature(p, op, argTypes, isVariadic, resultTypes);
 
   // Print out function attributes, if present.
   SmallVector<StringRef, 2> ignoredAttrs = {
       ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()};
 
+  SmallString<8> attrNameBuf;
+
   // Ignore any argument attributes.
   std::vector<SmallString<8>> argAttrStorage;
-  SmallString<8> argAttrName;
   for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
-    if (op->getAttr(getArgAttrName(i, argAttrName)))
-      argAttrStorage.emplace_back(argAttrName);
+    if (op->getAttr(getArgAttrName(i, attrNameBuf)))
+      argAttrStorage.emplace_back(attrNameBuf);
   ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end());
 
+  // Ignore any result attributes.
+  std::vector<SmallString<8>> resultAttrStorage;
+  for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
+    if (op->getAttr(getResultAttrName(i, attrNameBuf)))
+      resultAttrStorage.emplace_back(attrNameBuf);
+  ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end());
+
   auto attrs = op->getAttrs();
   if (attrs.size() > ignoredAttrs.size()) {
     p << "\n  attributes ";
index 16d734d..1554736 100644 (file)
@@ -49,3 +49,27 @@ func @func_op() {
   }
   return
 }
+
+// -----
+
+// expected-error@+1 {{expected non-function type}}
+func @f() -> (foo
+
+// -----
+
+// expected-error@+1 {{expected attribute name}}
+func @f() -> (i1 {)
+
+// -----
+
+// expected-error@+1 {{invalid to use 'test.invalid_attr'}}
+func @f(%arg0: i64 {test.invalid_attr}) {
+  return
+}
+
+// -----
+
+// expected-error@+1 {{invalid to use 'test.invalid_attr'}}
+func @f(%arg0: i64) -> (i64 {test.invalid_attr}) {
+  return %arg0 : i64
+}
index ea368d7..2f8dcc9 100644 (file)
@@ -924,6 +924,11 @@ func @invalid_func_arg_attr(i1 {non_dialect_attr = 10})
 
 // -----
 
+// expected-error @+1 {{results may only have dialect attributes}}
+func @invalid_func_result_attr() -> (i1 {non_dialect_attr = 10})
+
+// -----
+
 // expected-error @+1 {{expected '<' in tuple type}}
 func @invalid_tuple_missing_less(tuple i32>)
 
index dc300c7..13b4856 100644 (file)
@@ -847,6 +847,11 @@ func @func_arg_attrs(%arg0: i1 {dialect.attr = 10 : i64}) {
   return
 }
 
+// CHECK-LABEL: func @func_result_attrs({{.*}}) -> (f32 {dialect.attr = 1 : i64})
+func @func_result_attrs(%arg0: f32) -> (f32 {dialect.attr = 1}) {
+  return %arg0 : f32
+}
+
 // CHECK-LABEL: func @empty_tuple(tuple<>)
 func @empty_tuple(tuple<>)
 
index ee8325f..2e3d97b 100644 (file)
@@ -114,6 +114,24 @@ TestDialect::TestDialect(MLIRContext *context)
   allowUnknownOperations();
 }
 
+LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
+                                                    unsigned regionIndex,
+                                                    unsigned argIndex,
+                                                    NamedAttribute namedAttr) {
+  if (namedAttr.first == "test.invalid_attr")
+    return op->emitError() << "invalid to use 'test.invalid_attr'";
+  return success();
+}
+
+LogicalResult
+TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
+                                         unsigned resultIndex,
+                                         NamedAttribute namedAttr) {
+  if (namedAttr.first == "test.invalid_attr")
+    return op->emitError() << "invalid to use 'test.invalid_attr'";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Test IsolatedRegionOp - parse passthrough region arguments.
 //===----------------------------------------------------------------------===//
index ffe2a1c..ade0eb8 100644 (file)
@@ -40,6 +40,14 @@ public:
 
   /// Get the canonical string name of the dialect.
   static StringRef getDialectName() { return "test"; }
+
+  LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex,
+                                         unsigned argIndex,
+                                         NamedAttribute) override;
+
+  LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex,
+                                            unsigned resultIndex,
+                                            NamedAttribute) override;
 };
 
 #define GET_OP_CLASSES