Preserve function argument locations.
authorDominik Grewe <dominikg@google.com>
Wed, 19 Jan 2022 23:44:43 +0000 (23:44 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 20 Jan 2022 00:01:12 +0000 (00:01 +0000)
Previously the optional locations of function arguments were dropped in
`parseFunctionArgumentList`. This CL adds another output argument to the
function through which they are now returned. The values are then plumbed
through as an array of optional locations in the various places.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D117604

mlir/include/mlir/IR/FunctionImplementation.h
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/IR/FunctionImplementation.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/locations.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp

index f598bae..d1a15a7 100644 (file)
@@ -54,23 +54,24 @@ using FuncTypeBuilder = function_ref<Type(
 
 /// Parses function arguments using `parser`. The `allowVariadic` argument
 /// indicates whether functions with variadic arguments are supported. The
-/// trailing arguments are populated by this function with names, types and
-/// attributes of the arguments.
+/// trailing arguments are populated by this function with names, types,
+/// attributes and locations of the arguments.
 ParseResult parseFunctionArgumentList(
     OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
-    bool &isVariadic);
+    SmallVectorImpl<Optional<Location>> &argLocations, bool &isVariadic);
 
 /// Parses a function signature using `parser`. The `allowVariadic` argument
 /// indicates whether functions with variadic arguments are supported. The
-/// trailing arguments are populated by this function with names, types and
-/// attributes of the arguments and those of the results.
+/// trailing arguments are populated by this function with names, types,
+/// attributes and locations of the arguments and those of the results.
 ParseResult
 parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
                        SmallVectorImpl<OpAsmParser::OperandType> &argNames,
                        SmallVectorImpl<Type> &argTypes,
                        SmallVectorImpl<NamedAttrList> &argAttrs,
+                       SmallVectorImpl<Optional<Location>> &argLocations,
                        bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
                        SmallVectorImpl<NamedAttrList> &resultAttrs);
 
index 954a928..2f8cea6 100644 (file)
@@ -1208,20 +1208,23 @@ public:
 
   /// Parses a region. Any parsed blocks are appended to 'region' and must be
   /// moved to the op regions after the op is created. The first block of the
-  /// region takes 'arguments' of types 'argTypes'. If 'enableNameShadowing' is
-  /// set to true, the argument names are allowed to shadow the names of other
-  /// existing SSA values defined above the region scope. 'enableNameShadowing'
-  /// can only be set to true for regions attached to operations that are
-  /// 'IsolatedFromAbove.
-  virtual ParseResult parseRegion(Region &region,
-                                  ArrayRef<OperandType> arguments = {},
-                                  ArrayRef<Type> argTypes = {},
-                                  bool enableNameShadowing = false) = 0;
+  /// region takes 'arguments' of types 'argTypes'. If `argLocations` is
+  /// non-empty it contains an optional location to be attached to each
+  /// argument. If 'enableNameShadowing' is set to true, the argument names are
+  /// allowed to shadow the names of other existing SSA values defined above the
+  /// region scope. 'enableNameShadowing' can only be set to true for regions
+  /// attached to operations that are 'IsolatedFromAbove'.
+  virtual ParseResult
+  parseRegion(Region &region, ArrayRef<OperandType> arguments = {},
+              ArrayRef<Type> argTypes = {},
+              ArrayRef<Optional<Location>> argLocations = {},
+              bool enableNameShadowing = false) = 0;
 
   /// Parses a region if present.
   virtual OptionalParseResult
   parseOptionalRegion(Region &region, ArrayRef<OperandType> arguments = {},
                       ArrayRef<Type> argTypes = {},
+                      ArrayRef<Optional<Location>> argLocations = {},
                       bool enableNameShadowing = false) = 0;
 
   /// Parses a region if present. If the region is present, a new region is
index c6c83ca..b7b708b 100644 (file)
@@ -220,6 +220,7 @@ static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
   Region *body = result.addRegion();
   if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
                          /*argTypes=*/{unwrappedTypes},
+                         /*argLocations=*/{},
                          /*enableNameShadowing=*/false))
     return failure();
 
index a73d64c..429b7ff 100644 (file)
@@ -670,11 +670,13 @@ parseLaunchFuncOperands(OpAsmParser &parser,
                         SmallVectorImpl<Type> &argTypes) {
   if (parser.parseOptionalKeyword("args"))
     return success();
-  SmallVector<NamedAttrList, 4> argAttrs;
+  SmallVector<NamedAttrList> argAttrs;
+  SmallVector<Optional<Location>> argLocations;
   bool isVariadic = false;
   return function_interface_impl::parseFunctionArgumentList(
       parser, /*allowAttributes=*/false,
-      /*allowVariadic=*/false, argNames, argTypes, argAttrs, isVariadic);
+      /*allowVariadic=*/false, argNames, argTypes, argAttrs, argLocations,
+      isVariadic);
 }
 
 static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
