Add support for using the new diagnostics infrastructure in the parser. This...
authorRiver Riddle <riverriddle@google.com>
Wed, 8 May 2019 17:29:50 +0000 (10:29 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sat, 11 May 2019 02:24:05 +0000 (19:24 -0700)
--

PiperOrigin-RevId: 247239436

mlir/include/mlir/IR/Diagnostics.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/IR/Diagnostics.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/invalid.mlir

index 361f2c8..39c28b9 100644 (file)
@@ -24,6 +24,7 @@
 
 #include "mlir/IR/Location.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/Twine.h"
 #include <functional>
@@ -180,7 +181,8 @@ public:
 
   /// Stream operator for inserting new diagnostic arguments.
   template <typename Arg>
-  typename std::enable_if<!std::is_convertible<Arg, StringRef>::value,
+  typename std::enable_if<!std::is_convertible<Arg, Twine>::value ||
+                              std::is_integral<Arg>::value,
                           Diagnostic &>::type
   operator<<(Arg &&val) {
     arguments.push_back(DiagnosticArgument(std::forward<Arg>(val)));
@@ -190,15 +192,52 @@ public:
     arguments.push_back(DiagnosticArgument(val));
     return *this;
   }
+  Diagnostic &operator<<(char val) { return *this << Twine(val); }
   Diagnostic &operator<<(const Twine &val) {
-    llvm::SmallString<0> str;
-    arguments.push_back(DiagnosticArgument(val.toStringRef(str)));
-    stringArguments.emplace_back(std::move(str), arguments.size());
+    // Allocate memory to hold this string.
+    llvm::SmallString<0> data;
+    auto strRef = val.toStringRef(data);
+    strings.push_back(std::unique_ptr<char[]>(new char[strRef.size()]));
+    memcpy(&strings.back()[0], strRef.data(), strRef.size());
+
+    // Add the new string to the argument list.
+    strRef = StringRef(&strings.back()[0], strRef.size());
+    arguments.push_back(DiagnosticArgument(strRef));
     return *this;
   }
+
   /// Stream in an Identifier.
   Diagnostic &operator<<(Identifier val);
 
+  /// Stream in a range.
+  template <typename T> Diagnostic &operator<<(llvm::iterator_range<T> range) {
+    return appendRange(range);
+  }
+  template <typename T> Diagnostic &operator<<(llvm::ArrayRef<T> range) {
+    return appendRange(range);
+  }
+
+  /// Append a range to the diagnostic. The default delimiter between elements
+  /// is ','.
+  template <typename T, template <typename> class Container>
+  Diagnostic &appendRange(const Container<T> &c, const char *delim = ", ") {
+    interleave(
+        c, [&](T a) { *this << a; }, [&]() { *this << delim; });
+    return *this;
+  }
+
+  /// Append arguments to the diagnostic.
+  template <typename Arg1, typename Arg2, typename... Args>
+  Diagnostic &append(Arg1 &&arg1, Arg2 &&arg2, Args &&... args) {
+    append(std::forward<Arg1>(arg1));
+    return append(std::forward<Arg2>(arg2), std::forward<Args>(args)...);
+  }
+  /// Append one argument to the diagnostic.
+  template <typename Arg> Diagnostic &append(Arg &&arg) {
+    *this << std::forward<Arg>(arg);
+    return *this;
+  }
+
   /// Outputs this diagnostic to a stream.
   void print(raw_ostream &os) const;
 
@@ -234,10 +273,9 @@ private:
   /// The current list of arguments.
   SmallVector<DiagnosticArgument, 4> arguments;
 
-  /// A list of string values used as arguments and the corresponding index of
-  /// those arguments. This is used to guarantee the liveness of non-constant
-  /// strings used in diagnostics.
-  std::vector<std::pair<llvm::SmallString<0>, unsigned>> stringArguments;
+  /// A list of string values used as arguments. This is used to guarantee the
+  /// liveness of non-constant strings used in diagnostics.
+  std::vector<std::unique_ptr<char[]>> strings;
 
   /// A list of attached notes.
   NoteVector notes;
@@ -262,6 +300,7 @@ public:
       : owner(rhs.owner), impl(std::move(rhs.impl)) {
     // Reset the rhs diagnostic.
     rhs.impl.reset();
+    rhs.abandon();
   }
   ~InFlightDiagnostic() {
     if (isInFlight())
@@ -280,20 +319,20 @@ public:
 
   /// Attaches a note to this diagnostic.
   Diagnostic &attachNote(llvm::Optional<Location> noteLoc = llvm::None) {
-    assert(isInFlight() && "diagnostic not inflight");
+    assert(isActive() && "diagnostic not active");
     return impl->attachNote(noteLoc);
   }
 
   /// Reports the diagnostic to the engine.
   void report();
 
+  /// Abandons this diagnostic so that it will no longer be reported.
+  void abandon();
+
   /// Allow an inflight diagnostic to be converted to 'failure', otherwise
   /// 'success' if this is an empty diagnostic.
   operator LogicalResult() const;
 
-  /// Returns if the diagnostic is still in flight.
-  bool isInFlight() const { return impl.hasValue(); }
-
 private:
   InFlightDiagnostic &operator=(const InFlightDiagnostic &) = delete;
   InFlightDiagnostic &operator=(InFlightDiagnostic &&) = delete;
@@ -302,10 +341,17 @@ private:
 
   /// Add an argument to the internal diagnostic.
   template <typename Arg> void appendArgument(Arg &&arg) {
-    assert(isInFlight() && "diagnostic not inflight");
-    *impl << std::forward<Arg>(arg);
+    assert(isActive() && "diagnostic not active");
+    if (isInFlight())
+      *impl << std::forward<Arg>(arg);
   }
 
+  /// Returns if the diagnostic is still active, i.e. it has a live diagnostic.
+  bool isActive() const { return impl.hasValue(); }
+
+  /// Returns if the diagnostic is still in flight to be reported.
+  bool isInFlight() const { return owner; }
+
   // Allow access to the constructor.
   friend DiagnosticEngine;
 
index 986b9a4..c42050c 100644 (file)
@@ -62,6 +62,9 @@ class ParseResult : public LogicalResult {
 public:
   ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
 
+  // Allow diagnostics emitted during parsing to be converted to failure.
+  ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {}
+
   /// Failure is true in a boolean context.
   explicit operator bool() const { return failed(*this); }
 };
index 8d15921..0b39d4c 100644 (file)
@@ -209,7 +209,7 @@ public:
   /// Parse a keyword.
   ParseResult parseKeyword(const char *keyword, const Twine &msg = "") {
     if (parseOptionalKeyword(keyword))
-      return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'" + msg);
+      return emitError(getNameLoc(), "expected '") << keyword << "'" << msg;
     return success();
   }
 
@@ -370,9 +370,9 @@ public:
                                       ArrayRef<Type> types, llvm::SMLoc loc,
                                       SmallVectorImpl<Value *> &result) {
     if (operands.size() != types.size())
-      return emitError(loc, Twine(operands.size()) +
-                                " operands present, but expected " +
-                                Twine(types.size()));
+      return emitError(loc)
+             << operands.size() << " operands present, but expected "
+             << types.size();
 
     for (unsigned i = 0, e = operands.size(); i != e; ++i)
       if (resolveOperand(operands[i], types[i], result))
@@ -386,7 +386,8 @@ public:
                                           Function *&result) = 0;
 
   /// Emit a diagnostic at the specified location and return failure.
-  virtual ParseResult emitError(llvm::SMLoc loc, const Twine &message) = 0;
+  virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
+                                       const Twine &message = {}) = 0;
 };
 
 } // end namespace mlir
