FunctionSupport: wrap around bool to have a more semantic callback type
authorAlex Zinenko <zinenko@google.com>
Thu, 8 Aug 2019 19:11:27 +0000 (12:11 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 8 Aug 2019 19:11:54 +0000 (12:11 -0700)
This changes the type of the function type-building callback from
(ArrayRef<Type>, ArrayRef<Type>, bool, string &) to (ArrayRef<Type>,
ArrayRef<Type>, VariadicFlag, String &) to make the intended use clear from the
callback signature alone.

Also rearrange type definitions in Parser.cpp to make them more sorted
alphabetically.

PiperOrigin-RevId: 262405851

mlir/include/mlir/IR/FunctionSupport.h
mlir/lib/IR/Function.cpp
mlir/lib/IR/FunctionSupport.cpp
mlir/lib/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Parser/Parser.cpp

index ec12001..75a0a67 100644 (file)
@@ -53,13 +53,24 @@ inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
   return argDict ? argDict.getValue() : llvm::None;
 }
 
+/// A named class for passing around the variadic flag.
+class VariadicFlag {
+public:
+  explicit VariadicFlag(bool variadic) : variadic(variadic) {}
+  bool isVariadic() const { return variadic; }
+
+private:
+  /// Underlying storage.
+  bool variadic;
+};
+
 /// Callback type for `parseFunctionLikeOp`, the callback should produce the
 /// type that will be associated with a function-like operation from lists of
-/// function arguments and results, the boolean operand is true if the function
+/// function arguments and results, VariadicFlag indicates whether the function
 /// should have variadic arguments; in case of error, it may populate the last
 /// argument with a message.
 using FuncTypeBuilder = llvm::function_ref<Type(
-    Builder &, ArrayRef<Type>, ArrayRef<Type>, bool, std::string &)>;
+    Builder &, ArrayRef<Type>, ArrayRef<Type>, VariadicFlag, std::string &)>;
 
 /// Parser implementation for function-like operations.  Uses
 /// `funcTypeBuilder` to construct the custom function type given lists of
index e4d1960..fb54f85 100644 (file)
@@ -77,7 +77,8 @@ void FuncOp::build(Builder *builder, OperationState *result, StringRef name,
 
 ParseResult FuncOp::parse(OpAsmParser *parser, OperationState *result) {
   auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
-                          ArrayRef<Type> results, bool, std::string &) {
+                          ArrayRef<Type> results, impl::VariadicFlag,
+                          std::string &) {
     return builder.getFunctionType(argTypes, results);
   };
 
index 7416e64..064e438 100644 (file)
@@ -132,8 +132,8 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser *parser, OperationState *result,
     return failure();
 
   std::string errorMessage;
-  if (auto type =
-          funcTypeBuilder(builder, argTypes, results, isVariadic, errorMessage))
+  if (auto type = funcTypeBuilder(builder, argTypes, results,
+                                  impl::VariadicFlag(isVariadic), errorMessage))
     result->addAttribute(getTypeAttrName(), builder.getTypeAttr(type));
   else
     return parser->emitError(signatureLocation)
index 9c49338..8db9abe 100644 (file)
@@ -730,7 +730,8 @@ void LLVMFuncOp::build(Builder *builder, OperationState *result, StringRef name,
 // Returns a null type if any of the types provided are non-LLVM types, or if
 // there is more than one output type.
 static Type buildLLVMFunctionType(Builder &b, ArrayRef<Type> inputs,
-                                  ArrayRef<Type> outputs, bool isVariadic,
+                                  ArrayRef<Type> outputs,
+                                  impl::VariadicFlag variadicFlag,
                                   std::string &errorMessage) {
   if (outputs.size() > 1) {
     errorMessage = "expected zero or one function result";
@@ -761,7 +762,8 @@ static Type buildLLVMFunctionType(Builder &b, ArrayRef<Type> inputs,
     errorMessage = "expected LLVM type for function results";
     return {};
   }
-  return LLVMType::getFunctionTy(llvmOutput, llvmInputs, isVariadic);
+  return LLVMType::getFunctionTy(llvmOutput, llvmInputs,
+                                 variadicFlag.isVariadic());
 }
 
 // Print the LLVMFuncOp.  Collects argument and result types and passes them
index 5e722ad..14280bb 100644 (file)
@@ -3260,6 +3260,11 @@ public:
     return success(parser.consumeIf(Token::comma));
   }
 
+  /// Parses a `...` if present.
+  ParseResult parseOptionalEllipsis() override {
+    return success(parser.consumeIf(Token::ellipsis));
+  }
+
   /// Parse a `=` token.
   ParseResult parseEqual() override {
     return parser.parseToken(Token::equal, "expected '='");
@@ -3319,11 +3324,6 @@ public:
     return success(parser.consumeIf(Token::r_square));
   }
 
-  /// Parses a `...` if present.
-  ParseResult parseOptionalEllipsis() override {
-    return success(parser.consumeIf(Token::ellipsis));
-  }
-
   //===--------------------------------------------------------------------===//
   // Attribute Parsing
   //===--------------------------------------------------------------------===//