[TIR][Printer] text format printer considering future parsing use (#5483)
authorBohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Sat, 9 May 2020 15:35:58 +0000 (23:35 +0800)
committerGitHub <noreply@github.com>
Sat, 9 May 2020 15:35:58 +0000 (08:35 -0700)
include/tvm/ir/module.h
src/printer/meta_data.h
src/printer/relay_text_printer.cc
src/printer/text_printer.cc [new file with mode: 0644]
src/printer/text_printer.h [new file with mode: 0644]
src/printer/tir_text_printer.cc [new file with mode: 0644]
tests/python/unittest/test_arith_deduce_bound.py
tests/python/unittest/test_te_schedule.py
tests/python/unittest/test_tir_nodes.py

index d113860..ae78383 100644 (file)
@@ -363,7 +363,7 @@ TVM_DLL String PrettyPrint(const ObjectRef& node);
  * \return The text representation.
  */
 TVM_DLL String AsText(const ObjectRef& node,
-                           bool show_meta_data = true,
-                           runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
+                      bool show_meta_data = true,
+                      runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
 }  // namespace tvm
 #endif  // TVM_IR_MODULE_H_
index d390692..8bf58ec 100644 (file)
@@ -109,6 +109,15 @@ class TextMetaDataContext {
   }
 
   /*!
+   * \brief Test whether a node has been put in meta
+   * \param node The query node
+   * \return whether the node has been put in meta
+   */
+  bool InMeta(const ObjectRef& node) {
+    return meta_repr_.find(node) != meta_repr_.end();
+  }
+
+  /*!
    * \brief Print a key value pair
    */
   Doc PrintKeyValue(const std::string& str, const Doc& v) const {
index 2e675c8..9e6abee 100644 (file)
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file text_format_printer.cc
+ * \file relay_text_printer.cc
  * \brief Printer to print out the IR text format
  *        that can be parsed by a parser.
  *
 #include "meta_data.h"
 #include "../relay/analysis/dependency_graph.h"
 #include "../ir/attr_functor.h"
+#include "text_printer.h"
 
 namespace tvm {
 namespace relay {
 
-class RelayTextPrinter :
-    public ExprFunctor<Doc(const Expr&)>,
-    public PatternFunctor<Doc(const Pattern&)>,
-    public TypeFunctor<Doc(const Type&)>,
-    public AttrFunctor<Doc(const ObjectRef&)> {
- public:
-  explicit RelayTextPrinter(bool show_meta_data,
-                            runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
-      : show_meta_data_(show_meta_data),
-        annotate_(annotate) {}
-
-  /*!
-    * \brief Print additional info about expr in comment.
-    * \param expr The expression.
-    */
-  Doc PrintOptionalInfo(const Expr& expr) {
-    Doc doc;
-    // default annotations
-    if (annotate_ == nullptr) {
-      if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
-        doc << " /* ty=" << Print(expr->checked_type()) << " */";
-      }
-    } else {
-      std::string annotated_expr = annotate_(expr);
-      if (annotated_expr != "") {
-        doc << annotated_expr;
-      }
+/*!
+  * \brief Print additional info about expr in comment.
+  * \param expr The expression.
+  */
+Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) {
+  Doc doc;
+  // default annotations
+  if (annotate_ == nullptr) {
+    if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
+      doc << " /* ty=" << Print(expr->checked_type()) << " */";
+    }
+  } else {
+    std::string annotated_expr = annotate_(expr);
+    if (annotated_expr != "") {
+      doc << annotated_expr;
     }
-
-    return doc;
   }
 
-  // indent a new body
-  Doc PrintBody(const ObjectRef& node, int indent = 2) {
-    Doc doc;
-    Doc body;
-    doc << "{";
-    doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine();
-    doc << "}";
-    return doc;
-  }
+  return doc;
+}
 
-  // create a new scope by creating a new printer object. This allows temp var
-  // numbers to be reused and prevents hoisted vars from escaping too far
-  Doc PrintScope(const ObjectRef& node) {
-    // print in a new scope
-    doc_stack_.push_back(Doc());
-    // must print first so doc_stack_.back() reference doesn't become stale
-    Doc doc = Print(node, false, true);
-    doc = doc_stack_.back() << doc;
-    doc_stack_.pop_back();
-    return doc;
+// indent a new body
+Doc RelayTextPrinter::PrintBody(const ObjectRef& node, int indent) {
+  Doc doc;
+  Doc body;
+  doc << "{";
+  doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine();
+  doc << "}";
+  return doc;
+}
+
+// create a new scope by creating a new printer object. This allows temp var
+// numbers to be reused and prevents hoisted vars from escaping too far
+Doc RelayTextPrinter::PrintScope(const ObjectRef& node) {
+  // print in a new scope
+  doc_stack_.push_back(Doc());
+  // must print first so doc_stack_.back() reference doesn't become stale
+  Doc doc = Print(node, false, true);
+  doc = doc_stack_.back() << doc;
+  doc_stack_.pop_back();
+  return doc;
+}
+
+Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) {
+  if (node->IsInstance<BaseFuncNode>() &&
+      !node->IsInstance<relay::FunctionNode>()) {
+    // Temporarily skip non-relay functions.
+    // TODO(tvm-team) enhance the code to work for all functions
+  } else if (node.as<ExprNode>()) {
+    Expr expr = Downcast<Expr>(node);
+    dg_ = DependencyGraph::Create(&arena_, expr);
   }
 
-  Doc PrintFinal(const ObjectRef& node) {
-    if (node->IsInstance<BaseFuncNode>() &&
-        !node->IsInstance<relay::FunctionNode>()) {
-      // Temporarily skip non-relay functions.
-      // TODO(tvm-team) enhance the code to work for all functions
-    } else if (node.as<ExprNode>()) {
-      Expr expr = Downcast<Expr>(node);
-      dg_ = DependencyGraph::Create(&arena_, expr);
-    }
+  Doc doc;
+  doc << PrintScope(node);
+  return doc;
+}
 
-    Doc doc;
-    doc << PrintScope(node);
-    if (!meta_.empty()) {
-      doc << Doc::NewLine();
-      if (show_meta_data_) {
-        // append meta data in the end.
-        doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection();
-      } else {
-        doc << "// meta data omitted. you can use show_meta_data=True to include meta data";
-      }
-    }
-    return doc;
+Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
+  bool is_non_relay_func =
+      node->IsInstance<BaseFuncNode>() &&
+      !node->IsInstance<relay::FunctionNode>();
+  if (node.as<ExprNode>() && !is_non_relay_func) {
+    return PrintExpr(Downcast<Expr>(node), meta, try_inline);
+  } else if (node.as<TypeNode>()) {
+    return PrintType(Downcast<Type>(node), meta);
+  } else if (node.as<PatternNode>()) {
+    return PrintPattern(Downcast<Pattern>(node), meta);
+  } else if (node.as<IRModuleNode>()) {
+    return PrintMod(Downcast<IRModule>(node));
+  } else {
+    // default module.
+    std::ostringstream os;
+    os << node;
+    return Doc::RawText(os.str());
   }
+}
 
-  std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
-  std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
-
-  Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false) {
-    bool is_non_relay_func =
-        node->IsInstance<BaseFuncNode>() &&
-        !node->IsInstance<relay::FunctionNode>();
-    if (node.as<ExprNode>() && !is_non_relay_func) {
-      return PrintExpr(Downcast<Expr>(node), meta, try_inline);
-    } else if (node.as<TypeNode>()) {
-      return PrintType(Downcast<Type>(node), meta);
-    } else if (node.as<PatternNode>()) {
-      return PrintPattern(Downcast<Pattern>(node), meta);
-    } else if (node.as<IRModuleNode>()) {
-      return PrintMod(Downcast<IRModule>(node));
-    } else {
-      // default module.
+Doc RelayTextPrinter::TempVar(int n) {
+  Doc doc;
+  return doc << "%" << n;
+}
+
+Doc RelayTextPrinter::AllocTemp() {
+  return TempVar(temp_var_counter_++);
+}
+
+/*!
+  * \brief get a unique name with the corresponding prefix
+  * \param prefix The prefix of the name
+  * \return The returned name.
+  */
+Doc RelayTextPrinter::GetUniqueName(const std::string& prefix) {
+  std::string unique_prefix = prefix;
+  auto it = name_alloc_map_.find(prefix);
+  if (it != name_alloc_map_.end()) {
+    while (true) {
       std::ostringstream os;
-      os << node;
-      return Doc::RawText(os.str());
+      os << prefix << (++it->second);
+      std::string name = os.str();
+      if (name_alloc_map_.count(name) == 0) {
+        unique_prefix = name;
+        break;
+      }
     }
   }
+  name_alloc_map_[unique_prefix] = 0;
+  return Doc::Text(unique_prefix);
+}
 
-  Doc TempVar(int n) {
-    Doc doc;
-    return doc << "%" << n;
-  }
-
-  Doc AllocTemp() {
-    return TempVar(temp_var_counter_++);
-  }
-
-  /*!
-    * \brief get a unique name with the corresponding prefix
-    * \param prefix The prefix of the name
-    * \return The returned name.
-    */
-  Doc GetUniqueName(const std::string& prefix) {
-    std::string unique_prefix = prefix;
-    auto it = name_alloc_map_.find(prefix);
-    if (it != name_alloc_map_.end()) {
-      while (true) {
-        std::ostringstream os;
-        os << prefix << (++it->second);
-        std::string name = os.str();
-        if (name_alloc_map_.count(name) == 0) {
-          unique_prefix = name;
-          break;
-        }
-      }
-    }
-    name_alloc_map_[unique_prefix] = 0;
-    return Doc::Text(unique_prefix);
-  }
-
-  Doc Print(Kind k) {
-    switch (k) {
-    case kType:
-      return Doc::Text("Type");
-    case kShapeVar:
-      return Doc::Text("Shape");
-    case kBaseType:
-      return Doc::Text("BaseType");
-    case kConstraint:
-      return Doc::Text("Constraint");
-    case kAdtHandle:
-      return Doc::Text("AdtHandle");
-    case kTypeData:
-      return Doc::Text("TypeData");
-    default:
-      LOG(ERROR) << "Unknown Kind";
-      throw;
-    }
+Doc RelayTextPrinter::Print(Kind k) {
+  switch (k) {
+  case kType:
+    return Doc::Text("Type");
+  case kShapeVar:
+    return Doc::Text("Shape");
+  case kBaseType:
+    return Doc::Text("BaseType");
+  case kConstraint:
+    return Doc::Text("Constraint");
+  case kAdtHandle:
+    return Doc::Text("AdtHandle");
+  case kTypeData:
+    return Doc::Text("TypeData");
+  default:
+    LOG(ERROR) << "Unknown Kind";
+    throw;
   }
-  /*!
-   * \brief Allocate name to a type variable.
-   * \param var The input type variable.
-   * \return The corresponding name.
-   */
-  Doc AllocTypeVar(const TypeVar& var) {
-    if (memo_type_.count(var)) {
-      Doc val = memo_type_[var];
-      val << "-malformed-ir";
-      return val;
-    }
-    std::string name = var->name_hint;
-    if (name.length() == 0 || !std::isalpha(name[0])) {
-      name = "t" + name;
-    }
-    Doc val = GetUniqueName(name);
-    memo_type_[var] = val;
-    if (var->kind != kType) {
-      val << ": " << Print(var->kind);
-    }
+}
+/*!
+ * \brief Allocate name to a type variable.
+ * \param var The input type variable.
+ * \return The corresponding name.
+ */
+Doc RelayTextPrinter::AllocTypeVar(const TypeVar& var) {
+  if (memo_type_.count(var)) {
+    Doc val = memo_type_[var];
+    val << "-malformed-ir";
     return val;
   }
+  std::string name = var->name_hint;
+  if (name.length() == 0 || !std::isalpha(name[0])) {
+    name = "t" + name;
+  }
+  Doc val = GetUniqueName(name);
+  memo_type_[var] = val;
+  if (var->kind != kType) {
+    val << ": " << Print(var->kind);
+  }
+  return val;
+}
 
-  /*!
-   * \brief Allocate name to a variable.
-   * \param var The input variable.
-   * \return The corresponding name.
-   */
-  Doc AllocVar(const Var& var) {
-    // still print if ir is malformed, but show the error.
-    if (memo_.count(var)) {
-      Doc val = memo_[var];
-      val << "-malformed-ir";
-      return val;
-    }
-    std::string name = var->name_hint();
-    // always make sure first name is alpha
-    if (name.length() == 0 || !std::isalpha(name[0])) {
-      name = "v" + name;
-    }
-    Doc val = GetUniqueName("%" + name);
-    memo_[var] = val;
-    if (var->type_annotation.defined()) {
-      val << ": " << Print(var->type_annotation);
-    }
+/*!
+ * \brief Allocate name to a variable.
+ * \param var The input variable.
+ * \return The corresponding name.
+ */
+Doc RelayTextPrinter::AllocVar(const Var& var) {
+  // still print if ir is malformed, but show the error.
+  if (memo_.count(var)) {
+    Doc val = memo_[var];
+    val << "-malformed-ir";
     return val;
   }
-
-  bool IsUnique(const Expr& expr) {
-    auto it = dg_.expr_node.find(expr);
-    if (it == dg_.expr_node.end()) {
-      return true;
-    } else {
-      return !(it->second->parents.head && it->second->parents.head->next);
-    }
+  std::string name = var->name_hint();
+  // always make sure first name is alpha
+  if (name.length() == 0 || !std::isalpha(name[0])) {
+    name = "v" + name;
   }
+  Doc val = GetUniqueName("%" + name);
+  memo_[var] = val;
+  if (var->type_annotation.defined()) {
+    val << ": " << Print(var->type_annotation);
+  }
+  return val;
+}
 
-  bool AlwaysInline(const Expr& expr) {
-    return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
-           expr.as<VarNode>() || expr.as<ConstructorNode>();
+bool RelayTextPrinter::IsUnique(const Expr& expr) {
+  auto it = dg_.expr_node.find(expr);
+  if (it == dg_.expr_node.end()) {
+    return true;
+  } else {
+    return !(it->second->parents.head && it->second->parents.head->next);
   }
+}
 
-  //------------------------------------
-  // Overload of Expr printing functions
-  //------------------------------------
-  Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) {
-    // Exploit memoization to print GNF.
-    // The first time we visit an expression, we need to allocate a temp var
-    // for it. Every subsequent time we can just use its assigned variable.
-    // This works since hashing uses pointer equality.
+bool RelayTextPrinter::AlwaysInline(const Expr& expr) {
+  return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
+         expr.as<VarNode>() || expr.as<ConstructorNode>();
+}
 
-    // determine whether to inline
-    bool inline_expr = AlwaysInline(expr);
-    if (try_inline) {
-      inline_expr |= IsUnique(expr);
-    }
+//------------------------------------
+// Overload of Expr printing functions
+//------------------------------------
+Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) {
+  // Exploit memoization to print GNF.
+  // The first time we visit an expression, we need to allocate a temp var
+  // for it. Every subsequent time we can just use its assigned variable.
+  // This works since hashing uses pointer equality.
+
+  // determine whether to inline
+  bool inline_expr = AlwaysInline(expr);
+  if (try_inline) {
+    inline_expr |= IsUnique(expr);
+  }
+
+  auto it = memo_.find(expr);
+  if (it != memo_.end()) return it->second;
+
+  Doc printed_expr;
+  if (meta) {
+    printed_expr = meta_->GetMetaNode(GetRef<ObjectRef>(expr.get()));
+  } else if (!inline_expr && expr.as<LetNode>()) {
+    // wrap GNFed let in brackets
+    Doc body;
+    printed_expr << "(";
+    printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine();
+    printed_expr << ")";
+  } else {
+    printed_expr = VisitExpr(expr);
+  }
 
-    auto it = memo_.find(expr);
-    if (it != memo_.end()) return it->second;
-
-    Doc printed_expr;
-    if (meta) {
-      printed_expr = meta_.GetMetaNode(GetRef<ObjectRef>(expr.get()));
-    } else if (!inline_expr && expr.as<LetNode>()) {
-      // wrap GNFed let in brackets
-      Doc body;
-      printed_expr << "(";
-      printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine();
-      printed_expr << ")";
-    } else {
-      printed_expr = VisitExpr(expr);
-    }
+  printed_expr << PrintOptionalInfo(expr);
 
-    printed_expr << PrintOptionalInfo(expr);
-
-    // add expr to doc
-    if (expr.as<VarNode>()) {
-      // This is our first time visiting the var and we hit the VarNode case
-      // in the visitor. Thus the variable is free.
-      doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine();
-      // Memoization is done in AllocVar.
-      return memo_[expr];
-    } else if (inline_expr) {
-      memo_[expr] = printed_expr;
-      return printed_expr;
-    } else {
-      Doc temp_var = AllocTemp();
-      memo_[expr] = temp_var;
-      doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine();
-      return temp_var;
-    }
+  // add expr to doc
+  if (expr.as<VarNode>()) {
+    // This is our first time visiting the var and we hit the VarNode case
+    // in the visitor. Thus the variable is free.
+    doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine();
+    // Memoization is done in AllocVar.
+    return memo_[expr];
+  } else if (inline_expr) {
+    memo_[expr] = printed_expr;
+    return printed_expr;
+  } else {
+    Doc temp_var = AllocTemp();
+    memo_[expr] = temp_var;
+    doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine();
+    return temp_var;
   }
+}
+
+// Should only be triggered when op is a free variable being visited for the
+// first time.
+Doc RelayTextPrinter::VisitExpr_(const VarNode* op) {
+  return AllocVar(GetRef<Var>(op));
+}
 
-  // Should only be triggered when op is a free variable being visited for the
-  // first time.
-  Doc VisitExpr_(const VarNode* op) final {
-    return AllocVar(GetRef<Var>(op));
+/*!
+ * \brief special method to print out const scalar
+ * \param dtype The data type
+ * \param value The value to be printed.
+ */
+template<typename T>
+Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) {
+  std::ostringstream os;
+  if (dtype == DataType::Int(32)) {
+    os << value;
+  } else if (dtype == DataType::Float(32)) {
+    os << value << 'f';
+  } else if (dtype == DataType::Float(64)) {
+    os << value;
+  } else if (dtype == DataType::Bool()) {
+    return Doc::PyBoolLiteral(value != 0);
+  } else {
+    os << value;
   }
+  return Doc::Text(os.str());
+}
 
-  /*!
-   * \brief special method to print out const scalar
-   * \param dtype The data type
-   * \param value The value to be printed.
-   */
-  template<typename T>
-  static Doc ScalarLiteral(DataType dtype, const T& value) {
+Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) {
+  // Print out simple scalars directly.
+  if (op->is_scalar()) {
     std::ostringstream os;
+    DataType dtype = DataType(op->data->dtype);
+    CHECK_EQ(op->data->ctx.device_type, kDLCPU);
     if (dtype == DataType::Int(32)) {
-      os << value;
+      return ScalarLiteral(dtype, static_cast<const int32_t*>(op->data->data)[0]);
+    } else if (dtype == DataType::Int(64)) {
+      return ScalarLiteral(dtype, static_cast<const int64_t*>(op->data->data)[0]);
     } else if (dtype == DataType::Float(32)) {
-      os << value << 'f';
+      return ScalarLiteral(dtype, static_cast<const float*>(op->data->data)[0]);
     } else if (dtype == DataType::Float(64)) {
-      os << value;
+      return ScalarLiteral(dtype, static_cast<const double*>(op->data->data)[0]);
     } else if (dtype == DataType::Bool()) {
-      return Doc::PyBoolLiteral(value != 0);
-    } else {
-      os << value;
+      return ScalarLiteral(dtype, static_cast<const uint8_t*>(op->data->data)[0]);
     }
-    return Doc::Text(os.str());
   }
+  // default fall-back, record it as meta node.
+  Doc doc;
+  return doc << Print(GetRef<ObjectRef>(op), true);
+}
 
-  Doc VisitExpr_(const ConstantNode* op) final {
-    // Print out simple scalars directly.
-    if (op->is_scalar()) {
-      std::ostringstream os;
-      DataType dtype = DataType(op->data->dtype);
-      CHECK_EQ(op->data->ctx.device_type, kDLCPU);
-      if (dtype == DataType::Int(32)) {
-        return ScalarLiteral(dtype, static_cast<const int32_t*>(op->data->data)[0]);
-      } else if (dtype == DataType::Int(64)) {
-        return ScalarLiteral(dtype, static_cast<const int64_t*>(op->data->data)[0]);
-      } else if (dtype == DataType::Float(32)) {
-        return ScalarLiteral(dtype, static_cast<const float*>(op->data->data)[0]);
-      } else if (dtype == DataType::Float(64)) {
-        return ScalarLiteral(dtype, static_cast<const double*>(op->data->data)[0]);
-      } else if (dtype == DataType::Bool()) {
-        return ScalarLiteral(dtype, static_cast<const uint8_t*>(op->data->data)[0]);
-      }
-    }
-    // default fall-back, record it as meta node.
-    Doc doc;
-    return doc << Print(GetRef<ObjectRef>(op), true);
+Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
+  std::vector<Doc> fields;
+  for (Expr field : op->fields) {
+    fields.push_back(Print(field));
   }
+  Doc doc;
+  doc << "(" << Doc::Concat(fields);
+  // conform to python tuple format (1,)
+  if (op->fields.size() == 1) {
+    doc << ",";
+  }
+  return doc << ")";
+}
 
-  Doc VisitExpr_(const TupleNode* op) final {
-    std::vector<Doc> fields;
-    for (Expr field : op->fields) {
-      fields.push_back(Print(field));
-    }
-    Doc doc;
-    doc << "(" << Doc::Concat(fields);
-    // conform to python tuple format (1,)
-    if (op->fields.size() == 1) {
-      doc << ",";
+Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) {
+  Doc doc;
+  return doc << Print(op->tuple) << "." << op->index;
+}
+
+Doc RelayTextPrinter::VisitExpr_(const IfNode* op) {
+  Doc doc;
+  doc << "if (" << Print(op->cond) << ") ";
+  doc << PrintBody(op->true_branch);
+  doc << " else ";
+  doc << PrintBody(op->false_branch);
+  return doc;
+}
+
+Doc RelayTextPrinter::VisitExpr_(const LetNode* op) {
+  Doc doc;
+  doc
+    << "let "
+    << AllocVar(op->var)
+    << " = "
+    << Print(op->value, false, true)
+    << ";"
+    << Doc::NewLine();
+  // we use a scope here so GNF hoisting doesn't escape too far
+  // and nested, unique lets are not hoisted
+  doc << PrintScope(op->body);
+  return doc;
+}
+
+Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
+  Doc doc;
+  doc << prefix;
+  if (fn->type_params.size() > 0) {
+    doc << "[";
+    std::vector<Doc> type_params;
+    for (const TypeVar& tv : fn->type_params) {
+      type_params.push_back(Doc::Text(tv->name_hint));
     }
-    return doc << ")";
+    doc << Doc::Concat(type_params);
+    doc << "]";
   }
-
-  Doc VisitExpr_(const TupleGetItemNode* op) final {
-    Doc doc;
-    return doc << Print(op->tuple) << "." << op->index;
+  doc << "(";
+  std::vector<Doc> params;
+  for (Var param : fn->params) {
+    params.push_back(AllocVar(param));
   }
-
-  Doc VisitExpr_(const IfNode* op) final {
-    Doc doc;
-    doc << "if (" << Print(op->cond) << ") ";
-    doc << PrintBody(op->true_branch);
-    doc << " else ";
-    doc << PrintBody(op->false_branch);
-    return doc;
+  for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
+    params.push_back(d);
   }
-
-  Doc VisitExpr_(const LetNode* op) final {
-    Doc doc;
-    doc
-      << "let "
-      << AllocVar(op->var)
-      << " = "
-      << Print(op->value, false, true)
-      << ";"
-      << Doc::NewLine();
-    // we use a scope here so GNF hoisting doesn't escape too far
-    // and nested, unique lets are not hoisted
-    doc << PrintScope(op->body);
-    return doc;
+  doc << Doc::Concat(params) << ") ";
+  if (fn->ret_type.defined()) {
+    doc << "-> " << Print(fn->ret_type) << " ";
   }
+  doc << PrintBody(fn->body);
+  return doc;
+}
 
-  Doc PrintFunc(const Doc& prefix, const relay::Function& fn) {
+Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
+  if (auto* n = base_func.as<relay::FunctionNode>()) {
+    return PrintFunc(prefix, GetRef<relay::Function>(n));
+  } else if (auto* n = base_func.as<tir::PrimFuncNode>()) {
+    std::ostringstream os;
+    os << GetRef<tir::PrimFunc>(n);
+    return Doc::RawText(os.str());
+  } else {
+    // def @xyz = meta['ExternalFunc'][id]
     Doc doc;
-    doc << prefix;
-    if (fn->type_params.size() > 0) {
-      doc << "[";
-      std::vector<Doc> type_params;
-      for (const TypeVar& tv : fn->type_params) {
-        type_params.push_back(Doc::Text(tv->name_hint));
-      }
-      doc << Doc::Concat(type_params);
-      doc << "]";
-    }
-    doc << "(";
-    std::vector<Doc> params;
-    for (Var param : fn->params) {
-      params.push_back(AllocVar(param));
-    }
-    for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
-      params.push_back(d);
-    }
-    doc << Doc::Concat(params) << ") ";
-    if (fn->ret_type.defined()) {
-      doc << "-> " << Print(fn->ret_type) << " ";
-    }
-    doc << PrintBody(fn->body);
+    doc << prefix << " = " <<  meta_->GetMetaNode(base_func);
     return doc;
   }
+}
 
-  Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
-    if (auto* n = base_func.as<relay::FunctionNode>()) {
-      return PrintFunc(prefix, GetRef<relay::Function>(n));
-    } else if (auto* n = base_func.as<tir::PrimFuncNode>()) {
-      std::ostringstream os;
-      os << GetRef<tir::PrimFunc>(n);
-      return Doc::RawText(os.str());
-    } else {
-      // def @xyz = meta['ExternalFunc'][id]
-      Doc doc;
-      doc << prefix << " = " <<  meta_.GetMetaNode(base_func);
-      return doc;
+Doc RelayTextPrinter::PrintMod(const IRModule& mod) {
+  Doc doc;
+  int counter = 0;
+  // type definitions
+  for (const auto& kv : mod->type_definitions) {
+    if (counter++ != 0) {
+      doc << Doc::NewLine();
     }
+    doc << Print(kv.second);
+    doc << Doc::NewLine();
   }
-
-  Doc PrintMod(const IRModule& mod) {
-    Doc doc;
-    int counter = 0;
-    // type definitions
-    for (const auto& kv : mod->type_definitions) {
-      if (counter++ != 0) {
-        doc << Doc::NewLine();
-      }
-      doc << Print(kv.second);
-      doc << Doc::NewLine();
+  // functions
+  for (const auto& kv : mod->functions) {
+    if (kv.second.as<relay::FunctionNode>()) {
+      dg_ = DependencyGraph::Create(&arena_, kv.second);
     }
-    // functions
-    for (const auto& kv : mod->functions) {
-      if (kv.second.as<relay::FunctionNode>()) {
-        dg_ = DependencyGraph::Create(&arena_, kv.second);
-      }
-      if (counter++ != 0) {
-        doc << Doc::NewLine();
-      }
-      std::ostringstream os;
-      os << "def @" << kv.first->name_hint;
-      doc << PrintFunc(Doc::Text(os.str()), kv.second);
+    if (counter++ != 0) {
       doc << Doc::NewLine();
     }
-    return doc;
-  }
-
-  Doc VisitExpr_(const FunctionNode* op) final {
-    return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
+    std::ostringstream os;
+    os << "def @" << kv.first->name_hint;
+    doc << PrintFunc(Doc::Text(os.str()), kv.second);
+    doc << Doc::NewLine();
   }
+  return doc;
+}
 
-  Doc VisitExpr_(const GlobalVarNode* op) final {
-    return Doc::Text('@' + op->name_hint);
-  }
+Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
+  return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
+}
 
-  Doc VisitExpr_(const OpNode* op) final {
-    return Doc::Text(op->name);
-  }
+Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
+  return Doc::Text('@' + op->name_hint);
+}
 
-  Doc VisitExpr_(const CallNode* op) final {
-    Doc doc;
-    // visit args first so they are lifted before the op
-    // this places op closer to its call site
-    std::vector<Doc> args;
-    for (const Expr& arg : op->args) {
-      args.push_back(Print(arg));
-    }
-    for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
-      args.push_back(d);
-    }
-    const auto* cons_node = op->op.as<ConstructorNode>();
-    if (cons_node) {
-      doc << cons_node->name_hint;
-    } else {
-      doc << Print(op->op);
-    }
+Doc RelayTextPrinter::VisitExpr_(const OpNode* op) {
+  return Doc::Text(op->name);
+}
 
-    if (cons_node && cons_node->inputs.size() == 0) {
-      // don't print as a call if it's a 0-arity cons
-      return doc;
-    } else {
-      return doc << "(" << Doc::Concat(args) << ")";
-    }
+Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
+  Doc doc;
+  // visit args first so they are lifted before the op
+  // this places op closer to its call site
+  std::vector<Doc> args;
+  for (const Expr& arg : op->args) {
+    args.push_back(Print(arg));
+  }
+  for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
+    args.push_back(d);
+  }
+  const auto* cons_node = op->op.as<ConstructorNode>();
+  if (cons_node) {
+    doc << cons_node->name_hint;
+  } else {
+    doc << Print(op->op);
   }
 
-  Doc VisitExpr_(const RefCreateNode* op) final {
-    Doc doc;
-    return doc << "ref(" << Print(op->value) << ")";
+  if (cons_node && cons_node->inputs.size() == 0) {
+    // don't print as a call if it's a 0-arity cons
+    return doc;
+  } else {
+    return doc << "(" << Doc::Concat(args) << ")";
   }
+}
 
-  Doc VisitExpr_(const RefReadNode* op) final {
-    Doc doc;
-    return doc << Print(op->ref) << "^";
-  }
+Doc RelayTextPrinter::VisitExpr_(const RefCreateNode* op) {
+  Doc doc;
+  return doc << "ref(" << Print(op->value) << ")";
+}
 
-  Doc VisitExpr_(const RefWriteNode* op) final {
-    Doc doc;
-    return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")";
-  }
+Doc RelayTextPrinter::VisitExpr_(const RefReadNode* op) {
+  Doc doc;
+  return doc << Print(op->ref) << "^";
+}
 
-  Doc VisitExpr_(const MatchNode* op) final {
-    // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
-    Doc doc;
-    Doc body;
-    doc << "match";
-    if (!op->complete) {
-      doc << "?";
-    }
-    doc << " (" << Print(op->data) << ") {";
-    std::vector<Doc> clause_docs;
-    for (const auto& clause : op->clauses) {
-      Doc clause_doc;
-      clause_doc << PrintPattern(clause->lhs, false) << " => ";
-      Doc rhs_doc = PrintScope(clause->rhs);
-      if (clause->rhs.as<LetNode>()) {
-        // only add braces if there are multiple lines on the rhs
-        rhs_doc = Doc::Brace("{", rhs_doc, "}");
-      }
-      clause_doc << rhs_doc << ",";
-      clause_docs.push_back(clause_doc);
-    }
-    doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine()))
-        << Doc::NewLine() << "}";
-    return doc;
-  }
+Doc RelayTextPrinter::VisitExpr_(const RefWriteNode* op) {
+  Doc doc;
+  return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")";
+}
 
-  Doc PrintPattern(const Pattern& pattern, bool meta) {
-    auto it = memo_pattern_.find(pattern);
-    if (it != memo_pattern_.end()) return it->second;
-    Doc printed_pattern;
-    if (meta) {
-      printed_pattern = meta_.GetMetaNode(GetRef<ObjectRef>(pattern.get()));
-    } else {
-      printed_pattern = VisitPattern(pattern);
-    }
-    memo_pattern_[pattern] = printed_pattern;
-    return printed_pattern;
-  }
+Doc RelayTextPrinter::VisitExpr_(const MatchNode* op) {
+  // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
+  Doc doc;
+  Doc body;
+  doc << "match";
+  if (!op->complete) {
+    doc << "?";
+  }
+  doc << " (" << Print(op->data) << ") {";
+  std::vector<Doc> clause_docs;
+  for (const auto& clause : op->clauses) {
+    Doc clause_doc;
+    clause_doc << PrintPattern(clause->lhs, false) << " => ";
+    Doc rhs_doc = PrintScope(clause->rhs);
+    if (clause->rhs.as<LetNode>()) {
+      // only add braces if there are multiple lines on the rhs
+      rhs_doc = Doc::Brace("{", rhs_doc, "}");
+    }
+    clause_doc << rhs_doc << ",";
+    clause_docs.push_back(clause_doc);
+  }
+  doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine()))
+      << Doc::NewLine() << "}";
+  return doc;
+}
 
-  Doc VisitPattern_(const PatternConstructorNode* p) final {
-    Doc doc;
-    doc << p->constructor->name_hint;
-    if (!p->patterns.empty()) {
-      doc << "(";
-      std::vector<Doc> pats;
-      for (const auto& pat : p->patterns) {
-        pats.push_back(Print(pat));
-      }
-      doc << Doc::Concat(pats) << ")";
-    }
-    return doc;
+Doc RelayTextPrinter::PrintPattern(const Pattern& pattern, bool meta) {
+  auto it = memo_pattern_.find(pattern);
+  if (it != memo_pattern_.end()) return it->second;
+  Doc printed_pattern;
+  if (meta) {
+    printed_pattern = meta_->GetMetaNode(GetRef<ObjectRef>(pattern.get()));
+  } else {
+    printed_pattern = VisitPattern(pattern);
   }
+  memo_pattern_[pattern] = printed_pattern;
+  return printed_pattern;
+}
 
-  Doc VisitPattern_(const PatternTupleNode* pt) final {
-    Doc doc;
+Doc RelayTextPrinter::VisitPattern_(const PatternConstructorNode* p) {
+  Doc doc;
+  doc << p->constructor->name_hint;
+  if (!p->patterns.empty()) {
     doc << "(";
     std::vector<Doc> pats;
-    for (const auto& pat : pt->patterns) {
+    for (const auto& pat : p->patterns) {
       pats.push_back(Print(pat));
     }
     doc << Doc::Concat(pats) << ")";
-    return doc;
   }
+  return doc;
+}
 
-  Doc VisitPattern_(const PatternWildcardNode* pw) final {
-    return Doc::Text("_");
+Doc RelayTextPrinter::VisitPattern_(const PatternTupleNode* pt) {
+  Doc doc;
+  doc << "(";
+  std::vector<Doc> pats;
+  for (const auto& pat : pt->patterns) {
+    pats.push_back(Print(pat));
   }
+  doc << Doc::Concat(pats) << ")";
+  return doc;
+}
 
-  Doc VisitPattern_(const PatternVarNode* pv) final {
-    return AllocVar(pv->var);
-  }
+Doc RelayTextPrinter::VisitPattern_(const PatternWildcardNode* pw) {
+  return Doc::Text("_");
+}
 
-  Doc VisitExpr_(const ConstructorNode* n) final {
-    Doc doc;
-    doc << n->name_hint;
-    if (in_adt_def_ && n->inputs.size() != 0) {
-      doc << "(";
-      std::vector<Doc> inputs;
-      for (Type input : n->inputs) {
-        inputs.push_back(Print(input));
-      }
-      doc << Doc::Concat(inputs) << ")";
-    }
-    return doc;
-  }
+Doc RelayTextPrinter::VisitPattern_(const PatternVarNode* pv) {
+  return AllocVar(pv->var);
+}
 
-  //------------------------------------
-  // Overload of Type printing functions
-  //------------------------------------
-  Doc PrintType(const Type& type, bool meta) {
-    auto it = memo_type_.find(type);
-    if (it != memo_type_.end()) return it->second;
-    Doc printed_type;
-    if (meta) {
-      printed_type = meta_.GetMetaNode(GetRef<ObjectRef>(type.get()));
-    } else {
-      printed_type = VisitType(type);
+Doc RelayTextPrinter::VisitExpr_(const ConstructorNode* n) {
+  Doc doc;
+  doc << n->name_hint;
+  if (in_adt_def_ && n->inputs.size() != 0) {
+    doc << "(";
+    std::vector<Doc> inputs;
+    for (Type input : n->inputs) {
+      inputs.push_back(Print(input));
     }
-    memo_type_[type] = printed_type;
-    return printed_type;
+    doc << Doc::Concat(inputs) << ")";
   }
+  return doc;
+}
 
-  Doc VisitTypeDefault_(const Object* node) final {
-    // by default always print as meta data
-    return Print(GetRef<ObjectRef>(node), true);
+//------------------------------------
+// Overload of Type printing functions
+//------------------------------------
+Doc RelayTextPrinter::PrintType(const Type& type, bool meta) {
+  auto it = memo_type_.find(type);
+  if (it != memo_type_.end()) return it->second;
+  Doc printed_type;
+  if (meta) {
+    printed_type = meta_->GetMetaNode(GetRef<ObjectRef>(type.get()));
+  } else {
+    printed_type = VisitType(type);
   }
+  memo_type_[type] = printed_type;
+  return printed_type;
+}
 
-  Doc VisitType_(const TypeVarNode* node) final {
-    return Doc::Text(node->name_hint);
-  }
+Doc RelayTextPrinter::VisitTypeDefault_(const Object* node) {
+  // by default always print as meta data
+  return Print(GetRef<ObjectRef>(node), true);
+}
 
-  Doc VisitType_(const GlobalTypeVarNode* node) final {
-    return Doc::Text(node->name_hint);
-  }
+Doc RelayTextPrinter::VisitType_(const TypeVarNode* node) {
+  return Doc::Text(node->name_hint);
+}
 
-  Doc VisitType_(const TypeCallNode* node) final {
-    Doc doc = PrintType(node->func, false);
-    std::vector<Doc> args;
-    for (const Type& t : node->args) {
-      args.push_back(PrintType(t, false));
-    }
-    doc << "[";
-    doc << Doc::Concat(args);
-    doc << "]";
-    return doc;
-  }
+Doc RelayTextPrinter::VisitType_(const GlobalTypeVarNode* node) {
+  return Doc::Text(node->name_hint);
+}
 
-  Doc PrintDType(DataType dtype) {
-    return Doc::Text(runtime::DLDataType2String(dtype));
+Doc RelayTextPrinter::VisitType_(const TypeCallNode* node) {
+  Doc doc = PrintType(node->func, false);
+  std::vector<Doc> args;
+  for (const Type& t : node->args) {
+    args.push_back(PrintType(t, false));
   }
+  doc << "[";
+  doc << Doc::Concat(args);
+  doc << "]";
+  return doc;
+}
 
-  Doc VisitType_(const TensorTypeNode* node) final {
-    // scalar type
-    if (node->shape.size() == 0) {
-      return PrintDType(node->dtype);
-    }
-    Doc doc;
-    doc << "Tensor[(";
-    std::vector<Doc> shapes;
-    for (ObjectRef shape : node->shape) {
-      shapes.push_back(PrintAttr(shape));
-    }
-    doc << Doc::Concat(shapes);
-    return doc << "), " << PrintDType(node->dtype) << "]";
+Doc RelayTextPrinter::PrintDType(DataType dtype) {
+  return Doc::Text(runtime::DLDataType2String(dtype));
+}
+
+Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) {
+  // scalar type
+  if (node->shape.size() == 0) {
+    return PrintDType(node->dtype);
   }
+  Doc doc;
+  doc << "Tensor[(";
+  std::vector<Doc> shapes;
+  for (ObjectRef shape : node->shape) {
+    shapes.push_back(PrintAttr(shape));
+  }
+  doc << Doc::Concat(shapes);
+  return doc << "), " << PrintDType(node->dtype) << "]";
+}
 
-  Doc VisitType_(const TupleTypeNode* node) final {
-    std::vector<Doc> fields;
-    for (Type field : node->fields) {
-      fields.push_back(Print(field));
-    }
-    Doc doc;
-    doc << "(" << Doc::Concat(fields);
-    // conform to python tuple format (1,)
-    if (node->fields.size() == 1) {
-      doc << ",";
-    }
-    return doc << ")";
+Doc RelayTextPrinter::VisitType_(const TupleTypeNode* node) {
+  std::vector<Doc> fields;
+  for (Type field : node->fields) {
+    fields.push_back(Print(field));
+  }
+  Doc doc;
+  doc << "(" << Doc::Concat(fields);
+  // conform to python tuple format (1,)
+  if (node->fields.size() == 1) {
+    doc << ",";
   }
+  return doc << ")";
+}
 
-  Doc VisitType_(const FuncTypeNode* node) final {
-    Doc doc;
-    doc << "fn ";
-    if (node->type_params.size() != 0) {
-      doc << "[";
-      std::vector<Doc> type_params;
-      for (Type type_param : node->type_params) {
-        type_params.push_back(Print(type_param));
-      }
-      doc << Doc::Concat(type_params);
-      doc << "]";
-    }
-    std::vector<Doc> arg_types;
-    for (Type arg_type : node->arg_types) {
-      arg_types.push_back(Print(arg_type));
+Doc RelayTextPrinter::VisitType_(const FuncTypeNode* node) {
+  Doc doc;
+  doc << "fn ";
+  if (node->type_params.size() != 0) {
+    doc << "[";
+    std::vector<Doc> type_params;
+    for (Type type_param : node->type_params) {
+      type_params.push_back(Print(type_param));
     }
-    return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type);
+    doc << Doc::Concat(type_params);
+    doc << "]";
   }
-
-  Doc VisitType_(const RelayRefTypeNode* node) final {
-    Doc doc;
-    return doc << "ref(" << Print(node->value) << ")";
+  std::vector<Doc> arg_types;
+  for (Type arg_type : node->arg_types) {
+    arg_types.push_back(Print(arg_type));
   }
+  return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type);
+}
 
-  Doc VisitType_(const TypeDataNode* node) final {
-    in_adt_def_ = true;
-    Doc doc;
-    doc << "type " << Print(node->header);
-
-    // type vars
-    if (node->type_vars.size() != 0) {
-      doc << "[";
-      std::vector<Doc> type_vars;
-      for (Type type_var : node->type_vars) {
-        type_vars.push_back(Print(type_var));
-      }
-      doc << Doc::Concat(type_vars) << "]";
-    }
-    doc << " ";
+Doc RelayTextPrinter::VisitType_(const RelayRefTypeNode* node) {
+  Doc doc;
+  return doc << "ref(" << Print(node->value) << ")";
+}
 
-    std::vector<Doc> constructor_docs;
-    for (Constructor constructor : node->constructors) {
-      constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
-    }
-    Doc separator;
-    separator << "," << Doc::NewLine();
-    Doc adt_body;
-    adt_body << Doc::Concat(constructor_docs, separator);
-    // add trailing comma if there are any constructors
-    if (!constructor_docs.empty()) {
-      adt_body << ",";
+Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) {
+  in_adt_def_ = true;
+  Doc doc;
+  doc << "type " << Print(node->header);
+
+  // type vars
+  if (node->type_vars.size() != 0) {
+    doc << "[";
+    std::vector<Doc> type_vars;
+    for (Type type_var : node->type_vars) {
+      type_vars.push_back(Print(type_var));
     }
-    doc << Doc::Brace("{", adt_body, "}");
-    in_adt_def_ = false;
-    return doc;
+    doc << Doc::Concat(type_vars) << "]";
   }
+  doc << " ";
 
-  //------------------------------------
-  // Overload of Attr printing functions
-  //------------------------------------
+  std::vector<Doc> constructor_docs;
+  for (Constructor constructor : node->constructors) {
+    constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
+  }
+  Doc separator;
+  separator << "," << Doc::NewLine();
+  Doc adt_body;
+  adt_body << Doc::Concat(constructor_docs, separator);
+  // add trailing comma if there are any constructors
+  if (!constructor_docs.empty()) {
+    adt_body << ",";
+  }
+  doc << Doc::Brace("{", adt_body, "}");
+  in_adt_def_ = false;
+  return doc;
+}
 
-  Doc PrintAttr(const ObjectRef& value, bool meta = false) {
-    if (value.defined()) {
-      Doc printed_attr;
-      if (value.as<tvm::tir::AnyNode>()) {
-        printed_attr << "?";
-      } else if (meta) {
-        printed_attr = meta_.GetMetaNode(Downcast<ObjectRef>(value));
-      } else {
-        printed_attr = VisitAttr(value);
-      }
-      return printed_attr;
+//------------------------------------
+// Overload of Attr printing functions
+//------------------------------------
+
+Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) {
+  if (value.defined()) {
+    Doc printed_attr;
+    if (value.as<tvm::tir::AnyNode>()) {
+      printed_attr << "?";
+    } else if (meta) {
+      printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
     } else {
-      return Doc::Text("None");
+      printed_attr = VisitAttr(value);
     }
+    return printed_attr;
+  } else {
+    return Doc::Text("None");
   }
+}
 
-  Doc VisitAttrDefault_(const Object* op) final {
-    return PrintAttr(GetRef<ObjectRef>(op), true);
-  }
+Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) {
+  return PrintAttr(GetRef<ObjectRef>(op), true);
+}
 
-  Doc VisitAttr_(const ArrayNode* op) final {
-    Doc doc;
-    doc << "[";
-    std::vector<Doc> arr_vals;
-    for (auto val : op->data) {
-      arr_vals.push_back(PrintAttr(val));
-    }
-    doc << Doc::Concat(arr_vals);
-    doc << "]";
-    return doc;
-  }
+Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) {
+  Doc doc;
+  doc << "[";
+  std::vector<Doc> arr_vals;
+  for (auto val : op->data) {
+    arr_vals.push_back(PrintAttr(val));
+  }
+  doc << Doc::Concat(arr_vals);
+  doc << "]";
+  return doc;
+}
 
-  Doc VisitAttr_(const tir::IntImmNode* op) final {
-    return ScalarLiteral(op->dtype, op->value);
-  }
+Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) {
+  return ScalarLiteral(op->dtype, op->value);
+}
 
-  Doc VisitAttr_(const tir::FloatImmNode* op) final {
-    return ScalarLiteral(op->dtype, op->value);
-  }
+Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) {
+  return ScalarLiteral(op->dtype, op->value);
+}
 
-  Doc VisitAttr_(const tir::StringImmNode* op) final {
-    return Doc::StrLiteral(op->value);
-  }
+Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) {
+  return Doc::StrLiteral(op->value);
+}
 
- private:
-  /*! \brief Whether to print meta data. */
-  bool show_meta_data_;
-  /*! \brief additional comment function */
-  runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
-  /*! \brief Stack of docs to implement scoped GNFing. */
-  std::vector<Doc> doc_stack_{};
-  /*! \brief Map from Expr to Doc */
-  std::unordered_map<Expr, Doc, ObjectHash, ObjectEqual> memo_;
-  /*! \brief Map from Type to Doc */
-  std::unordered_map<Type, Doc, ObjectHash, ObjectEqual> memo_type_;
-  /*! \brief Map from Type to Doc */
-  std::unordered_map<Pattern, Doc, ObjectHash, ObjectEqual> memo_pattern_;
-  /*! \brief name allocation map */
-  std::unordered_map<std::string, int> name_alloc_map_;
-  /*! \brief meta data context */
-  TextMetaDataContext meta_;
-  /*! \brief counter of temporary variable */
-  size_t temp_var_counter_{0};
-  /*! \brief whether the printer is currently in an ADT definition */
-  bool in_adt_def_;
-  /*! \brief arena for dependency graph */
-  support::Arena arena_;
-  /*! \brief dependency graph of the expr */
-  DependencyGraph dg_;
-  class AttrPrinter;
-  friend class AttrPrinter;
-};
 
 /*!
  * \brief Attribute printer which prints the attributes in the call.
@@ -883,7 +833,7 @@ std::vector<Doc> RelayTextPrinter::PrintCallAttrs(
   if (op_node && (attrs->type_index() != op_node->attrs_type_index)) {
     // fallback
     Doc doc;
-    doc << meta_.GetMetaNode(attrs);
+    doc << meta_->GetMetaNode(attrs);
     docs.push_back(doc);
     return docs;
   } else {
@@ -905,44 +855,6 @@ std::vector<Doc> RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) {
   }
   return docs;
 }
-}  // namespace relay
 
-static const char* kSemVer = "v0.0.4";
-
-// TODO(tvm-team): split into files, related: arith/analyzer.h
-//
-// - text_printer.h (common header)
-// - text_printer.cc (prints modules dispatch into relay and tir files)
-//    - type_text_printer.cc(specific printing logics for types,
-//      can also consider put under type_text_printer)
-//    - Implements AsText
-// - relay_text_printer.cc (specific printing logics for relay)
-// - tir_text_printer.cc (specific printing logics for TIR)
-String PrettyPrint(const ObjectRef& node) {
-  Doc doc;
-  doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node);
-  return doc.str();
-}
-
-String AsText(const ObjectRef& node,
-                   bool show_meta_data,
-                   runtime::TypedPackedFunc<String(ObjectRef)> annotate) {
-  Doc doc;
-  doc << kSemVer << Doc::NewLine();
-  runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr;
-  if (annotate != nullptr) {
-    ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>(
-      [&annotate](const ObjectRef& expr) -> std::string {
-        return annotate(expr);
-      });
-  }
-  doc << relay::RelayTextPrinter(show_meta_data, ftyped).PrintFinal(node);
-  return doc.str();
-}
-
-TVM_REGISTER_GLOBAL("ir.PrettyPrint")
-.set_body_typed(PrettyPrint);
-
-TVM_REGISTER_GLOBAL("ir.AsText")
-.set_body_typed(AsText);
+}  // namespace relay
 }  // namespace tvm
diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc
new file mode 100644 (file)
index 0000000..592aabe
--- /dev/null
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file text_printer.cc
+ * \brief Printer to print out the unified IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/tir/function.h>
+#include <string>
+#include "text_printer.h"
+
+namespace tvm {
+
+static const char* kSemVer = "v0.0.4";
+
+// TODO(tvm-team): split into files, related: arith/analyzer.h
+//
+// - text_printer.h (common header)
+// - text_printer.cc (prints modules dispatch into relay and tir files)
+//    - type_text_printer.cc(specific printing logics for types,
+//      can also consider put under type_text_printer)
+//    - Implements AsText
+// - relay_text_printer.cc (specific printing logics for relay)
+// - tir_text_printer.cc (specific printing logics for TIR)
+
+Doc TextPrinter::PrintMod(const IRModule& mod) {
+  Doc doc;
+  int counter = 0;
+  // type definitions
+  for (const auto& kv : mod->type_definitions) {
+    if (counter++ != 0) {
+      doc << Doc::NewLine();
+    }
+    doc << relay_text_printer_.Print(kv.second);
+    doc << Doc::NewLine();
+  }
+  // functions
+  for (const auto& kv : mod->functions) {
+    if (kv.second.as<relay::FunctionNode>()) {
+      relay_text_printer_.dg_ =
+          relay::DependencyGraph::Create(&relay_text_printer_.arena_, kv.second);
+    }
+    if (counter++ != 0) {
+      doc << Doc::NewLine();
+    }
+    if (kv.second.as<relay::FunctionNode>()) {
+      std::ostringstream os;
+      os << "def @" << kv.first->name_hint;
+      doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
+    } else if (kv.second.as<tir::PrimFuncNode>()) {
+      doc << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
+    }
+    doc << Doc::NewLine();
+  }
+  return doc;
+}
+
+String PrettyPrint(const ObjectRef& node) {
+  Doc doc;
+  doc << TextPrinter(false, nullptr).PrintFinal(node);
+  return doc.str();
+}
+
+String AsText(const ObjectRef& node,
+              bool show_meta_data,
+              runtime::TypedPackedFunc<String(ObjectRef)> annotate) {
+  Doc doc;
+  doc << kSemVer << Doc::NewLine();
+  runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr;
+  if (annotate != nullptr) {
+    ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>(
+        [&annotate](const ObjectRef& expr) -> std::string {
+          return annotate(expr);
+        });
+  }
+  doc << TextPrinter(show_meta_data, ftyped).PrintFinal(node);
+  return doc.str();
+}
+
+TVM_REGISTER_GLOBAL("ir.PrettyPrint")
+.set_body_typed(PrettyPrint);
+
+TVM_REGISTER_GLOBAL("ir.AsText")
+.set_body_typed(AsText);
+
+}  // namespace tvm
diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h
new file mode 100644 (file)
index 0000000..63767af
--- /dev/null
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file text_printer.h
+ * \brief Printer to print out the unified IR text format
+ *        that can be parsed by a parser.
+ */
+
+#ifndef TVM_PRINTER_TEXT_PRINTER_H_
+#define TVM_PRINTER_TEXT_PRINTER_H_
+
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/op.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/ir/module.h>
+#include <tvm/tir/function.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
+#include <unordered_map>
+#include <vector>
+#include <string>
+#include "../relay/analysis/dependency_graph.h"
+#include "../ir/attr_functor.h"
+
+#include "doc.h"
+#include "meta_data.h"
+#include "text_printer.h"
+
+namespace tvm {
+class TextPrinter;
+}  // namespace tvm
+
+namespace tvm {
+namespace relay {
+
+class RelayTextPrinter :
+    public ExprFunctor<Doc(const Expr&)>,
+    public PatternFunctor<Doc(const Pattern&)>,
+    public TypeFunctor<Doc(const Type&)>,
+    public AttrFunctor<Doc(const ObjectRef&)> {
+ public:
+  explicit RelayTextPrinter(bool show_meta_data,
+                            TextMetaDataContext* meta,
+                            runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
+      : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {}
+
+  /*!
+    * \brief Print additional info about expr in comment.
+    * \param expr The expression.
+    */
+  Doc PrintOptionalInfo(const Expr& expr);
+  // indent a new body
+  Doc PrintBody(const ObjectRef& node, int indent = 2);
+  // create a new scope by creating a new printer object. This allows temp var
+  // numbers to be reused and prevents hoisted vars from escaping too far
+  Doc PrintScope(const ObjectRef& node);
+  Doc PrintFinal(const ObjectRef& node);
+  std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
+  std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
+
+  Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);
+
+  Doc TempVar(int n);
+  Doc AllocTemp();
+  /*!
+    * \brief get a unique name with the corresponding prefix
+    * \param prefix The prefix of the name
+    * \return The returned name.
+    */
+  Doc GetUniqueName(const std::string& prefix);
+  Doc Print(Kind k);
+  /*!
+   * \brief Allocate name to a type variable.
+   * \param var The input type variable.
+   * \return The corresponding name.
+   */
+  Doc AllocTypeVar(const TypeVar& var);
+  /*!
+   * \brief Allocate name to a variable.
+   * \param var The input variable.
+   * \return The corresponding name.
+   */
+  Doc AllocVar(const Var& var);
+  bool IsUnique(const Expr& expr);
+  bool AlwaysInline(const Expr& expr);
+
+  Doc PrintFunc(const Doc& prefix, const relay::Function& fn);
+  Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func);
+  Doc PrintMod(const IRModule& mod);
+
+  //------------------------------------
+  // Overload of Expr printing functions
+  //------------------------------------
+  Doc PrintExpr(const Expr& expr, bool meta, bool try_inline);
+  // Should only be triggered when op is a free variable being visited for the
+  // first time.
+  Doc VisitExpr_(const VarNode* op) final;
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param value The value to be printed.
+   */
+  template <typename T>
+  static Doc ScalarLiteral(DataType dtype, const T& value);
+  Doc VisitExpr_(const ConstantNode* op) final;
+  Doc VisitExpr_(const TupleNode* op) final;
+  Doc VisitExpr_(const TupleGetItemNode* op) final;
+  Doc VisitExpr_(const IfNode* op) final;
+  Doc VisitExpr_(const LetNode* op) final;
+  Doc VisitExpr_(const FunctionNode* op) final;
+  Doc VisitExpr_(const GlobalVarNode* op) final;
+  Doc VisitExpr_(const OpNode* op) final;
+  Doc VisitExpr_(const CallNode* op) final;
+  Doc VisitExpr_(const RefCreateNode* op) final;
+  Doc VisitExpr_(const RefReadNode* op) final;
+  Doc VisitExpr_(const RefWriteNode* op) final;
+  Doc VisitExpr_(const MatchNode* op) final;
+  Doc PrintPattern(const Pattern& pattern, bool meta);
+  Doc VisitPattern_(const PatternConstructorNode* p) final;
+  Doc VisitPattern_(const PatternTupleNode* pt) final;
+  Doc VisitPattern_(const PatternWildcardNode* pw) final;
+  Doc VisitPattern_(const PatternVarNode* pv) final;
+  Doc VisitExpr_(const ConstructorNode* n) final;
+  //------------------------------------
+  // Overload of Type printing functions
+  //------------------------------------
+  Doc PrintType(const Type& type, bool meta);
+  Doc VisitTypeDefault_(const Object* node) final;
+  Doc VisitType_(const TypeVarNode* node) final;
+  Doc VisitType_(const GlobalTypeVarNode* node);
+  Doc VisitType_(const TypeCallNode* node) final;
+  Doc PrintDType(DataType dtype);
+  Doc VisitType_(const TensorTypeNode* node) final;
+  Doc VisitType_(const TupleTypeNode* node) final;
+  Doc VisitType_(const FuncTypeNode* node) final;
+  Doc VisitType_(const RelayRefTypeNode* node) final;
+  Doc VisitType_(const TypeDataNode* node) final;
+  //------------------------------------
+  // Overload of Attr printing functions
+  //------------------------------------
+  Doc PrintAttr(const ObjectRef& value, bool meta = false);
+  Doc VisitAttrDefault_(const Object* op) final;
+  Doc VisitAttr_(const ArrayNode* op) final;
+  Doc VisitAttr_(const tir::IntImmNode* op) final;
+  Doc VisitAttr_(const tir::FloatImmNode* op) final;
+  Doc VisitAttr_(const tir::StringImmNode* op) final;
+
+ private:
+  /*! \brief Whether to print meta data. */
+  bool show_meta_data_;
+  /*! \brief additional comment function */
+  runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
+  /*! \brief Stack of docs to implement scoped GNFing. */
+  std::vector<Doc> doc_stack_{};
+  /*! \brief Map from Expr to Doc */
+  std::unordered_map<Expr, Doc, ObjectHash, ObjectEqual> memo_;
+  /*! \brief Map from Type to Doc */
+  std::unordered_map<Type, Doc, ObjectHash, ObjectEqual> memo_type_;
+  /*! \brief Map from Type to Doc */
+  std::unordered_map<Pattern, Doc, ObjectHash, ObjectEqual> memo_pattern_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+  /*! \brief meta data context */
+  TextMetaDataContext* meta_;
+  /*! \brief counter of temporary variable */
+  size_t temp_var_counter_{0};
+  /*! \brief whether the printer is currently in an ADT definition */
+  bool in_adt_def_;
+  /*! \brief arena for dependency graph */
+  support::Arena arena_;
+  /*! \brief dependency graph of the expr */
+  DependencyGraph dg_;
+  class AttrPrinter;
+  friend class AttrPrinter;
+  friend class tvm::TextPrinter;
+};
+
+}  // namespace relay
+}  // namespace tvm
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta)
+    : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext* meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  friend class tvm::TextPrinter;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype);
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T& data);
+  Doc GetUniqueName(std::string prefix);
+  Doc AllocVar(const Var& var);
+  Doc AllocBuf(const Buffer& buffer);
+  /*!
+   * \brief special method to render vectors of docs with a separator
+   * \param vec vector of docs
+   * \param sep separator
+   */
+  static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep);
+  Doc PrintBody(const Stmt& body, bool indent = true);
+};
+
+}  // namespace tir
+}  // namespace tvm
+
+namespace tvm {
+
+class TextPrinter {
+ public:
+  explicit TextPrinter(bool show_meta_data,
+                       const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate)
+      : show_meta_data_(show_meta_data), annotate_(annotate),
+        relay_text_printer_(show_meta_data, &meta_, annotate),
+        tir_text_printer_(show_meta_data, &meta_) {}
+
+  /*! \brief whether show meta data */
+  bool show_meta_data_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief additional comment function */
+  runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
+  /*! \brief Relay Text Printer */
+  relay::RelayTextPrinter relay_text_printer_;
+  /*! \brief TIR Text Printer */
+  tir::TIRTextPrinter tir_text_printer_;
+
+  Doc PrintFinal(const ObjectRef& node) {
+    Doc doc;
+    if (node->IsInstance<IRModuleNode>()) {
+      doc << PrintMod(Downcast<IRModule>(node));
+    } else if (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>()
+              || node->IsInstance<tir::StmtNode>()) {
+      doc << tir_text_printer_.Print(node);
+    } else {
+      doc << relay_text_printer_.PrintFinal(node);
+    }
+    if (!meta_.empty()) {
+      doc << Doc::NewLine();
+      if (show_meta_data_) {
+        // append meta data in the end.
+        doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection();
+      } else {
+        doc << "// meta data omitted. you can use show_meta_data=True to include meta data";
+      }
+    }
+    return doc;
+  }
+
+  Doc PrintMod(const IRModule& mod);
+};
+}  // namespace tvm
+
+#endif  // TVM_PRINTER_TEXT_PRINTER_H_
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
new file mode 100644 (file)
index 0000000..a5754d7
--- /dev/null
@@ -0,0 +1,597 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+#include "text_printer.h"
+
+namespace tvm {
+namespace tir {
+
+Doc TIRTextPrinter::Print(const ObjectRef& node) {
+  if (!node.defined()) return Doc::Text("(nullptr)");
+  if (node->IsInstance<StmtNode>()) {
+    return VisitStmt(Downcast<Stmt>(node));
+  } else if (node->IsInstance<AnyNode>()) {
+    return Doc::Text("?");
+  } else if (node->IsInstance<PrimExprNode>()) {
+    return VisitExpr(Downcast<PrimExpr>(node));
+  } else if (node->IsInstance<TypeNode>()) {
+    return VisitType(Downcast<Type>(node));
+  } else if (node->IsInstance<PrimFuncNode>()) {
+    return PrintPrimFunc(Downcast<PrimFunc>(node));
+  } else if (node->IsInstance<IRModuleNode>()) {
+    return PrintIRModule(Downcast<IRModule>(node));
+  } else if (node->IsInstance<ArrayNode>()) {
+    return PrintArray(node.as<ArrayNode>());
+  } else if (node->IsInstance<IterVarNode>()) {
+    return PrintIterVar(node.as<IterVarNode>());
+  } else if (node->IsInstance<RangeNode>()) {
+    return PrintRange(node.as<RangeNode>());
+  } else if (node->IsInstance<BufferNode>()) {
+    return PrintBuffer(node.as<BufferNode>());
+  } else if (node->IsInstance<StringObj>()) {
+    return PrintString(node.as<StringObj>());
+  } else {
+    return this->meta_->GetMetaNode(node);
+  }
+}
+
+Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
+  const auto* op = primFunc.operator->();
+  const auto& signature = op->func_type_annotation();
+  // collect Meta in DictAttr
+  for (const auto& it : primFunc->attrs->dict) {
+    meta_collector_.Collect(it.second);
+  }
+  // collect buffers in buffer_map
+  memo_var_.clear();
+  memo_buf_.clear();
+  for (const auto& it : op->buffer_map) {
+    memo_buf_[it.second] = AllocBuf(it.second);
+  }
+  // print PrimFunc
+  Doc doc;
+  doc << "primfn" << "(";
+  // print params and its type annotation
+  std::vector<Doc> params;
+  for (const auto& param : op->params) {
+    params.push_back(Print(param));
+  }
+  Doc sep;
+  doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")";
+  // print return type
+  doc << " -> " << Print(signature->ret_type);
+  // print attr
+  Doc attr_doc;
+  std::vector<Doc> attr_docs;
+  for (const auto& it : op->attrs->dict) {
+    attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
+  }
+  attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
+  doc << Doc::Indent(2, attr_doc);
+  // print all the buffers in the tree
+  Doc buffer_doc;
+  std::vector<Doc> buffer_docs;
+  for (const auto& it : memo_buf_) {
+    const auto& buf = it.first;
+    buffer_docs.push_back(Print(buf)
+                              << Doc::Text(": Buffer(") << Print(buf->data) << ", "
+                              << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", "
+                              << Print(buf->strides));
+    if (!is_zero(buf->elem_offset)) {
+      buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset);
+    }
+    if (buf->scope != "global") {
+      buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope);
+    }
+    if (buf->data_alignment != 128) {
+      buffer_docs.back() << ", align=" << buf->data_alignment;
+    }
+    if (buf->offset_factor != 1) {
+      buffer_docs.back() << ", offset_factor=" << buf->offset_factor;
+    }
+    if (buf->buffer_type != 1) {
+      buffer_docs.back() << ", type=" << Doc::StrLiteral("auto");
+    }
+    buffer_docs.back() << ")";
+  }
+  buffer_doc << Doc::NewLine() << "buffers = {";
+  buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine()));
+  doc << Doc::Indent(2, buffer_doc) << "}";
+  // print buffer_map
+  std::vector<Doc> buffer_map_doc;
+  for (const auto& it : op->buffer_map) {
+    buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second));
+  }
+  doc << Doc::Indent(2, Doc::NewLine()
+      << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
+  doc << PrintBody(op->body);
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintIRModule(const IRModule& module) {
+  const auto* op = module.operator->();
+  Doc doc;
+
+  Doc body;
+  body << Doc::NewLine();
+  std::vector<Doc> functions;
+  for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
+    if ((*it).second.as<PrimFuncNode>()) {
+      functions.push_back(Print((*it).second));
+    }
+  }
+  body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
+  doc << Doc::Indent(0, body);
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintArray(const ArrayNode* op) {
+  Doc doc;
+  doc << '[';
+  for (size_t i = 0; i < op->data.size(); ++i) {
+    if (i != 0) {
+      doc << ", ";
+    }
+    doc << Print(op->data[i]);
+  }
+  doc << ']';
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) {
+  Doc doc;
+  doc << "IterVar(" << Print(op->var);
+  if (op->dom.defined()) {
+    doc << ", [" << Print(op->dom) << "], ";
+  } else {
+    doc << ", " << Print(op->dom) << ", ";
+  }
+  doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", ";
+  doc << Doc::StrLiteral(op->thread_tag) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintRange(const RangeNode* op) {
+  return Print(op->min) << ":" << Print(op->min + op->extent);
+}
+
+Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) {
+  const Buffer& buffer = GetRef<Buffer>(op);
+  CHECK_GT(memo_buf_.count(buffer), 0);
+  return meta_->InMeta(buffer) ? meta_->GetMetaNode(buffer) : memo_buf_[buffer];
+}
+
+Doc TIRTextPrinter::VisitExprDefault_(const Object* op) {
+  return this->meta_->GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) {
+  return this->meta_->GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) {
+  return PrintConstScalar<int64_t>(op->dtype, op->value);
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) {
+  return PrintConstScalar<double>(op->dtype, op->value);
+}
+
+Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); }
+
+Doc TIRTextPrinter::VisitExpr_(const CastNode* op) {
+  Doc doc;
+  doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const VarNode* op) {
+  const Var& var = GetRef<Var>(op);
+  return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef<Var>(op));
+}
+
+#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString)     \
+  Doc TIRTextPrinter::VisitExpr_(const OpName* op) {               \
+    Doc doc;                                                       \
+    doc << "(" << Print(op->a) << OpString;                        \
+    doc << Print(op->b) << ")";                                    \
+    return doc;                                                    \
+  }
+
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " && ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " || ")
+
+Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) {
+  Doc doc;
+  doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) {
+  Doc doc;
+  doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MinNode* op) {
+  Doc doc;
+  doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) {
+  Doc doc;
+  doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const NotNode* op) {
+  Doc doc;
+  doc << "!" << Print(op->a);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) {
+  Doc doc;
+  doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
+      << Print(op->false_value);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) {
+  Doc doc;
+  doc << Print(op->buffer) << Print(op->indices);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) {
+  Doc doc;
+  doc << "(" << PrintDType(op->dtype) << "*)"
+      << Print(op->buffer_var) << "[" << Print(op->index) << "])";
+  if (!is_one(op->predicate)) {
+    doc << " if " << Print(op->predicate);
+  }
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const RampNode* op) {
+  Doc doc;
+  doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) {
+  Doc doc;
+  doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const LetNode* op) {
+  Doc doc;
+  doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body);
+  return doc;
+}
+
+inline const char* CallType2String(CallNode::CallType t) {
+  switch (t) {
+    case CallNode::Extern:return "extern";
+    case CallNode::ExternCPlusPlus:return "extern_cpp";
+    case CallNode::PureExtern:return "pure_extern";
+    case CallNode::Halide:return "halide";
+    case CallNode::Intrinsic:return "intrin";
+    case CallNode::PureIntrinsic:return "pure_intrin";
+  }
+  LOG(FATAL) << "Unknown CallType";
+  return "Unknown";
+}
+
+Doc TIRTextPrinter::VisitExpr_(const CallNode* op) {
+  Doc doc;
+  doc << "@" << Doc::Text(op->name) << "(";
+  std::vector<Doc> args;
+  for (const auto& arg : op->args) {
+    args.push_back(Print(arg));
+  }
+  doc << PrintSep(args, Doc::Text(", "))
+      << ", dtype=" << PrintDType(op->dtype)
+      << ", type=" << Doc::StrLiteral(CallType2String(op->call_type))
+      << ", index=" << op->value_index << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const ShuffleNode* op) {
+  Doc doc;
+  doc << "shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) {
+  Doc doc;
+  doc << "reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis)
+      << ", " << op->value_index << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) {
+  Doc doc;
+  doc << "let " << Print(op->var) << " = " << Print(op->value) << PrintBody(op->body);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) {
+  Doc doc;
+  meta_collector_.Collect(op->node);
+  doc << "attr [" << Print(op->node) << "] " << Doc::StrLiteral(op->attr_key) << " = "
+      << Print(op->value);
+  if (op->body->IsInstance<SeqStmtNode>()) {
+    doc << PrintBody(op->body);
+  } else {
+    doc << ";" << Doc::NewLine() << Print(op->body);
+  }
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) {
+  Doc doc;
+  doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")"
+      << PrintBody(op->body);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const StoreNode* op) {
+  Doc doc;
+  doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value);
+  if (!is_one(op->predicate)) {
+    doc << " if " << Print(op->predicate);
+  }
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) {
+  Doc doc;
+  doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) {
+  Doc doc;
+  doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", "
+      << Print(op->condition) << PrintBody(op->body) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
+  Doc doc;
+  doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", "
+      << Print(op->extents) << ")";
+  if (!is_one(op->condition)) {
+    doc << " if " << Print(op->condition);
+  }
+  if (op->body->IsInstance<SeqStmtNode>()) {
+    doc << PrintBody(op->body);
+  } else {
+    doc << ";" << Doc::NewLine() << Print(op->body);
+  }
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const FreeNode* op) {
+  Doc doc;
+  doc << "free(" << Print(op->buffer_var) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
+  Doc doc;
+  doc << "if " << Print(op->condition) << PrintBody(op->then_case);
+  if (!is_one(op->condition) && op->else_case.defined()) {
+    doc << " else" << PrintBody(op->else_case);
+  }
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) {
+  std::vector<Doc> stmts;
+  Doc seq_doc, doc;
+  for (Stmt stmt : op->seq) {
+    seq_doc << Doc::NewLine() << Print(stmt);
+  }
+  doc << " {" << Doc::Indent(2, seq_doc) << Doc::NewLine() << "}";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) {
+  Doc doc;
+  doc << Print(op->value);
+  return doc;
+}
+
+inline const char* ForType2String(ForType t) {
+  switch (t) {
+    case ForType::Serial:return "serial";
+    case ForType::Parallel:return "parallel";
+    case ForType::Vectorized:return "vectorized";
+    case ForType::Unrolled:return "unroll";
+  }
+  LOG(FATAL) << "Unknown ForType";
+  return "Unknown";
+}
+
+Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
+  Doc doc;
+  doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", "
+      << Print(op->min + op->extent) << ")";
+  if (op->for_type != ForType::Serial) {
+    doc << " " << Doc::StrLiteral(ForType2String(op->for_type));
+  }
+  doc << PrintBody(op->body);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) {
+  Doc doc;
+  doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitType_(const PrimTypeNode* node) {
+  Doc doc;
+  doc << PrintDType(node->dtype);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitType_(const PointerTypeNode* node) {
+  Doc doc;
+  doc << "Pointer(" << Print(node->element_type) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitType_(const TupleTypeNode* node) {
+  std::vector<Doc> fields;
+  for (Type field : node->fields) {
+    fields.push_back(Print(field));
+  }
+  Doc doc;
+  doc << "(" << Doc::Concat(fields);
+  // conform to python tuple format (1,)
+  if (node->fields.size() == 1) {
+    doc << ",";
+  }
+  return doc << ")";
+}
+
+Doc TIRTextPrinter::PrintDType(DataType dtype) {
+  return Doc::Text(runtime::DLDataType2String(dtype));
+}
+
+template <typename T>
+Doc TIRTextPrinter::PrintConstScalar(DataType dtype, const T& data) {
+  Doc doc;
+  std::ostringstream os;
+  os << data;
+  if (dtype == DataType::Int(32)) {
+    doc << Doc::Text(os.str());
+  } else {
+    if (dtype.bits() == 1 && dtype.lanes() == 1 && dtype.code() == kDLUInt) {
+      doc << ((data == 1) ? "True" : "False");
+      return doc;
+    }
+    doc << Doc::Text(os.str());
+    switch (dtype.code()) {
+      case kDLInt: doc << "i"; break;
+      case kDLUInt: doc << "u"; break;
+      case kDLFloat: doc << "f"; break;
+    }
+    doc << Doc::Text(std::to_string(dtype.bits()));
+    if (dtype.lanes() != 1) doc << "x" << Doc::Text(std::to_string(dtype.lanes()));
+  }
+  return doc;
+}
+
+Doc TIRTextPrinter::GetUniqueName(std::string prefix) {
+  // std::replace(prefix.begin(), prefix.end(), '.', '_');
+  std::string unique_prefix = prefix;
+  auto it = name_alloc_map_.find(prefix);
+  if (it != name_alloc_map_.end()) {
+    while (name_alloc_map_.count(
+        unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) {}
+  }
+  name_alloc_map_[unique_prefix] = 0;
+  return Doc::Text(unique_prefix);
+}
+
+Doc TIRTextPrinter::AllocVar(const Var& var) {
+  const auto& it = memo_var_.find(var);
+  if (it != memo_var_.end()) {
+    return it->second;
+  }
+  std::string name = var->name_hint;
+  if (name.length() == 0 || !std::isalpha(name[0])) {
+    name = "v" + name;
+  }
+  Doc val = GetUniqueName(name);
+  memo_var_[var] = val;
+  return val << ": " << Print(GetType(var));
+}
+
+Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) {
+  const auto& it = memo_buf_.find(buffer);
+  if (it != memo_buf_.end()) {
+    return it->second;
+  }
+  std::string name = buffer->name;
+  if (name.length() == 0 || !std::isalpha(name[0])) {
+    name = "buf_" + name;
+  }
+  Doc val = GetUniqueName(name);
+  memo_buf_[buffer] = val;
+  return val;
+}
+
+Doc TIRTextPrinter::PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
+  Doc seq;
+  if (vec.size() != 0) {
+    seq = vec[0];
+    for (size_t i = 1; i < vec.size(); i++) {
+      seq << sep << vec[i];
+    }
+  }
+  return seq;
+}
+
+Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) {
+  Doc doc;
+  if (body->IsInstance<SeqStmtNode>()) return Print(body);
+  doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}";
+  return doc;
+}
+
+}  // namespace tir
+}  // namespace tvm
index 6efb67b..372f0e9 100644 (file)
@@ -64,14 +64,14 @@ def test_deduce():
 
     e2 = (tvm.te.max(5, a * 4) < 0)
     res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