@@ -776,11 +778,12 @@ parseAttributions(OpAsmParser &parser, StringRef keyword,
 ///                 (`->` function-result-list)? memory-attribution `kernel`?
 ///                 function-attributes? region
 static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 8> entryArgs;
-  SmallVector<NamedAttrList, 1> argAttrs;
-  SmallVector<NamedAttrList, 1> resultAttrs;
-  SmallVector<Type, 8> argTypes;
-  SmallVector<Type, 4> resultTypes;
+  SmallVector<OpAsmParser::OperandType> entryArgs;
+  SmallVector<NamedAttrList> argAttrs;
+  SmallVector<NamedAttrList> resultAttrs;
+  SmallVector<Type> argTypes;
+  SmallVector<Type> resultTypes;
+  SmallVector<Optional<Location>> argLocations;
   bool isVariadic;
 
   // Parse the function name.
@@ -792,7 +795,7 @@ static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) {
   auto signatureLocation = parser.getCurrentLocation();
   if (failed(function_interface_impl::parseFunctionSignature(
           parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
-          isVariadic, resultTypes, resultAttrs)))
+          argLocations, isVariadic, resultTypes, resultAttrs)))
     return failure();
 
   if (entryArgs.empty() && !argTypes.empty())
index 897203d..2d00939 100644 (file)
@@ -2033,11 +2033,12 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
                            parser, result, LLVM::Linkage::External)));
 
   StringAttr nameAttr;
-  SmallVector<OpAsmParser::OperandType, 8> entryArgs;
-  SmallVector<NamedAttrList, 1> argAttrs;
-  SmallVector<NamedAttrList, 1> resultAttrs;
-  SmallVector<Type, 8> argTypes;
-  SmallVector<Type, 4> resultTypes;
+  SmallVector<OpAsmParser::OperandType> entryArgs;
+  SmallVector<NamedAttrList> argAttrs;
+  SmallVector<NamedAttrList> resultAttrs;
+  SmallVector<Type> argTypes;
+  SmallVector<Type> resultTypes;
+  SmallVector<Optional<Location>> argLocations;
   bool isVariadic;
 
   auto signatureLocation = parser.getCurrentLocation();
@@ -2045,7 +2046,7 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
                              result.attributes) ||
       function_interface_impl::parseFunctionSignature(
           parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs,
-          isVariadic, resultTypes, resultAttrs))
+          argLocations, isVariadic, resultTypes, resultAttrs))
     return failure();
 
   auto type =
index 6838007..ca4ec92 100644 (file)
@@ -1942,11 +1942,12 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
 //===----------------------------------------------------------------------===//
 
 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
-  SmallVector<OpAsmParser::OperandType, 4> entryArgs;
-  SmallVector<NamedAttrList, 4> argAttrs;
-  SmallVector<NamedAttrList, 4> resultAttrs;
-  SmallVector<Type, 4> argTypes;
-  SmallVector<Type, 4> resultTypes;
+  SmallVector<OpAsmParser::OperandType> entryArgs;
+  SmallVector<NamedAttrList> argAttrs;
+  SmallVector<NamedAttrList> resultAttrs;
+  SmallVector<Type> argTypes;
+  SmallVector<Type> resultTypes;
+  SmallVector<Optional<Location>> argLocations;
   auto &builder = parser.getBuilder();
 
   // Parse the name as a symbol.
@@ -1959,7 +1960,7 @@ static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
   bool isVariadic = false;
   if (function_interface_impl::parseFunctionSignature(
           parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
-          isVariadic, resultTypes, resultAttrs))
+          argLocations, isVariadic, resultTypes, resultAttrs))
     return failure();
 
   auto fnType = builder.getFunctionType(argTypes, resultTypes);
