struct Module;
+struct NamedModule {
+ std::string name;
+ std::shared_ptr<Module> module;
+};
+
struct NamedIValue {
NamedIValue(std::string name, TypePtr type, IValue ivalue)
: name_(name),
struct Module {
TH_DISALLOW_COPY_AND_ASSIGN(Module);
- Module() : name_("__main__"), optimize_(true) {}
-
- const std::string& name() const {
- return name_;
- }
+ Module() : optimize(true) {}
// note this doesn't change the flags of existing methods just ones
// added afterward.
void set_optimized(bool o) {
- optimize_ = o;
+ optimize = o;
}
bool is_optimized() const {
- return optimize_;
+ return optimize;
}
IValue forward(std::vector<IValue> inputs) {
void register_module(
const std::string& name,
std::shared_ptr<Module> module) {
- if (module->parent_) {
- AT_ERROR(
- "Attempting to assign submodule '",
- name,
- "' but it is already a submodule of another ScriptModule '", module->parent_->name(), "'");
- }
- module->parent_ = this;
- module->name_ = name;
- insert(name, modules_, EntityType::MODULE, std::move(module));
+ insert(name, modules_, EntityType::MODULE, {name, std::move(module)});
}
Method& create_method(
std::unique_ptr<Method> method(new Method(
this,
name,
- optimize_,
+ optimize,
std::move(graph),
std::move(member_inputs),
nullptr));
std::unique_ptr<Method> method(new Method(
this,
name,
- optimize_,
+ optimize,
std::make_shared<Graph>(),
{},
std::move(creator)));
}
std::shared_ptr<Module> get_module(const std::string& name) const {
- return modules_[get_offset(name, EntityType::MODULE)];
+ return modules_[get_offset(name, EntityType::MODULE)].module;
}
- c10::ArrayRef<std::shared_ptr<Module>> get_modules() const {
+ c10::ArrayRef<NamedModule> get_modules() const {
return modules_;
}
c10::ArrayRef<NamedIValue> get_parameters() const {
}
return nullptr;
}
- std::shared_ptr<Module> find_module(const std::string& name) {
+ NamedModule* find_module(const std::string& name) {
auto offset = find_offset(name, EntityType::MODULE);
- return offset ? modules_[*offset] : nullptr;
+ return offset ? &modules_[*offset] : nullptr;
}
Method* find_method(const std::string& name) {
auto offset = find_offset(name, EntityType::METHOD);
}
void apply(std::function<void(Module&)> fn) {
for (auto& submod : get_modules()) {
- submod->apply(fn);
+ submod.module->apply(fn);
}
fn(*this);
}
/// Enables "training" mode.
void train(bool on = true) {
for (auto& submod : get_modules()) {
- submod->train(on);
+ submod.module->train(on);
}
register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
}
parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
}
for (auto& mod : get_modules()) {
- names.push_back(mod->name());
+ names.push_back(mod.name);
// Submodules must be translated first, otherwise parameter_remap entries
// will not be filled in for methods of this module.
- mod->copy_into(module_lookup, parameter_remap, names);
+ mod.module->copy_into(module_lookup, parameter_remap, names);
names.pop_back();
}
for (auto& method : get_methods()) {
const c10::optional<at::ScalarType>& dtype,
bool non_blocking);
+ // modules have a single namespace, but spread over 4 different concepts:
+ // parameters, attributes, methods, and sub-modules
+ // we store individual lists of each concept, and a single map to
+ // unify the namespace and ensure fast lookup
+
+ // invariant: to ensure initial_ivalues of Methods stay valid,
+ // it is only legal to _add_ new modules and parameters.
+ // removing them will allow initial_ivalues to point to invalid parameters
+ // no such restriction exists for methods
+ std::vector<NamedModule> modules_;
+ std::vector<NamedIValue> parameters_;
+ std::vector<NamedIValue> attributes_;
+ std::vector<std::unique_ptr<Method>> methods_;
+
static const char* toString(EntityType t) {
switch (t) {
case EntityType::MODULE:
return list.back();
}
- // modules have a single namespace, but spread over 4 different concepts:
- // parameters, attributes, methods, and sub-modules
- // we store individual lists of each concept, and a single map to
- // unify the namespace and ensure fast lookup
-
- // invariant: to ensure initial_ivalues of Methods stay valid,
- // it is only legal to _add_ new modules and parameters.
- // removing them will allow initial_ivalues to point to invalid parameters
- // no such restriction exists for methods
- std::vector<std::shared_ptr<Module>> modules_;
- std::vector<NamedIValue> parameters_;
- std::vector<NamedIValue> attributes_;
- std::vector<std::unique_ptr<Method>> methods_;
-
std::unordered_map<std::string, Entry> dict_;
- std::string name_;
- // back reference to parent of this Module if present
- Module* parent_ = nullptr;
- bool optimize_;
+ bool optimize;
};
// returns nullptr and fills in failure_messages if the callee does not