-    assert str(res2.max_value) == "neg_inf"
-    assert str(res2.min_value) == "pos_inf"
+    assert str(res2.max_value) == "neg_inf: handle"
+    assert str(res2.min_value) == "pos_inf: handle"
 
     # expression containing variable a is on rhs
     e2 = (zero < tvm.te.max(5, a * 4))
     res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
-    assert str(res2.max_value) == "neg_inf"
-    assert str(res2.min_value) == "pos_inf"
+    assert str(res2.max_value) == "neg_inf: handle"
+    assert str(res2.min_value) == "pos_inf: handle"
 
     e3 = (-b)+a*c-d
     res3 = tvm.arith.deduce_bound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
@@ -88,8 +88,8 @@ def test_deduce():
 
     # Unsatisfiable `EQ`, variable as one of the Operand
     res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s})
-    assert str(res5.max_value) == "neg_inf"
-    assert str(res5.min_value) == "pos_inf"
+    assert str(res5.max_value) == "neg_inf: handle"
+    assert str(res5.min_value) == "pos_inf: handle"
 
     # variable `a` on the RHS side
     res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {})
@@ -111,13 +111,13 @@ def test_deduce():
     # Unsatisfiable Mul in `EQ`
     e5 = (4 * a == b)
     res9 = tvm.arith.deduce_bound(a, e5, {b: b_s}, {})