index 581f6f6..49ffaa2 100644 (file)
@@ -17,7 +17,7 @@ ParseResult mlir::function_interface_impl::parseFunctionArgumentList(
     OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
-    bool &isVariadic) {
+    SmallVectorImpl<Optional<Location>> &argLocations, bool &isVariadic) {
   if (parser.parseLParen())
     return failure();
 
@@ -60,11 +60,12 @@ ParseResult mlir::function_interface_impl::parseFunctionArgumentList(
       return parser.emitError(loc, "expected arguments without attributes");
     argAttrs.push_back(attrs);
 
-    // Parse a location if specified.  TODO: Don't drop it on the floor.
+    // Parse a location if specified.
     Optional<Location> explicitLoc;
     if (!argument.name.empty() &&
         parser.parseOptionalLocationSpecifier(explicitLoc))
       return failure();
+    argLocations.push_back(explicitLoc);
 
     return success();
   };
@@ -132,11 +133,12 @@ ParseResult mlir::function_interface_impl::parseFunctionSignature(
     OpAsmParser &parser, bool allowVariadic,
     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
-    bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
+    SmallVectorImpl<Optional<Location>> &argLocations, bool &isVariadic,
+    SmallVectorImpl<Type> &resultTypes,
     SmallVectorImpl<NamedAttrList> &resultAttrs) {
   bool allowArgAttrs = true;
   if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, argNames,
-                                argTypes, argAttrs, isVariadic))
+                                argTypes, argAttrs, argLocations, isVariadic))
     return failure();
   if (succeeded(parser.parseOptionalArrow()))
     return parseFunctionResultList(parser, resultTypes, resultAttrs);
@@ -190,11 +192,12 @@ void mlir::function_interface_impl::addArgAndResultAttrs(
 ParseResult mlir::function_interface_impl::parseFunctionOp(
     OpAsmParser &parser, OperationState &result, bool allowVariadic,
     FuncTypeBuilder funcTypeBuilder) {
-  SmallVector<OpAsmParser::OperandType, 4> entryArgs;
-  SmallVector<NamedAttrList, 4> argAttrs;
-  SmallVector<NamedAttrList, 4> resultAttrs;
-  SmallVector<Type, 4> argTypes;
-  SmallVector<Type, 4> resultTypes;
+  SmallVector<OpAsmParser::OperandType> entryArgs;
+  SmallVector<NamedAttrList> argAttrs;
+  SmallVector<NamedAttrList> resultAttrs;
+  SmallVector<Type> argTypes;
+  SmallVector<Type> resultTypes;
+  SmallVector<Optional<Location>> argLocations;
   auto &builder = parser.getBuilder();
 
   // Parse visibility.
@@ -210,7 +213,8 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
   llvm::SMLoc signatureLocation = parser.getCurrentLocation();
   bool isVariadic = false;
   if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
-                             argAttrs, isVariadic, resultTypes, resultAttrs))
+                             argAttrs, argLocations, isVariadic, resultTypes,
+                             resultAttrs))
     return failure();
 
   std::string errorMessage;
