* \return The text representation.
*/
TVM_DLL String AsText(const ObjectRef& node,
- bool show_meta_data = true,
- runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
+ bool show_meta_data = true,
+ runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
} // namespace tvm
#endif // TVM_IR_MODULE_H_
}
/*!
+ * \brief Test whether a node has been put in meta
+ * \param node The query node
+ * \return whether the node has been put in meta
+ */
+ bool InMeta(const ObjectRef& node) {
+ return meta_repr_.find(node) != meta_repr_.end();
+ }
+
+ /*!
* \brief Print a key value pair
*/
Doc PrintKeyValue(const std::string& str, const Doc& v) const {
*/
/*!
- * \file text_format_printer.cc
+ * \file relay_text_printer.cc
* \brief Printer to print out the IR text format
* that can be parsed by a parser.
*
#include "meta_data.h"
#include "../relay/analysis/dependency_graph.h"
#include "../ir/attr_functor.h"
+#include "text_printer.h"
namespace tvm {
namespace relay {
-class RelayTextPrinter :
- public ExprFunctor<Doc(const Expr&)>,
- public PatternFunctor<Doc(const Pattern&)>,
- public TypeFunctor<Doc(const Type&)>,
- public AttrFunctor<Doc(const ObjectRef&)> {
- public:
- explicit RelayTextPrinter(bool show_meta_data,
- runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
- : show_meta_data_(show_meta_data),
- annotate_(annotate) {}
-
- /*!
- * \brief Print additional info about expr in comment.
- * \param expr The expression.
- */
- Doc PrintOptionalInfo(const Expr& expr) {
- Doc doc;
- // default annotations
- if (annotate_ == nullptr) {
- if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
- doc << " /* ty=" << Print(expr->checked_type()) << " */";
- }
- } else {
- std::string annotated_expr = annotate_(expr);
- if (annotated_expr != "") {
- doc << annotated_expr;
- }
+/*!
+ * \brief Print additional info about expr in comment.
+ * \param expr The expression.
+ */
+Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) {
+ Doc doc;
+ // default annotations
+ if (annotate_ == nullptr) {
+ if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
+ doc << " /* ty=" << Print(expr->checked_type()) << " */";
+ }
+ } else {
+ std::string annotated_expr = annotate_(expr);
+ if (annotated_expr != "") {
+ doc << annotated_expr;
}
-
- return doc;
}
- // indent a new body
- Doc PrintBody(const ObjectRef& node, int indent = 2) {
- Doc doc;
- Doc body;
- doc << "{";
- doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine();
- doc << "}";
- return doc;
- }
+ return doc;
+}
- // create a new scope by creating a new printer object. This allows temp var
- // numbers to be reused and prevents hoisted vars from escaping too far
- Doc PrintScope(const ObjectRef& node) {
- // print in a new scope
- doc_stack_.push_back(Doc());
- // must print first so doc_stack_.back() reference doesn't become stale
- Doc doc = Print(node, false, true);
- doc = doc_stack_.back() << doc;
- doc_stack_.pop_back();
- return doc;
+// indent a new body
+Doc RelayTextPrinter::PrintBody(const ObjectRef& node, int indent) {
+ Doc doc;
+ Doc body;
+ doc << "{";
+ doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine();
+ doc << "}";
+ return doc;
+}
+
+// create a new scope by creating a new printer object. This allows temp var
+// numbers to be reused and prevents hoisted vars from escaping too far
+Doc RelayTextPrinter::PrintScope(const ObjectRef& node) {
+ // print in a new scope
+ doc_stack_.push_back(Doc());
+ // must print first so doc_stack_.back() reference doesn't become stale
+ Doc doc = Print(node, false, true);
+ doc = doc_stack_.back() << doc;
+ doc_stack_.pop_back();
+ return doc;
+}
+
+Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) {
+ if (node->IsInstance<BaseFuncNode>() &&
+ !node->IsInstance<relay::FunctionNode>()) {
+ // Temporarily skip non-relay functions.
+ // TODO(tvm-team) enhance the code to work for all functions
+ } else if (node.as<ExprNode>()) {
+ Expr expr = Downcast<Expr>(node);
+ dg_ = DependencyGraph::Create(&arena_, expr);
}
- Doc PrintFinal(const ObjectRef& node) {
- if (node->IsInstance<BaseFuncNode>() &&
- !node->IsInstance<relay::FunctionNode>()) {
- // Temporarily skip non-relay functions.
- // TODO(tvm-team) enhance the code to work for all functions
- } else if (node.as<ExprNode>()) {
- Expr expr = Downcast<Expr>(node);
- dg_ = DependencyGraph::Create(&arena_, expr);
- }
+ Doc doc;
+ doc << PrintScope(node);
+ return doc;
+}
- Doc doc;
- doc << PrintScope(node);
- if (!meta_.empty()) {
- doc << Doc::NewLine();
- if (show_meta_data_) {
- // append meta data in the end.
- doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection();
- } else {
- doc << "// meta data omitted. you can use show_meta_data=True to include meta data";
- }
- }
- return doc;
+Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
+ bool is_non_relay_func =
+ node->IsInstance<BaseFuncNode>() &&
+ !node->IsInstance<relay::FunctionNode>();
+ if (node.as<ExprNode>() && !is_non_relay_func) {
+ return PrintExpr(Downcast<Expr>(node), meta, try_inline);
+ } else if (node.as<TypeNode>()) {
+ return PrintType(Downcast<Type>(node), meta);
+ } else if (node.as<PatternNode>()) {
+ return PrintPattern(Downcast<Pattern>(node), meta);
+ } else if (node.as<IRModuleNode>()) {
+ return PrintMod(Downcast<IRModule>(node));
+ } else {
+ // default module.
+ std::ostringstream os;
+ os << node;
+ return Doc::RawText(os.str());
}
+}
- std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
- std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
-
- Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false) {
- bool is_non_relay_func =
- node->IsInstance<BaseFuncNode>() &&
- !node->IsInstance<relay::FunctionNode>();
- if (node.as<ExprNode>() && !is_non_relay_func) {
- return PrintExpr(Downcast<Expr>(node), meta, try_inline);
- } else if (node.as<TypeNode>()) {
- return PrintType(Downcast<Type>(node), meta);
- } else if (node.as<PatternNode>()) {
- return PrintPattern(Downcast<Pattern>(node), meta);
- } else if (node.as<IRModuleNode>()) {
- return PrintMod(Downcast<IRModule>(node));
- } else {
- // default module.
+Doc RelayTextPrinter::TempVar(int n) {
+ Doc doc;
+ return doc << "%" << n;
+}
+
+Doc RelayTextPrinter::AllocTemp() {
+ return TempVar(temp_var_counter_++);
+}
+
+/*!
+ * \brief get a unique name with the corresponding prefix
+ * \param prefix The prefix of the name
+ * \return The returned name.
+ */
+Doc RelayTextPrinter::GetUniqueName(const std::string& prefix) {
+ std::string unique_prefix = prefix;
+ auto it = name_alloc_map_.find(prefix);
+ if (it != name_alloc_map_.end()) {
+ while (true) {
std::ostringstream os;
- os << node;
- return Doc::RawText(os.str());
+ os << prefix << (++it->second);
+ std::string name = os.str();
+ if (name_alloc_map_.count(name) == 0) {
+ unique_prefix = name;
+ break;
+ }
}
}
+ name_alloc_map_[unique_prefix] = 0;
+ return Doc::Text(unique_prefix);
+}
- Doc TempVar(int n) {
- Doc doc;
- return doc << "%" << n;
- }
-
- Doc AllocTemp() {
- return TempVar(temp_var_counter_++);
- }
-
- /*!
- * \brief get a unique name with the corresponding prefix
- * \param prefix The prefix of the name
- * \return The returned name.
- */
- Doc GetUniqueName(const std::string& prefix) {
- std::string unique_prefix = prefix;
- auto it = name_alloc_map_.find(prefix);
- if (it != name_alloc_map_.end()) {
- while (true) {
- std::ostringstream os;
- os << prefix << (++it->second);
- std::string name = os.str();
- if (name_alloc_map_.count(name) == 0) {
- unique_prefix = name;
- break;
- }
- }
- }
- name_alloc_map_[unique_prefix] = 0;
- return Doc::Text(unique_prefix);
- }
-
- Doc Print(Kind k) {
- switch (k) {
- case kType:
- return Doc::Text("Type");
- case kShapeVar:
- return Doc::Text("Shape");
- case kBaseType:
- return Doc::Text("BaseType");
- case kConstraint:
- return Doc::Text("Constraint");
- case kAdtHandle:
- return Doc::Text("AdtHandle");
- case kTypeData:
- return Doc::Text("TypeData");
- default:
- LOG(ERROR) << "Unknown Kind";
- throw;
- }
+Doc RelayTextPrinter::Print(Kind k) {
+ switch (k) {
+ case kType:
+ return Doc::Text("Type");
+ case kShapeVar:
+ return Doc::Text("Shape");
+ case kBaseType:
+ return Doc::Text("BaseType");
+ case kConstraint:
+ return Doc::Text("Constraint");
+ case kAdtHandle:
+ return Doc::Text("AdtHandle");
+ case kTypeData:
+ return Doc::Text("TypeData");
+ default:
+ LOG(ERROR) << "Unknown Kind";
+ throw;
}
- /*!
- * \brief Allocate name to a type variable.
- * \param var The input type variable.
- * \return The corresponding name.
- */
- Doc AllocTypeVar(const TypeVar& var) {
- if (memo_type_.count(var)) {
- Doc val = memo_type_[var];
- val << "-malformed-ir";
- return val;
- }
- std::string name = var->name_hint;
- if (name.length() == 0 || !std::isalpha(name[0])) {
- name = "t" + name;
- }
- Doc val = GetUniqueName(name);
- memo_type_[var] = val;
- if (var->kind != kType) {
- val << ": " << Print(var->kind);
- }
+}
+/*!
+ * \brief Allocate name to a type variable.
+ * \param var The input type variable.
+ * \return The corresponding name.
+ */
+Doc RelayTextPrinter::AllocTypeVar(const TypeVar& var) {
+ if (memo_type_.count(var)) {
+ Doc val = memo_type_[var];
+ val << "-malformed-ir";
return val;
}
+ std::string name = var->name_hint;
+ if (name.length() == 0 || !std::isalpha(name[0])) {
+ name = "t" + name;
+ }
+ Doc val = GetUniqueName(name);
+ memo_type_[var] = val;
+ if (var->kind != kType) {
+ val << ": " << Print(var->kind);
+ }
+ return val;
+}
- /*!
- * \brief Allocate name to a variable.
- * \param var The input variable.
- * \return The corresponding name.
- */
- Doc AllocVar(const Var& var) {
- // still print if ir is malformed, but show the error.
- if (memo_.count(var)) {
- Doc val = memo_[var];
- val << "-malformed-ir";
- return val;
- }
- std::string name = var->name_hint();
- // always make sure first name is alpha
- if (name.length() == 0 || !std::isalpha(name[0])) {
- name = "v" + name;
- }
- Doc val = GetUniqueName("%" + name);
- memo_[var] = val;
- if (var->type_annotation.defined()) {
- val << ": " << Print(var->type_annotation);
- }
+/*!
+ * \brief Allocate name to a variable.
+ * \param var The input variable.
+ * \return The corresponding name.
+ */
+Doc RelayTextPrinter::AllocVar(const Var& var) {
+ // still print if ir is malformed, but show the error.
+ if (memo_.count(var)) {
+ Doc val = memo_[var];
+ val << "-malformed-ir";
return val;
}
-
- bool IsUnique(const Expr& expr) {
- auto it = dg_.expr_node.find(expr);
- if (it == dg_.expr_node.end()) {
- return true;
- } else {
- return !(it->second->parents.head && it->second->parents.head->next);
- }
+ std::string name = var->name_hint();
+ // always make sure first name is alpha
+ if (name.length() == 0 || !std::isalpha(name[0])) {
+ name = "v" + name;
}
+ Doc val = GetUniqueName("%" + name);
+ memo_[var] = val;
+ if (var->type_annotation.defined()) {
+ val << ": " << Print(var->type_annotation);
+ }
+ return val;
+}
- bool AlwaysInline(const Expr& expr) {
- return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
- expr.as<VarNode>() || expr.as<ConstructorNode>();
+bool RelayTextPrinter::IsUnique(const Expr& expr) {
+ auto it = dg_.expr_node.find(expr);
+ if (it == dg_.expr_node.end()) {
+ return true;
+ } else {
+ return !(it->second->parents.head && it->second->parents.head->next);
}
+}
- //------------------------------------
- // Overload of Expr printing functions
- //------------------------------------
- Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) {
- // Exploit memoization to print GNF.
- // The first time we visit an expression, we need to allocate a temp var
- // for it. Every subsequent time we can just use its assigned variable.
- // This works since hashing uses pointer equality.
+bool RelayTextPrinter::AlwaysInline(const Expr& expr) {
+ return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
+ expr.as<VarNode>() || expr.as<ConstructorNode>();
+}
- // determine whether to inline
- bool inline_expr = AlwaysInline(expr);
- if (try_inline) {
- inline_expr |= IsUnique(expr);
- }
+//------------------------------------
+// Overload of Expr printing functions
+//------------------------------------
+Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) {
+ // Exploit memoization to print GNF.
+ // The first time we visit an expression, we need to allocate a temp var
+ // for it. Every subsequent time we can just use its assigned variable.
+ // This works since hashing uses pointer equality.
+
+ // determine whether to inline
+ bool inline_expr = AlwaysInline(expr);
+ if (try_inline) {
+ inline_expr |= IsUnique(expr);
+ }
+
+ auto it = memo_.find(expr);
+ if (it != memo_.end()) return it->second;
+
+ Doc printed_expr;
+ if (meta) {
+ printed_expr = meta_->GetMetaNode(GetRef<ObjectRef>(expr.get()));
+ } else if (!inline_expr && expr.as<LetNode>()) {
+ // wrap GNFed let in brackets
+ Doc body;
+ printed_expr << "(";
+ printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine();
+ printed_expr << ")";
+ } else {
+ printed_expr = VisitExpr(expr);
+ }
- auto it = memo_.find(expr);
- if (it != memo_.end()) return it->second;
-
- Doc printed_expr;
- if (meta) {
- printed_expr = meta_.GetMetaNode(GetRef<ObjectRef>(expr.get()));
- } else if (!inline_expr && expr.as<LetNode>()) {
- // wrap GNFed let in brackets
- Doc body;
- printed_expr << "(";
- printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine();
- printed_expr << ")";
- } else {
- printed_expr = VisitExpr(expr);
- }
+ printed_expr << PrintOptionalInfo(expr);
- printed_expr << PrintOptionalInfo(expr);
-
- // add expr to doc
- if (expr.as<VarNode>()) {
- // This is our first time visiting the var and we hit the VarNode case
- // in the visitor. Thus the variable is free.
- doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine();
- // Memoization is done in AllocVar.
- return memo_[expr];
- } else if (inline_expr) {
- memo_[expr] = printed_expr;
- return printed_expr;
- } else {
- Doc temp_var = AllocTemp();
- memo_[expr] = temp_var;
- doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine();
- return temp_var;
- }
+ // add expr to doc
+ if (expr.as<VarNode>()) {
+ // This is our first time visiting the var and we hit the VarNode case
+ // in the visitor. Thus the variable is free.
+ doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine();
+ // Memoization is done in AllocVar.
+ return memo_[expr];
+ } else if (inline_expr) {
+ memo_[expr] = printed_expr;
+ return printed_expr;
+ } else {
+ Doc temp_var = AllocTemp();
+ memo_[expr] = temp_var;
+ doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine();
+ return temp_var;
}
+}
+
+// Should only be triggered when op is a free variable being visited for the
+// first time.
+Doc RelayTextPrinter::VisitExpr_(const VarNode* op) {
+ return AllocVar(GetRef<Var>(op));
+}
- // Should only be triggered when op is a free variable being visited for the
- // first time.
- Doc VisitExpr_(const VarNode* op) final {
- return AllocVar(GetRef<Var>(op));
+/*!
+ * \brief special method to print out const scalar
+ * \param dtype The data type
+ * \param value The value to be printed.
+ */
+template<typename T>
+Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) {
+ std::ostringstream os;
+ if (dtype == DataType::Int(32)) {
+ os << value;
+ } else if (dtype == DataType::Float(32)) {
+ os << value << 'f';
+ } else if (dtype == DataType::Float(64)) {
+ os << value;
+ } else if (dtype == DataType::Bool()) {
+ return Doc::PyBoolLiteral(value != 0);
+ } else {
+ os << value;
}
+ return Doc::Text(os.str());
+}
- /*!
- * \brief special method to print out const scalar
- * \param dtype The data type
- * \param value The value to be printed.
- */
- template<typename T>
- static Doc ScalarLiteral(DataType dtype, const T& value) {
+Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) {
+ // Print out simple scalars directly.
+ if (op->is_scalar()) {
std::ostringstream os;
+ DataType dtype = DataType(op->data->dtype);
+ CHECK_EQ(op->data->ctx.device_type, kDLCPU);
if (dtype == DataType::Int(32)) {
- os << value;
+ return ScalarLiteral(dtype, static_cast<const int32_t*>(op->data->data)[0]);
+ } else if (dtype == DataType::Int(64)) {
+ return ScalarLiteral(dtype, static_cast<const int64_t*>(op->data->data)[0]);
} else if (dtype == DataType::Float(32)) {
- os << value << 'f';
+ return ScalarLiteral(dtype, static_cast<const float*>(op->data->data)[0]);
} else if (dtype == DataType::Float(64)) {
- os << value;
+ return ScalarLiteral(dtype, static_cast<const double*>(op->data->data)[0]);
} else if (dtype == DataType::Bool()) {
- return Doc::PyBoolLiteral(value != 0);
- } else {
- os << value;
+ return ScalarLiteral(dtype, static_cast<const uint8_t*>(op->data->data)[0]);
}
- return Doc::Text(os.str());
}
+ // default fall-back, record it as meta node.
+ Doc doc;
+ return doc << Print(GetRef<ObjectRef>(op), true);
+}
- Doc VisitExpr_(const ConstantNode* op) final {
- // Print out simple scalars directly.
- if (op->is_scalar()) {
- std::ostringstream os;
- DataType dtype = DataType(op->data->dtype);
- CHECK_EQ(op->data->ctx.device_type, kDLCPU);
- if (dtype == DataType::Int(32)) {
- return ScalarLiteral(dtype, static_cast<const int32_t*>(op->data->data)[0]);
- } else if (dtype == DataType::Int(64)) {
- return ScalarLiteral(dtype, static_cast<const int64_t*>(op->data->data)[0]);
- } else if (dtype == DataType::Float(32)) {
- return ScalarLiteral(dtype, static_cast<const float*>(op->data->data)[0]);
- } else if (dtype == DataType::Float(64)) {
- return ScalarLiteral(dtype, static_cast<const double*>(op->data->data)[0]);
- } else if (dtype == DataType::Bool()) {
- return ScalarLiteral(dtype, static_cast<const uint8_t*>(op->data->data)[0]);
- }
- }
- // default fall-back, record it as meta node.
- Doc doc;
- return doc << Print(GetRef<ObjectRef>(op), true);
+Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
+ std::vector<Doc> fields;
+ for (Expr field : op->fields) {
+ fields.push_back(Print(field));
}
+ Doc doc;
+ doc << "(" << Doc::Concat(fields);
+ // conform to python tuple format (1,)
+ if (op->fields.size() == 1) {
+ doc << ",";
+ }
+ return doc << ")";
+}
- Doc VisitExpr_(const TupleNode* op) final {
- std::vector<Doc> fields;
- for (Expr field : op->fields) {
- fields.push_back(Print(field));
- }
- Doc doc;
- doc << "(" << Doc::Concat(fields);
- // conform to python tuple format (1,)
- if (op->fields.size() == 1) {
- doc << ",";
+Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) {
+ Doc doc;
+ return doc << Print(op->tuple) << "." << op->index;
+}
+
+Doc RelayTextPrinter::VisitExpr_(const IfNode* op) {
+ Doc doc;
+ doc << "if (" << Print(op->cond) << ") ";
+ doc << PrintBody(op->true_branch);
+ doc << " else ";
+ doc << PrintBody(op->false_branch);
+ return doc;
+}
+
+Doc RelayTextPrinter::VisitExpr_(const LetNode* op) {
+ Doc doc;
+ doc
+ << "let "
+ << AllocVar(op->var)
+ << " = "
+ << Print(op->value, false, true)
+ << ";"
+ << Doc::NewLine();
+ // we use a scope here so GNF hoisting doesn't escape too far
+ // and nested, unique lets are not hoisted
+ doc << PrintScope(op->body);
+ return doc;
+}
+
+Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
+ Doc doc;
+ doc << prefix;
+ if (fn->type_params.size() > 0) {
+ doc << "[";
+ std::vector<Doc> type_params;
+ for (const TypeVar& tv : fn->type_params) {
+ type_params.push_back(Doc::Text(tv->name_hint));
}
- return doc << ")";
+ doc << Doc::Concat(type_params);
+ doc << "]";
}
-
- Doc VisitExpr_(const TupleGetItemNode* op) final {
- Doc doc;
- return doc << Print(op->tuple) << "." << op->index;
+ doc << "(";
+ std::vector<Doc> params;
+ for (Var param : fn->params) {
+ params.push_back(AllocVar(param));
}
-
- Doc VisitExpr_(const IfNode* op) final {
- Doc doc;
- doc << "if (" << Print(op->cond) << ") ";
- doc << PrintBody(op->true_branch);
- doc << " else ";
- doc << PrintBody(op->false_branch);
- return doc;
+ for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
+ params.push_back(d);
}
-
- Doc VisitExpr_(const LetNode* op) final {
- Doc doc;
- doc
- << "let "
- << AllocVar(op->var)
- << " = "
- << Print(op->value, false, true)
- << ";"
- << Doc::NewLine();
- // we use a scope here so GNF hoisting doesn't escape too far
- // and nested, unique lets are not hoisted
- doc << PrintScope(op->body);
- return doc;
+ doc << Doc::Concat(params) << ") ";
+ if (fn->ret_type.defined()) {
+ doc << "-> " << Print(fn->ret_type) << " ";
}
+ doc << PrintBody(fn->body);
+ return doc;
+}
- Doc PrintFunc(const Doc& prefix, const relay::Function& fn) {
+Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
+ if (auto* n = base_func.as<relay::FunctionNode>()) {
+ return PrintFunc(prefix, GetRef<relay::Function>(n));
+ } else if (auto* n = base_func.as<tir::PrimFuncNode>()) {
+ std::ostringstream os;
+ os << GetRef<tir::PrimFunc>(n);
+ return Doc::RawText(os.str());
+ } else {
+ // def @xyz = meta['ExternalFunc'][id]
Doc doc;
- doc << prefix;
- if (fn->type_params.size() > 0) {
- doc << "[";
- std::vector<Doc> type_params;
- for (const TypeVar& tv : fn->type_params) {
- type_params.push_back(Doc::Text(tv->name_hint));
- }
- doc << Doc::Concat(type_params);
- doc << "]";
- }
- doc << "(";
- std::vector<Doc> params;
- for (Var param : fn->params) {
- params.push_back(AllocVar(param));
- }
- for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
- params.push_back(d);
- }
- doc << Doc::Concat(params) << ") ";
- if (fn->ret_type.defined()) {
- doc << "-> " << Print(fn->ret_type) << " ";
- }
- doc << PrintBody(fn->body);
+ doc << prefix << " = " << meta_->GetMetaNode(base_func);
return doc;
}
+}
- Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
- if (auto* n = base_func.as<relay::FunctionNode>()) {
- return PrintFunc(prefix, GetRef<relay::Function>(n));
- } else if (auto* n = base_func.as<tir::PrimFuncNode>()) {
- std::ostringstream os;
- os << GetRef<tir::PrimFunc>(n);
- return Doc::RawText(os.str());
- } else {
- // def @xyz = meta['ExternalFunc'][id]
- Doc doc;
- doc << prefix << " = " << meta_.GetMetaNode(base_func);
- return doc;
+Doc RelayTextPrinter::PrintMod(const IRModule& mod) {
+ Doc doc;
+ int counter = 0;
+ // type definitions
+ for (const auto& kv : mod->type_definitions) {
+ if (counter++ != 0) {
+ doc << Doc::NewLine();
}
+ doc << Print(kv.second);
+ doc << Doc::NewLine();
}
-
- Doc PrintMod(const IRModule& mod) {
- Doc doc;
- int counter = 0;
- // type definitions
- for (const auto& kv : mod->type_definitions) {
- if (counter++ != 0) {
- doc << Doc::NewLine();
- }
- doc << Print(kv.second);
- doc << Doc::NewLine();
+ // functions
+ for (const auto& kv : mod->functions) {
+ if (kv.second.as<relay::FunctionNode>()) {
+ dg_ = DependencyGraph::Create(&arena_, kv.second);
}
- // functions
- for (const auto& kv : mod->functions) {
- if (kv.second.as<relay::FunctionNode>()) {
- dg_ = DependencyGraph::Create(&arena_, kv.second);
- }
- if (counter++ != 0) {
- doc << Doc::NewLine();
- }
- std::ostringstream os;
- os << "def @" << kv.first->name_hint;
- doc << PrintFunc(Doc::Text(os.str()), kv.second);
+ if (counter++ != 0) {
doc << Doc::NewLine();
}
- return doc;
- }
-
- Doc VisitExpr_(const FunctionNode* op) final {
- return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
+ std::ostringstream os;
+ os << "def @" << kv.first->name_hint;
+ doc << PrintFunc(Doc::Text(os.str()), kv.second);
+ doc << Doc::NewLine();
}
+ return doc;
+}
- Doc VisitExpr_(const GlobalVarNode* op) final {
- return Doc::Text('@' + op->name_hint);
- }
+Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
+ return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
+}
- Doc VisitExpr_(const OpNode* op) final {
- return Doc::Text(op->name);
- }
+Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
+ return Doc::Text('@' + op->name_hint);
+}
- Doc VisitExpr_(const CallNode* op) final {
- Doc doc;
- // visit args first so they are lifted before the op
- // this places op closer to its call site
- std::vector<Doc> args;
- for (const Expr& arg : op->args) {
- args.push_back(Print(arg));
- }
- for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
- args.push_back(d);
- }
- const auto* cons_node = op->op.as<ConstructorNode>();
- if (cons_node) {
- doc << cons_node->name_hint;
- } else {
- doc << Print(op->op);
- }
+Doc RelayTextPrinter::VisitExpr_(const OpNode* op) {
+ return Doc::Text(op->name);
+}
- if (cons_node && cons_node->inputs.size() == 0) {
- // don't print as a call if it's a 0-arity cons
- return doc;
- } else {
- return doc << "(" << Doc::Concat(args) << ")";
- }
+Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
+ Doc doc;
+ // visit args first so they are lifted before the op
+ // this places op closer to its call site
+ std::vector<Doc> args;
+ for (const Expr& arg : op->args) {
+ args.push_back(Print(arg));
+ }
+ for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
+ args.push_back(d);
+ }
+ const auto* cons_node = op->op.as<ConstructorNode>();
+ if (cons_node) {
+ doc << cons_node->name_hint;
+ } else {
+ doc << Print(op->op);
}
- Doc VisitExpr_(const RefCreateNode* op) final {
- Doc doc;
- return doc << "ref(" << Print(op->value) << ")";
+ if (cons_node && cons_node->inputs.size() == 0) {
+ // don't print as a call if it's a 0-arity cons
+ return doc;
+ } else {
+ return doc << "(" << Doc::Concat(args) << ")";
}
+}
- Doc VisitExpr_(const RefReadNode* op) final {
- Doc doc;
- return doc << Print(op->ref) << "^";
- }
+Doc RelayTextPrinter::VisitExpr_(const RefCreateNode* op) {
+ Doc doc;
+ return doc << "ref(" << Print(op->value) << ")";
+}
- Doc VisitExpr_(const RefWriteNode* op) final {
- Doc doc;
- return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")";
- }
+Doc RelayTextPrinter::VisitExpr_(const RefReadNode* op) {
+ Doc doc;
+ return doc << Print(op->ref) << "^";
+}
- Doc VisitExpr_(const MatchNode* op) final {
- // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
- Doc doc;
- Doc body;
- doc << "match";
- if (!op->complete) {
- doc << "?";
- }
- doc << " (" << Print(op->data) << ") {";
- std::vector<Doc> clause_docs;
- for (const auto& clause : op->clauses) {
- Doc clause_doc;
- clause_doc << PrintPattern(clause->lhs, false) << " => ";
- Doc rhs_doc = PrintScope(clause->rhs);
- if (clause->rhs.as<LetNode>()) {
- // only add braces if there are multiple lines on the rhs
- rhs_doc = Doc::Brace("{", rhs_doc, "}");
- }
- clause_doc << rhs_doc << ",";
- clause_docs.push_back(clause_doc);
- }
- doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine()))
- << Doc::NewLine() << "}";
- return doc;
- }
+Doc RelayTextPrinter::VisitExpr_(const RefWriteNode* op) {
+ Doc doc;
+ return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")";
+}
- Doc PrintPattern(const Pattern& pattern, bool meta) {
- auto it = memo_pattern_.find(pattern);
- if (it != memo_pattern_.end()) return it->second;
- Doc printed_pattern;
- if (meta) {
- printed_pattern = meta_.GetMetaNode(GetRef<ObjectRef>(pattern.get()));
- } else {
- printed_pattern = VisitPattern(pattern);
- }
- memo_pattern_[pattern] = printed_pattern;
- return printed_pattern;
- }
+Doc RelayTextPrinter::VisitExpr_(const MatchNode* op) {
+ // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
+ Doc doc;
+ Doc body;
+ doc << "match";
+ if (!op->complete) {
+ doc << "?";
+ }
+ doc << " (" << Print(op->data) << ") {";
+ std::vector<Doc> clause_docs;
+ for (const auto& clause : op->clauses) {
+ Doc clause_doc;
+ clause_doc << PrintPattern(clause->lhs, false) << " => ";
+ Doc rhs_doc = PrintScope(clause->rhs);
+ if (clause->rhs.as<LetNode>()) {
+ // only add braces if there are multiple lines on the rhs
+ rhs_doc = Doc::Brace("{", rhs_doc, "}");
+ }
+ clause_doc << rhs_doc << ",";
+ clause_docs.push_back(clause_doc);
+ }
+ doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine()))
+ << Doc::NewLine() << "}";
+ return doc;
+}
- Doc VisitPattern_(const PatternConstructorNode* p) final {
- Doc doc;
- doc << p->constructor->name_hint;
- if (!p->patterns.empty()) {
- doc << "(";
- std::vector<Doc> pats;
- for (const auto& pat : p->patterns) {
- pats.push_back(Print(pat));
- }
- doc << Doc::Concat(pats) << ")";
- }
- return doc;
+Doc RelayTextPrinter::PrintPattern(const Pattern& pattern, bool meta) {
+ auto it = memo_pattern_.find(pattern);
+ if (it != memo_pattern_.end()) return it->second;
+ Doc printed_pattern;
+ if (meta) {
+ printed_pattern = meta_->GetMetaNode(GetRef<ObjectRef>(pattern.get()));
+ } else {
+ printed_pattern = VisitPattern(pattern);
}
+ memo_pattern_[pattern] = printed_pattern;
+ return printed_pattern;
+}
- Doc VisitPattern_(const PatternTupleNode* pt) final {
- Doc doc;
+Doc RelayTextPrinter::VisitPattern_(const PatternConstructorNode* p) {
+ Doc doc;
+ doc << p->constructor->name_hint;
+ if (!p->patterns.empty()) {
doc << "(";
std::vector<Doc> pats;
- for (const auto& pat : pt->patterns) {
+ for (const auto& pat : p->patterns) {
pats.push_back(Print(pat));
}
doc << Doc::Concat(pats) << ")";
- return doc;
}
+ return doc;
+}
- Doc VisitPattern_(const PatternWildcardNode* pw) final {
- return Doc::Text("_");
+Doc RelayTextPrinter::VisitPattern_(const PatternTupleNode* pt) {
+ Doc doc;
+ doc << "(";
+ std::vector<Doc> pats;
+ for (const auto& pat : pt->patterns) {
+ pats.push_back(Print(pat));
}
+ doc << Doc::Concat(pats) << ")";
+ return doc;
+}
- Doc VisitPattern_(const PatternVarNode* pv) final {
- return AllocVar(pv->var);
- }
+Doc RelayTextPrinter::VisitPattern_(const PatternWildcardNode* pw) {
+ return Doc::Text("_");
+}
- Doc VisitExpr_(const ConstructorNode* n) final {
- Doc doc;
- doc << n->name_hint;
- if (in_adt_def_ && n->inputs.size() != 0) {
- doc << "(";
- std::vector<Doc> inputs;
- for (Type input : n->inputs) {
- inputs.push_back(Print(input));
- }
- doc << Doc::Concat(inputs) << ")";
- }
- return doc;
- }
+Doc RelayTextPrinter::VisitPattern_(const PatternVarNode* pv) {
+ return AllocVar(pv->var);
+}
- //------------------------------------
- // Overload of Type printing functions
- //------------------------------------
- Doc PrintType(const Type& type, bool meta) {
- auto it = memo_type_.find(type);
- if (it != memo_type_.end()) return it->second;
- Doc printed_type;
- if (meta) {
- printed_type = meta_.GetMetaNode(GetRef<ObjectRef>(type.get()));
- } else {
- printed_type = VisitType(type);
+Doc RelayTextPrinter::VisitExpr_(const ConstructorNode* n) {
+ Doc doc;
+ doc << n->name_hint;
+ if (in_adt_def_ && n->inputs.size() != 0) {
+ doc << "(";
+ std::vector<Doc> inputs;
+ for (Type input : n->inputs) {
+ inputs.push_back(Print(input));
}
- memo_type_[type] = printed_type;
- return printed_type;
+ doc << Doc::Concat(inputs) << ")";
}
+ return doc;
+}
- Doc VisitTypeDefault_(const Object* node) final {
- // by default always print as meta data
- return Print(GetRef<ObjectRef>(node), true);
+//------------------------------------
+// Overload of Type printing functions
+//------------------------------------
+Doc RelayTextPrinter::PrintType(const Type& type, bool meta) {
+ auto it = memo_type_.find(type);
+ if (it != memo_type_.end()) return it->second;
+ Doc printed_type;
+ if (meta) {
+ printed_type = meta_->GetMetaNode(GetRef<ObjectRef>(type.get()));
+ } else {
+ printed_type = VisitType(type);
}
+ memo_type_[type] = printed_type;
+ return printed_type;
+}
- Doc VisitType_(const TypeVarNode* node) final {
- return Doc::Text(node->name_hint);
- }
+Doc RelayTextPrinter::VisitTypeDefault_(const Object* node) {
+ // by default always print as meta data
+ return Print(GetRef<ObjectRef>(node), true);
+}
- Doc VisitType_(const GlobalTypeVarNode* node) final {
- return Doc::Text(node->name_hint);
- }
+Doc RelayTextPrinter::VisitType_(const TypeVarNode* node) {
+ return Doc::Text(node->name_hint);
+}
- Doc VisitType_(const TypeCallNode* node) final {
- Doc doc = PrintType(node->func, false);
- std::vector<Doc> args;
- for (const Type& t : node->args) {
- args.push_back(PrintType(t, false));
- }
- doc << "[";
- doc << Doc::Concat(args);
- doc << "]";
- return doc;
- }
+Doc RelayTextPrinter::VisitType_(const GlobalTypeVarNode* node) {
+ return Doc::Text(node->name_hint);
+}
- Doc PrintDType(DataType dtype) {
- return Doc::Text(runtime::DLDataType2String(dtype));
+Doc RelayTextPrinter::VisitType_(const TypeCallNode* node) {
+ Doc doc = PrintType(node->func, false);
+ std::vector<Doc> args;
+ for (const Type& t : node->args) {
+ args.push_back(PrintType(t, false));
}
+ doc << "[";
+ doc << Doc::Concat(args);
+ doc << "]";
+ return doc;
+}
- Doc VisitType_(const TensorTypeNode* node) final {
- // scalar type
- if (node->shape.size() == 0) {
- return PrintDType(node->dtype);
- }
- Doc doc;
- doc << "Tensor[(";
- std::vector<Doc> shapes;
- for (ObjectRef shape : node->shape) {
- shapes.push_back(PrintAttr(shape));
- }
- doc << Doc::Concat(shapes);
- return doc << "), " << PrintDType(node->dtype) << "]";
+Doc RelayTextPrinter::PrintDType(DataType dtype) {
+ return Doc::Text(runtime::DLDataType2String(dtype));
+}
+
+Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) {
+ // scalar type
+ if (node->shape.size() == 0) {
+ return PrintDType(node->dtype);
}
+ Doc doc;
+ doc << "Tensor[(";
+ std::vector<Doc> shapes;
+ for (ObjectRef shape : node->shape) {
+ shapes.push_back(PrintAttr(shape));
+ }
+ doc << Doc::Concat(shapes);
+ return doc << "), " << PrintDType(node->dtype) << "]";
+}
- Doc VisitType_(const TupleTypeNode* node) final {
- std::vector<Doc> fields;
- for (Type field : node->fields) {
- fields.push_back(Print(field));
- }
- Doc doc;
- doc << "(" << Doc::Concat(fields);
- // conform to python tuple format (1,)
- if (node->fields.size() == 1) {
- doc << ",";
- }
- return doc << ")";
+Doc RelayTextPrinter::VisitType_(const TupleTypeNode* node) {
+ std::vector<Doc> fields;
+ for (Type field : node->fields) {
+ fields.push_back(Print(field));
+ }
+ Doc doc;
+ doc << "(" << Doc::Concat(fields);
+ // conform to python tuple format (1,)
+ if (node->fields.size() == 1) {
+ doc << ",";
}
+ return doc << ")";
+}
- Doc VisitType_(const FuncTypeNode* node) final {
- Doc doc;
- doc << "fn ";
- if (node->type_params.size() != 0) {
- doc << "[";
- std::vector<Doc> type_params;
- for (Type type_param : node->type_params) {
- type_params.push_back(Print(type_param));
- }
- doc << Doc::Concat(type_params);
- doc << "]";
- }
- std::vector<Doc> arg_types;
- for (Type arg_type : node->arg_types) {
- arg_types.push_back(Print(arg_type));
+Doc RelayTextPrinter::VisitType_(const FuncTypeNode* node) {
+ Doc doc;
+ doc << "fn ";
+ if (node->type_params.size() != 0) {
+ doc << "[";
+ std::vector<Doc> type_params;
+ for (Type type_param : node->type_params) {
+ type_params.push_back(Print(type_param));
}
- return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type);
+ doc << Doc::Concat(type_params);
+ doc << "]";
}
-
- Doc VisitType_(const RelayRefTypeNode* node) final {
- Doc doc;
- return doc << "ref(" << Print(node->value) << ")";
+ std::vector<Doc> arg_types;
+ for (Type arg_type : node->arg_types) {
+ arg_types.push_back(Print(arg_type));
}
+ return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type);
+}
- Doc VisitType_(const TypeDataNode* node) final {
- in_adt_def_ = true;
- Doc doc;
- doc << "type " << Print(node->header);
-
- // type vars
- if (node->type_vars.size() != 0) {
- doc << "[";
- std::vector<Doc> type_vars;
- for (Type type_var : node->type_vars) {
- type_vars.push_back(Print(type_var));
- }
- doc << Doc::Concat(type_vars) << "]";
- }
- doc << " ";
+Doc RelayTextPrinter::VisitType_(const RelayRefTypeNode* node) {
+ Doc doc;
+ return doc << "ref(" << Print(node->value) << ")";
+}
- std::vector<Doc> constructor_docs;
- for (Constructor constructor : node->constructors) {
- constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
- }
- Doc separator;
- separator << "," << Doc::NewLine();
- Doc adt_body;
- adt_body << Doc::Concat(constructor_docs, separator);
- // add trailing comma if there are any constructors
- if (!constructor_docs.empty()) {
- adt_body << ",";
+Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) {
+ in_adt_def_ = true;
+ Doc doc;
+ doc << "type " << Print(node->header);
+
+ // type vars
+ if (node->type_vars.size() != 0) {
+ doc << "[";
+ std::vector<Doc> type_vars;
+ for (Type type_var : node->type_vars) {
+ type_vars.push_back(Print(type_var));
}
- doc << Doc::Brace("{", adt_body, "}");
- in_adt_def_ = false;
- return doc;
+ doc << Doc::Concat(type_vars) << "]";
}
+ doc << " ";
- //------------------------------------
- // Overload of Attr printing functions
- //------------------------------------
+ std::vector<Doc> constructor_docs;
+ for (Constructor constructor : node->constructors) {
+ constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
+ }
+ Doc separator;
+ separator << "," << Doc::NewLine();
+ Doc adt_body;
+ adt_body << Doc::Concat(constructor_docs, separator);
+ // add trailing comma if there are any constructors
+ if (!constructor_docs.empty()) {
+ adt_body << ",";
+ }
+ doc << Doc::Brace("{", adt_body, "}");
+ in_adt_def_ = false;
+ return doc;
+}
- Doc PrintAttr(const ObjectRef& value, bool meta = false) {
- if (value.defined()) {
- Doc printed_attr;
- if (value.as<tvm::tir::AnyNode>()) {
- printed_attr << "?";
- } else if (meta) {
- printed_attr = meta_.GetMetaNode(Downcast<ObjectRef>(value));
- } else {
- printed_attr = VisitAttr(value);
- }
- return printed_attr;
+//------------------------------------
+// Overload of Attr printing functions
+//------------------------------------
+
+Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) {
+ if (value.defined()) {
+ Doc printed_attr;
+ if (value.as<tvm::tir::AnyNode>()) {
+ printed_attr << "?";
+ } else if (meta) {
+ printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
} else {
- return Doc::Text("None");
+ printed_attr = VisitAttr(value);
}
+ return printed_attr;
+ } else {
+ return Doc::Text("None");
}
+}
- Doc VisitAttrDefault_(const Object* op) final {
- return PrintAttr(GetRef<ObjectRef>(op), true);
- }
+Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) {
+ return PrintAttr(GetRef<ObjectRef>(op), true);
+}
- Doc VisitAttr_(const ArrayNode* op) final {
- Doc doc;
- doc << "[";
- std::vector<Doc> arr_vals;
- for (auto val : op->data) {
- arr_vals.push_back(PrintAttr(val));
- }
- doc << Doc::Concat(arr_vals);
- doc << "]";
- return doc;
- }
+Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) {
+ Doc doc;
+ doc << "[";
+ std::vector<Doc> arr_vals;
+ for (auto val : op->data) {
+ arr_vals.push_back(PrintAttr(val));
+ }
+ doc << Doc::Concat(arr_vals);
+ doc << "]";
+ return doc;
+}
- Doc VisitAttr_(const tir::IntImmNode* op) final {
- return ScalarLiteral(op->dtype, op->value);
- }
+Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) {
+ return ScalarLiteral(op->dtype, op->value);
+}
- Doc VisitAttr_(const tir::FloatImmNode* op) final {
- return ScalarLiteral(op->dtype, op->value);
- }
+Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) {
+ return ScalarLiteral(op->dtype, op->value);
+}
- Doc VisitAttr_(const tir::StringImmNode* op) final {
- return Doc::StrLiteral(op->value);
- }
+Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) {
+ return Doc::StrLiteral(op->value);
+}
- private:
- /*! \brief Whether to print meta data. */
- bool show_meta_data_;
- /*! \brief additional comment function */
- runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
- /*! \brief Stack of docs to implement scoped GNFing. */
- std::vector<Doc> doc_stack_{};
- /*! \brief Map from Expr to Doc */
- std::unordered_map<Expr, Doc, ObjectHash, ObjectEqual> memo_;
- /*! \brief Map from Type to Doc */
- std::unordered_map<Type, Doc, ObjectHash, ObjectEqual> memo_type_;
- /*! \brief Map from Type to Doc */
- std::unordered_map<Pattern, Doc, ObjectHash, ObjectEqual> memo_pattern_;
- /*! \brief name allocation map */
- std::unordered_map<std::string, int> name_alloc_map_;
- /*! \brief meta data context */
- TextMetaDataContext meta_;
- /*! \brief counter of temporary variable */
- size_t temp_var_counter_{0};
- /*! \brief whether the printer is currently in an ADT definition */
- bool in_adt_def_;
- /*! \brief arena for dependency graph */
- support::Arena arena_;
- /*! \brief dependency graph of the expr */
- DependencyGraph dg_;
- class AttrPrinter;
- friend class AttrPrinter;
-};
/*!
* \brief Attribute printer which prints the attributes in the call.
if (op_node && (attrs->type_index() != op_node->attrs_type_index)) {
// fallback
Doc doc;
- doc << meta_.GetMetaNode(attrs);
+ doc << meta_->GetMetaNode(attrs);
docs.push_back(doc);
return docs;
} else {
}
return docs;
}
-} // namespace relay
-static const char* kSemVer = "v0.0.4";
-
-// TODO(tvm-team): split into files, related: arith/analyzer.h
-//
-// - text_printer.h (common header)
-// - text_printer.cc (prints modules dispatch into relay and tir files)
-// - type_text_printer.cc(specific printing logics for types,
-// can also consider put under type_text_printer)
-// - Implements AsText
-// - relay_text_printer.cc (specific printing logics for relay)
-// - tir_text_printer.cc (specific printing logics for TIR)
-String PrettyPrint(const ObjectRef& node) {
- Doc doc;
- doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node);
- return doc.str();
-}
-
-String AsText(const ObjectRef& node,
- bool show_meta_data,
- runtime::TypedPackedFunc<String(ObjectRef)> annotate) {
- Doc doc;
- doc << kSemVer << Doc::NewLine();
- runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr;
- if (annotate != nullptr) {
- ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>(
- [&annotate](const ObjectRef& expr) -> std::string {
- return annotate(expr);
- });
- }
- doc << relay::RelayTextPrinter(show_meta_data, ftyped).PrintFinal(node);
- return doc.str();
-}
-
-TVM_REGISTER_GLOBAL("ir.PrettyPrint")
-.set_body_typed(PrettyPrint);
-
-TVM_REGISTER_GLOBAL("ir.AsText")
-.set_body_typed(AsText);
+} // namespace relay
} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file text_printer.cc
+ * \brief Printer to print out the unified IR text format
+ * that can be parsed by a parser.
+ */
+
+#include <tvm/tir/function.h>
+#include <string>
+#include "text_printer.h"
+
+namespace tvm {
+
+static const char* kSemVer = "v0.0.4";
+
+// TODO(tvm-team): split into files, related: arith/analyzer.h
+//
+// - text_printer.h (common header)
+// - text_printer.cc (prints modules dispatch into relay and tir files)
+// - type_text_printer.cc(specific printing logics for types,
+// can also consider put under type_text_printer)
+// - Implements AsText
+// - relay_text_printer.cc (specific printing logics for relay)
+// - tir_text_printer.cc (specific printing logics for TIR)
+
+Doc TextPrinter::PrintMod(const IRModule& mod) {
+ Doc doc;
+ int counter = 0;
+ // type definitions
+ for (const auto& kv : mod->type_definitions) {
+ if (counter++ != 0) {
+ doc << Doc::NewLine();
+ }
+ doc << relay_text_printer_.Print(kv.second);
+ doc << Doc::NewLine();
+ }
+ // functions
+ for (const auto& kv : mod->functions) {
+ if (kv.second.as<relay::FunctionNode>()) {
+ relay_text_printer_.dg_ =
+ relay::DependencyGraph::Create(&relay_text_printer_.arena_, kv.second);
+ }
+ if (counter++ != 0) {
+ doc << Doc::NewLine();
+ }
+ if (kv.second.as<relay::FunctionNode>()) {
+ std::ostringstream os;
+ os << "def @" << kv.first->name_hint;
+ doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
+ } else if (kv.second.as<tir::PrimFuncNode>()) {
+ doc << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
+ }
+ doc << Doc::NewLine();
+ }
+ return doc;
+}
+
+String PrettyPrint(const ObjectRef& node) {
+ Doc doc;
+ doc << TextPrinter(false, nullptr).PrintFinal(node);
+ return doc.str();
+}
+
+String AsText(const ObjectRef& node,
+ bool show_meta_data,
+ runtime::TypedPackedFunc<String(ObjectRef)> annotate) {
+ Doc doc;
+ doc << kSemVer << Doc::NewLine();
+ runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr;
+ if (annotate != nullptr) {
+ ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>(
+ [&annotate](const ObjectRef& expr) -> std::string {
+ return annotate(expr);
+ });
+ }
+ doc << TextPrinter(show_meta_data, ftyped).PrintFinal(node);
+ return doc.str();
+}
+
+TVM_REGISTER_GLOBAL("ir.PrettyPrint")
+.set_body_typed(PrettyPrint);
+
+TVM_REGISTER_GLOBAL("ir.AsText")
+.set_body_typed(AsText);
+
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file text_printer.h
+ * \brief Printer to print out the unified IR text format
+ * that can be parsed by a parser.
+ */
+
+#ifndef TVM_PRINTER_TEXT_PRINTER_H_
+#define TVM_PRINTER_TEXT_PRINTER_H_
+
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/op.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/ir/module.h>
+#include <tvm/tir/function.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
+#include <unordered_map>
+#include <vector>
+#include <string>
+#include "../relay/analysis/dependency_graph.h"
+#include "../ir/attr_functor.h"
+
+#include "doc.h"
+#include "meta_data.h"
+#include "text_printer.h"
+
+namespace tvm {
+class TextPrinter;
+} // namespace tvm
+
+namespace tvm {
+namespace relay {
+
+class RelayTextPrinter :
+ public ExprFunctor<Doc(const Expr&)>,
+ public PatternFunctor<Doc(const Pattern&)>,
+ public TypeFunctor<Doc(const Type&)>,
+ public AttrFunctor<Doc(const ObjectRef&)> {
+ public:
+ explicit RelayTextPrinter(bool show_meta_data,
+ TextMetaDataContext* meta,
+ runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
+ : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {}
+
+ /*!
+ * \brief Print additional info about expr in comment.
+ * \param expr The expression.
+ */
+ Doc PrintOptionalInfo(const Expr& expr);
+ // indent a new body
+ Doc PrintBody(const ObjectRef& node, int indent = 2);
+ // create a new scope by creating a new printer object. This allows temp var
+ // numbers to be reused and prevents hoisted vars from escaping too far
+ Doc PrintScope(const ObjectRef& node);
+ Doc PrintFinal(const ObjectRef& node);
+ std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
+ std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
+
+ Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);
+
+ Doc TempVar(int n);
+ Doc AllocTemp();
+ /*!
+ * \brief get a unique name with the corresponding prefix
+ * \param prefix The prefix of the name
+ * \return The returned name.
+ */
+ Doc GetUniqueName(const std::string& prefix);
+ Doc Print(Kind k);
+ /*!
+ * \brief Allocate name to a type variable.
+ * \param var The input type variable.
+ * \return The corresponding name.
+ */
+ Doc AllocTypeVar(const TypeVar& var);
+ /*!
+ * \brief Allocate name to a variable.
+ * \param var The input variable.
+ * \return The corresponding name.
+ */
+ Doc AllocVar(const Var& var);
+ bool IsUnique(const Expr& expr);
+ bool AlwaysInline(const Expr& expr);
+
+ Doc PrintFunc(const Doc& prefix, const relay::Function& fn);
+ Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func);
+ Doc PrintMod(const IRModule& mod);
+
+ //------------------------------------
+ // Overload of Expr printing functions
+ //------------------------------------
+ Doc PrintExpr(const Expr& expr, bool meta, bool try_inline);
+ // Should only be triggered when op is a free variable being visited for the
+ // first time.
+ Doc VisitExpr_(const VarNode* op) final;
+ /*!
+ * \brief special method to print out const scalar
+ * \param dtype The data type
+ * \param value The value to be printed.
+ */
+ template <typename T>
+ static Doc ScalarLiteral(DataType dtype, const T& value);
+ Doc VisitExpr_(const ConstantNode* op) final;
+ Doc VisitExpr_(const TupleNode* op) final;
+ Doc VisitExpr_(const TupleGetItemNode* op) final;
+ Doc VisitExpr_(const IfNode* op) final;
+ Doc VisitExpr_(const LetNode* op) final;
+ Doc VisitExpr_(const FunctionNode* op) final;
+ Doc VisitExpr_(const GlobalVarNode* op) final;
+ Doc VisitExpr_(const OpNode* op) final;
+ Doc VisitExpr_(const CallNode* op) final;
+ Doc VisitExpr_(const RefCreateNode* op) final;
+ Doc VisitExpr_(const RefReadNode* op) final;
+ Doc VisitExpr_(const RefWriteNode* op) final;
+ Doc VisitExpr_(const MatchNode* op) final;
+ Doc PrintPattern(const Pattern& pattern, bool meta);
+ Doc VisitPattern_(const PatternConstructorNode* p) final;
+ Doc VisitPattern_(const PatternTupleNode* pt) final;
+ Doc VisitPattern_(const PatternWildcardNode* pw) final;
+ Doc VisitPattern_(const PatternVarNode* pv) final;
+ Doc VisitExpr_(const ConstructorNode* n) final;
+ //------------------------------------
+ // Overload of Type printing functions
+ //------------------------------------
+ Doc PrintType(const Type& type, bool meta);
+ Doc VisitTypeDefault_(const Object* node) final;
+ Doc VisitType_(const TypeVarNode* node) final;
+ Doc VisitType_(const GlobalTypeVarNode* node);
+ Doc VisitType_(const TypeCallNode* node) final;
+ Doc PrintDType(DataType dtype);
+ Doc VisitType_(const TensorTypeNode* node) final;
+ Doc VisitType_(const TupleTypeNode* node) final;
+ Doc VisitType_(const FuncTypeNode* node) final;
+ Doc VisitType_(const RelayRefTypeNode* node) final;
+ Doc VisitType_(const TypeDataNode* node) final;
+ //------------------------------------
+ // Overload of Attr printing functions
+ //------------------------------------
+ Doc PrintAttr(const ObjectRef& value, bool meta = false);
+ Doc VisitAttrDefault_(const Object* op) final;
+ Doc VisitAttr_(const ArrayNode* op) final;
+ Doc VisitAttr_(const tir::IntImmNode* op) final;
+ Doc VisitAttr_(const tir::FloatImmNode* op) final;
+ Doc VisitAttr_(const tir::StringImmNode* op) final;
+
+ private:
+ /*! \brief Whether to print meta data. */
+ bool show_meta_data_;
+ /*! \brief additional comment function */
+ runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
+ /*! \brief Stack of docs to implement scoped GNFing. */
+ std::vector<Doc> doc_stack_{};
+ /*! \brief Map from Expr to Doc */
+ std::unordered_map<Expr, Doc, ObjectHash, ObjectEqual> memo_;
+ /*! \brief Map from Type to Doc */
+ std::unordered_map<Type, Doc, ObjectHash, ObjectEqual> memo_type_;
+ /*! \brief Map from Type to Doc */
+ std::unordered_map<Pattern, Doc, ObjectHash, ObjectEqual> memo_pattern_;
+ /*! \brief name allocation map */
+ std::unordered_map<std::string, int> name_alloc_map_;
+ /*! \brief meta data context */
+ TextMetaDataContext* meta_;
+ /*! \brief counter of temporary variable */
+ size_t temp_var_counter_{0};
+ /*! \brief whether the printer is currently in an ADT definition */
+ bool in_adt_def_;
+ /*! \brief arena for dependency graph */
+ support::Arena arena_;
+ /*! \brief dependency graph of the expr */
+ DependencyGraph dg_;
+ class AttrPrinter;
+ friend class AttrPrinter;
+ friend class tvm::TextPrinter;
+};
+
+} // namespace relay
+} // namespace tvm
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Meta node collector
+ * If we decide to put some node into meta, then all the sub-nodes inside
+ * it need to be put in meta as well, since when parsing we need to know
+ * whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+ explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+ void Collect(const ObjectRef& n) {
+ // these nodes can be print directly(StringLiteral or use identifier to identify)
+ if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+ || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+ return;
+ }
+ if (n->IsInstance<StmtNode>()) {
+ VisitStmt(Downcast<Stmt>(n));
+ } else if (n->IsInstance<PrimExprNode>()) {
+ VisitExpr(Downcast<PrimExpr>(n));
+ }
+ }
+
+ void VisitStmt(const Stmt& n) override {
+ meta_->GetMetaNode(n);
+ StmtVisitor::VisitStmt(n);
+ }
+
+ void VisitExpr(const PrimExpr& n) override {
+ meta_->GetMetaNode(n);
+ ExprVisitor::VisitExpr(n);
+ }
+
+ private:
+ TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+ public ExprFunctor<Doc(const PrimExpr&)>,
+ public TypeFunctor<Doc(const Type&)> {
+ public:
+ explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta)
+ : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {}
+
+ /*! \brief Print the node */
+ Doc Print(const ObjectRef& node);
+
+ private:
+ /*! \brief whether show meta data */
+ bool show_meta_;
+ /*! \brief meta data context */
+ TextMetaDataContext* meta_;
+ /*! \brief meta collector */
+ MetaCollector meta_collector_;
+ /*! \brief Map from Var to Doc */
+ std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+ /*! \brief Map from Buffer to Doc */
+ std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+ /*! \brief name allocation map */
+ std::unordered_map<std::string, int> name_alloc_map_;
+
+ friend class tvm::TextPrinter;
+
+ Doc VisitExpr_(const IntImmNode* op) override;
+ Doc VisitExpr_(const FloatImmNode* op) override;
+ Doc VisitExpr_(const StringImmNode* op) override;
+ Doc VisitExpr_(const CastNode* op) override;
+ Doc VisitExpr_(const VarNode* op) override;
+ Doc VisitExpr_(const AddNode* op) override;
+ Doc VisitExpr_(const SubNode* op) override;
+ Doc VisitExpr_(const MulNode* op) override;
+ Doc VisitExpr_(const DivNode* op) override;
+ Doc VisitExpr_(const ModNode* op) override;
+ Doc VisitExpr_(const FloorDivNode* op) override;
+ Doc VisitExpr_(const FloorModNode* op) override;
+ Doc VisitExpr_(const MinNode* op) override;
+ Doc VisitExpr_(const MaxNode* op) override;
+ Doc VisitExpr_(const EQNode* op) override;
+ Doc VisitExpr_(const NENode* op) override;
+ Doc VisitExpr_(const LTNode* op) override;
+ Doc VisitExpr_(const LENode* op) override;
+ Doc VisitExpr_(const GTNode* op) override;
+ Doc VisitExpr_(const GENode* op) override;
+ Doc VisitExpr_(const AndNode* op) override;
+ Doc VisitExpr_(const OrNode* op) override;
+ Doc VisitExpr_(const NotNode* op) override;
+ Doc VisitExpr_(const SelectNode* op) override;
+ Doc VisitExpr_(const BufferLoadNode* op) override;
+ Doc VisitExpr_(const LoadNode* op) override;
+ Doc VisitExpr_(const RampNode* op) override;
+ Doc VisitExpr_(const BroadcastNode* op) override;
+ Doc VisitExpr_(const LetNode* op) override;
+ Doc VisitExpr_(const CallNode* op) override;
+ Doc VisitExpr_(const ShuffleNode* op) override;
+ Doc VisitExpr_(const ReduceNode* op) override;
+ Doc VisitExprDefault_(const Object* op) override;
+
+ Doc VisitStmt_(const LetStmtNode* op) override;
+ Doc VisitStmt_(const AttrStmtNode* op) override;
+ Doc VisitStmt_(const AssertStmtNode* op) override;
+ Doc VisitStmt_(const StoreNode* op) override;
+ Doc VisitStmt_(const BufferStoreNode* op) override;
+ Doc VisitStmt_(const BufferRealizeNode* op) override;
+ Doc VisitStmt_(const AllocateNode* op) override;
+ Doc VisitStmt_(const FreeNode* op) override;
+ Doc VisitStmt_(const IfThenElseNode* op) override;
+ Doc VisitStmt_(const SeqStmtNode* op) override;
+ Doc VisitStmt_(const EvaluateNode* op) override;
+ Doc VisitStmt_(const ForNode* op) override;
+ Doc VisitStmt_(const PrefetchNode* op) override;
+ Doc VisitStmtDefault_(const Object* op) override;
+
+ Doc VisitType_(const PrimTypeNode* node) override;
+ Doc VisitType_(const PointerTypeNode* node) override;
+ Doc VisitType_(const TupleTypeNode* node) override;
+
+ Doc PrintIRModule(const IRModule& module);
+ Doc PrintPrimFunc(const PrimFunc& primFunc);
+ Doc PrintArray(const ArrayNode* op);
+ Doc PrintIterVar(const IterVarNode* op);
+ Doc PrintRange(const RangeNode* op);
+ Doc PrintBuffer(const BufferNode* op);
+ Doc PrintString(const StringObj* op) {
+ return Doc::StrLiteral(op->data);
+ }
+
+ /*!
+ * \brief special method to print out data type
+ * \param dtype The data type
+ */
+ static Doc PrintDType(DataType dtype);
+ /*!
+ * \brief special method to print out const scalar
+ * \param dtype The data type
+ * \param data The pointer to hold the data.
+ */
+ template <typename T>
+ static Doc PrintConstScalar(DataType dtype, const T& data);
+ Doc GetUniqueName(std::string prefix);
+ Doc AllocVar(const Var& var);
+ Doc AllocBuf(const Buffer& buffer);
+ /*!
+ * \brief special method to render vectors of docs with a separator
+ * \param vec vector of docs
+ * \param sep separator
+ */
+ static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep);
+ Doc PrintBody(const Stmt& body, bool indent = true);
+};
+
+} // namespace tir
+} // namespace tvm
+
+namespace tvm {
+
+class TextPrinter {
+ public:
+ explicit TextPrinter(bool show_meta_data,
+ const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate)
+ : show_meta_data_(show_meta_data), annotate_(annotate),
+ relay_text_printer_(show_meta_data, &meta_, annotate),
+ tir_text_printer_(show_meta_data, &meta_) {}
+
+ /*! \brief whether show meta data */
+ bool show_meta_data_;
+ /*! \brief meta data context */
+ TextMetaDataContext meta_;
+ /*! \brief additional comment function */
+ runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
+ /*! \brief Relay Text Printer */
+ relay::RelayTextPrinter relay_text_printer_;
+ /*! \brief TIR Text Printer */
+ tir::TIRTextPrinter tir_text_printer_;
+
+ Doc PrintFinal(const ObjectRef& node) {
+ Doc doc;
+ if (node->IsInstance<IRModuleNode>()) {
+ doc << PrintMod(Downcast<IRModule>(node));
+ } else if (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>()
+ || node->IsInstance<tir::StmtNode>()) {
+ doc << tir_text_printer_.Print(node);
+ } else {
+ doc << relay_text_printer_.PrintFinal(node);
+ }
+ if (!meta_.empty()) {
+ doc << Doc::NewLine();
+ if (show_meta_data_) {
+ // append meta data in the end.
+ doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection();
+ } else {
+ doc << "// meta data omitted. you can use show_meta_data=True to include meta data";
+ }
+ }
+ return doc;
+ }
+
+ Doc PrintMod(const IRModule& mod);
+};
+} // namespace tvm
+
+#endif // TVM_PRINTER_TEXT_PRINTER_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ * that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+#include "text_printer.h"
+
+namespace tvm {
+namespace tir {
+
+Doc TIRTextPrinter::Print(const ObjectRef& node) {
+ if (!node.defined()) return Doc::Text("(nullptr)");
+ if (node->IsInstance<StmtNode>()) {
+ return VisitStmt(Downcast<Stmt>(node));
+ } else if (node->IsInstance<AnyNode>()) {
+ return Doc::Text("?");
+ } else if (node->IsInstance<PrimExprNode>()) {
+ return VisitExpr(Downcast<PrimExpr>(node));
+ } else if (node->IsInstance<TypeNode>()) {
+ return VisitType(Downcast<Type>(node));
+ } else if (node->IsInstance<PrimFuncNode>()) {
+ return PrintPrimFunc(Downcast<PrimFunc>(node));
+ } else if (node->IsInstance<IRModuleNode>()) {
+ return PrintIRModule(Downcast<IRModule>(node));
+ } else if (node->IsInstance<ArrayNode>()) {
+ return PrintArray(node.as<ArrayNode>());
+ } else if (node->IsInstance<IterVarNode>()) {
+ return PrintIterVar(node.as<IterVarNode>());
+ } else if (node->IsInstance<RangeNode>()) {
+ return PrintRange(node.as<RangeNode>());
+ } else if (node->IsInstance<BufferNode>()) {
+ return PrintBuffer(node.as<BufferNode>());
+ } else if (node->IsInstance<StringObj>()) {
+ return PrintString(node.as<StringObj>());
+ } else {
+ return this->meta_->GetMetaNode(node);
+ }
+}
+
+Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
+ const auto* op = primFunc.operator->();
+ const auto& signature = op->func_type_annotation();
+ // collect Meta in DictAttr
+ for (const auto& it : primFunc->attrs->dict) {
+ meta_collector_.Collect(it.second);
+ }
+ // collect buffers in buffer_map
+ memo_var_.clear();
+ memo_buf_.clear();
+ for (const auto& it : op->buffer_map) {
+ memo_buf_[it.second] = AllocBuf(it.second);
+ }
+ // print PrimFunc
+ Doc doc;
+ doc << "primfn" << "(";
+ // print params and its type annotation
+ std::vector<Doc> params;
+ for (const auto& param : op->params) {
+ params.push_back(Print(param));
+ }
+ Doc sep;
+ doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")";
+ // print return type
+ doc << " -> " << Print(signature->ret_type);
+ // print attr
+ Doc attr_doc;
+ std::vector<Doc> attr_docs;
+ for (const auto& it : op->attrs->dict) {
+ attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
+ }
+ attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
+ doc << Doc::Indent(2, attr_doc);
+ // print all the buffers in the tree
+ Doc buffer_doc;
+ std::vector<Doc> buffer_docs;
+ for (const auto& it : memo_buf_) {
+ const auto& buf = it.first;
+ buffer_docs.push_back(Print(buf)
+ << Doc::Text(": Buffer(") << Print(buf->data) << ", "
+ << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", "
+ << Print(buf->strides));
+ if (!is_zero(buf->elem_offset)) {
+ buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset);
+ }
+ if (buf->scope != "global") {
+ buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope);
+ }
+ if (buf->data_alignment != 128) {
+ buffer_docs.back() << ", align=" << buf->data_alignment;
+ }
+ if (buf->offset_factor != 1) {
+ buffer_docs.back() << ", offset_factor=" << buf->offset_factor;
+ }
+ if (buf->buffer_type != 1) {
+ buffer_docs.back() << ", type=" << Doc::StrLiteral("auto");
+ }
+ buffer_docs.back() << ")";
+ }
+ buffer_doc << Doc::NewLine() << "buffers = {";
+ buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine()));
+ doc << Doc::Indent(2, buffer_doc) << "}";
+ // print buffer_map
+ std::vector<Doc> buffer_map_doc;
+ for (const auto& it : op->buffer_map) {
+ buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second));
+ }
+ doc << Doc::Indent(2, Doc::NewLine()
+ << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
+ doc << PrintBody(op->body);
+ return doc;
+}
+
+Doc TIRTextPrinter::PrintIRModule(const IRModule& module) {
+ const auto* op = module.operator->();
+ Doc doc;
+
+ Doc body;
+ body << Doc::NewLine();
+ std::vector<Doc> functions;
+ for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
+ if ((*it).second.as<PrimFuncNode>()) {
+ functions.push_back(Print((*it).second));
+ }
+ }
+ body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
+ doc << Doc::Indent(0, body);
+ return doc;
+}
+
+Doc TIRTextPrinter::PrintArray(const ArrayNode* op) {
+ Doc doc;
+ doc << '[';
+ for (size_t i = 0; i < op->data.size(); ++i) {
+ if (i != 0) {
+ doc << ", ";
+ }
+ doc << Print(op->data[i]);
+ }
+ doc << ']';
+ return doc;
+}
+
+Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) {
+ Doc doc;
+ doc << "IterVar(" << Print(op->var);
+ if (op->dom.defined()) {
+ doc << ", [" << Print(op->dom) << "], ";
+ } else {
+ doc << ", " << Print(op->dom) << ", ";
+ }
+ doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", ";
+ doc << Doc::StrLiteral(op->thread_tag) << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::PrintRange(const RangeNode* op) {
+ return Print(op->min) << ":" << Print(op->min + op->extent);
+}
+
+Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) {
+ const Buffer& buffer = GetRef<Buffer>(op);
+ CHECK_GT(memo_buf_.count(buffer), 0);
+ return meta_->InMeta(buffer) ? meta_->GetMetaNode(buffer) : memo_buf_[buffer];
+}
+
+Doc TIRTextPrinter::VisitExprDefault_(const Object* op) {
+ return this->meta_->GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) {
+ return this->meta_->GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) {
+ return PrintConstScalar<int64_t>(op->dtype, op->value);
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) {
+ return PrintConstScalar<double>(op->dtype, op->value);
+}
+
+Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); }
+
+Doc TIRTextPrinter::VisitExpr_(const CastNode* op) {
+ Doc doc;
+ doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const VarNode* op) {
+ const Var& var = GetRef<Var>(op);
+ return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef<Var>(op));
+}
+
+#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString) \
+ Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \
+ Doc doc; \
+ doc << "(" << Print(op->a) << OpString; \
+ doc << Print(op->b) << ")"; \
+ return doc; \
+ }
+
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " && ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " || ")
+
+Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) {
+ Doc doc;
+ doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) {
+ Doc doc;
+ doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MinNode* op) {
+ Doc doc;
+ doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) {
+ Doc doc;
+ doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const NotNode* op) {
+ Doc doc;
+ doc << "!" << Print(op->a);
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) {
+ Doc doc;
+ doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
+ << Print(op->false_value);
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) {
+ Doc doc;
+ doc << Print(op->buffer) << Print(op->indices);
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) {
+ Doc doc;
+ doc << "(" << PrintDType(op->dtype) << "*)"
+ << Print(op->buffer_var) << "[" << Print(op->index) << "])";
+ if (!is_one(op->predicate)) {
+ doc << " if " << Print(op->predicate);
+ }
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const RampNode* op) {
+ Doc doc;
+ doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) {
+ Doc doc;
+ doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const LetNode* op) {
+ Doc doc;
+ doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body);
+ return doc;
+}
+
+inline const char* CallType2String(CallNode::CallType t) {
+ switch (t) {
+ case CallNode::Extern:return "extern";
+ case CallNode::ExternCPlusPlus:return "extern_cpp";
+ case CallNode::PureExtern:return "pure_extern";
+ case CallNode::Halide:return "halide";
+ case CallNode::Intrinsic:return "intrin";
+ case CallNode::PureIntrinsic:return "pure_intrin";
+ }
+ LOG(FATAL) << "Unknown CallType";
+ return "Unknown";
+}
+
+Doc TIRTextPrinter::VisitExpr_(const CallNode* op) {
+ Doc doc;
+ doc << "@" << Doc::Text(op->name) << "(";
+ std::vector<Doc> args;
+ for (const auto& arg : op->args) {
+ args.push_back(Print(arg));
+ }
+ doc << PrintSep(args, Doc::Text(", "))
+ << ", dtype=" << PrintDType(op->dtype)
+ << ", type=" << Doc::StrLiteral(CallType2String(op->call_type))
+ << ", index=" << op->value_index << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const ShuffleNode* op) {
+ Doc doc;
+ doc << "shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) {
+ Doc doc;
+ doc << "reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis)
+ << ", " << op->value_index << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) {
+ Doc doc;
+ doc << "let " << Print(op->var) << " = " << Print(op->value) << PrintBody(op->body);
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) {
+ Doc doc;
+ meta_collector_.Collect(op->node);
+ doc << "attr [" << Print(op->node) << "] " << Doc::StrLiteral(op->attr_key) << " = "
+ << Print(op->value);
+ if (op->body->IsInstance<SeqStmtNode>()) {
+ doc << PrintBody(op->body);
+ } else {
+ doc << ";" << Doc::NewLine() << Print(op->body);
+ }
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) {
+ Doc doc;
+ doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")"
+ << PrintBody(op->body);
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const StoreNode* op) {
+ Doc doc;
+ doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value);
+ if (!is_one(op->predicate)) {
+ doc << " if " << Print(op->predicate);
+ }
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) {
+ Doc doc;
+ doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value);
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) {
+ Doc doc;
+ doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", "
+ << Print(op->condition) << PrintBody(op->body) << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
+ Doc doc;
+ doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", "
+ << Print(op->extents) << ")";
+ if (!is_one(op->condition)) {
+ doc << " if " << Print(op->condition);
+ }
+ if (op->body->IsInstance<SeqStmtNode>()) {
+ doc << PrintBody(op->body);
+ } else {
+ doc << ";" << Doc::NewLine() << Print(op->body);
+ }
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const FreeNode* op) {
+ Doc doc;
+ doc << "free(" << Print(op->buffer_var) << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
+ Doc doc;
+ doc << "if " << Print(op->condition) << PrintBody(op->then_case);
+ if (!is_one(op->condition) && op->else_case.defined()) {
+ doc << " else" << PrintBody(op->else_case);
+ }
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) {
+ std::vector<Doc> stmts;
+ Doc seq_doc, doc;
+ for (Stmt stmt : op->seq) {
+ seq_doc << Doc::NewLine() << Print(stmt);
+ }
+ doc << " {" << Doc::Indent(2, seq_doc) << Doc::NewLine() << "}";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) {
+ Doc doc;
+ doc << Print(op->value);
+ return doc;
+}
+
+inline const char* ForType2String(ForType t) {
+ switch (t) {
+ case ForType::Serial:return "serial";
+ case ForType::Parallel:return "parallel";
+ case ForType::Vectorized:return "vectorized";
+ case ForType::Unrolled:return "unroll";
+ }
+ LOG(FATAL) << "Unknown ForType";
+ return "Unknown";
+}
+
+Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
+ Doc doc;
+ doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", "
+ << Print(op->min + op->extent) << ")";
+ if (op->for_type != ForType::Serial) {
+ doc << " " << Doc::StrLiteral(ForType2String(op->for_type));
+ }
+ doc << PrintBody(op->body);
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) {
+ Doc doc;
+ doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitType_(const PrimTypeNode* node) {
+ Doc doc;
+ doc << PrintDType(node->dtype);
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitType_(const PointerTypeNode* node) {
+ Doc doc;
+ doc << "Pointer(" << Print(node->element_type) << ")";
+ return doc;
+}
+
+Doc TIRTextPrinter::VisitType_(const TupleTypeNode* node) {
+ std::vector<Doc> fields;
+ for (Type field : node->fields) {
+ fields.push_back(Print(field));
+ }
+ Doc doc;
+ doc << "(" << Doc::Concat(fields);
+ // conform to python tuple format (1,)
+ if (node->fields.size() == 1) {
+ doc << ",";
+ }
+ return doc << ")";
+}
+
+Doc TIRTextPrinter::PrintDType(DataType dtype) {
+ return Doc::Text(runtime::DLDataType2String(dtype));
+}
+
+template <typename T>
+Doc TIRTextPrinter::PrintConstScalar(DataType dtype, const T& data) {
+ Doc doc;
+ std::ostringstream os;
+ os << data;
+ if (dtype == DataType::Int(32)) {
+ doc << Doc::Text(os.str());
+ } else {
+ if (dtype.bits() == 1 && dtype.lanes() == 1 && dtype.code() == kDLUInt) {
+ doc << ((data == 1) ? "True" : "False");
+ return doc;
+ }
+ doc << Doc::Text(os.str());
+ switch (dtype.code()) {
+ case kDLInt: doc << "i"; break;
+ case kDLUInt: doc << "u"; break;
+ case kDLFloat: doc << "f"; break;
+ }
+ doc << Doc::Text(std::to_string(dtype.bits()));
+ if (dtype.lanes() != 1) doc << "x" << Doc::Text(std::to_string(dtype.lanes()));
+ }
+ return doc;
+}
+
+Doc TIRTextPrinter::GetUniqueName(std::string prefix) {
+ // std::replace(prefix.begin(), prefix.end(), '.', '_');
+ std::string unique_prefix = prefix;
+ auto it = name_alloc_map_.find(prefix);
+ if (it != name_alloc_map_.end()) {
+ while (name_alloc_map_.count(
+ unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) {}
+ }
+ name_alloc_map_[unique_prefix] = 0;
+ return Doc::Text(unique_prefix);
+}
+
+Doc TIRTextPrinter::AllocVar(const Var& var) {
+ const auto& it = memo_var_.find(var);
+ if (it != memo_var_.end()) {
+ return it->second;
+ }
+ std::string name = var->name_hint;
+ if (name.length() == 0 || !std::isalpha(name[0])) {
+ name = "v" + name;
+ }
+ Doc val = GetUniqueName(name);
+ memo_var_[var] = val;
+ return val << ": " << Print(GetType(var));
+}
+
+Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) {
+ const auto& it = memo_buf_.find(buffer);
+ if (it != memo_buf_.end()) {
+ return it->second;
+ }
+ std::string name = buffer->name;
+ if (name.length() == 0 || !std::isalpha(name[0])) {
+ name = "buf_" + name;
+ }
+ Doc val = GetUniqueName(name);
+ memo_buf_[buffer] = val;
+ return val;
+}
+
+Doc TIRTextPrinter::PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
+ Doc seq;
+ if (vec.size() != 0) {
+ seq = vec[0];
+ for (size_t i = 1; i < vec.size(); i++) {
+ seq << sep << vec[i];
+ }
+ }
+ return seq;
+}
+
+Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) {
+ Doc doc;
+ if (body->IsInstance<SeqStmtNode>()) return Print(body);
+ doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}";
+ return doc;
+}
+
+} // namespace tir
+} // namespace tvm
e2 = (tvm.te.max(5, a * 4) < 0)
res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
- assert str(res2.max_value) == "neg_inf"
- assert str(res2.min_value) == "pos_inf"
+ assert str(res2.max_value) == "neg_inf: handle"
+ assert str(res2.min_value) == "pos_inf: handle"
# expression containing variable a is on rhs
e2 = (zero < tvm.te.max(5, a * 4))
res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
- assert str(res2.max_value) == "neg_inf"
- assert str(res2.min_value) == "pos_inf"
+ assert str(res2.max_value) == "neg_inf: handle"
+ assert str(res2.min_value) == "pos_inf: handle"
e3 = (-b)+a*c-d
res3 = tvm.arith.deduce_bound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
# Unsatisfiable `EQ`, variable as one of the Operand
res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s})
- assert str(res5.max_value) == "neg_inf"
- assert str(res5.min_value) == "pos_inf"
+ assert str(res5.max_value) == "neg_inf: handle"
+ assert str(res5.min_value) == "pos_inf: handle"
# variable `a` on the RHS side
res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {})
# Unsatisfiable Mul in `EQ`
e5 = (4 * a == b)
res9 = tvm.arith.deduce_bound(a, e5, {b: b_s}, {})
- assert str(res9.max_value) == "neg_inf"
- assert str(res9.min_value) == "pos_inf"
+ assert str(res9.max_value) == "neg_inf: handle"
+ assert str(res9.min_value) == "pos_inf: handle"
# Unsatisfiable Mul in `EQ`
res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0)
- assert str(res10.max_value) == "neg_inf"
- assert str(res10.min_value) == "pos_inf"
+ assert str(res10.max_value) == "neg_inf: handle"
+ assert str(res10.min_value) == "pos_inf: handle"
def test_check():
stmt = tvm.lower(s, [A, C])["main"].body
assert isinstance(stmt.body.body, tvm.tir.Evaluate)
assert len(stmt.body.body.value.args) == 5
- assert str(stmt.body.body.value.args[3]) == "(i*i)"
- assert str(stmt.body.body.value.args[4]) == "(i + j)"
+ assert str(stmt.body.body.value.args[3]) == "(i: int32*i)"
+ assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)"
if __name__ == "__main__":
test_singleton()
a = te.var('a')
b = te.var('b')
c = a + b
- assert str(c) == '(%s + %s)' % (a.name, b.name)
+ assert str(c) == '(%s: int32 + %s: int32)' % (a.name, b.name)
def test_stmt():
assert False
except ValueError:
pass
- assert str(tvm.tir.any(x < y)) == '(%s < %s)' % (x.name, y.name)
- assert str(tvm.tir.any(x < y, x > z)) == '((%s < %s) || (%s > %s))' % (
+ assert str(tvm.tir.any(x < y)) == '(%s: int32 < %s: int32)' % (x.name, y.name)
+ assert str(tvm.tir.any(x < y, x > z)) == '((%s: int32 < %s: int32) || (%s > %s: int32))' % (
x.name, y.name, x.name, z.name)
assert str(tvm.tir.any(x < y, y > z + 1, x < z * 2)) == \
- '(((%s < %s) || (%s > (%s + 1))) || (%s < (%s*2)))' % (
+ '(((%s: int32 < %s: int32) || (%s > (%s: int32 + 1))) || (%s < (%s*2)))' % (
x.name, y.name, y.name, z.name, x.name, z.name)
assert False
except ValueError:
pass
- assert str(tvm.tir.all(x < y)) == '(%s < %s)' % (x.name, y.name)
- assert str(tvm.tir.all(x < y, x > z)) == '((%s < %s) && (%s > %s))' % (
+ assert str(tvm.tir.all(x < y)) == '(%s: int32 < %s: int32)' % (x.name, y.name)
+ assert str(tvm.tir.all(x < y, x > z)) == '((%s: int32 < %s: int32) && (%s > %s: int32))' % (
x.name, y.name, x.name, z.name)
assert str(tvm.tir.all(x < y, y > z + 1, x < z * 2)) == \
- '(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % (
+ '(((%s: int32 < %s: int32) && (%s > (%s: int32 + 1))) && (%s < (%s*2)))' % (
x.name, y.name, y.name, z.name, x.name, z.name)
def test_bitwise():
x = te.var('x')
y = te.var('y')
- assert str(x << y) == 'shift_left(x, y)'
- assert str(x >> y) == 'shift_right(x, y)'
- assert str(x & y) == 'bitwise_and(x, y)'
- assert str(x | y) == 'bitwise_or(x, y)'
- assert str(x ^ y) == 'bitwise_xor(x, y)'
- assert str(10 & x) == 'bitwise_and(10, x)'
- assert str(10 | x) == 'bitwise_or(10, x)'
- assert str(10 ^ x) == 'bitwise_xor(10, x)'
- assert str(10 >> x) == 'shift_right(10, x)'
- assert str(10 << x) == 'shift_left(10, x)'
- assert str(10 % x) == 'floormod(10, x)'
- assert str(~x) == 'bitwise_not(x)'
+ assert str(x << y) == '@shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
+ assert str(x >> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
+ assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
+ assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
+ assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)'
+ assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
+ assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
+ assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
+ assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
+ assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin", index=0)'
+ assert str(10 % x) == 'floormod(10, x: int32)'
+ assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin", index=0)'
assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2"
assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2"
assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2"
def test_isnan():
x = te.var('x', 'float32')
- assert str(tvm.tir.isnan(x)) == 'isnan(x)'
+ assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin", index=0)'
assert str(tvm.tir.isnan(x).dtype) == 'bool'
y = te.var('y', 'float16')
- assert str(tvm.tir.isnan(y)) == 'isnan(float32(y))'
+ assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin", index=0)'
z = te.var('z', 'int32')
- assert str(tvm.tir.isnan(z)) == '(bool)0'
+ assert str(tvm.tir.isnan(z)) == 'False'
k = te.var('k', 'int8x2')
assert str(tvm.tir.isnan(k).dtype) == 'uint1x2'