-    assert str(res9.max_value) == "neg_inf"
-    assert str(res9.min_value) == "pos_inf"
+    assert str(res9.max_value) == "neg_inf: handle"
+    assert str(res9.min_value) == "pos_inf: handle"
 
     # Unsatisfiable Mul in `EQ`
     res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {})    # simplifier is not able to prove that (b % b == 0)
-    assert str(res10.max_value) == "neg_inf"
-    assert str(res10.min_value) == "pos_inf"
+    assert str(res10.max_value) == "neg_inf: handle"
+    assert str(res10.min_value) == "pos_inf: handle"
 
 
 def test_check():
index 9e4d45e..9b8d406 100644 (file)
@@ -286,8 +286,8 @@ def test_tensor_intrin_scalar_params():
     stmt = tvm.lower(s, [A, C])["main"].body
     assert isinstance(stmt.body.body, tvm.tir.Evaluate)
     assert len(stmt.body.body.value.args) == 5
-    assert str(stmt.body.body.value.args[3]) == "(i*i)"
-    assert str(stmt.body.body.value.args[4]) == "(i + j)"
+    assert str(stmt.body.body.value.args[3]) == "(i: int32*i)"
+    assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)"
 
 if __name__ == "__main__":
     test_singleton()
index 468ab1d..36c9c76 100644 (file)
@@ -103,7 +103,7 @@ def test_basic():
     a = te.var('a')
     b = te.var('b')
     c =  a + b
