traced_model.cpu()
cpu_out = traced_model(x.float())
self.assertEqual(cpu_out, cuda_out)
+ traced_model.to('cuda')
+ cuda_out = traced_model(x.float().cuda())
+ traced_model.to('cpu')
+ cpu_out = traced_model(x.float())
+ self.assertEqual(cpu_out, cuda_out)
traced_model.double()
# state_dict + load_state_dict
_compiled_methods_whitelist = {
'forward', 'register_buffer', 'register_parameter', 'add_module',
- '_apply', 'apply', 'cuda', 'cpu', 'type', 'float', 'double', 'half',
+ '_apply', 'apply', 'cuda', 'cpu', 'to', 'type', 'float', 'double', 'half',
'state_dict', 'load_state_dict', '_load_from_state_dict',
'_named_members', 'parameters', 'named_parameters',
'buffers', 'named_buffers', 'children', 'named_children', 'modules',