index 877be72..27c602d 100644 (file)
@@ -113,17 +113,23 @@ Diagnostic &Diagnostic::attachNote(llvm::Optional<Location> noteLoc) {
 /// Allow an inflight diagnostic to be converted to 'failure', otherwise
 /// 'success' if this is an empty diagnostic.
 InFlightDiagnostic::operator LogicalResult() const {
-  return failure(isInFlight());
+  return failure(isActive());
 }
 
 /// Reports the diagnostic to the engine.
 void InFlightDiagnostic::report() {
+  // If this diagnostic is still inflight and it hasn't been abandoned, then
+  // report it.
   if (isInFlight()) {
     owner->emit(*impl);
-    impl.reset();
+    owner = nullptr;
   }
+  impl.reset();
 }
 
+/// Abandons this diagnostic.
+void InFlightDiagnostic::abandon() { owner = nullptr; }
+
 //===----------------------------------------------------------------------===//
 // DiagnosticEngineImpl
 //===----------------------------------------------------------------------===//
index d637994..245a50f 100644 (file)
@@ -122,10 +122,10 @@ public:
   }
 
   /// Emit an error and return failure.
-  ParseResult emitError(const Twine &message) {
+  InFlightDiagnostic emitError(const Twine &message = {}) {
     return emitError(state.curToken.getLoc(), message);
   }
-  ParseResult emitError(SMLoc loc, const Twine &message);
+  InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {});
 
   /// Advance the current lexer onto the next token.
   void consumeToken() {
@@ -240,14 +240,14 @@ private:
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
-ParseResult Parser::emitError(SMLoc loc, const Twine &message) {
+InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) {
+  auto diag = getContext()->emitError(getEncodedSourceLocation(loc), message);
+
   // If we hit a parse error in response to a lexer error, then the lexer
   // already reported the error.
   if (getToken().is(Token::error))
-    return failure();
-
-  getContext()->emitError(getEncodedSourceLocation(loc), message);
-  return failure();
+    diag.abandon();
+  return diag;
 }
 
 /// Consume the specified token if present and return success.  On failure,
@@ -1328,16 +1328,12 @@ Attribute Parser::parseAttribute(Type type) {
       auto sameElementNum =
           indicesType.getDimSize(0) == valuesType.getDimSize(0);
       if (!sameShape || !sameElementNum) {
-        std::string str;
-        llvm::raw_string_ostream s(str);
-        s << "expected shape ([";
-        interleaveComma(type.getShape(), s);
-        s << "]); inferred shape of indices literal ([";
-        interleaveComma(indicesType.getShape(), s);
-        s << "]); inferred shape of values literal ([";
-        interleaveComma(valuesType.getShape(), s);
-        s << "])";
-        return (emitError(s.str()), nullptr);
+        emitError() << "expected shape ([" << type.getShape()
+                    << "]); inferred shape of indices literal (["
+                    << indicesType.getShape()
+                    << "]); inferred shape of values literal (["
+                    << valuesType.getShape() << "])";
+        return nullptr;
       }
 
       if (parseToken(Token::greater, "expected '>'"))
@@ -1398,14 +1394,10 @@ DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) {
     return nullptr;
 
   if (literalParser.getShape() != type.getShape()) {
-    std::string str;
-    llvm::raw_string_ostream s(str);
-    s << "inferred shape of elements literal ([";
-    interleaveComma(literalParser.getShape(), s);
-    s << "]) does not match type ([";
-    interleaveComma(type.getShape(), s);
-    s << "])";
-    return (emitError(s.str()), nullptr);
+    emitError() << "inferred shape of elements literal (["
+                << literalParser.getShape() << "]) does not match type (["
+                << type.getShape() << "])";
+    return nullptr;
   }
 
   return builder.getDenseElementsAttr(type, literalParser.getValues())
@@ -2038,7 +2030,7 @@ ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) {
   auto name = getTokenSpelling();
   for (auto entry : dimsAndSymbols) {
     if (entry.first == name)
-      return emitError("redefinition of identifier '" + Twine(name) + "'");
+      return emitError("redefinition of identifier '" + name + "'");
   }
   consumeToken(Token::bare_identifier);
 
@@ -2743,8 +2735,8 @@ ParseResult FunctionParser::parseOptionalSSAUseAndTypeList(
     return failure();
 
   if (valueIDs.size() != types.size())
-    return emitError("expected " + Twine(valueIDs.size()) +
-                     " types to match operand list");
+    return emitError("expected ")
+           << valueIDs.size() << " types to match operand list";
 
   results.reserve(valueIDs.size());
   for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) {
@@ -2946,9 +2938,9 @@ ParseResult FunctionParser::parseOperation() {
     if (op->getNumResults() == 0)
       return emitError(loc, "cannot name an operation with no results");
     if (numExpectedResults != op->getNumResults())
-      return emitError(loc, "operation defines " + Twine(op->getNumResults()) +
-                                " results but was provided " +
-                                Twine(numExpectedResults) + " to bind");
+      return emitError(loc, "operation defines ")
+             << op->getNumResults() << " results but was provided "
+             << numExpectedResults << " to bind";
 
     // If the number of result names matches the number of operation results, we
     // can directly use the provided names.
@@ -3066,9 +3058,9 @@ Operation *FunctionParser::parseGenericOperation() {
   auto operandTypes = fnType.getInputs();
   if (operandTypes.size() != operandInfos.size()) {
     auto plural = "s"[operandInfos.size() == 1];
-    return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
-                                   " operand type" + plural + " but had " +
-                                   llvm::utostr(operandTypes.size())),
+    return (emitError(typeLoc, "expected ")
+                << operandInfos.size() << " operand type" << plural
+                << " but had " << operandTypes.size(),
             nullptr);
   }
 
@@ -3158,8 +3150,8 @@ public:
       return parseOperandList(result, requiredOperandCount, delimiter);
     }
     if (requiredOperandCount != -1)
-      return emitError(parser.getToken().getLoc(),
-                       "expected " + Twine(requiredOperandCount) + " operands");
+      return emitError(parser.getToken().getLoc(), "expected ")
+             << requiredOperandCount << " operands";
     return success();
   }
 
@@ -3322,8 +3314,8 @@ public:
     }
 
     if (requiredOperandCount != -1 && result.size() != requiredOperandCount)
-      return emitError(startLoc,
-                       "expected " + Twine(requiredOperandCount) + " operands");
+      return emitError(startLoc, "expected ")
+             << requiredOperandCount << " operands";
     return success();
   }
 
