Remove the need for passing a location to parseAttribute/parseType.
authorRiver Riddle <riverriddle@google.com>
Fri, 1 Nov 2019 22:39:30 +0000 (15:39 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 1 Nov 2019 22:40:16 +0000 (15:40 -0700)
Now that a proper parser is passed to these methods, there isn't a need to explicitly pass a source location. The source location can be recovered from the parser as necessary. This removes the need to explicitly decode an SMLoc in the case where we don't need to, which can be expensive.

This requires adding some basic nesting support to the parser for supporting nested parsers to allow for remapping source locations of the nested parsers to the top level parser for accurate diagnostics. This is due to the fact that the attribute and type parsers use different source buffers than the top level parser, as they may be represented in string form.

PiperOrigin-RevId: 278014858

14 files changed:
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
mlir/include/mlir/Dialect/QuantOps/QuantOps.h
mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
mlir/include/mlir/IR/Diagnostics.h
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/DialectImplementation.h
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/Parser/Lexer.h
mlir/lib/Parser/Parser.cpp

index d09e815..eb39537 100644 (file)
@@ -173,7 +173,7 @@ public:
   llvm::Module &getLLVMModule();
 
   /// Parse a type registered to this dialect.
-  Type parseType(DialectAsmParser &parser, Location loc) const override;
+  Type parseType(DialectAsmParser &parser) const override;
 
   /// Print a type registered to this dialect.
   void printType(Type type, DialectAsmPrinter &os) const override;
index 8888953..81c3097 100644 (file)
@@ -37,7 +37,7 @@ public:
   static StringRef getDialectNamespace() { return "linalg"; }
 
   /// Parse a type registered to this dialect.
-  Type parseType(DialectAsmParser &parser, Location loc) const override;
+  Type parseType(DialectAsmParser &parser) const override;
 
   /// Print a type registered to this dialect.
   void printType(Type type, DialectAsmPrinter &os) const override;
index f1ac383..020d349 100644 (file)
@@ -35,7 +35,7 @@ public:
   QuantizationDialect(MLIRContext *context);
 
   /// Parse a type registered to this dialect.
-  Type parseType(DialectAsmParser &parser, Location loc) const override;
+  Type parseType(DialectAsmParser &parser) const override;
 
   /// Print a type registered to this dialect.
   void printType(Type type, DialectAsmPrinter &os) const override;
index 6401eba..2571e5d 100644 (file)
@@ -46,7 +46,7 @@ public:
   static std::string getAttributeName(Decoration decoration);
 
   /// Parses a type registered to this dialect.
-  Type parseType(DialectAsmParser &parser, Location loc) const override;
+  Type parseType(DialectAsmParser &parser) const override;
 
   /// Prints a type registered to this dialect.
   void printType(Type type, DialectAsmPrinter &os) const override;
index d533273..1d284f6 100644 (file)
@@ -391,7 +391,7 @@ private:
   friend DiagnosticEngine;
 
   /// The engine that this diagnostic is to report to.
-  DiagnosticEngine *owner;
+  DiagnosticEngine *owner = nullptr;
 
   /// The raw diagnostic that is inflight to be reported.
   llvm::Optional<Diagnostic> impl;
index bd84bee..4880fd0 100644 (file)
@@ -117,8 +117,7 @@ public:
 
   /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
   /// refers to the expected type of the attribute.
-  virtual Attribute parseAttribute(DialectAsmParser &parser, Type type,
-                                   Location loc) const;
+  virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
 
   /// Print an attribute registered to this dialect. Note: The type of the
   /// attribute need not be printed by this method as it is always printed by
@@ -128,7 +127,7 @@ public:
   }
 
   /// Parse a type registered to this dialect.