@@ -253,6 +257,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
   llvm::SMLoc loc = parser.getCurrentLocation();
   OptionalParseResult parseResult = parser.parseOptionalRegion(
       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes,
+      entryArgs.empty() ? ArrayRef<Optional<Location>>() : argLocations,
       /*enableNameShadowing=*/false);
   if (parseResult.hasValue()) {
     if (failed(*parseResult))
index 0a5e79c..0cc0bbf 100644 (file)
@@ -363,16 +363,19 @@ public:
   //===--------------------------------------------------------------------===//
 
   /// Parse a region into 'region' with the provided entry block arguments.
-  /// 'isIsolatedNameScope' indicates if the naming scope of this region is
-  /// isolated from those above.
+  /// If non-empty, 'argLocations' contains an optional locations for each
+  /// argument. 'isIsolatedNameScope' indicates if the naming scope of this
+  /// region is isolated from those above.
   ParseResult parseRegion(Region &region,
                           ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments,
+                          ArrayRef<Optional<Location>> argLocations = {},
                           bool isIsolatedNameScope = false);
 
   /// Parse a region body into 'region'.
   ParseResult
   parseRegionBody(Region &region, llvm::SMLoc startLoc,
                   ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments,
+                  ArrayRef<Optional<Location>> argLocations,
                   bool isIsolatedNameScope);
 
   //===--------------------------------------------------------------------===//
@@ -1448,6 +1451,7 @@ public:
   /// effectively defines the SSA values of `arguments` and assigns their type.
   ParseResult parseRegion(Region &region, ArrayRef<OperandType> arguments,
                           ArrayRef<Type> argTypes,
+                          ArrayRef<Optional<Location>> argLocations,
                           bool enableNameShadowing) override {
     assert(arguments.size() == argTypes.size() &&
            "mismatching number of arguments and types");
@@ -1466,19 +1470,22 @@ public:
     (void)isIsolatedFromAbove;
     assert((!enableNameShadowing || isIsolatedFromAbove) &&
            "name shadowing is only allowed on isolated regions");
-    if (parser.parseRegion(region, regionArguments, enableNameShadowing))
+    if (parser.parseRegion(region, regionArguments, argLocations,
+                           enableNameShadowing))
       return failure();
     return success();
   }
 
   /// Parses a region if present.
-  OptionalParseResult parseOptionalRegion(Region &region,
-                                          ArrayRef<OperandType> arguments,
-                                          ArrayRef<Type> argTypes,
-                                          bool enableNameShadowing) override {
+  OptionalParseResult
+  parseOptionalRegion(Region &region, ArrayRef<OperandType> arguments,
+                      ArrayRef<Type> argTypes,
+                      ArrayRef<Optional<Location>> argLocations,
+                      bool enableNameShadowing) override {
     if (parser.getToken().isNot(Token::l_brace))
       return llvm::None;
-    return parseRegion(region, arguments, argTypes, enableNameShadowing);
+    return parseRegion(region, arguments, argTypes, argLocations,
+                       enableNameShadowing);
   }
 
   /// Parses a region if present. If the region is present, a new region is
@@ -1491,7 +1498,8 @@ public:
     if (parser.getToken().isNot(Token::l_brace))
       return llvm::None;
     std::unique_ptr<Region> newRegion = std::make_unique<Region>();
-    if (parseRegion(*newRegion, arguments, argTypes, enableNameShadowing))
+    if (parseRegion(*newRegion, arguments, argTypes, /*argLocations=*/{},
+                    enableNameShadowing))
       return failure();
 
     region = std::move(newRegion);
@@ -1815,7 +1823,7 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
 ParseResult OperationParser::parseRegion(
     Region &region,
     ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments,
-    bool isIsolatedNameScope) {
+    ArrayRef<Optional<Location>> argLocations, bool isIsolatedNameScope) {
   // Parse the '{'.
   Token lBraceTok = getToken();
   if (parseToken(Token::l_brace, "expected '{' to begin a region"))
@@ -1827,7 +1835,7 @@ ParseResult OperationParser::parseRegion(
 
   // Parse the region body.
   if ((!entryArguments.empty() || getToken().isNot(Token::r_brace)) &&
-      parseRegionBody(region, lBraceTok.getLoc(), entryArguments,
+      parseRegionBody(region, lBraceTok.getLoc(), entryArguments, argLocations,
                       isIsolatedNameScope)) {
     return failure();
   }
@@ -1843,7 +1851,8 @@ ParseResult OperationParser::parseRegion(
 ParseResult OperationParser::parseRegionBody(
     Region &region, llvm::SMLoc startLoc,
     ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments,
-    bool isIsolatedNameScope) {
+    ArrayRef<Optional<Location>> argLocations, bool isIsolatedNameScope) {
+  assert(argLocations.empty() || argLocations.size() == entryArguments.size());
   auto currentPt = opBuilder.saveInsertionPoint();
 
   // Push a new named value scope.
@@ -1865,7 +1874,9 @@ ParseResult OperationParser::parseRegionBody(
     if (getToken().is(Token::caret_identifier))
       return emitError("invalid block name in region with named arguments");
 
-    for (auto &placeholderArgPair : entryArguments) {
+    for (const auto &it : llvm::enumerate(entryArguments)) {
+      size_t argIndex = it.index();
+      auto &placeholderArgPair = it.value();
       auto &argInfo = placeholderArgPair.first;
 
       // Ensure that the argument was not already defined.
@@ -1875,7 +1886,10 @@ ParseResult OperationParser::parseRegionBody(
                    .attachNote(getEncodedSourceLocation(*defLoc))
                << "previously referenced here";
       }
-      auto loc = getEncodedSourceLocation(placeholderArgPair.first.loc);
+      Location loc =
+          (!argLocations.empty() && argLocations[argIndex])
+              ? *argLocations[argIndex]
+              : getEncodedSourceLocation(placeholderArgPair.first.loc);
       BlockArgument arg = block->addArgument(placeholderArgPair.second, loc);
 
       // Add a definition of this arg to the assembly state if provided.
index f6c4f21..fa1bafc 100644 (file)
@@ -52,7 +52,7 @@ func @escape_strings() {
 // CHECK-LABEL: func @argLocs(
 // CHECK-SAME:  %arg0: i32 loc({{.*}}locations.mlir":[[# @LINE+1]]:15),
 func @argLocs(%x: i32,
-// CHECK-SAME:  %arg1: i64 loc({{.*}}locations.mlir":[[# @LINE+1]]:15))
+// CHECK-SAME:  %arg1: i64 loc("hotdog")
               %y: i64 loc("hotdog")) {
   return
 }
index b1ce202..4d316bd 100644 (file)
@@ -595,7 +595,7 @@ static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
 
   // Parse the body region, and reuse the operand info as the argument info.
   Region *body = result.addRegion();
-  return parser.parseRegion(*body, argInfo, argType,
+  return parser.parseRegion(*body, argInfo, argType, /*argLocations=*/{},
                             /*enableNameShadowing=*/true);
 }