Implement 'to' on ScriptModules (#15340)
authorvfdev-5 <vfdev.5@gmail.com>
Wed, 19 Dec 2018 18:34:37 +0000 (10:34 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 18:41:23 +0000 (10:41 -0800)
Summary:
Following #6008
Fixes "Implement 'to' on ScriptModules #7354"

cc zdevito
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15340

Differential Revision: D13506646

Pulled By: zdevito

fbshipit-source-id: 318fea2e8e51a37ce9844efa4c8db67d45a66317

test/test_jit.py
torch/jit/__init__.py

index 82f89ca..84daa2e 100644 (file)
@@ -1478,6 +1478,11 @@ class TestJit(JitTestCase):
         traced_model.cpu()
         cpu_out = traced_model(x.float())
         self.assertEqual(cpu_out, cuda_out)
+        traced_model.to('cuda')
+        cuda_out = traced_model(x.float().cuda())
+        traced_model.to('cpu')
+        cpu_out = traced_model(x.float())
+        self.assertEqual(cpu_out, cuda_out)
         traced_model.double()
 
         # state_dict + load_state_dict
index ab06a34..88f52f9 100644 (file)
@@ -1267,7 +1267,7 @@ def _get_methods(cls):
 
 _compiled_methods_whitelist = {
     'forward', 'register_buffer', 'register_parameter', 'add_module',
-    '_apply', 'apply', 'cuda', 'cpu', 'type', 'float', 'double', 'half',
+    '_apply', 'apply', 'cuda', 'cpu', 'to', 'type', 'float', 'double', 'half',
     'state_dict', 'load_state_dict', '_load_from_state_dict',
     '_named_members', 'parameters', 'named_parameters',
     'buffers', 'named_buffers', 'children', 'named_children', 'modules',