-  virtual Type parseType(DialectAsmParser &parser, Location loc) const;
+  virtual Type parseType(DialectAsmParser &parser) const;
 
   /// Print a type registered to this dialect.
   virtual void printType(Type, DialectAsmPrinter &) const {
index c662a4c..c538b81 100644 (file)
@@ -129,6 +129,9 @@ public:
   /// Return the location of the original name token.
   virtual llvm::SMLoc getNameLoc() const = 0;
 
+  /// Re-encode the given source location as an MLIR location and return it.
+  virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
+
   /// Returns the full specification of the symbol being parsed. This allows for
   /// using a separate parser if necessary.
   virtual StringRef getFullSymbolSpec() const = 0;
index 39decf9..cc2de87 100644 (file)
@@ -1250,7 +1250,7 @@ llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
 llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
 
 /// Parse a type registered to this dialect.
-Type LLVMDialect::parseType(DialectAsmParser &parser, Location loc) const {
+Type LLVMDialect::parseType(DialectAsmParser &parser) const {
   StringRef tyData = parser.getFullSymbolSpec();
 
   // LLVM is not thread-safe, so lock access to it.
@@ -1259,7 +1259,8 @@ Type LLVMDialect::parseType(DialectAsmParser &parser, Location loc) const {
   llvm::SMDiagnostic errorMessage;
   llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module);
   if (!type)
-    return (emitError(loc, errorMessage.getMessage()), nullptr);
+    return (parser.emitError(parser.getNameLoc(), errorMessage.getMessage()),
+            nullptr);
   return LLVMType::get(getContext(), type);
 }
 
index 4a7bcd8..41e76a8 100644 (file)
@@ -108,8 +108,8 @@ Optional<int64_t> mlir::linalg::BufferType::getBufferSize() {
   return getImpl()->getBufferSize();
 }
 
-Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser,
-                                            Location loc) const {
+Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const {
+  Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
   StringRef spec = parser.getFullSymbolSpec();
   StringRef origSpec = spec;
   MLIRContext *context = getContext();
index 360c1b5..26212f6 100644 (file)
@@ -616,8 +616,8 @@ bool TypeParser::parseQuantParams(double &scale, int64_t &zeroPoint) {
 }
 
 /// Parse a type registered to this dialect.
-Type QuantizationDialect::parseType(DialectAsmParser &parser,
-                                    Location loc) const {
+Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
+  Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
   TypeParser typeParser(parser.getFullSymbolSpec(), getContext(), loc);
   Type parsedType = typeParser.parseType();
   if (parsedType == nullptr) {
index 26d1ff1..abe4724 100644 (file)
@@ -610,8 +610,9 @@ static Type parseStructType(SPIRVDialect const &dialect, StringRef spec,
 //              | pointer-type
 //              | runtime-array-type
 //              | struct-type
-Type SPIRVDialect::parseType(DialectAsmParser &parser, Location loc) const {
+Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
   StringRef spec = parser.getFullSymbolSpec();
+  Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
 
   if (spec.startswith("array"))
     return parseArrayType(*this, spec, loc);
index 7882e4f..c6266b0 100644 (file)
@@ -102,23 +102,23 @@ LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
 }
 
 /// Parse an attribute registered to this dialect.
-Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type,
-                                  Location loc) const {
-  emitError(loc) << "dialect '" << getNamespace()
-                 << "' provides no attribute parsing hook";
+Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
+  parser.emitError(parser.getNameLoc())
+      << "dialect '" << getNamespace()
+      << "' provides no attribute parsing hook";
   return Attribute();
 }
 
 /// Parse a type registered to this dialect.
-Type Dialect::parseType(DialectAsmParser &parser, Location loc) const {
+Type Dialect::parseType(DialectAsmParser &parser) const {
   // If this dialect allows unknown types, then represent this with OpaqueType.
   if (allowsUnknownTypes()) {
     auto ns = Identifier::get(getNamespace(), getContext());
     return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
   }
 
-  emitError(loc) << "dialect '" << getNamespace()
-                 << "' provides no type parsing hook";
+  parser.emitError(parser.getNameLoc())
+      << "dialect '" << getNamespace() << "' provides no type parsing hook";
   return Type();
 }
 
index 896c26c..b180771 100644 (file)
@@ -45,6 +45,9 @@ public:
   /// at the designated point in the input.
   void resetPointer(const char *newPointer) { curPtr = newPointer; }
 
+  /// Returns the start of the buffer.
+  const char *getBufferBegin() { return curBuffer.data(); }
+
 private:
   // Helpers.
   Token formToken(Token::Kind kind, const char *tokStart) {
index a6e0227..368f262 100644 (file)
@@ -52,16 +52,27 @@ namespace {
 class Parser;
 
 //===----------------------------------------------------------------------===//
-// AliasState
+// SymbolState
 //===----------------------------------------------------------------------===//
 
-/// This class contains record of any parsed top-level aliases.
-struct AliasState {
+/// This class contains record of any parsed top-level symbols.
+struct SymbolState {
   // A map from attribute alias identifier to Attribute.
   llvm::StringMap<Attribute> attributeAliasDefinitions;
 
   // A map from type alias identifier to Type.
   llvm::StringMap<Type> typeAliasDefinitions;
+
+  /// A set of locations into the main parser memory buffer for each of the
+  /// active nested parsers. Given that some nested parsers, i.e. custom dialect
+  /// parsers, operate on a temporary memory buffer, this provides an anchor
+  /// point for emitting diagnostics.
+  SmallVector<llvm::SMLoc, 1> nestedParserLocs;
+
+  /// The top-level lexer that contains the original memory buffer provided by
+  /// the user. This is used by nested parsers to get a properly encoded source
+  /// location.
+  Lexer *topLevelLexer = nullptr;
 };
 
 //===----------------------------------------------------------------------===//
@@ -72,9 +83,18 @@ struct AliasState {
 /// such as the current lexer position etc.
 struct ParserState {
   ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx,
-              AliasState &aliases)
+              SymbolState &symbols)
       : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()),
-        aliases(aliases) {}
+        symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) {
+    // Set the top level lexer for the symbol state if one doesn't exist.
+    if (!symbols.topLevelLexer)
+      symbols.topLevelLexer = &lex;
+  }
+  ~ParserState() {
+    // Reset the top level lexer if it refers the lexer in our state.
+    if (symbols.topLevelLexer == &lex)
+      symbols.topLevelLexer = nullptr;
+  }
   ParserState(const ParserState &) = delete;
   void operator=(const ParserState &) = delete;
 
@@ -87,8 +107,11 @@ struct ParserState {
   /// This is the next token that hasn't been consumed yet.
   Token curToken;
 
-  /// Any parsed alias state.
-  AliasState &aliases;
+  /// The current state for symbol parsing.
+  SymbolState &symbols;
+
+  /// The depth of this parser in the nested parsing stack.
+  size_t parserDepth;
 };
 
 //===----------------------------------------------------------------------===//
@@ -140,7 +163,32 @@ public:
   /// Encode the specified source location information into an attribute for
   /// attachment to the IR.
   Location getEncodedSourceLocation(llvm::SMLoc loc) {
-    return state.lex.getEncodedSourceLocation(loc);
+    // If there are no active nested parsers, we can get the encoded source
+    // location directly.
+    if (state.parserDepth == 0)
+      return state.lex.getEncodedSourceLocation(loc);
+    // Otherwise, we need to re-encode it to point to the top level buffer.
+    return state.symbols.topLevelLexer->getEncodedSourceLocation(
+        remapLocationToTopLevelBuffer(loc));
+  }
+
+  /// Remaps the given SMLoc to the top level lexer of the parser. This is used
+  /// to adjust locations of potentially nested parsers to ensure that they can
+  /// be emitted properly as diagnostics.
+  llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) {
+    // If there are no active nested parsers, we can return location directly.
+    SymbolState &symbols = state.symbols;
+    if (state.parserDepth == 0)
+      return loc;
+    assert(symbols.topLevelLexer && "expected valid top-level lexer");
+
+    // Otherwise, we need to remap the location to the main parser. This is
+    // simply offseting the location onto the location of the last nested
+    // parser.
+    size_t offset = loc.getPointer() - state.lex.getBufferBegin();
+    auto *rawLoc =
+        symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset;
+    return llvm::SMLoc::getFromPointer(rawLoc);
   }
 
   //===--------------------------------------------------------------------===//
@@ -388,6 +436,11 @@ public:
   /// Return the location of the original name token.
   llvm::SMLoc getNameLoc() const override { return nameLoc; }
 
+  /// Re-encode the given source location as an MLIR location and return it.
+  Location getEncodedSourceLoc(llvm::SMLoc loc) override {
+    return parser.getEncodedSourceLocation(loc);
+  }
+
   /// Returns the full specification of the symbol being parsed. This allows
   /// for using a separate parser if necessary.
   StringRef getFullSymbolSpec() const override { return fullSpec; }
@@ -517,7 +570,7 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
       return (p.emitError("expected string literal data in dialect symbol"),
               nullptr);
     symbolData = p.getToken().getStringValue();
-    loc = p.getToken().getLoc();
+    loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1);
     p.consumeToken(Token::string);
 
     // Consume the '>'.
@@ -529,6 +582,7 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
     auto dotHalves = identifier.split('.');
     dialectName = dotHalves.first;
     auto prettyName = dotHalves.second;
+    loc = llvm::SMLoc::getFromPointer(prettyName.data());
 
     // If the dialect's symbol is followed immediately by a <, then lex the body
     // of it into prettyName.
@@ -541,8 +595,16 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
     symbolData = prettyName.str();
   }
 
+  // Record the name location of the type remapped to the top level buffer.
+  llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc);
+  p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer);
+
   // Call into the provided symbol construction function.
-  return createSymbol(dialectName, symbolData, loc);
+  Symbol sym = createSymbol(dialectName, symbolData, loc);
+
+  // Pop the last parser location.
+  p.getState().symbols.nestedParserLocs.pop_back();
+  return sym;
 }
 
 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
@@ -550,14 +612,14 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
 /// string is returned in 'numRead'.
 template <typename T, typename ParserFn>
 static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
-                     AliasState &aliasState, ParserFn &&parserFn,
+                     SymbolState &symbolState, ParserFn &&parserFn,
                      size_t *numRead = nullptr) {
   SourceMgr sourceMgr;
   auto memBuffer = MemoryBuffer::getMemBuffer(
       inputStr, /*BufferName=*/"<mlir_parser_buffer>",
       /*RequiresNullTerminator=*/false);
   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
-  ParserState state(sourceMgr, context, aliasState);
+  ParserState state(sourceMgr, context, symbolState);
   Parser parser(state);
 
   Token startTok = parser.getToken();
@@ -573,8 +635,7 @@ static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
 
     // Otherwise, ensure that all of the tokens were parsed.
   } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) {
-    parser.emitError(endTok.getLoc(),
-                     "encountered unexpected tokens after parsing");
+    parser.emitError(endTok.getLoc(), "encountered unexpected token");
     return T();
   }
   return symbol;