@@ -3369,8 +3361,10 @@ public:
       return failure();
     if (auto defLoc = parser.getDefinitionLoc(argument.name, argument.number)) {
       parser.emitError(argument.location,
-                       "redefinition of SSA value '" + argument.name + "'");
-      return parser.emitError(*defLoc, "previously defined here");
+                       "redefinition of SSA value '" + argument.name + "'")
+              .attachNote(parser.getEncodedSourceLocation(*defLoc))
+          << "previously defined here";
+      return failure();
     }
     return success();
   }
@@ -3395,10 +3389,9 @@ public:
   }
 
   /// Emit a diagnostic at the specified location and return failure.
-  ParseResult emitError(llvm::SMLoc loc, const Twine &message) override {
+  InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
     emittedError = true;
-    return parser.emitError(loc,
-                            "custom op '" + Twine(opName) + "' " + message);
+    return parser.emitError(loc, "custom op '" + opName + "' " + message);
   }
 
   bool didEmitError() const { return emittedError; }
@@ -3750,8 +3743,7 @@ ParseResult ModuleParser::parseFunc() {
 
   // Verify no name collision / redefinition.
   if (function->getName() != name)
-    return emitError(loc,
-                     "redefinition of function named '" + name.str() + "'");
+    return emitError(loc, "redefinition of function named '") << name << "'";
 
   // Parse an optional trailing location.
   if (parseOptionalTrailingLocation(function))
@@ -3795,8 +3787,8 @@ ParseResult ModuleParser::finalizeModule() {
     // Resolve the reference.
     auto *resolvedFunction = getModule()->getNamedFunction(name);
     if (!resolvedFunction) {
-      forwardRef.second->emitError("reference to undefined function '" +
-                                   name.str() + "'");
+      forwardRef.second->emitError("reference to undefined function '")
+          << name << "'";
       return failure();
     }
 
index 1288fe0..020b056 100644 (file)
@@ -420,7 +420,7 @@ func @undef() {
 // -----
 
 func @duplicate_induction_var() {
-  affine.for %i = 1 to 10 {   // expected-error {{previously defined here}}
+  affine.for %i = 1 to 10 {   // expected-note {{previously defined here}}
     affine.for %i = 1 to 10 { // expected-error {{redefinition of SSA value '%i'}}
     }
   }