Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<String> import_set = {});
/*! \brief default constructor */
- IRModule() {}
+ IRModule() : IRModule(Map<GlobalVar, BaseFunc>()) {}
/*!
* \brief constructor
* \param n The object pointer.
}
/*!
- * \brief Construct an empty module.
- *
- * \returns The constructed module
- */
- static IRModule Empty() { return IRModule(Map<GlobalVar, BaseFunc>()); }
- /*!
* \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map and
/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;
+
+ /*! \brief Declare whether Ref is nullable. */
+ static constexpr bool _type_is_nullable = false;
+
// allow copy on write.
TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode);
};
result: tvm.relay.Function
The output function.
"""
- return _ffi_api.to_cps(func, mod)
+ use_mod = mod if mod is not None else tvm.ir.IRModule()
+ return _ffi_api.to_cps(func, use_mod)
def un_cps(func):
return fs;
}
-Array<Integer> PyDetectFeature(const Expr& expr, const IRModule& mod) {
- FeatureSet fs = DetectFeature(expr) + DetectFeature(mod);
+Array<Integer> PyDetectFeature(const Expr& expr, const Optional<IRModule>& mod) {
+ FeatureSet fs = DetectFeature(expr);
+ if (mod.defined()) {
+ fs = fs + DetectFeature(mod.value());
+ }
return static_cast<Array<Integer>>(fs);
}
// expose for testing only
TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases")
- .set_body_typed([](const Match& match, const IRModule& mod_ref) {
- IRModule call_mod = mod_ref;
- if (!call_mod.defined()) {
- call_mod = IRModule({}, {});
- }
+ .set_body_typed([](const Match& match, const Optional<IRModule>& mod_ref) {
+ IRModule call_mod = mod_ref.defined() ? mod_ref.value() : IRModule({}, {});
return UnmatchedCases(match, call_mod);
});
/*! \brief The schedule to the function */
te::Schedule schedule;
/*! \brief The lowered functions to support the function. */
- IRModule funcs = IRModule::Empty();
+ IRModule funcs = IRModule();
/*! \brief Parameter usage states in the shape function. */
tvm::Array<Integer> shape_func_param_states;
for (auto& kv : lowered_funcs_) {
if (ret.lowered_funcs.count(kv.first) == 0) {
- ret.lowered_funcs.Set(kv.first, IRModule::Empty());
+ ret.lowered_funcs.Set(kv.first, IRModule());
}
auto& mod = ret.lowered_funcs[kv.first];
mod->Update(kv.second);
CCacheKey key = (*pf0)(func, target);
CachedFunc lowered_func = (*pf1)(compile_engine_, key);
if (!lowered_funcs_.count(target->str())) {
- lowered_funcs_[target->str()] = IRModule::Empty();
+ lowered_funcs_[target->str()] = IRModule();
}
lowered_funcs_[target->str()]->Update(lowered_func->funcs);
return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name);
/*! return an expression that represent differentiation of e (according to WithGradientType).
* This version only work on first order code without control flow.
*/
-Expr FirstOrderGradient(const Expr& e, const IRModule& mod);
+Expr FirstOrderGradient(const Expr& e, const Optional<IRModule>& mod);
Type WithGradientType(const Type& t) {
// TODO(M.K.): stricter checking
}
//! \brief if the expression is a GlobalVar, transform to it's expression.
-Expr DeGlobal(const IRModule& mod, const Expr& e) {
- if (const auto* x = e.as<GlobalVarNode>()) {
- BaseFunc base_func = mod->Lookup(GetRef<GlobalVar>(x));
+Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
+ const auto* x = e.as<GlobalVarNode>();
+
+ if (mod.defined() && (x)) {
+ BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
return n->body;
} else {
return TupleType({f->ret_type, TupleType(vt)});
}
-Expr FirstOrderGradient(const Expr& re, const IRModule& mod) {
+Expr FirstOrderGradient(const Expr& re, const Optional<IRModule>& mod) {
// Currently we first remove any global functions for the first
// order case.
auto e = DeGlobal(mod, re);
return false;
}
-Expr Gradient(const Expr& re, const IRModule& mod) {
+Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
auto pass_func = [](IRModule mod, PassContext ctx) {
IRModuleNode* mod_ptr = mod.CopyOnWrite();
auto* func_dict = mod_ptr->functions.CopyOnWrite();
- IRModule device_mod = IRModule::Empty();
+ IRModule device_mod = IRModule();
for (auto& kv : func_dict->data) {
if (kv.second->IsInstance<PrimFuncNode>()) {