@@ -585,13 +646,12 @@ static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
 //===----------------------------------------------------------------------===//
 
 InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) {
-  auto diag = mlir::emitError(getEncodedSourceLocation(loc), message);
-
   // If we hit a parse error in response to a lexer error, then the lexer
   // already reported the error.
   if (getToken().is(Token::error))
-    diag.abandon();
-  return diag;
+    return InFlightDiagnostic();
+
+  return mlir::emitError(getEncodedSourceLocation(loc), message);
 }
 
 //===----------------------------------------------------------------------===//
@@ -701,24 +761,22 @@ Type Parser::parseComplexType() {
 ///
 Type Parser::parseExtendedType() {
   return parseExtendedSymbol<Type>(
-      *this, Token::exclamation_identifier, state.aliases.typeAliasDefinitions,
+      *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions,
       [&](StringRef dialectName, StringRef symbolData,
           llvm::SMLoc loc) -> Type {
-        Location encodedLoc = getEncodedSourceLocation(loc);
-
         // If we found a registered dialect, then ask it to parse the type.
         if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
           return parseSymbol<Type>(
-              symbolData, state.context, state.aliases, [&](Parser &parser) {
+              symbolData, state.context, state.symbols, [&](Parser &parser) {
                 CustomDialectAsmParser customParser(symbolData, parser);
-                return dialect->parseType(customParser, encodedLoc);
+                return dialect->parseType(customParser);
               });
         }
 
         // Otherwise, form a new opaque type.
         return OpaqueType::getChecked(
             Identifier::get(dialectName, state.context), symbolData,
-            state.context, encodedLoc);
+            state.context, getEncodedSourceLocation(loc));
       });
 }
 
