# Note: neg op from the traced function should be properly inlined
self.assertExpected(canonical(tm.graph))
+ def test_trace_hierarchy(self):
+ # Test that we preserve the module hierarchy for a ScriptModule
+ # submodule during tracing
+
+ class AnotherScriptMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(AnotherScriptMod, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(1, 2, 3))
+
+ @torch.jit.script_method
+ def bar(self):
+ return torch.zeros(4, 5)
+
+ class SomeScriptMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(SomeScriptMod, self).__init__()
+ self.asm = AnotherScriptMod()
+
+ @torch.jit.script_method
+ def foo(self):
+ return torch.zeros(3, 4)
+
+ @torch.jit.script_method
+ def bar(self):
+ return torch.zeros(4, 3)
+
+ class TraceMe(torch.nn.Module):
+ def __init__(self):
+ super(TraceMe, self).__init__()
+ self.ssm = SomeScriptMod()
+
+ def forward(self, x):
+ return self.ssm.bar() + x
+
+ orig = TraceMe()
+ traced = torch.jit.trace(orig, (torch.rand(4, 3, dtype=torch.float),))
+ # for each of these checks, check that *BOTH* the underlying
+ # _C.ScriptModule object has the expected method/param, as well as the
+ # Python object that wraps it.
+ self.assertTrue(traced.ssm._has_method('foo'))
+ self.assertTrue(hasattr(traced.ssm, 'foo'))
+
+ imported = self.getExportImportCopy(traced)
+
+ self.assertTrue(imported.ssm._has_method('foo'))
+ self.assertTrue(hasattr(imported.ssm, 'foo'))
+
+ self.assertTrue(imported.ssm.asm._has_method('bar'))
+ self.assertTrue(hasattr(imported.ssm.asm, 'bar'))
+
+ self.assertTrue(imported.ssm.asm._has_parameter('param'))
+ self.assertTrue(hasattr(imported.ssm.asm, 'param'))
+
def test_call_traced_module_from_traced_module(self):
class TracedModule1(torch.nn.Module):
def __init__(self):
class ScriptMod(torch.jit.ScriptModule):
def __init__(self):
super(ScriptMod, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(5, 7))
+ self.param_foo = torch.nn.Parameter(torch.rand(5, 7))
@torch.jit.script_method
def forward(self, x):
- return torch.mm(x, self.param)
+ return torch.mm(x, self.param_foo)
class TracedModule(torch.nn.Module):
def __init__(self):
void save(const std::string& filename);
+ void copy_into(std::function<std::shared_ptr<Module>(std::vector<std::string>)> module_lookup, std::vector<std::string> names = {}) const {
+ std::unordered_map<at::Tensor*, at::Tensor*> parameter_remap;
+ auto curr = module_lookup(names);
+ for (auto &kv : parameters) {
+ curr->register_parameter(kv.key(), *kv.value().slot(), kv.value().is_buffer);
+ parameter_remap[kv.value().slot()] = curr->parameter_slot(kv.key());
+ }
+ for (auto &kv : modules) {
+ names.push_back(kv.key());
+ kv.value().module->copy_into(module_lookup, names);
+ names.pop_back();
+ }
+ for (auto &kv : methods) {
+ std::vector<at::Tensor*> params;
+ for (auto &p : kv.value()->params()) {
+ params.push_back(parameter_remap[p]);
+ }
+ curr->create_method(kv.key(), kv.value()->graph()->copy(), params);
+ }
+ }
+
private:
void to_impl(
const c10::optional<at::Device>& device,
rcb = createResolutionCallback(frames_up=1)
self._define(lang, rcb, True)
+ def copy(self):
+ m = ScriptModule()
+
+ def module_lookup(names):
+ curr = m
+ for name in names:
+ if not hasattr(curr, name):
+ setattr(curr, name, ScriptModule())
+ curr = getattr(curr, name)
+ return curr
+ self._copy_into(module_lookup, [])
+ return m
+
class WeakScriptModuleProxy(ScriptModule):
def __init__(self, original, stubs):
# Guards behavior of __setattr__ and __getattr__ so ScriptModule
raise ValueError("Modules that have hooks assigned can't be compiled")
for name, submodule in orig._modules.items():
- self._modules[name] = TracedModule(submodule, id_set, optimize=optimize)
+ if isinstance(submodule, ScriptModule) and not isinstance(submodule, TracedModule):
+ self._modules[name] = submodule.copy()
+ else:
+ self._modules[name] = TracedModule(submodule, id_set, optimize=optimize)
self._freeze()