[REFACTOR][IR] Move error.h into ir (#4701)
authorTianqi Chen <tqchen@users.noreply.github.com>
Tue, 14 Jan 2020 04:16:03 +0000 (20:16 -0800)
committerGitHub <noreply@github.com>
Tue, 14 Jan 2020 04:16:03 +0000 (20:16 -0800)
We will use a single ErrorReporter to report errors during
program transformations.

17 files changed:
include/tvm/ir/error.h [moved from include/tvm/relay/error.h with 66% similarity]
include/tvm/relay/expr_functor.h
include/tvm/relay/pattern_functor.h
include/tvm/relay/transform.h
src/ir/error.cc [moved from src/relay/ir/error.cc with 88% similarity]
src/relay/backend/vm/compiler.cc
src/relay/backend/vm/compiler.h
src/relay/op/tensor/transform.cc
src/relay/op/tensor/transform.h
src/relay/op/type_relations.cc
src/relay/op/type_relations.h
src/relay/pass/kind_check.cc
src/relay/pass/match_exhaustion.cc
src/relay/pass/type_infer.cc
src/relay/pass/type_solver.cc
src/relay/pass/type_solver.h
src/relay/qnn/op/concatenate.cc

similarity index 66%
rename from include/tvm/relay/error.h
rename to include/tvm/ir/error.h
index 1c91b6e..94064ae 100644 (file)
  */
 
 /*!
- * \file error.h
- * \brief The set of errors raised by Relay.
+ * \file tvm/ir/error.h
+ * \brief Utilities for error tracking and reporting.
  */
-#ifndef TVM_RELAY_ERROR_H_
-#define TVM_RELAY_ERROR_H_
+#ifndef TVM_IR_ERROR_H_
+#define TVM_IR_ERROR_H_
 
+#include <tvm/ir/span.h>
 #include <tvm/ir/module.h>
 
 #include <string>
 #include <sstream>
 #include <unordered_map>
 
-#include "./base.h"
-#include "./expr.h"
-
-
 namespace tvm {
-namespace relay {
-
-#define RELAY_ERROR(msg) (RelayErrorStream() << msg)
-
-// Forward declaratio for RelayErrorStream.
-struct Error;
-
-/*! \brief A wrapper around std::stringstream.
+/*!
+ * \brief A wrapper around std::stringstream to build error.
  *
- * This is designed to avoid platform specific
- * issues compiling and using std::stringstream
- * for error reporting.
+ * Can be consumed by Error to construct an error.
+ *
+ * \code
+ *
+ * void ReportError(const Error& err);
+ *
+ * void Test(int number) {
+ *   // Use error reporter to construct an error.
+ *   ReportError(ErrorBuilder() << "This is an error number=" << number);
+ * }
+ *
+ * \endcode
  */
-struct RelayErrorStream {
-  std::stringstream ss;
-
+struct ErrorBuilder {
+ public:
   template<typename T>
-  RelayErrorStream& operator<<(const T& t) {
-    ss << t;
+  ErrorBuilder& operator<<(const T& val) {  // NOLINT(*)
+    stream_ << val;
     return *this;
   }
 
-  std::string str() const {
-    return ss.str();
-  }
-
-  void Raise() const;
+ private:
+  std::stringstream stream_;
+  friend class Error;
 };
 
-struct Error : public dmlc::Error {
-  Span sp;
-  explicit Error(const std::string& msg) : dmlc::Error(msg), sp(nullptr) {}
-  Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp(nullptr) {} // NOLINT(*)
-  Error(const Error& err) : dmlc::Error(err.what()), sp(nullptr) {}
-  Error() : dmlc::Error(""), sp(nullptr) {}
+/*!
+ * \brief Custom Error class to be thrown during compilation.
+ */
+class Error : public dmlc::Error {
+ public:
+  /*! \brief Location of the error */
+  Span span;
+  /*!
+   * \brief construct error from message.
+   * \param msg The message
+   */
+  explicit Error(const std::string& msg) : dmlc::Error(msg), span(nullptr) {}
+  /*!
+   * \brief construct error from error builder.
+   * \param err The error builder
+   */
+  Error(const ErrorBuilder& err) : dmlc::Error(err.stream_.str()), span(nullptr) {} // NOLINT(*)
+  /*!
+   * \brief copy constructor.
+   * \param other The other ereor.
+   */
+  Error(const Error& other) : dmlc::Error(other.what()), span(other.span) {} // NOLINT(*)
+  /*!
+   * \brief default constructor. */
+  Error() : dmlc::Error(""), span(nullptr) {}
 };
 
-/*! \brief An abstraction around how errors are stored and reported.
+/*!
+ * \brief An abstraction around how errors are stored and reported.
  * Designed to be opaque to users, so we can support a robust and simpler
  * error reporting mode, as well as a more complex mode.
  *
@@ -94,23 +111,26 @@ struct Error : public dmlc::Error {
  */
 class ErrorReporter {
  public:
+  /*! \brief default constructor. */
   ErrorReporter() : errors_(), node_to_error_() {}
 
-  /*! \brief Report a tvm::relay::Error.
+  /*!
+   * \brief Report a tvm::Error.
    *
    * This API is useful for reporting spanned errors.
    *
    * \param err The error to report.
    */
   void Report(const Error& err) {
-    if (!err.sp.defined()) {
+    if (!err.span.defined()) {
       throw err;
     }
 
     this->errors_.push_back(err);
   }
 
-  /*! \brief Report an error against a program, using the full program
+  /*!
+   * \brief Report an error against a program, using the full program
    * error reporting strategy.
    *
    * This error reporting method requires the global function in which
@@ -121,12 +141,13 @@ class ErrorReporter {
    * \param node The expression or type to report the error at.
    * \param err The error message to report.
    */
-  inline void ReportAt(const GlobalVar& global, const ObjectRef& node, std::stringstream& err) {
+  void ReportAt(const GlobalVar& global, const ObjectRef& node, std::stringstream& err) {
     std::string err_msg = err.str();
     this->ReportAt(global, node, Error(err_msg));
   }
 
-  /*! \brief Report an error against a program, using the full program
+  /*!
+   * \brief Report an error against a program, using the full program
    * error reporting strategy.
    *
    * This error reporting method requires the global function in which
@@ -139,7 +160,8 @@ class ErrorReporter {
    */
   void ReportAt(const GlobalVar& global, const ObjectRef& node, const Error& err);
 
-  /*! \brief Render all reported errors and exit the program.
+  /*!
+   * \brief Render all reported errors and exit the program.
    *
    * This function should be used after executing a pass to render reported errors.
    *
@@ -161,7 +183,5 @@ class ErrorReporter {
   std::unordered_map<ObjectRef, GlobalVar, ObjectHash, ObjectEqual> node_to_gv_;
 };
 
-}  // namespace relay
 }  // namespace tvm
-
-#endif  // TVM_RELAY_ERROR_H_
+#endif  // TVM_IR_ERROR_H_
index f1d7152..68cef94 100644 (file)
 #define TVM_RELAY_EXPR_FUNCTOR_H_
 
 #include <tvm/node/functor.h>
+#include <tvm/ir/error.h>
+
 #include <string>
 #include <utility>
 #include <unordered_map>
+
 #include "./expr.h"
 #include "./adt.h"
 #include "./op.h"
-#include "./error.h"
+
 
 namespace tvm {
 namespace relay {
index 71a024f..6e0fb17 100644 (file)
 #define TVM_RELAY_PATTERN_FUNCTOR_H_
 
 #include <tvm/node/functor.h>
+#include <tvm/ir/error.h>
+
 #include <string>
 #include <utility>
 #include <unordered_map>
+
 #include "./expr.h"
 #include "./op.h"
-#include "./error.h"
 #include "./adt.h"
 
 namespace tvm {
index d57740c..a740ea4 100644 (file)
@@ -59,7 +59,7 @@
 #include <tvm/base.h>
 #include <tvm/packed_func_ext.h>
 #include <tvm/relay/attrs/transform.h>
-#include <tvm/relay/error.h>
+#include <tvm/ir/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/ir/module.h>
 #include <tvm/relay/op.h>
similarity index 88%
rename from src/relay/ir/error.cc
rename to src/ir/error.cc
index 5c23f33..99db14e 100644 (file)
  */
 
 /*!
- * \file error_reporter.h
- * \brief The set of errors raised by Relay.
+ * \file ir/error.cc
+ * \brief Utilities for error tracking and reporting.
  */
 
-#include <tvm/relay/expr.h>
 #include <tvm/ir/module.h>
-#include <tvm/relay/error.h>
+#include <tvm/ir/error.h>
+// NOTE on dependencies on relay AsText.
+// We calls into relay's printing module for better rendering.
+// These dependency does not happen at the interface-level.
+// And is only used to enhance developer experiences when relay
+// functions are presented.
+#include <tvm/relay/expr.h>
+
 #include <string>
 #include <vector>
 #include <rang.hpp>
 
 namespace tvm {
-namespace relay {
-
-void RelayErrorStream::Raise() const {
-  throw Error(*this);
-}
 
 template<typename T, typename U>
 using NodeMap = std::unordered_map<T, U, ObjectHash, ObjectEqual>;
@@ -43,7 +44,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) {
   // First we pick an error reporting strategy for each error.
   // TODO(@jroesch): Spanned errors are currently not supported.
   for (auto err : this->errors_) {
-    CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported";
+    CHECK(!err.span.defined()) << "attempting to use spanned errors, currently not supported";
   }
 
   NodeMap<GlobalVar, NodeMap<ObjectRef, std::string>> error_maps;
@@ -110,7 +111,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) {
     //
     // The annotation callback will annotate the error messages
     // contained in the map.
-    annotated_prog << AsText(func, false, [&err_map](tvm::relay::Expr expr) {
+    annotated_prog << relay::AsText(func, false, [&err_map](tvm::relay::Expr expr) {
       auto it = err_map.find(expr);
       if (it != err_map.end()) {
         CHECK_NE(it->second.size(), 0);
@@ -144,5 +145,4 @@ void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node, con
   this->node_to_gv_.insert({ node, global });
 }
 
-}  // namespace relay
 }  // namespace tvm
index ce3972f..be0ce5a 100644 (file)
@@ -23,7 +23,7 @@
  */
 
 #include <tvm/operation.h>
-#include <tvm/relay/error.h>
+#include <tvm/ir/error.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/interpreter.h>
 #include <tvm/relay/qnn/transform.h>
index 00bde11..07e704f 100644 (file)
@@ -25,7 +25,7 @@
 #ifndef TVM_RELAY_BACKEND_VM_COMPILER_H_
 #define TVM_RELAY_BACKEND_VM_COMPILER_H_
 
-#include <tvm/relay/error.h>
+#include <tvm/ir/error.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/interpreter.h>
 #include <tvm/logging.h>
index c9d824d..aa643c4 100644 (file)
@@ -22,7 +22,7 @@
  * \brief Transform operators.
  */
 #include <tvm/relay/op.h>
-#include <tvm/relay/error.h>
+#include <tvm/ir/error.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/expr_operator.h>
 #include <tvm/ir.h>
@@ -392,7 +392,7 @@ bool StackRel(const Array<Type>& types,
     for (size_t j = 0; j < first->shape.size(); ++j) {
       if (j == static_cast<size_t>(axis)) continue;
       if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue;
-      throw relay::Error("relay.stack requires all tensors have the same shape "
+      throw Error("relay.stack requires all tensors have the same shape "
                          "on non-stacking axes");
     }
   }
index 74a630c..a1cbf7a 100644 (file)
@@ -24,7 +24,7 @@
 #ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_
 #define TVM_RELAY_OP_TENSOR_TRANSFORM_H_
 
-#include <tvm/relay/error.h>
+#include <tvm/ir/error.h>
 #include <vector>
 #include <algorithm>
 #include <limits>
@@ -48,10 +48,10 @@ bool ConcatenateRel(const Array<Type>& types,
   */
   const auto* tensor_tuple = types[0].as<TupleTypeNode>();
   if (tensor_tuple == nullptr) {
-    throw relay::Error(
-        RELAY_ERROR(
-          "concatenate requires a tuple of tensors as the first argument, found "
-        << PrettyPrint(types[0])));
+    throw Error(
+        ErrorBuilder()
+        << "concatenate requires a tuple of tensors as the first argument, found "
+        << PrettyPrint(types[0]));
   } else if (types[0].as<IncompleteTypeNode>() != nullptr) {
     return false;
   }
@@ -68,10 +68,10 @@ bool ConcatenateRel(const Array<Type>& types,
   // Sanity check: axis
   int axis = param->axis;
   if (!(-ndim <= axis && axis < ndim)) {
-    throw relay::Error(RELAY_ERROR(
+    throw Error(ErrorBuilder() <<
       "concatenate only accepts `axis` in [-ndim, ndim)" <<
       ", but got axis = " << axis <<
-      ", and ndim = " << ndim));
+      ", and ndim = " << ndim);
   }
   axis = axis < 0 ? ndim + axis : axis;
 
@@ -85,16 +85,16 @@ bool ConcatenateRel(const Array<Type>& types,
     int e_ndim = static_cast<int>(e->shape.size());
     const DataType& e_dtype = e->dtype;
     if (e_ndim != ndim) {
-      throw relay::Error("relay.concatenate requires all tensors have the same ndim");
+      throw Error("relay.concatenate requires all tensors have the same ndim");
     }
     if (e_dtype != dtype) {
-      throw relay::Error("relay.concatenate requires all tensors have the same dtype");
+      throw Error("relay.concatenate requires all tensors have the same dtype");
     }
     for (size_t j = 0; j < first->shape.size(); ++j) {
       if (j == static_cast<size_t>(axis)) continue;
       if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue;
-      throw relay::Error("relay.concatenate requires all tensors have the same shape "
-                         "on non-concatenating axes");
+      throw Error("relay.concatenate requires all tensors have the same shape "
+                   "on non-concatenating axes");
     }
   }
 
index 13baefb..fbaf665 100644 (file)
@@ -93,9 +93,9 @@ Type ConcreteBroadcast(const TensorType& t1,
     } else if (EqualCheck(s1, s2)) {
       oshape.push_back(s1);
     } else {
-      RELAY_ERROR(
-          "Incompatible broadcast type "
-              << t1 << " and " << t2).Raise();
+      throw Error(ErrorBuilder()
+          << "Incompatible broadcast type "
+          << t1 << " and " << t2);
     }
   }
 
index f52bf78..80e555b 100644 (file)
@@ -25,7 +25,7 @@
 #ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_
 #define TVM_RELAY_OP_TYPE_RELATIONS_H_
 
-#include <tvm/relay/error.h>
+#include <tvm/ir/error.h>
 #include <tvm/relay/type.h>
 #include <string>
 
index 2d207b5..55fd78a 100644 (file)
@@ -32,7 +32,7 @@
  * contains a data type such as `int`, `float`, `uint`.
  */
 #include <tvm/relay/analysis.h>
-#include <tvm/relay/error.h>
+#include <tvm/ir/error.h>
 #include "../ir/type_functor.h"
 
 namespace tvm {
@@ -55,11 +55,12 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
                         Kind expected, const std::string& description) {
     Kind k = this->VisitType(t);
     if (k != expected) {
-      ReportFatalError(RELAY_ERROR("Incorrect kind for a " << description
-                                   << ". Type " << t << " inside " << outer
-                                   << " is of kind " << k
-                                   << " but was expected to be "
-                                   << expected));
+      ReportFatalError(ErrorBuilder()
+        << "Incorrect kind for a " << description
+        << ". Type " << t << " inside " << outer
+        << " is of kind " << k
+        << " but was expected to be "
+        << expected);
     }
   }
 
@@ -127,8 +128,9 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
     TypeCall tc = GetRef<TypeCall>(op);
     const auto* gtv = op->func.as<GlobalTypeVarNode>();
     if (gtv == nullptr) {
-      ReportFatalError(RELAY_ERROR("The callee in " << tc
-                                   << " is not a global type var, but is " << op->func));
+      ReportFatalError(
+        ErrorBuilder() <<"The callee in " << tc
+        << " is not a global type var, but is " << op->func);
     }
 
     CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function");
@@ -141,8 +143,9 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
     auto var = GetRef<GlobalTypeVar>(gtv);
     auto data = mod->LookupTypeDef(var);
     if (data->type_vars.size() != op->args.size()) {
-      ReportFatalError(RELAY_ERROR("Expected " << data->type_vars.size() << "arguments for " << tc
-                                   << "; got " << op->args.size()));
+      ReportFatalError(ErrorBuilder()
+        << "Expected " << data->type_vars.size() << "arguments for " << tc
+        << "; got " << op->args.size());
     }
     return Kind::kType;
   }
@@ -161,8 +164,9 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
 
     for (const auto& con : op->constructors) {
       if (!con->belong_to.same_as(op->header)) {
-        ReportFatalError(RELAY_ERROR(con << " has header " << con->belong_to
-                                     << " but " << op << " has header " << op->header));
+        ReportFatalError(ErrorBuilder()
+          <<con << " has header " << con->belong_to
+          << " but " << op << " has header " << op->header);
       }
 
       for (const Type& t : con->inputs) {
index 161b682..885c47e 100644 (file)
@@ -28,7 +28,7 @@
  * dynamic error unless exhaustiveness is checked in advance.
  */
 #include <tvm/relay/adt.h>
-#include <tvm/relay/error.h>
+#include <tvm/ir/error.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
 #include <stack>
index 876cf48..faf42ab 100644 (file)
@@ -38,7 +38,7 @@
  * constraints we will trigger an error.
  */
 
-#include <tvm/relay/error.h>
+#include <tvm/ir/error.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/analysis.h>
@@ -144,11 +144,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     } catch (const dmlc::Error &e) {
       this->ReportFatalError(
         expr,
-        RELAY_ERROR("Error unifying `"
+        ErrorBuilder()
+          << "Error unifying `"
           << t1
           << "` and `"
           << t2
-          << "`: " << e.what()));
+          << "`: " << e.what());
       return Type();
     }
   }
@@ -188,9 +189,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     if (!mod_.defined()) {
       this->ReportFatalError(
         GetRef<GlobalVar>(op),
-        RELAY_ERROR(
+        ErrorBuilder() <<
           "Cannot do type inference on global variables " \
-          "without a module"));
+          "without a module");
     }
     Expr e = mod_->Lookup(var);
     return e->checked_type();
@@ -239,16 +240,18 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
 
     auto* tc = unified.as<TypeCallNode>();
     if (!tc) {
-      this->ReportFatalError(pc, RELAY_ERROR("Expected a type call, got " << unified));
+      this->ReportFatalError(pc, ErrorBuilder() << "Expected a type call, got " << unified);
     }
     if (td->header != tc->func) {
-      this->ReportFatalError(pc, RELAY_ERROR("ADT headers must match, but we have "
-                                             << td->header << " and " << tc->func));
+      this->ReportFatalError(pc,
+        ErrorBuilder() << "ADT headers must match, but we have "
+                        << td->header << " and " << tc->func);
     }
     if (td->type_vars.size() != tc->args.size()) {
-      this->ReportFatalError(pc, RELAY_ERROR("The number of type args must match"
-                                             << "the number of type vars in the type data: "
-                                             << td->type_vars.size() << " != " << tc->args.size()));
+      this->ReportFatalError(pc,
+        ErrorBuilder() << "The number of type args must match"
+                       << "the number of type vars in the type data: "
+                       << td->type_vars.size() << " != " << tc->args.size());
     }
     std::unordered_map<TypeVar, Type, ObjectHash, ObjectEqual> type_var_map_;
     for (size_t i = 0; i < td->type_vars.size(); ++i) {
@@ -256,9 +259,10 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     }
     CHECK(con->constructor->inputs.size() == con->patterns.size()) << "not enough pattern";
     if (con->constructor->inputs.size() != con->patterns.size()) {
-      this->ReportFatalError(pc, RELAY_ERROR("Not enough inputs for the constructor; "
-                                             << "expected " << con->constructor->inputs.size()
-                                             << ", got " << con->patterns.size()));
+      this->ReportFatalError(pc,
+        ErrorBuilder() << "Not enough inputs for the constructor; "
+                       << "expected " << con->constructor->inputs.size()
+                       << ", got " << con->patterns.size());
     }
     for (size_t i = 0; i < con->constructor->inputs.size(); ++i) {
       VisitPattern(con->patterns[i], Bind(con->constructor->inputs[i], type_var_map_));
@@ -278,7 +282,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
 
     auto* tt = unified.as<TupleTypeNode>();
     if (!tt) {
-      this->ReportFatalError(pt, RELAY_ERROR("Expected a tuple type, got " << unified));
+      this->ReportFatalError(pt, ErrorBuilder() << "Expected a tuple type, got " << unified);
     }
     CHECK(tup->patterns.size() == tt->fields.size()) << "not enough pattern";
     for (size_t i = 0; i < tup->patterns.size(); ++i) {
@@ -310,7 +314,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
       Match match = GetRef<Match>(op);
       Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_);
       if (unmatched_cases.size() != 0) {
-        RelayErrorStream ss;
+        ErrorBuilder ss;
         ss << "match expression does not handle the following cases: ";
         int i = 0;
         for (auto cs : unmatched_cases) {
@@ -454,8 +458,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     if (fn_ty_node == nullptr && inc_ty_node == nullptr) {
       this->ReportFatalError(
         GetRef<Call>(call),
-        RELAY_ERROR("only expressions with function types can be called, found "
-        << ftype));
+        ErrorBuilder()
+          << "only expressions with function types can be called, found "
+          << ftype);
     }
 
     // incomplete type => it must be a function taking the arg types
@@ -470,11 +475,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     Array<Type> type_args = call->type_args;
     if (type_args.size() > fn_ty_node->type_params.size()) {
       this->ReportFatalError(GetRef<Call>(call),
-        RELAY_ERROR("Incorrect number of type args in "
+        ErrorBuilder()
+          << "Incorrect number of type args in "
           << call->span << ": "
           << "Expected "
           << fn_ty_node->type_params.size()
-          << "but got " << type_args.size()));
+          << "but got " << type_args.size());
     }
 
     FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args);
@@ -488,13 +494,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
       if (type_arity < number_of_args) {
         this->ReportFatalError(
           GetRef<Call>(call),
-          RELAY_ERROR("the function is provided too many arguments "
-          << "expected " << type_arity << ", found " << number_of_args));
+          ErrorBuilder()
+            << "the function is provided too many arguments "
+            << "expected " << type_arity << ", found " << number_of_args);
       } else {
         this->ReportFatalError(
           GetRef<Call>(call),
-          RELAY_ERROR("the function is provided too few arguments "
-          << "expected " << type_arity << ", found " << number_of_args));
+          ErrorBuilder()
+            << "the function is provided too few arguments "
+            << "expected " << type_arity << ", found " << number_of_args);
       }
     }
 
index 372b351..d0d8b43 100644 (file)
@@ -124,10 +124,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     } else {
       Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type);
       if (!resolved.defined()) {
-        solver_->ReportError(RELAY_ERROR("unable to unify: "
-                                         << "`" << PrettyPrint(lhs->resolved_type) << "` and `"
-                                         << PrettyPrint(rhs->resolved_type) << "`"),
-                             this->loc);
+        solver_->ReportError(
+          ErrorBuilder() << "unable to unify: "
+                         << "`" << PrettyPrint(lhs->resolved_type) << "` and `"
+                         << PrettyPrint(rhs->resolved_type) << "`",
+          this->loc);
         return lhs->resolved_type;
       } else {
         TypeNode* top = solver_->GetTypeNode(resolved);
@@ -225,13 +226,13 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     tvm::Array<IndexExpr> shape;
     if (tt1->shape.size() != tt2->shape.size()) {
       this->solver_->ReportError(
-        RELAY_ERROR(
+        ErrorBuilder() <<
           "tensor type `" << PrettyPrint(tt1) <<
           "` has " <<  tt1->shape.size() <<
           " dimensions, while `" <<
           PrettyPrint(tt2) <<
           "` has " << tt2->shape.size() <<
-          " dimensions"), this->loc);
+          " dimensions", this->loc);
       return Type(nullptr);
     }
 
@@ -253,7 +254,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
     }
 
     if (mismatches.size() != 0) {
-      RelayErrorStream err;
+      ErrorBuilder err;
       err << "in particular ";
       for (auto mismatch : mismatches) {
         err << "dimension "
@@ -639,10 +640,11 @@ bool TypeSolver::Solve() {
       rnode->resolved = false;
     } catch (const dmlc::Error& err) {
       rnode->resolved = false;
-      this->ReportError(RELAY_ERROR("an internal invariant was violated while "
-                                    "typechecking your program "
-                                    << err.what()),
-                        rnode->location);
+      this->ReportError(
+        ErrorBuilder() << "an internal invariant was violated while "
+                       << "typechecking your program "
+                       << err.what(),
+        rnode->location);
     }
 
     // Mark inqueue as false after the function call
index eba1bea..00a43ec 100644 (file)
@@ -27,7 +27,7 @@
 #include <tvm/relay/expr.h>
 #include <tvm/relay/type.h>
 #include <tvm/relay/analysis.h>
-#include <tvm/relay/error.h>
+#include <tvm/ir/error.h>
 #include <vector>
 #include <queue>
 #include <unordered_map>
index 685fb9f..26093f2 100644 (file)
@@ -41,9 +41,10 @@ bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& at
   // Check the scale and zero point types
   const auto* input_scales_tuple = types[1].as<TupleTypeNode>();
   if (input_scales_tuple == nullptr) {
-    throw relay::Error(
-        RELAY_ERROR("qnn concatenate requires a tuple of scales as the second argument, found "
-                    << PrettyPrint(types[1])));
+    throw Error(
+        ErrorBuilder()
+        << "qnn concatenate requires a tuple of scales as the second argument, found "
+        << PrettyPrint(types[1]));
   }
   for (const auto& input_scale : input_scales_tuple->fields) {
     CHECK(IsScalarType(input_scale, DataType::Float(32)));  // input_scales[idx]
@@ -51,9 +52,10 @@ bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& at
 
   const auto* input_zero_points_tuple = types[2].as<TupleTypeNode>();
   if (input_zero_points_tuple == nullptr) {
-    throw relay::Error(
-        RELAY_ERROR("qnn concatenate requires a tuple of zero_points as the third argument, found "
-                    << PrettyPrint(types[2])));
+    throw Error(
+        ErrorBuilder()
+        << "qnn concatenate requires a tuple of zero_points as the third argument, found "
+        << PrettyPrint(types[2]));
   }
   for (const auto& input_zero_point : input_zero_points_tuple->fields) {
     CHECK(IsScalarType(input_zero_point, DataType::Int(32)));  // input_zero_points[idx]