self.assertTrue(imported.ssm.asm._has_parameter('param'))
self.assertTrue(hasattr(imported.ssm.asm, 'param'))
+ def test_trace_parameter(self):
+ class Param(nn.Module):
+ def __init__(self):
+ super(Param, self).__init__()
+ self.register_parameter("bias", nn.Parameter(torch.Tensor(4, 4)))
+
+ def forward(self, x):
+ return x
+
+ class M3(torch.jit.ScriptModule):
+ def __init__(self, model):
+ super(M3, self).__init__(False)
+ self.traced = torch.jit.trace(model, (torch.rand(3, 3)))
+
+ @torch.jit.script_method
+ def forward(self, x):
+ return self.traced(x)
+
+ class M2(nn.Module):
+ def __init__(self, model):
+ super(M2, self).__init__()
+ self.module = M3(model)
+
+ def forward(self, x):
+ return self.module(x)
+
+ class M1(torch.jit.ScriptModule):
+ def __init__(self, model):
+ super(M1, self).__init__(False)
+ self.traced = torch.jit.trace(M2(model), (torch.rand(3, 3)))
+
+ @torch.jit.script_method
+ def forward(self, x):
+ return self.traced(x)
+
+ module = M1(Param())
+ f = io.BytesIO()
+ torch.jit.save(module, f)
+
def test_call_traced_module_from_traced_module(self):
class TracedModule1(torch.nn.Module):
def __init__(self):
void save(const std::string& filename);
- void copy_into(std::function<std::shared_ptr<Module>(std::vector<std::string>)> module_lookup, std::vector<std::string> names = {}) const {
- std::unordered_map<at::Tensor*, at::Tensor*> parameter_remap;
+ void copy_into(std::function<std::shared_ptr<Module>(
+ std::vector<std::string>)> module_lookup,
+ // parameter_remap is needed when a parent module uses a parameter of a submodule
+ std::unordered_map<at::Tensor*, at::Tensor*>& parameter_remap,
+ std::vector<std::string> names = {}) const {
auto curr = module_lookup(names);
for (auto &kv : parameters) {
curr->register_parameter(kv.key(), *kv.value().slot(), kv.value().is_buffer);
}
for (auto &kv : modules) {
names.push_back(kv.key());
- kv.value().module->copy_into(module_lookup, names);
+ // Submodules must be translated first, otherwise parameter_remap entries
+ // will not be filled in for methods of this module.
+ kv.value().module->copy_into(module_lookup, parameter_remap, names);
names.pop_back();
}
for (auto &kv : methods) {
std::vector<at::Tensor*> params;
for (auto &p : kv.value()->params()) {
- params.push_back(parameter_remap[p]);
+ params.push_back(parameter_remap.at(p));
}
curr->create_method(kv.key(), kv.value()->graph()->copy(), params);
}
setattr(curr, name, ScriptModule())
curr = getattr(curr, name)
return curr
- self._copy_into(module_lookup, [])
+ self._copy_into(module_lookup, {}, [])
return m
def __getstate__(self):