[mlir][Asm] Add support for resolving operation locations after parsing has finished
authorRiver Riddle <riddleriver@gmail.com>
Fri, 13 Nov 2020 07:33:43 +0000 (23:33 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Fri, 13 Nov 2020 07:34:36 +0000 (23:34 -0800)
This revision adds support in the parser/printer for "deferrable" aliases, i.e. those that can be resolved after printing has finished. This allows for printing aliases for operation locations after the module instead of before, i.e. this is now supported:

```
"foo.op"() : () -> () loc(#loc)

#loc = loc("some_location")
```

Differential Revision: https://reviews.llvm.org/D91227

mlir/lib/IR/AsmPrinter.cpp
mlir/lib/Parser/LocationParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Parser.h
mlir/test/IR/invalid-locations.mlir
mlir/test/IR/locations.mlir

index 4b91b19..5644abf 100644 (file)
@@ -225,6 +225,38 @@ static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
 //===----------------------------------------------------------------------===//
 
 namespace {
+/// This class represents a specific instance of a symbol Alias.
+class SymbolAlias {
+public:
+  SymbolAlias(StringRef name, bool isDeferrable)
+      : name(name), suffixIndex(0), hasSuffixIndex(false),
+        isDeferrable(isDeferrable) {}
+  SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
+      : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
+        isDeferrable(isDeferrable) {}
+
+  /// Print this alias to the given stream.
+  void print(raw_ostream &os) const {
+    os << name;
+    if (hasSuffixIndex)
+      os << suffixIndex;
+  }
+
+  /// Returns true if this alias supports deferred resolution when parsing.
+  bool canBeDeferred() const { return isDeferrable; }
+
+private:
+  /// The main name of the alias.
+  StringRef name;
+  /// The optional suffix index of the alias, if multiple aliases had the same
+  /// name.
+  uint32_t suffixIndex : 30;
+  /// A flag indicating whether this alias has a suffix or not.
+  bool hasSuffixIndex : 1;
+  /// A flag indicating whether this alias may be deferred or not.
+  bool isDeferrable : 1;
+};
+
 /// This class represents a utility that initializes the set of attribute and
 /// type aliases, without the need to store the extra information within the
 /// main AliasState class or pass it around via function arguments.
@@ -236,14 +268,14 @@ public:
       : interfaces(interfaces), aliasAllocator(aliasAllocator),
         aliasOS(aliasBuffer) {}
 
-  void initialize(
-      Operation *op, const OpPrintingFlags &printerFlags,
-      llvm::MapVector<Attribute, std::pair<StringRef, Optional<int>>>
-          &attrToAlias,
-      llvm::MapVector<Type, std::pair<StringRef, Optional<int>>> &typeToAlias);
+  void initialize(Operation *op, const OpPrintingFlags &printerFlags,
+                  llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
+                  llvm::MapVector<Type, SymbolAlias> &typeToAlias);
 
-  /// Visit the given attribute to see if it has an alias.
-  void visit(Attribute attr);
+  /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
+  /// set to true if the originator of this attribute can resolve the alias
+  /// after parsing has completed (e.g. in the case of operation locations).
+  void visit(Attribute attr, bool canBeDeferred = false);
 
   /// Visit the given type to see if it has an alias.
   void visit(Type type);
@@ -251,9 +283,11 @@ public:
 private:
   /// Try to generate an alias for the provided symbol. If an alias is
   /// generated, the provided alias mapping and reverse mapping are updated.
+  /// Returns success if an alias was generated, failure otherwise.
   template <typename T>
-  void generateAlias(T symbol,
-                     llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
+  LogicalResult
+  generateAlias(T symbol,
+                llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
 
   /// The set of asm interfaces within the context.
   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
@@ -268,6 +302,9 @@ private:
   /// The set of visited attributes.
   DenseSet<Attribute> visitedAttributes;
 
+  /// The set of attributes that have aliases *and* can be deferred.
+  DenseSet<Attribute> deferrableAttributes;
+
   /// The set of visited types.
   DenseSet<Type> visitedTypes;
 
@@ -291,7 +328,7 @@ public:
   void print(Operation *op) {
     // Visit the operation location.
     if (printerFlags.shouldPrintDebugInfo())
-      initializer.visit(op->getLoc());
+      initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
 
     // If requested, always print the generic form.
     if (!printerFlags.shouldPrintGenericOpForm()) {
@@ -464,9 +501,10 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
 /// Given a collection of aliases and symbols, initialize a mapping from a
 /// symbol to a given alias.
 template <typename T>
-static void initializeAliases(
-    llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
-    llvm::MapVector<T, std::pair<StringRef, Optional<int>>> &symbolToAlias) {
+static void
+initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
+                  llvm::MapVector<T, SymbolAlias> &symbolToAlias,
+                  DenseSet<T> *deferrableAliases = nullptr) {
   std::vector<std::pair<StringRef, std::vector<T>>> aliases =
       aliasToSymbol.takeVector();
   llvm::array_pod_sort(aliases.begin(), aliases.end(),
@@ -477,20 +515,24 @@ static void initializeAliases(
   for (auto &it : aliases) {
     // If there is only one instance for this alias, use the name directly.
     if (it.second.size() == 1) {
-      symbolToAlias.insert({it.second.front(), {it.first, llvm::None}});
+      T symbol = it.second.front();
+      bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
+      symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
       continue;
     }
     // Otherwise, add the index to the name.
-    for (int i = 0, e = it.second.size(); i < e; ++i)
-      symbolToAlias.insert({it.second[i], {it.first, i}});
+    for (int i = 0, e = it.second.size(); i < e; ++i) {
+      T symbol = it.second[i];
+      bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
+      symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
+    }
   }
 }
 
 void AliasInitializer::initialize(
     Operation *op, const OpPrintingFlags &printerFlags,
-    llvm::MapVector<Attribute, std::pair<StringRef, Optional<int>>>
-        &attrToAlias,
-    llvm::MapVector<Type, std::pair<StringRef, Optional<int>>> &typeToAlias) {
+    llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
+    llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
   // Use a dummy printer when walking the IR so that we can collect the
   // attributes/types that will actually be used during printing when
   // considering aliases.
@@ -498,13 +540,25 @@ void AliasInitializer::initialize(
   aliasPrinter.print(op);
 
   // Initialize the aliases sorted by name.
-  initializeAliases(aliasToAttr, attrToAlias);
+  initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
   initializeAliases(aliasToType, typeToAlias);
 }
 
-void AliasInitializer::visit(Attribute attr) {
-  if (!visitedAttributes.insert(attr).second)
+void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
+  if (!visitedAttributes.insert(attr).second) {
+    // If this attribute already has an alias and this instance can't be
+    // deferred, make sure that the alias isn't deferred.
+    if (!canBeDeferred)
+      deferrableAttributes.erase(attr);
+    return;
+  }
+
+  // Try to generate an alias for this attribute.
+  if (succeeded(generateAlias(attr, aliasToAttr))) {
+    if (canBeDeferred)
+      deferrableAttributes.insert(attr);
     return;
+  }
 
   if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
     for (Attribute element : arrayAttr.getValue())
@@ -515,15 +569,16 @@ void AliasInitializer::visit(Attribute attr) {
   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
     visit(typeAttr.getValue());
   }
-
-  // Try to generate an alias for this attribute.
-  generateAlias(attr, aliasToAttr);
 }
 
 void AliasInitializer::visit(Type type) {
   if (!visitedTypes.insert(type).second)
     return;
 
+  // Try to generate an alias for this type.
+  if (succeeded(generateAlias(type, aliasToType)))
+    return;
+
   // Visit several subtypes that contain types or atttributes.
   if (auto funcType = type.dyn_cast<FunctionType>()) {
     // Visit input and result types for functions.
@@ -539,13 +594,10 @@ void AliasInitializer::visit(Type type) {
       for (auto map : memref.getAffineMaps())
         visit(AffineMapAttr::get(map));
   }
-
-  // Try to generate an alias for this type.
-  generateAlias(type, aliasToType);
 }
 
 template <typename T>
-void AliasInitializer::generateAlias(
+LogicalResult AliasInitializer::generateAlias(
     T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
   SmallString<16> tempBuffer;
   for (const auto &interface : interfaces) {
@@ -559,8 +611,9 @@ void AliasInitializer::generateAlias(
 
     aliasToSymbol[name].push_back(symbol);
     aliasBuffer.clear();
-    break;
+    return success();
   }
+  return failure();
 }
 
 //===----------------------------------------------------------------------===//
@@ -580,21 +633,31 @@ public:
   /// Returns success if an alias was printed, failure otherwise.
   LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
 
-  /// Print all of the referenced attribute aliases.
-  void printAttributeAliases(raw_ostream &os, NewLineCounter &newLine) const;
-
   /// Get an alias for the given type if it has one and print it in `os`.
   /// Returns success if an alias was printed, failure otherwise.
   LogicalResult getAlias(Type ty, raw_ostream &os) const;
 
-  /// Print all of the referenced type aliases.
-  void printTypeAliases(raw_ostream &os, NewLineCounter &newLine) const;
+  /// Print all of the referenced aliases that can not be resolved in a deferred
+  /// manner.
+  void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
+    printAliases(os, newLine, /*isDeferred=*/false);
+  }
+
+  /// Print all of the referenced aliases that support deferred resolution.
+  void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
+    printAliases(os, newLine, /*isDeferred=*/true);
+  }
 
 private:
-  /// Mapping between attribute and a pair comprised of a base alias name and a
-  /// count suffix. If the suffix is set to None, it is not displayed.
-  llvm::MapVector<Attribute, std::pair<StringRef, Optional<int>>> attrToAlias;
-  llvm::MapVector<Type, std::pair<StringRef, Optional<int>>> typeToAlias;
+  /// Print all of the referenced aliases that support the provided resolution
+  /// behavior.
+  void printAliases(raw_ostream &os, NewLineCounter &newLine,
+                    bool isDeferred) const;
+
+  /// Mapping between attribute and alias.
+  llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
+  /// Mapping between type and alias.
+  llvm::MapVector<Type, SymbolAlias> typeToAlias;
 
   /// An allocator used for alias names.
   llvm::BumpPtrAllocator aliasAllocator;
@@ -608,44 +671,34 @@ void AliasState::initialize(
   initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
 }
 
-static void printAlias(raw_ostream &os,
-                       const std::pair<StringRef, Optional<int>> &alias,
-                       char prefix) {
-  os << prefix << alias.first;
-  if (alias.second)
-    os << *alias.second;
-}
-
 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
   auto it = attrToAlias.find(attr);
   if (it == attrToAlias.end())
     return failure();
-
-  printAlias(os, it->second, '#');
+  it->second.print(os << '#');
   return success();
 }
 
-void AliasState::printAttributeAliases(raw_ostream &os,
-                                       NewLineCounter &newLine) const {
-  for (const auto &it : attrToAlias) {
-    printAlias(os, it.second, '#');
-    os << " = " << it.first << newLine;
-  }
-}
-
 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
   auto it = typeToAlias.find(ty);
   if (it == typeToAlias.end())
     return failure();
 
-  printAlias(os, it->second, '!');
+  it->second.print(os << '!');
   return success();
 }
 
-void AliasState::printTypeAliases(raw_ostream &os,
-                                  NewLineCounter &newLine) const {
-  for (const auto &it : typeToAlias) {
-    printAlias(os, it.second, '!');
+void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
+                              bool isDeferred) const {
+  auto filterFn = [=](const auto &aliasIt) {
+    return aliasIt.second.canBeDeferred() == isDeferred;
+  };
+  for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
+    it.second.print(os << '#');
+    os << " = " << it.first << newLine;
+  }
+  for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
+    it.second.print(os << '!');
     os << " = " << it.first << newLine;
   }
 }
@@ -2237,12 +2290,15 @@ private:
 } // end anonymous namespace
 
 void OperationPrinter::print(ModuleOp op) {
-  // Output the aliases at the top level.
-  state->getAliasState().printAttributeAliases(os, newLine);
-  state->getAliasState().printTypeAliases(os, newLine);
+  // Output the aliases at the top level that can't be deferred.
+  state->getAliasState().printNonDeferredAliases(os, newLine);
 
   // Print the module.
   print(op.getOperation());
+  os << newLine;
+
+  // Output the aliases at the top level that can be deferred.
+  state->getAliasState().printDeferredAliases(os, newLine);
 }
 
 void OperationPrinter::print(Operation *op) {
index 7ad8b92..93977f0 100644 (file)
@@ -177,35 +177,3 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
 
   return emitError("expected location instance");
 }
-
-ParseResult Parser::parseOptionalTrailingLocation(Location &loc) {
-  // If there is a 'loc' we parse a trailing location.
-  if (!consumeIf(Token::kw_loc))
-    return success();
-  if (parseToken(Token::l_paren, "expected '(' in location"))
-    return failure();
-  Token tok = getToken();
-
-  // Check to see if we are parsing a location alias.
-  LocationAttr directLoc;
-  if (tok.is(Token::hash_identifier)) {
-    // TODO: This should be reworked a bit to allow for resolving operation
-    // locations to aliases after the operation has already been parsed(i.e.
-    // allow post parse location fixups).
-    Attribute attr = parseExtendedAttr(Type());
-    if (!attr)
-      return failure();
-    if (!(directLoc = attr.dyn_cast<LocationAttr>()))
-      return emitError(tok.getLoc()) << "expected location, but found " << attr;
-
-    // Otherwise, we parse the location directly.
-  } else if (parseLocationInstance(directLoc)) {
-    return failure();
-  }
-
-  if (parseToken(Token::r_paren, "expected ')' in location"))
-    return failure();
-
-  loc = directLoc;
-  return success();
-}
index d8f6ebd..03768b5 100644 (file)
@@ -183,9 +183,15 @@ public:
   Operation *parseGenericOperation(Block *insertBlock,
                                    Block::iterator insertPt);
 
+  /// Parse an optional trailing location for the given operation.
+  ///
+  ///   trailing-location ::= (`loc` (`(` location `)` | attribute-alias))?
+  ///
+  ParseResult parseTrailingOperationLocation(Operation *op);
+
   /// This is the structure of a result specifier in the assembly syntax,
   /// including the name, number of results, and location.
-  typedef std::tuple<StringRef, unsigned, SMLoc> ResultRecord;
+  using ResultRecord = std::tuple<StringRef, unsigned, SMLoc>;
 
   /// Parse an operation instance that is in the op-defined custom form.
   /// resultInfo specifies information about the "%name =" specifiers.
@@ -297,6 +303,10 @@ private:
   /// their first reference, to allow checking for use of undefined values.
   DenseMap<Value, SMLoc> forwardRefPlaceholders;
 
+  /// A set of operations whose locations reference aliases that have yet to
+  /// be resolved.
+  SmallVector<std::pair<Operation *, Token>, 8> opsWithDeferredLocs;
+
   /// The builder used when creating parsed operation instances.
   OpBuilder opBuilder;
 
@@ -333,6 +343,22 @@ ParseResult OperationParser::finalize() {
     return failure();
   }
 
+  // Resolve the locations of any deferred operations.
+  auto &attributeAliases = getState().symbols.attributeAliasDefinitions;
+  for (std::pair<Operation *, Token> &it : opsWithDeferredLocs) {
+    llvm::SMLoc tokLoc = it.second.getLoc();
+    StringRef identifier = it.second.getSpelling().drop_front();
+    Attribute attr = attributeAliases.lookup(identifier);
+    if (!attr)
+      return emitError(tokLoc) << "operation location alias was never defined";
+
+    LocationAttr locAttr = attr.dyn_cast<LocationAttr>();
+    if (!locAttr)
+      return emitError(tokLoc)
+             << "expected location, but found '" << attr << "'";
+    it.first->setLoc(locAttr);
+  }
+
   return success();
 }
 
@@ -817,11 +843,11 @@ Operation *OperationParser::parseGenericOperation() {
       return nullptr;
   }
 
-  // Parse a location if one is present.
-  if (parseOptionalTrailingLocation(result.location))
+  // Create the operation and try to parse a location for it.
+  Operation *op = opBuilder.createOperation(result);
+  if (parseTrailingOperationLocation(op))
     return nullptr;
-
-  return opBuilder.createOperation(result);
+  return op;
 }
 
 Operation *OperationParser::parseGenericOperation(Block *insertBlock,
@@ -1623,12 +1649,56 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
   if (opAsmParser.didEmitError())
     return nullptr;
 
-  // Parse a location if one is present.
-  if (parseOptionalTrailingLocation(opState.location))
+  // Otherwise, create the operation and try to parse a location for it.
+  Operation *op = opBuilder.createOperation(opState);
+  if (parseTrailingOperationLocation(op))
     return nullptr;
+  return op;
+}
+
+ParseResult OperationParser::parseTrailingOperationLocation(Operation *op) {
+  // If there is a 'loc' we parse a trailing location.
+  if (!consumeIf(Token::kw_loc))
+    return success();
+  if (parseToken(Token::l_paren, "expected '(' in location"))
+    return failure();
+  Token tok = getToken();
+
+  // Check to see if we are parsing a location alias.
+  LocationAttr directLoc;
+  if (tok.is(Token::hash_identifier)) {
+    consumeToken();
+
+    StringRef identifier = tok.getSpelling().drop_front();
+    if (identifier.contains('.')) {
+      return emitError(tok.getLoc())
+             << "expected location, but found dialect attribute: '#"
+             << identifier << "'";
+    }
+
+    // If this alias can be resolved, do it now.
+    Attribute attr =
+        getState().symbols.attributeAliasDefinitions.lookup(identifier);
+    if (attr) {
+      if (!(directLoc = attr.dyn_cast<LocationAttr>()))
+        return emitError(tok.getLoc())
+               << "expected location, but found '" << attr << "'";
+    } else {
+      // Otherwise, remember this operation and resolve its location later.
+      opsWithDeferredLocs.emplace_back(op, tok);
+    }
 
-  // Otherwise, we succeeded.  Use the state it parsed as our op information.
-  return opBuilder.createOperation(opState);
+    // Otherwise, we parse the location directly.
+  } else if (parseLocationInstance(directLoc)) {
+    return failure();
+  }
+
+  if (parseToken(Token::r_paren, "expected ')' in location"))
+    return failure();
+
+  if (directLoc)
+    op->setLoc(directLoc);
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
index bec9cf5..e32cca4 100644 (file)
@@ -243,12 +243,6 @@ public:
   /// Parse a name or FileLineCol location instance.
   ParseResult parseNameOrFileLineColLocation(LocationAttr &loc);
 
-  /// Parse an optional trailing location.
-  ///
-  ///   trailing-location ::= (`loc` (`(` location `)` | attribute-alias))?
-  ///
-  ParseResult parseOptionalTrailingLocation(Location &loc);
-
   //===--------------------------------------------------------------------===//
   // Affine Parsing
   //===--------------------------------------------------------------------===//
index 06f7951..175ab34 100644 (file)
@@ -102,7 +102,23 @@ func @location_fused_missing_r_square() {
 // -----
 
 func @location_invalid_alias() {
-  // expected-error@+1 {{expected location, but found #foo.loc}}
+  // expected-error@+1 {{expected location, but found dialect attribute: '#foo.loc'}}
   return loc(#foo.loc)
 }
 
+// -----
+
+func @location_invalid_alias() {
+  // expected-error@+1 {{operation location alias was never defined}}
+  return loc(#invalid_alias)
+}
+
+// -----
+
+func @location_invalid_alias() {
+  // expected-error@+1 {{expected location, but found 'true'}}
+  return loc(#non_loc)
+}
+
+#non_loc = true
+
index 1a83871..f7c9636 100644 (file)
@@ -27,8 +27,8 @@ func @inline_notation() -> i32 {
 // CHECK-LABEL: func @loc_attr(i1 {foo.loc_attr = loc(callsite("foo" at "mysource.cc":10:8))})
 func @loc_attr(i1 {foo.loc_attr = loc(callsite("foo" at "mysource.cc":10:8))})
 
-// CHECK-ALIAS: #[[LOC:.*]] = loc("out_of_line_location")
-#loc = loc("out_of_line_location")
-
-// CHECK-ALIAS: "foo.op"() : () -> () loc(#[[LOC]])
+// CHECK-ALIAS: "foo.op"() : () -> () loc(#[[LOC:.*]])
 "foo.op"() : () -> () loc(#loc)
+
+// CHECK-ALIAS: #[[LOC]] = loc("out_of_line_location")
+#loc = loc("out_of_line_location")