From c69092ae0d39f9a5161f098d933d0a2ec570a2c5 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 13 Jan 2020 16:18:01 -0800 Subject: [PATCH] [REFACTOR][IR] Unified IR IRModule structure. (#4699) This PR brings relay::Module as the unified IRModule structure. IRModule will be used as the basic unit for transformations through out the stack. - Rename relay::Module -> IRModule - Move relay/module.h -> ir/module.h - ModuleNode::FromExpr -> IRModule::FromExpr - FromText -> IRModule::FromText --- docs/dev/relay_pass_infra.rst | 2 +- include/tvm/{relay => ir}/module.h | 145 ++++++++-------- include/tvm/ir/type_relation.h | 7 +- include/tvm/ir_pass.h | 2 +- include/tvm/relay/analysis.h | 18 +- include/tvm/relay/base.h | 3 - include/tvm/relay/error.h | 7 +- include/tvm/relay/feature.h | 7 +- include/tvm/relay/interpreter.h | 4 +- include/tvm/relay/transform.h | 22 +-- src/{relay => }/ir/module.cc | 164 +++++++++--------- src/relay/backend/build_module.cc | 10 +- src/relay/backend/compile_engine.cc | 4 +- .../backend/contrib/codegen_c/codegen.cc | 4 +- src/relay/backend/graph_runtime_codegen.cc | 2 +- src/relay/backend/interpreter.cc | 6 +- src/relay/backend/vm/compiler.cc | 6 +- src/relay/backend/vm/compiler.h | 6 +- src/relay/backend/vm/inline_primitives.cc | 10 +- src/relay/backend/vm/lambda_lift.cc | 10 +- src/relay/backend/vm/removed_unused_funcs.cc | 10 +- src/relay/ir/alpha_equal.cc | 4 +- src/relay/ir/error.cc | 4 +- src/relay/ir/pretty_printer.cc | 8 +- src/relay/op/tensor/transform.h | 1 + src/relay/pass/alter_op_layout.cc | 4 +- src/relay/pass/canonicalize_cast.cc | 4 +- src/relay/pass/canonicalize_ops.cc | 4 +- src/relay/pass/combine_parallel_conv2d.cc | 4 +- src/relay/pass/combine_parallel_dense.cc | 4 +- src/relay/pass/combine_parallel_op_batch.cc | 4 +- src/relay/pass/convert_layout.cc | 4 +- src/relay/pass/dead_code.cc | 4 +- src/relay/pass/device_annotation.cc | 4 +- src/relay/pass/eliminate_common_subexpr.cc | 4 +- src/relay/pass/eta_expand.cc | 10 +- src/relay/pass/feature.cc | 6 +- src/relay/pass/fold_constant.cc | 12 +- src/relay/pass/fold_scale_axis.cc | 8 +- src/relay/pass/fuse_ops.cc | 6 +- src/relay/pass/gradient.cc | 8 +- src/relay/pass/kind_check.cc | 8 +- src/relay/pass/legalize.cc | 4 +- src/relay/pass/match_exhaustion.cc | 18 +- src/relay/pass/partial_eval.cc | 16 +- src/relay/pass/pass_manager.cc | 40 ++--- src/relay/pass/print_ir.cc | 4 +- src/relay/pass/quantize/annotate.cc | 4 +- src/relay/pass/quantize/partition.cc | 4 +- src/relay/pass/quantize/realize.cc | 6 +- src/relay/pass/simplify_inference.cc | 4 +- src/relay/pass/to_a_normal_form.cc | 8 +- src/relay/pass/to_cps.cc | 30 ++-- src/relay/pass/to_graph_normal_form.cc | 4 +- src/relay/pass/type_infer.cc | 12 +- src/relay/pass/type_solver.cc | 6 +- src/relay/pass/type_solver.h | 4 +- src/relay/pass/util.cc | 22 +-- tests/cpp/relay_pass_type_infer_test.cc | 2 +- tests/cpp/relay_transform_sequential.cc | 6 +- 60 files changed, 384 insertions(+), 374 deletions(-) rename include/tvm/{relay => ir}/module.h (79%) rename src/{relay => }/ir/module.cc (70%) diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst index 57dcca1bb..60d2b7296 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/relay_pass_infra.rst @@ -353,7 +353,7 @@ registration. auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); // Create a module for optimization. - auto mod = relay::ModuleNode::FromExpr(fx); + auto mod = IRModule::FromExpr(fx); // Create a sequential pass. tvm::Array pass_seqs{ diff --git a/include/tvm/relay/module.h b/include/tvm/ir/module.h similarity index 79% rename from include/tvm/relay/module.h rename to include/tvm/ir/module.h index 4fc31eb00..735a5f50d 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/ir/module.h @@ -18,55 +18,41 @@ */ /*! - * \file tvm/relay/module.h - * \brief The global environment: contains information needed to - * compile & optimize Relay programs. + * \file tvm/ir/module.h + * \brief IRModule that holds the functions and type definitions. */ -#ifndef TVM_RELAY_MODULE_H_ -#define TVM_RELAY_MODULE_H_ - -#include -#include -#include -#include -#include +#ifndef TVM_IR_MODULE_H_ +#define TVM_IR_MODULE_H_ + +#include +#include +#include + #include #include #include #include namespace tvm { -namespace relay { - -struct Module; - -/*! \brief The global environment of Relay programs. - * - * The global environment contains the global - * information needed to compile a Relay program. - * - * It contains all global functions, and configuration - * options. +class IRModule; +/*! + * \brief IRModule that holds functions and type definitions. * - * Many operations require access to the global - * Module. We pass the Module by value - * in a functional style as an explicit argument, - * but we mutate the Module while optimizing - * Relay programs. + * IRModule is the basic unit for all IR transformations across the stack. * - * The functional style allows users to construct custom - * environments easily, for example each thread can store - * a Module while auto-tuning. + * Many operations require access to the global IRModule. + * We pass the IRModule by value in a functional style as an explicit argument, + * but we mutate the Module while optimizing programs. + * \sa IRModule */ - -class ModuleNode : public RelayNode { +class IRModuleNode : public Object { public: /*! \brief A map from ids to all global functions. */ tvm::Map functions; /*! \brief A map from global type vars to ADT type data. */ tvm::Map type_definitions; - ModuleNode() {} + IRModuleNode() {} void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("functions", &functions); @@ -75,10 +61,6 @@ class ModuleNode : public RelayNode { v->Visit("global_type_var_map_", &global_type_var_map_); } - TVM_DLL static Module make(tvm::Map global_funcs, - tvm::Map global_type_defs, - std::unordered_set imports = {}); - /*! * \brief Add a function to the global environment. * \param var The var of the global function. @@ -219,7 +201,7 @@ class ModuleNode : public RelayNode { * functions in another environment. * \param other The other environment. */ - TVM_DLL void Update(const Module& other); + TVM_DLL void Update(const IRModule& other); /*! * \brief Import Relay code from the file at path. @@ -243,24 +225,8 @@ class ModuleNode : public RelayNode { */ TVM_DLL std::unordered_set Imports() const; - /*! \brief Construct a module from a standalone expression. - * - * Allows one to optionally pass a global function map and - * map of type definitions as well. - * - * \param expr The expression to set as the main function to the module. - * \param global_funcs The global function map. - * \param type_definitions Map of global type definitions - * - * \returns A module with expr set as the main function. - */ - TVM_DLL static Module FromExpr( - const Expr& expr, - const tvm::Map& global_funcs = {}, - const tvm::Map& type_definitions = {}); - static constexpr const char* _type_key = "relay.Module"; - TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object); private: /*! \brief Helper function for registering a typedef's constructors */ @@ -285,27 +251,62 @@ class ModuleNode : public RelayNode { importing is idempotent for each module. */ std::unordered_set import_set_; + friend class IRModule; }; -struct Module : public ObjectRef { - Module() {} - explicit Module(ObjectPtr<::tvm::Object> p) : ObjectRef(p) {} - - ModuleNode* operator->() const { - return static_cast(get_mutable()); +/*! + * \brief Managed reference class to IRModuleNode. + * \sa IRModuleNode + */ +class IRModule : public ObjectRef { + public: + /*! + * \brief constructor + * \param functions Functions in the module. + * \param type_definitions Type definitions in the module. + * \param import_set Set of imported files in the module + */ + TVM_DLL explicit IRModule(tvm::Map functions, + tvm::Map type_definitions = {}, + std::unordered_set import_set = {}); + /*! \brief default constructor */ + IRModule() {} + /*! + * \brief constructor + * \param n The object pointer. + */ + explicit IRModule(ObjectPtr n) : ObjectRef(n) {} + /*! \return mutable pointers to the node. */ + IRModuleNode* operator->() const { + auto* ptr = get_mutable(); + CHECK(ptr != nullptr); + return static_cast(ptr); } + /*! + * \brief Construct a module from a standalone expression. + * + * Allows one to optionally pass a global function map and + * map of type definitions as well. + * + * \param expr The expression to set as the main function to the module. + * \param global_funcs The global function map. + * \param type_definitions Map of global type definitions + * + * \returns A module with expr set as the main function. + */ + TVM_DLL static IRModule FromExpr( + const RelayExpr& expr, + const tvm::Map& global_funcs = {}, + const tvm::Map& type_definitions = {}); - using ContainerType = ModuleNode; + /*! + * \brief Parse text format source file into an IRModule. + * \param text A string of Relay source code. + * \param source_path The path to the source file. + * \return A Relay module. + */ + TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path); }; -/*! \brief Parse Relay source into a module. - * \param source A string of Relay source code. - * \param source_name The name of the source file. - * \return A Relay module. - */ -Module FromText(const std::string& source, const std::string& source_name); - -} // namespace relay } // namespace tvm - -#endif // TVM_RELAY_MODULE_H_ +#endif // TVM_IR_MODULE_H_ diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index 71d1d9eb4..db3582ec6 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -30,9 +30,8 @@ namespace tvm { // TODO(tqchen): remove after migrate Module to ir. -namespace relay { -struct Module; -} +class IRModule; + /*! * \brief reporter that reports back to the @@ -76,7 +75,7 @@ class TypeReporterNode : public Object { * \brief Retrieve the current global module. * \return The global module. */ - TVM_DLL virtual relay::Module GetModule() = 0; + TVM_DLL virtual IRModule GetModule() = 0; // solver is not serializable. void VisitAttrs(tvm::AttrVisitor* v) {} diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 36ca03f5b..891d3245c 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -71,7 +71,7 @@ Stmt CanonicalSimplify(Stmt stmt, * \return Canonicalized expression. */ TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr, - Map vrange = Map()); + Map vrange = Map()); /*! * \brief Deep compare lhs and rhs diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index 8c14f024f..87dd5b408 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -26,7 +26,7 @@ #include #include -#include +#include #include #include @@ -49,7 +49,7 @@ namespace relay { * * \return The kind of the passed type. */ -TVM_DLL Kind KindCheck(const Type& t, const Module& mod); +TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod); /*! * \brief Check whether an expression is constant. @@ -188,7 +188,7 @@ TVM_DLL tvm::Array AllVars(const Expr& expr); * * \return List of free vars, in the PostDFS order visited by expr. */ -TVM_DLL tvm::Array FreeTypeVars(const Expr& expr, const Module& mod); +TVM_DLL tvm::Array FreeTypeVars(const Expr& expr, const IRModule& mod); /*! * \brief Get free TypeVars from type t. @@ -201,7 +201,7 @@ TVM_DLL tvm::Array FreeTypeVars(const Expr& expr, const Module& mod); * * \return List of free type vars, in the PostDFS order visited by type. */ -TVM_DLL tvm::Array FreeTypeVars(const Type& t, const Module& mod); +TVM_DLL tvm::Array FreeTypeVars(const Type& t, const IRModule& mod); /*! * \brief Get all bound type variables from expression expr. @@ -214,7 +214,7 @@ TVM_DLL tvm::Array FreeTypeVars(const Type& t, const Module& mod); * * \return List of bound type vars, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array BoundTypeVars(const Expr& expr, const Module& mod); +TVM_DLL tvm::Array BoundTypeVars(const Expr& expr, const IRModule& mod); /*! * \brief Get all bound type variables from type t. @@ -227,7 +227,7 @@ TVM_DLL tvm::Array BoundTypeVars(const Expr& expr, const Module& mod); * * \return List of bound type vars, in the PostDFS order visited by type. */ -TVM_DLL tvm::Array BoundTypeVars(const Type& t, const Module& mod); +TVM_DLL tvm::Array BoundTypeVars(const Type& t, const IRModule& mod); /*! * \brief Get all type variables in expression expr. @@ -237,7 +237,7 @@ TVM_DLL tvm::Array BoundTypeVars(const Type& t, const Module& mod); * * \return List of type vars, in the PostDFS order in the expression. */ -TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const Module& mod); +TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const IRModule& mod); /*! * \brief Get all type variables in type t. @@ -247,7 +247,7 @@ TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const Module& mod); * * \return List of type vars, in the PostDFS order visited by type. */ -TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); +TVM_DLL tvm::Array AllTypeVars(const Type& t, const IRModule& mod); /*! * \brief Collect the device mapping information of each expression. @@ -277,7 +277,7 @@ TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); * \return Returns a list of cases (as patterns) that are not handled by the match * expression. */ -TVM_DLL Array UnmatchedCases(const Match& match, const Module& mod); +TVM_DLL Array UnmatchedCases(const Match& match, const IRModule& mod); /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index f2db652c5..45d060a88 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -106,9 +106,6 @@ class Id : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode); }; - -struct Module; - } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 4cd999fb4..1c91b6e6c 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -24,13 +24,16 @@ #ifndef TVM_RELAY_ERROR_H_ #define TVM_RELAY_ERROR_H_ +#include + #include #include #include #include + #include "./base.h" #include "./expr.h" -#include "./module.h" + namespace tvm { namespace relay { @@ -146,7 +149,7 @@ class ErrorReporter { * \param module The module to report errors on. * \param use_color Controls whether to colorize the output. */ - void RenderErrors(const Module& module, bool use_color = true); + void RenderErrors(const IRModule& module, bool use_color = true); inline bool AnyErrors() { return errors_.size() != 0; diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index 8292344b3..744d7c4e1 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -26,6 +26,8 @@ #include #include +#include + #include namespace tvm { @@ -141,7 +143,6 @@ class FeatureSet { */ FeatureSet DetectFeature(const RelayExpr& expr); -struct Module; /*! * \brief Calculate the feature of the program. * @@ -149,7 +150,7 @@ struct Module; * * \return The FeatureSet. */ -FeatureSet DetectFeature(const Module& mod); +FeatureSet DetectFeature(const IRModule& mod); /*! * \brief Calculate the feature of the program. @@ -159,7 +160,7 @@ FeatureSet DetectFeature(const Module& mod); * * \return The FeatureSet. */ -inline FeatureSet DetectFeature(const Expr& expr, const Module& mod) { +inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) { return DetectFeature(expr) + DetectFeature(mod); } diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index dc35fc264..73868008a 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -35,7 +35,7 @@ #define TVM_RELAY_INTERPRETER_H_ #include -#include +#include #include #include @@ -62,7 +62,7 @@ namespace relay { * \return A function that takes in an expression and returns a value. */ runtime::TypedPackedFunc -CreateInterpreter(Module mod, DLContext context, Target target); +CreateInterpreter(IRModule mod, DLContext context, Target target); /*! \brief A Relay closure, i.e a scope and a function. */ class Closure; diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 1b6155f4c..d57740c99 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -61,7 +61,7 @@ #include #include #include -#include +#include #include #include #include @@ -236,7 +236,7 @@ class PassNode : public RelayNode { * * \return The transformed module. */ - Module operator()(const Module& mod) const { + IRModule operator()(const IRModule& mod) const { return this->operator()(mod, PassContext::Current()); } @@ -248,8 +248,8 @@ class PassNode : public RelayNode { * * \return The transformed module. */ - virtual Module operator()(const Module& mod, - const PassContext& pass_ctx) const = 0; + virtual IRModule operator()(const IRModule& mod, + const PassContext& pass_ctx) const = 0; void VisitAttrs(tvm::AttrVisitor* v) {} @@ -266,7 +266,7 @@ class Pass : public ObjectRef { * * \return The transformed module. */ - Module operator()(const Module& mod) const { + IRModule operator()(const IRModule& mod) const { const PassNode* node = operator->(); CHECK(node != nullptr); return node->operator()(mod); @@ -279,8 +279,8 @@ class Pass : public ObjectRef { * * \return The transformed module. */ - Module operator()(const Module& mod, - const PassContext& pass_ctx) const { + IRModule operator()(const IRModule& mod, + const PassContext& pass_ctx) const { const PassNode* node = operator->(); CHECK(node != nullptr); return node->operator()(mod, pass_ctx); @@ -329,7 +329,7 @@ class Sequential : public Pass { * \return The created module pass. */ Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, + const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, const tvm::Array& required); @@ -345,7 +345,7 @@ Pass CreateModulePass( * \return The created function pass. */ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< - Function(Function, Module, PassContext)>& pass_func, + Function(Function, IRModule, PassContext)>& pass_func, int opt_level, const std::string& name, const tvm::Array& required); @@ -624,7 +624,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); * \note this function mutates mod and is not thread-safe. */ TVM_DLL Function InferType(const Function& f, - const Module& mod, + const IRModule& mod, const GlobalVar& var); /*! @@ -689,7 +689,7 @@ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); * * \return the converted Function. */ -TVM_DLL Function ToCPS(const Function& f, const Module& mod); +TVM_DLL Function ToCPS(const Function& f, const IRModule& mod); /*! * \brief Remove the continuation argument of a CPS function. diff --git a/src/relay/ir/module.cc b/src/ir/module.cc similarity index 70% rename from src/relay/ir/module.cc rename to src/ir/module.cc index 0d2aebc8c..09abac712 100644 --- a/src/relay/ir/module.cc +++ b/src/ir/module.cc @@ -21,29 +21,33 @@ * \file module.cc * \brief The global module in Relay. */ -#include +#include +#include +// NOTE on dependencies on relay analysis. +// We calls into relay's analysis module to verify correctness +// when a relay function is presented. +// These dependency does not happen at the interface-level. +// And is only used to enhance developer experiences when relay +// functions are presented. #include #include + #include #include #include namespace tvm { -namespace relay { - -using tvm::NodePrinter; -using namespace runtime; -Module ModuleNode::make(tvm::Map global_funcs, - tvm::Map global_type_defs, - std::unordered_set imports) { - auto n = make_object(); - n->functions = std::move(global_funcs); - n->type_definitions = std::move(global_type_defs); +IRModule::IRModule(tvm::Map functions, + tvm::Map type_definitions, + std::unordered_set import_set) { + auto n = make_object(); + n->functions = std::move(functions); + n->type_definitions = std::move(type_definitions); n->global_type_var_map_ = {}; n->global_var_map_ = {}; n->constructor_tag_map_ = {}; - n->import_set_ = imports; + n->import_set_ = std::move(import_set); for (const auto& kv : n->functions) { // set global var map @@ -59,26 +63,25 @@ Module ModuleNode::make(tvm::Map global_funcs, n->global_type_var_map_.Set(kv.first->name_hint, kv.first); n->RegisterConstructors(kv.first, kv.second); } - - return Module(n); + data_ = std::move(n); } -bool ModuleNode::ContainGlobalVar(const std::string& name) const { +bool IRModuleNode::ContainGlobalVar(const std::string& name) const { return global_var_map_.find(name) != global_var_map_.end(); } -bool ModuleNode::ContainGlobalTypeVar(const std::string& name) const { +bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const { return global_type_var_map_.find(name) != global_type_var_map_.end(); } -GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const { +GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const { auto it = global_var_map_.find(name); CHECK(it != global_var_map_.end()) << "Cannot find global var " << name << " in the Module"; return (*it).second; } -tvm::Array ModuleNode::GetGlobalVars() const { +tvm::Array IRModuleNode::GetGlobalVars() const { std::vector global_vars; for (const auto& pair : global_var_map_) { global_vars.push_back(pair.second); @@ -86,7 +89,7 @@ tvm::Array ModuleNode::GetGlobalVars() const { return tvm::Array(global_vars); } -GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { +GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const { CHECK(global_type_var_map_.defined()); auto it = global_type_var_map_.find(name); CHECK(it != global_type_var_map_.end()) @@ -94,7 +97,7 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { return (*it).second; } -tvm::Array ModuleNode::GetGlobalTypeVars() const { +tvm::Array IRModuleNode::GetGlobalTypeVars() const { std::vector global_type_vars; for (const auto& pair : global_type_var_map_) { global_type_vars.push_back(pair.second); @@ -112,7 +115,7 @@ tvm::Array concat(const tvm::Array& l, const tvm::Array& r) { } // helper function to run type check -relay::Function RunTypeCheck(const Module& mod, +relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::Function f) { auto func = Downcast(relay::DeDup(std::move(f))); @@ -146,12 +149,12 @@ relay::Function RunTypeCheck(const Module& mod, return checked_func; } -void ModuleNode::Add(const GlobalVar& var, - const BaseFunc& f, - bool update) { +void IRModuleNode::Add(const GlobalVar& var, + const BaseFunc& f, + bool update) { BaseFunc checked_func = f; if (auto* ptr = f.as()) { - checked_func = RunTypeCheck(GetRef(this), + checked_func = RunTypeCheck(GetRef(this), var, GetRef(ptr)); } @@ -162,16 +165,16 @@ void ModuleNode::Add(const GlobalVar& var, if (functions.find(var) != functions.end()) { CHECK(update) << "Already have definition for " << var->name_hint; - auto old_type = functions[var].as()->checked_type(); - CHECK(AlphaEqual(type, old_type)) + auto old_type = functions[var].as()->checked_type(); + CHECK(relay::AlphaEqual(type, old_type)) << "Module#update changes type, not possible in this mode."; } var->checked_type_ = type; AddUnchecked(var, checked_func); } -void ModuleNode::AddUnchecked(const GlobalVar& var, - const BaseFunc& func) { +void IRModuleNode::AddUnchecked(const GlobalVar& var, + const BaseFunc& func) { this->functions.Set(var, func); auto it = global_var_map_.find(var->name_hint); @@ -185,7 +188,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var, global_var_map_.Set(var->name_hint, var); } -void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) { +void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) { // We hash the global type var name to use as a globally unique prefix for tags. // The hash will be used as the most significant byte of the tag, with the index of // the constructor in the less significant bytes @@ -197,19 +200,19 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& } } -void ModuleNode::AddTypeDef(const GlobalTypeVar& var, - const TypeData& type, - bool update) { +void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, + const TypeData& type, + bool update) { AddTypeDefUnchecked(var, type, update); // need to kind check at the end because the check can look up // a definition potentially - CHECK(relay::KindCheck(type, GetRef(this)) == Kind::kTypeData) + CHECK(relay::KindCheck(type, GetRef(this)) == TypeKind::kTypeData) << "Invalid or malformed typedata given to module: " << type; } -void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, - const TypeData& type, - bool update) { +void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, + const TypeData& type, + bool update) { this->type_definitions.Set(var, type); if (!update) { // set global type var map @@ -220,55 +223,55 @@ void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, RegisterConstructors(var, type); } -void ModuleNode::Update(const GlobalVar& var, - const BaseFunc& func) { +void IRModuleNode::Update(const GlobalVar& var, + const BaseFunc& func) { this->Add(var, func, true); } -void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var, - const TypeData& type) { +void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, + const TypeData& type) { this->AddTypeDef(var, type, true); } -void ModuleNode::Remove(const GlobalVar& var) { +void IRModuleNode::Remove(const GlobalVar& var) { auto functions_node = this->functions.CopyOnWrite(); functions_node->data.erase(var); auto gvar_node = global_var_map_.CopyOnWrite(); gvar_node->data.erase(var->name_hint); } -BaseFunc ModuleNode::Lookup(const GlobalVar& var) const { +BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { auto it = functions.find(var); CHECK(it != functions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } -BaseFunc ModuleNode::Lookup(const std::string& name) const { +BaseFunc IRModuleNode::Lookup(const std::string& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } -TypeData ModuleNode::LookupTypeDef(const GlobalTypeVar& var) const { +TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const { auto it = type_definitions.find(var); CHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } -TypeData ModuleNode::LookupTypeDef(const std::string& name) const { +TypeData IRModuleNode::LookupTypeDef(const std::string& name) const { GlobalTypeVar id = this->GetGlobalTypeVar(name); return this->LookupTypeDef(id); } -Constructor ModuleNode::LookupTag(const int32_t tag) { +Constructor IRModuleNode::LookupTag(const int32_t tag) { auto it = constructor_tag_map_.find(tag); CHECK(it != constructor_tag_map_.end()) << "There is no constructor with the tag " << tag; return (*it).second; } -void ModuleNode::Update(const Module& mod) { +void IRModuleNode::Update(const IRModule& mod) { // add functions and type defs. we add them unchecked first, so all definitions // can reference each other, independent of the order in which they were defined. for (auto pair : mod->functions) { @@ -285,11 +288,11 @@ void ModuleNode::Update(const Module& mod) { } } -Module ModuleNode::FromExpr( +IRModule IRModule::FromExpr( const RelayExpr& expr, const tvm::Map& global_funcs, const tvm::Map& type_definitions) { - auto mod = ModuleNode::make(global_funcs, type_definitions); + auto mod = IRModule(global_funcs, type_definitions); BaseFunc func; if (auto* func_node = expr.as()) { func = GetRef(func_node); @@ -303,7 +306,7 @@ Module ModuleNode::FromExpr( return mod; } -void ModuleNode::Import(const std::string& path) { +void IRModuleNode::Import(const std::string& path) { if (this->import_set_.count(path) == 0) { this->import_set_.insert(path); DLOG(INFO) << "Importing: " << path; @@ -311,102 +314,102 @@ void ModuleNode::Import(const std::string& path) { std::string file_contents { std::istreambuf_iterator(src_file), std::istreambuf_iterator() }; - auto mod_to_import = FromText(file_contents, path); + auto mod_to_import = IRModule::FromText(file_contents, path); Update(mod_to_import); } } -void ModuleNode::ImportFromStd(const std::string& path) { +void IRModuleNode::ImportFromStd(const std::string& path) { auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; std::string std_path = (*f)(); return this->Import(std_path + "/" + path); } -std::unordered_set ModuleNode::Imports() const { +std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } -Module FromText(const std::string& source, const std::string& source_name) { +IRModule IRModule::FromText(const std::string& text, const std::string& source_path) { auto* f = tvm::runtime::Registry::Get("relay.fromtext"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; - Module mod = (*f)(source, source_name); + IRModule mod = (*f)(text, source_path); return mod; } -TVM_REGISTER_NODE_TYPE(ModuleNode); +TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("relay._make.Module") .set_body_typed([](tvm::Map funcs, tvm::Map types) { - return ModuleNode::make(funcs, types, {}); + return IRModule(funcs, types, {}); }); TVM_REGISTER_GLOBAL("relay._module.Module_Add") .set_body([](TVMArgs args, TVMRetValue* ret) { - Module mod = args[0]; + IRModule mod = args[0]; GlobalVar var = args[1]; ObjectRef val = args[2]; bool update = args[3]; - CHECK(val->IsInstance()); + CHECK(val->IsInstance()); if (val->IsInstance()) { mod->Add(var, Downcast(val), update); } else if (val->IsInstance()) { GlobalVar gv = Downcast(val); - auto mod_copy = Module(make_object(*mod.operator->())); + auto mod_copy = IRModule(make_object(*mod.operator->())); mod_copy = relay::transform::EtaExpand( /* expand_constructor */ false, /* expand_global_var */ true)(mod_copy); auto func = mod_copy->Lookup(gv->name_hint); mod->Add(var, Downcast(func), update); } else { - auto func = FunctionNode::make({}, Downcast(val), Type(nullptr), {}); + auto func = relay::FunctionNode::make({}, Downcast(val), Type(nullptr), {}); mod->Add(var, func, update); } *ret = mod; }); TVM_REGISTER_GLOBAL("relay._module.Module_AddDef") -.set_body_method(&ModuleNode::AddTypeDef); +.set_body_method(&IRModuleNode::AddTypeDef); TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVar") -.set_body_method(&ModuleNode::GetGlobalVar); +.set_body_method(&IRModuleNode::GetGlobalVar); TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalVars") -.set_body_method(&ModuleNode::GetGlobalVars); +.set_body_method(&IRModuleNode::GetGlobalVars); TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVars") -.set_body_method(&ModuleNode::GetGlobalTypeVars); +.set_body_method(&IRModuleNode::GetGlobalTypeVars); TVM_REGISTER_GLOBAL("relay._module.Module_ContainGlobalVar") -.set_body_method(&ModuleNode::ContainGlobalVar); +.set_body_method(&IRModuleNode::ContainGlobalVar); TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVar") -.set_body_method(&ModuleNode::GetGlobalTypeVar); +.set_body_method(&IRModuleNode::GetGlobalTypeVar); TVM_REGISTER_GLOBAL("relay._module.Module_Lookup") -.set_body_typed([](Module mod, GlobalVar var) { +.set_body_typed([](IRModule mod, GlobalVar var) { return mod->Lookup(var); }); TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str") -.set_body_typed([](Module mod, std::string var) { +.set_body_typed([](IRModule mod, std::string var) { return mod->Lookup(var); }); TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef") -.set_body_typed([](Module mod, GlobalTypeVar var) { +.set_body_typed([](IRModule mod, GlobalTypeVar var) { return mod->LookupTypeDef(var); }); TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str") -.set_body_typed([](Module mod, std::string var) { +.set_body_typed([](IRModule mod, std::string var) { return mod->LookupTypeDef(var); }); TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag") -.set_body_typed([](Module mod, int32_t tag) { +.set_body_typed([](IRModule mod, int32_t tag) { return mod->LookupTag(tag); }); @@ -414,29 +417,28 @@ TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr") .set_body_typed([](RelayExpr e, tvm::Map funcs, tvm::Map type_defs) { - return ModuleNode::FromExpr(e, funcs, type_defs); + return IRModule::FromExpr(e, funcs, type_defs); }); TVM_REGISTER_GLOBAL("relay._module.Module_Update") -.set_body_typed([](Module mod, Module from) { +.set_body_typed([](IRModule mod, IRModule from) { mod->Update(from); }); TVM_REGISTER_GLOBAL("relay._module.Module_Import") -.set_body_typed([](Module mod, std::string path) { +.set_body_typed([](IRModule mod, std::string path) { mod->Import(path); }); TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd") -.set_body_typed([](Module mod, std::string path) { +.set_body_typed([](IRModule mod, std::string path) { mod->ImportFromStd(path); });; TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ModuleNode( " << node->functions << ")"; +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IRModuleNode( " << node->functions << ")"; }); -} // namespace relay } // namespace tvm diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index f64d55678..0458dfd55 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -294,7 +294,7 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return relay::Module The updated Relay module after optimization. */ - relay::Module Optimize( + IRModule Optimize( Function func, const TargetsMap& targets, const std::unordered_map& params) { @@ -303,7 +303,7 @@ class RelayBuildModule : public runtime::ModuleNode { } // Perform Module->Module optimizations. - relay::Module relay_module = relay::ModuleNode::FromExpr(func); + IRModule relay_module = IRModule::FromExpr(func); Array pass_seqs; @@ -408,8 +408,8 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return updated_module The updated module after device annotation. */ - relay::Module RunDeviceAnnotationPass(const relay::Module& relay_module, - int fallback_device) { + IRModule RunDeviceAnnotationPass(const IRModule& relay_module, + int fallback_device) { UpdateHeterogeneousInputs(fallback_device); auto rewrite = transform::RewriteAnnotatedOps(fallback_device); auto updated_module = rewrite(relay_module); @@ -461,7 +461,7 @@ class RelayBuildModule : public runtime::ModuleNode { Function func, const std::unordered_map& params) { // Optimize input Relay Function and returns Relay Module - relay::Module relay_module = Optimize(func, targets_, params); + IRModule relay_module = Optimize(func, targets_, params); // Get the updated function. func = Downcast(relay_module->Lookup("main")); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index e95e03bb2..d4a7cb1f2 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -613,7 +613,7 @@ class CompileEngineImpl : public CompileEngineNode { } Array LowerExternalFunctions() { - std::unordered_map ext_mods; + std::unordered_map ext_mods; std::vector cached_ext_funcs; for (const auto& it : cache_) { auto src_func = it.first->source_func; @@ -623,7 +623,7 @@ class CompileEngineImpl : public CompileEngineNode { const tvm::ir::StringImmNode* code_gen = compiler.as(); CHECK(code_gen) << "No external codegen is set"; if (ext_mods.find(code_gen->value) == ext_mods.end()) { - ext_mods[code_gen->value] = relay::ModuleNode::make({}, {}); + ext_mods[code_gen->value] = IRModule({}, {}); } auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol); const tvm::ir::StringImmNode* symbol_name = ext_symbol.as(); diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 642dbb022..0504b2edf 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -186,8 +186,8 @@ class CSourceCodegen : public CSourceModuleCodegenBase { if (ref->IsInstance()) { GenCFunc(Downcast(ref)); - } else if (ref->IsInstance()) { - relay::Module mod = Downcast(ref); + } else if (ref->IsInstance()) { + IRModule mod = Downcast(ref); for (const auto& it : mod->functions) { GenCFunc(Downcast(it.second)); } diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 3ff72b3cc..2e18e46e4 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -24,7 +24,7 @@ #include #include -#include +#include #include #include diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 432ad29b1..68af247aa 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -233,7 +233,7 @@ class Interpreter : public ExprFunctor, PatternFunctor { public: - Interpreter(Module mod, DLContext context, Target target) + Interpreter(IRModule mod, DLContext context, Target target) : mod_(mod), context_(context), target_(target), @@ -761,7 +761,7 @@ class Interpreter : private: // Module - Module mod_; + IRModule mod_; // For simplicity we only run the interpreter on a single context. // Context to run the interpreter on. DLContext context_; @@ -779,7 +779,7 @@ class Interpreter : TypedPackedFunc CreateInterpreter( - Module mod, + IRModule mod, DLContext context, Target target) { if (mod.defined()) { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b27c55bb9..ce3972f1c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -752,7 +752,7 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, if (name == "lower") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 3); - Module mod = args[0]; + IRModule mod = args[0]; this->Lower(mod, args[1], args[2]); }); } else if (name == "codegen") { @@ -813,7 +813,7 @@ relay::Function VMCompiler::BindParamsByName( return ret; } -void VMCompiler::Lower(Module mod, +void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { CHECK_EQ(targets.size(), 1) @@ -884,7 +884,7 @@ void VMCompiler::Lower(Module mod, } } -Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) { +IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) { Array pass_seqs; Array entry_functions{tvm::PrimExpr{"main"}}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 7efcb4ba8..00bde1153 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -62,7 +62,7 @@ using TargetsMap = Map; struct VMCompilerContext { // The module context for the compilation - Module module; + IRModule module; // Error reporter ErrorReporter err_reporter; // Map from a unique integer to ADT constructor tag @@ -107,7 +107,7 @@ class VMCompiler : public runtime::ModuleNode { to target mapping. For homogeneous compilation, it is a build target. * \param target_host Host compilation target, if target is device. */ - void Lower(Module mod, + void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); @@ -125,7 +125,7 @@ class VMCompiler : public runtime::ModuleNode { relay::Function func, const std::unordered_map& params); - Module OptimizeModule(const Module& mod, const TargetsMap& targets); + IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets); void PopulateGlobalMap(); diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index abd8d2902..9c1608c41 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -52,10 +52,10 @@ namespace vm { * (fn(...) { ... })(...) */ struct PrimitiveInliner : ExprMutator { - Module module_; + IRModule module_; std::unordered_map var_map; - explicit PrimitiveInliner(const Module& module) : module_(module) {} + explicit PrimitiveInliner(const IRModule& module) : module_(module) {} Expr VisitExpr_(const LetNode* let_node) { var_map.insert({let_node->var, VisitExpr(let_node->value)}); @@ -106,7 +106,7 @@ struct PrimitiveInliner : ExprMutator { } } - Module Inline() { + IRModule Inline() { auto gvar_funcs = module_->functions; for (auto pair : gvar_funcs) { auto global = pair.first; @@ -137,8 +137,8 @@ struct PrimitiveInliner : ExprMutator { namespace transform { Pass InlinePrimitives() { - runtime::TypedPackedFunc pass_func = - [=](Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relay::vm::PrimitiveInliner(m).Inline(); }; auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {}); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 34f03d2d7..2527e6005 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -60,7 +60,7 @@ Function MarkClosure(const Function& func) { */ class LambdaLifter : public ExprMutator { public: - explicit LambdaLifter(const Module& module) : module_(module) {} + explicit LambdaLifter(const IRModule& module) : module_(module) {} Expr VisitExpr_(const LetNode* let_node) final { bool is_lambda = false; @@ -184,7 +184,7 @@ class LambdaLifter : public ExprMutator { } } - Module Lift() { + IRModule Lift() { // There is an ordering bug here. auto glob_funcs = module_->functions; for (auto pair : glob_funcs) { @@ -204,7 +204,7 @@ class LambdaLifter : public ExprMutator { private: std::unordered_map lambda_map_; std::vector letrec_; - Module module_; + IRModule module_; }; } // namespace vm @@ -212,8 +212,8 @@ class LambdaLifter : public ExprMutator { namespace transform { Pass LambdaLift() { - runtime::TypedPackedFunc pass_func = - [=](Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relay::vm::LambdaLifter(m).Lift(); }; return CreateModulePass(pass_func, 1, "LambdaLift", {}); diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index 5de2e9283..419b09588 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -40,7 +40,7 @@ namespace vm { * \brief Detects all the functions that can be possibly called by entry function. */ struct CallTracer : ExprVisitor { - Module module_; + IRModule module_; // Record the names of all encountered functions std::unordered_set called_funcs_; @@ -48,7 +48,7 @@ struct CallTracer : ExprVisitor { // Record the expressions that are being visited std::unordered_set visiting_; - explicit CallTracer(const Module& module) + explicit CallTracer(const IRModule& module) : module_{module}, called_funcs_{}, visiting_{} {} @@ -99,7 +99,7 @@ struct CallTracer : ExprVisitor { * * \return The module with dead functions removed. */ -Module RemoveUnusedFunctions(const Module& module, +IRModule RemoveUnusedFunctions(const IRModule& module, Array entry_funcs) { std::unordered_set called_funcs{}; for (auto entry : entry_funcs) { @@ -122,8 +122,8 @@ Module RemoveUnusedFunctions(const Module& module, namespace transform { Pass RemoveUnusedFunctions(Array entry_functions) { - runtime::TypedPackedFunc pass_func = - [=](Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relay::vm::RemoveUnusedFunctions(m, entry_functions); }; return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {}); diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index bef045443..4398e448d 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -60,8 +60,8 @@ class AlphaEqualHandler: if (!rhs->IsInstance()) return false; return ExprEqual(Downcast(lhs), Downcast(rhs)); } - if (const auto lhsm = lhs.as()) { - auto rhsm = rhs.as(); + if (const auto lhsm = lhs.as()) { + auto rhsm = rhs.as(); if (!rhsm) return false; if (lhsm->functions.size() != rhsm->functions.size()) return false; for (const auto& p : lhsm->functions) { diff --git a/src/relay/ir/error.cc b/src/relay/ir/error.cc index 7c47c7441..5c23f3326 100644 --- a/src/relay/ir/error.cc +++ b/src/relay/ir/error.cc @@ -23,7 +23,7 @@ */ #include -#include +#include #include #include #include @@ -39,7 +39,7 @@ void RelayErrorStream::Raise() const { template using NodeMap = std::unordered_map; -void ErrorReporter::RenderErrors(const Module& module, bool use_color) { +void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { // First we pick an error reporting strategy for each error. // TODO(@jroesch): Spanned errors are currently not supported. for (auto err : this->errors_) { diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index c88c3a024..25650c776 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -33,7 +33,7 @@ #include #include -#include +#include #include #include "doc.h" #include "type_functor.h" @@ -242,8 +242,8 @@ class PrettyPrinter : return PrintType(Downcast(node), meta); } else if (node.as()) { return PrintPattern(Downcast(node), meta); - } else if (node.as()) { - return PrintMod(Downcast(node)); + } else if (node.as()) { + return PrintMod(Downcast(node)); } else { Doc doc; return doc << node; @@ -525,7 +525,7 @@ class PrettyPrinter : } } - Doc PrintMod(const Module& mod) { + Doc PrintMod(const IRModule& mod) { Doc doc; int counter = 0; // type definitions diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index a702db898..74a630cd9 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_ #define TVM_RELAY_OP_TENSOR_TRANSFORM_H_ +#include #include #include #include diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index 20a57faeb..b027e5edb 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -118,8 +118,8 @@ Expr AlterOpLayout(const Expr& expr) { namespace transform { Pass AlterOpLayout() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::alter_op_layout::AlterOpLayout(f)); }; return CreateFunctionPass(pass_func, 3, "AlterOpLayout", diff --git a/src/relay/pass/canonicalize_cast.cc b/src/relay/pass/canonicalize_cast.cc index 7a52dcf1b..85ca66ddf 100644 --- a/src/relay/pass/canonicalize_cast.cc +++ b/src/relay/pass/canonicalize_cast.cc @@ -129,8 +129,8 @@ Expr CanonicalizeCast(const Expr& e) { namespace transform { Pass CanonicalizeCast() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(CanonicalizeCast(f)); }; return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index 222651687..5def35bf2 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -69,8 +69,8 @@ Expr CanonicalizeOps(const Expr& e) { namespace transform { Pass CanonicalizeOps() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(CanonicalizeOps(f)); }; return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 5cb4c4597..530b19948 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -216,8 +216,8 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { namespace transform { Pass CombineParallelConv2D(uint64_t min_num_branches) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(CombineParallelConv2D(f, min_num_branches)); }; return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", diff --git a/src/relay/pass/combine_parallel_dense.cc b/src/relay/pass/combine_parallel_dense.cc index 81a4806fc..7cf161860 100644 --- a/src/relay/pass/combine_parallel_dense.cc +++ b/src/relay/pass/combine_parallel_dense.cc @@ -76,8 +76,8 @@ Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) { namespace transform { Pass CombineParallelDense(uint64_t min_num_branches) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(CombineParallelDense(f, min_num_branches)); }; return CreateFunctionPass(pass_func, 4, "CombineParallelDense", diff --git a/src/relay/pass/combine_parallel_op_batch.cc b/src/relay/pass/combine_parallel_op_batch.cc index b240ba75f..f1514b5d7 100644 --- a/src/relay/pass/combine_parallel_op_batch.cc +++ b/src/relay/pass/combine_parallel_op_batch.cc @@ -186,8 +186,8 @@ namespace transform { Pass CombineParallelOpBatch(const std::string& op_name, const std::string& batch_op_name, uint64_t min_num_branches) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(CombineParallelOpBatch(f, op_name, batch_op_name, diff --git a/src/relay/pass/convert_layout.cc b/src/relay/pass/convert_layout.cc index df711bf06..20007d289 100644 --- a/src/relay/pass/convert_layout.cc +++ b/src/relay/pass/convert_layout.cc @@ -128,8 +128,8 @@ Expr ConvertLayout(const Expr& expr, const std::string& desired_layout) { namespace transform { Pass ConvertLayout(const std::string& desired_layout) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::convert_op_layout::ConvertLayout(f, desired_layout)); }; return CreateFunctionPass( diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 05324af4c..deb26aac7 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -140,8 +140,8 @@ Expr DeadCodeElimination(const Expr& e, bool inline_once) { namespace transform { Pass DeadCodeElimination(bool inline_once) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(DeadCodeElimination(f, inline_once)); }; return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 3ef501d97..286305157 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -572,8 +572,8 @@ TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceAnnotationOps") namespace transform { Pass RewriteAnnotatedOps(int fallback_device) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(RewriteAnnotatedOps(f, fallback_device)); }; return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index 04aef0e3c..bf08b0715 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -87,8 +87,8 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) { namespace transform { Pass EliminateCommonSubexpr(PackedFunc fskip) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(EliminateCommonSubexpr(f, fskip)); }; return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index 5716da657..3d7ed9390 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -57,7 +57,7 @@ class TypeVarReplacer : public TypeMutator { */ class EtaExpander : public ExprMutator { public: - explicit EtaExpander(const Module& mod, bool expand_constructor, bool expand_global_var) + explicit EtaExpander(const IRModule& mod, bool expand_constructor, bool expand_global_var) : mod_(mod), type_var_replacer_(TypeVarReplacer()), expand_constructor_(expand_constructor), @@ -66,7 +66,7 @@ class EtaExpander : public ExprMutator { << "must expand at least one language feature"; } - Module Expand() { + IRModule Expand() { for (GlobalVar global_var : mod_->GetGlobalVars()) { const BaseFunc base_func = mod_->Lookup(global_var); if (auto* n = base_func.as()) { @@ -147,7 +147,7 @@ class EtaExpander : public ExprMutator { private: /*! \brief reference to module being expanded */ - const Module mod_; + const IRModule mod_; /*! \brief type variable replacer */ TypeVarReplacer type_var_replacer_; /*! \brief whether to expand constructor nodes */ @@ -161,8 +161,8 @@ class EtaExpander : public ExprMutator { namespace transform { Pass EtaExpand(bool expand_constructor, bool expand_global_var) { - runtime::TypedPackedFunc pass_func = - [=](Module mod, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand(); }; return CreateModulePass(pass_func, 1, "EtaExpand", {}); diff --git a/src/relay/pass/feature.cc b/src/relay/pass/feature.cc index ad0ce9509..43f28ae28 100644 --- a/src/relay/pass/feature.cc +++ b/src/relay/pass/feature.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include "pass_util.h" namespace tvm { @@ -89,7 +89,7 @@ FeatureSet DetectFeature(const Expr& expr) { return fd.fs; } -FeatureSet DetectFeature(const Module& mod) { +FeatureSet DetectFeature(const IRModule& mod) { FeatureSet fs = FeatureSet::No(); if (mod.defined()) { for (const auto& f : mod->functions) { @@ -99,7 +99,7 @@ FeatureSet DetectFeature(const Module& mod) { return fs; } -Array PyDetectFeature(const Expr& expr, const Module& mod) { +Array PyDetectFeature(const Expr& expr, const IRModule& mod) { FeatureSet fs = DetectFeature(expr) + DetectFeature(mod); return static_cast>(fs); } diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 352a1d77c..af4f4390a 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -79,7 +79,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.check_constant") // or make a more powerful partial evaluator. class ConstantFolder : public ExprMutator { public: - explicit ConstantFolder(FInterpreter executor, Module module) + explicit ConstantFolder(FInterpreter executor, IRModule module) : executor_(executor), module_(module), shape_of_op_(Op::Get("shape_of")), @@ -168,7 +168,7 @@ class ConstantFolder : public ExprMutator { // Internal constant checker ConstantChecker checker_; // Module - Module module_; + IRModule module_; // Cache the following ops for equivalence checking in this pass. const Op& shape_of_op_; @@ -209,7 +209,7 @@ class ConstantFolder : public ExprMutator { // TODO(@jroesch): fix this func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {}); } - auto mod = ModuleNode::make( + auto mod = IRModule( {}, module_->type_definitions, module_->Imports()); @@ -277,7 +277,7 @@ class ConstantFolder : public ExprMutator { }; -Expr FoldConstant(const Expr& expr, const Module& mod) { +Expr FoldConstant(const Expr& expr, const IRModule& mod) { DLContext ctx; ctx.device_type = kDLCPU; ctx.device_id = 0; @@ -292,8 +292,8 @@ Expr FoldConstant(const Expr& expr, const Module& mod) { namespace transform { Pass FoldConstant() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(FoldConstant(f, m)); }; return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index ddb3ac069..9d8d54522 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -949,8 +949,8 @@ Expr BackwardFoldScaleAxis(const Expr& data) { namespace transform { Pass ForwardFoldScaleAxis() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast( relay::fold_scale_axis::ForwardFoldScaleAxis(f)); }; @@ -962,8 +962,8 @@ TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis") .set_body_typed(ForwardFoldScaleAxis); Pass BackwardFoldScaleAxis() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast( relay::fold_scale_axis::BackwardFoldScaleAxis(f)); }; diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index c217d0653..bf38a48b2 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -970,15 +970,15 @@ class FuseMutator : private ExprMutator { } }; -Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) { +Expr FuseOps(const Expr& expr, int fuse_opt_level, const IRModule& module) { return FuseMutator().Transform(expr, fuse_opt_level); } namespace transform { Pass FuseOps(int fuse_opt_level) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; return Downcast(FuseOps(f, opt_level, m)); }; diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 838012890..78f17bcc6 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -67,7 +67,7 @@ Type WithGradientType(const Type&); /*! return an expression that represent differentiation of e (according to WithGradientType). * This version only work on first order code without control flow. */ -Expr FirstOrderGradient(const Expr& e, const Module& mod); +Expr FirstOrderGradient(const Expr& e, const IRModule& mod); Type WithGradientType(const Type& t) { // TODO(M.K.): stricter checking @@ -80,7 +80,7 @@ Type WithGradientType(const Type& t) { } //! \brief if the expression is a GlobalVar, transform to it's expression. -Expr DeGlobal(const Module& mod, const Expr& e) { +Expr DeGlobal(const IRModule& mod, const Expr& e) { if (const auto* x = e.as()) { BaseFunc base_func = mod->Lookup(GetRef(x)); if (auto* n = base_func.as()) { @@ -222,7 +222,7 @@ Type GradRetType(const Function& f) { return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); } -Expr FirstOrderGradient(const Expr& re, const Module& mod) { +Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { // Currently we first remove any global functions for the first // order case. auto e = DeGlobal(mod, re); @@ -532,7 +532,7 @@ bool MissingGrad(const Expr& e) { return false; } -Expr Gradient(const Expr& re, const Module& mod) { +Expr Gradient(const Expr& re, const IRModule& mod) { auto e = DeGlobal(mod, re); auto f = e.as(); CHECK(f) << "input need to be a function"; diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 081f132d0..2d207b5f6 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -41,10 +41,10 @@ namespace relay { using namespace tvm::runtime; struct KindChecker : TypeFunctor { - const Module& mod; + const IRModule& mod; ErrorReporter err_reporter; - explicit KindChecker(const Module& mod) : mod(mod), err_reporter() {} + explicit KindChecker(const IRModule& mod) : mod(mod), err_reporter() {} void ReportFatalError(const Error& err) { this->err_reporter.Report(err); @@ -177,7 +177,7 @@ struct KindChecker : TypeFunctor { } }; -Kind KindCheck(const Type& t, const Module& mod) { +Kind KindCheck(const Type& t, const IRModule& mod) { KindChecker kc(mod); return kc.Check(t); } @@ -185,7 +185,7 @@ Kind KindCheck(const Type& t, const Module& mod) { TVM_REGISTER_GLOBAL("relay._analysis.check_kind") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { - *ret = KindCheck(args[0], ModuleNode::make({}, {})); + *ret = KindCheck(args[0], IRModule({}, {})); } else { *ret = KindCheck(args[0], args[1]); } diff --git a/src/relay/pass/legalize.cc b/src/relay/pass/legalize.cc index 654c91e9e..63608084e 100644 --- a/src/relay/pass/legalize.cc +++ b/src/relay/pass/legalize.cc @@ -98,8 +98,8 @@ Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) { namespace transform { Pass Legalize(const std::string& legalize_map_attr_name) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::legalize::Legalize(f, legalize_map_attr_name)); }; return CreateFunctionPass(pass_func, 1, "Legalize", {ir::StringImmNode::make("InferType")}); diff --git a/src/relay/pass/match_exhaustion.cc b/src/relay/pass/match_exhaustion.cc index 6e18c630d..161b6827f 100644 --- a/src/relay/pass/match_exhaustion.cc +++ b/src/relay/pass/match_exhaustion.cc @@ -155,17 +155,17 @@ Array> CartesianProduct(Array> fields) { Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, const Pattern& cand, - const Module& mod); + const IRModule& mod); Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand, - const Module& mod); + const IRModule& mod); // Expands all wildcards in the candidate pattern once // Returns a list of all possible expansions. Array ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, - const Module& mod) { + const IRModule& mod) { if (auto clause_ctor = clause_pat.as()) { return ExpandWildcardsConstructor(GetRef(clause_ctor), cand, mod); } else { @@ -178,7 +178,7 @@ Array ExpandWildcards(const Pattern& clause_pat, // Returns a list of all possible expansions. Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, const Pattern& cand, - const Module& mod) { + const IRModule& mod) { auto gtv = Downcast(clause_ctor->constructor->belong_to); // for a wildcard node, create constructor nodes with wildcards for all args. @@ -228,7 +228,7 @@ Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, // Returns a list of all possible expansions. Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand, - const Module& mod) { + const IRModule& mod) { // for a wildcard node, create constructor nodes with wildcards for all args. if (cand.as()) { Array args; @@ -271,7 +271,7 @@ Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, * \return Returns a list of cases that are not handled by the match * expression. */ -Array UnmatchedCases(const Match& match, const Module& mod) { +Array UnmatchedCases(const Match& match, const IRModule& mod) { /* algorithm: * candidates = { Wildcard } * while candidates not empty { @@ -328,10 +328,10 @@ Array UnmatchedCases(const Match& match, const Module& mod) { // expose for testing only TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases") .set_body_typed( - [](const Match& match, const Module& mod_ref) { - Module call_mod = mod_ref; + [](const Match& match, const IRModule& mod_ref) { + IRModule call_mod = mod_ref; if (!call_mod.defined()) { - call_mod = ModuleNode::make({}, {}); + call_mod = IRModule({}, {}); } return UnmatchedCases(match, call_mod); }); diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index a2e8d06a6..4c343bd30 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -569,7 +569,7 @@ FInterpreter CPUInterpreter() { // in case we are already in a build context. With fresh_build_ctx(BuildConfig::Create()); - return CreateInterpreter(Module(nullptr), CPUContext(), target); + return CreateInterpreter(IRModule(nullptr), CPUContext(), target); } using FuncId = int; @@ -623,7 +623,7 @@ Function AsFunc(const Expr& e) { class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const Module& mod) : mod_(mod) { } + PartialEvaluator(const IRModule& mod) : mod_(mod) { } PStatic VisitExpr(const Expr& e, LetList* ll) final { PStatic ret = ExprFunctor::VisitExpr(e, ll); @@ -954,7 +954,7 @@ class PartialEvaluator : public ExprFunctor PStatic ConstEvaluate(const Expr& expr, LetList* ll) { std::vector passes = {transform::FuseOps(0), transform::InferType()}; - auto mod = ModuleNode::FromExpr(expr); + auto mod = IRModule::FromExpr(expr); auto seq = transform::Sequential(passes); mod = seq(mod); auto entry_func = Downcast(mod->Lookup("main")); @@ -1184,7 +1184,7 @@ class PartialEvaluator : public ExprFunctor private: Environment env_; - Module mod_; + IRModule mod_; std::unordered_map gv_map_; /*! Termination checking is done as follows: * We have finitely many FunctionIds. @@ -1255,7 +1255,7 @@ Expr PostProcess(const Expr& e) { } // namespace partial_eval -Module PartialEval(const Module& m) { +IRModule PartialEval(const IRModule& m) { relay::partial_eval::PartialEvaluator pe(m); std::vector gvs; for (const auto& p : m->functions) { @@ -1270,9 +1270,9 @@ Module PartialEval(const Module& m) { namespace transform { Pass PartialEval() { - runtime::TypedPackedFunc pass_func = - [=](Module m, PassContext pc) { - return PartialEval(m); + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { + return relay::PartialEval(m); }; return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); } diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 7d270591a..bcd4451d6 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -98,7 +98,7 @@ class ModulePassNode : public PassNode { * implement the algorithm in the `pass_func` and let it run on a module. It * will then remove the dead code including the unused functions in the module. */ - runtime::TypedPackedFunc pass_func; + runtime::TypedPackedFunc pass_func; ModulePassNode() = default; @@ -114,7 +114,7 @@ class ModulePassNode : public PassNode { * * \return Return the updated module. */ - Module operator()(const Module& mod, const PassContext& pass_ctx) const final; + IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. @@ -122,7 +122,7 @@ class ModulePassNode : public PassNode { PassInfo Info() const override { return pass_info; } TVM_DLL static ModulePass make( - runtime::TypedPackedFunc pass_func, + runtime::TypedPackedFunc pass_func, PassInfo pass_info); static constexpr const char* _type_key = "relay.ModulePass"; @@ -155,7 +155,7 @@ class FunctionPassNode : public PassNode { * `pass_func` and let it run on a given module. The same `pass_func` will * then be applied on each function in the module. */ - runtime::TypedPackedFunc pass_func; + runtime::TypedPackedFunc pass_func; FunctionPassNode() = default; @@ -171,7 +171,7 @@ class FunctionPassNode : public PassNode { * * \return Return the updated module. */ - Module operator()(const Module& mod, const PassContext& pass_ctx) const final; + IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. @@ -179,7 +179,7 @@ class FunctionPassNode : public PassNode { PassInfo Info() const override { return pass_info; } TVM_DLL static FunctionPass make( - runtime::TypedPackedFunc pass_func, + runtime::TypedPackedFunc pass_func, PassInfo pass_info); static constexpr const char* _type_key = "relay.FunctionPass"; @@ -248,7 +248,7 @@ class SequentialNode : public PassNode { * metadata, i.e. required_passes. Likely, we can have a data structure, i.e. * PassInfo, to store the relevant information including the parent passes. */ - void ResolveDependency(const Module& mod); + void ResolveDependency(const IRModule& mod); /*! * \brief Perform optimizations on a series of passes. The aforementioned @@ -261,7 +261,7 @@ class SequentialNode : public PassNode { * * \return Return the updated module. */ - Module operator()(const Module& mod, const PassContext& pass_ctx) const final; + IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; static constexpr const char* _type_key = "relay.Sequential"; TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); @@ -278,7 +278,7 @@ PassInfo PassInfoNode::make(int opt_level, } ModulePass ModulePassNode::make( - runtime::TypedPackedFunc pass_func, + runtime::TypedPackedFunc pass_func, PassInfo pass_info) { auto n = make_object(); n->pass_func = std::move(pass_func); @@ -287,7 +287,7 @@ ModulePass ModulePassNode::make( } // Module -> Module optimizations. -Module ModulePassNode::operator()(const Module& mod, +IRModule ModulePassNode::operator()(const IRModule& mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); DLOG(INFO) << "Executing module pass : " @@ -295,13 +295,13 @@ Module ModulePassNode::operator()(const Module& mod, << " with opt level: " << pass_info->opt_level; CHECK(mod.defined()); - Module updated_mod = pass_func(mod, pass_ctx); + IRModule updated_mod = pass_func(mod, pass_ctx); CHECK(updated_mod.defined()); return updated_mod; } FunctionPass FunctionPassNode::make( - runtime::TypedPackedFunc pass_func, + runtime::TypedPackedFunc pass_func, PassInfo pass_info) { auto n = make_object(); n->pass_func = std::move(pass_func); @@ -310,7 +310,7 @@ FunctionPass FunctionPassNode::make( } // Perform Module -> Module optimizations at the Function level. -Module FunctionPassNode::operator()(const Module& mod, +IRModule FunctionPassNode::operator()(const IRModule& mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); CHECK(mod.defined()); @@ -320,7 +320,7 @@ Module FunctionPassNode::operator()(const Module& mod, << pass_info->opt_level; // Execute the pass function and return a new module. - Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports()); + IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports()); std::vector > updates; for (const auto& it : updated_mod->functions) { // only picks up relay::Function @@ -364,7 +364,7 @@ const SequentialNode* Sequential::operator->() const { return static_cast(get()); } -void SequentialNode::ResolveDependency(const Module& mod) { +void SequentialNode::ResolveDependency(const IRModule& mod) { // TODO(zhiics) Implement it. // 1. Consider the required passes for each pass. // 2. Only resolve the enabled passes. @@ -410,9 +410,9 @@ Pass GetPass(const std::string& pass_name) { // TODO(zhiics): we currenlty only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. -Module SequentialNode::operator()(const Module& module, +IRModule SequentialNode::operator()(const IRModule& module, const PassContext& pass_ctx) const { - Module mod = module; + IRModule mod = module; for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); @@ -429,7 +429,7 @@ Module SequentialNode::operator()(const Module& module, } Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, + const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, const tvm::Array& required) { @@ -438,7 +438,7 @@ Pass CreateModulePass( } Pass CreateFunctionPass( - const runtime::TypedPackedFunc& pass_func, + const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, const tvm::Array& required) { @@ -479,7 +479,7 @@ TVM_REGISTER_GLOBAL("relay._transform.MakeModulePass") TVM_REGISTER_GLOBAL("relay._transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; - Module mod = args[1]; + IRModule mod = args[1]; *ret = pass(mod); }); diff --git a/src/relay/pass/print_ir.cc b/src/relay/pass/print_ir.cc index 5191a2ec5..e7ed89b68 100644 --- a/src/relay/pass/print_ir.cc +++ b/src/relay/pass/print_ir.cc @@ -32,8 +32,8 @@ namespace relay { namespace transform { Pass PrintIR(bool show_meta_data) { - runtime::TypedPackedFunc pass_func = - [=](Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m, show_meta_data); return m; }; diff --git a/src/relay/pass/quantize/annotate.cc b/src/relay/pass/quantize/annotate.cc index 5e1083a95..4b7f15a36 100644 --- a/src/relay/pass/quantize/annotate.cc +++ b/src/relay/pass/quantize/annotate.cc @@ -92,8 +92,8 @@ Pass QuantizeAnnotate() { return e; }; - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); auto new_params = func->params; for (const auto& x : FreeVars(func)) { diff --git a/src/relay/pass/quantize/partition.cc b/src/relay/pass/quantize/partition.cc index 6ad05e8ca..8fb65a4c7 100644 --- a/src/relay/pass/quantize/partition.cc +++ b/src/relay/pass/quantize/partition.cc @@ -78,8 +78,8 @@ TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr") }); Pass QuantizePartition() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { auto ret = Downcast( ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr)); return ret; diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc index 5f808714b..ae3724502 100644 --- a/src/relay/pass/quantize/realize.cc +++ b/src/relay/pass/quantize/realize.cc @@ -190,7 +190,7 @@ Expr QuantizeRealize(const Call& ref_call, } Expr FoldConstantOpt(const Expr& expr) { - auto mod = ModuleNode::FromExpr(expr); + auto mod = IRModule::FromExpr(expr); mod = transform::FoldConstant()(mod); auto entry_func = Downcast(mod->Lookup("main")); return expr.as() == nullptr ? entry_func->body : entry_func; @@ -522,8 +522,8 @@ RELAY_REGISTER_OP("annotation.cast_hint") .set_attr("FQRealizeRewrite", CastHintRealize); Pass QuantizeRealizePass() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast( ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr)); }; diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index 108edfcfe..32fc06f93 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -183,8 +183,8 @@ Expr SimplifyInference(const Expr& e) { namespace transform { Pass SimplifyInference() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(SimplifyInference(f)); }; return CreateFunctionPass(pass_func, 0, "SimplifyInference", diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index c839beb42..c08deefd2 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -291,7 +291,7 @@ Expr ToANormalFormAux(const Expr& e) { return Fill::ToANormalForm(e, dg, &node_scope); } -Module ToANormalForm(const Module& m) { +IRModule ToANormalForm(const IRModule& m) { DLOG(INFO) << "ToANF:" << std::endl << m; tvm::Map updates; @@ -321,9 +321,9 @@ Module ToANormalForm(const Module& m) { namespace transform { Pass ToANormalForm() { - runtime::TypedPackedFunc pass_func = - [=](Module m, PassContext pc) { - return ToANormalForm(m); + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { + return relay::ToANormalForm(m); }; return CreateModulePass(pass_func, 1, "ToANormalForm", {}); } diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc index 898e4e9f8..3ffc4ad95 100644 --- a/src/relay/pass/to_cps.cc +++ b/src/relay/pass/to_cps.cc @@ -111,21 +111,27 @@ using VarMap = std::unordered_map; */ using MCont = std::function; -Function ToCPS(const Function& f, const Module& m, CPSMap* cm); +Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm); -Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const TypeVar& answer) { - std::function remap = [&](const Var& v) { return vm->count(v) == 0 ? v : vm->at(v); }; +Function ToCPS(const Function& f, + const IRModule& m, + CPSMap* cm, + VarMap* vm, + const TypeVar& answer) { + std::function remap = [&](const Var& v) { + return vm->count(v) == 0 ? v : vm->at(v); + }; auto function_type = Downcast(f->checked_type()); // Each MCont can be used at most once. struct CPSFunctor : ExprFunctor, PatternMutator { CPSFunctor(const std::function& remap, const TypeVar& answer, - const Module& m, + const IRModule& m, VarMap* vm, CPSMap* cm) : remap(remap), answer(answer), m(m), vm(vm), cm(cm) { } const std::function& remap; TypeVar answer; - Module m; + IRModule m; VarMap* vm; CPSMap* cm; @@ -295,7 +301,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const f->attrs); } -Function ToCPS(const Function& f, const Module& m, CPSMap* cm) { +Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { TypeVar answer = TypeVarNode::make("answer", kType); VarMap var; struct Remapper : ExprVisitor, PatternVisitor { @@ -325,7 +331,7 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm) { return FunctionNode::make(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs); } -Function ToCPS(const Function& f, const Module& m) { +Function ToCPS(const Function& f, const IRModule& m) { CPSMap cps; return ToCPS(f, m, &cps); } @@ -368,7 +374,7 @@ Function UnCPS(const Function& f) { } TVM_REGISTER_GLOBAL("relay._transform.to_cps") -.set_body_typed(static_cast(ToCPS)); +.set_body_typed(static_cast(ToCPS)); TVM_REGISTER_GLOBAL("relay._transform.un_cps") .set_body_typed(UnCPS); @@ -376,8 +382,8 @@ TVM_REGISTER_GLOBAL("relay._transform.un_cps") namespace transform { Pass ToCPS() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Function(ToCPS(f, m)); }; return CreateFunctionPass(pass_func, 1, "ToCPS", {}); @@ -388,8 +394,8 @@ TVM_REGISTER_GLOBAL("relay._transform.ToCPS") Pass UnCPS() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Function(UnCPS(f)); }; return CreateFunctionPass(pass_func, 1, "UnCPS", {}); diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index c9eeefddc..b6ff2490a 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -79,8 +79,8 @@ Expr ToGraphNormalForm(const Expr& e) { namespace transform { Pass ToGraphNormalForm() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(ToGraphNormalForm(f)); }; return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {}); diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 23ed83caf..876cf481c 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -105,7 +105,7 @@ class TypeInferencer : private ExprFunctor, public: // constructors - explicit TypeInferencer(Module mod, GlobalVar current_func) + explicit TypeInferencer(IRModule mod, GlobalVar current_func) : mod_(mod), current_func_(current_func), err_reporter(), solver_(current_func, mod, &this->err_reporter) { CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer"; @@ -118,7 +118,7 @@ class TypeInferencer : private ExprFunctor, // type resolver that maps back to type class Resolver; // internal environment - Module mod_; + IRModule mod_; // The current function being type checked. GlobalVar current_func_; @@ -798,7 +798,7 @@ void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); } -Expr InferType(const Expr& expr, const Module& mod) { +Expr InferType(const Expr& expr, const IRModule& mod) { auto main = mod->GetGlobalVar("main"); auto inferencer = TypeInferencer(mod, main); auto e = inferencer.Infer(expr); @@ -811,7 +811,7 @@ Expr InferType(const Expr& expr, const Module& mod) { } Function InferType(const Function& func, - const Module& mod, + const IRModule& mod, const GlobalVar& var) { CHECK(mod.defined()) << "internal error: module must be set for type inference"; Function func_copy = Function(make_object(*func.operator->())); @@ -832,8 +832,8 @@ Function InferType(const Function& func, namespace transform { Pass InferType() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(InferType(f, m)); }; return CreateFunctionPass(pass_func, 0, "InferType", {}); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index ceed96471..372b35169 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -60,7 +60,7 @@ class TypeSolver::Reporter : public TypeReporterNode { location = ref; } - TVM_DLL Module GetModule() final { + TVM_DLL IRModule GetModule() final { return this->solver_->module_; } @@ -531,7 +531,7 @@ class TypeSolver::Merger : public TypeFunctor { // constructor TypeSolver::TypeSolver( const GlobalVar& current_func, - const Module& module, + const IRModule& module, ErrorReporter* err_reporter) : reporter_(make_object(this)), current_func(current_func), @@ -661,7 +661,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver") using runtime::PackedFunc; using runtime::TypedPackedFunc; ErrorReporter *err_reporter = new ErrorReporter(); - auto module = ModuleNode::make({}, {}); + auto module = IRModule({}, {}); auto dummy_fn_name = GlobalVar("test"); module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {})); auto solver = std::make_shared(dummy_fn_name, module, err_reporter); diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index bf1ac716c..eba1bea7c 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -62,7 +62,7 @@ using common::LinkedList; */ class TypeSolver { public: - TypeSolver(const GlobalVar& current_func, const Module& _mod, ErrorReporter* err_reporter); + TypeSolver(const GlobalVar& current_func, const IRModule& _mod, ErrorReporter* err_reporter); ~TypeSolver(); /*! * \brief Add a type constraint to the solver. @@ -179,7 +179,7 @@ class TypeSolver { /*! \brief Error reporting. */ ErrorReporter* err_reporter_; /*! \brief The module. */ - Module module_; + IRModule module_; /*! * \brief GetTypeNode that is corresponds to t. diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 577f49212..e45b15a0c 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -72,7 +72,7 @@ class TypeVarTVisitor : public TypeVisitor { class TypeVarEVisitor : private ExprVisitor { public: - explicit TypeVarEVisitor(const Module& mod) : mod_(mod) {} + explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {} Array CollectFree() { Array ret; @@ -156,7 +156,7 @@ class TypeVarEVisitor : private ExprVisitor { private: InsertionSet type_vars_; InsertionSet bound_type_vars_; - const Module& mod_; + const IRModule& mod_; }; class VarVisitor : protected ExprVisitor, protected PatternVisitor { @@ -234,27 +234,27 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { InsertionSet bound_vars_; }; -tvm::Array FreeTypeVars(const Expr& expr, const Module& mod) { +tvm::Array FreeTypeVars(const Expr& expr, const IRModule& mod) { return TypeVarEVisitor(mod).Free(expr); } -tvm::Array FreeTypeVars(const Type& type, const Module& mod) { +tvm::Array FreeTypeVars(const Type& type, const IRModule& mod) { return TypeVarEVisitor(mod).Free(type); } -tvm::Array BoundTypeVars(const Expr& expr, const Module& mod) { +tvm::Array BoundTypeVars(const Expr& expr, const IRModule& mod) { return TypeVarEVisitor(mod).Bound(expr); } -tvm::Array BoundTypeVars(const Type& type, const Module& mod) { +tvm::Array BoundTypeVars(const Type& type, const IRModule& mod) { return TypeVarEVisitor(mod).Bound(type); } -tvm::Array AllTypeVars(const Expr& expr, const Module& mod) { +tvm::Array AllTypeVars(const Expr& expr, const IRModule& mod) { return TypeVarEVisitor(mod).All(expr); } -tvm::Array AllTypeVars(const Type& type, const Module& mod) { +tvm::Array AllTypeVars(const Type& type, const IRModule& mod) { return TypeVarEVisitor(mod).All(type); } @@ -293,7 +293,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.all_vars") TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; - Module mod = args[1]; + IRModule mod = args[1]; if (x.as()) { *ret = FreeTypeVars(Downcast(x), mod); } else { @@ -304,7 +304,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars") TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; - Module mod = args[1]; + IRModule mod = args[1]; if (x.as()) { *ret = BoundTypeVars(Downcast(x), mod); } else { @@ -315,7 +315,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars") TVM_REGISTER_GLOBAL("relay._analysis.all_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { ObjectRef x = args[0]; - Module mod = args[1]; + IRModule mod = args[1]; if (x.as()) { *ret = AllTypeVars(Downcast(x), mod); } else { diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index f727404fb..070892b1a 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -33,7 +33,7 @@ TEST(Relay, SelfReference) { auto y = relay::VarNode::make("y", tensor_type); auto call = relay::CallNode::make(f, Array{ y }); auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); - auto mod = relay::ModuleNode::FromExpr(fx); + auto mod = IRModule::FromExpr(fx); mod = relay::transform::InferType()(mod); auto type_fx = mod->Lookup("main"); diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index 22934a6b8..8321c580a 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include #include #include @@ -73,7 +73,7 @@ TEST(Relay, Sequential) { relay::transform::AlterOpLayout() }; relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); - auto mod = relay::ModuleNode::FromExpr(func); + auto mod = IRModule::FromExpr(func); auto pass_ctx = relay::transform::PassContext::Create(); pass_ctx->opt_level = 3; pass_ctx->fallback_device = 1; @@ -100,7 +100,7 @@ TEST(Relay, Sequential) { relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {}); // Infer type for the expected function. - auto mod1 = relay::ModuleNode::FromExpr(expected_func); + auto mod1 = IRModule::FromExpr(expected_func); mod1 = relay::transform::InferType()(mod1); auto expected = mod1->Lookup("main"); CHECK(relay::AlphaEqual(f, expected)); -- 2.34.1