-    assert str(c) == '(%s + %s)' % (a.name, b.name)
+    assert str(c) == '(%s: int32 + %s: int32)' % (a.name, b.name)
 
 
 def test_stmt():
@@ -138,11 +138,11 @@ def test_any():
         assert False
     except ValueError:
         pass
-    assert str(tvm.tir.any(x < y)) == '(%s < %s)' % (x.name, y.name)
-    assert str(tvm.tir.any(x < y, x > z)) == '((%s < %s) || (%s > %s))' % (
+    assert str(tvm.tir.any(x < y)) == '(%s: int32 < %s: int32)' % (x.name, y.name)
+    assert str(tvm.tir.any(x < y, x > z)) == '((%s: int32 < %s: int32) || (%s > %s: int32))' % (
         x.name, y.name, x.name, z.name)
     assert str(tvm.tir.any(x < y, y > z + 1, x < z * 2)) == \
-        '(((%s < %s) || (%s > (%s + 1))) || (%s < (%s*2)))' % (
+        '(((%s: int32 < %s: int32) || (%s > (%s: int32 + 1))) || (%s < (%s*2)))' % (
             x.name, y.name, y.name, z.name, x.name, z.name)
 
 
@@ -160,29 +160,29 @@ def test_all():
         assert False
     except ValueError:
         pass
-    assert str(tvm.tir.all(x < y)) == '(%s < %s)' % (x.name, y.name)
-    assert str(tvm.tir.all(x < y, x > z)) == '((%s < %s) && (%s > %s))' % (
+    assert str(tvm.tir.all(x < y)) == '(%s: int32 < %s: int32)' % (x.name, y.name)
+    assert str(tvm.tir.all(x < y, x > z)) == '((%s: int32 < %s: int32) && (%s > %s: int32))' % (
         x.name, y.name, x.name, z.name)
     assert str(tvm.tir.all(x < y, y > z + 1, x < z * 2)) == \
-        '(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % (
+        '(((%s: int32 < %s: int32) && (%s > (%s: int32 + 1))) && (%s < (%s*2)))' % (
             x.name, y.name, y.name, z.name, x.name, z.name)
 
 
 def test_bitwise():
     x = te.var('x')
     y = te.var('y')
-    assert str(x << y) == 'shift_left(x, y)'
-    assert str(x >> y) == 'shift_right(x, y)'
-    assert str(x & y) == 'bitwise_and(x, y)'
-    assert str(x | y) == 'bitwise_or(x, y)'
-    assert str(x ^ y) == 'bitwise_xor(x, y)'
-    assert str(10 & x) == 'bitwise_and(10, x)'
-    assert str(10 | x) == 'bitwise_or(10, x)'
-    assert str(10 ^ x) == 'bitwise_xor(10, x)'
-    assert str(10 >> x) == 'shift_right(10, x)'
-    assert str(10 << x) == 'shift_left(10, x)'
-    assert str(10 % x) == 'floormod(10, x)'
-    assert str(~x) == 'bitwise_not(x)'
+    assert str(x << y) == '@shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
+    assert str(x >> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
+    assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
+    assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
+    assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
+    assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
+    assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
+    assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
+    assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
+    assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
+    assert str(10 % x) == 'floormod(10, x: int32)'
+    assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin", index=0)'
     assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2"
     assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
     assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2"
@@ -239,12 +239,12 @@ def test_divide_by_zero():
 
 def test_isnan():
     x = te.var('x', 'float32')
-    assert str(tvm.tir.isnan(x)) == 'isnan(x)'
+    assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin", index=0)'
     assert str(tvm.tir.isnan(x).dtype) == 'bool'
     y = te.var('y', 'float16')
-    assert str(tvm.tir.isnan(y)) == 'isnan(float32(y))'
+    assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin", index=0)'
     z = te.var('z', 'int32')
-    assert str(tvm.tir.isnan(z)) == '(bool)0'
+    assert str(tvm.tir.isnan(z)) == 'False'
     k = te.var('k', 'int8x2')
     assert str(tvm.tir.isnan(k).dtype) == 'uint1x2'