@@ -1315,7 +1373,7 @@ Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
 ///
 Attribute Parser::parseExtendedAttr(Type type) {
   Attribute attr = parseExtendedSymbol<Attribute>(
-      *this, Token::hash_identifier, state.aliases.attributeAliasDefinitions,
+      *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions,
       [&](StringRef dialectName, StringRef symbolData,
           llvm::SMLoc loc) -> Attribute {
         // Parse an optional trailing colon type.
@@ -1326,10 +1384,9 @@ Attribute Parser::parseExtendedAttr(Type type) {
         // If we found a registered dialect, then ask it to parse the attribute.
         if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
           return parseSymbol<Attribute>(
-              symbolData, state.context, state.aliases, [&](Parser &parser) {
+              symbolData, state.context, state.symbols, [&](Parser &parser) {
                 CustomDialectAsmParser customParser(symbolData, parser);
-                return dialect->parseAttribute(customParser, attrType,
-                                               getEncodedSourceLocation(loc));
+                return dialect->parseAttribute(customParser, attrType);
               });
         }
 
@@ -4242,7 +4299,7 @@ ParseResult ModuleParser::parseAttributeAliasDef() {
   StringRef aliasName = getTokenSpelling().drop_front();
 
   // Check for redefinitions.
-  if (getState().aliases.attributeAliasDefinitions.count(aliasName) > 0)
+  if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0)
     return emitError("redefinition of attribute alias id '" + aliasName + "'");
 
   // Make sure this isn't invading the dialect attribute namespace.
@@ -4261,7 +4318,7 @@ ParseResult ModuleParser::parseAttributeAliasDef() {
   if (!attr)
     return failure();
 
-  getState().aliases.attributeAliasDefinitions[aliasName] = attr;
+  getState().symbols.attributeAliasDefinitions[aliasName] = attr;
   return success();
 }
 
@@ -4274,7 +4331,7 @@ ParseResult ModuleParser::parseTypeAliasDef() {
   StringRef aliasName = getTokenSpelling().drop_front();
 
   // Check for redefinitions.
-  if (getState().aliases.typeAliasDefinitions.count(aliasName) > 0)
+  if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0)
     return emitError("redefinition of type alias id '" + aliasName + "'");
 
   // Make sure this isn't invading the dialect type namespace.
@@ -4295,7 +4352,7 @@ ParseResult ModuleParser::parseTypeAliasDef() {
     return failure();
 
   // Register this alias with the parser state.
-  getState().aliases.typeAliasDefinitions.try_emplace(aliasName, aliasedType);
+  getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType);
   return success();
 }
 
@@ -4374,7 +4431,7 @@ OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
   OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
       sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
 
-  AliasState aliasState;
+  SymbolState aliasState;
   ParserState state(sourceMgr, context, aliasState);
   if (ModuleParser(state).parseModule(*module))
     return nullptr;
@@ -4440,7 +4497,7 @@ OwningModuleRef mlir::parseSourceString(StringRef moduleStr,
 template <typename T, typename ParserFn>
 static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context,
                      size_t &numRead, ParserFn &&parserFn) {
-  AliasState aliasState;
+  SymbolState aliasState;
   return parseSymbol<T>(
       inputStr, context, aliasState,
       [&](Parser &parser) {