struct Module;
-struct NamedModule {
- std::string name;
- std::shared_ptr<Module> module;
-};
-
struct NamedIValue {
NamedIValue(std::string name, TypePtr type, IValue ivalue)
: name_(name),
struct Module {
TH_DISALLOW_COPY_AND_ASSIGN(Module);
- Module() : optimize(true) {}
+ Module() : name_("__main__"), optimize_(true) {}
+
+ const std::string& name() const {
+ return name_;
+ }
// note this doesn't change the flags of existing methods just ones
// added afterward.
void set_optimized(bool o) {
- optimize = o;
+ optimize_ = o;
}
bool is_optimized() const {
- return optimize;
+ return optimize_;
}
IValue forward(std::vector<IValue> inputs) {
void register_module(
const std::string& name,
std::shared_ptr<Module> module) {
- insert(name, modules_, EntityType::MODULE, {name, std::move(module)});
+ // We would like to enable more stringent error checking at this point,
+ // but because script functions are considered modules, it is possible
+ // to hit this situation without knowing it. For now this is disabled
+ // until a later PR that distinguishes script functions from script modules.
+ // See TestScript.test_submodule_twice for example failure
+ // if (module->parent_) {
+ // AT_WARN(
+ // "Attempting to assign submodule '",
+ // name,
+ // "' but it is already a submodule of another ScriptModule '", module->parent_->name(), "'",
+ // " Modules of this form do not import and export correctly. This use is deprecated and may be"
+ // " removed in a future version.");
+ // }
+ module->parent_ = this;
+ module->name_ = name;
+ insert(name, modules_, EntityType::MODULE, std::move(module));
}
Method& create_method(
std::unique_ptr<Method> method(new Method(
this,
name,
- optimize,
+ optimize_,
std::move(graph),
std::move(member_inputs),
nullptr));
std::unique_ptr<Method> method(new Method(
this,
name,
- optimize,
+ optimize_,
std::make_shared<Graph>(),
{},
std::move(creator)));
}
std::shared_ptr<Module> get_module(const std::string& name) const {
- return modules_[get_offset(name, EntityType::MODULE)].module;
+ return modules_[get_offset(name, EntityType::MODULE)];
}
- c10::ArrayRef<NamedModule> get_modules() const {
+ c10::ArrayRef<std::shared_ptr<Module>> get_modules() const {
return modules_;
}
c10::ArrayRef<NamedIValue> get_parameters() const {
}
return nullptr;
}
- NamedModule* find_module(const std::string& name) {
+ std::shared_ptr<Module> find_module(const std::string& name) {
auto offset = find_offset(name, EntityType::MODULE);
- return offset ? &modules_[*offset] : nullptr;
+ return offset ? modules_[*offset] : nullptr;
}
Method* find_method(const std::string& name) {
auto offset = find_offset(name, EntityType::METHOD);
}
void apply(std::function<void(Module&)> fn) {
for (auto& submod : get_modules()) {
- submod.module->apply(fn);
+ submod->apply(fn);
}
fn(*this);
}
/// Enables "training" mode.
void train(bool on = true) {
for (auto& submod : get_modules()) {
- submod.module->train(on);
+ submod->train(on);
}
register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
}
parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
}
for (auto& mod : get_modules()) {
- names.push_back(mod.name);
+ names.push_back(mod->name());
// Submodules must be translated first, otherwise parameter_remap entries
// will not be filled in for methods of this module.
- mod.module->copy_into(module_lookup, parameter_remap, names);
+ mod->copy_into(module_lookup, parameter_remap, names);
names.pop_back();
}
for (auto& method : get_methods()) {
const c10::optional<at::ScalarType>& dtype,
bool non_blocking);
- // modules have a single namespace, but spread over 4 different concepts:
- // parameters, attributes, methods, and sub-modules
- // we store individual lists of each concept, and a single map to
- // unify the namespace and ensure fast lookup
-
- // invariant: to ensure initial_ivalues of Methods stay valid,
- // it is only legal to _add_ new modules and parameters.
- // removing them will allow initial_ivalues to point to invalid parameters
- // no such restriction exists for methods
- std::vector<NamedModule> modules_;
- std::vector<NamedIValue> parameters_;
- std::vector<NamedIValue> attributes_;
- std::vector<std::unique_ptr<Method>> methods_;
-
static const char* toString(EntityType t) {
switch (t) {
case EntityType::MODULE:
return list.back();
}
+ // modules have a single namespace, but spread over 4 different concepts:
+ // parameters, attributes, methods, and sub-modules
+ // we store individual lists of each concept, and a single map to
+ // unify the namespace and ensure fast lookup
+
+ // invariant: to ensure initial_ivalues of Methods stay valid,
+ // it is only legal to _add_ new modules and parameters.
+ // removing them will allow initial_ivalues to point to invalid parameters
+ // no such restriction exists for methods
+ std::vector<std::shared_ptr<Module>> modules_;
+ std::vector<NamedIValue> parameters_;
+ std::vector<NamedIValue> attributes_;
+ std::vector<std::unique_ptr<Method>> methods_;
+
std::unordered_map<std::string, Entry> dict_;
+ std::string name_;
- bool optimize;
+ // back reference to parent of this Module if present
+ Module* parent_ = nullptr;
+ bool optimize_;
};
// returns nullptr and fills in failure_messages if the callee does not