input = torch.ones(1)
self.assertEqual(fn2(input), input)
+ def test_out_of_order_methods(self):
+ # Remove this when import/export is implemented for classes
+ with self.disableModuleHook():
+ @torch.jit.script
+ class FooTest:
+ def __init__(self, x):
+ self.x = x
+ self.x = self.get_stuff(x)
+
+ def get_stuff(self, y):
+ return self.x + y
+
+ @torch.jit.script
+ def fn(x):
+ f = FooTest(x)
+ return f.x
+
+ input = torch.ones(1)
+ self.assertEqual(fn(input), input + input)
+
for test in autograd_method_tests():
add_autograd_test(*test)