trace.set_graph(graph)
return graph
+ def checkScript(self,
+ script,
+ inputs,
+ optimize=True,
+ outputs=None,
+ name='func',
+ capture_output=False,
+ frames_up=1,
+ check_expected=False):
+ if isinstance(script, str):
+ cu = torch.jit.CompilationUnit(script, optimize, _frames_up=frames_up)
+ ge = getattr(cu, name)
+ else:
+ if capture_output:
+ with self.capture_stdout() as captured:
+ outputs = script(*inputs)
+ else:
+ outputs = script(*inputs)
+ # Check the string frontend first
+ source = textwrap.dedent(inspect.getsource(script))
+ self.checkScript(
+ source,
+ inputs,
+ optimize,
+ outputs,
+ script.__name__,
+ capture_output,
+ frames_up=2,
+ check_expected=check_expected)
+ # Continue checking the Python frontend
+ ge = torch.jit.script(script, optimize, _frames_up=1)
+
+ if capture_output:
+ with self.capture_stdout() as captured:
+ outputs_ge = ge(*inputs)
+ if not WINDOWS:
+ self.assertExpected(captured[0], subname='stdout')
+ else:
+ outputs_ge = ge(*inputs)
+ self.assertEqual(outputs, outputs_ge)
+
+ if check_expected:
+ self.assertExpectedGraph(ge.graph)
+
+ return ge
+
def checkTrace(self, func, reference_tensors, input_tensors=None,
optimize=True, drop=None, allow_unused=False, verbose=False,
inputs_require_grads=True, check_tolerance=1e-5, export_import=True):
return ge
- def assertAllFused(self, graph, except_for=()):
- if [n.kind() for n in graph.nodes()] == ['prim::DifferentiableGraph']:
- graph = next(graph.nodes()).g('Subgraph')
- allowed_nodes = {'prim::Constant', 'prim::FusionGroup'} | set(except_for)
- self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
- 'got {}'.format(graph))
- self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
-
def assertExportImport(self, trace, inputs):
graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
m = torch.jit.ScriptModule()
self.assertExportImport(trace, (t,) + tuple(model.parameters()))
self.assertExpectedONNXGraph(trace)
- @unittest.skipIf(not IS_WINDOWS, "Testing Fuse skipped on windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- def test_windows_fuse(self):
- def scaleshift(x, scale, shift):
- return x * scale + shift
+ def test_canonicalize_tensor_iterator(self):
+ x = torch.randn(4, 4)
- graph = torch.jit.script(scaleshift).graph
+ def f(x):
+ x = x + 2
+ x = x - 4
+ x = x * 6
+ x = x / 8
+ return x
- inputs = [
- torch.randn(4, 4, dtype=torch.float, device='cuda'),
- torch.randn(4, dtype=torch.float, device='cuda'),
- torch.randn(4, dtype=torch.float, device='cuda'),
- ]
+ traced = torch.jit.trace(f, (x,))
+ f(x)
+ graph = traced.graph_for(x)
+ # There should be 4 int constants for the right sides of operators, plus two
+ # for alpha arguments for add and sub
+ self.assertTrue(str(traced.graph_for(x)).count(': int = prim::Constant'), 6)
- ge = self.checkTrace(scaleshift, inputs)
- fuse_graph = ge.graph_for(*inputs)
+ # TODO: adapt this test to check that GraphExecutor treats them differently
+ @unittest.skip("Need to be adjusted to Graph Executor")
+ def test_arg_configurations(self):
+ """Different arg configurations should trigger different traces"""
+ x = Variable(torch.FloatTensor(4, 4).uniform_())
+ x_double = Variable(x.data.double())
+ x_grad = Variable(x.data.clone(), requires_grad=True)
+ y = Variable(torch.randn(4))
- def run_graph(graph, inputs):
- m = torch.jit.ScriptModule()
- m._create_method_from_graph("forward", graph)
- return m(*inputs)
+ configurations = [
+ (x,),
+ (x_double,),
+ (x_grad,),
+ (y,),
+ ([x, x],),
+ ([x, y],),
+ ]
+ if torch.cuda.is_available():
+ x_cuda = Variable(x.data.cuda())
+ configurations += [
+ (x_cuda,),
+ ([x, x_cuda],),
+ ([x_cuda, x],),
+ ([[x_cuda, x]],),
+ ]
+ if torch.cuda.device_count() > 1:
+ x_cuda_1 = Variable(x.data.cuda(1))
+ configurations += [
+ (x_cuda_1,),
+ ([x_cuda, x_cuda_1],),
+ ]
- self.assertEqual(run_graph(graph, inputs), run_graph(fuse_graph, inputs))
+ @torch.jit.compile(nderivs=0)
+ def fn(*args):
+ in_vars, _ = torch._C._jit_flatten(args)
+ return in_vars[0] + 1
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_broadcast_fusion_cuda(self):
- def scaleshift(x, scale, shift):
- return x * scale + shift
+ for i, config in enumerate(configurations):
+ self.assertFalse(fn.has_trace_for(*config))
+ fn(*config)
+ self.assertTrue(fn.has_trace_for(*config))
+ for unk_config in configurations[i + 1:]:
+ self.assertFalse(fn.has_trace_for(*unk_config))
+ self.assertEqual(fn.hits, 0)
- inputs = [
- torch.randn(4, 4, dtype=torch.float, device='cuda'),
- torch.randn(4, dtype=torch.float, device='cuda'),
- torch.randn(4, dtype=torch.float, device='cuda'),
- ]
- ge = self.checkTrace(scaleshift, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs))
+ def test_cse(self):
+ x = torch.tensor([0.4, 0.3], requires_grad=True)
+ y = torch.tensor([0.7, 0.5], requires_grad=True)
- # TODO: Fuser doesn't work at all when inputs require grad. Fix that
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_lstm_fusion_cuda(self):
- inputs = get_lstm_inputs('cuda')
- ge = self.checkTrace(LSTMCellF, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs))
+ def fn(x, y):
+ w = (x + y) * (x + y) * (x + y)
+ t = torch.tanh(w) + torch.tanh(w)
+ z = (x + y) * (x + y) * (x + y) + t
+ return z
- @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
- @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
- @enable_cpu_fuser
- def test_lstm_fusion_cpu(self):
- inputs = get_lstm_inputs('cpu')
- try:
- ge = self.checkTrace(LSTMCellF, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs))
- except RuntimeError as e:
- if 'Failed to compile' in e.args[0]:
- warnings.warn('CPU fuser test has failed! This is not a hard failure, '
- 'because the kernels sometimes trigger bugs in compilers '
- '(most notably GCC 7.2).')
- raise unittest.SkipTest('Failed to compile')
- else:
- raise
+ trace, _ = torch.jit.get_trace_graph(fn, (x, y))
+ self.run_pass('cse', trace)
+ self.assertExpectedGraph(trace)
+ self.assertExportImport(trace, (x, y))
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_lstm_fusion_concat_cuda(self):
- inputs = get_lstm_inputs('cuda')
- ge = self.checkTrace(LSTMCellC, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs))
+ def test_recursive_cse(self):
+ x = torch.tensor([0.1])
+ y = torch.tensor([0.2])
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_concat_fusion_cuda(self):
- hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
- cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
+ def fn(x, y):
+ z = x
+ if bool(x + y > x):
+ z = x + y
+ return z
- def foo(hx, cx):
- return torch.cat((hx + cx, hx * cx))
+ graph = torch.jit.script(fn).graph
+ self.run_pass('cse', graph)
+ self.assertExpectedGraph(graph)
- ge = self.checkTrace(foo, (hx, cx))
- self.assertExpectedGraph(ge.graph_for(hx, cx))
+ def test_scalar(self):
+ # NB: must not require grad; if it requires grad, it's always a Tensor
+ x = torch.tensor(2.)
+ y = torch.tensor(3.)
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_concat_fusion_invariant_cuda(self):
- # Invariant: the output of prim::FusedConcat may
- # not be an input to any node inside the FusionGroup.
- def fn(x, y, z):
- x1 = x + y
- y1 = x - y
- w = torch.cat([x1, y1])
- return w + z
+ def fn(x, y):
+ return x - y
+ trace, _ = torch.jit.get_trace_graph(fn, (x, y))
- x = torch.randn(2, 2, dtype=torch.float, device='cuda')
- y = torch.randn(2, 2, dtype=torch.float, device='cuda')
- z = torch.randn(4, 2, dtype=torch.float, device='cuda')
- ge = self.checkTrace(fn, (x, y, z))
- self.assertExpectedGraph(ge.graph_for(x, y, z))
+ def test_shape_analysis_broadcast(self):
+ def broadcast(a, b):
+ return a + b
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_fusion_distribute_cuda(self):
- def f(x, y):
- z1, z2 = (x + y).chunk(2, dim=1)
- return z1 * z2
+ x = torch.randn(3, 1, 5, requires_grad=True)
+ y = torch.randn(4, 1, 8, 5, requires_grad=True)
- x = torch.randn(4, 4, dtype=torch.float, device='cuda')
- y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ graph = torch.jit.script(broadcast).graph
+ torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False)
+ self.assertExpectedGraph(graph)
- ge = self.checkTrace(f, (x, y))
- self.assertExpectedGraph(ge.graph_for(x, y))
+ # TODO: update verify to work with GraphExecutors
+ @unittest.skip("verify needs to be updated to work with GraphExecutors")
+ def test_verify(self):
+ x = torch.tensor([0.4], requires_grad=True)
+ y = torch.tensor([0.7], requires_grad=True)
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_fusion_rand_cuda(self):
- class M(torch.jit.ScriptModule):
- __constants__ = ['d']
+ @torch.jit.compile
+ def f(x, y):
+ z = torch.sigmoid(x * (x + y))
+ w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
+ return z, w
- def __init__(self):
- self.d = torch.device('cuda')
+ torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
- @torch.jit.script_method
- def create(self, x):
- return x * x + x + torch.rand_like(x)
+ @suppress_warnings
+ def test_constant(self):
+ x = torch.randn(2, 2, requires_grad=True)
- x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
- m = M()
- out1 = m.create(x)
- out2 = m.create(x)
- self.assertNotEqual(out1, out2)
- self.assertTrue(torch.all(out1 >= 0))
- self.assertTrue(torch.all(out1 < 1))
- self.assertTrue(torch.all(out2 >= 0))
- self.assertTrue(torch.all(out2 < 1))
+ def f(x):
+ return x.matmul(torch.diag(torch.tensor([2., 2.])))
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_fusion_arg_configurations_cuda(self):
- # A smoke test to make sure we won't use the same kernel for contiguous
- # and non-contiguous arguments.
- # TODO: add optionally enabled debug counters to the fuser to verify
- # that we really can tell the difference between configurations
- def f(x, y):
- z1, z2 = (x + y).chunk(2, dim=1)
- return z1 * z2
+ self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
- x = torch.randn(4, 4, dtype=torch.float, device='cuda')
- y = torch.randn(4, 4, dtype=torch.float, device='cuda')
- traced_f = torch.jit.trace(f, (x, y,))
- self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
+ def test_legacy_fail(self):
+ class MyLegacyFn(Function):
+ def forward(self, x):
+ return x
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_fusion_checks_cat_inputs(self):
- # We shouldn't treat cat nodes as broadcasting. All their inputs
- # need to be checked for having the same map size, before we can
- # run the kernel.
- @torch.jit.script
- def f(x, y):
- return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0)
+ def backward(self, grad_output):
+ return grad_output
- # NOTE: y is broadcastable to x, but output of f(x, y) should have
- # shape 3x4, and not 4x4.
- x = torch.randn(2, 4, dtype=torch.float, device='cuda')
- y = torch.randn(1, 4, dtype=torch.float, device='cuda')
+ x = torch.tensor([0.], requires_grad=True)
+ with self.assertRaisesRegex(RuntimeError, "MyLegacyFn"):
+ torch.jit.get_trace_graph(lambda x: MyLegacyFn()(x), (x,))
- self.assertEqual(f(x, y).shape, (3, 4))
- self.assertAllFused(f.graph_for(x, y))
+ def test_inplace_transplant(self):
+ x = torch.tensor([0.], requires_grad=True)
- @staticmethod
- def fn_test_comparison_gt_lt(x, y):
- mask = (x > 0).type_as(x)
- z = x * mask + y
- mask = (x < 0).type_as(x)
- z = z * mask + y
- return z
+ def fn(x):
+ y = x.clone()
+ y.add_(2)
+ y.add_(3)
+ return y
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_comparison_gt_lt_cuda(self):
- x = torch.randn(4, 4, dtype=torch.float, device='cuda')
- y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ trace, _ = torch.jit.get_trace_graph(fn, (x,))
+ self.assertExpectedGraph(trace)
+ self.assertExportImport(trace, (x,))
- ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
- self.assertAllFused(ge.graph_for(x, y))
+ def test_inplace_flags(self):
+ class InplaceFn(Function):
+ @staticmethod
+ def forward(ctx, x):
+ ctx.mark_dirty(x)
+ return x.add_(1)
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_comparison_ge_le_cuda(self):
- def f(x, y):
- mask = (x >= 0).type_as(x)
- z = x * mask + y
- mask = (x <= 0).type_as(x)
- z = z * mask + y
- return z
+ @staticmethod
+ def backward(ctx, go):
+ return go
- x = torch.randn(4, 4, dtype=torch.float, device='cuda')
- y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ class RegularFn(Function):
+ @staticmethod
+ def forward(ctx, x):
+ return x.add(1)
- ge = self.checkTrace(f, (x, y))
- self.assertAllFused(ge.graph_for(x, y))
+ @staticmethod
+ def backward(ctx, go):
+ return go
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_comparison_eq_ne(self):
- def f(x, y):
- mask = (x == 0).type_as(x)
- z = x * mask + y
- mask = (x != 0).type_as(x)
- z = z * mask + y
- return z
+ x = torch.tensor([0.], requires_grad=True)
- x = torch.randn(4, 4, dtype=torch.float, device='cuda')
- y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ def fn(x):
+ y = RegularFn.apply(x)
+ y = InplaceFn.apply(y)
+ y = InplaceFn.apply(y)
+ y = RegularFn.apply(y)
+ return y
- ge = self.checkTrace(f, (x, y))
- self.assertAllFused(ge.graph_for(x, y))
+ trace, _ = torch.jit.get_trace_graph(fn, (x,), _force_outplace=True)
+ self.run_pass('dce', trace)
+ ops = [n for n in trace.graph().nodes()]
+ for op in ops:
+ self.assertTrue(op.hasAttribute('inplace'))
+ inplace_flags = [False, True, True, False]
+ for op, is_inplace in zip(ops, inplace_flags):
+ self.assertEqual(op.i('inplace'), is_inplace)
- @staticmethod
- def fn_test_relu(x, y):
- return F.relu(x + .5 * y)
+ def test_inplace_check(self):
+ class MyInplaceFn(Function):
+ @staticmethod
+ def forward(self, x):
+ x.add_(1)
+ self.mark_dirty(x)
+ return x
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_relu_cuda(self):
- x = torch.randn(4, 4, dtype=torch.float, device='cuda')
- y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ @staticmethod
+ def backward(self, grad):
+ return grad
- ge = self.checkTrace(self.fn_test_relu, (x, y))
+ def fn(x):
+ return MyInplaceFn.apply(x)
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- def test_small_constant_cuda(self):
- def fn_test_small_constant(x, y):
- return (1e-8 * x + 5e-9 * y) * 1e8
- x = torch.randn(4, 4, dtype=torch.float, device='cuda')
- y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ x = torch.randn(5, 5)
+ ge = torch._C.GraphExecutor(fn, (x,), lambda var: '', _force_outplace=True)
+ with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
+ ge(x)
- ge = self.checkTrace(fn_test_small_constant, (x, y))
+ def do_trace_size(self, requires_grad):
+ def fn(x):
+ return x.view(x.shape[1] * 2, x.size(0), 2)
- @staticmethod
- def fn_test_exp(x, y):
- return (x + .5 * y).exp()
+ x = torch.randn(5, 2, 4, requires_grad=requires_grad)
+ y = torch.randn(4, 8, 4, requires_grad=requires_grad)
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_exp_cuda(self):
- x = torch.randn(4, 4, dtype=torch.float, device='cuda')
- y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ # Check that it behaves as expected
+ traced_fn = torch.jit.trace(fn, x)
+ self.assertEqual(traced_fn(y), fn(y))
+ self.assertEqual(traced_fn(x), fn(x))
- ge = self.checkTrace(self.fn_test_exp, (x, y))
+ # Check that the trace looks ok
+ trace, _ = torch.jit.get_trace_graph(fn, (x,))
+ self.assertExpectedGraph(trace)
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @unittest.skipIf(not RUN_CUDA_HALF, "no half support")
- def test_cuda_half(self):
- x = torch.randn(4, 4, dtype=torch.half, device='cuda')
- y = torch.randn(4, 4, dtype=torch.half, device='cuda')
+ def test_trace_size(self):
+ self.do_trace_size(False)
- funcs = [
- self.fn_test_comparison_gt_lt,
- self.fn_test_relu,
- self.fn_test_exp
- ]
+ # test the different graph_executor path that happens when
+ # gradients are required and sizes are involved
+ def test_trace_size_with_grad(self):
+ self.do_trace_size(True)
- # Note: Non fused inputs must be float to prevent loss of precision
- inputs = (x.float(), y.float())
- fusion_inputs = (x, y)
- for fn in funcs:
- local_inputs = [t.clone().requires_grad_() for t in inputs]
- local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
+ def test_trace_casts(self):
+ casts = [
+ lambda x: x.byte(),
+ lambda x: x.float(),
+ lambda x: x.cpu(),
+ lambda x: x.to(device='cpu'),
+ lambda x: x.to(dtype=torch.int64),
+ lambda x: x.to(device='cpu', dtype=torch.float),
+ lambda x: x.to(x)
+ ]
- # Verifies outputs
- fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False, optimize=True)
- outputs = fn(*local_inputs)
- fusion_outputs = fusion(*local_fusion_inputs)
- outputs_half = [t.half() for t in outputs]
- self.assertEqual(outputs_half, fusion_outputs)
+ def assertContainsCast(trace):
+ self.assertEqual(sum(n.kind() == 'aten::to' for n in trace.graph.nodes()), 1)
- # Verifies gradients
- for output, fusion_output in zip(outputs_half, fusion_outputs):
- grads = torch.autograd.grad(
- output.float().sum(), local_inputs, allow_unused=True, retain_graph=True)
- fusion_grads = torch.autograd.grad(
- fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True)
- grads_half = [t.half() for t in grads]
- self.assertEqual(grads_half, fusion_grads)
+ for cast in casts:
+ trace = torch.jit.trace(cast, torch.randn(2, 2))
+ assertContainsCast(trace)
+ x = torch.randn(2, 2)
+ self.assertEqual(trace(x), cast(x))
- def test_canonicalize_tensor_iterator(self):
- x = torch.randn(4, 4)
+ def to_tensor(x, y):
+ return x.to(y)
- def f(x):
- x = x + 2
- x = x - 4
- x = x * 6
- x = x / 8
- return x
+ to_tensor_trace = torch.jit.trace(to_tensor, (torch.randn(2, 2), torch.randn(1, 8)))
+ assertContainsCast(to_tensor_trace)
+ x, y = torch.randn(2, 2), torch.randn(1, 10)
+ self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
- traced = torch.jit.trace(f, (x,))
- f(x)
- graph = traced.graph_for(x)
- # There should be 4 int constants for the right sides of operators, plus two
- # for alpha arguments for add and sub
- self.assertTrue(str(traced.graph_for(x)).count(': int = prim::Constant'), 6)
+ def test_trace_warn(self):
+ def fn(x):
+ int(x) # Warning 1.
+ y = x * 1
+ if y: # Warning 2.
+ pass
+ q = [x, x * 4]
+ z = q[y] # Warning 3.
+ float(z) # Warning 4.
+ z.tolist() # Warning 5.
+ z.numpy() # Warning 6.
+ for elem in torch.ones(4, 4): # Warning 7.
+ pass
+ return z + 4
- # TODO: adapt this test to check that GraphExecutor treats them differently
- @unittest.skip("Need to be adjusted to Graph Executor")
- def test_arg_configurations(self):
- """Different arg configurations should trigger different traces"""
- x = Variable(torch.FloatTensor(4, 4).uniform_())
- x_double = Variable(x.data.double())
- x_grad = Variable(x.data.clone(), requires_grad=True)
- y = Variable(torch.randn(4))
+ with warnings.catch_warnings(record=True) as warns:
+ traced_fn = torch.jit.trace(fn, torch.tensor([1]))
+ warns = [str(w.message) for w in warns]
+ self.assertEqual(len(warns), 7)
+ self.assertIn('a Python integer', warns[0])
+ self.assertIn('a Python boolean', warns[1])
+ self.assertIn('a Python index', warns[2])
+ self.assertIn('a Python float', warns[3])
+ self.assertIn('a Python list', warns[4])
+ self.assertIn('a NumPy array', warns[5])
+ self.assertIn('Iterating over', warns[6])
- configurations = [
- (x,),
- (x_double,),
- (x_grad,),
- (y,),
- ([x, x],),
- ([x, y],),
- ]
- if torch.cuda.is_available():
- x_cuda = Variable(x.data.cuda())
- configurations += [
- (x_cuda,),
- ([x, x_cuda],),
- ([x_cuda, x],),
- ([[x_cuda, x]],),
- ]
- if torch.cuda.device_count() > 1:
- x_cuda_1 = Variable(x.data.cuda(1))
- configurations += [
- (x_cuda_1,),
- ([x_cuda, x_cuda_1],),
- ]
+ def test_trace_tuple(self):
+ def fn(x, y):
+ return x, (x * y[1], x * y[0])
- @torch.jit.compile(nderivs=0)
- def fn(*args):
- in_vars, _ = torch._C._jit_flatten(args)
- return in_vars[0] + 1
+ x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
+ traced_fn = torch.jit.trace(fn, (x, y))
+ self.assertEqual(traced_fn(x, y), fn(x, y))
+ self.assertExpectedGraph(traced_fn.graph)
+ self.assertExportImport(traced_fn.graph, (x, y))
- for i, config in enumerate(configurations):
- self.assertFalse(fn.has_trace_for(*config))
- fn(*config)
- self.assertTrue(fn.has_trace_for(*config))
- for unk_config in configurations[i + 1:]:
- self.assertFalse(fn.has_trace_for(*unk_config))
- self.assertEqual(fn.hits, 0)
+ def test_trace_random(self):
+ def f(mean, std):
+ return torch.normal(mean, std)
- def test_cse(self):
- x = torch.tensor([0.4, 0.3], requires_grad=True)
- y = torch.tensor([0.7, 0.5], requires_grad=True)
+ traced = torch.jit.trace(f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False)
+ mean, std = torch.zeros(5, 5), torch.ones(5, 5)
+ with torch.random.fork_rng(devices=[]):
+ output = f(mean, std)
+ traced_output = traced(mean, std)
+ self.assertEqual(output, traced_output)
- def fn(x, y):
- w = (x + y) * (x + y) * (x + y)
- t = torch.tanh(w) + torch.tanh(w)
- z = (x + y) * (x + y) * (x + y) + t
- return z
+ def test_trace_tensor_factory(self):
+ def run(**kwargs):
+ inputs_require_grads = kwargs.pop('inputs_require_grads', True)
- trace, _ = torch.jit.get_trace_graph(fn, (x, y))
- self.run_pass('cse', trace)
- self.assertExpectedGraph(trace)
- self.assertExportImport(trace, (x, y))
+ def fn(x):
+ return x + torch.ones(2, 3, **kwargs)
- def test_recursive_cse(self):
- x = torch.tensor([0.1])
- y = torch.tensor([0.2])
+ input_kwargs = kwargs.copy()
+ if 'out' in input_kwargs:
+ del input_kwargs['out']
+ input = torch.ones(2, 3, **input_kwargs)
+ self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
+ # check we recorded 'ones' and did not just record a constant
+ tfn = torch.jit.trace(fn, input)
+ self.assertTrue("ones" in str(tfn.graph))
+ run()
+ run(dtype=torch.int, inputs_require_grads=False)
+ run(out=torch.tensor([]))
+ if RUN_CUDA:
+ run(device="cuda:0")
+ if RUN_CUDA_MULTI_GPU:
+ run(device="cuda:1")
- def fn(x, y):
- z = x
- if bool(x + y > x):
- z = x + y
- return z
+ # TODO: implement
+ @unittest.expectedFailure
+ def test_output_unflatten(self):
+ """Check that outputs of traced functions retain the original structure and nesting"""
+ def fn(x):
+ return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
- graph = torch.jit.script(fn).graph
- self.run_pass('cse', graph)
- self.assertExpectedGraph(graph)
+ self.checkTrace(fn, (torch.randn(2, 2),))
- def test_scalar(self):
- # NB: must not require grad; if it requires grad, it's always a Tensor
- x = torch.tensor(2.)
- y = torch.tensor(3.)
+ # TODO: implement
+ @unittest.expectedFailure
+ def test_input_flatten(self):
+ """Check that inputs to traced functions are flattened"""
- def fn(x, y):
- return x - y
- trace, _ = torch.jit.get_trace_graph(fn, (x, y))
+ def fn(x, t):
+ y, z = t
+ return x * y * z
- def test_shape_analysis_broadcast(self):
- def broadcast(a, b):
- return a + b
+ inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
+ self.checkTrace(fn, inputs)
- x = torch.randn(3, 1, 5, requires_grad=True)
- y = torch.randn(4, 1, 8, 5, requires_grad=True)
+ # TODO: adapt to a GraphExecutor test
+ @unittest.skip("Need to instrument GraphExecutors a bit more")
+ def test_flags(self):
+ x, y = torch.randn(2, 2)
+ y = Variable(torch.randn(2, 2))
- graph = torch.jit.script(broadcast).graph
- torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False)
- self.assertExpectedGraph(graph)
+ @torch.jit.compile
+ def fn(x, y):
+ return (x * x + y * y + x * y).sum()
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
- @skipIfRocm
- def test_fuse_last_device_cuda(self):
- device = 'cuda:' + str(1)
- x = torch.tensor([0.4], dtype=torch.float, device=device)
- y = torch.tensor([0.7], dtype=torch.float, device=device)
+ grads = {}
+ for rx, ry in product((True, False), repeat=2):
+ x.requires_grad = rx
+ y.requires_grad = ry
- def doit(x, y):
- return torch.sigmoid(torch.tanh(x * (x + y) + x))
+ self.assertFalse(fn.has_trace_for(x, y))
+ out = fn(x, y)
- ge = self.checkTrace(doit, (x, y))
- self.assertExpectedGraph(ge.graph_for(x, y))
+ self.assertFalse(fn.has_trace_for(x, y))
+ for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
+ if not compute:
+ continue
+ grad_v, = torch.autograd.grad(out, v, retain_graph=True)
+ expected_grad = grads.setdefault(name, grad_v)
+ self.assertEqual(grad_v, expected_grad)
+ self.assertEqual(fn.has_trace_for(x, y), rx or ry)
- # TODO: update verify to work with GraphExecutors
- @unittest.skip("verify needs to be updated to work with GraphExecutors")
- def test_verify(self):
+ def test_python_ir(self):
x = torch.tensor([0.4], requires_grad=True)
y = torch.tensor([0.7], requires_grad=True)
- @torch.jit.compile
- def f(x, y):
- z = torch.sigmoid(x * (x + y))
- w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
- return z, w
+ def doit(x, y):
+ return torch.sigmoid(torch.tanh(x * (x + y)))
- torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
+ trace, _ = torch.jit.get_trace_graph(doit, (x, y))
+ self.run_pass('dce', trace)
+ self.run_pass('canonicalize', trace)
+ g = trace.graph()
+ g2 = torch._C.Graph()
+ g_to_g2 = {}
+ for node in g.inputs():
+ g_to_g2[node] = g2.addInput()
+ for node in g.nodes():
+ n_ = g2.createClone(node, lambda x: g_to_g2[x])
+ g2.appendNode(n_)
+ for o, no in zip(node.outputs(), n_.outputs()):
+ g_to_g2[o] = no
- @suppress_warnings
- def test_constant(self):
- x = torch.randn(2, 2, requires_grad=True)
+ for node in g.outputs():
+ g2.registerOutput(g_to_g2[node])
- def f(x):
- return x.matmul(torch.diag(torch.tensor([2., 2.])))
+ t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
+ self.assertEqual(t_node.attributeNames(), ["a"])
+ g2.appendNode(t_node)
+ self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a")))
+ self.assertExpected(str(g2))
- self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
+ @skipIfRocm
+ def test_cpp_cuda(self):
+ # rather than rebuild assertExpected in cpp,
+ # just glob all the cpp outputs into one file for now
+ self.assertExpected(torch._C._jit_run_cpp_tests())
- def test_legacy_fail(self):
- class MyLegacyFn(Function):
- def forward(self, x):
- return x
+ def test_batchnorm(self):
+ x = torch.ones(2, 2, 2, 2)
+ trace, _ = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x, _force_outplace=True)
+ self.assertExpectedGraph(trace)
- def backward(self, grad_output):
- return grad_output
+ def test_dropout(self):
+ x = torch.ones(2, 2)
+ trace, _ = torch.jit.get_trace_graph(nn.Dropout(0.6), x)
+ self.assertExpectedGraph(trace)
- x = torch.tensor([0.], requires_grad=True)
- with self.assertRaisesRegex(RuntimeError, "MyLegacyFn"):
- torch.jit.get_trace_graph(lambda x: MyLegacyFn()(x), (x,))
+ def test_conv(self):
+ x = torch.ones(20, 16, 50, 40)
+ trace, _ = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x)
+ self.assertExpectedGraph(trace)
- def test_inplace_transplant(self):
- x = torch.tensor([0.], requires_grad=True)
+ def test_repeated_input(self):
+ def fn(a, b):
+ return a + b
- def fn(x):
- y = x.clone()
- y.add_(2)
- y.add_(3)
- return y
+ ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
+ self.assertExpectedGraph(ge.graph)
- trace, _ = torch.jit.get_trace_graph(fn, (x,))
- self.assertExpectedGraph(trace)
- self.assertExportImport(trace, (x,))
+ def test_repeated_output(self):
+ def fn(a, b):
+ z = a + b
+ return z, z
- def test_inplace_flags(self):
- class InplaceFn(Function):
- @staticmethod
- def forward(ctx, x):
- ctx.mark_dirty(x)
- return x.add_(1)
+ ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
+ self.assertExpectedGraph(ge.graph)
- @staticmethod
- def backward(ctx, go):
- return go
+ @skipIfNoTorchVision
+ def test_alexnet(self):
+ x = torch.ones(1, 3, 224, 224)
+ trace, _ = torch.jit.get_trace_graph(torchvision.models.AlexNet(), x)
+ self.run_pass('cse', trace)
+ self.assertExpectedGraph(trace)
- class RegularFn(Function):
- @staticmethod
- def forward(ctx, x):
- return x.add(1)
+ # Inplace copies don't work with tracer yet.
+ # This is actually somewhat important to support correctly
+ # as all backwards functions of views are implemented
+ # as a zero filled tensor with a gradient fill on the
+ # viewed portion.
+ def test_inplace_copy(self):
+ x = torch.randn(4, 4, requires_grad=True)
- @staticmethod
- def backward(ctx, go):
- return go
+ def f(x):
+ out = Variable(torch.zeros(x.size()))
+ out.copy_(x)
+ return out
- x = torch.tensor([0.], requires_grad=True)
+ trace, z = torch.jit.get_trace_graph(f, (x, ))
+ self.run_pass('dce', trace)
+ self.assertExpectedGraph(trace)
+ self.assertExportImport(trace, (x,))
- def fn(x):
- y = RegularFn.apply(x)
- y = InplaceFn.apply(y)
- y = InplaceFn.apply(y)
- y = RegularFn.apply(y)
- return y
+ def test_shared_param(self):
- trace, _ = torch.jit.get_trace_graph(fn, (x,), _force_outplace=True)
- self.run_pass('dce', trace)
- ops = [n for n in trace.graph().nodes()]
- for op in ops:
- self.assertTrue(op.hasAttribute('inplace'))
- inplace_flags = [False, True, True, False]
- for op, is_inplace in zip(ops, inplace_flags):
- self.assertEqual(op.i('inplace'), is_inplace)
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super(MyModule, self).__init__()
+ self.b = self.a = nn.Parameter(torch.randn(2, 2))
- def test_inplace_check(self):
- class MyInplaceFn(Function):
- @staticmethod
def forward(self, x):
- x.add_(1)
- self.mark_dirty(x)
- return x
+ return x * self.a + self.b
- @staticmethod
- def backward(self, grad):
- return grad
+ m = MyModule()
+ trace, _ = torch.jit.get_trace_graph(m, (torch.randn(2, 2),))
+ self.assertEqual(len(list(trace.graph().inputs())), 2)
+ self.assertExpectedGraph(trace)
- def fn(x):
- return MyInplaceFn.apply(x)
+ def test_nested_inplace(self):
+ x = torch.randn(2, 2)
+ trace, _ = torch.jit.get_trace_graph(
+ lambda x: F.threshold(x, 0, 0, inplace=True), (x, ))
+ self.assertExpectedGraph(trace)
+ self.assertExportImport(trace, (x,))
- x = torch.randn(5, 5)
- ge = torch._C.GraphExecutor(fn, (x,), lambda var: '', _force_outplace=True)
- with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
- ge(x)
+ def run_ge_tests(self, optimize, use_cuda):
+ def rand(*args):
+ t = torch.rand(*args).float()
+ if use_cuda:
+ t = t.cuda()
+ return t
+ self.checkTrace(lambda a, b: a * b + b,
+ [rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
+ optimize=optimize)
+ # trivial identity
+ self.checkTrace(lambda a, b: (
+ b, a), [rand(1), rand(1)], optimize=optimize)
- def do_trace_size(self, requires_grad):
- def fn(x):
- return x.view(x.shape[1] * 2, x.size(0), 2)
+ def foo(a):
+ t = a * a
+ return t * t, 4 * t
+ self.checkTrace(foo, [rand(1)], optimize=optimize)
+ # unused input
+ self.checkTrace(
+ lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize,
+ allow_unused=True)
+ # test outputs that do not get used in grad
+ self.checkTrace(foo, [rand(1)], drop=1, optimize=optimize)
+ # test autograd fallback
+ self.checkTrace(lambda a, b: a * b /
+ (a - 2 * b) + b, [rand(1), rand(1)],
+ optimize=optimize)
- x = torch.randn(5, 2, 4, requires_grad=requires_grad)
- y = torch.randn(4, 8, 4, requires_grad=requires_grad)
+ def test_ge_unoptimized(self):
+ self.run_ge_tests(False, False)
- # Check that it behaves as expected
- traced_fn = torch.jit.trace(fn, x)
- self.assertEqual(traced_fn(y), fn(y))
- self.assertEqual(traced_fn(x), fn(x))
+ @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+ @enable_cpu_fuser
+ def test_ge_optimized(self):
+ self.run_ge_tests(True, False)
- # Check that the trace looks ok
- trace, _ = torch.jit.get_trace_graph(fn, (x,))
- self.assertExpectedGraph(trace)
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ @skipIfRocm
+ def test_ge_cuda(self):
+ self.run_ge_tests(True, True)
- def test_trace_size(self):
- self.do_trace_size(False)
+ # more manual test of graph executor that can be used as a scratchpad
+ def test_ge(self):
+ def foo(a, b):
+ return a * b / (a - b) + b
+ V = Variable
+ a, b = V(torch.rand(1)), V(torch.rand(1))
+ ge = torch._C.GraphExecutor(foo, (a, b), lambda var: '')
+ a, b = V(torch.rand(1), requires_grad=True), V(
+ torch.rand(1), requires_grad=True)
+ r, = ge(a, b)
+ da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
- # test the different graph_executor path that happens when
- # gradients are required and sizes are involved
- def test_trace_size_with_grad(self):
- self.do_trace_size(True)
+ l2 = (da * db + db * db)
+ g2result = torch.autograd.grad(l2, [da, db])
- def test_trace_casts(self):
- casts = [
- lambda x: x.byte(),
- lambda x: x.float(),
- lambda x: x.cpu(),
- lambda x: x.to(device='cpu'),
- lambda x: x.to(dtype=torch.int64),
- lambda x: x.to(device='cpu', dtype=torch.float),
- lambda x: x.to(x)
- ]
+ r = foo(a, b)
+ da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
+ self.assertEqual(da, da2)
+ self.assertEqual(db, db2)
+ l3 = (da2 * db2 + db2 * db2)
+ g2result2 = torch.autograd.grad(l3, [da2, db2])
+ self.assertEqual(g2result, g2result2)
- def assertContainsCast(trace):
- self.assertEqual(sum(n.kind() == 'aten::to' for n in trace.graph.nodes()), 1)
+ def test_trace_annotation(self):
+ @_trace(torch.rand(1))
+ def foo(a):
+ return a + a + a
- for cast in casts:
- trace = torch.jit.trace(cast, torch.randn(2, 2))
- assertContainsCast(trace)
- x = torch.randn(2, 2)
- self.assertEqual(trace(x), cast(x))
+ x = torch.randn(5, 5)
+ self.assertEqual(foo(x), x + x + x)
- def to_tensor(x, y):
- return x.to(y)
+ def test_trace_script(self):
+ @torch.jit.script
+ def func1(x):
+ # type: (Tuple[Tensor, Tensor]) -> Tensor
+ return x[0] + x[1]
- to_tensor_trace = torch.jit.trace(to_tensor, (torch.randn(2, 2), torch.randn(1, 8)))
- assertContainsCast(to_tensor_trace)
- x, y = torch.randn(2, 2), torch.randn(1, 10)
- self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
+ @torch.jit.script
+ def func2(x):
+ # type: (List[Tensor]) -> Tensor
+ return x[0] + x[1]
- def test_trace_warn(self):
- def fn(x):
- int(x) # Warning 1.
- y = x * 1
- if y: # Warning 2.
- pass
- q = [x, x * 4]
- z = q[y] # Warning 3.
- float(z) # Warning 4.
- z.tolist() # Warning 5.
- z.numpy() # Warning 6.
- for elem in torch.ones(4, 4): # Warning 7.
- pass
- return z + 4
+ a = torch.randn(5)
+ b = torch.randn(5)
- with warnings.catch_warnings(record=True) as warns:
- traced_fn = torch.jit.trace(fn, torch.tensor([1]))
- warns = [str(w.message) for w in warns]
- self.assertEqual(len(warns), 7)
- self.assertIn('a Python integer', warns[0])
- self.assertIn('a Python boolean', warns[1])
- self.assertIn('a Python index', warns[2])
- self.assertIn('a Python float', warns[3])
- self.assertIn('a Python list', warns[4])
- self.assertIn('a NumPy array', warns[5])
- self.assertIn('Iterating over', warns[6])
+ expected = func1((a, b))
+ traced = torch.jit.trace(func1, ((a, b),))
+ result = traced((a, b))
+ self.assertEqual(expected, result)
- def test_trace_tuple(self):
- def fn(x, y):
- return x, (x * y[1], x * y[0])
+ expected = func2((a, b))
+ traced = torch.jit.trace(func2, ((a, b),))
+ result = traced((a, b))
+ self.assertEqual(expected, result)
- x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
- traced_fn = torch.jit.trace(fn, (x, y))
- self.assertEqual(traced_fn(x, y), fn(x, y))
- self.assertExpectedGraph(traced_fn.graph)
- self.assertExportImport(traced_fn.graph, (x, y))
+ def test_einsum(self):
+ def outer(x, y):
+ return torch.einsum('i,j->ij', (x, y))
- def test_trace_random(self):
- def f(mean, std):
- return torch.normal(mean, std)
+ traced = torch.jit.trace(outer, (torch.randn(4), torch.randn(5)))
+ script = torch.jit.script(outer)
+ fns = [traced, script]
+ x, y = torch.randn(10), torch.randn(2)
+ for fn in [traced, script]:
+ self.assertGraphContains(fn.graph, kind='aten::einsum')
+ self.assertEqual(fn(x, y), outer(x, y))
- traced = torch.jit.trace(f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False)
- mean, std = torch.zeros(5, 5), torch.ones(5, 5)
- with torch.random.fork_rng(devices=[]):
- output = f(mean, std)
- traced_output = traced(mean, std)
- self.assertEqual(output, traced_output)
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "calls .cuda()")
+ @skipIfRocm
+ def test_traced_module_cuda(self):
+ class Model(nn.Module):
+ def __init__(self, num_features, num_layers):
+ super(Model, self).__init__()
+ self.num_layers = num_layers
+ layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()]
+ for _ in range(num_layers)]
+ self.submodule = nn.Sequential(*chain(*layers))
- def test_trace_tensor_factory(self):
- def run(**kwargs):
- inputs_require_grads = kwargs.pop('inputs_require_grads', True)
+ def forward(self, x):
+ for i in range(self.num_layers):
+ x = self.submodule[i](x) + x
+ return x
- def fn(x):
- return x + torch.ones(2, 3, **kwargs)
+ model = Model(5, 3)
+ x = torch.randn(2, 5)
+ traced_model = torch.jit.trace(model, x)
- input_kwargs = kwargs.copy()
- if 'out' in input_kwargs:
- del input_kwargs['out']
- input = torch.ones(2, 3, **input_kwargs)
- self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
- # check we recorded 'ones' and did not just record a constant
- tfn = torch.jit.trace(fn, input)
- self.assertTrue("ones" in str(tfn.graph))
- run()
- run(dtype=torch.int, inputs_require_grads=False)
- run(out=torch.tensor([]))
- if RUN_CUDA:
- run(device="cuda:0")
- if RUN_CUDA_MULTI_GPU:
- run(device="cuda:1")
+ # We're missing some attributes these modules had initially. Make sure we can
+ # still get the __repr__()
+ model.__repr__()
- # TODO: implement
- @unittest.expectedFailure
- def test_output_unflatten(self):
- """Check that outputs of traced functions retain the original structure and nesting"""
- def fn(x):
- return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
+ # XXX: indexing sequentials is broken
+ linear_submodule = next(iter(traced_model.submodule._modules.values()))
- self.checkTrace(fn, (torch.randn(2, 2),))
+ # All attributes that aren't parameters should raise
+ with self.assertRaises(AttributeError):
+ linear_submodule.in_features
+ linear_submodule.weight
+ with self.assertRaises(RuntimeError):
+ traced_model.asdf = 4
+ linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape))
+ with self.assertRaises(RuntimeError):
+ del linear_submodule.weight
- # TODO: implement
- @unittest.expectedFailure
- def test_input_flatten(self):
- """Check that inputs to traced functions are flattened"""
+ # Submodules can't be called
+ with self.assertRaises(RuntimeError):
+ linear_submodule(x)
- def fn(x, t):
- y, z = t
- return x * y * z
+ # Type casts
+ linear_submodule.cuda()
+ traced_model.float().cuda()
+ cuda_out = traced_model(x.float().cuda())
+ traced_model.cpu()
+ cpu_out = traced_model(x.float())
+ self.assertEqual(cpu_out, cuda_out)
+ traced_model.double()
- inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
- self.checkTrace(fn, inputs)
+ # state_dict + load_state_dict
+ state = {k: v.clone() for k, v in traced_model.state_dict().items()}
+ new_state = {k: v.clone().fill_(1) for k, v in state.items()}
+ out = traced_model(x)
+ traced_model.load_state_dict(new_state)
+ out_ones = traced_model(x)
+ traced_model.load_state_dict(state)
+ out_state = traced_model(x)
+ self.assertEqual(out, out_state)
+ self.assertNotEqual(out, out_ones)
- # TODO: adapt to a GraphExecutor test
- @unittest.skip("Need to instrument GraphExecutors a bit more")
- def test_flags(self):
- x, y = torch.randn(2, 2)
- y = Variable(torch.randn(2, 2))
+ def test_python_function(self):
+ class MyFn(Function):
+ @staticmethod
+ def forward(ctx, x):
+ return x + 1
- @torch.jit.compile
- def fn(x, y):
- return (x * x + y * y + x * y).sum()
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output
- grads = {}
- for rx, ry in product((True, False), repeat=2):
- x.requires_grad = rx
- y.requires_grad = ry
+ @_trace(torch.zeros(2))
+ def fn(x):
+ return MyFn.apply(x + 2) + 3
- self.assertFalse(fn.has_trace_for(x, y))
- out = fn(x, y)
+ x = torch.tensor([1., 2., 3.])
+ y = torch.randn(2, 2, requires_grad=True)
+ fn(x)
+ fn(y)
- self.assertFalse(fn.has_trace_for(x, y))
- for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
- if not compute:
- continue
- grad_v, = torch.autograd.grad(out, v, retain_graph=True)
- expected_grad = grads.setdefault(name, grad_v)
- self.assertEqual(grad_v, expected_grad)
- self.assertEqual(fn.has_trace_for(x, y), rx or ry)
+ def test_python_function_tup(self):
+ class MyFn(Function):
+ @staticmethod
+ def forward(ctx, x):
+ return x + 1, x - 1
- def test_python_ir(self):
- x = torch.tensor([0.4], requires_grad=True)
- y = torch.tensor([0.7], requires_grad=True)
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output, grad_output
- def doit(x, y):
- return torch.sigmoid(torch.tanh(x * (x + y)))
+ @_trace(torch.zeros(2))
+ def fn(x):
+ a, b = MyFn.apply(x + 2)
+ return a + b + 3
+ x = torch.tensor([1., 2., 3.])
+ y = torch.randn(2, 2, requires_grad=True)
+ fn(x)
+ fn(y)
- trace, _ = torch.jit.get_trace_graph(doit, (x, y))
- self.run_pass('dce', trace)
- self.run_pass('canonicalize', trace)
- g = trace.graph()
- g2 = torch._C.Graph()
- g_to_g2 = {}
- for node in g.inputs():
- g_to_g2[node] = g2.addInput()
- for node in g.nodes():
- n_ = g2.createClone(node, lambda x: g_to_g2[x])
- g2.appendNode(n_)
- for o, no in zip(node.outputs(), n_.outputs()):
- g_to_g2[o] = no
+ def test_decompose_addmm(self):
+ @torch.jit.script
+ def addmm(mat, mat1, mat2, alpha, beta):
+ a = mat.addmm(mat1, mat2)
+ b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
+ c = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
+ d = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
- for node in g.outputs():
- g2.registerOutput(g_to_g2[node])
+ return a + b + c + d
- t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
- self.assertEqual(t_node.attributeNames(), ["a"])
- g2.appendNode(t_node)
- self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a")))
- self.assertExpected(str(g2))
+ mat = torch.randn(2, 2)
+ mat1 = torch.randn(2, 4)
+ mat2 = torch.randn(4, 2)
+ alpha = torch.FloatTensor([123.0])
+ beta = torch.FloatTensor([321.0])
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
- @skipIfRocm
- def test_cpp_cuda(self):
- # rather than rebuild assertExpected in cpp,
- # just glob all the cpp outputs into one file for now
- self.assertExpected(torch._C._jit_run_cpp_tests())
+ out_ref = addmm(mat, mat1, mat2, alpha, beta)
+ self.run_pass('canonicalize_ops', addmm.graph)
+ out_test = addmm(mat, mat1, mat2, alpha, beta)
+ self.assertEqual(out_ref, out_test)
+ self.assertExpected(canonical(addmm.graph))
- def test_batchnorm(self):
- x = torch.ones(2, 2, 2, 2)
- trace, _ = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x, _force_outplace=True)
- self.assertExpectedGraph(trace)
+ def test_index_put(self):
+ ten = torch.zeros(3, 3)
+ mask = torch.Tensor([[True, True, True],
+ [True, False, False],
+ [True, True, False]]).byte()
- def test_dropout(self):
- x = torch.ones(2, 2)
- trace, _ = torch.jit.get_trace_graph(nn.Dropout(0.6), x)
- self.assertExpectedGraph(trace)
+ def test_fn(ten, mask):
+ ten[mask] = torch.ones(6)
+ return ten
- def test_conv(self):
- x = torch.ones(20, 16, 50, 40)
- trace, _ = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x)
- self.assertExpectedGraph(trace)
+ traced_test_fn = torch.jit.trace(test_fn, (ten, mask))
- def test_repeated_input(self):
- def fn(a, b):
- return a + b
+ ten = torch.rand(3, 3)
+ self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
- ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
- self.assertExpectedGraph(ge.graph)
-
- def test_repeated_output(self):
- def fn(a, b):
- z = a + b
- return z, z
+ def test_sparse_tensors_error(self):
+ def get_sparse():
+ return torch.sparse.FloatTensor(2, 3)
- ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
- self.assertExpectedGraph(ge.graph)
+ @torch.jit.script
+ def sparse(input):
+ output = get_sparse()
+ return output, input
- @skipIfNoTorchVision
- def test_alexnet(self):
- x = torch.ones(1, 3, 224, 224)
- trace, _ = torch.jit.get_trace_graph(torchvision.models.AlexNet(), x)
- self.run_pass('cse', trace)
- self.assertExpectedGraph(trace)
+ with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
+ sparse(get_sparse())
- # Inplace copies don't work with tracer yet.
- # This is actually somewhat important to support correctly
- # as all backwards functions of views are implemented
- # as a zero filled tensor with a gradient fill on the
- # viewed portion.
- def test_inplace_copy(self):
- x = torch.randn(4, 4, requires_grad=True)
+ with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
+ sparse(torch.tensor([1]))
- def f(x):
- out = Variable(torch.zeros(x.size()))
- out.copy_(x)
- return out
+ def test_tuple_specialization(self):
+ @torch.jit.script
+ def f(t):
+ # type: (Tuple[Tensor, Tensor]) -> Tensor
+ x, y = t
+ return x + y
- trace, z = torch.jit.get_trace_graph(f, (x, ))
- self.run_pass('dce', trace)
- self.assertExpectedGraph(trace)
- self.assertExportImport(trace, (x,))
+ t = torch.randn(2, 2), torch.randn(2, 2)
+ f(t)
+ graph = f.graph_for(t)
+ input_types = list(next(graph.inputs()).type().elements())
+ for t in input_types:
+ self.assertEqual(t.kind(), 'TensorType')
- def test_shared_param(self):
+ def test_constant_prop_simple(self):
+ @torch.jit.script
+ def constant_prop(input_tensor):
+ a = 2 * 3
+ b = a + 2
+ return b + input_tensor
- class MyModule(torch.nn.Module):
- def __init__(self):
- super(MyModule, self).__init__()
- self.b = self.a = nn.Parameter(torch.randn(2, 2))
+ x = torch.tensor(2)
+ out_ref = constant_prop(x)
+ self.run_pass('constant_propagation', constant_prop.graph)
+ out_test = constant_prop(torch.tensor(2))
+ self.assertEqual(out_ref, out_test)
+ self.assertExpected(canonical(constant_prop.graph))
- def forward(self, x):
- return x * self.a + self.b
+ def test_constant_prop_nested(self):
+ @torch.jit.script
+ def constant_prop(a):
+ b = 2 + 1
+ if bool(a < 2):
+ c = b + 2
+ else:
+ c = b - 2
+ return c
+ out_ref = constant_prop(torch.tensor(2))
+ self.run_pass('constant_propagation', constant_prop.graph)
+ out_test = constant_prop(torch.tensor(2))
+ self.assertEqual(out_ref, out_test)
+ self.assertExpected(canonical(constant_prop.graph))
- m = MyModule()
- trace, _ = torch.jit.get_trace_graph(m, (torch.randn(2, 2),))
- self.assertEqual(len(list(trace.graph().inputs())), 2)
- self.assertExpectedGraph(trace)
+ def test_constant_prop_print(self):
+ @torch.jit.script
+ def constant_prop(input_tensor):
+ a = 2 * 3
+ print(a)
+ b = a + 2
+ return b + input_tensor
- def test_nested_inplace(self):
- x = torch.randn(2, 2)
- trace, _ = torch.jit.get_trace_graph(
- lambda x: F.threshold(x, 0, 0, inplace=True), (x, ))
- self.assertExpectedGraph(trace)
- self.assertExportImport(trace, (x,))
+ self.run_pass('constant_propagation', constant_prop.graph)
+ self.assertExpected(canonical(constant_prop.graph))
- def run_ge_tests(self, optimize, use_cuda):
- def rand(*args):
- t = torch.rand(*args).float()
- if use_cuda:
- t = t.cuda()
- return t
- self.checkTrace(lambda a, b: a * b + b,
- [rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
- optimize=optimize)
- # trivial identity
- self.checkTrace(lambda a, b: (
- b, a), [rand(1), rand(1)], optimize=optimize)
+ def test_constant_prop_rand(self):
+ @torch.jit.script
+ def constant_prop():
+ a = torch.randn([3])
+ b = a + 2
+ return b
- def foo(a):
- t = a * a
- return t * t, 4 * t
- self.checkTrace(foo, [rand(1)], optimize=optimize)
- # unused input
- self.checkTrace(
- lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize,
- allow_unused=True)
- # test outputs that do not get used in grad
- self.checkTrace(foo, [rand(1)], drop=1, optimize=optimize)
- # test autograd fallback
- self.checkTrace(lambda a, b: a * b /
- (a - 2 * b) + b, [rand(1), rand(1)],
- optimize=optimize)
+ self.run_pass('constant_propagation', constant_prop.graph)
+ self.assertExpected(canonical(constant_prop.graph))
- def test_ge_unoptimized(self):
- self.run_ge_tests(False, False)
+ def test_trace_records_names(self):
+ def foo(bar, baz):
+ baz = bar + 3
+ quick_brown_fox = torch.neg(baz)
+ for i in range(20):
+ yeet = quick_brown_fox - 3.14
+ return yeet
- def _test_fused_abs(self, device='cpu'):
+ traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3)))
+ graph_str = str(traced.graph)
+ assert 'bar' in graph_str
+ assert 'baz' in graph_str
+ assert 'quick_brown_fox' in graph_str
+ def test_constant_prop_if_constant(self):
@torch.jit.script
- def func(x):
- return x.abs() * 2
-
- a = torch.randn(5, device=device)
- self.assertEqual(func(a), a.abs() * 2)
- self.assertAllFused(func.graph_for(a))
+ def constant_prop(a, b):
+ c0 = 1
+ c1 = 1
+ c2 = 1
+ if bool(a): # -> c0, c1
+ if bool(b): # -> c0
+ if True: # -> c0
+ c0 = c0 + 1
+ if False:
+ c1 = c1 + 1
+ c2 = c2 + 1
+ else: # -> c0, c1
+ c1 = c1 + 1
- @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
- @enable_cpu_fuser
- def test_fused_abs_cpu(self):
- self._test_fused_abs()
+ if True: # inlined
+ c0 = c0 + 1 # dynamic
+ c2 = c2 + 4 # set to 5
+ return a + c0 + c1 + c2
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @skipIfRocm
- def test_fused_abs_cuda(self):
- self._test_fused_abs(device="cuda")
+ self.run_pass('constant_propagation', constant_prop.graph)
+ self.assertExpected(canonical(constant_prop.graph))
- @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
- @enable_cpu_fuser
- def test_ge_optimized(self):
- self.run_ge_tests(True, False)
+ def test_constant_prop_loop_constant(self):
+ @torch.jit.script
+ def constant_prop():
+ b = 0
+ while True:
+ b = 1
+ while False:
+ b = 2
+ return b
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- @skipIfRocm
- def test_ge_cuda(self):
- self.run_ge_tests(True, True)
+ self.run_pass('constant_propagation', constant_prop.graph)
+ self.assertExpected(canonical(constant_prop.graph))
- @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
- @enable_cpu_fuser
- def test_fused_where_and_typing(self):
- def f(x, y):
- mask = x > y
- res = torch.where(mask, x, y)
- return mask, res
+ def test_trace_detach(self):
+ def foo(x, w):
+ return torch.matmul(x, w).detach()
- script_f = torch.jit.script(f)
+ traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
- x = torch.randn(4, 4, dtype=torch.double)
- y = torch.randn(4, 4, dtype=torch.double)
+ self.assertExpectedGraph(traced.graph)
+ x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
+ traced_result = traced(x, w)
+ self.assertEqual(foo(x, w), traced_result)
+ self.assertFalse(traced_result.requires_grad)
+ self.assertIsNone(traced_result.grad_fn)
- result1, result2 = script_f(x, y)
- expected1, expected2 = f(x, y)
- self.assertEqual(result1, expected1)
- self.assertEqual(result2, expected2)
- self.assertAllFused(script_f.graph_for(x, y))
+ def test_trace_detach_inplace(self):
+ def foo(x, w):
+ y = torch.matmul(x, w)
+ y.detach_()
+ return y
- # more manual test of graph executor that can be used as a scratchpad
- def test_ge(self):
- def foo(a, b):
- return a * b / (a - b) + b
- V = Variable
- a, b = V(torch.rand(1)), V(torch.rand(1))
- ge = torch._C.GraphExecutor(foo, (a, b), lambda var: '')
- a, b = V(torch.rand(1), requires_grad=True), V(
- torch.rand(1), requires_grad=True)
- r, = ge(a, b)
- da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
+ traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
- l2 = (da * db + db * db)
- g2result = torch.autograd.grad(l2, [da, db])
+ self.assertExpectedGraph(traced.graph)
+ x, w = torch.rand(3, 4), torch.rand(4, 5)
+ traced_result = traced(x, w)
+ self.assertEqual(foo(x, w), traced_result)
+ self.assertFalse(traced_result.requires_grad)
+ self.assertIsNone(traced_result.grad_fn)
- r = foo(a, b)
- da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
- self.assertEqual(da, da2)
- self.assertEqual(db, db2)
- l3 = (da2 * db2 + db2 * db2)
- g2result2 = torch.autograd.grad(l3, [da2, db2])
- self.assertEqual(g2result, g2result2)
+ def test_trace_detach_onnx_erase(self):
+ class Mod(torch.nn.Module):
+ def forward(self, x, w):
+ return torch.matmul(x, w).detach()
- def test_trace_annotation(self):
- @_trace(torch.rand(1))
- def foo(a):
- return a + a + a
+ f = io.BytesIO()
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f))
- x = torch.randn(5, 5)
- self.assertEqual(foo(x), x + x + x)
+ def test_trace_slice_full_dim(self):
+ def foo(x):
+ return x[0:5, 0] + 1.0
- def test_trace_script(self):
- @torch.jit.script
- def func1(x):
- # type: (Tuple[Tensor, Tensor]) -> Tensor
- return x[0] + x[1]
+ traced = torch.jit.trace(foo, (torch.rand(5, 4),))
+ test_x = torch.rand(6, 3)
+ self.assertEqual(foo(test_x), traced(test_x))
- @torch.jit.script
- def func2(x):
- # type: (List[Tensor]) -> Tensor
- return x[0] + x[1]
+ def test_export_expand_aten_fallback(self):
+ class ExpandTest(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x):
+ y = x
+ for i in range(5):
+ y = x.expand([3, 4, i])
+ return y
- a = torch.randn(5)
- b = torch.randn(5)
+ mod = ExpandTest()
+ example_outs = mod(torch.rand(3, 4, 1))
+ f = io.BytesIO()
+ with self.assertRaisesRegex(RuntimeError, 'Could not export a broadcasted operation'):
+ torch.onnx.export_to_pretty_string(mod, (torch.rand(3, 4, 1),), f, verbose=False,
+ example_outputs=example_outs)
- expected = func1((a, b))
- traced = torch.jit.trace(func1, ((a, b),))
- result = traced((a, b))
- self.assertEqual(expected, result)
+ self.assertExpected(
+ torch.onnx.export_to_pretty_string(mod, (torch.rand(3, 4, 1),), f, verbose=False,
+ example_outputs=example_outs,
+ operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK))
- expected = func2((a, b))
- traced = torch.jit.trace(func2, ((a, b),))
- result = traced((a, b))
- self.assertEqual(expected, result)
+ def test_export_dropout(self):
+ test = torch.nn.Dropout()
+ test.eval()
- def test_einsum(self):
- def outer(x, y):
- return torch.einsum('i,j->ij', (x, y))
+ traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False)
+ imported = self.getExportImportCopy(traced)
+ x = torch.randn(3, 4)
+ self.assertEqual(traced(x), imported(x))
- traced = torch.jit.trace(outer, (torch.randn(4), torch.randn(5)))
- script = torch.jit.script(outer)
- fns = [traced, script]
- x, y = torch.randn(10), torch.randn(2)
- for fn in [traced, script]:
- self.assertGraphContains(fn.graph, kind='aten::einsum')
- self.assertEqual(fn(x, y), outer(x, y))
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ def test_cuda_export_restore(self):
+ class Sub(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Sub, self).__init__()
+ self.weight = nn.Parameter(torch.randn(3, 4))
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "calls .cuda()")
- @skipIfRocm
- def test_traced_module_cuda(self):
- class Model(nn.Module):
- def __init__(self, num_features, num_layers):
- super(Model, self).__init__()
- self.num_layers = num_layers
- layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()]
- for _ in range(num_layers)]
- self.submodule = nn.Sequential(*chain(*layers))
+ @torch.jit.script_method
+ def forward(self, thing):
+ return self.weight + thing
- def forward(self, x):
- for i in range(self.num_layers):
- x = self.submodule[i](x) + x
- return x
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M, self).__init__()
+ self.mod = Sub()
- model = Model(5, 3)
- x = torch.randn(2, 5)
- traced_model = torch.jit.trace(model, x)
+ @torch.jit.script_method
+ def forward(self, v):
+ return self.mod(v)
+ m = M()
+ m.cuda()
+ m2 = self.getExportImportCopy(m)
+ m2.cuda()
+ input = torch.rand(3, 4).cuda()
+ self.assertEqual(m(input), m2(input))
- # We're missing some attributes these modules had initially. Make sure we can
- # still get the __repr__()
- model.__repr__()
+ def test_export_batchnorm(self):
+ for mode in ['eval', 'train']:
+ for clazz in [
+ torch.nn.BatchNorm1d(100),
+ torch.nn.BatchNorm1d(100, affine=False),
+ torch.nn.BatchNorm2d(100),
+ torch.nn.BatchNorm2d(100, affine=False)]:
+ getattr(clazz, mode)()
- # XXX: indexing sequentials is broken
- linear_submodule = next(iter(traced_model.submodule._modules.values()))
+ input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
+ torch.randn(20, 100, 35, 45)
- # All attributes that aren't parameters should raise
- with self.assertRaises(AttributeError):
- linear_submodule.in_features
- linear_submodule.weight
- with self.assertRaises(RuntimeError):
- traced_model.asdf = 4
- linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape))
- with self.assertRaises(RuntimeError):
- del linear_submodule.weight
+ traced = torch.jit.trace(clazz, (input,))
+ imported = self.getExportImportCopy(traced)
+ x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
+ torch.randn(20, 100, 35, 45)
+ self.assertEqual(traced(x), imported(x))
- # Submodules can't be called
- with self.assertRaises(RuntimeError):
- linear_submodule(x)
+ def test_export_rnn(self):
+ for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]:
+ class RNNTest(torch.nn.Module):
+ def __init__(self):
+ super(RNNTest, self).__init__()
+ self.rnn = clazz
- # Type casts
- linear_submodule.cuda()
- traced_model.float().cuda()
- cuda_out = traced_model(x.float().cuda())
- traced_model.cpu()
- cpu_out = traced_model(x.float())
- self.assertEqual(cpu_out, cuda_out)
- traced_model.double()
+ def forward(self, x, lengths, h0):
+ packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
+ out, h = self.rnn(packed, h0)
+ padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
+ return padded_outs
- # state_dict + load_state_dict
- state = {k: v.clone() for k, v in traced_model.state_dict().items()}
- new_state = {k: v.clone().fill_(1) for k, v in state.items()}
- out = traced_model(x)
- traced_model.load_state_dict(new_state)
- out_ones = traced_model(x)
- traced_model.load_state_dict(state)
- out_state = traced_model(x)
- self.assertEqual(out, out_state)
- self.assertNotEqual(out, out_ones)
+ test = RNNTest()
- def test_python_function(self):
- class MyFn(Function):
- @staticmethod
- def forward(ctx, x):
- return x + 1
+ traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20)))
+ imported = self.getExportImportCopy(traced)
+ # NB: We make sure to pass in a batch with a different max sequence
+ # length to ensure that the argument stashing for pad_packed works
+ # properly.
+ x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20)
+ self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0))
- @staticmethod
- def backward(ctx, grad_output):
- return grad_output
+ def test_export_lstm(self):
+ class LSTMTest(torch.nn.Module):
+ def __init__(self):
+ super(LSTMTest, self).__init__()
+ self.rnn = nn.LSTM(10, 20, 2)
- @_trace(torch.zeros(2))
- def fn(x):
- return MyFn.apply(x + 2) + 3
+ def forward(self, x, lengths, hiddens):
+ h0, c0 = hiddens
+ packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
+ out, (h, c) = self.rnn(packed, (h0, c0))
+ padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
+ return padded_outs
- x = torch.tensor([1., 2., 3.])
- y = torch.randn(2, 2, requires_grad=True)
- fn(x)
- fn(y)
+ test = LSTMTest()
- def test_python_function_tup(self):
- class MyFn(Function):
- @staticmethod
- def forward(ctx, x):
- return x + 1, x - 1
+ traced = torch.jit.trace(test, (torch.randn(5, 3, 10),
+ torch.LongTensor([3, 2, 1]),
+ (torch.randn(2, 3, 20), torch.randn(2, 3, 20))))
+ imported = self.getExportImportCopy(traced)
+ x, lengths, h0, c0 = \
+ torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
+ self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
- @staticmethod
- def backward(ctx, grad_output):
- return grad_output, grad_output
+ def test_trace_variable_instantiation(self):
+ def random_foo(x):
+ return Variable(Variable(x) + 1.0)
- @_trace(torch.zeros(2))
- def fn(x):
- a, b = MyFn.apply(x + 2)
- return a + b + 3
- x = torch.tensor([1., 2., 3.])
- y = torch.randn(2, 2, requires_grad=True)
- fn(x)
- fn(y)
+ random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
- def test_decompose_addmm(self):
- @torch.jit.script
- def addmm(mat, mat1, mat2, alpha, beta):
- a = mat.addmm(mat1, mat2)
- b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
- c = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
- d = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
+ x = torch.rand(5, 6)
+ self.assertEqual(random_foo(x), random_foo_traced(x))
- return a + b + c + d
+ def test_trace_slice_expr_complete_type(self):
+ def random_foo(x):
+ return x + 1.0
- mat = torch.randn(2, 2)
- mat1 = torch.randn(2, 4)
- mat2 = torch.randn(4, 2)
- alpha = torch.FloatTensor([123.0])
- beta = torch.FloatTensor([321.0])
+ random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
- out_ref = addmm(mat, mat1, mat2, alpha, beta)
- self.run_pass('canonicalize_ops', addmm.graph)
- out_test = addmm(mat, mat1, mat2, alpha, beta)
- self.assertEqual(out_ref, out_test)
- self.assertExpected(canonical(addmm.graph))
+ @torch.jit.script
+ def random_bar(x):
+ return random_foo_traced(x)[0:1]
- def test_index_put(self):
- ten = torch.zeros(3, 3)
- mask = torch.Tensor([[True, True, True],
- [True, False, False],
- [True, True, False]]).byte()
+ x = torch.rand(3, 4)
+ self.assertEqual(random_bar(x), (x + 1)[0:1])
- def test_fn(ten, mask):
- ten[mask] = torch.ones(6)
- return ten
+ def test_export_tensoroption_to(self):
+ def foo(x):
+ return x.new_tensor(x[0]).cpu() + x
- traced_test_fn = torch.jit.trace(test_fn, (ten, mask))
+ traced = torch.jit.trace(foo, (torch.rand([2])))
+ example_outputs = traced(torch.rand([2]))
- ten = torch.rand(3, 3)
- self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
+ f = io.BytesIO()
+ self.assertExpected(torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
+ example_outputs=example_outputs))
- def test_sparse_tensors_error(self):
- def get_sparse():
- return torch.sparse.FloatTensor(2, 3)
+ def test_pretty_printer(self):
+ @torch.jit.script
+ def if_test(a, b):
+ # FIXME: use 0 instead of a.
+ # c = 0
+ c = a
+ if bool(a < b):
+ c = b
+ else:
+ c = a
+ return c
@torch.jit.script
- def sparse(input):
- output = get_sparse()
- return output, input
+ def if_one(a, b):
+ c = b
+ if bool(a < b):
+ c = a
+ return c
- with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
- sparse(get_sparse())
+ @torch.jit.script
+ def while_test(a, i):
+ while bool(i < 3):
+ a *= a
+ i += 1
+ return a
- with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
- sparse(torch.tensor([1]))
+ @torch.jit.script
+ def while_if_test(a, b):
+ c = 0
+ while bool(a < 10):
+ a = a + 1
+ b = b + 1
+ if bool(a > b):
+ c = 2
+ else:
+ c = 3
+ return a + 1 + c
- def test_tuple_specialization(self):
@torch.jit.script
- def f(t):
- # type: (Tuple[Tensor, Tensor]) -> Tensor
- x, y = t
- return x + y
+ def loop_use_test(y):
+ x = y + 1
+ z = x + 5
+ while bool(y < 8):
+ y += 1
+ z = x
+ return x, z
- t = torch.randn(2, 2), torch.randn(2, 2)
- f(t)
- graph = f.graph_for(t)
- input_types = list(next(graph.inputs()).type().elements())
- for t in input_types:
- self.assertEqual(t.kind(), 'TensorType')
+ def python_fn(x):
+ return x + 10
- def test_constant_prop_simple(self):
@torch.jit.script
- def constant_prop(input_tensor):
- a = 2 * 3
- b = a + 2
- return b + input_tensor
+ def python_op_name_test(y):
+ return python_fn(y)
- x = torch.tensor(2)
- out_ref = constant_prop(x)
- self.run_pass('constant_propagation', constant_prop.graph)
- out_test = constant_prop(torch.tensor(2))
- self.assertEqual(out_ref, out_test)
- self.assertExpected(canonical(constant_prop.graph))
+ @torch.jit.script
+ def empty_int_list_test(y):
+ x = torch.jit.annotate(List[int], [])
+ return x[0]
- def test_constant_prop_nested(self):
@torch.jit.script
- def constant_prop(a):
- b = 2 + 1
- if bool(a < 2):
- c = b + 2
- else:
- c = b - 2
- return c
- out_ref = constant_prop(torch.tensor(2))
- self.run_pass('constant_propagation', constant_prop.graph)
- out_test = constant_prop(torch.tensor(2))
- self.assertEqual(out_ref, out_test)
- self.assertExpected(canonical(constant_prop.graph))
+ def empty_float_list_test(y):
+ return [1.0, 2.0, 3.0]
- def test_constant_prop_print(self):
@torch.jit.script
- def constant_prop(input_tensor):
- a = 2 * 3
- print(a)
- b = a + 2
- return b + input_tensor
+ def print_weird_test(y):
+ print("hi\016")
- self.run_pass('constant_propagation', constant_prop.graph)
- self.assertExpected(canonical(constant_prop.graph))
+ self.assertExpected(if_test.graph.pretty_print(), "if_test")
+ self.assertExpected(if_one.graph.pretty_print(), "if_one")
+ self.assertExpected(while_test.graph.pretty_print(), "while_test")
+ self.assertExpected(while_if_test.graph.pretty_print(), "while_if_test")
+ self.assertExpected(loop_use_test.graph.pretty_print(), "loop_use_test")
+ self.assertExpected(python_op_name_test.graph.pretty_print(), "python_op_name_test")
+ self.assertExpected(empty_int_list_test.graph.pretty_print(), "empty_int_list_test")
+ self.assertExpected(empty_float_list_test.graph.pretty_print(), "empty_float_list_test")
+ self.assertExpected(print_weird_test.graph.pretty_print(), "print_weird_test")
- def test_constant_prop_rand(self):
- @torch.jit.script
- def constant_prop():
- a = torch.randn([3])
- b = a + 2
- return b
+ def test_cu_escaped_number(self):
+ cu = torch.jit.CompilationUnit('''
+ def foo(a):
+ print("hi\016")
+ ''')
+ self.assertExpected(cu.foo.graph.pretty_print())
- self.run_pass('constant_propagation', constant_prop.graph)
- self.assertExpected(canonical(constant_prop.graph))
+ def test_import_method(self):
+ @torch.jit.script
+ def foo(x, y):
+ return 2 * x + y
- def test_trace_records_names(self):
- def foo(bar, baz):
- baz = bar + 3
- quick_brown_fox = torch.neg(baz)
- for i in range(20):
- yeet = quick_brown_fox - 3.14
- return yeet
+ r, _ = foo._python_print()
+ mod = torch.jit.ScriptModule()
+ torch._C._jit_import_methods(mod, "op_version_set = 0\n{}".format(r), [])
+ self.assertExpected(mod.graph.pretty_print())
- traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3)))
- graph_str = str(traced.graph)
- assert 'bar' in graph_str
- assert 'baz' in graph_str
- assert 'quick_brown_fox' in graph_str
+ def test_function_default_values(self):
+ outer_var = torch.tensor(20)
+ outer_var2 = torch.tensor(30)
+ a = torch.tensor(0.5)
+ b = torch.tensor(10)
- def test_constant_prop_if_constant(self):
@torch.jit.script
- def constant_prop(a, b):
- c0 = 1
- c1 = 1
- c2 = 1
- if bool(a): # -> c0, c1
- if bool(b): # -> c0
- if True: # -> c0
- c0 = c0 + 1
- if False:
- c1 = c1 + 1
- c2 = c2 + 1
- else: # -> c0, c1
- c1 = c1 + 1
+ def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
+ return x + a + b + c
- if True: # inlined
- c0 = c0 + 1 # dynamic
- c2 = c2 + 4 # set to 5
- return a + c0 + c1 + c2
+ self.assertExpectedGraph(simple_fn.graph, "simple")
+ self.assertEqual(
+ simple_fn(torch.ones(1)),
+ torch.ones(1) + 0.5 + 10 + (20 + 30))
+ self.assertEqual(
+ simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)),
+ torch.ones(1) + 1 + 3 + 4)
- self.run_pass('constant_propagation', constant_prop.graph)
- self.assertExpected(canonical(constant_prop.graph))
+ outer_c = torch.tensor(9)
+ outer_flag = torch.tensor(False)
- def test_constant_prop_loop_constant(self):
@torch.jit.script
- def constant_prop():
- b = 0
- while True:
- b = 1
- while False:
- b = 2
- return b
-
- self.run_pass('constant_propagation', constant_prop.graph)
- self.assertExpected(canonical(constant_prop.graph))
+ def bool_fn(x, a=outer_c, flag=outer_flag):
+ if bool(flag):
+ result = x
+ else:
+ result = x + a
+ return result
- def test_trace_detach(self):
- def foo(x, w):
- return torch.matmul(x, w).detach()
+ self.assertExpectedGraph(bool_fn.graph, "bool")
+ self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
+ self.assertEqual(
+ bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)),
+ torch.ones(1))
- traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
+ @torch.jit.script
+ def none_fn(x=None):
+ # type: (Optional[int]) -> Optional[int]
+ return x
- self.assertExpectedGraph(traced.graph)
- x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
- traced_result = traced(x, w)
- self.assertEqual(foo(x, w), traced_result)
- self.assertFalse(traced_result.requires_grad)
- self.assertIsNone(traced_result.grad_fn)
+ self.assertExpectedGraph(none_fn.graph, "none")
+ self.assertEqual(none_fn(), None)
+ self.assertEqual(none_fn(1), 1)
- def test_trace_detach_inplace(self):
- def foo(x, w):
- y = torch.matmul(x, w)
- y.detach_()
- return y
-
- traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
+ @torch.jit.script
+ def hints(x, a=0.5, b=10):
+ # type: (Tensor, float, int) -> Tensor
+ return x + a + b
- self.assertExpectedGraph(traced.graph)
- x, w = torch.rand(3, 4), torch.rand(4, 5)
- traced_result = traced(x, w)
- self.assertEqual(foo(x, w), traced_result)
- self.assertFalse(traced_result.requires_grad)
- self.assertIsNone(traced_result.grad_fn)
+ self.assertExpectedGraph(hints.graph, "type_hints")
+ self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
- def test_trace_detach_onnx_erase(self):
- class Mod(torch.nn.Module):
- def forward(self, x, w):
- return torch.matmul(x, w).detach()
+ with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
- f = io.BytesIO()
- self.assertExpected(torch.onnx.export_to_pretty_string(
- Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f))
+ @torch.jit.script
+ def hints_bad_types(x, a=10, b=0.5):
+ # type: (Tensor, float, int) -> Tensor
+ return x + a + b
- def test_trace_slice_full_dim(self):
- def foo(x):
- return x[0:5, 0] + 1.0
+ def test_module_default_values(self):
+ four = torch.tensor(4)
- traced = torch.jit.trace(foo, (torch.rand(5, 4),))
- test_x = torch.rand(6, 3)
- self.assertEqual(foo(test_x), traced(test_x))
+ class Test(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Test, self).__init__()
- def test_export_expand_aten_fallback(self):
- class ExpandTest(torch.jit.ScriptModule):
@torch.jit.script_method
- def forward(self, x):
- y = x
- for i in range(5):
- y = x.expand([3, 4, i])
- return y
+ def forward(self, input, other=four):
+ return input + other
- mod = ExpandTest()
- example_outs = mod(torch.rand(3, 4, 1))
- f = io.BytesIO()
- with self.assertRaisesRegex(RuntimeError, 'Could not export a broadcasted operation'):
- torch.onnx.export_to_pretty_string(mod, (torch.rand(3, 4, 1),), f, verbose=False,
- example_outputs=example_outs)
+ t = Test()
+ self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
- self.assertExpected(
- torch.onnx.export_to_pretty_string(mod, (torch.rand(3, 4, 1),), f, verbose=False,
- example_outputs=example_outs,
- operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK))
+ def test_warnings(self):
+ import warnings
- def test_export_dropout(self):
- test = torch.nn.Dropout()
- test.eval()
+ @torch.jit.script
+ def fn(x):
+ if bool(x < 2):
+ warnings.warn("x is less than 2")
+ return x
- traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False)
- imported = self.getExportImportCopy(traced)
- x = torch.randn(3, 4)
- self.assertEqual(traced(x), imported(x))
+ self.assertExpectedGraph(fn.graph)
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- def test_cuda_export_restore(self):
- class Sub(torch.jit.ScriptModule):
- def __init__(self):
- super(Sub, self).__init__()
- self.weight = nn.Parameter(torch.randn(3, 4))
- @torch.jit.script_method
- def forward(self, thing):
- return self.weight + thing
+class TestBatched(TestCase):
+ # generate random examples and create an batchtensor with them
+ def rand_batch(self, *dims):
+ dims = [dim for dim in dims if dim != ()]
+ xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:]),
+ requires_grad=True) for i in range(dims[0])]
+ xb = BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]).byte())
+ return xs, xb
- class M(torch.jit.ScriptModule):
- def __init__(self):
- super(M, self).__init__()
- self.mod = Sub()
+ def test_create_batchtensor(self):
+ # create from tensorlist
+ xs, batch = self.rand_batch(4, (True, 3), (False, 2), (True, 5))
+ self.assertEqual(xs, batch.examples())
+ # create from data, mask, dims
+ batch2 = BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims())
+ self.assertEqual(xs, batch2.examples())
+ # expand a tensor to a batchtensor given batch_size
+ xs = torch.rand(3, 4, 5)
+ batch3 = BatchTensor(xs, 2)
+ xs = xs.unsqueeze(0)
+ self.assertEqual([xs, xs], batch3.examples())
- @torch.jit.script_method
- def forward(self, v):
- return self.mod(v)
- m = M()
- m.cuda()
- m2 = self.getExportImportCopy(m)
- m2.cuda()
- input = torch.rand(3, 4).cuda()
- self.assertEqual(m(input), m2(input))
+ def test_batch_elementwise_unary(self):
+ @torch.jit.batch(batch_size=4)
+ def tanh(a):
+ return torch.tanh(a)
- def test_export_batchnorm(self):
- for mode in ['eval', 'train']:
- for clazz in [
- torch.nn.BatchNorm1d(100),
- torch.nn.BatchNorm1d(100, affine=False),
- torch.nn.BatchNorm2d(100),
- torch.nn.BatchNorm2d(100, affine=False)]:
- getattr(clazz, mode)()
+ xs, batch = self.rand_batch(4, (True, 3), (False, 2))
+ res_batch = tanh(batch)
+ res = [torch.tanh(xs[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
- torch.randn(20, 100, 35, 45)
+ def test_batch_elementwise_binary(self):
+ @torch.jit.batch(batch_size=4)
+ def add(a, b):
+ return a + b
- traced = torch.jit.trace(clazz, (input,))
- imported = self.getExportImportCopy(traced)
- x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
- torch.randn(20, 100, 35, 45)
- self.assertEqual(traced(x), imported(x))
+ xs, batch = self.rand_batch(4, (True, 3), (False, 2))
+ xs2, batch2 = xs, batch
+ res_batch = add(batch, batch2)
+ res = [torch.add(xs[j], xs2[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- def test_export_rnn(self):
- for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]:
- class RNNTest(torch.nn.Module):
- def __init__(self):
- super(RNNTest, self).__init__()
- self.rnn = clazz
+ # test broadcast
+ xs, batch = self.rand_batch(4, (False, 3), (False, 2))
+ b = torch.rand(3, 2)
+ res_batch = add(batch, b)
+ res = [torch.add(xs[j], b) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- def forward(self, x, lengths, h0):
- packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
- out, h = self.rnn(packed, h0)
- padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
- return padded_outs
+ def test_batch_mm(self):
+ @torch.jit.batch(batch_size=4)
+ def mm(a, b):
+ return torch.mm(a, b)
- test = RNNTest()
+ xs, batch = self.rand_batch(4, (True, 3), (False, 2))
+ xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
+ res_batch = mm(batch, batch2)
+ res = [torch.mm(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20)))
- imported = self.getExportImportCopy(traced)
- # NB: We make sure to pass in a batch with a different max sequence
- # length to ensure that the argument stashing for pad_packed works
- # properly.
- x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20)
- self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0))
+ # test broadcast
+ b = torch.rand(2, 4)
+ res_batch = mm(batch, b)
+ res = [torch.mm(xs[j].squeeze(0), b).unsqueeze(0) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- def test_export_lstm(self):
- class LSTMTest(torch.nn.Module):
- def __init__(self):
- super(LSTMTest, self).__init__()
- self.rnn = nn.LSTM(10, 20, 2)
+ def test_batch_matmul(self):
+ @torch.jit.batch(batch_size=4)
+ def matmul(a, b):
+ return torch.matmul(a, b)
- def forward(self, x, lengths, hiddens):
- h0, c0 = hiddens
- packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
- out, (h, c) = self.rnn(packed, (h0, c0))
- padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
- return padded_outs
+ def matmul_test(xs, batch, xs2, batch2):
+ ys = [torch.matmul(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
+ ybs = matmul(batch, batch2)
+ self.assertEqual(ys, ybs.examples())
- test = LSTMTest()
+ # 1 dimension * 1 dimension
+ xs, batch = self.rand_batch(4, (False, 2))
+ xs2, batch2 = self.rand_batch(4, (False, 2))
+ matmul_test(xs, batch, xs2, batch2)
+ # 1 dimension * 2 dimension
+ xs, batch = self.rand_batch(4, (False, 2))
+ xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
+ matmul_test(xs, batch, xs2, batch2)
+ # 2 dimension * 1 dimensions
+ xs, batch = self.rand_batch(4, (True, 3), (False, 2))
+ xs2, batch2 = self.rand_batch(4, (False, 2))
+ matmul_test(xs, batch, xs2, batch2)
+ # 2 dimension * 2 dimension
+ xs, batch = self.rand_batch(4, (True, 3), (False, 2))
+ xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
+ matmul_test(xs, batch, xs2, batch2)
- traced = torch.jit.trace(test, (torch.randn(5, 3, 10),
- torch.LongTensor([3, 2, 1]),
- (torch.randn(2, 3, 20), torch.randn(2, 3, 20))))
- imported = self.getExportImportCopy(traced)
- x, lengths, h0, c0 = \
- torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
- self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
+ def test_batch_select(self):
+ @torch.jit.batch(batch_size=4)
+ def select(x):
+ return torch.select(x, 1, 0)
- def test_trace_variable_instantiation(self):
- def random_foo(x):
- return Variable(Variable(x) + 1.0)
+ xs, batch = self.rand_batch(4, (True, 3), (True, 2))
+ res_batch = select(batch)
+ res = [torch.select(xs[j], 1, 0) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
+ xs, batch = self.rand_batch(4, (False, 3), (True, 2))
+ res_batch = select(batch)
+ res = [torch.select(xs[j], 1, 0) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- x = torch.rand(5, 6)
- self.assertEqual(random_foo(x), random_foo_traced(x))
+ def test_batch_index_select(self):
+ @torch.jit.batch(batch_size=4)
+ def index_select(x, ind):
+ return x.index_select(1, ind)
- def test_trace_slice_expr_complete_type(self):
- def random_foo(x):
- return x + 1.0
+ xs, batch = self.rand_batch(4, (False, 5), (True, 2))
+ ind = [torch.randint(0, 4, (1,), dtype=torch.long) for i in range(4)]
+ ind_batch = BatchTensor(ind, torch.tensor([]).byte())
+ res_batch = index_select(batch, ind_batch)
+ res = [torch.index_select(xs[j], 1, ind[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
+ def test_batch_where(self):
+ @torch.jit.batch(batch_size=4)
+ def where(c, a, b):
+ return torch.where(c, a, b)
- @torch.jit.script
- def random_bar(x):
- return random_foo_traced(x)[0:1]
+ xs, batch = self.rand_batch(4, (False, 3), (False, 2))
+ xs2, batch2 = self.rand_batch(4, (False, 3), (False, 2))
- x = torch.rand(3, 4)
- self.assertEqual(random_bar(x), (x + 1)[0:1])
+ dims = [4, (False, 3), (False, 2)]
+ xs_cond = [torch.rand(1, 3, 2).byte() for i in range(dims[0])]
+ batch_cond = BatchTensor(xs_cond, torch.tensor([b for b, d in dims[1:]]))
- def test_export_tensoroption_to(self):
- def foo(x):
- return x.new_tensor(x[0]).cpu() + x
+ res_batch = where(batch_cond, batch, batch2)
+ res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- traced = torch.jit.trace(foo, (torch.rand([2])))
- example_outputs = traced(torch.rand([2]))
+ def test_batch_argmax(self):
+ @torch.jit.batch(batch_size=4)
+ def argmax(a):
+ return torch.argmax(a, 1)
- f = io.BytesIO()
- self.assertExpected(torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
- example_outputs=example_outputs))
+ xs, batch = self.rand_batch(4, (True, 5), (True, 6))
+ res_batch = argmax(batch)
+ res = [torch.argmax(xs[j], 1) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- def test_pretty_printer(self):
- @torch.jit.script
- def if_test(a, b):
- # FIXME: use 0 instead of a.
- # c = 0
- c = a
- if bool(a < b):
- c = b
- else:
- c = a
- return c
+ @torch.jit.batch(batch_size=4)
+ def argmax(a):
+ return torch.argmax(a, 1, False)
- @torch.jit.script
- def if_one(a, b):
- c = b
- if bool(a < b):
- c = a
- return c
+ res_batch = argmax(batch)
+ res = [torch.argmax(xs[j], 1, False) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- @torch.jit.script
- def while_test(a, i):
- while bool(i < 3):
- a *= a
- i += 1
- return a
+ def test_batch_topk(self):
+ @torch.jit.batch(batch_size=4)
+ def topk(a):
+ return torch.topk(a, 3, 1)
- @torch.jit.script
- def while_if_test(a, b):
- c = 0
- while bool(a < 10):
- a = a + 1
- b = b + 1
- if bool(a > b):
- c = 2
- else:
- c = 3
- return a + 1 + c
+ xs, batch = self.rand_batch(4, (False, 5), (True, 6))
- @torch.jit.script
- def loop_use_test(y):
- x = y + 1
- z = x + 5
- while bool(y < 8):
- y += 1
- z = x
- return x, z
+ # along static dim
+ res_batch = topk(batch)
+ res = [torch.topk(xs[j], 3, 1)[0] for j in range(4)]
+ res_idx = [torch.topk(xs[j], 3, 1)[1] for j in range(4)]
+ self.assertEqual(res, res_batch[0].examples())
+ self.assertEqual(res_idx, res_batch[1].examples())
- def python_fn(x):
- return x + 10
+ @torch.jit.batch(batch_size=4)
+ def topk(a):
+ return torch.topk(a, 1, 2)
- @torch.jit.script
- def python_op_name_test(y):
- return python_fn(y)
+ # along dynamic dim
+ res_batch = topk(batch)
+ res = [torch.topk(xs[j], 1, 2)[0] for j in range(4)]
+ res_idx = [torch.topk(xs[j], 1, 2)[1] for j in range(4)]
+ self.assertEqual(res, res_batch[0].examples())
+ self.assertEqual(res_idx, res_batch[1].examples())
- @torch.jit.script
- def empty_int_list_test(y):
- x = torch.jit.annotate(List[int], [])
- return x[0]
+ def test_batch_softmax(self):
+ @torch.jit.batch(batch_size=4)
+ def softmax(a):
+ return torch.softmax(a, 1)
- @torch.jit.script
- def empty_float_list_test(y):
- return [1.0, 2.0, 3.0]
+ xs, batch = self.rand_batch(4, (False, 5), (True, 6))
- @torch.jit.script
- def print_weird_test(y):
- print("hi\016")
+ # along static dim
+ res_batch = softmax(batch)
+ res = [torch.softmax(xs[j], 1) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- self.assertExpected(if_test.graph.pretty_print(), "if_test")
- self.assertExpected(if_one.graph.pretty_print(), "if_one")
- self.assertExpected(while_test.graph.pretty_print(), "while_test")
- self.assertExpected(while_if_test.graph.pretty_print(), "while_if_test")
- self.assertExpected(loop_use_test.graph.pretty_print(), "loop_use_test")
- self.assertExpected(python_op_name_test.graph.pretty_print(), "python_op_name_test")
- self.assertExpected(empty_int_list_test.graph.pretty_print(), "empty_int_list_test")
- self.assertExpected(empty_float_list_test.graph.pretty_print(), "empty_float_list_test")
- self.assertExpected(print_weird_test.graph.pretty_print(), "print_weird_test")
+ @torch.jit.batch(batch_size=4)
+ def softmax(a):
+ return torch.softmax(a, 2)
- def test_cu_escaped_number(self):
- cu = torch.jit.CompilationUnit('''
- def foo(a):
- print("hi\016")
- ''')
- self.assertExpected(cu.foo.graph.pretty_print())
+ # along dynamic dim
+ res_batch = softmax(batch)
+ res = [torch.softmax(xs[j], 2) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- def test_import_method(self):
- @torch.jit.script
- def foo(x, y):
- return 2 * x + y
+ def test_batch_view(self):
+ @torch.jit.batch(batch_size=4)
+ def view(a):
+ return a.view([4, -1, 3])
- r, _ = foo._python_print()
- mod = torch.jit.ScriptModule()
- torch._C._jit_import_methods(mod, "op_version_set = 0\n{}".format(r), [])
- self.assertExpected(mod.graph.pretty_print())
+ xs, batch = self.rand_batch(4, (True, 5), (False, 3))
+ res_batch = view(batch)
+ res = [xs[j].view([1, -1, 3]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- def test_function_default_values(self):
- outer_var = torch.tensor(20)
- outer_var2 = torch.tensor(30)
- a = torch.tensor(0.5)
- b = torch.tensor(10)
+ def test_batch_cat(self):
+ @torch.jit.batch(batch_size=4)
+ def cat2(a, b):
+ return torch.cat([a, b], 2)
- @torch.jit.script
- def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
- return x + a + b + c
+ xs, batch = self.rand_batch(4, (True, 5), (False, 3))
+ xs2, batch2 = xs, batch
+ res_batch = cat2(batch, batch2)
+ res = [torch.cat([xs[j], xs2[j]], 2) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- self.assertExpectedGraph(simple_fn.graph, "simple")
- self.assertEqual(
- simple_fn(torch.ones(1)),
- torch.ones(1) + 0.5 + 10 + (20 + 30))
- self.assertEqual(
- simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)),
- torch.ones(1) + 1 + 3 + 4)
+ def test_batch_sum(self):
+ @torch.jit.batch(batch_size=4)
+ def batch_sum(a):
+ return a.sum()
- outer_c = torch.tensor(9)
- outer_flag = torch.tensor(False)
+ xs, batch = self.rand_batch(4, (True, 5), (False, 3))
+ res_batch = batch_sum(batch)
+ res = [xs[j].sum().unsqueeze(0) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- @torch.jit.script
- def bool_fn(x, a=outer_c, flag=outer_flag):
- if bool(flag):
- result = x
+ def test_if_else(self):
+ def single_if(a, b):
+ if bool(a > b):
+ a = a + b
else:
- result = x + a
- return result
-
- self.assertExpectedGraph(bool_fn.graph, "bool")
- self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
- self.assertEqual(
- bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)),
- torch.ones(1))
+ a = a - b
+ return a
- @torch.jit.script
- def none_fn(x=None):
- # type: (Optional[int]) -> Optional[int]
- return x
+ batch_if = torch.jit.batch(batch_size=4)(single_if)
- self.assertExpectedGraph(none_fn.graph, "none")
- self.assertEqual(none_fn(), None)
- self.assertEqual(none_fn(1), 1)
+ a, batch_a = self.rand_batch(4, ())
+ b, batch_b = self.rand_batch(4, ())
+ res_batch = batch_if(batch_a, batch_b)
+ res = [single_if(a[j], b[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- @torch.jit.script
- def hints(x, a=0.5, b=10):
- # type: (Tensor, float, int) -> Tensor
- return x + a + b
+ script_if = torch.jit.script(single_if)
+ graph = torch.to_batch_graph(script_if.graph)
+ self.assertExpected(canonical(graph))
- self.assertExpectedGraph(hints.graph, "type_hints")
- self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
+ def test_if_else_with_scalar(self):
+ def single_if(a, b):
+ if bool(a > 0.1):
+ a = a + b
+ else:
+ a = a - b
+ return a
- with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
+ batch_if = torch.jit.batch(batch_size=4)(single_if)
- @torch.jit.script
- def hints_bad_types(x, a=10, b=0.5):
- # type: (Tensor, float, int) -> Tensor
- return x + a + b
+ a, batch_a = self.rand_batch(4, ())
+ b, batch_b = self.rand_batch(4, ())
+ res_batch = batch_if(batch_a, batch_b)
+ res = [single_if(a[j], b[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- def test_module_default_values(self):
- four = torch.tensor(4)
+ script_if = torch.jit.script(single_if)
+ graph = torch.to_batch_graph(script_if.graph)
+ self.assertExpected(canonical(graph))
- class Test(torch.jit.ScriptModule):
- def __init__(self):
- super(Test, self).__init__()
+ def test_if_noelse(self):
+ def single_if(a, b):
+ if bool(a > b):
+ a = a + b
+ return a
- @torch.jit.script_method
- def forward(self, input, other=four):
- return input + other
+ batch_if = torch.jit.batch(batch_size=4)(single_if)
- t = Test()
- self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
+ a, batch_a = self.rand_batch(4, ())
+ b, batch_b = self.rand_batch(4, ())
+ res_batch = batch_if(batch_a, batch_b)
+ res = [single_if(a[j], b[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
- def test_warnings(self):
- import warnings
+ script_if = torch.jit.script(single_if)
+ graph = torch.to_batch_graph(script_if.graph)
+ self.assertExpected(canonical(graph))
- @torch.jit.script
- def fn(x):
- if bool(x < 2):
- warnings.warn("x is less than 2")
- return x
+ def test_if_noelse_with_scalar(self):
+ def single_if(a, b):
+ if bool(a > 0.1):
+ a = a + b
+ return a
- self.assertExpectedGraph(fn.graph)
+ batch_if = torch.jit.batch(batch_size=4)(single_if)
+ a, batch_a = self.rand_batch(4, ())
+ b, batch_b = self.rand_batch(4, ())
+ res_batch = batch_if(batch_a, batch_b)
+ res = [single_if(a[j], b[j]) for j in range(4)]
+ self.assertEqual(res, res_batch.examples())
-class TestBatched(TestCase):
- # generate random examples and create an batchtensor with them
- def rand_batch(self, *dims):
- dims = [dim for dim in dims if dim != ()]
- xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:]),
- requires_grad=True) for i in range(dims[0])]
- xb = BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]).byte())
- return xs, xb
+ script_if = torch.jit.script(single_if)
+ graph = torch.to_batch_graph(script_if.graph)
+ self.assertExpected(canonical(graph))
- def test_create_batchtensor(self):
- # create from tensorlist
- xs, batch = self.rand_batch(4, (True, 3), (False, 2), (True, 5))
- self.assertEqual(xs, batch.examples())
- # create from data, mask, dims
- batch2 = BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims())
- self.assertEqual(xs, batch2.examples())
- # expand a tensor to a batchtensor given batch_size
- xs = torch.rand(3, 4, 5)
- batch3 = BatchTensor(xs, 2)
- xs = xs.unsqueeze(0)
- self.assertEqual([xs, xs], batch3.examples())
+ def test_while(self):
+ def single_while(a, b):
+ while bool(a > b):
+ a = a - b
+ return a
- def test_batch_elementwise_unary(self):
- @torch.jit.batch(batch_size=4)
- def tanh(a):
- return torch.tanh(a)
+ batch_while = torch.jit.batch(batch_size=4)(single_while)
- xs, batch = self.rand_batch(4, (True, 3), (False, 2))
- res_batch = tanh(batch)
- res = [torch.tanh(xs[j]) for j in range(4)]
+ a, batch_a = self.rand_batch(4, ())
+ b = [torch.abs(torch.rand(1)) for i in range(4)]
+ batch_b = BatchTensor(b, torch.tensor([]).byte())
+ res_batch = batch_while(batch_a, batch_b)
+ res = [single_while(a[j], b[j]) for j in range(4)]
self.assertEqual(res, res_batch.examples())
- def test_batch_elementwise_binary(self):
- @torch.jit.batch(batch_size=4)
- def add(a, b):
- return a + b
-
- xs, batch = self.rand_batch(4, (True, 3), (False, 2))
- xs2, batch2 = xs, batch
- res_batch = add(batch, batch2)
- res = [torch.add(xs[j], xs2[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
+ script_while = torch.jit.script(single_while)
+ graph = torch.to_batch_graph(script_while.graph)
+ self.assertExpected(canonical(graph))
- # test broadcast
- xs, batch = self.rand_batch(4, (False, 3), (False, 2))
- b = torch.rand(3, 2)
- res_batch = add(batch, b)
- res = [torch.add(xs[j], b) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
+ def test_for(self):
+ def single_for(x, y):
+ for _ in range(10):
+ x = x + y
+ return x
- def test_batch_mm(self):
- @torch.jit.batch(batch_size=4)
- def mm(a, b):
- return torch.mm(a, b)
+ batch_for = torch.jit.batch(batch_size=4)(single_for)
- xs, batch = self.rand_batch(4, (True, 3), (False, 2))
- xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
- res_batch = mm(batch, batch2)
- res = [torch.mm(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
+ a, batch_a = self.rand_batch(4, ())
+ b, batch_b = self.rand_batch(4, ())
+ res_batch = batch_for(batch_a, batch_b)
+ res = [single_for(a[j], b[j]) for j in range(4)]
self.assertEqual(res, res_batch.examples())
- # test broadcast
- b = torch.rand(2, 4)
- res_batch = mm(batch, b)
- res = [torch.mm(xs[j].squeeze(0), b).unsqueeze(0) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
+ script_for = torch.jit.script(single_for)
+ graph = torch.to_batch_graph(script_for.graph)
+ self.assertExpected(canonical(graph))
- def test_batch_matmul(self):
- @torch.jit.batch(batch_size=4)
- def matmul(a, b):
- return torch.matmul(a, b)
+ def test_lstm(self):
+ def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
+ for i in range(x_all.size(1)):
+ x = x_all.select(1, i)
+ i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
+ f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
+ o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
+ # activations
+ i_t = torch.sigmoid(i_t)
+ f_t = torch.sigmoid(f_t)
+ o_t = torch.sigmoid(o_t)
+ # cell computations
+ c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
+ c_t = torch.tanh(c_t)
+ c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
+ h_t = torch.mul(o_t, torch.tanh(c_t))
+ h = h_t
+ c = c_t
+ return h
- def matmul_test(xs, batch, xs2, batch2):
- ys = [torch.matmul(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
- ybs = matmul(batch, batch2)
- self.assertEqual(ys, ybs.examples())
+ LSTM_batch = torch.jit.batch(batch_size=4)(LSTM)
- # 1 dimension * 1 dimension
- xs, batch = self.rand_batch(4, (False, 2))
- xs2, batch2 = self.rand_batch(4, (False, 2))
- matmul_test(xs, batch, xs2, batch2)
- # 1 dimension * 2 dimension
- xs, batch = self.rand_batch(4, (False, 2))
- xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
- matmul_test(xs, batch, xs2, batch2)
- # 2 dimension * 1 dimensions
- xs, batch = self.rand_batch(4, (True, 3), (False, 2))
- xs2, batch2 = self.rand_batch(4, (False, 2))
- matmul_test(xs, batch, xs2, batch2)
- # 2 dimension * 2 dimension
- xs, batch = self.rand_batch(4, (True, 3), (False, 2))
- xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
- matmul_test(xs, batch, xs2, batch2)
+ batch_size, input_size, hidden_size = 4, 3, 2
+ xs, batch = self.rand_batch(batch_size, (True, 4), (False, input_size))
+ hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
+ cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
- def test_batch_select(self):
- @torch.jit.batch(batch_size=4)
- def select(x):
- return torch.select(x, 1, 0)
+ # input to hidden weights
+ w_xi = torch.rand(input_size, hidden_size)
+ w_xf = torch.rand(input_size, hidden_size)
+ w_xo = torch.rand(input_size, hidden_size)
+ w_xc = torch.rand(input_size, hidden_size)
+ # hidden to hidden weights
+ w_hi = torch.rand(hidden_size, hidden_size)
+ w_hf = torch.rand(hidden_size, hidden_size)
+ w_ho = torch.rand(hidden_size, hidden_size)
+ w_hc = torch.rand(hidden_size, hidden_size)
+ # bias terms
+ b_i = torch.rand(hidden_size)
+ b_f = torch.rand(hidden_size)
+ b_o = torch.rand(hidden_size)
+ b_c = torch.rand(hidden_size)
- xs, batch = self.rand_batch(4, (True, 3), (True, 2))
- res_batch = select(batch)
- res = [torch.select(xs[j], 1, 0) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
+ ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
+ w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) for j in range(batch_size)]
+ ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
+ w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
+ self.assertEqual(ys, ybs.examples())
- xs, batch = self.rand_batch(4, (False, 3), (True, 2))
- res_batch = select(batch)
- res = [torch.select(xs[j], 1, 0) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
+ def test_greedy_search(self):
+ def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
+ b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
+ iter_count = torch.zeros_like(iter_num)
+ while bool(iter_count < iter_num):
+ iter_count = iter_count + 1
+ # LSTM Cell
+ i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
+ f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
+ o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
+ # activations
+ i_t = torch.sigmoid(i_t)
+ f_t = torch.sigmoid(f_t)
+ o_t = torch.sigmoid(o_t)
+ # cell computations
+ c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
+ c_t = torch.tanh(c_t)
+ c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
+ h_t = torch.mul(o_t, torch.tanh(c_t))
+ h = h_t
+ c = c_t
+ # calculate feature with max probability
+ s_t = torch.matmul(h_t, w_hs) + b_s
+ p_t = torch.softmax(s_t, 1)
+ i_t = torch.argmax(p_t, 1)
+ x = embed.index_select(1, i_t).squeeze(1)
+ return h
- def test_batch_index_select(self):
- @torch.jit.batch(batch_size=4)
- def index_select(x, ind):
- return x.index_select(1, ind)
-
- xs, batch = self.rand_batch(4, (False, 5), (True, 2))
- ind = [torch.randint(0, 4, (1,), dtype=torch.long) for i in range(4)]
- ind_batch = BatchTensor(ind, torch.tensor([]).byte())
- res_batch = index_select(batch, ind_batch)
- res = [torch.index_select(xs[j], 1, ind[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_where(self):
- @torch.jit.batch(batch_size=4)
- def where(c, a, b):
- return torch.where(c, a, b)
-
- xs, batch = self.rand_batch(4, (False, 3), (False, 2))
- xs2, batch2 = self.rand_batch(4, (False, 3), (False, 2))
-
- dims = [4, (False, 3), (False, 2)]
- xs_cond = [torch.rand(1, 3, 2).byte() for i in range(dims[0])]
- batch_cond = BatchTensor(xs_cond, torch.tensor([b for b, d in dims[1:]]))
-
- res_batch = where(batch_cond, batch, batch2)
- res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_argmax(self):
- @torch.jit.batch(batch_size=4)
- def argmax(a):
- return torch.argmax(a, 1)
-
- xs, batch = self.rand_batch(4, (True, 5), (True, 6))
- res_batch = argmax(batch)
- res = [torch.argmax(xs[j], 1) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- @torch.jit.batch(batch_size=4)
- def argmax(a):
- return torch.argmax(a, 1, False)
-
- res_batch = argmax(batch)
- res = [torch.argmax(xs[j], 1, False) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_topk(self):
- @torch.jit.batch(batch_size=4)
- def topk(a):
- return torch.topk(a, 3, 1)
-
- xs, batch = self.rand_batch(4, (False, 5), (True, 6))
-
- # along static dim
- res_batch = topk(batch)
- res = [torch.topk(xs[j], 3, 1)[0] for j in range(4)]
- res_idx = [torch.topk(xs[j], 3, 1)[1] for j in range(4)]
- self.assertEqual(res, res_batch[0].examples())
- self.assertEqual(res_idx, res_batch[1].examples())
-
- @torch.jit.batch(batch_size=4)
- def topk(a):
- return torch.topk(a, 1, 2)
-
- # along dynamic dim
- res_batch = topk(batch)
- res = [torch.topk(xs[j], 1, 2)[0] for j in range(4)]
- res_idx = [torch.topk(xs[j], 1, 2)[1] for j in range(4)]
- self.assertEqual(res, res_batch[0].examples())
- self.assertEqual(res_idx, res_batch[1].examples())
-
- def test_batch_softmax(self):
- @torch.jit.batch(batch_size=4)
- def softmax(a):
- return torch.softmax(a, 1)
-
- xs, batch = self.rand_batch(4, (False, 5), (True, 6))
-
- # along static dim
- res_batch = softmax(batch)
- res = [torch.softmax(xs[j], 1) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- @torch.jit.batch(batch_size=4)
- def softmax(a):
- return torch.softmax(a, 2)
-
- # along dynamic dim
- res_batch = softmax(batch)
- res = [torch.softmax(xs[j], 2) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_view(self):
- @torch.jit.batch(batch_size=4)
- def view(a):
- return a.view([4, -1, 3])
-
- xs, batch = self.rand_batch(4, (True, 5), (False, 3))
- res_batch = view(batch)
- res = [xs[j].view([1, -1, 3]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_cat(self):
- @torch.jit.batch(batch_size=4)
- def cat2(a, b):
- return torch.cat([a, b], 2)
-
- xs, batch = self.rand_batch(4, (True, 5), (False, 3))
- xs2, batch2 = xs, batch
- res_batch = cat2(batch, batch2)
- res = [torch.cat([xs[j], xs2[j]], 2) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_batch_sum(self):
- @torch.jit.batch(batch_size=4)
- def batch_sum(a):
- return a.sum()
-
- xs, batch = self.rand_batch(4, (True, 5), (False, 3))
- res_batch = batch_sum(batch)
- res = [xs[j].sum().unsqueeze(0) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- def test_if_else(self):
- def single_if(a, b):
- if bool(a > b):
- a = a + b
- else:
- a = a - b
- return a
-
- batch_if = torch.jit.batch(batch_size=4)(single_if)
-
- a, batch_a = self.rand_batch(4, ())
- b, batch_b = self.rand_batch(4, ())
- res_batch = batch_if(batch_a, batch_b)
- res = [single_if(a[j], b[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- script_if = torch.jit.script(single_if)
- graph = torch.to_batch_graph(script_if.graph)
- self.assertExpected(canonical(graph))
-
- def test_if_else_with_scalar(self):
- def single_if(a, b):
- if bool(a > 0.1):
- a = a + b
- else:
- a = a - b
- return a
-
- batch_if = torch.jit.batch(batch_size=4)(single_if)
-
- a, batch_a = self.rand_batch(4, ())
- b, batch_b = self.rand_batch(4, ())
- res_batch = batch_if(batch_a, batch_b)
- res = [single_if(a[j], b[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- script_if = torch.jit.script(single_if)
- graph = torch.to_batch_graph(script_if.graph)
- self.assertExpected(canonical(graph))
-
- def test_if_noelse(self):
- def single_if(a, b):
- if bool(a > b):
- a = a + b
- return a
-
- batch_if = torch.jit.batch(batch_size=4)(single_if)
-
- a, batch_a = self.rand_batch(4, ())
- b, batch_b = self.rand_batch(4, ())
- res_batch = batch_if(batch_a, batch_b)
- res = [single_if(a[j], b[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- script_if = torch.jit.script(single_if)
- graph = torch.to_batch_graph(script_if.graph)
- self.assertExpected(canonical(graph))
-
- def test_if_noelse_with_scalar(self):
- def single_if(a, b):
- if bool(a > 0.1):
- a = a + b
- return a
-
- batch_if = torch.jit.batch(batch_size=4)(single_if)
-
- a, batch_a = self.rand_batch(4, ())
- b, batch_b = self.rand_batch(4, ())
- res_batch = batch_if(batch_a, batch_b)
- res = [single_if(a[j], b[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- script_if = torch.jit.script(single_if)
- graph = torch.to_batch_graph(script_if.graph)
- self.assertExpected(canonical(graph))
-
- def test_while(self):
- def single_while(a, b):
- while bool(a > b):
- a = a - b
- return a
-
- batch_while = torch.jit.batch(batch_size=4)(single_while)
-
- a, batch_a = self.rand_batch(4, ())
- b = [torch.abs(torch.rand(1)) for i in range(4)]
- batch_b = BatchTensor(b, torch.tensor([]).byte())
- res_batch = batch_while(batch_a, batch_b)
- res = [single_while(a[j], b[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- script_while = torch.jit.script(single_while)
- graph = torch.to_batch_graph(script_while.graph)
- self.assertExpected(canonical(graph))
-
- def test_for(self):
- def single_for(x, y):
- for _ in range(10):
- x = x + y
- return x
-
- batch_for = torch.jit.batch(batch_size=4)(single_for)
-
- a, batch_a = self.rand_batch(4, ())
- b, batch_b = self.rand_batch(4, ())
- res_batch = batch_for(batch_a, batch_b)
- res = [single_for(a[j], b[j]) for j in range(4)]
- self.assertEqual(res, res_batch.examples())
-
- script_for = torch.jit.script(single_for)
- graph = torch.to_batch_graph(script_for.graph)
- self.assertExpected(canonical(graph))
-
- def test_lstm(self):
- def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
- for i in range(x_all.size(1)):
- x = x_all.select(1, i)
- i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
- f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
- o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
- # activations
- i_t = torch.sigmoid(i_t)
- f_t = torch.sigmoid(f_t)
- o_t = torch.sigmoid(o_t)
- # cell computations
- c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
- c_t = torch.tanh(c_t)
- c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
- h_t = torch.mul(o_t, torch.tanh(c_t))
- h = h_t
- c = c_t
- return h
-
- LSTM_batch = torch.jit.batch(batch_size=4)(LSTM)
-
- batch_size, input_size, hidden_size = 4, 3, 2
- xs, batch = self.rand_batch(batch_size, (True, 4), (False, input_size))
- hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
- cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
-
- # input to hidden weights
- w_xi = torch.rand(input_size, hidden_size)
- w_xf = torch.rand(input_size, hidden_size)
- w_xo = torch.rand(input_size, hidden_size)
- w_xc = torch.rand(input_size, hidden_size)
- # hidden to hidden weights
- w_hi = torch.rand(hidden_size, hidden_size)
- w_hf = torch.rand(hidden_size, hidden_size)
- w_ho = torch.rand(hidden_size, hidden_size)
- w_hc = torch.rand(hidden_size, hidden_size)
- # bias terms
- b_i = torch.rand(hidden_size)
- b_f = torch.rand(hidden_size)
- b_o = torch.rand(hidden_size)
- b_c = torch.rand(hidden_size)
-
- ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
- w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) for j in range(batch_size)]
- ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
- w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
- self.assertEqual(ys, ybs.examples())
-
- def test_greedy_search(self):
- def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
- b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
- iter_count = torch.zeros_like(iter_num)
- while bool(iter_count < iter_num):
- iter_count = iter_count + 1
- # LSTM Cell
- i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
- f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
- o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
- # activations
- i_t = torch.sigmoid(i_t)
- f_t = torch.sigmoid(f_t)
- o_t = torch.sigmoid(o_t)
- # cell computations
- c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
- c_t = torch.tanh(c_t)
- c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
- h_t = torch.mul(o_t, torch.tanh(c_t))
- h = h_t
- c = c_t
- # calculate feature with max probability
- s_t = torch.matmul(h_t, w_hs) + b_s
- p_t = torch.softmax(s_t, 1)
- i_t = torch.argmax(p_t, 1)
- x = embed.index_select(1, i_t).squeeze(1)
- return h
-
- greedy_batch = torch.jit.batch(batch_size=4)(greedy)
+ greedy_batch = torch.jit.batch(batch_size=4)(greedy)
batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
xs, batch = self.rand_batch(batch_size, (False, input_size))
ge = torch.jit.script(script, optimize)
ge(*inputs)
- def checkScript(self,
- script,
- inputs,
- optimize=True,
- outputs=None,
- name='func',
- capture_output=False,
- frames_up=1,
- check_expected=False):
- if isinstance(script, str):
- cu = torch.jit.CompilationUnit(script, optimize, _frames_up=frames_up)
- ge = getattr(cu, name)
- else:
- if capture_output:
- with self.capture_stdout() as captured:
- outputs = script(*inputs)
- else:
- outputs = script(*inputs)
- # Check the string frontend first
- source = textwrap.dedent(inspect.getsource(script))
- self.checkScript(
- source,
- inputs,
- optimize,
- outputs,
- script.__name__,
- capture_output,
- frames_up=2,
- check_expected=check_expected)
- # Continue checking the Python frontend
- ge = torch.jit.script(script, optimize, _frames_up=1)
-
- if capture_output:
- with self.capture_stdout() as captured:
- outputs_ge = ge(*inputs)
- if not WINDOWS:
- self.assertExpected(captured[0], subname='stdout')
- else:
- outputs_ge = ge(*inputs)
- self.assertEqual(outputs, outputs_ge)
-
- if check_expected:
- self.assertExpectedGraph(ge.graph)
-
- return ge
-
def test_training_param(self):
class What(torch.jit.ScriptModule):
@torch.jit.script_method
b = torch.rand(1, requires_grad=True)
self.checkScript(func, (a, b), optimize=True)
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_clamp_fusion(self):
- def func2(a, b):
- return torch.clamp(a + b, min=0, max=2)
-
- def funcInf(a, b):
- return torch.clamp(a + b, min=0, max=float('inf'))
-
- a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
- b = torch.randn(4, 4, dtype=torch.float, device='cuda')
-
- funcs = (func2, funcInf)
- for f in funcs:
- s = self.checkScript(f, (a, b))
- self.assertAllFused(s.graph_for(a, b), except_for={'aten::size'})
-
- c = s(a, b)
- c.sum().backward()
- graph = backward_graph(s)
- self.assertAllFused(graph, except_for={'prim::SumToSize'})
-
def test_mul(self):
def func(a, b):
return a * b
func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
self.assertExpected(canonical(func2.graph), subname='2')
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "No CUDA")
- @skipIfRocm
- def test_chunk_fusion_cuda(self):
- def fn(x):
- a, b, c = x.chunk(3, 1)
- return a * b + c
-
- inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
-
- self.checkScript(fn, inputs)
-
- fn_script = torch.jit.script(fn)
- _ = fn_script(*inputs)
- self.assertExpectedGraph(fn_script.graph_for(*inputs))
-
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "No CUDA")
- @skipIfRocm
- def test_chunk_multiple_fusion_cuda(self):
- # The arguments are intentionally used out of order as a test to see
- # if the fusion compiler adds extra args in the correct order
- def fn(s, x, y, z):
- z1, z2 = z.chunk(2, 2)
- x1, x2, x3 = x.chunk(3, 1)
- y1, y2 = y.chunk(2, 0)
- return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
-
- inputs = [
- torch.randn(5, 2, 3, dtype=torch.float, device='cuda'),
- torch.randn(5, 6, 3, dtype=torch.float, device='cuda'),
- torch.randn(10, 2, 3, dtype=torch.float, device='cuda'),
- torch.randn(5, 2, 6, dtype=torch.float, device='cuda'),
- ]
-
- self.checkScript(fn, inputs)
-
- fn_script = torch.jit.script(fn)
- _ = fn_script(*inputs)
- self.assertExpectedGraph(fn_script.graph_for(*inputs))
-
- @staticmethod
- def _test_chunk_fusion_correctness(self, device='cpu'):
- def chunk_4_0(x):
- x0, x1, x2, x3 = x.chunk(4, 0)
- return x0 + x1 + x2 + x3
-
- def chunk_4_1(x):
- x0, x1, x2, x3 = x.chunk(4, 1)
- return x0 + x1 + x2 + x3
-
- def chunk_4_last(x):
- x0, x1, x2, x3 = x.chunk(4, 2)
- return x0 + x1 + x2 + x3
-
- fns = [chunk_4_0, chunk_4_1, chunk_4_last]
- tensors = [
- # splitSize = 1
- torch.randn(4, 4, 4, dtype=torch.float, device=device),
-
- # contiguous case
- torch.randn(12, 8, 16, dtype=torch.float, device=device),
-
- # non-contiguous case
- torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2),
- ]
-
- for tensor in tensors:
- for fn in fns:
- self.checkScript(fn, [tensor])
-
- @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
- @skipIfRocm
- @enable_cpu_fuser
- def test_chunk_fusion_correctness(self):
- return self._test_chunk_fusion_correctness(self, 'cpu')
-
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "No CUDA")
- @skipIfRocm
- def test_chunk_fusion_correctness_cuda(self):
- return self._test_chunk_fusion_correctness(self, 'cuda')
-
def test_cat(self):
@torch.jit.script
def func(x):
self.checkScript(func2, ())
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_tensor_scalar_fusion_cuda(self):
- def should_fuse(x):
- z = 3.
- y = x + z
- return x * y
-
- # XXX: right now we only support fusing scalars if
- # they're constant (#9940)
- def should_not_fuse(x, z):
- y = x + int(z)
- return x * y
-
- inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
- ge = self.checkScript(should_fuse, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs), subname='1')
-
- inputs = [
- torch.randn(2, 2, dtype=torch.float, device='cuda'),
- torch.tensor(3., dtype=torch.float, device='cuda'),
- ]
- ge = self.checkScript(should_not_fuse, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs), subname='2')
-
def test_list_ops(self):
def test_equality():
a = [1, 2, 3]
# NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs)
- def test_view_shape_prop(self):
- cu = torch.jit.CompilationUnit('''
- def test_view_shape_prop(a):
- return a.view(size=[-1])
- ''')
- inputs = [torch.zeros(10, 10)]
- outputs = torch.zeros(100)
-
- real_outs = cu.test_view_shape_prop(*inputs)
- self.assertEqual(real_outs, outputs)
-
- def test_view_listconstruct_shape_prop(self):
- def fn(x):
- B = x.size(0)
- C = x.size(1)
- T = x.size(2)
- return x.view(T, B, C)
-
- x = torch.randn(3, 1, 5, requires_grad=True)
- graph = torch.jit.script(fn).graph
- torch._C._jit_pass_shape_analysis(graph, (x,), False)
- self.assertTrue(next(graph.outputs()).type().kind() != 'DynamicType')
-
- def test_integral_shape_inference(self):
- cu = torch.jit.CompilationUnit('''
- def test_integral_shape_inference(a):
- return a / a
- ''')
- inputs = [torch.ones(10, 10).type(torch.LongTensor)]
- outputs = torch.ones(10, 10)
-
- self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
-
- def test_fuser_multiple_blocks(self):
- cu = torch.jit.CompilationUnit('''
- def test_fuser_multiple_blocks(this, that, theother, meme):
- i = 0
- while i < 20:
- this = torch.cat([this, meme], dim=0)
- that = torch.cat([that, meme], dim=0)
- theother = torch.cat([theother, meme], dim=0)
- i = i + 1
- return this, that, theother
- ''')
-
- inputs = [torch.ones(0, 10, 10)] * 3
- inputs += [torch.ones(1, 10, 10)]
- outputs = [torch.ones(20, 10, 10)] * 3
-
- self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
-
- @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
- @enable_cpu_fuser
- def test_scalar_fusion(self):
- def fn(x, y):
- return 2 * x + y
-
- x = torch.tensor(0.1, dtype=torch.float, device='cpu')
- y = torch.tensor(1, dtype=torch.float, device='cpu')
- ge = self.checkScript(fn, (x, y))
- self.assertExpectedGraph(ge.graph_for(x, y))
-
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_fusion_chunk_motion_deduplicates_inputs(self):
- def func1(x):
- z = x * x
- z0, z1 = z.chunk(2)
- return z0 * z1
-
- def func2(x):
- z = x * x * x
- z0, z1 = z.chunk(2)
- return z0 * z1
-
- inputs = [
- torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float),
- ]
- for func in [func1, func2]:
- module = self.checkScript(func, inputs)
- forward_graph = module.graph_for(*inputs)
- self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
- fusion_group = list(forward_graph.nodes())[-1]
- self.assertEqual(len(list(fusion_group.inputs())), 1)
-
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_lstm_gates_permutations_fusion_cuda(self):
- # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
- # Test that any permutation of this will still result in one FusionGroup.
- choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh']
- template = dedent('''
- def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
- gates = {} + {} + {} + {}
- ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
- return ingate * forgetgate * cellgate * outgate
+ def test_view_shape_prop(self):
+ cu = torch.jit.CompilationUnit('''
+ def test_view_shape_prop(a):
+ return a.view(size=[-1])
''')
- for permutation in itertools.permutations(choices, len(choices)):
- code = template.format(*permutation)
- scope = {}
- exec(code, globals(), scope)
- cu = torch.jit.CompilationUnit(code)
+ inputs = [torch.zeros(10, 10)]
+ outputs = torch.zeros(100)
- inputs = get_lstm_inputs('cuda', training=False)
- self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
- forward_graph = cu.cell.graph_for(*inputs)
- self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
+ real_outs = cu.test_view_shape_prop(*inputs)
+ self.assertEqual(real_outs, outputs)
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_lstm_fusion_cuda(self):
- inputs = get_lstm_inputs('cuda', training=True)
- module = self.checkScript(LSTMCellS, inputs)
- forward_graph = module.graph_for(*inputs)
- self.assertGraphContainsExactly(
- forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
- self.assertExpectedGraph(forward_graph, subname='forward')
+ def test_view_listconstruct_shape_prop(self):
+ def fn(x):
+ B = x.size(0)
+ C = x.size(1)
+ T = x.size(2)
+ return x.view(T, B, C)
- hy, cy = module(*inputs)
- (hy + cy).sum().backward()
- self.assertExpectedGraph(backward_graph(module), subname='backward')
+ x = torch.randn(3, 1, 5, requires_grad=True)
+ graph = torch.jit.script(fn).graph
+ torch._C._jit_pass_shape_analysis(graph, (x,), False)
+ self.assertTrue(next(graph.outputs()).type().kind() != 'DynamicType')
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- @skipIfRocm
- def test_milstm_fusion_cuda(self):
- inputs = get_milstm_inputs('cuda', training=True)
- module = self.checkScript(MiLSTMCell, inputs)
- forward_graph = module.graph_for(*inputs)
- self.assertGraphContainsExactly(
- forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
- self.assertExpectedGraph(forward_graph, subname='forward')
+ def test_integral_shape_inference(self):
+ cu = torch.jit.CompilationUnit('''
+ def test_integral_shape_inference(a):
+ return a / a
+ ''')
+ inputs = [torch.ones(10, 10).type(torch.LongTensor)]
+ outputs = torch.ones(10, 10)
- hy, cy = module(*inputs)
- (hy + cy).sum().backward()
- self.assertExpectedGraph(backward_graph(module), subname='backward')
+ self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
+
+ def test_fuser_multiple_blocks(self):
+ cu = torch.jit.CompilationUnit('''
+ def test_fuser_multiple_blocks(this, that, theother, meme):
+ i = 0
+ while i < 20:
+ this = torch.cat([this, meme], dim=0)
+ that = torch.cat([that, meme], dim=0)
+ theother = torch.cat([theother, meme], dim=0)
+ i = i + 1
+ return this, that, theother
+ ''')
+
+ inputs = [torch.ones(0, 10, 10)] * 3
+ inputs += [torch.ones(1, 10, 10)]
+ outputs = [torch.ones(20, 10, 10)] * 3
+
+ self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
def test_dropout_script(self):
m.eval()
self.assertEqual(m(), 1)
- def test_script_module_for(self):
+ def test_script_module_for(self):
+ class M(torch.jit.ScriptModule):
+ __constants__ = ['b']
+
+ def __init__(self):
+ super(M, self).__init__(False)
+ self.b = [1, 2, 3, 4]
+
+ @torch.jit.script_method
+ def forward(self):
+ sum = 0
+ for i in self.b:
+ sum += i
+ return sum
+
+ m = M()
+ self.assertEqual(m(), 10)
+
+ def test_script_module_for2(self):
+ class Sub(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Sub, self).__init__(False)
+ self.weight = nn.Parameter(torch.randn(2))
+
+ @torch.jit.script_method
+ def forward(self, thing):
+ return self.weight + thing
+
+ class M(torch.jit.ScriptModule):
+ __constants__ = ['mods']
+
+ def __init__(self):
+ super(M, self).__init__(False)
+ self.mods = nn.ModuleList([Sub() for i in range(10)])
+
+ @torch.jit.script_method
+ def forward(self, v):
+ for m in self.mods:
+ v = m(v)
+ return v
+
+ i = torch.Tensor(2)
+ m = M()
+ o = m(i)
+ v = i
+ for sub in m.mods:
+ v = sub(v)
+ self.assertEqual(o, v)
+
+ def test_script_module_const_submodule_fail(self):
+ class Sub(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Sub, self).__init__(False)
+ self.weight = nn.Parameter(torch.randn(2))
+
+ @torch.jit.script_method
+ def forward(self, thing):
+ return self.weight + thing
+
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M, self).__init__(False)
+ self.mods = [Sub() for _ in range(10)]
+
+ @torch.jit.script_method
+ def forward(self):
+ for _ in self.mods:
+ print(1)
+ return 4
+
+ with self.assertRaisesRegex(RuntimeError, "did you forget to add it __constants__"):
+ M()
+
+ class DerivedStateModule(torch.jit.ScriptModule):
+ def __init__(self):
+ super(TestScript.DerivedStateModule, self).__init__()
+ self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
+ self.register_buffer('derived', torch.neg(self.param).detach())
+
+ # This is a flag so we can test that the pack method was called
+ self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long))
+ # This is a flag so we can test that the unpack method was called
+ self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long))
+
+ @torch.jit.script_method
+ def _pack(self):
+ self.pack_called.set_(torch.ones(1, dtype=torch.long))
+ self.derived.set_(torch.rand(1, dtype=torch.float).detach())
+
+ @torch.jit.script_method
+ def _unpack(self):
+ self.unpack_called.set_(torch.ones(1, dtype=torch.long))
+ self.derived.set_(torch.neg(self.param).detach())
+
+ @torch.jit.script_method
+ def forward(self, x):
+ return x + self.derived
+
+ def test_pack_unpack_state(self):
+ sm = TestScript.DerivedStateModule()
+ x = torch.rand(3, 4, dtype=torch.float)
+ torch.testing.assert_allclose(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
+
+ # Test save path
+ self.assertFalse(sm.pack_called.item())
+ self.assertFalse(sm.unpack_called.item())
+ sm.apply(lambda s: s._pack())
+ imported = self.getExportImportCopy(sm)
+ sm.apply(lambda s: s._unpack())
+ imported.apply(lambda s: s._unpack())
+ # ensure pack was called before serialization
+ self.assertTrue(sm.pack_called.item())
+ # ensure unpack was called after serialization so as to leave the module in an initialized state
+ self.assertTrue(sm.unpack_called.item())
+
+ torch.testing.assert_allclose(sm.derived, torch.neg(sm.param))
+
+ # Test load paths
+ self.assertTrue(imported.unpack_called.item())
+ torch.testing.assert_allclose(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
+
+ def test_pack_unpack_nested(self):
+ class SubSubMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(SubSubMod, self).__init__()
+ self.register_buffer('buf', torch.ones(3, 4) * 3)
+
+ @torch.jit.script_method
+ def _pack(self):
+ self.buf.set_(torch.zeros(1, dtype=torch.double))
+
+ @torch.jit.script_method
+ def _unpack(self):
+ self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 3)
+
+ @torch.jit.script_method
+ def forward(self, x):
+ return x + self.buf
+
+ class SubMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(SubMod, self).__init__()
+ self.register_buffer('buf', torch.ones(3, 4) * 2)
+ self.ssm = SubSubMod()
+
+ @torch.jit.script_method
+ def _pack(self):
+ self.buf.set_(torch.zeros(1, dtype=torch.double))
+
+ @torch.jit.script_method
+ def _unpack(self):
+ self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 2)
+
+ @torch.jit.script_method
+ def forward(self, x):
+ return self.ssm(x + self.buf)
+
+ class Mod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Mod, self).__init__()
+ self.submod = SubMod()
+ self.register_buffer('buf', torch.ones(3, 4) * 1)
+
+ @torch.jit.script_method
+ def _pack(self):
+ self.buf.set_(torch.zeros(1, dtype=torch.double))
+
+ @torch.jit.script_method
+ def _unpack(self):
+ self.buf.set_(torch.ones(3, 4, dtype=torch.double))
+
+ @torch.jit.script_method
+ def forward(self, x):
+ return self.submod(x + self.buf)
+
+ m = Mod()
+ torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
+ m.apply(lambda s: s._pack())
+ torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.zeros(3, 4))
+ m.apply(lambda s: s._unpack())
+ torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
+
+ def test_script_module_not_tuple(self):
class M(torch.jit.ScriptModule):
- __constants__ = ['b']
+ __constants__ = ['mods']
def __init__(self):
super(M, self).__init__(False)
- self.b = [1, 2, 3, 4]
+ self.mods = 1
@torch.jit.script_method
- def forward(self):
- sum = 0
- for i in self.b:
- sum += i
- return sum
-
- m = M()
- self.assertEqual(m(), 10)
+ def forward(self, v):
+ for m in self.mods:
+ print(m)
+ return v
+ with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
+ M()
- def test_script_module_for2(self):
+ def test_script_sequential_for(self):
class Sub(torch.jit.ScriptModule):
def __init__(self):
super(Sub, self).__init__(False)
def __init__(self):
super(M, self).__init__(False)
- self.mods = nn.ModuleList([Sub() for i in range(10)])
+ self.mods = nn.Sequential(Sub(), Sub(), Sub())
@torch.jit.script_method
def forward(self, v):
v = m(v)
return v
+ @torch.jit.script_method
+ def forward2(self, v):
+ return self.mods(v)
+
i = torch.Tensor(2)
m = M()
o = m(i)
v = sub(v)
self.assertEqual(o, v)
- def test_script_module_const_submodule_fail(self):
+ o2 = m.forward2(i)
+ self.assertEqual(o2, v)
+
+ def test_script_sequential_multi_output_fail(self):
class Sub(torch.jit.ScriptModule):
def __init__(self):
super(Sub, self).__init__(False)
def forward(self, thing):
return self.weight + thing
+ class ReturnMulti(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ReturnMulti, self).__init__(False)
+
+ @torch.jit.script_method
+ def forward(self, x):
+ return x, x, x
+
+ class HaveSequential(torch.jit.ScriptModule):
+ __constants__ = ['someseq']
+
+ def __init__(self):
+ super(HaveSequential, self).__init__(False)
+ self.someseq = nn.Sequential(
+ Sub(),
+ ReturnMulti(),
+ Sub()
+ )
+
+ @torch.jit.script_method
+ def forward(self, x):
+ return self.someseq(x)
+
+ with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"):
+ hs = HaveSequential()
+ i = torch.Tensor(2)
+ hs(i)
+
+ def test_constant_as_attr(self):
class M(torch.jit.ScriptModule):
+ __constants__ = ['dim']
+
def __init__(self):
super(M, self).__init__(False)
- self.mods = [Sub() for _ in range(10)]
+ self.dim = 1
@torch.jit.script_method
- def forward(self):
- for _ in self.mods:
- print(1)
- return 4
+ def forward(self, v):
+ return torch.cat([v, v, v], dim=self.dim)
+ v = torch.zeros(1, 1)
+ self.assertEqual(torch.cat([v, v, v], dim=1), M()(v))
- with self.assertRaisesRegex(RuntimeError, "did you forget to add it __constants__"):
- M()
+ class StarTestSumStarred(torch.nn.Module):
+ def __init__(self):
+ super(TestScript.StarTestSumStarred, self).__init__()
- class DerivedStateModule(torch.jit.ScriptModule):
+ def forward(self, *inputs):
+ output = inputs[0]
+ for i in range(1, len(inputs)):
+ output += inputs[i]
+ return output
+
+ class StarTestReturnThree(torch.nn.Module):
def __init__(self):
- super(TestScript.DerivedStateModule, self).__init__()
- self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
- self.register_buffer('derived', torch.neg(self.param).detach())
+ super(TestScript.StarTestReturnThree, self).__init__()
- # This is a flag so we can test that the pack method was called
- self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long))
- # This is a flag so we can test that the unpack method was called
- self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long))
+ def forward(self, rep):
+ return rep, rep, rep
- @torch.jit.script_method
- def _pack(self):
- self.pack_called.set_(torch.ones(1, dtype=torch.long))
- self.derived.set_(torch.rand(1, dtype=torch.float).detach())
+ def test_script_star_expr(self):
- @torch.jit.script_method
- def _unpack(self):
- self.unpack_called.set_(torch.ones(1, dtype=torch.long))
- self.derived.set_(torch.neg(self.param).detach())
+ class M2(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M2, self).__init__(True)
+ self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
+ (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
+ self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
- @torch.jit.script_method
- def forward(self, x):
- return x + self.derived
+ @torch.jit.script_method
+ def forward(self, rep):
+ tup = self.g(rep)
+ return self.m(*tup)
- def test_pack_unpack_state(self):
- sm = TestScript.DerivedStateModule()
- x = torch.rand(3, 4, dtype=torch.float)
- torch.testing.assert_allclose(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
+ m = M2()
+ self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
- # Test save path
- self.assertFalse(sm.pack_called.item())
- self.assertFalse(sm.unpack_called.item())
- sm.apply(lambda s: s._pack())
- imported = self.getExportImportCopy(sm)
- sm.apply(lambda s: s._unpack())
- imported.apply(lambda s: s._unpack())
- # ensure pack was called before serialization
- self.assertTrue(sm.pack_called.item())
- # ensure unpack was called after serialization so as to leave the module in an initialized state
- self.assertTrue(sm.unpack_called.item())
+ def test_script_star_expr_string(self):
+ class M2(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M2, self).__init__(True)
+ self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
+ (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
+ self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
- torch.testing.assert_allclose(sm.derived, torch.neg(sm.param))
+ self.define('''
+ def forward(self, rep):
+ tup = self.g(rep)
+ return self.m(*tup)
+ ''')
- # Test load paths
- self.assertTrue(imported.unpack_called.item())
- torch.testing.assert_allclose(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
+ m = M2()
+ self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
- def test_pack_unpack_nested(self):
- class SubSubMod(torch.jit.ScriptModule):
+ class StarTestSumAndReturnThree(torch.nn.Module):
+ def __init__(self):
+ super(TestScript.StarTestSumAndReturnThree, self).__init__()
+
+ def forward(self, *inputs):
+ output = inputs[0]
+ for i in range(1, len(inputs)):
+ output += inputs[i]
+ return output, output, output
+
+ def test_script_star_assign(self):
+ class M2(torch.jit.ScriptModule):
def __init__(self):
- super(SubSubMod, self).__init__()
- self.register_buffer('buf', torch.ones(3, 4) * 3)
+ super(M2, self).__init__(True)
+ self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3))
+ self.define('''
+ def forward(self, rep):
+ head, *tail = self.g(rep)
+ return head
+ ''')
- @torch.jit.script_method
- def _pack(self):
- self.buf.set_(torch.zeros(1, dtype=torch.double))
+ m = M2()
+ self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
- @torch.jit.script_method
- def _unpack(self):
- self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 3)
+ def test_script_module_star_assign2(self):
+ class M2(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M2, self).__init__(True)
+ self.g = torch.jit.trace(
+ TestScript.StarTestSumAndReturnThree(),
+ (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
+ _force_outplace=True)
+ self.define('''
+ def forward(self, rep):
+ *head, tail = self.g(rep, rep, rep)
+ return tail
+ ''')
+
+ m = M2()
+ self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
+
+ def test_script_module_star_assign2_inplace(self):
+ class M2(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M2, self).__init__(True)
+ self.g = torch.jit.trace(
+ TestScript.StarTestSumAndReturnThree(),
+ (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
+ _force_outplace=False)
+ self.define('''
+ def forward(self, rep):
+ *head, tail = self.g(rep, rep, rep)
+ return tail
+ ''')
+
+ m = M2()
+ # since forward() makes three aliases to the input `rep` before passing
+ # it to StarTestSumAndReturnThree(), in-place behavior will be different
+ # than the above out of place.
+ self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3))
+
+ def test_script_module_star_assign_fail_pythonop(self):
+
+ with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
+ class M2(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M2, self).__init__(True)
+
+ def myfunc():
+ return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
+
+ self.define('''
+ def forward(self, rep):
+ a, *b = myfunc()
+ return a
+ ''')
+
+ m = M2()
+ m(torch.zeros(4, 3))
+
+ def test_script_module_star_assign_fail_builtin(self):
+ with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
+ class M2(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M2, self).__init__(True)
+
+ self.define('''
+ def forward(self, rep):
+ a, *b = torch.neg(rep)
+ return a
+ ''')
+
+ m = M2()
+ m(torch.zeros(4, 3))
- @torch.jit.script_method
- def forward(self, x):
- return x + self.buf
+ def test_pack_padded_pad_packed_trace(self):
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
+ T, B, C = 3, 5, 7
- class SubMod(torch.jit.ScriptModule):
+ class PadPackedWrapper(torch.nn.Module):
def __init__(self):
- super(SubMod, self).__init__()
- self.register_buffer('buf', torch.ones(3, 4) * 2)
- self.ssm = SubSubMod()
+ super(PadPackedWrapper, self).__init__()
- @torch.jit.script_method
- def _pack(self):
- self.buf.set_(torch.zeros(1, dtype=torch.double))
+ def forward(self, x, seq_lens):
+ x = pack_padded_sequence(x, seq_lens)
+ x, _ = pad_packed_sequence(x)
+ return x
- @torch.jit.script_method
- def _unpack(self):
- self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 2)
+ x = np.ones((T, B, C))
+ seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
+ # set padding value so we can test equivalence
+ for b in range(B):
+ if seq_lens[b] < T:
+ x[seq_lens[b]:, b, :] = 0
+ seq_lens = torch.from_numpy(seq_lens)
+ x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True)
- @torch.jit.script_method
- def forward(self, x):
- return self.ssm(x + self.buf)
+ m = PadPackedWrapper()
+ m_traced = torch.jit.trace(m, (x, seq_lens,))
- class Mod(torch.jit.ScriptModule):
- def __init__(self):
- super(Mod, self).__init__()
- self.submod = SubMod()
- self.register_buffer('buf', torch.ones(3, 4) * 1)
+ y = m(x, seq_lens)
+ loss = torch.sum(y)
+ loss.backward()
+ grad = x.grad.clone()
+ x.grad.zero_()
- @torch.jit.script_method
- def _pack(self):
- self.buf.set_(torch.zeros(1, dtype=torch.double))
+ y_traced = m_traced(x, seq_lens)
+ loss_traced = torch.sum(y_traced)
+ loss_traced.backward()
+ grad_traced = x.grad.clone()
- @torch.jit.script_method
- def _unpack(self):
- self.buf.set_(torch.ones(3, 4, dtype=torch.double))
+ self.assertEqual(y_traced, x)
+ self.assertEqual(y_traced, y)
+ self.assertEqual(grad, grad_traced)
- @torch.jit.script_method
- def forward(self, x):
- return self.submod(x + self.buf)
+ f = io.BytesIO()
+ torch.onnx._export(m, (x, seq_lens), f, verbose=False)
- m = Mod()
- torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
- m.apply(lambda s: s._pack())
- torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.zeros(3, 4))
- m.apply(lambda s: s._unpack())
- torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
+ def test_script_outputs(self):
+ with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
+ @torch.jit.script
+ def foo(a):
+ c, d = a + a
+ return c + d
- def test_script_module_not_tuple(self):
- class M(torch.jit.ScriptModule):
- __constants__ = ['mods']
+ @torch.jit.script
+ def return3():
+ return 1, 2, 3
- def __init__(self):
- super(M, self).__init__(False)
- self.mods = 1
+ with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
+ @torch.jit.script
+ def bind2():
+ a, b = return3()
+ print(a)
+ print(b)
- @torch.jit.script_method
- def forward(self, v):
- for m in self.mods:
- print(m)
- return v
- with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
- M()
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ def test_script_get_device_cuda(self):
+ @torch.jit.script
+ def foo(a):
+ return a.get_device()
- def test_script_sequential_for(self):
- class Sub(torch.jit.ScriptModule):
- def __init__(self):
- super(Sub, self).__init__(False)
- self.weight = nn.Parameter(torch.randn(2))
+ v = torch.randn(1, device='cuda')
+ self.assertEqual(foo(v), 0)
- @torch.jit.script_method
- def forward(self, thing):
- return self.weight + thing
+ def test_script_chunk(self):
+ @torch.jit.script
+ def foo(a):
+ b, c = torch.chunk(a, dim=0, chunks=2)
+ return b
+ v = torch.rand(10, 3)
+ self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
- class M(torch.jit.ScriptModule):
- __constants__ = ['mods']
+ def test_rnn_trace_override(self):
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
+ num_layers = 3
+ T, B, C = 11, 5, 7
- def __init__(self):
- super(M, self).__init__(False)
- self.mods = nn.Sequential(Sub(), Sub(), Sub())
+ class RNNTraceWrapper(torch.nn.Module):
+ def __init__(self, cell_type):
+ super(RNNTraceWrapper, self).__init__()
+ if cell_type == 'RNN':
+ self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers)
+ elif cell_type == 'LSTM':
+ self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers)
+ elif cell_type == 'GRU':
+ self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers)
- @torch.jit.script_method
- def forward(self, v):
- for m in self.mods:
- v = m(v)
- return v
+ def forward(self, x, seq_lens):
+ x = pack_padded_sequence(x, seq_lens)
+ x, _ = self.rnn(x)
+ x, _ = pad_packed_sequence(x)
+ return x
- @torch.jit.script_method
- def forward2(self, v):
- return self.mods(v)
+ for cell_type in ['RNN', 'LSTM', 'GRU']:
+ x = torch.ones(T, B, C, requires_grad=True)
+ seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
- i = torch.Tensor(2)
- m = M()
- o = m(i)
- v = i
- for sub in m.mods:
- v = sub(v)
- self.assertEqual(o, v)
+ m = RNNTraceWrapper(cell_type)
+ m_traced = torch.jit.trace(m, (x, seq_lens,))
- o2 = m.forward2(i)
- self.assertEqual(o2, v)
+ y = m(x, seq_lens)
+ loss = torch.sum(y)
+ loss.backward()
+ grad = x.grad.clone()
+ x.grad.zero_()
- def test_script_sequential_multi_output_fail(self):
- class Sub(torch.jit.ScriptModule):
- def __init__(self):
- super(Sub, self).__init__(False)
- self.weight = nn.Parameter(torch.randn(2))
+ y_traced = m_traced(x, seq_lens)
+ loss_traced = torch.sum(y_traced)
+ loss_traced.backward()
+ grad_traced = x.grad.clone()
- @torch.jit.script_method
- def forward(self, thing):
- return self.weight + thing
+ self.assertEqual(y_traced, y)
+ self.assertEqual(grad, grad_traced)
- class ReturnMulti(torch.jit.ScriptModule):
- def __init__(self):
- super(ReturnMulti, self).__init__(False)
+ f = io.BytesIO()
+ torch.onnx._export(m, (x, seq_lens), f, verbose=False)
- @torch.jit.script_method
- def forward(self, x):
- return x, x, x
+ def test_python_call_non_tensor(self):
+ def foo(a, b, c):
+ # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor]
+ d, e = c
+ return b + e, a + d
- class HaveSequential(torch.jit.ScriptModule):
- __constants__ = ['someseq']
+ @torch.jit.script
+ def bar():
+ x = torch.ones(3, 4)
+ a, b = foo(x, 3, (x, 3))
+ return a, b
- def __init__(self):
- super(HaveSequential, self).__init__(False)
- self.someseq = nn.Sequential(
- Sub(),
- ReturnMulti(),
- Sub()
- )
+ self.assertEqual((6, torch.ones(3, 4) + 1), bar())
- @torch.jit.script_method
- def forward(self, x):
- return self.someseq(x)
+ def test_python_call_non_tensor_wrong(self):
+ with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
+ def foo():
+ # type: () -> Tensor
+ return ((3, 4),)
- with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"):
- hs = HaveSequential()
- i = torch.Tensor(2)
- hs(i)
+ @torch.jit.script
+ def bar():
+ return foo()
- def test_constant_as_attr(self):
- class M(torch.jit.ScriptModule):
- __constants__ = ['dim']
+ bar()
- def __init__(self):
- super(M, self).__init__(False)
- self.dim = 1
+ def test_tuples(self):
+ def foo(i):
+ a = (i + 4, i * 2)
+ c = a
+ # some nonsense with if-statements and loops to check
+ # that tuple lowering doesn't fail
+ if True:
+ c = (i * 9, i + 1)
+ t0, t1 = c
+ while False:
+ t0, t1 = c
+ c = (t1, t0)
+ x = (1,)
+ y = 1,
+ return t0, x, y
- @torch.jit.script_method
- def forward(self, v):
- return torch.cat([v, v, v], dim=self.dim)
- v = torch.zeros(1, 1)
- self.assertEqual(torch.cat([v, v, v], dim=1), M()(v))
+ v = torch.rand(10, 3)
+ self.checkScript(foo, (v,))
- class StarTestSumStarred(torch.nn.Module):
- def __init__(self):
- super(TestScript.StarTestSumStarred, self).__init__()
+ with self.assertRaisesRegex(RuntimeError, r"variable 'a' previously has type \(Tensor, Tensor\)"):
+ @torch.jit.script
+ def mixtypes(x):
+ a = (x, x)
+ if True:
+ a = 4
- def forward(self, *inputs):
- output = inputs[0]
- for i in range(1, len(inputs)):
- output += inputs[i]
- return output
+ def test_if_tuple_sizes(self):
+ with self.assertRaisesRegex(RuntimeError, "Type mismatch"):
+ @torch.jit.script
+ def diff_tuple_sizes(x):
+ if False:
+ c0 = ((x, x), (x, x, x))
+ else:
+ c0 = ((x, x, x), (x, x))
+ return c0
- class StarTestReturnThree(torch.nn.Module):
- def __init__(self):
- super(TestScript.StarTestReturnThree, self).__init__()
+ def test_if_different_type(self):
+ with self.assertRaisesRegex(RuntimeError, "Type mismatch: c0 is set to type int "
+ "in the true branch and type float in the false branch:"):
+ @torch.jit.script
+ def diff_type_used():
+ if False:
+ c0 = 1
+ else:
+ c0 = 1.0
+ return c0
- def forward(self, rep):
- return rep, rep, rep
+ with self.assertRaisesRegex(RuntimeError, "variable 'c0' previously has type float"):
+ @torch.jit.script
+ def diff_existing_type(x):
+ c0 = 1.0
+ if False:
+ c0 = 1
+ print(x)
+ return x
- def test_script_star_expr(self):
+ @torch.jit.script
+ def diff_type_unused():
+ if True:
+ c0 = 1
+ print(c0)
+ else:
+ c0 = 1.0
+ print(c0)
+ return 1
- class M2(torch.jit.ScriptModule):
- def __init__(self):
- super(M2, self).__init__(True)
- self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
- (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
- self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
+ def test_if_list(self):
+ # testing that different length lists don't throw error
+ @torch.jit.script
+ def test_list(x):
+ if True:
+ c = [x, x]
+ else:
+ c = [x, x, x]
+ return torch.cat(c)
- @torch.jit.script_method
- def forward(self, rep):
- tup = self.g(rep)
- return self.m(*tup)
+ b = torch.zeros(2, 4)
+ test_list.graph.propagate_shapes((b,), False)
+ self.assertExpected(canonical(test_list.graph))
- m = M2()
- self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
+ def test_if_supertype(self):
+ @torch.jit.script
+ def tensor_unifying(x, y, z):
- def test_script_star_expr_string(self):
- class M2(torch.jit.ScriptModule):
- def __init__(self):
- super(M2, self).__init__(True)
- self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
- (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
- self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
+ # testing dynamic is appropriately set for y and z
+ if True:
+ x, y, z = x, y, z
+ else:
+ x, y, z = x, x, y
- self.define('''
- def forward(self, rep):
- tup = self.g(rep)
- return self.m(*tup)
- ''')
+ return x, y, z
- m = M2()
- self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
+ a = torch.zeros(2, 2, dtype=torch.float)
+ b = torch.zeros(2, 4, dtype=torch.long)
+ c = torch.zeros(2, 4, dtype=torch.float)
- class StarTestSumAndReturnThree(torch.nn.Module):
- def __init__(self):
- super(TestScript.StarTestSumAndReturnThree, self).__init__()
+ tensor_unifying.graph.propagate_shapes((a, b, c), False)
+ self.assertExpected(canonical(tensor_unifying.graph))
- def forward(self, *inputs):
- output = inputs[0]
- for i in range(1, len(inputs)):
- output += inputs[i]
- return output, output, output
+ def test_type_annotations_repeated_list(self):
+ @torch.jit.script
+ def float_fn(x, y):
+ # type: (float, BroadcastingList3[float]) -> List[float]
+ return y
+ self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0]))
+ self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0)))
- def test_script_star_assign(self):
- class M2(torch.jit.ScriptModule):
- def __init__(self):
- super(M2, self).__init__(True)
- self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3))
- self.define('''
- def forward(self, rep):
- head, *tail = self.g(rep)
- return head
- ''')
+ @torch.jit.script
+ def float_fn_call():
+ print(float_fn(1.0, 1.0))
+ print(float_fn(1.0, (1.0, 1.0, 1.0)))
- m = M2()
- self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
+ @torch.jit.script
+ def int_fn(x):
+ # type: (BroadcastingList3[int]) -> List[int]
+ return x
+ self.assertEqual(int_fn(1), int_fn([1, 1, 1]))
+ self.assertEqual(int_fn(1), int_fn((1, 1, 1)))
- def test_script_module_star_assign2(self):
- class M2(torch.jit.ScriptModule):
- def __init__(self):
- super(M2, self).__init__(True)
- self.g = torch.jit.trace(
- TestScript.StarTestSumAndReturnThree(),
- (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
- _force_outplace=True)
- self.define('''
- def forward(self, rep):
- *head, tail = self.g(rep, rep, rep)
- return tail
- ''')
+ @torch.jit.script
+ def int_fn_call():
+ print(int_fn(1))
+ print(int_fn((1, 1, 1)))
- m = M2()
- self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
+ with self.assertRaisesRegex(RuntimeError, "expected number"):
+ @torch.jit.script
+ def fn(x):
+ # type: (BroadcastingListx[int]) -> List[int]
+ return x
- def test_script_module_star_assign2_inplace(self):
- class M2(torch.jit.ScriptModule):
- def __init__(self):
- super(M2, self).__init__(True)
- self.g = torch.jit.trace(
- TestScript.StarTestSumAndReturnThree(),
- (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
- _force_outplace=False)
- self.define('''
- def forward(self, rep):
- *head, tail = self.g(rep, rep, rep)
- return tail
- ''')
+ with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
+ @torch.jit.script
+ def nested(x, y):
+ # type: (int, Tuple[int, int[2]]) -> List[int]
+ return x
- m = M2()
- # since forward() makes three aliases to the input `rep` before passing
- # it to StarTestSumAndReturnThree(), in-place behavior will be different
- # than the above out of place.
- self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3))
+ def test_ntuple_builtins(self):
+ from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
- def test_script_module_star_assign_fail_pythonop(self):
+ def test_ints():
+ return _single(1), _pair(2), _triple(3), _quadruple(4)
- with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
- class M2(torch.jit.ScriptModule):
- def __init__(self):
- super(M2, self).__init__(True)
+ def test_floats():
+ return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1)
- def myfunc():
- return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
+ self.checkScript(test_ints, ())
+ self.checkScript(test_floats, ())
- self.define('''
- def forward(self, rep):
- a, *b = myfunc()
- return a
- ''')
+ def test_embedding_renorm_grad_error(self):
+ # Testing that the builtin call to embedding_renorm_ correctly throws
+ # Error when .backward() is called on its input
- m = M2()
- m(torch.zeros(4, 3))
+ def embedding_norm(input, embedding_matrix, max_norm):
+ F.embedding(input, embedding_matrix, max_norm=0.01)
- def test_script_module_star_assign_fail_builtin(self):
- with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
- class M2(torch.jit.ScriptModule):
- def __init__(self):
- super(M2, self).__init__(True)
+ @torch.jit.script
+ def embedding_norm_script(input, embedding_matrix, max_norm):
+ # type: (Tensor, Tensor, float)
+ F.embedding(input, embedding_matrix, max_norm=0.01)
- self.define('''
- def forward(self, rep):
- a, *b = torch.neg(rep)
- return a
- ''')
+ for fun in [embedding_norm, embedding_norm_script]:
+ input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
+ embedding_matrix = torch.randn(10, 3)
- m = M2()
- m(torch.zeros(4, 3))
+ var1 = torch.randn(10, 3, requires_grad=True)
+ var2 = var1.detach().requires_grad_()
+ output1 = var1 * embedding_matrix
+ output2 = var2 * embedding_matrix
- def test_pack_padded_pad_packed_trace(self):
- from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
- T, B, C = 3, 5, 7
+ output1.sum().backward()
- class PadPackedWrapper(torch.nn.Module):
- def __init__(self):
- super(PadPackedWrapper, self).__init__()
+ ignore = F.embedding(input, embedding_matrix, max_norm=0.01)
+ with self.assertRaisesRegex(RuntimeError, "modified"):
+ output2.sum().backward()
+
+ def test_type_annotations(self):
+ def fn(x, y):
+ # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
+ return x, x * 2, x * 3
- def forward(self, x, seq_lens):
- x = pack_padded_sequence(x, seq_lens)
- x, _ = pad_packed_sequence(x)
- return x
+ with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
+ @torch.jit.script
+ def script_fn(x):
+ x, y, z, w = fn(x, x)
- x = np.ones((T, B, C))
- seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
- # set padding value so we can test equivalence
- for b in range(B):
- if seq_lens[b] < T:
- x[seq_lens[b]:, b, :] = 0
- seq_lens = torch.from_numpy(seq_lens)
- x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True)
+ with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
+ @torch.jit.script
+ def script_fn2(x):
+ x, y = fn(x, x)
- m = PadPackedWrapper()
- m_traced = torch.jit.trace(m, (x, seq_lens,))
+ def fn_unpack(x):
+ y, z, w = fn(x, x)
+ return y
- y = m(x, seq_lens)
- loss = torch.sum(y)
- loss.backward()
- grad = x.grad.clone()
- x.grad.zero_()
+ def fn_index(x):
+ q = fn(x, x)
+ return x
- y_traced = m_traced(x, seq_lens)
- loss_traced = torch.sum(y_traced)
- loss_traced.backward()
- grad_traced = x.grad.clone()
+ def fn_string(str, strpair):
+ # type: (str, Tuple[str, str]) -> Tuple[str, int, str, str]
+ str1, str2 = strpair
+ return str, 2, str1, str2
- self.assertEqual(y_traced, x)
- self.assertEqual(y_traced, y)
- self.assertEqual(grad, grad_traced)
+ x = torch.ones(2, 2)
+ self.checkScript(fn_unpack, (x,), optimize=True)
+ self.checkScript(fn_index, (x,), optimize=True)
+ self.checkScript(fn_string, ("1", ("3", "4")), optimize=True)
- f = io.BytesIO()
- torch.onnx._export(m, (x, seq_lens), f, verbose=False)
+ def test_type_annotations_varargs(self):
+ def fn_varargs(x, *args):
+ return args[0] if args else x
- def test_script_outputs(self):
- with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
- @torch.jit.script
- def foo(a):
- c, d = a + a
- return c + d
+ def fn1(x, y, z):
+ return fn_varargs(x)
- @torch.jit.script
- def return3():
- return 1, 2, 3
+ def fn2(x, y, z):
+ return fn_varargs(x, y)
- with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
- @torch.jit.script
- def bind2():
- a, b = return3()
- print(a)
- print(b)
+ def fn3(x, y, z):
+ return fn_varargs(x, y, z)
- @unittest.skipIf(not RUN_CUDA, "requires CUDA")
- def test_script_get_device_cuda(self):
- @torch.jit.script
- def foo(a):
- return a.get_device()
+ x, y, z = [torch.randn(2, 2) for _ in range(3)]
+ self.checkScript(fn1, (x, y, z), optimize=True)
+ self.checkScript(fn2, (x, y, z), optimize=True)
+ self.checkScript(fn3, (x, y, z), optimize=True)
- v = torch.randn(1, device='cuda')
- self.assertEqual(foo(v), 0)
+ @unittest.skipIf(not PY35, "Python 3.5 needed")
+ def test_type_annotation_py3(self):
+ import importlib.util
- def test_script_chunk(self):
- @torch.jit.script
- def foo(a):
- b, c = torch.chunk(a, dim=0, chunks=2)
- return b
- v = torch.rand(10, 3)
- self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
+ code = dedent("""
+ import torch
+ from torch import Tensor
+ from typing import Tuple
- def test_rnn_trace_override(self):
- from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
- num_layers = 3
- T, B, C = 11, 5, 7
+ def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]:
+ return (x, y + z, z)
+ """)
- class RNNTraceWrapper(torch.nn.Module):
- def __init__(self, cell_type):
- super(RNNTraceWrapper, self).__init__()
- if cell_type == 'RNN':
- self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers)
- elif cell_type == 'LSTM':
- self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers)
- elif cell_type == 'GRU':
- self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ script_path = os.path.join(tmp_dir, 'script.py')
+ with open(script_path, 'w') as f:
+ f.write(code)
+ fn = get_fn('test_type_annotation_py3', script_path)
- def forward(self, x, seq_lens):
- x = pack_padded_sequence(x, seq_lens)
- x, _ = self.rnn(x)
- x, _ = pad_packed_sequence(x)
- return x
+ with self.assertRaisesRegex(RuntimeError, r"expected a value of type Tensor for argument"
+ r" '0' but found \(Tensor, Tensor\)"):
+ @torch.jit.script
+ def bad_fn(x):
+ x, y = fn((x, x), x, x)
+ return y
- for cell_type in ['RNN', 'LSTM', 'GRU']:
- x = torch.ones(T, B, C, requires_grad=True)
- seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
+ with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
+ @torch.jit.script
+ def bad_fn2(x):
+ x, y = fn(x, x, x)
+ return y
- m = RNNTraceWrapper(cell_type)
- m_traced = torch.jit.trace(m, (x, seq_lens,))
+ with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
+ @torch.jit.script
+ def bad_fn3(x):
+ x, y, z, w = fn(x, x, x)
+ return y
- y = m(x, seq_lens)
- loss = torch.sum(y)
- loss.backward()
- grad = x.grad.clone()
- x.grad.zero_()
+ def good_fn(x):
+ y, z, w = fn(x, x, x)
+ return y, z, w
- y_traced = m_traced(x, seq_lens)
- loss_traced = torch.sum(y_traced)
- loss_traced.backward()
- grad_traced = x.grad.clone()
+ self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
- self.assertEqual(y_traced, y)
- self.assertEqual(grad, grad_traced)
+ def test_type_annotation_module(self):
+ class BaseModule(torch.jit.ScriptModule):
+ def foo(self, x):
+ # type: (Tensor) -> Tensor
+ return x + 1
- f = io.BytesIO()
- torch.onnx._export(m, (x, seq_lens), f, verbose=False)
+ def bar(self, x, y):
+ # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
+ return x + y, y
- def test_python_call_non_tensor(self):
- def foo(a, b, c):
- # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor]
- d, e = c
- return b + e, a + d
+ def baz(self, x, y):
+ return x
- @torch.jit.script
- def bar():
- x = torch.ones(3, 4)
- a, b = foo(x, 3, (x, 3))
- return a, b
+ class ModuleTooMany(BaseModule):
+ @torch.jit.script_method
+ def method(self, x):
+ return self.foo(x, x)
- self.assertEqual((6, torch.ones(3, 4) + 1), bar())
+ class ModuleTooFew(BaseModule):
+ @torch.jit.script_method
+ def method(self, x):
+ return self.bar(x)
- def test_python_call_non_tensor_wrong(self):
- with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
- def foo():
- # type: () -> Tensor
- return ((3, 4),)
+ class ModuleTooManyAssign(BaseModule):
+ @torch.jit.script_method
+ def method(self, x):
+ y, z, w = self.bar(x, x)
+ return x
- @torch.jit.script
- def bar():
- return foo()
+ class ModuleDefault(BaseModule):
+ @torch.jit.script_method
+ def method(self, x):
+ y = self.baz(x)
+ return x
- bar()
+ with self.assertRaisesRegex(RuntimeError, "expected at most 1 arguments but found 2"):
+ ModuleTooMany()
+ with self.assertRaisesRegex(RuntimeError, "argument 1 not provided"):
+ ModuleTooFew()
+ with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"):
+ ModuleTooManyAssign()
+ with self.assertRaisesRegex(RuntimeError, "argument 1 not provided."):
+ ModuleDefault()
- def test_tuples(self):
- def foo(i):
- a = (i + 4, i * 2)
- c = a
- # some nonsense with if-statements and loops to check
- # that tuple lowering doesn't fail
- if True:
- c = (i * 9, i + 1)
- t0, t1 = c
- while False:
- t0, t1 = c
- c = (t1, t0)
- x = (1,)
- y = 1,
- return t0, x, y
+ def test_script_define_order(self):
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ pass
- v = torch.rand(10, 3)
- self.checkScript(foo, (v,))
+ @torch.jit.script_method
+ def call_foo(self, input):
+ return self.foo(input)
- with self.assertRaisesRegex(RuntimeError, r"variable 'a' previously has type \(Tensor, Tensor\)"):
- @torch.jit.script
- def mixtypes(x):
- a = (x, x)
- if True:
- a = 4
+ @torch.jit.script_method
+ def foo(self, input):
+ return input + 1
+ m = M()
+ self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
- def test_if_tuple_sizes(self):
- with self.assertRaisesRegex(RuntimeError, "Type mismatch"):
- @torch.jit.script
- def diff_tuple_sizes(x):
- if False:
- c0 = ((x, x), (x, x, x))
- else:
- c0 = ((x, x, x), (x, x))
- return c0
+ def test_script_define_order_recursive_fail(self):
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ pass
- def test_if_different_type(self):
- with self.assertRaisesRegex(RuntimeError, "Type mismatch: c0 is set to type int "
- "in the true branch and type float in the false branch:"):
- @torch.jit.script
- def diff_type_used():
- if False:
- c0 = 1
- else:
- c0 = 1.0
- return c0
+ @torch.jit.script_method
+ def call_foo(self, input):
+ return self.foo(input)
- with self.assertRaisesRegex(RuntimeError, "variable 'c0' previously has type float"):
- @torch.jit.script
- def diff_existing_type(x):
- c0 = 1.0
- if False:
- c0 = 1
- print(x)
- return x
+ @torch.jit.script_method
+ def foo(self, input):
+ self.call_foo(input)
- @torch.jit.script
- def diff_type_unused():
- if True:
- c0 = 1
- print(c0)
- else:
- c0 = 1.0
- print(c0)
- return 1
+ with self.assertRaisesRegex(RuntimeError, 'called recursively involving'):
+ M()
- def test_if_list(self):
- # testing that different length lists don't throw error
- @torch.jit.script
- def test_list(x):
- if True:
- c = [x, x]
- else:
- c = [x, x, x]
- return torch.cat(c)
+ def test_script_kwargs_fn_call(self):
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ pass
- b = torch.zeros(2, 4)
- test_list.graph.propagate_shapes((b,), False)
- self.assertExpected(canonical(test_list.graph))
+ @torch.jit.script_method
+ def call_foo(self, input):
+ return self.foo(input=input, bar=1)
- def test_if_supertype(self):
+ @torch.jit.script_method
+ def foo(self, bar, input):
+ # type: (int, Tensor) -> Tensor
+ return input + bar
+ m = M()
+ self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
+
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ def test_trace_of_script(self):
@torch.jit.script
- def tensor_unifying(x, y, z):
+ def foo(a, c):
+ b = 0.0
+ if bool(a == 0.0):
+ b = 1.0
+ return b + c
- # testing dynamic is appropriately set for y and z
- if True:
- x, y, z = x, y, z
- else:
- x, y, z = x, x, y
+ a = torch.ones(1, dtype=torch.float)
- return x, y, z
+ @_trace(torch.zeros(1, dtype=torch.float))
+ def use(b):
+ return foo(b - 1.0, a) + 1.0
- a = torch.zeros(2, 2, dtype=torch.float)
- b = torch.zeros(2, 4, dtype=torch.long)
- c = torch.zeros(2, 4, dtype=torch.float)
+ # test we propagated shapes through the function
+ self.assertTrue("Dynamic" not in str(use.graph))
- tensor_unifying.graph.propagate_shapes((a, b, c), False)
- self.assertExpected(canonical(tensor_unifying.graph))
+ self.assertEqual(3, use(torch.ones(1, dtype=torch.float)))
+ self.assertEqual(2, use(torch.zeros(1, dtype=torch.float)))
- def test_type_annotations_repeated_list(self):
+ def test_if_define(self):
@torch.jit.script
- def float_fn(x, y):
- # type: (float, BroadcastingList3[float]) -> List[float]
- return y
- self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0]))
- self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0)))
+ def foo(a):
+ if bool(a == 0):
+ b = 1
+ else:
+ b = 0
+ return b + 1
@torch.jit.script
- def float_fn_call():
- print(float_fn(1.0, 1.0))
- print(float_fn(1.0, (1.0, 1.0, 1.0)))
+ def foo2(a):
+ b = 0
+ if bool(a == 0):
+ b = 1
+ return b + 1
@torch.jit.script
- def int_fn(x):
- # type: (BroadcastingList3[int]) -> List[int]
- return x
- self.assertEqual(int_fn(1), int_fn([1, 1, 1]))
- self.assertEqual(int_fn(1), int_fn((1, 1, 1)))
+ def foo3(a):
+ b = 1
+ if bool(a == 0):
+ c = 4
+ else:
+ b = 0
+ return b + 1
- @torch.jit.script
- def int_fn_call():
- print(int_fn(1))
- print(int_fn((1, 1, 1)))
+ a = torch.ones(1, dtype=torch.long)
+ b = torch.zeros(1, dtype=torch.long)
+ self.assertEqual(1, foo(a))
+ self.assertEqual(2, foo(b))
+ self.assertEqual(1, foo2(a))
+ self.assertEqual(2, foo2(b))
+ self.assertEqual(1, foo3(a))
+ self.assertEqual(2, foo3(b))
- with self.assertRaisesRegex(RuntimeError, "expected number"):
- @torch.jit.script
- def fn(x):
- # type: (BroadcastingListx[int]) -> List[int]
- return x
+ def test_script_module_export_submodule(self):
+ class M1(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M1, self).__init__(False)
+ self.weight = nn.Parameter(torch.randn(2))
- with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
- @torch.jit.script
- def nested(x, y):
- # type: (int, Tuple[int, int[2]]) -> List[int]
- return x
+ @torch.jit.script_method
+ def forward(self, thing):
+ return self.weight + thing
- def test_ntuple_builtins(self):
- from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
+ class M2(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M2, self).__init__(False)
+ # test submodule
+ self.sub = M1()
+ self.weight = nn.Parameter(torch.randn(2, 3))
+ self.bias = nn.Parameter(torch.randn(2))
+ self.define("""
+ def hi(self, a):
+ return self.weight.mm(a)
+ """)
- def test_ints():
- return _single(1), _pair(2), _triple(3), _quadruple(4)
+ @torch.jit.script_method
+ def doit(self, input):
+ return self.weight.mm(input)
- def test_floats():
- return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1)
+ @torch.jit.script_method
+ def doit2(self, input):
+ return self.weight.mm(input)
- self.checkScript(test_ints, ())
- self.checkScript(test_floats, ())
+ @torch.jit.script_method
+ def doit3(self, input):
+ return input + torch.ones([1], dtype=torch.double)
- def test_embedding_renorm_grad_error(self):
- # Testing that the builtin call to embedding_renorm_ correctly throws
- # Error when .backward() is called on its input
+ @torch.jit.script_method
+ def forward(self, input):
+ a = self.doit(input)
+ b = self.doit2(input)
+ c = self.hi(input)
+ return a + b + self.bias + c
- def embedding_norm(input, embedding_matrix, max_norm):
- F.embedding(input, embedding_matrix, max_norm=0.01)
+ m_orig = M2()
+ m_import = self.getExportImportCopy(m_orig)
- @torch.jit.script
- def embedding_norm_script(input, embedding_matrix, max_norm):
- # type: (Tensor, Tensor, float)
- F.embedding(input, embedding_matrix, max_norm=0.01)
+ input = torch.randn(3, 2)
+ self.assertEqual(m_orig.doit(input), m_import.doit(input))
+ self.assertEqual(m_orig.hi(input), m_import.hi(input))
+ self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
+ self.assertEqual(m_orig.forward(input), m_import.forward(input))
- for fun in [embedding_norm, embedding_norm_script]:
- input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
- embedding_matrix = torch.randn(10, 3)
+ @skipIfNoTorchVision
+ def test_script_module_trace_resnet18(self):
+ x = torch.ones(1, 3, 224, 224)
+ m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224))
+ m_import = self.getExportImportCopy(m_orig)
- var1 = torch.randn(10, 3, requires_grad=True)
- var2 = var1.detach().requires_grad_()
- output1 = var1 * embedding_matrix
- output2 = var2 * embedding_matrix
+ input = torch.randn(1, 3, 224, 224, requires_grad=True)
+ output_orig = m_orig(input)
+ output_orig.sum().backward()
+ grad_orig = input.grad.clone()
+ input.grad.zero_()
- output1.sum().backward()
+ output_import = m_import(input)
+ output_import.sum().backward()
+ grad_import = input.grad.clone()
- ignore = F.embedding(input, embedding_matrix, max_norm=0.01)
- with self.assertRaisesRegex(RuntimeError, "modified"):
- output2.sum().backward()
+ self.assertEqual(output_orig, output_import)
+ self.assertEqual(grad_orig, grad_import)
- def test_type_annotations(self):
- def fn(x, y):
- # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
- return x, x * 2, x * 3
+ @skipIfNoTorchVision
+ def test_script_module_script_resnet(self):
+ def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
- with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
- @torch.jit.script
- def script_fn(x):
- x, y, z, w = fn(x, x)
+ def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
- with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
- @torch.jit.script
- def script_fn2(x):
- x, y = fn(x, x)
+ class BasicBlock(torch.jit.ScriptModule):
+ expansion = 1
+ __constants__ = ['downsample']
- def fn_unpack(x):
- y, z, w = fn(x, x)
- return y
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
- def fn_index(x):
- q = fn(x, x)
- return x
+ @torch.jit.script_method
+ def forward(self, x):
+ residual = x
- def fn_string(str, strpair):
- # type: (str, Tuple[str, str]) -> Tuple[str, int, str, str]
- str1, str2 = strpair
- return str, 2, str1, str2
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
- x = torch.ones(2, 2)
- self.checkScript(fn_unpack, (x,), optimize=True)
- self.checkScript(fn_index, (x,), optimize=True)
- self.checkScript(fn_string, ("1", ("3", "4")), optimize=True)
+ out = self.conv2(out)
+ out = self.bn2(out)
- def test_type_annotations_varargs(self):
- def fn_varargs(x, *args):
- return args[0] if args else x
+ if self.downsample is not None:
+ residual = self.downsample(x)
- def fn1(x, y, z):
- return fn_varargs(x)
+ out += residual
+ out = self.relu(out)
- def fn2(x, y, z):
- return fn_varargs(x, y)
+ return out
- def fn3(x, y, z):
- return fn_varargs(x, y, z)
+ class ResNet(torch.jit.ScriptModule):
+ __constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
- x, y, z = [torch.randn(2, 2) for _ in range(3)]
- self.checkScript(fn1, (x, y, z), optimize=True)
- self.checkScript(fn2, (x, y, z), optimize=True)
- self.checkScript(fn3, (x, y, z), optimize=True)
+ def __init__(self, block, layers, num_classes=1000):
+ super(ResNet, self).__init__()
+ self.inplanes = 64
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
- @unittest.skipIf(not PY35, "Python 3.5 needed")
- def test_type_annotation_py3(self):
- import importlib.util
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
- code = dedent("""
- import torch
- from torch import Tensor
- from typing import Tuple
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
- def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]:
- return (x, y + z, z)
- """)
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
- with tempfile.TemporaryDirectory() as tmp_dir:
- script_path = os.path.join(tmp_dir, 'script.py')
- with open(script_path, 'w') as f:
- f.write(code)
- fn = get_fn('test_type_annotation_py3', script_path)
+ return nn.Sequential(*layers)
- with self.assertRaisesRegex(RuntimeError, r"expected a value of type Tensor for argument"
- r" '0' but found \(Tensor, Tensor\)"):
- @torch.jit.script
- def bad_fn(x):
- x, y = fn((x, x), x, x)
- return y
+ @torch.jit.script_method
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
- with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
- @torch.jit.script
- def bad_fn2(x):
- x, y = fn(x, x, x)
- return y
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
- with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
- @torch.jit.script
- def bad_fn3(x):
- x, y, z, w = fn(x, x, x)
- return y
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
- def good_fn(x):
- y, z, w = fn(x, x, x)
- return y, z, w
+ return x
- self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
+ resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
- def test_type_annotation_module(self):
- class BaseModule(torch.jit.ScriptModule):
- def foo(self, x):
- # type: (Tensor) -> Tensor
- return x + 1
+ resnet18_imported = self.getExportImportCopy(resnet18)
- def bar(self, x, y):
- # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
- return x + y, y
+ input = torch.randn(1, 3, 224, 224, requires_grad=True)
+ output_orig = resnet18(input)
+ output_orig.sum().backward()
+ grad_orig = input.grad.clone()
+ input.grad.zero_()
- def baz(self, x, y):
- return x
+ output_import = resnet18_imported(input)
+ output_import.sum().backward()
+ grad_import = input.grad.clone()
- class ModuleTooMany(BaseModule):
- @torch.jit.script_method
- def method(self, x):
- return self.foo(x, x)
+ self.assertEqual(output_orig, output_import)
+ self.assertEqual(grad_orig, grad_import)
- class ModuleTooFew(BaseModule):
- @torch.jit.script_method
- def method(self, x):
- return self.bar(x)
+ def test_script_module_export_tensor_type(self):
+ class M(torch.jit.ScriptModule):
- class ModuleTooManyAssign(BaseModule):
- @torch.jit.script_method
- def method(self, x):
- y, z, w = self.bar(x, x)
- return x
+ def __init__(self, type):
+ super(M, self).__init__(False)
+ self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_())
- class ModuleDefault(BaseModule):
@torch.jit.script_method
- def method(self, x):
- y = self.baz(x)
- return x
+ def foo(self):
+ return self.param
- with self.assertRaisesRegex(RuntimeError, "expected at most 1 arguments but found 2"):
- ModuleTooMany()
- with self.assertRaisesRegex(RuntimeError, "argument 1 not provided"):
- ModuleTooFew()
- with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"):
- ModuleTooManyAssign()
- with self.assertRaisesRegex(RuntimeError, "argument 1 not provided."):
- ModuleDefault()
+ for type in [torch.float, torch.double]:
+ m_orig = M(type)
+ m_import = self.getExportImportCopy(m_orig)
+ # check to make sure the storage wasn't resized
+ self.assertTrue(m_orig.param.storage().size() == 25)
+ self.assertEqual(m_orig.foo(), m_import.foo())
+ self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
- def test_script_define_order(self):
+ @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA")
+ def test_script_module_export_tensor_cuda(self):
class M(torch.jit.ScriptModule):
+
def __init__(self):
- pass
+ super(M, self).__init__(False)
+ self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_())
@torch.jit.script_method
- def call_foo(self, input):
- return self.foo(input)
+ def foo(self):
+ return self.param
- @torch.jit.script_method
- def foo(self, input):
- return input + 1
- m = M()
- self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
+ m_orig = M()
+ m_import = self.getExportImportCopy(m_orig)
+ # check to make sure the storage wasn't resized
+ self.assertTrue(m_orig.param.storage().size() == 25)
+ self.assertTrue(m_import.foo().device == torch.device('cuda:0'))
+ self.assertEqual(m_orig.foo(), m_import.foo())
+ self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
- def test_script_define_order_recursive_fail(self):
+ def test_script_module_export_blocks(self):
class M(torch.jit.ScriptModule):
- def __init__(self):
- pass
+ def __init__(self, n, m):
+ super(M, self).__init__()
+ self.weight = torch.nn.Parameter(torch.rand(n, m))
@torch.jit.script_method
- def call_foo(self, input):
- return self.foo(input)
+ def forward(self, input):
+ if bool(input.sum() > 0):
+ output = self.weight.mv(input)
+ else:
+ output = self.weight + input
+ return output
+
+ m_orig = M(200, 200)
+ m_import = self.getExportImportCopy(m_orig)
+
+ t = torch.rand(200)
+ self.assertEqual(m_orig(t), m_import(t))
+
+ def test_script_module_export_shared_storage(self):
+ class M(torch.jit.ScriptModule):
+
+ def __init__(self):
+ super(M, self).__init__(False)
+ self.param1 = torch.nn.Parameter(torch.rand(5, 5))
+ self.param2 = torch.nn.Parameter(self.param1[3])
+ self.param3 = torch.nn.Parameter(torch.rand(5, 5))
+ self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6])
@torch.jit.script_method
- def foo(self, input):
- self.call_foo(input)
+ def foo(self):
+ return self.param1 + self.param2 + self.param3 + self.param4
- with self.assertRaisesRegex(RuntimeError, 'called recursively involving'):
- M()
+ m_orig = M()
+ m_import = self.getExportImportCopy(m_orig)
- def test_script_kwargs_fn_call(self):
- class M(torch.jit.ScriptModule):
+ self.assertEqual(m_orig.foo(), m_import.foo())
+ self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
+ self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
+
+ def test_onnx_export_script_module(self):
+ class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
- pass
+ super(ModuleToExport, self).__init__()
@torch.jit.script_method
- def call_foo(self, input):
- return self.foo(input=input, bar=1)
+ def forward(self, x):
+ y = x - x
+ return x + x
- @torch.jit.script_method
- def foo(self, bar, input):
- # type: (int, Tensor) -> Tensor
- return input + bar
- m = M()
- self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
+ mte = ModuleToExport()
+ outputs = mte(torch.zeros(1, 2, 3))
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+ example_outputs=outputs))
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- def test_trace_of_script(self):
- @torch.jit.script
- def foo(a, c):
- b = 0.0
- if bool(a == 0.0):
- b = 1.0
- return b + c
+ def test_onnx_export_script_python_fail(self):
+ class ModuleToInline(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToInline, self).__init__()
- a = torch.ones(1, dtype=torch.float)
+ def forward(self, x):
+ return torch.neg(x)
- @_trace(torch.zeros(1, dtype=torch.float))
- def use(b):
- return foo(b - 1.0, a) + 1.0
+ class ModuleToExport(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToExport, self).__init__()
+ self.mod = ModuleToInline()
- # test we propagated shapes through the function
- self.assertTrue("Dynamic" not in str(use.graph))
+ @torch.jit.script_method
+ def forward(self, x):
+ y = self.mod(x)
+ return y + y
- self.assertEqual(3, use(torch.ones(1, dtype=torch.float)))
- self.assertEqual(2, use(torch.zeros(1, dtype=torch.float)))
+ mte = ModuleToExport()
+ outputs = mte(torch.zeros(1, 2, 3))
+ f = io.BytesIO()
+ with self.assertRaisesRegex(RuntimeError, "Couldn't export Python operator"):
+ torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False,
+ example_outputs=outputs)
- def test_if_define(self):
- @torch.jit.script
- def foo(a):
- if bool(a == 0):
- b = 1
- else:
- b = 0
- return b + 1
+ def test_onnx_export_script_inline_trace(self):
+ class ModuleToInline(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToInline, self).__init__()
- @torch.jit.script
- def foo2(a):
- b = 0
- if bool(a == 0):
- b = 1
- return b + 1
+ def forward(self, x):
+ return torch.neg(x)
- @torch.jit.script
- def foo3(a):
- b = 1
- if bool(a == 0):
- c = 4
- else:
- b = 0
- return b + 1
+ class ModuleToExport(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToExport, self).__init__()
+ self.mod = torch.jit.trace(ModuleToInline(), torch.zeros(1, 2, 3))
- a = torch.ones(1, dtype=torch.long)
- b = torch.zeros(1, dtype=torch.long)
- self.assertEqual(1, foo(a))
- self.assertEqual(2, foo(b))
- self.assertEqual(1, foo2(a))
- self.assertEqual(2, foo2(b))
- self.assertEqual(1, foo3(a))
- self.assertEqual(2, foo3(b))
+ @torch.jit.script_method
+ def forward(self, x):
+ y = self.mod(x)
+ return y + y
- def test_script_module_export_submodule(self):
- class M1(torch.jit.ScriptModule):
+ mte = ModuleToExport()
+ outputs = mte(torch.zeros(1, 2, 3))
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+ example_outputs=outputs))
+
+ def test_onnx_export_script_inline_script(self):
+ class ModuleToInline(torch.jit.ScriptModule):
def __init__(self):
- super(M1, self).__init__(False)
- self.weight = nn.Parameter(torch.randn(2))
+ super(ModuleToInline, self).__init__()
@torch.jit.script_method
- def forward(self, thing):
- return self.weight + thing
+ def forward(self, x):
+ return torch.neg(x)
- class M2(torch.jit.ScriptModule):
+ class ModuleToExport(torch.jit.ScriptModule):
def __init__(self):
- super(M2, self).__init__(False)
- # test submodule
- self.sub = M1()
- self.weight = nn.Parameter(torch.randn(2, 3))
- self.bias = nn.Parameter(torch.randn(2))
- self.define("""
- def hi(self, a):
- return self.weight.mm(a)
- """)
+ super(ModuleToExport, self).__init__()
+ self.mod = ModuleToInline()
@torch.jit.script_method
- def doit(self, input):
- return self.weight.mm(input)
+ def forward(self, x):
+ y = self.mod(x)
+ return y + y
- @torch.jit.script_method
- def doit2(self, input):
- return self.weight.mm(input)
+ mte = ModuleToExport()
+ outputs = mte(torch.zeros(1, 2, 3))
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+ example_outputs=outputs))
- @torch.jit.script_method
- def doit3(self, input):
- return input + torch.ones([1], dtype=torch.double)
+ def test_onnx_export_script_module_loop(self):
+ class ModuleToExport(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToExport, self).__init__()
@torch.jit.script_method
- def forward(self, input):
- a = self.doit(input)
- b = self.doit2(input)
- c = self.hi(input)
- return a + b + self.bias + c
+ def forward(self, x):
+ # test if we support end to end onnx export on loop and
+ # nested loops with and without loop index
+ for _ in range(5):
+ for i in range(3):
+ x = x + i
+ return x
- m_orig = M2()
- m_import = self.getExportImportCopy(m_orig)
+ mte = ModuleToExport()
+ outputs = mte(torch.zeros(1, 2, 3))
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+ example_outputs=outputs))
- input = torch.randn(3, 2)
- self.assertEqual(m_orig.doit(input), m_import.doit(input))
- self.assertEqual(m_orig.hi(input), m_import.hi(input))
- self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
- self.assertEqual(m_orig.forward(input), m_import.forward(input))
+ def test_onnx_export_script_truediv(self):
+ class ModuleToExport(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToExport, self).__init__()
- @skipIfNoTorchVision
- def test_script_module_trace_resnet18(self):
- x = torch.ones(1, 3, 224, 224)
- m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224))
- m_import = self.getExportImportCopy(m_orig)
+ @torch.jit.script_method
+ def forward(self, x):
+ z = x.size(0) / 2
+ return x + z
- input = torch.randn(1, 3, 224, 224, requires_grad=True)
- output_orig = m_orig(input)
- output_orig.sum().backward()
- grad_orig = input.grad.clone()
- input.grad.zero_()
+ mte = ModuleToExport()
+ outputs = mte(torch.zeros(1, 2, 3))
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+ example_outputs=outputs))
- output_import = m_import(input)
- output_import.sum().backward()
- grad_import = input.grad.clone()
+ def test_onnx_raw_export_script_truediv(self):
+ class ModuleToExport(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToExport, self).__init__()
- self.assertEqual(output_orig, output_import)
- self.assertEqual(grad_orig, grad_import)
+ @torch.jit.script_method
+ def forward(self, x):
+ z = x.size(0) / 2
+ return x + z
- @skipIfNoTorchVision
- def test_script_module_script_resnet(self):
- def conv1x1(in_planes, out_planes, stride=1):
- """1x1 convolution"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+ mte = ModuleToExport()
+ outputs = mte(torch.zeros(1, 2, 3))
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+ example_outputs=outputs, export_raw_ir=True))
- def conv3x3(in_planes, out_planes, stride=1):
- """3x3 convolution with padding"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
- padding=1, bias=False)
+ def test_onnx_export_script_non_alpha_add_sub(self):
+ class ModuleToExport(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToExport, self).__init__()
- class BasicBlock(torch.jit.ScriptModule):
- expansion = 1
- __constants__ = ['downsample']
+ @torch.jit.script_method
+ def forward(self, x):
+ bs = x.size(0) + 1
+ return bs - 1
+
+ mte = ModuleToExport()
+ outputs = torch.LongTensor([mte(torch.rand(3, 4))])
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ mte, (torch.rand(3, 4),), None, verbose=False,
+ example_outputs=outputs))
- def __init__(self, inplanes, planes, stride=1, downsample=None):
- super(BasicBlock, self).__init__()
- self.conv1 = conv3x3(inplanes, planes, stride)
- self.bn1 = nn.BatchNorm2d(planes)
- self.relu = nn.ReLU(inplace=True)
- self.conv2 = conv3x3(planes, planes)
- self.bn2 = nn.BatchNorm2d(planes)
- self.downsample = downsample
- self.stride = stride
+ def test_onnx_export_script_module_if(self):
+ class ModuleToExport(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToExport, self).__init__()
@torch.jit.script_method
def forward(self, x):
- residual = x
+ if bool(torch.sum(x) > 0):
+ x = torch.neg(x)
+ return x
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
+ mte = ModuleToExport()
+ outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+ example_outputs=outputs))
- out = self.conv2(out)
- out = self.bn2(out)
+ def test_onnx_export_script_inline_params(self):
+ class ModuleToInline(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToInline, self).__init__()
+ self.m = torch.nn.Parameter(torch.ones(3, 3))
+ self.unused = torch.nn.Parameter(torch.ones(1, 2, 3))
- if self.downsample is not None:
- residual = self.downsample(x)
+ @torch.jit.script_method
+ def forward(self, x):
+ return torch.mm(x, self.m)
- out += residual
- out = self.relu(out)
+ class ModuleToExport(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ModuleToExport, self).__init__()
+ self.mod = ModuleToInline()
+ self.param = torch.nn.Parameter(torch.ones(3, 4))
- return out
+ @torch.jit.script_method
+ def forward(self, x):
+ y = self.mod(x)
+ return torch.mm(y, self.param)
- class ResNet(torch.jit.ScriptModule):
- __constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
+ mte = ModuleToExport()
+ result = mte(torch.zeros(2, 3))
+ reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
+ self.assertEqual(result, reference)
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ mte, (torch.ones(2, 3),), None, verbose=False,
+ example_outputs=result, propagate=True))
- def __init__(self, block, layers, num_classes=1000):
- super(ResNet, self).__init__()
- self.inplanes = 64
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
- bias=False)
- self.bn1 = nn.BatchNorm2d(64)
- self.relu = nn.ReLU(inplace=True)
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- self.layer1 = self._make_layer(block, 64, layers[0])
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- self.fc = nn.Linear(512 * block.expansion, num_classes)
+ def test_trace_with_size(self):
+ @_trace(torch.zeros(1, 1))
+ def foo(x):
+ return x + 1
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
+ @torch.jit.script
+ def bar(x):
+ y = int(foo(x))
+ if True:
+ y = 7
+ return y + 1
- def _make_layer(self, block, planes, blocks, stride=1):
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- conv1x1(self.inplanes, planes * block.expansion, stride),
- nn.BatchNorm2d(planes * block.expansion),
- )
+ self.assertEqual(8, bar(torch.ones(1, 1)))
- layers = []
- layers.append(block(self.inplanes, planes, stride, downsample))
- self.inplanes = planes * block.expansion
- for _ in range(1, blocks):
- layers.append(block(self.inplanes, planes))
+ def test_tracing_slicing(self):
+ @_trace(torch.zeros(10))
+ def foo_trace(x):
+ return x[-5:-3]
- return nn.Sequential(*layers)
+ @torch.jit.script
+ def foo_script(x):
+ return x[-5:-3]
- @torch.jit.script_method
- def forward(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.maxpool(x)
+ def foo(x):
+ return x[-5:-3]
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
+ a = torch.arange(0, 8)
+ b = torch.arange(0, 20)
+ self.assertEqual(foo_trace(a), foo_script(a))
+ self.assertEqual(foo_trace(a), foo(a))
+ self.assertNotEqual(foo_trace(a), foo_trace(b))
- x = self.avgpool(x)
- x = x.view(x.size(0), -1)
- x = self.fc(x)
+ def test_tracing_indexing(self):
+ @_trace(torch.zeros(10))
+ def foo_trace(x):
+ return x[-2]
- return x
+ @torch.jit.script
+ def foo_script(x):
+ return x[-2]
- resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
+ def foo(x):
+ return x[-2]
- resnet18_imported = self.getExportImportCopy(resnet18)
+ a = torch.arange(0, 8)
+ b = torch.arange(0, 20)
+ self.assertEqual(foo_script(a), foo_trace(a))
+ self.assertEqual(foo_trace(a), foo(a))
+ self.assertNotEqual(foo_trace(a), foo_trace(b))
- input = torch.randn(1, 3, 224, 224, requires_grad=True)
- output_orig = resnet18(input)
- output_orig.sum().backward()
- grad_orig = input.grad.clone()
- input.grad.zero_()
+ def test_index_select_shape_prop(self):
- output_import = resnet18_imported(input)
- output_import.sum().backward()
- grad_import = input.grad.clone()
+ @torch.jit.script
+ def foo(x, y):
+ return torch.index_select(x, index=y, dim=1)
- self.assertEqual(output_orig, output_import)
- self.assertEqual(grad_orig, grad_import)
+ a = torch.zeros(2, 2)
+ b = torch.zeros(4, dtype=torch.long)
+ torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False)
+ self.assertExpected(canonical(foo.graph))
- def test_script_module_export_tensor_type(self):
- class M(torch.jit.ScriptModule):
+ def test_onnx_export_speculate(self):
- def __init__(self, type):
- super(M, self).__init__(False)
- self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_())
+ class Foo(torch.jit.ScriptModule):
+ def __init__(self, m):
+ super(Foo, self).__init__()
+ self.m = m
@torch.jit.script_method
- def foo(self):
- return self.param
+ def forward(self, x):
+ x += x
+ # because we are testing if we emit `if` statement correctly
+ # we cannot use `True` as the condition. Constant prop
+ # would remove the `if` statements.
+ c = torch.sum(x) > 4
+ if bool(c):
+ if bool(c):
+ y = self.m(x)
+ else:
+ y = self.m(x)
+ else:
+ y = self.m(x)
+ return y
- for type in [torch.float, torch.double]:
- m_orig = M(type)
- m_import = self.getExportImportCopy(m_orig)
- # check to make sure the storage wasn't resized
- self.assertTrue(m_orig.param.storage().size() == 25)
- self.assertEqual(m_orig.foo(), m_import.foo())
- self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
+ linear = torch.jit.trace(nn.Linear(10, 20).float(), torch.zeros(1, 10, dtype=torch.float))
- @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA")
- def test_script_module_export_tensor_cuda(self):
- class M(torch.jit.ScriptModule):
+ @torch.jit.script
+ def transpose(x):
+ return x.t()
- def __init__(self):
- super(M, self).__init__(False)
- self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_())
+ f1 = Foo(transpose)
+ outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
+ f2 = Foo(linear)
+ outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))
- @torch.jit.script_method
- def foo(self):
- return self.param
+ onnx_ish = torch.onnx.export_to_pretty_string(
+ f1,
+ (torch.ones(1, 10, dtype=torch.float), ),
+ None, verbose=False, example_outputs=outputs_f1)
+ self.assertExpected(onnx_ish, subname='f1')
+ onnx_ish = torch.onnx.export_to_pretty_string(
+ f2,
+ (torch.ones(1, 10, dtype=torch.float), ),
+ None, verbose=False, example_outputs=outputs_f2)
+ self.assertExpected(onnx_ish, subname='f2')
+
+ def test_onnx_export_shape_reshape(self):
+ class Foo(torch.nn.Module):
+ def forward(self, x):
+ import torch.onnx.operators
+ x = x.repeat(5, 1, 1)
+ shape = torch.onnx.operators.shape_as_tensor(x)
+ reshaped = torch.onnx.operators.reshape_from_tensor_shape(x, shape)
+ return reshaped
+
+ foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
+ outputs = foo(torch.zeros(1, 2, 3))
+ f = io.BytesIO()
+ s = torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f,
+ example_outputs=outputs)
+ self.assertExpected(s)
+
+ def test_shape_analysis_loop(self):
+ def foo(a, b, x):
+ c = a
+ # on the first iteration of the loop it appears that
+ # c should have a expand to the size of b
+ # but on the second+ iterations, there is no broadcast and the
+ # sizes are different.
+ # previously this would cause the compiler to (1) enter an infinite
+ # loop trying to compute the shape, and (2) insert invalid
+ # broadcasts.
+ # this test ensure we don't regress on these issues
+ for _ in range(2):
+ a = c + b
+ c = x
+ b = x
+ return a
- m_orig = M()
- m_import = self.getExportImportCopy(m_orig)
- # check to make sure the storage wasn't resized
- self.assertTrue(m_orig.param.storage().size() == 25)
- self.assertTrue(m_import.foo().device == torch.device('cuda:0'))
- self.assertEqual(m_orig.foo(), m_import.foo())
- self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
+ self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=False)
- def test_script_module_export_blocks(self):
- class M(torch.jit.ScriptModule):
- def __init__(self, n, m):
- super(M, self).__init__()
- self.weight = torch.nn.Parameter(torch.rand(n, m))
+ def test_intlist_args(self):
+ def func_1(x):
+ return torch.nn.functional.adaptive_avg_pool1d(x, 1)
- @torch.jit.script_method
- def forward(self, input):
- if bool(input.sum() > 0):
- output = self.weight.mv(input)
- else:
- output = self.weight + input
- return output
+ def func_2(x):
+ return torch.nn.functional.adaptive_avg_pool1d(x, output_size=1)
- m_orig = M(200, 200)
- m_import = self.getExportImportCopy(m_orig)
+ def func_3(x):
+ return torch.nn.functional.adaptive_avg_pool1d(x, output_size=[1])
- t = torch.rand(200)
- self.assertEqual(m_orig(t), m_import(t))
+ x = torch.randn(8, 8, 8)
+ self.checkScript(func_1, [x], optimize=True)
+ self.checkScript(func_2, [x], optimize=True)
+ self.checkScript(func_3, [x], optimize=True)
- def test_script_module_export_shared_storage(self):
- class M(torch.jit.ScriptModule):
+ def test_wrong_implicit_expand(self):
- def __init__(self):
- super(M, self).__init__(False)
- self.param1 = torch.nn.Parameter(torch.rand(5, 5))
- self.param2 = torch.nn.Parameter(self.param1[3])
- self.param3 = torch.nn.Parameter(torch.rand(5, 5))
- self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6])
+ @_trace(torch.zeros(3), torch.zeros(1))
+ def foo(a, b):
+ return a + b
- @torch.jit.script_method
- def foo(self):
- return self.param1 + self.param2 + self.param3 + self.param4
+ a = torch.rand(4)
+ b = torch.rand(4)
+ self.assertEqual(a + b, foo(a, b))
- m_orig = M()
- m_import = self.getExportImportCopy(m_orig)
+ def test_builtin_args_fails(self):
- self.assertEqual(m_orig.foo(), m_import.foo())
- self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
- self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
+ with self.assertRaisesRegex(RuntimeError, 'expected at most'):
+ @torch.jit.script
+ def f0(a):
+ torch.sum(a, a, a, a)
- def test_onnx_export_script_module(self):
- class ModuleToExport(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToExport, self).__init__()
+ with self.assertRaisesRegex(RuntimeError, 'argument self not provided'):
+ @torch.jit.script
+ def f1(a):
+ torch.sum(foo=4)
- @torch.jit.script_method
- def forward(self, x):
- y = x - x
- return x + x
+ with self.assertRaisesRegex(RuntimeError, 'specified twice'):
+ @torch.jit.script
+ def f2(a):
+ torch.sum(a, self=a)
- mte = ModuleToExport()
- outputs = mte(torch.zeros(1, 2, 3))
- self.assertExpected(torch.onnx.export_to_pretty_string(
- mte, (torch.zeros(1, 2, 3),), None, verbose=False,
- example_outputs=outputs))
+ with self.assertRaisesRegex(RuntimeError, 'not provided'):
+ @torch.jit.script
+ def f3(a):
+ torch.sum(dim=4)
- def test_onnx_export_script_python_fail(self):
- class ModuleToInline(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToInline, self).__init__()
+ with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but found Tensor'):
+ @torch.jit.script
+ def f4(a):
+ torch.cat(a)
- def forward(self, x):
- return torch.neg(x)
+ with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but found int\[\]'):
+ @torch.jit.script
+ def f5(a):
+ torch.cat([3])
- class ModuleToExport(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToExport, self).__init__()
- self.mod = ModuleToInline()
+ with self.assertRaisesRegex(RuntimeError, 'Lists must contain only a single type'):
+ @torch.jit.script
+ def f6(a):
+ a.expand(size=[3, [4]])
- @torch.jit.script_method
- def forward(self, x):
- y = self.mod(x)
- return y + y
+ with self.assertRaisesRegex(RuntimeError, 'xpected a value of type Tensor for argument \'self\''):
+ @torch.jit.script
+ def f7(a):
+ torch.sum([4])
- mte = ModuleToExport()
- outputs = mte(torch.zeros(1, 2, 3))
- f = io.BytesIO()
- with self.assertRaisesRegex(RuntimeError, "Couldn't export Python operator"):
- torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False,
- example_outputs=outputs)
+ def test_builtin_args(self):
- def test_onnx_export_script_inline_trace(self):
- class ModuleToInline(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToInline, self).__init__()
+ def t0(a):
+ # default arg dim
+ return torch.cat([a, a])
- def forward(self, x):
- return torch.neg(x)
+ self.checkScript(t0, (torch.zeros(1, 1),))
- class ModuleToExport(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToExport, self).__init__()
- self.mod = torch.jit.trace(ModuleToInline(), torch.zeros(1, 2, 3))
+ def t1(a):
+ # keywords out of order
+ return torch.cat(dim=1, tensors=[a, a])
- @torch.jit.script_method
- def forward(self, x):
- y = self.mod(x)
- return y + y
+ self.checkScript(t1, (torch.zeros(1, 1, 2),))
- mte = ModuleToExport()
- outputs = mte(torch.zeros(1, 2, 3))
- self.assertExpected(torch.onnx.export_to_pretty_string(
- mte, (torch.zeros(1, 2, 3),), None, verbose=False,
- example_outputs=outputs))
+ def t2(a):
+ # mix const/non-const attributes
+ if True:
+ b = 1
+ else:
+ b = 0
+ return torch.sum(a, dim=b, keepdim=False)
- def test_onnx_export_script_inline_script(self):
- class ModuleToInline(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToInline, self).__init__()
+ self.checkScript(t2, (torch.zeros(1, 1, 2),))
- @torch.jit.script_method
- def forward(self, x):
- return torch.neg(x)
+ def test_parser_type_annotations(self):
+ cu = torch.jit.CompilationUnit('''
+ def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
+ return x, x
+ ''')
- class ModuleToExport(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToExport, self).__init__()
- self.mod = ModuleToInline()
+ self.assertExpected(cu.__getattr__('foo').pretty_print_schema())
- @torch.jit.script_method
- def forward(self, x):
- y = self.mod(x)
- return y + y
+ def test_parser_type_annotations_comment(self):
+ cu = torch.jit.CompilationUnit('''
+ def foo(x, y):
+ # type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]
+ return x, x
+ ''')
- mte = ModuleToExport()
- outputs = mte(torch.zeros(1, 2, 3))
- self.assertExpected(torch.onnx.export_to_pretty_string(
- mte, (torch.zeros(1, 2, 3),), None, verbose=False,
- example_outputs=outputs))
+ self.assertExpected(cu.__getattr__('foo').pretty_print_schema())
- def test_onnx_export_script_module_loop(self):
- class ModuleToExport(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToExport, self).__init__()
+ def test_parser_type_annotations_unknown_type(self):
+ with self.assertRaisesRegex(RuntimeError, r'Unknown type name Foo'):
+ cu = torch.jit.CompilationUnit('''
+ def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
+ return x, x
+ ''')
- @torch.jit.script_method
- def forward(self, x):
- # test if we support end to end onnx export on loop and
- # nested loops with and without loop index
- for _ in range(5):
- for i in range(3):
- x = x + i
- return x
+ def test_parser_type_annotations_subscript_non_ident(self):
+ with self.assertRaisesRegex(RuntimeError, r'Subscripted type must be a type identifier'):
+ cu = torch.jit.CompilationUnit('''
+ def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]:
+ return x, x
+ ''')
- mte = ModuleToExport()
- outputs = mte(torch.zeros(1, 2, 3))
- self.assertExpected(torch.onnx.export_to_pretty_string(
- mte, (torch.zeros(1, 2, 3),), None, verbose=False,
- example_outputs=outputs))
+ def test_parser_type_annotations_subscript_tensor(self):
+ with self.assertRaisesRegex(RuntimeError, r'Unknown type constructor Tensor'):
+ cu = torch.jit.CompilationUnit('''
+ def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
+ return x, x
+ ''')
- def test_onnx_export_script_truediv(self):
- class ModuleToExport(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToExport, self).__init__()
+ def test_parser_type_annotations_incompatible_expression(self):
+ with self.assertRaisesRegex(RuntimeError, r'Expression of type \+ cannot be used in a type expression'):
+ cu = torch.jit.CompilationUnit('''
+ def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]:
+ return x, x
+ ''')
- @torch.jit.script_method
- def forward(self, x):
- z = x.size(0) / 2
- return x + z
+ def test_gather_dynamic_index(self):
+ def t(x):
+ gather1 = x[0]
+ idx = 0 + 1
+ gather2 = x[idx]
+ return gather1 + gather2
- mte = ModuleToExport()
- outputs = mte(torch.zeros(1, 2, 3))
- self.assertExpected(torch.onnx.export_to_pretty_string(
- mte, (torch.zeros(1, 2, 3),), None, verbose=False,
- example_outputs=outputs))
+ self.checkScript(t, (torch.zeros(3, 2, 3),))
- def test_onnx_raw_export_script_truediv(self):
- class ModuleToExport(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToExport, self).__init__()
+ def test_slice_dynamic_index(self):
+ def t(x):
+ slice1 = x[0:1]
+ zero = 0
+ one = zero + 1
+ slice2 = x[zero:one]
+ return slice1 + slice2
- @torch.jit.script_method
- def forward(self, x):
- z = x.size(0) / 2
- return x + z
+ self.checkScript(t, (torch.zeros(3, 2, 3),))
- mte = ModuleToExport()
- outputs = mte(torch.zeros(1, 2, 3))
- self.assertExpected(torch.onnx.export_to_pretty_string(
- mte, (torch.zeros(1, 2, 3),), None, verbose=False,
- example_outputs=outputs, export_raw_ir=True))
+ def test_addmm_grad(self):
+ """ This test checks several things:
+ 1. An expand node was inserted before the addmm operating on the
+ bias term.
+ 2. The fused form of addmm appears in the ultimate graph that's
+ executed.
+ 3. A sum op was emitted for accumulating gradients along the 0th
+ (expanded) dimension of the bias term.
+ 4. The correct symbolic representation for the backward pass of the
+ mm operator was emitted (x.t() -> mm)
- def test_onnx_export_script_non_alpha_add_sub(self):
- class ModuleToExport(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToExport, self).__init__()
+ TODO: we should actually check these conditions once we have a way
+ to dump the GraphExecutor state. Namely the processed forward graph
+ and the backward graph.
+ """
+ @torch.jit.script
+ def addmm_grad_test(b, x, w):
+ return torch.addmm(b, x, w)
- @torch.jit.script_method
- def forward(self, x):
- bs = x.size(0) + 1
- return bs - 1
+ # Initialize param and input values
+ w_init = torch.rand(2, 5)
+ b_init = torch.rand(5)
+ x = torch.rand(3, 2)
- mte = ModuleToExport()
- outputs = torch.LongTensor([mte(torch.rand(3, 4))])
- self.assertExpected(torch.onnx.export_to_pretty_string(
- mte, (torch.rand(3, 4),), None, verbose=False,
- example_outputs=outputs))
+ # Clone trainable params
+ b = b_init.clone()
+ b.requires_grad_()
+ w = w_init.clone()
+ w.requires_grad_()
- def test_onnx_export_script_module_if(self):
- class ModuleToExport(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToExport, self).__init__()
+ # Test symbolic differentiation
+ y = addmm_grad_test(b, x, w)
+ y.sum().backward()
- @torch.jit.script_method
- def forward(self, x):
- if bool(torch.sum(x) > 0):
- x = torch.neg(x)
- return x
+ # clone params for autograd reference
+ b_ref = b_init.clone()
+ b_ref.requires_grad_()
+ w_ref = w_init.clone()
+ w_ref.requires_grad_()
+ y_ref = torch.addmm(b_ref, x, w_ref)
+ y_ref.sum().backward()
- mte = ModuleToExport()
- outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
- self.assertExpected(torch.onnx.export_to_pretty_string(
- mte, (torch.zeros(1, 2, 3),), None, verbose=False,
- example_outputs=outputs))
+ self.assertEqual(w.grad, w_ref.grad)
+ self.assertEqual(b.grad, b_ref.grad)
+
+ def test_zeros(self):
+ class M(torch.jit.ScriptModule):
+ __constants__ = ['d']
- def test_onnx_export_script_inline_params(self):
- class ModuleToInline(torch.jit.ScriptModule):
def __init__(self):
- super(ModuleToInline, self).__init__()
- self.m = torch.nn.Parameter(torch.ones(3, 3))
- self.unused = torch.nn.Parameter(torch.ones(1, 2, 3))
+ self.d = torch.device('cpu')
@torch.jit.script_method
- def forward(self, x):
- return torch.mm(x, self.m)
+ def create(self):
+ return torch.zeros([1, 1, 2], dtype=torch.float, device=self.d, layout=torch.strided)
- class ModuleToExport(torch.jit.ScriptModule):
- def __init__(self):
- super(ModuleToExport, self).__init__()
- self.mod = ModuleToInline()
- self.param = torch.nn.Parameter(torch.ones(3, 4))
+ r = M().create()
+ self.assertEqual(r.dtype, torch.float)
+ self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r)
- @torch.jit.script_method
- def forward(self, x):
- y = self.mod(x)
- return torch.mm(y, self.param)
+ def test_vararg_zeros(self):
+ def foo():
+ return torch.zeros(3, 4, 5, dtype=torch.int)
- mte = ModuleToExport()
- result = mte(torch.zeros(2, 3))
- reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
- self.assertEqual(result, reference)
- self.assertExpected(torch.onnx.export_to_pretty_string(
- mte, (torch.ones(2, 3),), None, verbose=False,
- example_outputs=result, propagate=True))
+ self.checkScript(foo, ())
- def test_trace_with_size(self):
- @_trace(torch.zeros(1, 1))
- def foo(x):
- return x + 1
+ def test_rand(self):
+ def test_rand():
+ a = torch.rand([3, 4])
+ return a + 1.0 - a
- @torch.jit.script
- def bar(x):
- y = int(foo(x))
- if True:
- y = 7
- return y + 1
+ self.checkScript(test_rand, ())
- self.assertEqual(8, bar(torch.ones(1, 1)))
+ def test_erase_number_types(self):
+ def func(a):
+ b = 7 + 1 + 3
+ c = a + b
+ c += b
+ return c
- def test_tracing_slicing(self):
- @_trace(torch.zeros(10))
- def foo_trace(x):
- return x[-5:-3]
+ graph = torch.jit.script(func).graph
+ self.run_pass('remove_inplace_ops', graph)
+ self.run_pass('erase_number_types', graph)
+ self.assertExpectedGraph(graph)
- @torch.jit.script
- def foo_script(x):
- return x[-5:-3]
+ def test_mm_batching(self):
+ lstm_cell = torch.jit.script(LSTMCellS)
- def foo(x):
- return x[-5:-3]
+ def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
+ for i in range(x.size(0)):
+ hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
+ return hx
- a = torch.arange(0, 8)
- b = torch.arange(0, 20)
- self.assertEqual(foo_trace(a), foo_script(a))
- self.assertEqual(foo_trace(a), foo(a))
- self.assertNotEqual(foo_trace(a), foo_trace(b))
+ slstm = torch.jit.script(lstm)
- def test_tracing_indexing(self):
- @_trace(torch.zeros(10))
- def foo_trace(x):
- return x[-2]
+ inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
+ slstm(*inputs).sum().backward()
- @torch.jit.script
- def foo_script(x):
- return x[-2]
+ fw_graph = slstm.graph_for(*inputs)
+ bw_graph = backward_graph(slstm, diff_graph_idx=0)
+ self.assertTrue('prim::MMBatchSide' in str(fw_graph))
+ self.assertTrue('prim::MMTreeReduce' in str(bw_graph))
- def foo(x):
- return x[-2]
+ sout = slstm(*inputs)
+ out = lstm(*inputs)
+ self.assertEqual(slstm(*inputs), lstm(*inputs))
+ self.assertEqual(torch.autograd.grad(slstm(*inputs).sum(), inputs),
+ torch.autograd.grad(lstm(*inputs).sum(), inputs))
- a = torch.arange(0, 8)
- b = torch.arange(0, 20)
- self.assertEqual(foo_script(a), foo_trace(a))
- self.assertEqual(foo_trace(a), foo(a))
- self.assertNotEqual(foo_trace(a), foo_trace(b))
+ def test_loop_unrolling(self):
+ def fn(x):
+ y = 0
+ for i in range(int(x)):
+ y += i
+ return y
- def test_index_select_shape_prop(self):
+ graph = torch.jit.script(fn).graph
+ self.run_pass('loop_unrolling', graph)
+ self.assertExpectedGraph(graph)
+ self.checkScript(fn, (torch.tensor(10),))
- @torch.jit.script
- def foo(x, y):
- return torch.index_select(x, index=y, dim=1)
+ def test_loop_unrolling_const(self):
+ def fn():
+ y = 0
+ for i in range(10):
+ y += 1
+ return y
- a = torch.zeros(2, 2)
- b = torch.zeros(4, dtype=torch.long)
- torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False)
- self.assertExpected(canonical(foo.graph))
+ def fn2():
+ y = 0
+ for i in range(10):
+ y += i
+ return y
- def test_onnx_export_speculate(self):
+ def check(fn, name):
+ graph = torch.jit.script(fn).graph
+ self.run_pass('loop_unrolling', graph)
+ self.assertExpectedGraph(graph, subname=name)
+ self.checkScript(fn, ())
- class Foo(torch.jit.ScriptModule):
- def __init__(self, m):
- super(Foo, self).__init__()
- self.m = m
+ check(fn, 'add_const')
+ check(fn2, 'add_iter')
- @torch.jit.script_method
- def forward(self, x):
- x += x
- # because we are testing if we emit `if` statement correctly
- # we cannot use `True` as the condition. Constant prop
- # would remove the `if` statements.
- c = torch.sum(x) > 4
- if bool(c):
- if bool(c):
- y = self.m(x)
- else:
- y = self.m(x)
- else:
- y = self.m(x)
- return y
+ def test_loop_unrolling_nested(self):
+ def fn(x):
+ y = 0
+ for i in range(10):
+ for j in range(int(x)):
+ y += j
+ return y
- linear = torch.jit.trace(nn.Linear(10, 20).float(), torch.zeros(1, 10, dtype=torch.float))
+ graph = torch.jit.script(fn).graph
+ self.run_pass('loop_unrolling', graph)
+ self.assertExpectedGraph(graph)
+ self.checkScript(fn, (torch.tensor(10),))
- @torch.jit.script
- def transpose(x):
- return x.t()
+ def test_loop_unroll_unused_counter(self):
+ def fn(x):
+ y = 0
+ for i in range(int(x)):
+ y += 1
+ return y
- f1 = Foo(transpose)
- outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
- f2 = Foo(linear)
- outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))
+ graph = torch.jit.script(fn).graph
+ self.run_pass('loop_unrolling', graph)
+ self.assertExpectedGraph(graph)
- onnx_ish = torch.onnx.export_to_pretty_string(
- f1,
- (torch.ones(1, 10, dtype=torch.float), ),
- None, verbose=False, example_outputs=outputs_f1)
- self.assertExpected(onnx_ish, subname='f1')
- onnx_ish = torch.onnx.export_to_pretty_string(
- f2,
- (torch.ones(1, 10, dtype=torch.float), ),
- None, verbose=False, example_outputs=outputs_f2)
- self.assertExpected(onnx_ish, subname='f2')
+ def test_loop_unroll_negative(self):
+ def fn(x):
+ y = 0
+ for i in range(int(x)):
+ y += 1
+ return y
- def test_onnx_export_shape_reshape(self):
- class Foo(torch.nn.Module):
- def forward(self, x):
- import torch.onnx.operators
- x = x.repeat(5, 1, 1)
- shape = torch.onnx.operators.shape_as_tensor(x)
- reshaped = torch.onnx.operators.reshape_from_tensor_shape(x, shape)
- return reshaped
+ self.checkScript(fn, (torch.tensor(-20),))
+ self.checkScript(fn, (torch.tensor(-2),))
+ self.checkScript(fn, (torch.tensor(-1),))
+ self.checkScript(fn, (torch.tensor(0),))
+ self.checkScript(fn, (torch.tensor(1),))
+ self.checkScript(fn, (torch.tensor(2),))
- foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
- outputs = foo(torch.zeros(1, 2, 3))
- f = io.BytesIO()
- s = torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f,
- example_outputs=outputs)
- self.assertExpected(s)
+ def test_where(self):
+ def fn(x, y):
+ return torch.where(x > 0.0, x, y)
- def test_shape_analysis_loop(self):
- def foo(a, b, x):
- c = a
- # on the first iteration of the loop it appears that
- # c should have a expand to the size of b
- # but on the second+ iterations, there is no broadcast and the
- # sizes are different.
- # previously this would cause the compiler to (1) enter an infinite
- # loop trying to compute the shape, and (2) insert invalid
- # broadcasts.
- # this test ensure we don't regress on these issues
- for _ in range(2):
- a = c + b
- c = x
- b = x
- return a
+ self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
- self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=False)
+ def test_where_method(self):
+ def fn(x, y):
+ return x.where(x > 0.0, y)
- def test_intlist_args(self):
- def func_1(x):
- return torch.nn.functional.adaptive_avg_pool1d(x, 1)
+ self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
- def func_2(x):
- return torch.nn.functional.adaptive_avg_pool1d(x, output_size=1)
+ def test_reassign_module_lhs(self):
+ with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\' because it has type value and self is'
+ ' not a first-class value. Only reassignments to first-class values are allowed'):
+ class ReassignSelfLHS(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x):
+ for i in range(20):
+ self = x
+ return self
- def func_3(x):
- return torch.nn.functional.adaptive_avg_pool1d(x, output_size=[1])
+ ReassignSelfLHS()
- x = torch.randn(8, 8, 8)
- self.checkScript(func_1, [x], optimize=True)
- self.checkScript(func_2, [x], optimize=True)
- self.checkScript(func_3, [x], optimize=True)
+ def test_reassign_module_rhs(self):
+ with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'x\' to a value of type module because x is not a'
+ ' first-class value. Only reassignments to first-class values are allowed'):
+ class ReassignSelfRHS(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x):
+ for i in range(20):
+ x = self
+ return self
- def test_wrong_implicit_expand(self):
+ ReassignSelfRHS()
- @_trace(torch.zeros(3), torch.zeros(1))
- def foo(a, b):
- return a + b
+ def test_unknown_builtin(self):
+ with self.assertRaisesRegex(RuntimeError, 'unknown builtin op'):
+ @torch.jit.script
+ def unknown_builtin(x):
+ return x.splork(3)
- a = torch.rand(4)
- b = torch.rand(4)
- self.assertEqual(a + b, foo(a, b))
+ def test_return_tuple(self):
+ def return_tuple(x):
+ a = (x, x)
+ return a, x
+ self.checkScript(return_tuple, (torch.rand(4),))
- def test_builtin_args_fails(self):
+ def test_method_no_self(self):
+ with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'):
+ class MethodNoSelf(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward():
+ return torch.zeros(3, 4)
- with self.assertRaisesRegex(RuntimeError, 'expected at most'):
- @torch.jit.script
- def f0(a):
- torch.sum(a, a, a, a)
+ MethodNoSelf()
- with self.assertRaisesRegex(RuntimeError, 'argument self not provided'):
+ def test_return_stmt_not_at_end(self):
+ with self.assertRaisesRegex(RuntimeError, 'return statements can appear only at the end of the function body'):
@torch.jit.script
- def f1(a):
- torch.sum(foo=4)
+ def return_stmt_wrong(x):
+ if bool(x > 3):
+ return 3
+ else:
+ return x
- with self.assertRaisesRegex(RuntimeError, 'specified twice'):
+ def test_for_range_no_arg(self):
+ with self.assertRaisesRegex(RuntimeError, r'range\(\) expects 1 argument but got 0'):
@torch.jit.script
- def f2(a):
- torch.sum(a, self=a)
+ def range_no_arg(x):
+ for i in range():
+ x += 1
+ return x
- with self.assertRaisesRegex(RuntimeError, 'not provided'):
- @torch.jit.script
- def f3(a):
- torch.sum(dim=4)
+ def test_list_iterables(self):
+ with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
+ cu = torch.jit.CompilationUnit('''
+ def list_iterables(x):
+ for i, j in [2, 3, 4], [5, 6, 7]:
+ x += i
+ x += j
+ return x
+ ''')
- with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but found Tensor'):
- @torch.jit.script
- def f4(a):
- torch.cat(a)
+ def test_for_tuple_unpack(self):
+ with self.assertRaisesRegex(RuntimeError, 'Iteration variable unpacking is not supported'):
+ cu = torch.jit.CompilationUnit('''
+ def for_tuple_unpack(x, y):
+ for i, j in [[3, 4], [5, 6], [7, 8]]:
+ x += i
+ y += j
+ return x, y
+ ''')
- with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but found int\[\]'):
- @torch.jit.script
- def f5(a):
- torch.cat([3])
+ def test_single_starred_lhs(self):
+ with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
+ ' of another non-starred expression'):
+ cu = torch.jit.CompilationUnit('''
+ def single_starred_lhs(x):
+ a = (x, x, x)
+ *b, = a
+ return b
+ ''')
- with self.assertRaisesRegex(RuntimeError, 'Lists must contain only a single type'):
+ def test_singleton_tuple_unpack(self):
+ def foo(a):
+ b, = (a,)
+ return b + 1
+ self.checkScript(foo, (torch.rand(3),))
+
+ def test_multi_reduction(self):
+ with self.assertRaisesRegex(
+ RuntimeError,
+ 'augmented assignment can only have one LHS expression'):
+ cu = torch.jit.CompilationUnit('''
+ def multi_reduction(x):
+ a, b += x
+ return a, b
+ ''')
+
+ def test_invalid_call_arguments(self):
+ with self.assertRaisesRegex(RuntimeError, 'arguments for call are not valid'):
@torch.jit.script
- def f6(a):
- a.expand(size=[3, [4]])
+ def invalid_call_arguments(x):
+ return torch.unsqueeze(3, 4, 5, 6, 7, 8)
+
+ def test_invalid_lhs_assignment(self):
+ with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
+ cu = torch.jit.CompilationUnit('''
+ def invalid_lhs_assignment(x):
+ x + 1 = x
+ return x
+ ''')
+
+ def test_multi_starred_expr_lhs(self):
+ with self.assertRaisesRegex(RuntimeError, 'Only one starred expression is allowed on the lhs'):
+ cu = torch.jit.CompilationUnit('''
+ def multi_starred_expr_lhs():
+ a, *b, *c = [1, 2, 3, 4, 5, 6]
+ return a
+ ''')
+
+ def test_pack_tuple_into_non_var(self):
+ with self.assertRaisesRegex(RuntimeError, 'Cannot pack a tuple into a non-variable'):
+ cu = torch.jit.CompilationUnit('''
+ def pack_tuple_into_non_var(x):
+ a, *1 = (3, 4, 5)
+ return x
+ ''')
+
+ def test_print_kwargs(self):
+ with self.assertRaisesRegex(RuntimeError, 'print doesn\'t accept any keyword arguments'):
+ cu = torch.jit.CompilationUnit('''
+ def print_kwargs(x):
+ print(x, flush=True)
+ return x
+ ''')
- with self.assertRaisesRegex(RuntimeError, 'xpected a value of type Tensor for argument \'self\''):
+ def test_builtin_use_as_value(self):
+ with self.assertRaisesRegex(RuntimeError, 'builtin cannot be used as a value'):
@torch.jit.script
- def f7(a):
- torch.sum([4])
-
- def test_builtin_args(self):
+ def builtin_use_as_value(x):
+ return x.unsqueeze
- def t0(a):
- # default arg dim
- return torch.cat([a, a])
+ def test_wrong_use_as_tuple(self):
+ with self.assertRaisesRegex(RuntimeError, 'cannot be used as a tuple'):
+ def test_fn():
+ return 3
- self.checkScript(t0, (torch.zeros(1, 1),))
+ @torch.jit.script
+ def wrong_use_as_tuple(self):
+ a, b = test_fn
+ return a
- def t1(a):
- # keywords out of order
- return torch.cat(dim=1, tensors=[a, a])
+ def test_wrong_attr_lookup(self):
+ with self.assertRaisesRegex(RuntimeError, 'attribute lookup is not defined on builtin'):
+ @torch.jit.script
+ def wrong_attr_lookup(self, x):
+ a = x.unsqueeze.myattr
+ return a
- self.checkScript(t1, (torch.zeros(1, 1, 2),))
+ def test_wrong_use_as_callable(self):
+ with self.assertRaisesRegex(RuntimeError, 'cannot call a value'):
+ @torch.jit.script
+ def wrong_use_as_callable(x):
+ return x(3, 4, 5)
- def t2(a):
- # mix const/non-const attributes
- if True:
- b = 1
- else:
- b = 0
- return torch.sum(a, dim=b, keepdim=False)
+ def test_python_val_doesnt_have_attr(self):
+ with self.assertRaisesRegex(RuntimeError, 'object has no attribute abcd'):
- self.checkScript(t2, (torch.zeros(1, 1, 2),))
+ @torch.jit.script
+ def python_val_doesnt_have_attr():
+ # this has to be a module otherwise attr lookup would not be
+ # allowed in the first place
+ return shutil.abcd
- def test_parser_type_annotations(self):
- cu = torch.jit.CompilationUnit('''
- def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
- return x, x
- ''')
+ def test_wrong_module_attr_lookup(self):
+ with self.assertRaisesRegex(RuntimeError, 'python value of type \'type\' cannot be used as a value:'):
+ import io
- self.assertExpected(cu.__getattr__('foo').pretty_print_schema())
+ @torch.jit.script
+ def wrong_module_attr_lookup():
+ return io.BytesIO
- def test_parser_type_annotations_comment(self):
- cu = torch.jit.CompilationUnit('''
- def foo(x, y):
- # type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]
- return x, x
- ''')
+ def test_wrong_method_call_inputs(self):
+ with self.assertRaisesRegex(RuntimeError, 'argument y not provided'):
+ class SomeModule(torch.jit.ScriptModule):
- self.assertExpected(cu.__getattr__('foo').pretty_print_schema())
+ @torch.jit.script_method
+ def foo(self, x, y):
+ return x
- def test_parser_type_annotations_unknown_type(self):
- with self.assertRaisesRegex(RuntimeError, r'Unknown type name Foo'):
- cu = torch.jit.CompilationUnit('''
- def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
- return x, x
- ''')
+ @torch.jit.script_method
+ def forward(self, x, y):
+ return self.foo(x)
+ SomeModule()
- def test_parser_type_annotations_subscript_non_ident(self):
- with self.assertRaisesRegex(RuntimeError, r'Subscripted type must be a type identifier'):
+ def test_single_starred_expr_for_loop(self):
+ with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
cu = torch.jit.CompilationUnit('''
- def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]:
- return x, x
+ def test():
+ x = 0
+ for *a in [1, 2, 3]:
+ x = x + 1
+ return x
''')
- def test_parser_type_annotations_subscript_tensor(self):
- with self.assertRaisesRegex(RuntimeError, r'Unknown type constructor Tensor'):
+ def test_duplicate(self):
+ with self.assertRaisesRegex(RuntimeError, 'Method \'test\' already defined'):
cu = torch.jit.CompilationUnit('''
- def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
- return x, x
- ''')
+ def test():
+ return 1
- def test_parser_type_annotations_incompatible_expression(self):
- with self.assertRaisesRegex(RuntimeError, r'Expression of type \+ cannot be used in a type expression'):
- cu = torch.jit.CompilationUnit('''
- def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]:
- return x, x
+ def test():
+ return 2
''')
- def test_gather_dynamic_index(self):
- def t(x):
- gather1 = x[0]
- idx = 0 + 1
- gather2 = x[idx]
- return gather1 + gather2
-
- self.checkScript(t, (torch.zeros(3, 2, 3),))
+ def test_call_ge(self):
+ with self.assertRaisesRegex(RuntimeError, 'expected at most 1 arguments but found 3'):
+ @_trace(torch.zeros(1, 2, 3))
+ def foo(x):
+ return x
- def test_slice_dynamic_index(self):
- def t(x):
- slice1 = x[0:1]
- zero = 0
- one = zero + 1
- slice2 = x[zero:one]
- return slice1 + slice2
+ @torch.jit.script
+ def test_fn():
+ return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3))
- self.checkScript(t, (torch.zeros(3, 2, 3),))
+ def test_wrong_return_type(self):
+ with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
+ def somefunc():
+ # type: () -> Tuple[Tuple[Tensor, Tensor]]
+ return torch.zeros(3, 4), torch.zeros(4, 5)
- def test_addmm_grad(self):
- """ This test checks several things:
- 1. An expand node was inserted before the addmm operating on the
- bias term.
- 2. The fused form of addmm appears in the ultimate graph that's
- executed.
- 3. A sum op was emitted for accumulating gradients along the 0th
- (expanded) dimension of the bias term.
- 4. The correct symbolic representation for the backward pass of the
- mm operator was emitted (x.t() -> mm)
+ @torch.jit.script
+ def wrong_return_type():
+ return somefunc()
+ wrong_return_type()
- TODO: we should actually check these conditions once we have a way
- to dump the GraphExecutor state. Namely the processed forward graph
- and the backward graph.
- """
- @torch.jit.script
- def addmm_grad_test(b, x, w):
- return torch.addmm(b, x, w)
+ # Tests for calling between different front-end modes
+ def test_call_python_fn_from_tracing_fn(self):
+ def python_fn(x):
+ return torch.neg(x)
- # Initialize param and input values
- w_init = torch.rand(2, 5)
- b_init = torch.rand(5)
- x = torch.rand(3, 2)
+ @_trace(torch.rand(3, 4))
+ def traced_fn(x):
+ return python_fn(x) + 1
- # Clone trainable params
- b = b_init.clone()
- b.requires_grad_()
- w = w_init.clone()
- w.requires_grad_()
+ # The neg op in the python function should be properly inlined to the
+ # graph
+ self.assertExpected(canonical(traced_fn.graph))
- # Test symbolic differentiation
- y = addmm_grad_test(b, x, w)
- y.sum().backward()
+ def test_call_python_mod_from_tracing_fn(self):
+ class PythonMod(torch.nn.Module):
+ def __init__(self):
+ super(PythonMod, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 3))
- # clone params for autograd reference
- b_ref = b_init.clone()
- b_ref.requires_grad_()
- w_ref = w_init.clone()
- w_ref.requires_grad_()
- y_ref = torch.addmm(b_ref, x, w_ref)
- y_ref.sum().backward()
+ def forward(self, x):
+ return torch.mm(x, self.param)
- self.assertEqual(w.grad, w_ref.grad)
- self.assertEqual(b.grad, b_ref.grad)
+ pm = PythonMod()
- def test_zeros(self):
- class M(torch.jit.ScriptModule):
- __constants__ = ['d']
+ @_trace(torch.rand(3, 4))
+ def traced_fn(x):
+ return pm(x) + 1.0
- def __init__(self):
- self.d = torch.device('cpu')
+ # Note: the parameter self.param from the Python module is inlined
+ # into the graph
+ self.assertExpected(canonical(traced_fn.graph))
- @torch.jit.script_method
- def create(self):
- return torch.zeros([1, 1, 2], dtype=torch.float, device=self.d, layout=torch.strided)
+ def test_call_traced_fn_from_tracing_fn(self):
+ @_trace(torch.rand(3, 4))
+ def traced_fn1(x):
+ return torch.neg(x)
- r = M().create()
- self.assertEqual(r.dtype, torch.float)
- self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r)
+ @_trace(torch.rand(3, 4))
+ def traced_fn(x):
+ return traced_fn1(x) + 1
- def test_vararg_zeros(self):
- def foo():
- return torch.zeros(3, 4, 5, dtype=torch.int)
+ self.assertExpected(canonical(traced_fn.graph))
- self.checkScript(foo, ())
+ def test_call_traced_mod_from_tracing_fn(self):
+ class TracedModule(torch.nn.Module):
+ def __init__(self):
+ super(TracedModule, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 3))
- @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
- def test_rand(self):
+ def forward(self, x):
+ return torch.mm(x, self.param)
- def test_rand():
- a = torch.rand([3, 4])
- return a + 1.0 - a
+ tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
- self.checkScript(test_rand, ())
+ @_trace(torch.rand(3, 4))
+ def traced_fn(x):
+ return tm(x) + 1.0
- def test_erase_number_types(self):
- def func(a):
- b = 7 + 1 + 3
- c = a + b
- c += b
- return c
+ # Note: the parameter self.param from the Python module is inlined
+ # into the graph
+ self.assertExpected(canonical(traced_fn.graph))
- graph = torch.jit.script(func).graph
- self.run_pass('remove_inplace_ops', graph)
- self.run_pass('erase_number_types', graph)
- self.assertExpectedGraph(graph)
+ def test_call_script_fn_from_tracing_fn(self):
+ @torch.jit.script
+ def script_fn(x):
+ return torch.neg(x)
- def test_mm_batching(self):
- lstm_cell = torch.jit.script(LSTMCellS)
+ @_trace(torch.rand(3, 4))
+ def traced_fn(x):
+ return script_fn(x) + 1
- def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
- for i in range(x.size(0)):
- hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
- return hx
+ self.assertExpected(canonical(traced_fn.graph))
- slstm = torch.jit.script(lstm)
+ def test_call_script_mod_from_tracing_fn(self):
+ class ScriptMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ScriptMod, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 3))
- inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
- slstm(*inputs).sum().backward()
+ @torch.jit.script_method
+ def forward(self, x):
+ return torch.mm(x, self.param)
- fw_graph = slstm.graph_for(*inputs)
- bw_graph = backward_graph(slstm, diff_graph_idx=0)
- self.assertTrue('prim::MMBatchSide' in str(fw_graph))
- self.assertTrue('prim::MMTreeReduce' in str(bw_graph))
+ sm = ScriptMod()
- sout = slstm(*inputs)
- out = lstm(*inputs)
- self.assertEqual(slstm(*inputs), lstm(*inputs))
- self.assertEqual(torch.autograd.grad(slstm(*inputs).sum(), inputs),
- torch.autograd.grad(lstm(*inputs).sum(), inputs))
+ @_trace(torch.rand(3, 4))
+ def traced_fn(x):
+ return sm(x) + 1.0
- def test_loop_unrolling(self):
- def fn(x):
- y = 0
- for i in range(int(x)):
- y += i
- return y
+ self.assertExpected(canonical(traced_fn.graph))
- graph = torch.jit.script(fn).graph
- self.run_pass('loop_unrolling', graph)
- self.assertExpectedGraph(graph)
- self.checkScript(fn, (torch.tensor(10),))
+ def test_call_python_fn_from_traced_module(self):
+ def python_fn(x):
+ return torch.neg(x)
- def test_loop_unrolling_const(self):
- def fn():
- y = 0
- for i in range(10):
- y += 1
- return y
+ class TracedModule(torch.nn.Module):
+ def __init__(self):
+ super(TracedModule, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 3))
- def fn2():
- y = 0
- for i in range(10):
- y += i
- return y
+ def forward(self, x):
+ return torch.mm(python_fn(x), self.param)
- def check(fn, name):
- graph = torch.jit.script(fn).graph
- self.run_pass('loop_unrolling', graph)
- self.assertExpectedGraph(graph, subname=name)
- self.checkScript(fn, ())
+ tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
- check(fn, 'add_const')
- check(fn2, 'add_iter')
+ # Note: parameter self.param from the traced module should appear as
+ # an input to the graph and the neg op from the Python function should
+ # be properly inlined
+ self.assertExpected(canonical(tm.graph))
- def test_loop_unrolling_nested(self):
- def fn(x):
- y = 0
- for i in range(10):
- for j in range(int(x)):
- y += j
- return y
+ def test_call_python_mod_from_traced_module(self):
+ class PythonModule(torch.nn.Module):
+ def __init__(self):
+ super(PythonModule, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(5, 7))
- graph = torch.jit.script(fn).graph
- self.run_pass('loop_unrolling', graph)
- self.assertExpectedGraph(graph)
- self.checkScript(fn, (torch.tensor(10),))
+ def forward(self, x):
+ return torch.mm(x, self.param)
- def test_loop_unroll_unused_counter(self):
- def fn(x):
- y = 0
- for i in range(int(x)):
- y += 1
- return y
+ class TracedModule(torch.nn.Module):
+ def __init__(self):
+ super(TracedModule, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 5))
+ self.mod = PythonModule()
- graph = torch.jit.script(fn).graph
- self.run_pass('loop_unrolling', graph)
- self.assertExpectedGraph(graph)
+ def forward(self, x):
+ return self.mod(torch.mm(x, self.param)) + 1.0
- def test_loop_unroll_negative(self):
- def fn(x):
- y = 0
- for i in range(int(x)):
- y += 1
- return y
+ tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
- self.checkScript(fn, (torch.tensor(-20),))
- self.checkScript(fn, (torch.tensor(-2),))
- self.checkScript(fn, (torch.tensor(-1),))
- self.checkScript(fn, (torch.tensor(0),))
- self.checkScript(fn, (torch.tensor(1),))
- self.checkScript(fn, (torch.tensor(2),))
+ # Note: the parameters from both modules should appear in the flattened
+ # inputs of the graph. All ops from both modules should be inlined.
+ self.assertExpected(canonical(tm.graph))
- def test_where(self):
- def fn(x, y):
- return torch.where(x > 0.0, x, y)
+ def test_call_traced_fn_from_traced_module(self):
+ @_trace(torch.rand(3, 4))
+ def traced_fn(x):
+ return torch.neg(x)
- self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
+ class TracedModule(torch.nn.Module):
+ def __init__(self):
+ super(TracedModule, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 5))
- def test_where_method(self):
- def fn(x, y):
- return x.where(x > 0.0, y)
+ def forward(self, x):
+ return traced_fn(torch.mm(x, self.param))
- self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
+ tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+ # Note: neg op from the traced function should be properly inlined
+ self.assertExpected(canonical(tm.graph))
- def test_reassign_module_lhs(self):
- with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\' because it has type value and self is'
- ' not a first-class value. Only reassignments to first-class values are allowed'):
- class ReassignSelfLHS(torch.jit.ScriptModule):
- @torch.jit.script_method
- def forward(self, x):
- for i in range(20):
- self = x
- return self
+ def test_call_traced_module_from_traced_module(self):
+ class TracedModule1(torch.nn.Module):
+ def __init__(self):
+ super(TracedModule1, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(5, 7))
- ReassignSelfLHS()
+ def forward(self, x):
+ return torch.mm(x, self.param)
- def test_reassign_module_rhs(self):
- with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'x\' to a value of type module because x is not a'
- ' first-class value. Only reassignments to first-class values are allowed'):
- class ReassignSelfRHS(torch.jit.ScriptModule):
- @torch.jit.script_method
- def forward(self, x):
- for i in range(20):
- x = self
- return self
+ class TracedModule(torch.nn.Module):
+ def __init__(self):
+ super(TracedModule, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 5))
+ self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5))
- ReassignSelfRHS()
+ def forward(self, x):
+ return self.mod(torch.mm(x, self.param)) + 1.0
- def test_unknown_builtin(self):
- with self.assertRaisesRegex(RuntimeError, 'unknown builtin op'):
- @torch.jit.script
- def unknown_builtin(x):
- return x.splork(3)
+ tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
- def test_return_tuple(self):
- def return_tuple(x):
- a = (x, x)
- return a, x
- self.checkScript(return_tuple, (torch.rand(4),))
+ # Note: the parameters from both modules should appear in the flattened
+ # inputs of the graph. All ops from both modules should be inlined.
+ self.assertExpected(canonical(tm.graph))
- def test_method_no_self(self):
- with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'):
- class MethodNoSelf(torch.jit.ScriptModule):
- @torch.jit.script_method
- def forward():
- return torch.zeros(3, 4)
+ def test_call_script_fn_from_traced_module(self):
+ @torch.jit.script
+ def traced_fn(x):
+ return torch.neg(x)
- MethodNoSelf()
+ class TracedModule(torch.nn.Module):
+ def __init__(self):
+ super(TracedModule, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 5))
- def test_return_stmt_not_at_end(self):
- with self.assertRaisesRegex(RuntimeError, 'return statements can appear only at the end of the function body'):
- @torch.jit.script
- def return_stmt_wrong(x):
- if bool(x > 3):
- return 3
- else:
- return x
+ def forward(self, x):
+ return traced_fn(torch.mm(x, self.param))
- def test_for_range_no_arg(self):
- with self.assertRaisesRegex(RuntimeError, r'range\(\) expects 1 argument but got 0'):
- @torch.jit.script
- def range_no_arg(x):
- for i in range():
- x += 1
- return x
+ tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+ # Note: neg op from the script function should be properly inlined
+ self.assertExpected(canonical(tm.graph))
- def test_list_iterables(self):
- with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
- cu = torch.jit.CompilationUnit('''
- def list_iterables(x):
- for i, j in [2, 3, 4], [5, 6, 7]:
- x += i
- x += j
- return x
- ''')
+ def test_call_script_module_from_traced_module(self):
+ class ScriptMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ScriptMod, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(5, 7))
+
+ @torch.jit.script_method
+ def forward(self, x):
+ return torch.mm(x, self.param)
- def test_for_tuple_unpack(self):
- with self.assertRaisesRegex(RuntimeError, 'Iteration variable unpacking is not supported'):
- cu = torch.jit.CompilationUnit('''
- def for_tuple_unpack(x, y):
- for i, j in [[3, 4], [5, 6], [7, 8]]:
- x += i
- y += j
- return x, y
- ''')
+ class TracedModule(torch.nn.Module):
+ def __init__(self):
+ super(TracedModule, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 5))
+ self.mod = ScriptMod()
- def test_single_starred_lhs(self):
- with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
- ' of another non-starred expression'):
- cu = torch.jit.CompilationUnit('''
- def single_starred_lhs(x):
- a = (x, x, x)
- *b, = a
- return b
- ''')
+ def forward(self, x):
+ return self.mod(torch.mm(x, self.param)) + 1.0
- def test_singleton_tuple_unpack(self):
- def foo(a):
- b, = (a,)
- return b + 1
- self.checkScript(foo, (torch.rand(3),))
+ tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
- def test_multi_reduction(self):
- with self.assertRaisesRegex(
- RuntimeError,
- 'augmented assignment can only have one LHS expression'):
- cu = torch.jit.CompilationUnit('''
- def multi_reduction(x):
- a, b += x
- return a, b
- ''')
+ # Note: the parameters from both modules should appear in the flattened
+ # inputs of the graph. All ops from both modules should be inlined.
+ self.assertExpected(canonical(tm.graph))
- def test_invalid_call_arguments(self):
- with self.assertRaisesRegex(RuntimeError, 'arguments for call are not valid'):
- @torch.jit.script
- def invalid_call_arguments(x):
- return torch.unsqueeze(3, 4, 5, 6, 7, 8)
+ def test_call_python_fn_from_script_fn(self):
+ def python_fn(x):
+ return torch.neg(x)
- def test_invalid_lhs_assignment(self):
- with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
- cu = torch.jit.CompilationUnit('''
- def invalid_lhs_assignment(x):
- x + 1 = x
- return x
- ''')
+ @torch.jit.script
+ def script_fn(x):
+ return python_fn(x) + 1
- def test_multi_starred_expr_lhs(self):
- with self.assertRaisesRegex(RuntimeError, 'Only one starred expression is allowed on the lhs'):
- cu = torch.jit.CompilationUnit('''
- def multi_starred_expr_lhs():
- a, *b, *c = [1, 2, 3, 4, 5, 6]
- return a
- ''')
+ # Note: the call to python_fn appears as `^python_fn()` and is called
+ # as a PythonOp in the interpreter
+ self.assertExpected(canonical(script_fn.graph))
- def test_pack_tuple_into_non_var(self):
- with self.assertRaisesRegex(RuntimeError, 'Cannot pack a tuple into a non-variable'):
- cu = torch.jit.CompilationUnit('''
- def pack_tuple_into_non_var(x):
- a, *1 = (3, 4, 5)
- return x
- ''')
+ def test_call_python_mod_from_script_fn(self):
+ class PythonModule(torch.nn.Module):
+ def __init__(self):
+ super(PythonModule, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(5, 7))
- def test_print_kwargs(self):
- with self.assertRaisesRegex(RuntimeError, 'print doesn\'t accept any keyword arguments'):
- cu = torch.jit.CompilationUnit('''
- def print_kwargs(x):
- print(x, flush=True)
- return x
- ''')
+ def forward(self, x):
+ return torch.mm(x, self.param)
- def test_builtin_use_as_value(self):
- with self.assertRaisesRegex(RuntimeError, 'builtin cannot be used as a value'):
- @torch.jit.script
- def builtin_use_as_value(x):
- return x.unsqueeze
+ pm = PythonModule()
- def test_wrong_use_as_tuple(self):
- with self.assertRaisesRegex(RuntimeError, 'cannot be used as a tuple'):
- def test_fn():
- return 3
+ @torch.jit.script
+ def script_fn(x):
+ return pm(x) + 1
- @torch.jit.script
- def wrong_use_as_tuple(self):
- a, b = test_fn
- return a
+ # Note: call to pm(x) appears as ^<python_value>() in the trace.
+ # Parameters are NOT inlined.
+ self.assertExpected(str(script_fn.graph))
- def test_wrong_attr_lookup(self):
- with self.assertRaisesRegex(RuntimeError, 'attribute lookup is not defined on builtin'):
- @torch.jit.script
- def wrong_attr_lookup(self, x):
- a = x.unsqueeze.myattr
- return a
+ def test_call_traced_fn_from_script_fn(self):
+ @_trace(torch.rand(3, 4))
+ def traced_fn(x):
+ return torch.neg(x)
- def test_wrong_use_as_callable(self):
- with self.assertRaisesRegex(RuntimeError, 'cannot call a value'):
- @torch.jit.script
- def wrong_use_as_callable(x):
- return x(3, 4, 5)
+ @torch.jit.script
+ def script_fn(x):
+ return traced_fn(x) + 1
- def test_python_val_doesnt_have_attr(self):
- with self.assertRaisesRegex(RuntimeError, 'object has no attribute abcd'):
+ # Note: the neg op from traced_fn should be properly inlined into the
+ # script function's graph
+ self.assertExpected(str(script_fn.graph))
- @torch.jit.script
- def python_val_doesnt_have_attr():
- # this has to be a module otherwise attr lookup would not be
- # allowed in the first place
- return shutil.abcd
+ def test_call_traced_mod_from_script_fn(self):
+ class TracedModule(torch.nn.Module):
+ def __init__(self):
+ super(TracedModule, self).__init__()
- def test_wrong_module_attr_lookup(self):
- with self.assertRaisesRegex(RuntimeError, 'python value of type \'type\' cannot be used as a value:'):
- import io
+ def forward(self, x):
+ return torch.mm(x, torch.zeros(4, 3))
- @torch.jit.script
- def wrong_module_attr_lookup():
- return io.BytesIO
+ tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
- def test_wrong_method_call_inputs(self):
- with self.assertRaisesRegex(RuntimeError, 'argument y not provided'):
- class SomeModule(torch.jit.ScriptModule):
+ @torch.jit.script
+ def script_fn(x):
+ return tm(x) + 1
- @torch.jit.script_method
- def foo(self, x, y):
- return x
+ self.assertExpected(str(script_fn.graph))
- @torch.jit.script_method
- def forward(self, x, y):
- return self.foo(x)
- SomeModule()
+ def test_call_script_fn_from_script_fn(self):
+ @torch.jit.script
+ def script_fn1(x):
+ return torch.neg(x)
- def test_single_starred_expr_for_loop(self):
- with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
- cu = torch.jit.CompilationUnit('''
- def test():
- x = 0
- for *a in [1, 2, 3]:
- x = x + 1
- return x
- ''')
+ @torch.jit.script
+ def script_fn(x):
+ return script_fn1(x) + 1
- def test_duplicate(self):
- with self.assertRaisesRegex(RuntimeError, 'Method \'test\' already defined'):
- cu = torch.jit.CompilationUnit('''
- def test():
- return 1
+ # Note: the neg op from script_fn1 should be properly inlined into the
+ # graph of script_fn
+ self.assertExpected(canonical(script_fn.graph))
- def test():
- return 2
- ''')
+ def test_call_script_mod_from_script_fn(self):
+ class ScriptMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ScriptMod, self).__init__()
- def test_call_ge(self):
- with self.assertRaisesRegex(RuntimeError, 'expected at most 1 arguments but found 3'):
- @_trace(torch.zeros(1, 2, 3))
- def foo(x):
- return x
+ @torch.jit.script_method
+ def forward(self, x):
+ return torch.mm(x, torch.zeros([4, 3]))
- @torch.jit.script
- def test_fn():
- return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3))
+ sm = ScriptMod()
- def test_wrong_return_type(self):
- with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
- def somefunc():
- # type: () -> Tuple[Tuple[Tensor, Tensor]]
- return torch.zeros(3, 4), torch.zeros(4, 5)
+ @torch.jit.script
+ def script_fn(x):
+ return sm(x) + 1
- @torch.jit.script
- def wrong_return_type():
- return somefunc()
- wrong_return_type()
+ self.assertExpected(canonical(script_fn.graph))
- # Tests for calling between different front-end modes
- def test_call_python_fn_from_tracing_fn(self):
+ def test_call_python_fn_from_script_module(self):
def python_fn(x):
return torch.neg(x)
- @_trace(torch.rand(3, 4))
- def traced_fn(x):
- return python_fn(x) + 1
+ class ScriptMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ScriptMod, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 3))
- # The neg op in the python function should be properly inlined to the
- # graph
- self.assertExpected(canonical(traced_fn.graph))
+ @torch.jit.script_method
+ def forward(self, x):
+ return python_fn(torch.mm(x, self.param))
- def test_call_python_mod_from_tracing_fn(self):
+ sm = ScriptMod()
+ self.assertExpected(str(sm.__getattr__('forward').graph))
+
+ def test_call_python_mod_from_script_module(self):
class PythonMod(torch.nn.Module):
def __init__(self):
super(PythonMod, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(3, 5))
+
+ def forward(self, x):
+ return torch.mm(x, self.param)
+
+ class ScriptMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ScriptMod, self).__init__()
self.param = torch.nn.Parameter(torch.rand(4, 3))
+ self.pm = PythonMod()
+ @torch.jit.script_method
def forward(self, x):
- return torch.mm(x, self.param)
+ return self.pm(torch.mm(x, self.param))
- pm = PythonMod()
+ sm = ScriptMod()
+ # Note: the call into PythonMod appears as ^<python_value>(). Parameters
+ # are NOT inlined
+ self.assertExpected(str(sm.graph))
- @_trace(torch.rand(3, 4))
+ def test_call_tracing_fn_from_script_module(self):
+ @_trace(torch.rand(3, 3))
def traced_fn(x):
- return pm(x) + 1.0
-
- # Note: the parameter self.param from the Python module is inlined
- # into the graph
- self.assertExpected(canonical(traced_fn.graph))
-
- def test_call_traced_fn_from_tracing_fn(self):
- @_trace(torch.rand(3, 4))
- def traced_fn1(x):
return torch.neg(x)
- @_trace(torch.rand(3, 4))
- def traced_fn(x):
- return traced_fn1(x) + 1
+ class ScriptMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ScriptMod, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 3))
- self.assertExpected(canonical(traced_fn.graph))
+ @torch.jit.script_method
+ def forward(self, x):
+ return traced_fn(torch.mm(x, self.param))
- def test_call_traced_mod_from_tracing_fn(self):
- class TracedModule(torch.nn.Module):
+ sm = ScriptMod()
+ self.assertExpected(str(sm.__getattr__('forward').graph))
+
+ def test_call_tracing_mod_from_script_module(self):
+ class TracedMod(torch.nn.Module):
def __init__(self):
- super(TracedModule, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 3))
+ super(TracedMod, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(3, 5))
def forward(self, x):
return torch.mm(x, self.param)
- tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+ class ScriptMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ScriptMod, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(4, 3))
+ self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3))
- @_trace(torch.rand(3, 4))
- def traced_fn(x):
- return tm(x) + 1.0
+ @torch.jit.script_method
+ def forward(self, x):
+ return self.tm(torch.mm(x, self.param))
- # Note: the parameter self.param from the Python module is inlined
- # into the graph
- self.assertExpected(canonical(traced_fn.graph))
+ sm = ScriptMod()
+ # Note: the parameters from both modules should appear in the flattened
+ # input list to the graph. The mm op from TracedMod should be properly
+ # inlined
+ self.assertExpected(str(sm.graph))
- def test_call_script_fn_from_tracing_fn(self):
+ def test_call_script_fn_from_script_module(self):
@torch.jit.script
def script_fn(x):
return torch.neg(x)
- @_trace(torch.rand(3, 4))
- def traced_fn(x):
- return script_fn(x) + 1
-
- self.assertExpected(canonical(traced_fn.graph))
-
- def test_call_script_mod_from_tracing_fn(self):
class ScriptMod(torch.jit.ScriptModule):
def __init__(self):
super(ScriptMod, self).__init__()
@torch.jit.script_method
def forward(self, x):
- return torch.mm(x, self.param)
+ return script_fn(torch.mm(x, self.param))
sm = ScriptMod()
+ self.assertExpected(canonical(sm.__getattr__('forward').graph))
- @_trace(torch.rand(3, 4))
- def traced_fn(x):
- return sm(x) + 1.0
-
- self.assertExpected(canonical(traced_fn.graph))
+ def test_call_script_mod_from_script_module(self):
+ class ScriptMod1(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ScriptMod1, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(3, 5))
- def test_call_python_fn_from_traced_module(self):
- def python_fn(x):
- return torch.neg(x)
+ @torch.jit.script_method
+ def forward(self, x):
+ return torch.mm(x, self.param)
- class TracedModule(torch.nn.Module):
+ class ScriptMod(torch.jit.ScriptModule):
def __init__(self):
- super(TracedModule, self).__init__()
+ super(ScriptMod, self).__init__()
self.param = torch.nn.Parameter(torch.rand(4, 3))
+ self.tm = ScriptMod1()
+ @torch.jit.script_method
def forward(self, x):
- return torch.mm(python_fn(x), self.param)
+ return self.tm(torch.mm(x, self.param))
- tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+ sm = ScriptMod()
+ # Note: the parameters from both modules should appear in the flattened
+ # input list to the graph. The mm op from ScriptMod1 should be properly
+ # inlined
+ self.assertExpected(canonical(sm.graph))
- # Note: parameter self.param from the traced module should appear as
- # an input to the graph and the neg op from the Python function should
- # be properly inlined
- self.assertExpected(canonical(tm.graph))
+ def test_module_with_params_called_fails(self):
+ with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with parameters. Stateful "
+ "modules to be inlined must be submodules of the callee."):
+ class ScriptMod(torch.jit.ScriptModule):
+ def __init__(self):
+ super(ScriptMod, self).__init__()
+ self.param = torch.nn.Parameter(torch.rand(3, 3))
- def test_call_python_mod_from_traced_module(self):
- class PythonModule(torch.nn.Module):
- def __init__(self):
- super(PythonModule, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(5, 7))
+ @torch.jit.script_method
+ def forward(self, x):
+ return torch.mm(x, self.param)
- def forward(self, x):
- return torch.mm(x, self.param)
+ sm = ScriptMod()
- class TracedModule(torch.nn.Module):
- def __init__(self):
- super(TracedModule, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 5))
- self.mod = PythonModule()
+ @torch.jit.script
+ def some_func(x):
+ return sm(x)
- def forward(self, x):
- return self.mod(torch.mm(x, self.param)) + 1.0
+ def test_index_put_trace_with_view(self):
+ @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
+ def test_index_put(target, indices, rhs):
+ target[indices] = rhs
+ return target
- tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+ self.assertExpectedGraph(test_index_put.graph)
- # Note: the parameters from both modules should appear in the flattened
- # inputs of the graph. All ops from both modules should be inlined.
- self.assertExpected(canonical(tm.graph))
+ def test_index_put_trace_without_view(self):
+ @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
+ def test_index_put(target, indices, rhs):
+ target[indices] = rhs
+ return target
- def test_call_traced_fn_from_traced_module(self):
- @_trace(torch.rand(3, 4))
- def traced_fn(x):
- return torch.neg(x)
+ self.assertExpectedGraph(test_index_put.graph)
- class TracedModule(torch.nn.Module):
- def __init__(self):
- super(TracedModule, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 5))
+ def test_tuple_indexing(self):
+ def tuple_index(a):
+ if bool(a):
+ b = (1, 2)
+ else:
+ b = (0, 2)
+ return b[-2], b[1]
+
+ self.checkScript(tuple_index, (torch.tensor([1]),))
+ self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True)
+ tuple_comp = torch.jit.script(tuple_index)
+ self.assertExpectedGraph(tuple_comp.graph)
+ self.run_pass('lower_all_tuples', tuple_comp.graph)
+ m = torch.jit.ScriptModule()
+ m._create_method_from_graph("forward", tuple_comp.graph)
+ self.assertEqual(m(torch.tensor(1)), (1, 2))
+
+ with self.assertRaisesRegex(RuntimeError, "tuple indices must be integer constants"):
+ @torch.jit.script
+ def test_non_constant_input(a):
+ if bool(a):
+ b = 1
+ else:
+ b = 0
+ c = (0, 1)
+ return c[b]
+
+ def test_indexing_float():
+ c = (1, 2)
+ return c[0.1]
+ self.checkScriptRaisesRegex(test_indexing_float, (), Exception,
+ "tuple indices must")
+
+ def test_indexing_out_of_bounds_pos():
+ c = (1, 2)
+ return c[2]
+
+ self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
+ "out of range")
+
+ def test_indexing_out_of_bounds_neg():
+ c = (1, 2)
+ return c[-3]
+
+ self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
+ "out of range")
+
+ def test_tuple_slicing(self):
+ def tuple_slice(a):
+ if bool(a):
+ b = (1, 2, 3, 4)
+ else:
+ b = (4, 3, 2, 1)
+ c = b[-4:4]
+ d = b[0:]
+ e = c[1:-1]
+ return e
+
+ self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
+ tuple_comp = torch.jit.script(tuple_slice)
+ self.assertExpectedGraph(tuple_comp.graph)
+ self.run_pass('lower_all_tuples', tuple_comp.graph)
+ self.assertTrue('Tuple' not in str(tuple_comp.graph))
+ m = torch.jit.ScriptModule()
+ m._create_method_from_graph("forward", tuple_comp.graph)
+ self.assertEqual(m(torch.tensor(1)), (2, 3))
+
+ @torch.jit.script
+ def test_indexing_end_out_of_bounds():
+ c = (1, 2)
+ return c[2:10]
+
+ # output is None in script and () in python
+ self.assertEqual(test_indexing_end_out_of_bounds(), None)
+
+ def test_unwrap_optional_builtin(self):
+ def test(x):
+ # type: (Optional[int]) -> int
+ x = torch.jit._unwrap_optional(x)
+ x = x + x
+ return x
+
+ self.checkScript(test, (3,))
+
+ with self.assertRaisesRegex(AssertionError, "Unwrapping null optional"):
+ test(None)
+
+ test_script = torch.jit.script(test)
+ with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
+ test_script(None)
+
+ @torch.jit.script
+ def test_test():
+ return torch.jit._unwrap_optional(1)
+
+ with self.assertRaisesRegex(RuntimeError, "cannot match an Optional\\[T\\] to None"):
+ @torch.jit.script
+ def test_no_type():
+ # type: () -> int
+ return torch.jit._unwrap_optional(None)
- def forward(self, x):
- return traced_fn(torch.mm(x, self.param))
+ def test_indexing_error(self):
+ with self.assertRaisesRegex(RuntimeError, "Indexing only supported on lists, tensors, and tuples"):
+ @torch.jit.script
+ def test_wrong_type():
+ a = 8
+ return a[0]
- tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
- # Note: neg op from the traced function should be properly inlined
- self.assertExpected(canonical(tm.graph))
+ def test_annotated_script_fn(self):
+ @torch.jit.script
+ def foo(x, y, z):
+ # type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
+ return x
- def test_call_traced_module_from_traced_module(self):
- class TracedModule1(torch.nn.Module):
- def __init__(self):
- super(TracedModule1, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(5, 7))
+ self.assertExpected(foo.__getattr__('forward').pretty_print_schema())
- def forward(self, x):
- return torch.mm(x, self.param)
+ def test_annotated_script_method(self):
+ class SM(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x, y):
+ # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor, Tensor]
+ return y, y, y
- class TracedModule(torch.nn.Module):
- def __init__(self):
- super(TracedModule, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 5))
- self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5))
+ sm = SM()
- def forward(self, x):
- return self.mod(torch.mm(x, self.param)) + 1.0
+ self.assertExpected(sm.__getattr__('forward').pretty_print_schema())
- tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+ def test_annotated_script_fn_return_mismatch(self):
+ with self.assertRaisesRegex(RuntimeError, r"Return value at position 0 was annotated as "
+ r"having type \(Tensor, Tensor\) but is "
+ r"actually of type Tensor"):
+ @torch.jit.script
+ def return_tup(x):
+ # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
+ return x, x
- # Note: the parameters from both modules should appear in the flattened
- # inputs of the graph. All ops from both modules should be inlined.
- self.assertExpected(canonical(tm.graph))
+ def test_annotated_script_fn_arg_mismatch(self):
+ with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
+ @torch.jit.script
+ def tuple_arg(x):
+ # type: (Tuple[Tensor, Tensor]) -> Tensor
+ return x + 1
- def test_call_script_fn_from_traced_module(self):
+ def test_script_non_tensor_args_outputs(self):
@torch.jit.script
- def traced_fn(x):
- return torch.neg(x)
-
- class TracedModule(torch.nn.Module):
- def __init__(self):
- super(TracedModule, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 5))
+ def fn(x, y):
+ # type: (Tensor, float) -> float
+ return float((x + y).sum())
- def forward(self, x):
- return traced_fn(torch.mm(x, self.param))
+ x = torch.ones(2, 2)
+ z = fn(x, 1)
+ self.assertIsInstance(z, float)
+ self.assertEqual(z, 8.)
- tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
- # Note: neg op from the script function should be properly inlined
- self.assertExpected(canonical(tm.graph))
+ @unittest.skip('https://github.com/pytorch/pytorch/issues/9595')
+ def test_inline_and_run_annotated_script_fn(self):
+ @torch.jit.script
+ def to_inline(x, y):
+ # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
+ return y
- def test_call_script_module_from_traced_module(self):
- class ScriptMod(torch.jit.ScriptModule):
- def __init__(self):
- super(ScriptMod, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(5, 7))
+ @torch.jit.script
+ def some_func(x):
+ return to_inline((x, x), x)
- @torch.jit.script_method
- def forward(self, x):
- return torch.mm(x, self.param)
+ x = torch.rand(3, 4)
+ self.assertEqual(some_func(x), x)
- class TracedModule(torch.nn.Module):
- def __init__(self):
- super(TracedModule, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 5))
- self.mod = ScriptMod()
+ def test_file_format_serialization(self):
+ import tempfile
+ filename = tempfile.mktemp()
+ writer = torch._C.PyTorchFileWriter(filename)
+ import os
+ import random
+ buffers = [os.urandom(size) for size in [random.randint(1, 100) for i in range(20)]]
+ offsets = []
+ for i, buf in enumerate(buffers):
+ writer.write_record(str(i), buf, len(buf))
+ offsets.append(i)
+ import pickle
+ serialized_offsets = pickle.dumps(offsets)
+ writer.write_record("meta", serialized_offsets, len(serialized_offsets))
+ writer.write_end_of_file()
- def forward(self, x):
- return self.mod(torch.mm(x, self.param)) + 1.0
+ reader = torch._C.PyTorchFileReader(filename)
+ serialized_offsets_read = reader.get_record("meta")
+ parsed_serialized_offsets = pickle.loads(serialized_offsets)
- tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+ for i, offset in enumerate(parsed_serialized_offsets):
+ data = reader.get_record(str(offset))
+ assert(data == buffers[i])
- # Note: the parameters from both modules should appear in the flattened
- # inputs of the graph. All ops from both modules should be inlined.
- self.assertExpected(canonical(tm.graph))
+ # for each type, the input type annotation and corresponding return type annotation
+ def type_input_return_pairs(self):
+ return [
+ ('Tensor', 'Tensor'),
+ ('torch.Tensor', 'Tensor'),
+ ('str', 'str'),
+ ('int', 'int'),
+ ('bool', 'bool'),
+ ('BroadcastingList3[float]', 'List[float]'),
+ ('BroadcastingList2[int]', 'List[int]'),
+ ('List[int]', 'List[int]'),
+ ('Optional[int]', 'Optional[int]'),
+ ]
- def test_call_python_fn_from_script_fn(self):
- def python_fn(x):
- return torch.neg(x)
+ # replacing code input & return type pair
+ def format_code(self, code, pair):
+ return code.format(input=pair[0], output=pair[1])
- @torch.jit.script
- def script_fn(x):
- return python_fn(x) + 1
+ # ***** Type annotation tests ****
+ # Test combinations of:
+ # {String frontend, Python AST Frontend}
+ # {Python 3-style type annotations, MyPy-style type comments}
+ # {Script method, Script function}
- # Note: the call to python_fn appears as `^python_fn()` and is called
- # as a PythonOp in the interpreter
- self.assertExpected(canonical(script_fn.graph))
+ # String frontend , Python 3-style type annotations , Script function
+ def test_annot_string_py3_fn(self):
+ code = '''
+ def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
+ return x, x
+ '''
+ test_str = []
+ for pair in self.type_input_return_pairs():
+ cu = torch.jit.CompilationUnit(self.format_code(code, pair))
+ test_str.append(cu.__getattr__('foo').pretty_print_schema())
+ self.assertExpected("\n".join(test_str))
- def test_call_python_mod_from_script_fn(self):
- class PythonModule(torch.nn.Module):
+ # String frontend , Python 3-style type annotations , Script method
+ def test_annot_string_py3_method(self):
+ class TestModule(torch.jit.ScriptModule):
def __init__(self):
- super(PythonModule, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(5, 7))
+ super(TestModule, self).__init__()
- def forward(self, x):
- return torch.mm(x, self.param)
+ code = '''
+ def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
+ return x, x
+ '''
+ test_str = []
+ for pair in self.type_input_return_pairs():
+ tm = TestModule()
+ tm.define(self.format_code(code, pair))
+ test_str.append(tm.__getattr__('foo').pretty_print_schema())
+ self.assertExpected("\n".join(test_str))
- pm = PythonModule()
+ # String frontend , MyPy-style type comments , Script function
+ def test_annot_string_mypy_fn(self):
+ code = '''
+ def foo(x, y):
+ # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
+ return x, x
+ '''
+ test_str = []
+ for pair in self.type_input_return_pairs():
+ cu = torch.jit.CompilationUnit(self.format_code(code, pair))
+ test_str.append(cu.__getattr__('foo').pretty_print_schema())
+ self.assertExpected("\n".join(test_str))
- @torch.jit.script
- def script_fn(x):
- return pm(x) + 1
+ # String frontend , MyPy-style type comments , Script method
+ def test_annot_string_mypy_method(self):
+ class TestModule(torch.jit.ScriptModule):
+ def __init__(self):
+ super(TestModule, self).__init__()
- # Note: call to pm(x) appears as ^<python_value>() in the trace.
- # Parameters are NOT inlined.
- self.assertExpected(str(script_fn.graph))
+ code = '''
+ def foo(self, x, y):
+ # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
+ return x, x
+ '''
- def test_call_traced_fn_from_script_fn(self):
- @_trace(torch.rand(3, 4))
- def traced_fn(x):
- return torch.neg(x)
+ test_str = []
+ for pair in self.type_input_return_pairs():
+ tm = TestModule()
+ tm.define(self.format_code(code, pair))
+ test_str.append(tm.__getattr__('foo').pretty_print_schema())
+ self.assertExpected("\n".join(test_str))
- @torch.jit.script
- def script_fn(x):
- return traced_fn(x) + 1
+ # Helper function to eval Python3 code without causing a syntax error for
+ # this file under py2
+ def _get_py3_code(self, code, fn_name):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ script_path = os.path.join(tmp_dir, 'script.py')
+ with open(script_path, 'w') as f:
+ f.write(code)
+ import importlib.util
+ spec = importlib.util.spec_from_file_location(fn_name, script_path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ fn = getattr(module, fn_name)
+ return fn
- # Note: the neg op from traced_fn should be properly inlined into the
- # script function's graph
- self.assertExpected(str(script_fn.graph))
+ # Python AST Frontend , Python 3-style type annotations , Script function
+ @unittest.skipIf(not PY35, "Python 3.5 needed")
+ def test_annot_ast_py3_fn(self):
+ code = dedent('''
+ from typing import Tuple, List, Optional
+ from torch import Tensor
+ from torch.jit.annotations import BroadcastingList2, BroadcastingList3
+ import torch
+ @torch.jit.script
+ def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
+ return x, x
+ ''')
+ test_str = []
+ for pair in self.type_input_return_pairs():
+ fn = self._get_py3_code(self.format_code(code, pair), 'foo')
+ test_str.append(fn.__getattr__('forward').pretty_print_schema())
+ self.assertExpected("\n".join(test_str))
- def test_call_traced_mod_from_script_fn(self):
- class TracedModule(torch.nn.Module):
- def __init__(self):
- super(TracedModule, self).__init__()
+ # Python AST Frontend , Python 3-style type annotations , Script method
+ @unittest.skipIf(not PY35, "Python 3.5 needed")
+ def test_annot_ast_py3_method(self):
+ code = dedent('''
+ from typing import Tuple, List, Optional
+ from torch import Tensor
+ from torch.jit.annotations import BroadcastingList2, \\
+ BroadcastingList3
+ import torch
+ class FooModule(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
+ return x, x
+ instance = FooModule()
+ ''')
- def forward(self, x):
- return torch.mm(x, torch.zeros(4, 3))
+ test_str = []
+ for pair in self.type_input_return_pairs():
+ fn = self._get_py3_code(self.format_code(code, pair), 'instance')
+ test_str.append(fn.__getattr__('foo').pretty_print_schema())
+ self.assertExpected("\n".join(test_str))
- tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+ # Python AST Frontend , MyPy-style type comments , Script function
+ @unittest.skipIf(not PY35, "Python 3.5 needed")
+ def test_annot_ast_mypy_fn(self):
+ code = dedent('''
+ import torch
+ @torch.jit.script
+ def foo(x, y):
+ # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
+ return x, x
+ ''')
- @torch.jit.script
- def script_fn(x):
- return tm(x) + 1
+ test_str = []
+ for pair in self.type_input_return_pairs():
+ fn = self._get_py3_code(self.format_code(code, pair), 'foo')
+ test_str.append(fn.__getattr__('forward').pretty_print_schema())
+ self.assertExpected("\n".join(test_str))
- self.assertExpected(str(script_fn.graph))
+ # Python AST Frontend , MyPy-style type comments , Script method
+ @unittest.skipIf(not PY35, "Python 3.5 needed")
+ def test_annot_ast_mypy_method(self):
+ code = dedent('''
+ import torch
+ class FooModule(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def foo(self, x, y):
+ # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
+ return x, x
+ instance = FooModule()
+ ''')
- def test_call_script_fn_from_script_fn(self):
- @torch.jit.script
- def script_fn1(x):
- return torch.neg(x)
+ test_str = []
+ for pair in self.type_input_return_pairs():
+ fn = self._get_py3_code(self.format_code(code, pair), 'instance')
+ test_str.append(fn.__getattr__('foo').pretty_print_schema())
+ self.assertExpected("\n".join(test_str))
- @torch.jit.script
- def script_fn(x):
- return script_fn1(x) + 1
+ def test_method_casts_script(self):
+ cast_types = [
+ 'byte', 'char', 'double', 'float', 'int', 'long', 'short'
+ ]
- # Note: the neg op from script_fn1 should be properly inlined into the
- # graph of script_fn
- self.assertExpected(canonical(script_fn.graph))
+ for cast_type in cast_types:
+ cu = torch.jit.CompilationUnit('''
+ def cast_to(x):
+ return x.{cast_type}()
+ '''.format(cast_type=cast_type))
- def test_call_script_mod_from_script_fn(self):
- class ScriptMod(torch.jit.ScriptModule):
- def __init__(self):
- super(ScriptMod, self).__init__()
+ x = torch.rand(3, 4, 5) * 128
+ cu_result = cu.cast_to(x)
+ reference = getattr(x, cast_type)()
+ self.assertEqual(cu_result, reference)
- @torch.jit.script_method
+ def test_listconstruct_erasure(self):
+ class FooMod(torch.nn.Module):
def forward(self, x):
- return torch.mm(x, torch.zeros([4, 3]))
+ mask = x < 0.0
+ return x[mask]
- sm = ScriptMod()
+ import io
+ f = io.BytesIO()
+ self.assertExpected(torch.onnx.export_to_pretty_string(
+ FooMod(), (torch.rand(3, 4),), f,
+ operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK))
- @torch.jit.script
- def script_fn(x):
- return sm(x) + 1
+ def test_trace_checker_arange_as_constant(self):
+ with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
+ @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)])
+ def foo(x):
+ y = torch.arange(0, x.shape[0]).double()
+ return x + y.unsqueeze(1)
- self.assertExpected(canonical(script_fn.graph))
+ @suppress_warnings
+ def test_trace_checker_dot_data(self):
+ with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value '
+ r'across invocations'):
+ @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
+ def foo(x):
+ y = x.data
+ return x + y
- def test_call_python_fn_from_script_module(self):
- def python_fn(x):
- return torch.neg(x)
+ @suppress_warnings
+ def test_trace_checker_control_flow(self):
+ def foo(x):
+ for _ in range(x.size(0)):
+ x = torch.neg(x)
+ return x
- class ScriptMod(torch.jit.ScriptModule):
- def __init__(self):
- super(ScriptMod, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 3))
+ with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
+ torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])
- @torch.jit.script_method
- def forward(self, x):
- return python_fn(torch.mm(x, self.param))
+ @suppress_warnings
+ def test_trace_checker_memoization(self):
+ with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
+ def foo(x):
+ if not hasattr(foo, 'cache'):
+ foo.cache = torch.neg(x)
+ return x + foo.cache
- sm = ScriptMod()
- self.assertExpected(str(sm.__getattr__('forward').graph))
+ traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
- def test_call_python_mod_from_script_module(self):
- class PythonMod(torch.nn.Module):
- def __init__(self):
- super(PythonMod, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(3, 5))
+ def checkTracerWarning(self, *args, **kwargs):
+ with warnings.catch_warnings(record=True) as warns:
+ torch.jit.trace(*args, **kwargs)
+ self.assertGreater(len(warns), 0)
+ for warn in warns:
+ self.assertIn("cause the trace to be incorrect", str(warn.message))
- def forward(self, x):
- return torch.mm(x, self.param)
+ def test_trace_checker_slice_lhs(self):
+ def foo(x):
+ for i in range(3):
+ x[i, :] = torch.zeros(4)
+ return x
- class ScriptMod(torch.jit.ScriptModule):
- def __init__(self):
- super(ScriptMod, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 3))
- self.pm = PythonMod()
+ self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(3, 4)]),
+ 'Output nr 1. of the traced function does not match the '
+ 'corresponding output of the Python function')
- @torch.jit.script_method
- def forward(self, x):
- return self.pm(torch.mm(x, self.param))
+ def test_trace_checker_inplace_on_view(self):
+ def foo(x):
+ x.view(-1).add_(-x.view(-1))
+ return x
- sm = ScriptMod()
- # Note: the call into PythonMod appears as ^<python_value>(). Parameters
- # are NOT inlined
- self.assertExpected(str(sm.graph))
+ self.assertWarnsRegex(lambda: torch.jit.trace(foo,
+ torch.rand(3, 4),
+ check_inputs=[torch.rand(5, 6)],
+ _force_outplace=True),
+ 'Output nr 1. of the traced function does not match the '
+ 'corresponding output of the Python function')
- def test_call_tracing_fn_from_script_module(self):
- @_trace(torch.rand(3, 3))
- def traced_fn(x):
- return torch.neg(x)
+ def test_lhs_index_fails(self):
+ def foo(x):
+ x[0, 1] = 4
+ return x
+ self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
- class ScriptMod(torch.jit.ScriptModule):
- def __init__(self):
- super(ScriptMod, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 3))
+ def test_lhs_index_trivial(self):
+ def foo(y, x):
+ y[...] = x
+ return y
+ self.checkTrace(foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False)
- @torch.jit.script_method
- def forward(self, x):
- return traced_fn(torch.mm(x, self.param))
+ def test_inplace_warn(self):
+ def foo(x):
+ x.view(-1).add_(-x.view(-1))
+ return x
+ self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
- sm = ScriptMod()
- self.assertExpected(str(sm.__getattr__('forward').graph))
+ @suppress_warnings
+ def test_trace_checker_dropout_train(self):
+ def foo(x):
+ return torch.dropout(x, p=0.5, train=True)
- def test_call_tracing_mod_from_script_module(self):
- class TracedMod(torch.nn.Module):
- def __init__(self):
- super(TracedMod, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(3, 5))
+ self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
+ 'Output nr 1. of the traced function does not match the '
+ 'corresponding output of the Python function')
+ self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
+ 'Trace had nondeterministic nodes')
- def forward(self, x):
- return torch.mm(x, self.param)
+ def test_trace_checker_dropout_notrain(self):
+ input = torch.rand(3, 4)
- class ScriptMod(torch.jit.ScriptModule):
- def __init__(self):
- super(ScriptMod, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 3))
- self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3))
+ @_trace(input)
+ def foo(x):
+ return torch.dropout(x, p=0.5, train=False)
+
+ self.assertEqual(foo(input), input)
+ def test_export_dynamic_slice(self):
+ class DynamicSliceExportMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
- return self.tm(torch.mm(x, self.param))
+ retval = x[0]
+ for i in range(x.size(1)):
+ retval += torch.sum(x[0:i], dim=0)
+ return retval
- sm = ScriptMod()
- # Note: the parameters from both modules should appear in the flattened
- # input list to the graph. The mm op from TracedMod should be properly
- # inlined
- self.assertExpected(str(sm.graph))
+ mod = DynamicSliceExportMod()
- def test_call_script_fn_from_script_module(self):
- @torch.jit.script
- def script_fn(x):
- return torch.neg(x)
+ input = torch.rand(3, 4, 5)
+ example_outs = mod(input)
- class ScriptMod(torch.jit.ScriptModule):
- def __init__(self):
- super(ScriptMod, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 3))
+ f = io.BytesIO()
+ exported = torch.onnx.export_to_pretty_string(
+ DynamicSliceExportMod(), (input,), f, example_outputs=example_outs)
+ self.assertExpected(exported)
- @torch.jit.script_method
- def forward(self, x):
- return script_fn(torch.mm(x, self.param))
+ def test_string_frontend_elif(self):
+ code = '''
+ def elif_test(niter : int):
+ rv = 0
+ for i in range(niter):
+ if i % 3 == 0 and i % 5 == 0:
+ rv += 35
+ elif i % 3 == 0:
+ rv += 3
+ elif i % 5 == 0:
+ rv += 5
+ else:
+ rv += i
+ return rv
+ '''
- sm = ScriptMod()
- self.assertExpected(canonical(sm.__getattr__('forward').graph))
+ self.checkScript(code, (101,), name='elif_test', outputs=3028)
- def test_call_script_mod_from_script_module(self):
- class ScriptMod1(torch.jit.ScriptModule):
- def __init__(self):
- super(ScriptMod1, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(3, 5))
+ def test_addmm_fusion(self):
+ class AddmmWrapper(torch.nn.Module):
+ def forward(self, x, y, c):
+ return torch.mm(x, y) + c
- @torch.jit.script_method
- def forward(self, x):
- return torch.mm(x, self.param)
+ # Test addmm fusion is disabled for normal Jit
+ x, y, c = torch.rand(3, 4), torch.rand(4, 5), torch.rand(3, 5)
+ f = io.BytesIO()
+ pretty = torch.onnx.export_to_pretty_string(AddmmWrapper(), (x, y, c), f)
+ self.assertExpected(pretty, 'onnx')
- class ScriptMod(torch.jit.ScriptModule):
+ jit_trace = torch.jit.trace(AddmmWrapper(), (x, y, c))
+ ge_graph = jit_trace.__getattr__('forward').graph_for(x, y, c)
+ self.assertExpectedGraph(ge_graph, 'jit')
+
+ def test_pyop_exception_message(self):
+ class Foo(torch.jit.ScriptModule):
def __init__(self):
- super(ScriptMod, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(4, 3))
- self.tm = ScriptMod1()
+ super(Foo, self).__init__()
+ self.conv = nn.Conv2d(1, 10, kernel_size=5)
@torch.jit.script_method
def forward(self, x):
- return self.tm(torch.mm(x, self.param))
-
- sm = ScriptMod()
- # Note: the parameters from both modules should appear in the flattened
- # input list to the graph. The mm op from ScriptMod1 should be properly
- # inlined
- self.assertExpected(canonical(sm.graph))
-
- def test_module_with_params_called_fails(self):
- with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with parameters. Stateful "
- "modules to be inlined must be submodules of the callee."):
- class ScriptMod(torch.jit.ScriptModule):
- def __init__(self):
- super(ScriptMod, self).__init__()
- self.param = torch.nn.Parameter(torch.rand(3, 3))
-
- @torch.jit.script_method
- def forward(self, x):
- return torch.mm(x, self.param)
-
- sm = ScriptMod()
+ return self.conv(x)
+ foo = Foo()
+ # testing that the correct error message propagates
+ with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
+ foo(torch.ones([123])) # wrong size
- @torch.jit.script
- def some_func(x):
- return sm(x)
+ def test_exceptions(self):
+ cu = torch.jit.CompilationUnit('''
+ def foo(cond):
+ if bool(cond):
+ raise ValueError(3)
+ return 1
+ ''')
- def test_index_put_trace_with_view(self):
- @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
- def test_index_put(target, indices, rhs):
- target[indices] = rhs
- return target
+ cu.foo(torch.tensor(0))
+ with self.assertRaisesRegex(torch.jit.Error, "Exception"):
+ cu.foo(torch.tensor(1))
- self.assertExpectedGraph(test_index_put.graph)
+ @torch.jit.script
+ def foo(cond):
+ a = 3
+ if bool(cond):
+ raise ArbitraryError(a, "hi")
+ if False:
+ raise ArbitraryError
+ return a
- def test_index_put_trace_without_view(self):
- @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
- def test_index_put(target, indices, rhs):
- target[indices] = rhs
- return target
+ foo(torch.tensor(0))
+ # we don't currently validate the name of the exception
+ with self.assertRaisesRegex(torch.jit.Error, "Exception"):
+ foo(torch.tensor(1))
- self.assertExpectedGraph(test_index_put.graph)
+ @torch.jit.script
+ def foo_except_used():
+ a = Exception()
+ print(a)
+ raise a
- def test_tuple_indexing(self):
- def tuple_index(a):
- if bool(a):
- b = (1, 2)
- else:
- b = (0, 2)
- return b[-2], b[1]
+ # a not DCEd
+ with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"):
+ foo_except_used()
- self.checkScript(tuple_index, (torch.tensor([1]),))
- self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True)
- tuple_comp = torch.jit.script(tuple_index)
- self.assertExpectedGraph(tuple_comp.graph)
- self.run_pass('lower_all_tuples', tuple_comp.graph)
- m = torch.jit.ScriptModule()
- m._create_method_from_graph("forward", tuple_comp.graph)
- self.assertEqual(m(torch.tensor(1)), (1, 2))
+ # We don't validate the expr following raise
+ @torch.jit.script
+ def foo():
+ raise 3 + 4
- with self.assertRaisesRegex(RuntimeError, "tuple indices must be integer constants"):
+ # no control flow analysis yet
+ with self.assertRaisesRegex(RuntimeError, "undefined value a"):
@torch.jit.script
- def test_non_constant_input(a):
- if bool(a):
- b = 1
+ def foo():
+ if True:
+ a = 1
else:
- b = 0
- c = (0, 1)
- return c[b]
-
- def test_indexing_float():
- c = (1, 2)
- return c[0.1]
- self.checkScriptRaisesRegex(test_indexing_float, (), Exception,
- "tuple indices must")
+ raise Exception("Hi")
+ return a
- def test_indexing_out_of_bounds_pos():
- c = (1, 2)
- return c[2]
+ def test_assertions(self):
+ cu = torch.jit.CompilationUnit('''
+ def foo(cond):
+ assert bool(cond), "hi"
+ return 0
+ ''')
- self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
- "out of range")
+ cu.foo(torch.tensor(1))
+ with self.assertRaisesRegex(torch.jit.Error, "Exception"):
+ cu.foo(torch.tensor(0))
- def test_indexing_out_of_bounds_neg():
- c = (1, 2)
- return c[-3]
+ @torch.jit.script
+ def foo(cond):
+ assert bool(cond), "hi"
- self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
- "out of range")
+ foo(torch.tensor(1))
+ # we don't currently validate the name of the exception
+ with self.assertRaisesRegex(torch.jit.Error, "Exception"):
+ foo(torch.tensor(0))
- def test_tuple_slicing(self):
- def tuple_slice(a):
- if bool(a):
- b = (1, 2, 3, 4)
- else:
- b = (4, 3, 2, 1)
- c = b[-4:4]
- d = b[0:]
- e = c[1:-1]
- return e
+ def test_weak_script_function(self):
+ outer_var = 10
+ outer_var2 = 11
- self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
- tuple_comp = torch.jit.script(tuple_slice)
- self.assertExpectedGraph(tuple_comp.graph)
- self.run_pass('lower_all_tuples', tuple_comp.graph)
- self.assertTrue('Tuple' not in str(tuple_comp.graph))
- m = torch.jit.ScriptModule()
- m._create_method_from_graph("forward", tuple_comp.graph)
- self.assertEqual(m(torch.tensor(1)), (2, 3))
+ def not_a_script_fn(x):
+ return x + 2
@torch.jit.script
- def test_indexing_end_out_of_bounds():
- c = (1, 2)
- return c[2:10]
-
- # output is None in script and () in python
- self.assertEqual(test_indexing_end_out_of_bounds(), None)
+ def even_more_inner(x):
+ return x + 1
- def test_unwrap_optional_builtin(self):
- def test(x):
- # type: (Optional[int]) -> int
- x = torch.jit._unwrap_optional(x)
- x = x + x
- return x
+ @torch.jit.script
+ def inner(x):
+ return not_a_script_fn(x) + x + even_more_inner(x)
- self.checkScript(test, (3,))
+ @torch.jit.script
+ def strong_script_fn(x):
+ if bool(x.norm() > 2):
+ x = x + 3
+ return x + 4 + inner(x)
- with self.assertRaisesRegex(AssertionError, "Unwrapping null optional"):
- test(None)
+ @torch._jit_internal.weak_script
+ def weak_script_fn_inner(x):
+ return x + 6 + not_a_script_fn(x)
- test_script = torch.jit.script(test)
- with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
- test_script(None)
+ @torch._jit_internal.weak_script
+ def weak_script_fn(x):
+ return x + 5 + weak_script_fn_inner(x) + weak_script_fn_inner(x)
- @torch.jit.script
- def test_test():
- return torch.jit._unwrap_optional(1)
+ def fn(x):
+ x = not_a_script_fn(x)
+ x = strong_script_fn(x)
+ return weak_script_fn(x)
- with self.assertRaisesRegex(RuntimeError, "cannot match an Optional\\[T\\] to None"):
- @torch.jit.script
- def test_no_type():
- # type: () -> int
- return torch.jit._unwrap_optional(None)
+ input = torch.randn(3, 4, 5)
+ self.checkScript(fn, (input,))
- def test_indexing_error(self):
- with self.assertRaisesRegex(RuntimeError, "Indexing only supported on lists, tensors, and tuples"):
- @torch.jit.script
- def test_wrong_type():
- a = 8
- return a[0]
+ def test_python_op_exception(self):
+ def python_op(x):
+ raise Exception("bad!")
- def test_annotated_script_fn(self):
@torch.jit.script
- def foo(x, y, z):
- # type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
- return x
+ def fn(x):
+ return python_op(x)
- self.assertExpected(foo.__getattr__('forward').pretty_print_schema())
+ with self.assertRaisesRegex(RuntimeError, "operation failed in interpreter"):
+ fn(torch.tensor(4))
- def test_annotated_script_method(self):
- class SM(torch.jit.ScriptModule):
- @torch.jit.script_method
- def forward(self, x, y):
- # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor, Tensor]
- return y, y, y
+ def test_trace_contiguous(self):
+ def foo(x):
+ return x[:, :, ::2].contiguous().view(12)
- sm = SM()
+ x = torch.rand(2, 3, 4)
+ traced = torch.jit.trace(foo, (x,))
+ y = traced(x)
+ self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())
- self.assertExpected(sm.__getattr__('forward').pretty_print_schema())
+ # This tests the logic in THPVariable_contiguous. There is short-circuiting
+ # code that prevents us from even getting to VariableType::contiguous, since
+ # it is an optimization that prevents us from acquiring the GIL for touching
+ # the device. We needed to add the tracing logic directly into the
+ # THPVariable_contiguous function only for the path where we are skipping
+ # dispatch into contiguous. We should see an aten::contiguous in this trace!
+ def test_trace_contiguous_short_circuit(self):
+ def foo(x):
+ return x.contiguous()
- def test_annotated_script_fn_return_mismatch(self):
- with self.assertRaisesRegex(RuntimeError, r"Return value at position 0 was annotated as "
- r"having type \(Tensor, Tensor\) but is "
- r"actually of type Tensor"):
- @torch.jit.script
- def return_tup(x):
- # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
- return x, x
+ x = torch.rand(2, 3, 4)
+ traced = torch.jit.trace(foo, (x,))
+ self.assertExpectedGraph(traced.graph)
- def test_annotated_script_fn_arg_mismatch(self):
- with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
- @torch.jit.script
- def tuple_arg(x):
- # type: (Tuple[Tensor, Tensor]) -> Tensor
- return x + 1
+ def test_weak_module(self):
- def test_script_non_tensor_args_outputs(self):
- @torch.jit.script
- def fn(x, y):
- # type: (Tensor, float) -> float
- return float((x + y).sum())
+ @torch._jit_internal.weak_module
+ class Weak(torch.nn.Module):
+ __constants__ = ['number']
- x = torch.ones(2, 2)
- z = fn(x, 1)
- self.assertIsInstance(z, float)
- self.assertEqual(z, 8.)
+ def __init__(self):
+ super(Weak, self).__init__()
+ self.number = 199
- @unittest.skip('https://github.com/pytorch/pytorch/issues/9595')
- def test_inline_and_run_annotated_script_fn(self):
- @torch.jit.script
- def to_inline(x, y):
- # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
- return y
+ def python_op_in_weak_module(self, x):
+ return x + 123
- @torch.jit.script
- def some_func(x):
- return to_inline((x, x), x)
+ @torch._jit_internal.weak_script_method
+ def forward(self, x):
+ return 55 + self.number + self.python_op_in_weak_module(x)
- x = torch.rand(3, 4)
- self.assertEqual(some_func(x), x)
+ class OtherStrong(torch.jit.ScriptModule):
+ __constants__ = ['number']
- def test_file_format_serialization(self):
- import tempfile
- filename = tempfile.mktemp()
- writer = torch._C.PyTorchFileWriter(filename)
- import os
- import random
- buffers = [os.urandom(size) for size in [random.randint(1, 100) for i in range(20)]]
- offsets = []
- for i, buf in enumerate(buffers):
- writer.write_record(str(i), buf, len(buf))
- offsets.append(i)
- import pickle
- serialized_offsets = pickle.dumps(offsets)
- writer.write_record("meta", serialized_offsets, len(serialized_offsets))
- writer.write_end_of_file()
+ def __init__(self):
+ super(OtherStrong, self).__init__()
+ self.number = 357
- reader = torch._C.PyTorchFileReader(filename)
- serialized_offsets_read = reader.get_record("meta")
- parsed_serialized_offsets = pickle.loads(serialized_offsets)
+ def python_op_in_strong_module(self, x):
+ return x + 456
- for i, offset in enumerate(parsed_serialized_offsets):
- data = reader.get_record(str(offset))
- assert(data == buffers[i])
+ @torch.jit.script_method
+ def forward(self, x):
+ return x + self.number + self.python_op_in_strong_module(x)
- # for each type, the input type annotation and corresponding return type annotation
- def type_input_return_pairs(self):
- return [
- ('Tensor', 'Tensor'),
- ('torch.Tensor', 'Tensor'),
- ('str', 'str'),
- ('int', 'int'),
- ('bool', 'bool'),
- ('BroadcastingList3[float]', 'List[float]'),
- ('BroadcastingList2[int]', 'List[int]'),
- ('List[int]', 'List[int]'),
- ('Optional[int]', 'Optional[int]'),
- ]
+ class Passthrough(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Passthrough, self).__init__()
+ self.weak = Weak()
- # replacing code input & return type pair
- def format_code(self, code, pair):
- return code.format(input=pair[0], output=pair[1])
+ @torch.jit.script_method
+ def forward(self, x):
+ return self.weak(x)
- # ***** Type annotation tests ****
- # Test combinations of:
- # {String frontend, Python AST Frontend}
- # {Python 3-style type annotations, MyPy-style type comments}
- # {Script method, Script function}
+ weak_mod = Weak()
+ x = torch.ones(1)
+ expected_result = 55 + 199 + (x + 123)
- # String frontend , Python 3-style type annotations , Script function
- def test_annot_string_py3_fn(self):
- code = '''
- def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
- return x, x
- '''
- test_str = []
- for pair in self.type_input_return_pairs():
- cu = torch.jit.CompilationUnit(self.format_code(code, pair))
- test_str.append(cu.__getattr__('foo').pretty_print_schema())
- self.assertExpected("\n".join(test_str))
+ # Ensure weak mod is running without the JIT by passing the wrong type
+ # (i.e. not a tensor)
+ weak_mod(2)
- # String frontend , Python 3-style type annotations , Script method
- def test_annot_string_py3_method(self):
- class TestModule(torch.jit.ScriptModule):
+ python_result = weak_mod(x)
+ strong_mod = Passthrough()
+ script_result = strong_mod(x)
+
+ self.assertEqual(python_result, expected_result)
+ self.assertEqual(script_result, expected_result)
+
+ class Strong(torch.jit.ScriptModule):
def __init__(self):
- super(TestModule, self).__init__()
+ super(Strong, self).__init__()
+ self.weak = Weak()
+ self.strong = OtherStrong()
- code = '''
- def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
- return x, x
- '''
- test_str = []
- for pair in self.type_input_return_pairs():
- tm = TestModule()
- tm.define(self.format_code(code, pair))
- test_str.append(tm.__getattr__('foo').pretty_print_schema())
- self.assertExpected("\n".join(test_str))
+ @torch.jit.script_method
+ def forward(self, x):
+ y = 2 * x
+ return y + 1 + self.weak(y) + self.strong(y)
- # String frontend , MyPy-style type comments , Script function
- def test_annot_string_mypy_fn(self):
- code = '''
- def foo(x, y):
- # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
- return x, x
- '''
- test_str = []
- for pair in self.type_input_return_pairs():
- cu = torch.jit.CompilationUnit(self.format_code(code, pair))
- test_str.append(cu.__getattr__('foo').pretty_print_schema())
- self.assertExpected("\n".join(test_str))
+ strong_mod = Strong()
+ strong_mod2 = Strong()
+ x = torch.ones(1)
+ expected_result = (x * 2) + 1 + (55 + 199 + x * 2 + 123) + (x * 2 + 357 + x * 2 + 456)
+ script_result = strong_mod(x)
+ script_result2 = strong_mod2(x)
+ self.assertEqual(script_result, expected_result)
+ self.assertEqual(script_result, script_result2)
- # String frontend , MyPy-style type comments , Script method
- def test_annot_string_mypy_method(self):
- class TestModule(torch.jit.ScriptModule):
- def __init__(self):
- super(TestModule, self).__init__()
+ def test_weak_module_parameters_and_buffers(self):
+ weights = torch.randn(10, 10)
+ bias = torch.randn(10)
+ weights2 = torch.randn(10, 10)
+ bias2 = torch.randn(10)
- code = '''
- def foo(self, x, y):
- # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
- return x, x
- '''
+ @torch._jit_internal.weak_module
+ class TestLinear(torch.nn.Module):
+ def __init__(self, in_features, out_features):
+ super(TestLinear, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
+ self.bias = torch.nn.Parameter(torch.Tensor(out_features))
+ self.register_buffer('counter', torch.ones(out_features))
+ self.reset_parameters()
- test_str = []
- for pair in self.type_input_return_pairs():
- tm = TestModule()
- tm.define(self.format_code(code, pair))
- test_str.append(tm.__getattr__('foo').pretty_print_schema())
- self.assertExpected("\n".join(test_str))
+ def reset_parameters(self):
+ torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in)
+ torch.nn.init.uniform_(self.bias, -bound, bound)
- # Helper function to eval Python3 code without causing a syntax error for
- # this file under py2
- def _get_py3_code(self, code, fn_name):
- with tempfile.TemporaryDirectory() as tmp_dir:
- script_path = os.path.join(tmp_dir, 'script.py')
- with open(script_path, 'w') as f:
- f.write(code)
- import importlib.util
- spec = importlib.util.spec_from_file_location(fn_name, script_path)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- fn = getattr(module, fn_name)
- return fn
+ @torch._jit_internal.weak_script_method
+ def forward(self, input):
+ return F.linear(input, self.weight, self.bias) + self.counter
- # Python AST Frontend , Python 3-style type annotations , Script function
- @unittest.skipIf(not PY35, "Python 3.5 needed")
- def test_annot_ast_py3_fn(self):
- code = dedent('''
- from typing import Tuple, List, Optional
- from torch import Tensor
- from torch.jit.annotations import BroadcastingList2, BroadcastingList3
- import torch
- @torch.jit.script
- def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
- return x, x
- ''')
- test_str = []
- for pair in self.type_input_return_pairs():
- fn = self._get_py3_code(self.format_code(code, pair), 'foo')
- test_str.append(fn.__getattr__('forward').pretty_print_schema())
- self.assertExpected("\n".join(test_str))
+ # Initialize a ScriptModule that uses the weak module above multiple times
+ class Strong(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Strong, self).__init__()
+ self.fc1 = TestLinear(10, 10)
+ self.fc1.weight = torch.nn.Parameter(weights)
+ self.fc1.bias = torch.nn.Parameter(bias)
+ self.fc2 = TestLinear(10, 10)
+ self.fc2.weight = torch.nn.Parameter(weights2)
+ self.fc2.bias = torch.nn.Parameter(bias2)
- # Python AST Frontend , Python 3-style type annotations , Script method
- @unittest.skipIf(not PY35, "Python 3.5 needed")
- def test_annot_ast_py3_method(self):
- code = dedent('''
- from typing import Tuple, List, Optional
- from torch import Tensor
- from torch.jit.annotations import BroadcastingList2, \\
- BroadcastingList3
- import torch
- class FooModule(torch.jit.ScriptModule):
- @torch.jit.script_method
- def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
- return x, x
- instance = FooModule()
- ''')
+ @torch.jit.script_method
+ def forward(self, x):
+ return x + self.fc1(x) + self.fc1(x) + self.fc2(x)
- test_str = []
- for pair in self.type_input_return_pairs():
- fn = self._get_py3_code(self.format_code(code, pair), 'instance')
- test_str.append(fn.__getattr__('foo').pretty_print_schema())
- self.assertExpected("\n".join(test_str))
+ strong_mod = Strong()
- # Python AST Frontend , MyPy-style type comments , Script function
- @unittest.skipIf(not PY35, "Python 3.5 needed")
- def test_annot_ast_mypy_fn(self):
- code = dedent('''
- import torch
- @torch.jit.script
- def foo(x, y):
- # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
- return x, x
- ''')
+ # Run same calculation as module
+ inp = torch.ones(10)
+ lin = torch.nn.Linear(10, 10)
+ lin.weight = torch.nn.Parameter(weights)
+ lin.bias = torch.nn.Parameter(bias)
+ lin2 = torch.nn.Linear(10, 10)
+ lin2.weight = torch.nn.Parameter(weights2)
+ lin2.bias = torch.nn.Parameter(bias2)
+ expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10)
- test_str = []
- for pair in self.type_input_return_pairs():
- fn = self._get_py3_code(self.format_code(code, pair), 'foo')
- test_str.append(fn.__getattr__('forward').pretty_print_schema())
- self.assertExpected("\n".join(test_str))
+ self.assertEqual(strong_mod(inp), expected_result)
+ self.assertExportImportModule(strong_mod, (inp,))
- # Python AST Frontend , MyPy-style type comments , Script method
- @unittest.skipIf(not PY35, "Python 3.5 needed")
- def test_annot_ast_mypy_method(self):
- code = dedent('''
- import torch
- class FooModule(torch.jit.ScriptModule):
- @torch.jit.script_method
- def foo(self, x, y):
- # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
- return x, x
- instance = FooModule()
- ''')
+ def test_weak_module_nested(self):
+ @torch._jit_internal.weak_module
+ class OtherWeak(torch.nn.Module):
+ __constants__ = ['constant']
- test_str = []
- for pair in self.type_input_return_pairs():
- fn = self._get_py3_code(self.format_code(code, pair), 'instance')
- test_str.append(fn.__getattr__('foo').pretty_print_schema())
- self.assertExpected("\n".join(test_str))
+ def __init__(self, in_features, out_features):
+ super(OtherWeak, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
+ self.bias = torch.nn.Parameter(torch.ones(out_features))
+ self.constant = 3
- def test_method_casts_script(self):
- cast_types = [
- 'byte', 'char', 'double', 'float', 'int', 'long', 'short'
- ]
+ @torch._jit_internal.weak_script_method
+ def forward(self, x):
+ return x * x + self.constant + F.linear(x, self.weight, self.bias)
- for cast_type in cast_types:
- cu = torch.jit.CompilationUnit('''
- def cast_to(x):
- return x.{cast_type}()
- '''.format(cast_type=cast_type))
+ class OtherStrong(torch.jit.ScriptModule):
- x = torch.rand(3, 4, 5) * 128
- cu_result = cu.cast_to(x)
- reference = getattr(x, cast_type)()
- self.assertEqual(cu_result, reference)
+ def __init__(self):
+ super(OtherStrong, self).__init__()
- def test_listconstruct_erasure(self):
- class FooMod(torch.nn.Module):
+ @torch.jit.script_method
def forward(self, x):
- mask = x < 0.0
- return x[mask]
+ return x + 27
- import io
- f = io.BytesIO()
- self.assertExpected(torch.onnx.export_to_pretty_string(
- FooMod(), (torch.rand(3, 4),), f,
- operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK))
+ @torch._jit_internal.weak_module
+ class Weak(torch.nn.Module):
+ def __init__(self, in_features, out_features):
+ super(Weak, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = torch.nn.Parameter(2 * torch.ones(out_features, in_features))
+ self.bias = torch.nn.Parameter(2 * torch.ones(out_features))
+ self.weak_submodule = OtherWeak(10, 10)
+ self.strong_submodule = OtherStrong()
- def test_trace_checker_arange_as_constant(self):
- with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
- @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)])
- def foo(x):
- y = torch.arange(0, x.shape[0]).double()
- return x + y.unsqueeze(1)
+ @torch._jit_internal.weak_script_method
+ def forward(self, x):
+ return x + self.weak_submodule(x) + self.strong_submodule(x) \
+ + F.linear(x, self.weight, self.bias)
- @suppress_warnings
- def test_trace_checker_dot_data(self):
- with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value '
- r'across invocations'):
- @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
- def foo(x):
- y = x.data
- return x + y
+ class Strong(torch.jit.ScriptModule):
+ __constants__ = ['constant']
- @suppress_warnings
- def test_trace_checker_control_flow(self):
- def foo(x):
- for _ in range(x.size(0)):
- x = torch.neg(x)
- return x
+ def __init__(self):
+ super(Strong, self).__init__()
+ self.weak = Weak(10, 10)
- with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
- torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])
+ @torch.jit.script_method
+ def forward(self, x):
+ return x + self.weak(x)
- @suppress_warnings
- def test_trace_checker_memoization(self):
- with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
- def foo(x):
- if not hasattr(foo, 'cache'):
- foo.cache = torch.neg(x)
- return x + foo.cache
+ strong_mod = Strong()
+ inp = torch.randn(10)
+ result = strong_mod(inp)
+ expected_result = inp + (inp + inp * inp + inp + 27) + 3 \
+ + F.linear(inp, torch.ones(10, 10), torch.ones(10)) \
+ + F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
+ self.assertEqual(result, expected_result)
- traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
+ def test_weak_module_submodule(self):
+ @torch._jit_internal.weak_module
+ class Weak(torch.nn.Module):
+ def __init__(self):
+ super(Weak, self).__init__()
+ self.param = torch.nn.Parameter(100 * torch.ones(5))
- def checkTracerWarning(self, *args, **kwargs):
- with warnings.catch_warnings(record=True) as warns:
- torch.jit.trace(*args, **kwargs)
- self.assertGreater(len(warns), 0)
- for warn in warns:
- self.assertIn("cause the trace to be incorrect", str(warn.message))
+ @torch._jit_internal.weak_script_method
+ def forward(self, x):
+ return x + self.param
+
+ weak = Weak()
- def test_trace_checker_slice_lhs(self):
- def foo(x):
- for i in range(3):
- x[i, :] = torch.zeros(4)
- return x
+ class OtherStrong(torch.jit.ScriptModule):
+ def __init__(self):
+ super(OtherStrong, self).__init__()
+ self.weak = weak
+ self.weak2 = weak
- self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(3, 4)]),
- 'Output nr 1. of the traced function does not match the '
- 'corresponding output of the Python function')
+ @torch.jit.script_method
+ def forward(self, x):
+ return x + self.weak(x)
- def test_trace_checker_inplace_on_view(self):
- def foo(x):
- x.view(-1).add_(-x.view(-1))
- return x
+ class Strong(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Strong, self).__init__()
+ self.weak = Weak()
- self.assertWarnsRegex(lambda: torch.jit.trace(foo,
- torch.rand(3, 4),
- check_inputs=[torch.rand(5, 6)],
- _force_outplace=True),
- 'Output nr 1. of the traced function does not match the '
- 'corresponding output of the Python function')
+ @torch.jit.script_method
+ def forward(self, x):
+ return self.weak(x) + weak(x)
- def test_lhs_index_fails(self):
- def foo(x):
- x[0, 1] = 4
- return x
- self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
+ other_strong_mod = OtherStrong()
- def test_lhs_index_trivial(self):
- def foo(y, x):
- y[...] = x
- return y
- self.checkTrace(foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False)
+ self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)
- def test_inplace_warn(self):
- def foo(x):
- x.view(-1).add_(-x.view(-1))
- return x
- self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
+ with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"):
+ strong_mod = Strong()
- @suppress_warnings
- def test_trace_checker_dropout_train(self):
- def foo(x):
- return torch.dropout(x, p=0.5, train=True)
+ def test_weak_module_copying(self):
+ class Submodule(torch.nn.Module):
+ def __init__(self):
+ super(Submodule, self).__init__()
- self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
- 'Output nr 1. of the traced function does not match the '
- 'corresponding output of the Python function')
- self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
- 'Trace had nondeterministic nodes')
+ def forward(self, x):
+ return x + 100
- def test_trace_checker_dropout_notrain(self):
- input = torch.rand(3, 4)
+ @torch._jit_internal.weak_module
+ class Weak(torch.nn.Module):
+ def __init__(self, in_features, out_features):
+ super(Weak, self).__init__()
+ self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
+ self.bias = torch.nn.Parameter(torch.ones(out_features))
+ self.register_buffer("buffer", torch.ones(out_features))
+ self.submodule = Submodule()
- @_trace(input)
- def foo(x):
- return torch.dropout(x, p=0.5, train=False)
+ @torch._jit_internal.weak_script_method
+ def forward(self, x):
+ return F.linear(x, self.weight, self.bias) \
+ + self.buffer + self.submodule(x)
- self.assertEqual(foo(input), input)
+ class Strong(torch.jit.ScriptModule):
+ def __init__(self, weak):
+ super(Strong, self).__init__()
+ self.weak = weak
- def test_export_dynamic_slice(self):
- class DynamicSliceExportMod(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
- retval = x[0]
- for i in range(x.size(1)):
- retval += torch.sum(x[0:i], dim=0)
- return retval
+ return self.weak(x)
- mod = DynamicSliceExportMod()
+ inp = torch.ones(5, 5) * 5
+ weak_mod = Weak(5, 5)
+ strong_mod = Strong(weak_mod)
- input = torch.rand(3, 4, 5)
- example_outs = mod(input)
+ self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule))
+ self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule))
- f = io.BytesIO()
- exported = torch.onnx.export_to_pretty_string(
- DynamicSliceExportMod(), (input,), f, example_outputs=example_outs)
- self.assertExpected(exported)
+ self.assertIs(strong_mod.weak.weight, weak_mod.weight)
+ self.assertIs(strong_mod.weak.buffer, weak_mod.buffer)
+ self.assertIs(strong_mod.weak.submodule, weak_mod.submodule)
- def test_string_frontend_elif(self):
- code = '''
- def elif_test(niter : int):
- rv = 0
- for i in range(niter):
- if i % 3 == 0 and i % 5 == 0:
- rv += 35
- elif i % 3 == 0:
- rv += 3
- elif i % 5 == 0:
- rv += 5
- else:
- rv += i
- return rv
- '''
+ # Test lookup fallback
+ weak_mod.new_attribute = 10
+ self.assertIs(strong_mod.weak.new_attribute, weak_mod.new_attribute)
- self.checkScript(code, (101,), name='elif_test', outputs=3028)
+ weak_mod.weight.data += torch.ones(5, 5) * 100
+ self.assertTrue(strong_mod(inp).allclose(weak_mod(inp)))
- def test_addmm_fusion(self):
- class AddmmWrapper(torch.nn.Module):
- def forward(self, x, y, c):
- return torch.mm(x, y) + c
+ # Re-assignment is not tracked
+ weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
+ self.assertFalse(strong_mod(inp).allclose(weak_mod(inp)))
- # Test addmm fusion is disabled for normal Jit
- x, y, c = torch.rand(3, 4), torch.rand(4, 5), torch.rand(3, 5)
- f = io.BytesIO()
- pretty = torch.onnx.export_to_pretty_string(AddmmWrapper(), (x, y, c), f)
- self.assertExpected(pretty, 'onnx')
+ def test_backend_cudnn_enabled(self):
+ # Only test that this compiles
+ @torch.jit.script
+ def fn(x):
+ if torch.backends.cudnn.enabled:
+ x = x + 2
+ else:
+ x = x + 3
+ return x
- jit_trace = torch.jit.trace(AddmmWrapper(), (x, y, c))
- ge_graph = jit_trace.__getattr__('forward').graph_for(x, y, c)
- self.assertExpectedGraph(ge_graph, 'jit')
+ def test_inplace_add(self):
- def test_pyop_exception_message(self):
- class Foo(torch.jit.ScriptModule):
- def __init__(self):
- super(Foo, self).__init__()
- self.conv = nn.Conv2d(1, 10, kernel_size=5)
+ def foo(a, b):
+ c = a + b
+ c.add_(b)
+ return c
+ self.checkScript(foo, (torch.rand(3), torch.rand(3)))
- @torch.jit.script_method
- def forward(self, x):
- return self.conv(x)
- foo = Foo()
- # testing that the correct error message propagates
- with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
- foo(torch.ones([123])) # wrong size
+ def test_add_out(self):
+ def foo(a, b):
+ c = a + b
+ e = 2 * a
+ torch.add(c, b, out=e)
+ return e
+ self.checkScript(foo, (torch.rand(3), torch.rand(3)))
- def test_exceptions(self):
- cu = torch.jit.CompilationUnit('''
- def foo(cond):
- if bool(cond):
- raise ValueError(3)
- return 1
- ''')
+ def test_augmented_assign(self):
+ def foo(a, b):
+ a += b
+ a -= b
+ a /= b
+ a *= b
+ return a, b
+ self.checkScript(foo, (torch.rand(3), torch.rand(3)), check_expected=True)
- cu.foo(torch.tensor(0))
- with self.assertRaisesRegex(torch.jit.Error, "Exception"):
- cu.foo(torch.tensor(1))
+ def test_pass(self):
+ def foo(x):
+ # type: (bool) -> int
+ for _i in range(3):
+ pass
+ if x:
+ pass
+ else:
+ pass
+ return 3
- @torch.jit.script
- def foo(cond):
- a = 3
- if bool(cond):
- raise ArbitraryError(a, "hi")
- if False:
- raise ArbitraryError
- return a
+ self.checkScript(foo, (True,))
- foo(torch.tensor(0))
- # we don't currently validate the name of the exception
- with self.assertRaisesRegex(torch.jit.Error, "Exception"):
- foo(torch.tensor(1))
+ def test_optional_conversion(self):
+ @torch.jit.script
+ def other_fn(x=None):
+ # type: (Optional[int]) -> int
+ return torch.jit._unwrap_optional(x)
@torch.jit.script
- def foo_except_used():
- a = Exception()
- print(a)
- raise a
+ def fn(x):
+ # type: (int) -> int
+ return other_fn(x)
- # a not DCEd
- with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"):
- foo_except_used()
+ self.assertEqual(fn(2), 2)
- # We don't validate the expr following raise
@torch.jit.script
- def foo():
- raise 3 + 4
-
- # no control flow analysis yet
- with self.assertRaisesRegex(RuntimeError, "undefined value a"):
- @torch.jit.script
- def foo():
- if True:
- a = 1
- else:
- raise Exception("Hi")
- return a
+ def unify_to_optional(x):
+ # type: (bool) -> Optional[int]
+ if x:
+ a = None
+ else:
+ a = 2
+ return a
- def test_assertions(self):
- cu = torch.jit.CompilationUnit('''
- def foo(cond):
- assert bool(cond), "hi"
- return 0
- ''')
+ self.assertEqual(unify_to_optional(True), None)
+ self.assertEqual(unify_to_optional(False), 2)
- cu.foo(torch.tensor(1))
- with self.assertRaisesRegex(torch.jit.Error, "Exception"):
- cu.foo(torch.tensor(0))
+ @torch.jit.script
+ def opt_list(x):
+ # type: (Optional[List[float]]) -> int
+ return 2
@torch.jit.script
- def foo(cond):
- assert bool(cond), "hi"
+ def broadcast_opt_list(x):
+ # type: (Optional[BroadcastingList2[float]]) -> int
+ return 2
- foo(torch.tensor(1))
- # we don't currently validate the name of the exception
- with self.assertRaisesRegex(torch.jit.Error, "Exception"):
- foo(torch.tensor(0))
+ @torch.jit.script
+ def opt_list_tuple_caller(x):
+ # type: (Tuple[float, float]) -> int
+ return opt_list(x) + broadcast_opt_list(x)
- def test_weak_script_function(self):
- outer_var = 10
- outer_var2 = 11
+ self.assertEqual(opt_list_tuple_caller((2., 3.)), 4)
- def not_a_script_fn(x):
- return x + 2
+ def test_lhs_indexing(self):
+ def foo(a, b):
+ a = a.clone()
+ a[0] = b
+ return a
+ self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
- @torch.jit.script
- def even_more_inner(x):
- return x + 1
+ def test_lhs_advanced_indexing_assignment(self):
+ def foo(x, y):
+ a = torch.exp(x)
+ b = x == 1
+ a[b] = y[b]
+ return a
+ self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
- @torch.jit.script
- def inner(x):
- return not_a_script_fn(x) + x + even_more_inner(x)
+ def test_lhs_advanced_indexing_augmented_assignment(self):
+ def foo(x, y):
+ a = torch.exp(x)
+ b = x == 1
+ a[b] += y[b]
+ return a
+ self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
- @torch.jit.script
- def strong_script_fn(x):
- if bool(x.norm() > 2):
- x = x + 3
- return x + 4 + inner(x)
+ def test_lhs_indexing_list(self):
+ def foo(a, b):
+ ls = [a]
+ ls[0] = b
+ return ls
+ self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
- @torch._jit_internal.weak_script
- def weak_script_fn_inner(x):
- return x + 6 + not_a_script_fn(x)
+ def test_inplace_copy_script(self):
+ def foo(x):
+ a = torch.rand(3, 4)
+ a.copy_(x)
+ return a
+ self.checkScript(foo, (torch.rand(3, 4),))
- @torch._jit_internal.weak_script
- def weak_script_fn(x):
- return x + 5 + weak_script_fn_inner(x) + weak_script_fn_inner(x)
+ def test_lhs_indexing_increment(self):
+ def foo(a, b):
+ a[0] += b
+ return a
+ self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
- def fn(x):
- x = not_a_script_fn(x)
- x = strong_script_fn(x)
- return weak_script_fn(x)
+ def test_lhs_indexing_increment_list(self):
+ def foo(a, b):
+ a = a.clone()
+ ls = [a, b]
+ ls[0] += b
+ return ls
+ self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
- input = torch.randn(3, 4, 5)
- self.checkScript(fn, (input,))
+ def test_lhs_indexing_increment_list_prim(self):
+ def foo():
+ ls = [1, 2, 3]
+ ls[0] += 5
+ return ls
+ self.checkScript(foo, ())
- def test_python_op_exception(self):
- def python_op(x):
- raise Exception("bad!")
+ def test_lhs_indexing_multi(self):
+ def foo(a, b):
+ a = a.clone()
+ foo, a[0], bar = (1, b, 3)
+ return foo, a, bar
+ self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
- @torch.jit.script
- def fn(x):
- return python_op(x)
+ def test_bool_dispatch(self):
+ with self.disableModuleHook(): # TODO: Python print broadcasting list
+ def kwarg_false(x):
+ # type: (Tensor) -> Tensor
+ return F.max_pool1d(x, 1, 1, return_indices=False)
+ self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
- with self.assertRaisesRegex(RuntimeError, "operation failed in interpreter"):
- fn(torch.tensor(4))
+ def kwarg_true(x):
+ # type: (Tensor) -> Tuple[Tensor, Tensor]
+ return F.max_pool1d(x, 1, 1, return_indices=True)
+ self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
- def test_trace_contiguous(self):
- def foo(x):
- return x[:, :, ::2].contiguous().view(12)
+ def full_kwarg_false(x):
+ # type: (Tensor) -> Tensor
+ return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
+ self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
- x = torch.rand(2, 3, 4)
- traced = torch.jit.trace(foo, (x,))
- y = traced(x)
- self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())
+ def full_kwarg_true(x):
+ # type: (Tensor) -> Tuple[Tensor, Tensor]
+ return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
+ self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
- # This tests the logic in THPVariable_contiguous. There is short-circuiting
- # code that prevents us from even getting to VariableType::contiguous, since
- # it is an optimization that prevents us from acquiring the GIL for touching
- # the device. We needed to add the tracing logic directly into the
- # THPVariable_contiguous function only for the path where we are skipping
- # dispatch into contiguous. We should see an aten::contiguous in this trace!
- def test_trace_contiguous_short_circuit(self):
- def foo(x):
- return x.contiguous()
+ def use_default(x):
+ # type: (Tensor) -> Tensor
+ return F.max_pool1d(x, 1, 1)
+ self.checkScript(use_default, (torch.randn(3, 3, 3),))
- x = torch.rand(2, 3, 4)
- traced = torch.jit.trace(foo, (x,))
- self.assertExpectedGraph(traced.graph)
+ def arg_false(x):
+ # type: (Tensor) -> Tensor
+ return F.max_pool1d(x, 1, 1, 0, 1, False, False)
+ self.checkScript(arg_false, (torch.randn(3, 3, 3),))
- def test_weak_module(self):
+ def arg_true(x):
+ # type: (Tensor) -> Tuple[Tensor, Tensor]
+ return F.max_pool1d(x, 1, 1, 0, 1, False, True)
+ self.checkScript(arg_true, (torch.randn(3, 3, 3),))
- @torch._jit_internal.weak_module
- class Weak(torch.nn.Module):
- __constants__ = ['number']
+ def test_infer_size(self):
+ from torch._C import _infer_size
- def __init__(self):
- super(Weak, self).__init__()
- self.number = 199
+ def fn(x, y):
+ # type: (Tensor, Tensor) -> List[int]
+ return _infer_size(x.size(), y.size())
- def python_op_in_weak_module(self, x):
- return x + 123
+ self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
- @torch._jit_internal.weak_script_method
- def forward(self, x):
- return 55 + self.number + self.python_op_in_weak_module(x)
+ def test_mutable_dce(self):
+ @torch.jit.script
+ def foo():
+ a = torch.rand(2, 3)
+ a += torch.rand(2, 3)
+ b = torch.rand(2, 3)
+ b += torch.rand(2, 3)
+ # b should be cleaned up but not a
+ return a
- class OtherStrong(torch.jit.ScriptModule):
- __constants__ = ['number']
+ self.assertExpectedGraph(foo.graph)
- def __init__(self):
- super(OtherStrong, self).__init__()
- self.number = 357
+ def test_mutable_dce_block(self):
+ @torch.jit.script
+ def foo():
+ a = torch.rand(2, 3)
+ a += torch.rand(2, 3)
+ b = torch.rand(2, 3)
+ if bool(a > torch.zeros(2, 3)):
+ b += torch.rand(2, 3)
+ a += torch.rand(2, 3)
+ # a should be cleaned up but not b
+ return b
- def python_op_in_strong_module(self, x):
- return x + 456
+ self.assertExpectedGraph(foo.graph)
- @torch.jit.script_method
- def forward(self, x):
- return x + self.number + self.python_op_in_strong_module(x)
+ def test_mutable_dce_graph_input(self):
+ @torch.jit.script
+ def foo(a):
+ a += torch.rand(2, 3)
+ # shouldn't clean up `a` even though it's not used in the output
- class Passthrough(torch.jit.ScriptModule):
- def __init__(self):
- super(Passthrough, self).__init__()
- self.weak = Weak()
+ self.assertExpectedGraph(foo.graph)
+
+ def test_mutable_dce_list(self):
+ @torch.jit.script
+ def foo(a):
+ l = []
+ l.append(a)
+ c = l[0]
+ b = torch.rand(2, 3)
+ c += torch.rand(2, 3)
+ return b
+
+ self.assertExpectedGraph(foo.graph)
+
+ def test_mutable_dce_loop(self):
+ @torch.jit.script
+ def foo(a):
+ l = []
+ l.append(a)
+ i = 0
+ b = torch.rand(2, 3)
+ while i < 1:
+ dead = torch.rand(2, 3)
+ c = l[0]
+ c += torch.rand(2, 3)
+ i += 1
+ return b
- @torch.jit.script_method
- def forward(self, x):
- return self.weak(x)
+ self.assertExpectedGraph(foo.graph)
- weak_mod = Weak()
- x = torch.ones(1)
- expected_result = 55 + 199 + (x + 123)
- # Ensure weak mod is running without the JIT by passing the wrong type
- # (i.e. not a tensor)
- weak_mod(2)
+class MnistNet(nn.Module):
+ def __init__(self):
+ super(MnistNet, self).__init__()
+ self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
+ self.conv2_drop = nn.Dropout2d()
+ self.fc1 = nn.Linear(320, 50)
+ self.fc2 = nn.Linear(50, 10)
- python_result = weak_mod(x)
- strong_mod = Passthrough()
- script_result = strong_mod(x)
+ def forward(self, x):
+ x = F.relu(F.max_pool2d(self.conv1(x), 2))
+ x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
+ x = x.view(-1, 320)
+ x = F.relu(self.fc1(x))
+ x = F.dropout(x, training=self.training)
+ x = self.fc2(x)
+ return F.log_softmax(x, dim=1)
- self.assertEqual(python_result, expected_result)
- self.assertEqual(script_result, expected_result)
- class Strong(torch.jit.ScriptModule):
- def __init__(self):
- super(Strong, self).__init__()
- self.weak = Weak()
- self.strong = OtherStrong()
+class TestEndToEndHybridFrontendModels(JitTestCase):
+ @staticmethod
+ def _test_dcgan_models(self, device, check_export_import=True):
+ class DCGANGenerator(nn.Module):
+ def __init__(self, nz, ngf, nc):
+ super(DCGANGenerator, self).__init__()
+ self.main = nn.Sequential(
+ # input is Z, going into a convolution
+ nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
+ nn.BatchNorm2d(ngf * 8),
+ nn.ReLU(True),
+ # state size. (ngf*8) x 4 x 4
+ nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
+ nn.BatchNorm2d(ngf * 4),
+ nn.ReLU(True),
+ # state size. (ngf*4) x 8 x 8
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
+ nn.BatchNorm2d(ngf * 2),
+ nn.ReLU(True),
+ # state size. (ngf*2) x 16 x 16
+ nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
+ nn.BatchNorm2d(ngf),
+ nn.ReLU(True),
+ # state size. (ngf) x 32 x 32
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
+ nn.Tanh()
+ # state size. (nc) x 64 x 64
+ )
- @torch.jit.script_method
- def forward(self, x):
- y = 2 * x
- return y + 1 + self.weak(y) + self.strong(y)
+ def forward(self, input):
+ return self.main(input)
- strong_mod = Strong()
- strong_mod2 = Strong()
- x = torch.ones(1)
- expected_result = (x * 2) + 1 + (55 + 199 + x * 2 + 123) + (x * 2 + 357 + x * 2 + 456)
- script_result = strong_mod(x)
- script_result2 = strong_mod2(x)
- self.assertEqual(script_result, expected_result)
- self.assertEqual(script_result, script_result2)
+ class DCGANDiscriminator(nn.Module):
+ def __init__(self, nc, ndf):
+ super(DCGANDiscriminator, self).__init__()
+ self.main = nn.Sequential(
+ # input is (nc) x 64 x 64
+ nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ # state size. (ndf) x 32 x 32
+ nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
+ nn.BatchNorm2d(ndf * 2),
+ nn.LeakyReLU(0.2, inplace=True),
+ # state size. (ndf*2) x 16 x 16
+ nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
+ nn.BatchNorm2d(ndf * 4),
+ nn.LeakyReLU(0.2, inplace=True),
+ # state size. (ndf*4) x 8 x 8
+ nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
+ nn.BatchNorm2d(ndf * 8),
+ nn.LeakyReLU(0.2, inplace=True),
+ # state size. (ndf*8) x 4 x 4
+ nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
+ nn.Sigmoid()
+ )
- def test_weak_module_parameters_and_buffers(self):
- weights = torch.randn(10, 10)
- bias = torch.randn(10)
- weights2 = torch.randn(10, 10)
- bias2 = torch.randn(10)
+ def forward(self, input):
+ return self.main(input).view(-1, 1).squeeze(1)
- @torch._jit_internal.weak_module
- class TestLinear(torch.nn.Module):
- def __init__(self, in_features, out_features):
- super(TestLinear, self).__init__()
- self.in_features = in_features
- self.out_features = out_features
- self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
- self.bias = torch.nn.Parameter(torch.Tensor(out_features))
- self.register_buffer('counter', torch.ones(out_features))
- self.reset_parameters()
+ bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
+ self.checkTrace(DCGANGenerator(nz, ngf, nc).to(device),
+ (torch.rand(bs, nz, 1, 1, device=device),),
+ export_import=check_export_import)
+ example_input = DCGANGenerator(nz, ngf, nc).to(device)(torch.rand(bs, nz, 1, 1, device=device))
+ self.checkTrace(DCGANDiscriminator(nc, ndf).to(device), (example_input,),
+ export_import=check_export_import)
- def reset_parameters(self):
- torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
- if self.bias is not None:
- fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
- bound = 1 / math.sqrt(fan_in)
- torch.nn.init.uniform_(self.bias, -bound, bound)
+ def test_dcgan_models(self):
+ self._test_dcgan_models(self, device='cpu')
- @torch._jit_internal.weak_script_method
- def forward(self, input):
- return F.linear(input, self.weight, self.bias) + self.counter
+ @unittest.skipIf(not RUN_CUDA, "no CUDA")
+ @skipIfRocm
+ def test_dcgan_models_cuda(self):
+ # XXX: export_import on CUDA modules doesn't work (#11480)
+ self._test_dcgan_models(self, device='cuda', check_export_import=False)
- # Initialize a ScriptModule that uses the weak module above multiple times
- class Strong(torch.jit.ScriptModule):
+ @staticmethod
+ def _test_neural_style(self, device, check_export_import=True):
+ class TransformerNet(torch.nn.Module):
def __init__(self):
- super(Strong, self).__init__()
- self.fc1 = TestLinear(10, 10)
- self.fc1.weight = torch.nn.Parameter(weights)
- self.fc1.bias = torch.nn.Parameter(bias)
- self.fc2 = TestLinear(10, 10)
- self.fc2.weight = torch.nn.Parameter(weights2)
- self.fc2.bias = torch.nn.Parameter(bias2)
+ super(TransformerNet, self).__init__()
+ # Initial convolution layers
+ self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
+ self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
+ self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
+ self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
+ self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
+ self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
+ # Residual layers
+ self.res1 = ResidualBlock(128)
+ self.res2 = ResidualBlock(128)
+ self.res3 = ResidualBlock(128)
+ self.res4 = ResidualBlock(128)
+ self.res5 = ResidualBlock(128)
+ # Upsampling Layers
+ self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
+ self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
+ self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
+ self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
+ self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
+ # Non-linearities
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, X):
+ y = self.relu(self.in1(self.conv1(X)))
+ y = self.relu(self.in2(self.conv2(y)))
+ y = self.relu(self.in3(self.conv3(y)))
+ y = self.res1(y)
+ y = self.res2(y)
+ y = self.res3(y)
+ y = self.res4(y)
+ y = self.res5(y)
+ y = self.relu(self.in4(self.deconv1(y)))
+ y = self.relu(self.in5(self.deconv2(y)))
+ y = self.deconv3(y)
+ return y
+
+ class ConvLayer(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
+ super(ConvLayer, self).__init__()
+ reflection_padding = kernel_size // 2
+ self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
+ self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
- @torch.jit.script_method
def forward(self, x):
- return x + self.fc1(x) + self.fc1(x) + self.fc2(x)
+ out = self.reflection_pad(x)
+ out = self.conv2d(out)
+ return out
- strong_mod = Strong()
+ class ResidualBlock(torch.nn.Module):
+ """ResidualBlock
+ introduced in: https://arxiv.org/abs/1512.03385
+ recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
+ """
- # Run same calculation as module
- inp = torch.ones(10)
- lin = torch.nn.Linear(10, 10)
- lin.weight = torch.nn.Parameter(weights)
- lin.bias = torch.nn.Parameter(bias)
- lin2 = torch.nn.Linear(10, 10)
- lin2.weight = torch.nn.Parameter(weights2)
- lin2.bias = torch.nn.Parameter(bias2)
- expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10)
+ def __init__(self, channels):
+ super(ResidualBlock, self).__init__()
+ self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
+ self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
+ self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
+ self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
+ self.relu = torch.nn.ReLU()
- self.assertEqual(strong_mod(inp), expected_result)
- self.assertExportImportModule(strong_mod, (inp,))
+ def forward(self, x):
+ residual = x
+ out = self.relu(self.in1(self.conv1(x)))
+ out = self.in2(self.conv2(out))
+ out = out + residual
+ return out
- def test_weak_module_nested(self):
- @torch._jit_internal.weak_module
- class OtherWeak(torch.nn.Module):
- __constants__ = ['constant']
+ class UpsampleConvLayer(torch.nn.Module):
+ """UpsampleConvLayer
+ Upsamples the input and then does a convolution. This method gives better results
+ compared to ConvTranspose2d.
+ ref: http://distill.pub/2016/deconv-checkerboard/
+ """
- def __init__(self, in_features, out_features):
- super(OtherWeak, self).__init__()
- self.in_features = in_features
- self.out_features = out_features
- self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
- self.bias = torch.nn.Parameter(torch.ones(out_features))
- self.constant = 3
+ def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
+ super(UpsampleConvLayer, self).__init__()
+ self.upsample = upsample
+ if upsample:
+ self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
+ reflection_padding = kernel_size // 2
+ self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
+ self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
- @torch._jit_internal.weak_script_method
def forward(self, x):
- return x * x + self.constant + F.linear(x, self.weight, self.bias)
+ x_in = x
+ if self.upsample:
+ x_in = self.upsample_layer(x_in)
+ out = self.reflection_pad(x_in)
+ out = self.conv2d(out)
+ return out
- class OtherStrong(torch.jit.ScriptModule):
+ self.checkTrace(TransformerNet(), (torch.rand(5, 3, 64, 64),), export_import=check_export_import)
- def __init__(self):
- super(OtherStrong, self).__init__()
+ def test_neural_style(self):
+ self._test_neural_style(self, device='cpu')
- @torch.jit.script_method
- def forward(self, x):
- return x + 27
+ @unittest.skipIf(not RUN_CUDA, "no CUDA")
+ def test_neural_style_cuda(self):
+ # XXX: export_import on CUDA modules doesn't work (#11480)
+ self._test_neural_style(self, device='cuda', check_export_import=False)
- @torch._jit_internal.weak_module
- class Weak(torch.nn.Module):
- def __init__(self, in_features, out_features):
- super(Weak, self).__init__()
- self.in_features = in_features
- self.out_features = out_features
- self.weight = torch.nn.Parameter(2 * torch.ones(out_features, in_features))
- self.bias = torch.nn.Parameter(2 * torch.ones(out_features))
- self.weak_submodule = OtherWeak(10, 10)
- self.strong_submodule = OtherStrong()
+ @staticmethod
+ def _test_mnist(self, device, check_export_import=True):
+ # eval() is present because dropout makes this nondeterministic
+ self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
+ export_import=check_export_import)
- @torch._jit_internal.weak_script_method
- def forward(self, x):
- return x + self.weak_submodule(x) + self.strong_submodule(x) \
- + F.linear(x, self.weight, self.bias)
+ def test_mnist(self):
+ self._test_mnist(self, device='cpu')
- class Strong(torch.jit.ScriptModule):
- __constants__ = ['constant']
+ @unittest.skipIf(not RUN_CUDA, "no CUDA")
+ @skipIfRocm
+ def test_mnist_cuda(self):
+ # XXX: export_import on CUDA modules doesn't work (#11480)
+ self._test_mnist(self, device='cuda', check_export_import=False)
- def __init__(self):
- super(Strong, self).__init__()
- self.weak = Weak(10, 10)
+ @unittest.skipIf(not RUN_CUDA, "no CUDA")
+ @skipIfRocm
+ def test_mnist_training_leaks_no_memory_cuda(self):
+ net = MnistNet().cuda()
+ # MnistNet uses dropout, don't check its trace
+ traced_net = torch.jit.trace(net, [torch.randn(5, 1, 28, 28, device='cuda')],
+ check_trace=False)
- @torch.jit.script_method
- def forward(self, x):
- return x + self.weak(x)
+ def train(iters):
+ for _ in range(iters):
+ # Get some fake data
+ inp = torch.randn(5, 1, 28, 28, device='cuda')
+ out = traced_net(inp)
- strong_mod = Strong()
- inp = torch.randn(10)
- result = strong_mod(inp)
- expected_result = inp + (inp + inp * inp + inp + 27) + 3 \
- + F.linear(inp, torch.ones(10, 10), torch.ones(10)) \
- + F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
- self.assertEqual(result, expected_result)
+ # Here's some fake loss
+ out.sum().backward()
- def test_weak_module_submodule(self):
- @torch._jit_internal.weak_module
- class Weak(torch.nn.Module):
- def __init__(self):
- super(Weak, self).__init__()
- self.param = torch.nn.Parameter(100 * torch.ones(5))
+ # Zero out grads
+ traced_net.zero_grad()
- @torch._jit_internal.weak_script_method
- def forward(self, x):
- return x + self.param
+ # Set it up so the params have .grad fields so they are not reported as leaks
+ train(1)
- weak = Weak()
+ with self.assertLeaksNoCudaTensors():
+ train(5)
- class OtherStrong(torch.jit.ScriptModule):
+ @staticmethod
+ def _test_reinforcement_learning(self, device, test_export_import=True):
+ class Policy(nn.Module):
def __init__(self):
- super(OtherStrong, self).__init__()
- self.weak = weak
- self.weak2 = weak
+ super(Policy, self).__init__()
+ self.affine1 = nn.Linear(4, 128)
+ self.affine2 = nn.Linear(128, 2)
- @torch.jit.script_method
def forward(self, x):
- return x + self.weak(x)
+ x = F.relu(self.affine1(x))
+ action_scores = self.affine2(x)
+ return F.softmax(action_scores, dim=1)
- class Strong(torch.jit.ScriptModule):
- def __init__(self):
- super(Strong, self).__init__()
- self.weak = Weak()
+ self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
+ export_import=test_export_import)
- @torch.jit.script_method
- def forward(self, x):
- return self.weak(x) + weak(x)
+ def test_reinforcement_learning(self):
+ self._test_reinforcement_learning(self, device='cpu')
- other_strong_mod = OtherStrong()
+ @unittest.skipIf(not RUN_CUDA, "no CUDA")
+ def test_reinforcement_learning_cuda(self):
+ # XXX: export_import on CUDA modules doesn't work (#11480)
+ self._test_reinforcement_learning(self, device='cuda', test_export_import=False)
- self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)
+ @staticmethod
+ def _test_snli(self, device, check_export_import=True):
+ class Bottle(nn.Module):
- with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"):
- strong_mod = Strong()
+ def forward(self, input):
+ if len(input.size()) <= 2:
+ return super(Bottle, self).forward(input)
+ size = input.size()[:2]
+ out = super(Bottle, self).forward(input.view(size[0] * size[1], -1))
+ return out.view(size[0], size[1], -1)
- def test_weak_module_copying(self):
- class Submodule(torch.nn.Module):
- def __init__(self):
- super(Submodule, self).__init__()
+ class Linear(Bottle, nn.Linear):
+ pass
- def forward(self, x):
- return x + 100
+ class Encoder(nn.Module):
- @torch._jit_internal.weak_module
- class Weak(torch.nn.Module):
- def __init__(self, in_features, out_features):
- super(Weak, self).__init__()
- self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
- self.bias = torch.nn.Parameter(torch.ones(out_features))
- self.register_buffer("buffer", torch.ones(out_features))
- self.submodule = Submodule()
+ def __init__(self, config):
+ super(Encoder, self).__init__()
+ self.config = config
+ input_size = config.d_proj if config.projection else config.d_embed
+ dropout = 0 if config.n_layers == 1 else config.dp_ratio
+ self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
+ num_layers=config.n_layers, dropout=dropout,
+ bidirectional=config.birnn)
- @torch._jit_internal.weak_script_method
- def forward(self, x):
- return F.linear(x, self.weight, self.bias) \
- + self.buffer + self.submodule(x)
+ def forward(self, inputs):
+ batch_size = inputs.size()[1]
+ state_shape = self.config.n_cells, batch_size, self.config.d_hidden
+ h0 = c0 = inputs.new_zeros(state_shape)
+ outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
+ return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
- class Strong(torch.jit.ScriptModule):
- def __init__(self, weak):
- super(Strong, self).__init__()
- self.weak = weak
+ class SNLIClassifier(nn.Module):
- @torch.jit.script_method
- def forward(self, x):
- return self.weak(x)
+ def __init__(self, config):
+ super(SNLIClassifier, self).__init__()
+ self.config = config
+ self.embed = nn.Embedding(config.n_embed, config.d_embed)
+ self.projection = Linear(config.d_embed, config.d_proj)
+ self.encoder = Encoder(config)
+ self.dropout = nn.Dropout(p=config.dp_ratio)
+ self.relu = nn.ReLU()
+ seq_in_size = 2 * config.d_hidden
+ if self.config.birnn:
+ seq_in_size *= 2
+ lin_config = [seq_in_size] * 2
+ self.out = nn.Sequential(
+ Linear(*lin_config),
+ self.relu,
+ self.dropout,
+ Linear(*lin_config),
+ self.relu,
+ self.dropout,
+ Linear(*lin_config),
+ self.relu,
+ self.dropout,
+ Linear(seq_in_size, config.d_out))
- inp = torch.ones(5, 5) * 5
- weak_mod = Weak(5, 5)
- strong_mod = Strong(weak_mod)
+ def forward(self, premise, hypothesis):
+ prem_embed = self.embed(premise)
+ hypo_embed = self.embed(hypothesis)
+ if self.config.fix_emb:
+ prem_embed = prem_embed.detach()
+ hypo_embed = hypo_embed.detach()
+ if self.config.projection:
+ prem_embed = self.relu(self.projection(prem_embed))
+ hypo_embed = self.relu(self.projection(hypo_embed))
+ premise = self.encoder(prem_embed)
+ hypothesis = self.encoder(hypo_embed)
+ scores = self.out(torch.cat([premise, hypothesis], 1))
+ return scores
- self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule))
- self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule))
+ class Config:
+ n_embed = 100
+ d_embed = 100
+ d_proj = 300
+ dp_ratio = 0.0 # For deterministic testing TODO: change by fixing seed in checkTrace?
+ d_hidden = 300
+ birnn = True
+ d_out = 300
+ fix_emb = True
+ projection = True
+ n_layers = 2
+ n_cells = 4 # 2 * n_layers because birnn = True
- self.assertIs(strong_mod.weak.weight, weak_mod.weight)
- self.assertIs(strong_mod.weak.buffer, weak_mod.buffer)
- self.assertIs(strong_mod.weak.submodule, weak_mod.submodule)
+ premise = torch.LongTensor(48, 128).random_(0, 100).to(device)
+ hypothesis = torch.LongTensor(24, 128).random_(0, 100).to(device)
- # Test lookup fallback
- weak_mod.new_attribute = 10
- self.assertIs(strong_mod.weak.new_attribute, weak_mod.new_attribute)
+ self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
+ inputs_require_grads=False, export_import=check_export_import)
- weak_mod.weight.data += torch.ones(5, 5) * 100
- self.assertTrue(strong_mod(inp).allclose(weak_mod(inp)))
+ @skipIfRocm
+ def test_snli(self):
+ self._test_snli(self, device='cpu')
- # Re-assignment is not tracked
- weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
- self.assertFalse(strong_mod(inp).allclose(weak_mod(inp)))
+ @skipIfRocm
+ @unittest.skipIf(not RUN_CUDA, "no CUDA")
+ def test_snli_cuda(self):
+ # XXX: export_import on CUDA modules doesn't work (#11480)
+ self._test_snli(self, device='cuda', check_export_import=False)
- def test_backend_cudnn_enabled(self):
- # Only test that this compiles
- @torch.jit.script
- def fn(x):
- if torch.backends.cudnn.enabled:
- x = x + 2
- else:
- x = x + 3
- return x
+ @staticmethod
+ def _test_super_resolution(self, device, check_export_import=True):
+ import torch.nn.init as init
- def test_inplace_add(self):
+ class Net(nn.Module):
- def foo(a, b):
- c = a + b
- c.add_(b)
- return c
- self.checkScript(foo, (torch.rand(3), torch.rand(3)))
+ def __init__(self, upscale_factor):
+ super(Net, self).__init__()
- def test_add_out(self):
- def foo(a, b):
- c = a + b
- e = 2 * a
- torch.add(c, b, out=e)
- return e
- self.checkScript(foo, (torch.rand(3), torch.rand(3)))
+ self.relu = nn.ReLU()
+ self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
+ self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
+ self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
+ self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
+ self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
- def test_augmented_assign(self):
- def foo(a, b):
- a += b
- a -= b
- a /= b
- a *= b
- return a, b
- self.checkScript(foo, (torch.rand(3), torch.rand(3)), check_expected=True)
+ def forward(self, x):
+ x = self.relu(self.conv1(x))
+ x = self.relu(self.conv2(x))
+ x = self.relu(self.conv3(x))
+ x = self.pixel_shuffle(self.conv4(x))
+ return x
- def test_pass(self):
- def foo(x):
- # type: (bool) -> int
- for _i in range(3):
- pass
- if x:
- pass
- else:
- pass
- return 3
+ net = Net(upscale_factor=4).to(device)
+ self.checkTrace(net, (torch.rand(5, 1, 64, 64, device=device),),
+ export_import=check_export_import)
- self.checkScript(foo, (True,))
+ @skipIfRocm
+ def test_super_resolution(self):
+ self._test_super_resolution(self, device='cpu')
- def test_optional_conversion(self):
- @torch.jit.script
- def other_fn(x=None):
- # type: (Optional[int]) -> int
- return torch.jit._unwrap_optional(x)
+ @skipIfRocm
+ @unittest.skipIf(not RUN_CUDA, 'no CUDA')
+ def test_super_resolution_cuda(self):
+ # XXX: export_import on CUDA modules doesn't work (#11480)
+ self._test_super_resolution(self, device='cuda', check_export_import=False)
- @torch.jit.script
- def fn(x):
- # type: (int) -> int
- return other_fn(x)
+ @suppress_warnings
+ def test_time_sequence_prediction(self):
+ class Sequence(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Sequence, self).__init__()
+ self.lstm1 = nn.LSTMCell(1, 51)
+ self.lstm2 = nn.LSTMCell(51, 51)
+ self.linear = nn.Linear(51, 1)
- self.assertEqual(fn(2), 2)
+ # TODO: could not pass tuple to a python Op and type annotations
+ # is not descending to python signature, hence the wrapper
+ # see https://github.com/pytorch/pytorch/issues/8778
+ # and https://github.com/pytorch/pytorch/issues/8777
+ def test_lstm1(self, input, hx, cx):
+ # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
+ return self.lstm1(input, (hx, cx))
- @torch.jit.script
- def unify_to_optional(x):
- # type: (bool) -> Optional[int]
- if x:
- a = None
- else:
- a = 2
- return a
+ def test_lstm2(self, input, hx, cx):
+ # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
+ return self.lstm2(input, (hx, cx))
- self.assertEqual(unify_to_optional(True), None)
- self.assertEqual(unify_to_optional(False), 2)
+ # TODO: could not support tensor constructors in script
+ # see https://github.com/pytorch/pytorch/issues/8814
+ def test_tensor(self):
+ return torch.tensor([], dtype=torch.double)
- @torch.jit.script
- def opt_list(x):
- # type: (Optional[List[float]]) -> int
- return 2
+ @torch.jit.script_method
+ def forward(self, input):
+ # TODO: add future as input with default val
+ # see https://github.com/pytorch/pytorch/issues/8724
+ outputs = self.test_tensor()
+ h_t = torch.zeros((3, 51), dtype=torch.double)
+ c_t = torch.zeros((3, 51), dtype=torch.double)
+ h_t2 = torch.zeros((3, 51), dtype=torch.double)
+ c_t2 = torch.zeros((3, 51), dtype=torch.double)
- @torch.jit.script
- def broadcast_opt_list(x):
- # type: (Optional[BroadcastingList2[float]]) -> int
- return 2
+ output = torch.zeros([3, 51])
+ future = 2
- @torch.jit.script
- def opt_list_tuple_caller(x):
- # type: (Tuple[float, float]) -> int
- return opt_list(x) + broadcast_opt_list(x)
+ # TODO: chunk call should appear as the for loop iterable
+ # We hard-code it to 4 for now.
+ a, b, c, d = input.chunk(input.size(1), dim=1)
+ for input_t in (a, b, c, d):
+ h_t, c_t = self.test_lstm1(input_t, h_t, c_t)
+ h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
+ output = self.linear(h_t2)
+ outputs = torch.cat((outputs, output), 1)
+ for _ in range(future): # if we should predict the future
+ h_t, c_t = self.test_lstm1(output, h_t, c_t)
+ h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
+ output = self.linear(h_t2)
+ outputs = torch.cat((outputs, output), 1)
+ return outputs
- self.assertEqual(opt_list_tuple_caller((2., 3.)), 4)
+ # TODO: toggle export_import once above issues are fixed
+ self.checkTrace(Sequence(), (torch.rand(3, 4),),
+ export_import=False)
- def test_lhs_indexing(self):
- def foo(a, b):
- a = a.clone()
- a[0] = b
- return a
- self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
+ @staticmethod
+ def _test_vae(self, device, check_export_import=True):
+ class VAE(nn.Module):
+ def __init__(self):
+ super(VAE, self).__init__()
- def test_lhs_advanced_indexing_assignment(self):
- def foo(x, y):
- a = torch.exp(x)
- b = x == 1
- a[b] = y[b]
- return a
- self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
+ self.fc1 = nn.Linear(784, 400)
+ self.fc21 = nn.Linear(400, 20)
+ self.fc22 = nn.Linear(400, 20)
+ self.fc3 = nn.Linear(20, 400)
+ self.fc4 = nn.Linear(400, 784)
- def test_lhs_advanced_indexing_augmented_assignment(self):
- def foo(x, y):
- a = torch.exp(x)
- b = x == 1
- a[b] += y[b]
- return a
- self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
+ def encode(self, x):
+ h1 = F.relu(self.fc1(x))
+ return self.fc21(h1), self.fc22(h1)
- def test_lhs_indexing_list(self):
- def foo(a, b):
- ls = [a]
- ls[0] = b
- return ls
- self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
+ def reparameterize(self, mu, logvar):
+ if self.training:
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ return eps.mul(std).add_(mu)
+ else:
+ return mu
- def test_inplace_copy_script(self):
- def foo(x):
- a = torch.rand(3, 4)
- a.copy_(x)
- return a
- self.checkScript(foo, (torch.rand(3, 4),))
+ def decode(self, z):
+ h3 = F.relu(self.fc3(z))
+ return torch.sigmoid(self.fc4(h3))
- def test_lhs_indexing_increment(self):
- def foo(a, b):
- a[0] += b
- return a
- self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
+ def forward(self, x):
+ mu, logvar = self.encode(x.view(-1, 784))
+ z = self.reparameterize(mu, logvar)
+ return self.decode(z), mu, logvar
- def test_lhs_indexing_increment_list(self):
- def foo(a, b):
- a = a.clone()
- ls = [a, b]
- ls[0] += b
- return ls
- self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
+ # eval() is present because randn_like makes this nondeterministic
+ self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
+ export_import=check_export_import)
- def test_lhs_indexing_increment_list_prim(self):
- def foo():
- ls = [1, 2, 3]
- ls[0] += 5
- return ls
- self.checkScript(foo, ())
+ def test_vae(self):
+ self._test_vae(self, device='cpu')
- def test_lhs_indexing_multi(self):
- def foo(a, b):
- a = a.clone()
- foo, a[0], bar = (1, b, 3)
- return foo, a, bar
- self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
+ @unittest.skipIf(not RUN_CUDA, "no CUDA")
+ def test_vae_cuda(self):
+ # XXX: export_import on CUDA modules doesn't work (#11480)
+ self._test_vae(self, device='cuda', check_export_import=False)
- def test_bool_dispatch(self):
- with self.disableModuleHook(): # TODO: Python print broadcasting list
- def kwarg_false(x):
- # type: (Tensor) -> Tensor
- return F.max_pool1d(x, 1, 1, return_indices=False)
- self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
- def kwarg_true(x):
- # type: (Tensor) -> Tuple[Tensor, Tensor]
- return F.max_pool1d(x, 1, 1, return_indices=True)
- self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
+# Smoke tests for export methods
+class TestPytorchExportModes(JitTestCase):
+ class MyModel(nn.Module):
+ def __init__(self):
+ super(TestPytorchExportModes.MyModel, self).__init__()
+
+ def forward(self, x):
+ return x.transpose(0, 1)
- def full_kwarg_false(x):
- # type: (Tensor) -> Tensor
- return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
- self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
+ def test_protobuf(self):
+ torch_model = TestPytorchExportModes.MyModel()
+ fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
+ f = io.BytesIO()
+ torch.onnx._export(torch_model, (fake_input), f, verbose=False,
+ export_type=torch.onnx.ExportTypes.PROTOBUF_FILE)
- def full_kwarg_true(x):
- # type: (Tensor) -> Tuple[Tensor, Tensor]
- return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
- self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
+ def test_zipfile(self):
+ torch_model = TestPytorchExportModes.MyModel()
+ fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
+ f = io.BytesIO()
+ torch.onnx._export(torch_model, (fake_input), f, verbose=False,
+ export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE)
- def use_default(x):
- # type: (Tensor) -> Tensor
- return F.max_pool1d(x, 1, 1)
- self.checkScript(use_default, (torch.randn(3, 3, 3),))
+ def test_compressed_zipfile(self):
+ torch_model = TestPytorchExportModes.MyModel()
+ fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
+ f = io.BytesIO()
+ torch.onnx._export(torch_model, (fake_input), f, verbose=False,
+ export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE)
- def arg_false(x):
- # type: (Tensor) -> Tensor
- return F.max_pool1d(x, 1, 1, 0, 1, False, False)
- self.checkScript(arg_false, (torch.randn(3, 3, 3),))
+ def test_directory(self):
+ torch_model = TestPytorchExportModes.MyModel()
+ fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
+ d = tempfile.mkdtemp()
+ torch.onnx._export(torch_model, (fake_input), d, verbose=False,
+ export_type=torch.onnx.ExportTypes.DIRECTORY)
+ shutil.rmtree(d)
- def arg_true(x):
- # type: (Tensor) -> Tuple[Tensor, Tensor]
- return F.max_pool1d(x, 1, 1, 0, 1, False, True)
- self.checkScript(arg_true, (torch.randn(3, 3, 3),))
+ @skipIfRocm
+ @skipIfNoLapack
+ def test_aten_fallback(self):
+ class ModelWithAtenNotONNXOp(nn.Module):
+ def forward(self, x, y):
+ abcd = x + y
+ defg = torch.qr(abcd)
+ return defg
- def test_infer_size(self):
- from torch._C import _infer_size
+ x = torch.rand(3, 4)
+ y = torch.rand(3, 4)
+ f = io.BytesIO()
+ exported = torch.onnx.export_to_pretty_string(
+ ModelWithAtenNotONNXOp(), (x, y), f,
+ operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
+ self.assertExpected(exported)
- def fn(x, y):
- # type: (Tensor, Tensor) -> List[int]
- return _infer_size(x.size(), y.size())
+ # torch.fmod is using to test ONNX_ATEN.
+ # If you plan to remove fmod from aten, or found this test failed.
+ # please contact @Rui.
+ @skipIfRocm
+ def test_onnx_aten(self):
+ class ModelWithAtenFmod(nn.Module):
+ def forward(self, x, y):
+ return torch.fmod(x, y)
- self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
+ f = io.BytesIO()
+ x = torch.randn(3, 4, dtype=torch.float32)
+ y = torch.randn(3, 4, dtype=torch.float32)
+ exported = torch.onnx.export_to_pretty_string(
+ ModelWithAtenFmod(), (x, y), f,
+ operator_export_type=OperatorExportTypes.ONNX_ATEN)
+ self.assertExpected(exported)
- def test_mutable_dce(self):
- @torch.jit.script
- def foo():
- a = torch.rand(2, 3)
- a += torch.rand(2, 3)
- b = torch.rand(2, 3)
- b += torch.rand(2, 3)
- # b should be cleaned up but not a
- return a
- self.assertExpectedGraph(foo.graph)
+# known to be failing in tracer
+EXCLUDE_TRACED = {
+ 'test_split_dim',
+ 'test_split_dim_neg0',
- def test_mutable_dce_block(self):
- @torch.jit.script
- def foo():
- a = torch.rand(2, 3)
- a += torch.rand(2, 3)
- b = torch.rand(2, 3)
- if bool(a > torch.zeros(2, 3)):
- b += torch.rand(2, 3)
- a += torch.rand(2, 3)
- # a should be cleaned up but not b
- return b
+ # The following fail due to #12024.
+ # A prim::ListConstruct is involved and the indices get traced as DynamicType,
+ # which always require_grad. This causes a crash in autodiff.
+ 'test___getitem___adv_index',
+ 'test___getitem___adv_index_beg',
+ 'test___getitem___adv_index_comb',
+ 'test___getitem___adv_index_dup',
+ 'test___getitem___adv_index_sub',
+ 'test___getitem___adv_index_sub_2',
+ 'test___getitem___adv_index_sub_3',
+ 'test___getitem___adv_index_var',
+}
- self.assertExpectedGraph(foo.graph)
+EXCLUDE_TYPE_CHECK = {
+ # slogdet tests use itemgetter to select its only differentiable output,
+ # but this happens outside of the graph we handle, so there are fewer
+ # reference outputs than graph outputs.
+ 'test_slogdet_1x1_neg_det',
+ 'test_slogdet_1x1_pos_det',
+ 'test_slogdet_distinct_singular_values',
+ 'test_slogdet_neg_det',
+ 'test_slogdet_pos_det',
+ 'test_slogdet_symmetric',
+ 'test_slogdet_symmetric_pd',
+}
- def test_mutable_dce_graph_input(self):
- @torch.jit.script
- def foo(a):
- a += torch.rand(2, 3)
- # shouldn't clean up `a` even though it's not used in the output
+# known to be failing in script
+EXCLUDE_SCRIPT = {
+ 'test_norm_fro',
+ 'test_norm_fro_default',
+ 'test_norm_nuc',
- self.assertExpectedGraph(foo.graph)
+ # aten op has additional cudnn argument
+ 'test_nn_unfold',
- def test_mutable_dce_list(self):
- @torch.jit.script
- def foo(a):
- l = []
- l.append(a)
- c = l[0]
- b = torch.rand(2, 3)
- c += torch.rand(2, 3)
- return b
+ # flaky test - TODO fix
+ 'test_nn_ctc_loss',
- self.assertExpectedGraph(foo.graph)
+ # unknown builtin op
+ 'test_nn_fold',
+}
- def test_mutable_dce_loop(self):
- @torch.jit.script
- def foo(a):
- l = []
- l.append(a)
- i = 0
- b = torch.rand(2, 3)
- while i < 1:
- dead = torch.rand(2, 3)
- c = l[0]
- c += torch.rand(2, 3)
- i += 1
- return b
+EXCLUDE_PYTHON_PRINT = {
+ # no support for BroadcastingList in python printer
+ 'test_nn_max_unpool1d',
+ 'test_nn_max_unpool2d',
+ 'test_nn_max_unpool3d',
+ 'test_nn_max_pool1d',
+ 'test_nn_max_pool2d',
+ 'test_nn_max_pool3d',
+ 'test_nn_max_pool1d_with_indices',
+}
- self.assertExpectedGraph(foo.graph)
+EXCLUDE_SCRIPT_MODULES = {
+ 'test_nn_AdaptiveAvgPool2d_tuple_none',
+ 'test_nn_AdaptiveAvgPool3d_tuple_none',
+ 'test_nn_AdaptiveMaxPool2d_tuple_none',
+ 'test_nn_AdaptiveMaxPool3d_tuple_none',
+}
+DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
+ 'test_nn_avg_pool2d',
+ 'test_nn_log_softmax',
+ 'test_nn_threshold',
+ 'test_nn_nll_loss',
+}
-class MnistNet(nn.Module):
- def __init__(self):
- super(MnistNet, self).__init__()
- self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
- self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
- self.conv2_drop = nn.Dropout2d()
- self.fc1 = nn.Linear(320, 50)
- self.fc2 = nn.Linear(50, 10)
- def forward(self, x):
- x = F.relu(F.max_pool2d(self.conv1(x), 2))
- x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
- x = x.view(-1, 320)
- x = F.relu(self.fc1(x))
- x = F.dropout(x, training=self.training)
- x = self.fc2(x)
- return F.log_softmax(x, dim=1)
+# make a new function where all non-tensor arguments in 'args' have been partially
+# applied, and all tensor arguments remain.
+# used to trace functions when some arguments are not tensors
+def partial_apply_nontensors(fn, args, **kwargs):
+ source = ['t' if isinstance(arg, torch.Tensor) else 's' for arg in args]
+
+ def new_fn(*tensors_):
+ tensors = iter(tensors_)
+ return fn(*(args[i] if s == 's' else next(tensors) for i, s in enumerate(source)), **kwargs)
+ return new_fn, [arg for arg in args if isinstance(arg, torch.Tensor)]
-class TestEndToEndHybridFrontendModels(JitTestCase):
- @staticmethod
- def _test_dcgan_models(self, device, check_export_import=True):
- class DCGANGenerator(nn.Module):
- def __init__(self, nz, ngf, nc):
- super(DCGANGenerator, self).__init__()
- self.main = nn.Sequential(
- # input is Z, going into a convolution
- nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
- nn.BatchNorm2d(ngf * 8),
- nn.ReLU(True),
- # state size. (ngf*8) x 4 x 4
- nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ngf * 4),
- nn.ReLU(True),
- # state size. (ngf*4) x 8 x 8
- nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ngf * 2),
- nn.ReLU(True),
- # state size. (ngf*2) x 16 x 16
- nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ngf),
- nn.ReLU(True),
- # state size. (ngf) x 32 x 32
- nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
- nn.Tanh()
- # state size. (nc) x 64 x 64
- )
- def forward(self, input):
- return self.main(input)
+# create a trace function from input fn
+#
+# disable_autodiff_subgraph_inlining:
+# Don't inline autodiff subgraphs so we can test autodiff
+def create_traced_fn(self, fn,
+ disable_autodiff_subgraph_inlining=False):
+ def traced_fn(*inputs, **kwargs):
+ fn_tensors, inputs_tensors = partial_apply_nontensors(fn, inputs, **kwargs)
+ traced = torch.jit.trace(fn_tensors, inputs_tensors)
+ self.assertExportImport(traced.graph, inputs_tensors)
+ if disable_autodiff_subgraph_inlining:
+ traced.debug_disable_autodiff_subgraph_inlining()
+ output = traced(*inputs_tensors)
+ traced_fn.last_graph = traced.graph_for(*inputs_tensors)
+ return output
+ return traced_fn
- class DCGANDiscriminator(nn.Module):
- def __init__(self, nc, ndf):
- super(DCGANDiscriminator, self).__init__()
- self.main = nn.Sequential(
- # input is (nc) x 64 x 64
- nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
- nn.LeakyReLU(0.2, inplace=True),
- # state size. (ndf) x 32 x 32
- nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ndf * 2),
- nn.LeakyReLU(0.2, inplace=True),
- # state size. (ndf*2) x 16 x 16
- nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ndf * 4),
- nn.LeakyReLU(0.2, inplace=True),
- # state size. (ndf*4) x 8 x 8
- nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ndf * 8),
- nn.LeakyReLU(0.2, inplace=True),
- # state size. (ndf*8) x 4 x 4
- nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
- nn.Sigmoid()
- )
+script_template = '''
+def the_method({}):
+ return {}
+'''
- def forward(self, input):
- return self.main(input).view(-1, 1).squeeze(1)
+script_method_template = '''
+def forward({}):
+ return {}
+'''
- bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
- self.checkTrace(DCGANGenerator(nz, ngf, nc).to(device),
- (torch.rand(bs, nz, 1, 1, device=device),),
- export_import=check_export_import)
- example_input = DCGANGenerator(nz, ngf, nc).to(device)(torch.rand(bs, nz, 1, 1, device=device))
- self.checkTrace(DCGANDiscriminator(nc, ndf).to(device), (example_input,),
- export_import=check_export_import)
- def test_dcgan_models(self):
- self._test_dcgan_models(self, device='cpu')
+def get_constant(x):
+ if x == inf:
+ return 'float(\'inf\')' if PY2 else 'math.inf'
+ if x == -inf:
+ return 'float(\'-inf\')' if PY2 else '-math.inf'
+ return x
- @unittest.skipIf(not RUN_CUDA, "no CUDA")
- @skipIfRocm
- def test_dcgan_models_cuda(self):
- # XXX: export_import on CUDA modules doesn't work (#11480)
- self._test_dcgan_models(self, device='cuda', check_export_import=False)
- @staticmethod
- def _test_neural_style(self, device, check_export_import=True):
- class TransformerNet(torch.nn.Module):
- def __init__(self):
- super(TransformerNet, self).__init__()
- # Initial convolution layers
- self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
- self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
- self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
- self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
- self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
- self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
- # Residual layers
- self.res1 = ResidualBlock(128)
- self.res2 = ResidualBlock(128)
- self.res3 = ResidualBlock(128)
- self.res4 = ResidualBlock(128)
- self.res5 = ResidualBlock(128)
- # Upsampling Layers
- self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
- self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
- self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
- self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
- self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
- # Non-linearities
- self.relu = torch.nn.ReLU()
+def get_script_args(args):
+ formals = []
+ tensors = []
+ actuals = []
+ for arg in args:
+ if isinstance(arg, torch.Tensor):
+ name = 'i{}'.format(len(formals))
+ formals.append(name)
+ actuals.append(name)
+ tensors.append(arg)
+ elif isinstance(arg, str):
+ actuals.append("'{}'".format(arg))
+ else:
+ actuals.append(str(get_constant(arg)))
+ return (formals, tensors, actuals)
- def forward(self, X):
- y = self.relu(self.in1(self.conv1(X)))
- y = self.relu(self.in2(self.conv2(y)))
- y = self.relu(self.in3(self.conv3(y)))
- y = self.res1(y)
- y = self.res2(y)
- y = self.res3(y)
- y = self.res4(y)
- y = self.res5(y)
- y = self.relu(self.in4(self.deconv1(y)))
- y = self.relu(self.in5(self.deconv2(y)))
- y = self.deconv3(y)
- return y
- class ConvLayer(torch.nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride):
- super(ConvLayer, self).__init__()
- reflection_padding = kernel_size // 2
- self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
- self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
+# create a script function from (name, func_type, output_process_fn),
+# returns a function takes in (args, kwargs) and runs the compiled function and
+# then applies the post process fn to the outputs
+def create_script_fn(self, method_name, func_type, output_process_fn,
+ disable_autodiff_subgraph_inlining=False):
+ def script_fn(*args, **kwargs):
+ formals, tensors, actuals = get_script_args(args)
+ kwargs_str = ''
+ for k, v in kwargs.items():
+ kwargs_str += ', ' + k + '=' + str(v)
+ if func_type == 'functional':
+ call = 'torch.{}({}{})'.format(method_name, ', '.join(actuals), kwargs_str)
+ elif func_type == 'method':
+ call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str)
+ elif func_type == 'nn_functional':
+ call = 'torch.nn.functional.{}({}{})'.format(method_name, ', '.join(actuals), kwargs_str)
+ else:
+ raise 'Unsupported function type'
- def forward(self, x):
- out = self.reflection_pad(x)
- out = self.conv2d(out)
- return out
+ script = script_template.format(', '.join(formals), call)
- class ResidualBlock(torch.nn.Module):
- """ResidualBlock
- introduced in: https://arxiv.org/abs/1512.03385
- recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
- """
+ CU = torch.jit.CompilationUnit(script)
+ if disable_autodiff_subgraph_inlining:
+ CU.the_method.debug_disable_autodiff_subgraph_inlining()
+ self.assertExportImport(CU.the_method.graph, tensors)
+ output = output_process_fn(CU.the_method(*tensors))
+ script_fn.last_graph = CU.the_method.graph_for(*tensors)
+ return output
+ return script_fn
- def __init__(self, channels):
- super(ResidualBlock, self).__init__()
- self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
- self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
- self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
- self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
- self.relu = torch.nn.ReLU()
- def forward(self, x):
- residual = x
- out = self.relu(self.in1(self.conv1(x)))
- out = self.in2(self.conv2(out))
- out = out + residual
- return out
+def check_alias_annotation(method_name, args, kwargs):
+ formals, tensors, actuals = get_script_args(args)
+ kwargs_str = ''
+ for k, v in kwargs.items():
+ kwargs_str += ', ' + k + '=' + str(v)
+ call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str)
+ script = script_template.format(', '.join(formals), call)
+ CU = torch.jit.CompilationUnit(script)
+ torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), method_name)
- class UpsampleConvLayer(torch.nn.Module):
- """UpsampleConvLayer
- Upsamples the input and then does a convolution. This method gives better results
- compared to ConvTranspose2d.
- ref: http://distill.pub/2016/deconv-checkerboard/
- """
- def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
- super(UpsampleConvLayer, self).__init__()
- self.upsample = upsample
- if upsample:
- self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
- reflection_padding = kernel_size // 2
- self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
- self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
+def check_output_types(self, func, ref_outputs, args, kwargs):
+ graph = getattr(func, 'last_graph', None)
+ if not isinstance(ref_outputs, tuple):
+ ref_outputs = (ref_outputs,)
+ types = [o.type() for o in graph.outputs()]
+ self.assertEqual(len(types), len(ref_outputs))
+ for i, (t, ref_out) in enumerate(zip(types, ref_outputs)):
+ if isinstance(ref_out, list):
+ assert len(ref_out) > 0
+ elem = ref_out[0]
+ assert isinstance(elem, torch.Tensor)
+ self.assertTrue(t.isSubtypeOf(torch._C.ListType.ofTensors()))
+ else:
+ ref_type = torch._C.Type.inferFrom(ref_out)
+ self.assertTrue(ref_type.isSubtypeOf(t))
- def forward(self, x):
- x_in = x
- if self.upsample:
- x_in = self.upsample_layer(x_in)
- out = self.reflection_pad(x_in)
- out = self.conv2d(out)
- return out
- self.checkTrace(TransformerNet(), (torch.rand(5, 3, 64, 64),), export_import=check_export_import)
+def check_against_reference(self, func, reference_func, args, kwargs=None,
+ allow_unused=True, check_types=True, no_grad=False):
+ kwargs = kwargs if kwargs else {}
- def test_neural_style(self):
- self._test_neural_style(self, device='cpu')
+ def allSum(vs):
+ if isinstance(vs, torch.Tensor):
+ vs = (vs,)
+ return sum([(i + 1) * v.sum()
+ for i, v in enumerate(vs)
+ if v is not None and v.dtype.is_floating_point])
- @unittest.skipIf(not RUN_CUDA, "no CUDA")
- def test_neural_style_cuda(self):
- # XXX: export_import on CUDA modules doesn't work (#11480)
- self._test_neural_style(self, device='cuda', check_export_import=False)
+ def clone_inputs(requires_grad):
+ inputs = [
+ arg.detach().clone().requires_grad_(requires_grad and arg.requires_grad)
+ if isinstance(arg, torch.Tensor) else arg for arg in args
+ ]
+ return inputs, [input for input in inputs if isinstance(input, torch.Tensor) and input.requires_grad]
+
+ nograd_inputs, nograd_tensors = clone_inputs(False)
+ recording_inputs, recording_tensors = clone_inputs(True)
+
+ # test no gradients case
+ outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
+ outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
+ self.assertEqual(outputs, outputs_test)
+
+ if check_types:
+ check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
- @staticmethod
- def _test_mnist(self, device, check_export_import=True):
- # eval() is present because dropout makes this nondeterministic
- self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
- export_import=check_export_import)
+ if no_grad:
+ # skip grad tests
+ return
- def test_mnist(self):
- self._test_mnist(self, device='cpu')
+ # test single grad case
+ outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
+ grads = torch.autograd.grad(allSum(outputs), recording_tensors,
+ allow_unused=allow_unused)
- @unittest.skipIf(not RUN_CUDA, "no CUDA")
- @skipIfRocm
- def test_mnist_cuda(self):
- # XXX: export_import on CUDA modules doesn't work (#11480)
- self._test_mnist(self, device='cuda', check_export_import=False)
+ outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
+ grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
+ allow_unused=allow_unused)
+ self.assertEqual(outputs, outputs_test)
+ self.assertEqual(grads, grads_test)
- @unittest.skipIf(not RUN_CUDA, "no CUDA")
- @skipIfRocm
- def test_mnist_training_leaks_no_memory_cuda(self):
- net = MnistNet().cuda()
- # MnistNet uses dropout, don't check its trace
- traced_net = torch.jit.trace(net, [torch.randn(5, 1, 28, 28, device='cuda')],
- check_trace=False)
+ # test the grad grad case
+ if self._testMethodName in nn_functional_single_grad:
+ return
- def train(iters):
- for _ in range(iters):
- # Get some fake data
- inp = torch.randn(5, 1, 28, 28, device='cuda')
- out = traced_net(inp)
+ outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
+ l1 = allSum(outputs)
+ grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
+ allow_unused=allow_unused)
+ l2 = (allSum(grads) * l1)
+ grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
- # Here's some fake loss
- out.sum().backward()
+ recording_inputs, recording_tensors = clone_inputs(True)
- # Zero out grads
- traced_net.zero_grad()
+ outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
+ l1_test = allSum(outputs_test)
+ grads_test = torch.autograd.grad(
+ l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
+ l2_test = (allSum(grads_test) * l1_test)
+ grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
- # Set it up so the params have .grad fields so they are not reported as leaks
- train(1)
+ self.assertEqual(outputs, outputs_test)
+ self.assertEqual(grads, grads_test)
+ for g2, g2_test in zip(grads2, grads2_test):
+ if g2 is None and g2_test is None:
+ continue
+ self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4))
- with self.assertLeaksNoCudaTensors():
- train(5)
- @staticmethod
- def _test_reinforcement_learning(self, device, test_export_import=True):
- class Policy(nn.Module):
- def __init__(self):
- super(Policy, self).__init__()
- self.affine1 = nn.Linear(4, 128)
- self.affine2 = nn.Linear(128, 2)
+class TestFuser(JitTestCase):
+ def assertAllFused(self, graph, except_for=()):
+ if [n.kind() for n in graph.nodes()] == ['prim::DifferentiableGraph']:
+ graph = next(graph.nodes()).g('Subgraph')
+ allowed_nodes = {'prim::Constant', 'prim::FusionGroup'} | set(except_for)
+ self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
+ 'got {}'.format(graph))
+ self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
- def forward(self, x):
- x = F.relu(self.affine1(x))
- action_scores = self.affine2(x)
- return F.softmax(action_scores, dim=1)
+ def _test_fused_abs(self, device='cpu'):
- self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
- export_import=test_export_import)
+ @torch.jit.script
+ def func(x):
+ return x.abs() * 2
- def test_reinforcement_learning(self):
- self._test_reinforcement_learning(self, device='cpu')
+ a = torch.randn(5, device=device)
+ self.assertEqual(func(a), a.abs() * 2)
+ self.assertAllFused(func.graph_for(a))
- @unittest.skipIf(not RUN_CUDA, "no CUDA")
- def test_reinforcement_learning_cuda(self):
- # XXX: export_import on CUDA modules doesn't work (#11480)
- self._test_reinforcement_learning(self, device='cuda', test_export_import=False)
+ @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+ @enable_cpu_fuser
+ def test_abs_cpu(self):
+ self._test_fused_abs()
- @staticmethod
- def _test_snli(self, device, check_export_import=True):
- class Bottle(nn.Module):
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+ @skipIfRocm
+ def test_abs_cuda(self):
+ self._test_fused_abs(device="cuda")
- def forward(self, input):
- if len(input.size()) <= 2:
- return super(Bottle, self).forward(input)
- size = input.size()[:2]
- out = super(Bottle, self).forward(input.view(size[0] * size[1], -1))
- return out.view(size[0], size[1], -1)
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_arg_configurations_smoke_cuda(self):
+ # A smoke test to make sure we won't use the same kernel for contiguous
+ # and non-contiguous arguments.
+ # TODO: add optionally enabled debug counters to the fuser to verify
+ # that we really can tell the difference between configurations
+ def f(x, y):
+ z1, z2 = (x + y).chunk(2, dim=1)
+ return z1 * z2
- class Linear(Bottle, nn.Linear):
- pass
+ x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ traced_f = torch.jit.trace(f, (x, y,))
+ self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
- class Encoder(nn.Module):
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_broadcast_cuda(self):
+ def scaleshift(x, scale, shift):
+ return x * scale + shift
- def __init__(self, config):
- super(Encoder, self).__init__()
- self.config = config
- input_size = config.d_proj if config.projection else config.d_embed
- dropout = 0 if config.n_layers == 1 else config.dp_ratio
- self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
- num_layers=config.n_layers, dropout=dropout,
- bidirectional=config.birnn)
+ inputs = [
+ torch.randn(4, 4, dtype=torch.float, device='cuda'),
+ torch.randn(4, dtype=torch.float, device='cuda'),
+ torch.randn(4, dtype=torch.float, device='cuda'),
+ ]
+ ge = self.checkTrace(scaleshift, inputs)
+ self.assertExpectedGraph(ge.graph_for(*inputs))
- def forward(self, inputs):
- batch_size = inputs.size()[1]
- state_shape = self.config.n_cells, batch_size, self.config.d_hidden
- h0 = c0 = inputs.new_zeros(state_shape)
- outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
- return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @unittest.skipIf(not RUN_CUDA_HALF, "no half support")
+ def test_cuda_half(self):
+ x = torch.randn(4, 4, dtype=torch.half, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.half, device='cuda')
- class SNLIClassifier(nn.Module):
+ funcs = [
+ self.fn_test_comparison_gt_lt,
+ self.fn_test_relu,
+ self.fn_test_exp
+ ]
- def __init__(self, config):
- super(SNLIClassifier, self).__init__()
- self.config = config
- self.embed = nn.Embedding(config.n_embed, config.d_embed)
- self.projection = Linear(config.d_embed, config.d_proj)
- self.encoder = Encoder(config)
- self.dropout = nn.Dropout(p=config.dp_ratio)
- self.relu = nn.ReLU()
- seq_in_size = 2 * config.d_hidden
- if self.config.birnn:
- seq_in_size *= 2
- lin_config = [seq_in_size] * 2
- self.out = nn.Sequential(
- Linear(*lin_config),
- self.relu,
- self.dropout,
- Linear(*lin_config),
- self.relu,
- self.dropout,
- Linear(*lin_config),
- self.relu,
- self.dropout,
- Linear(seq_in_size, config.d_out))
+ # Note: Non fused inputs must be float to prevent loss of precision
+ inputs = (x.float(), y.float())
+ fusion_inputs = (x, y)
+ for fn in funcs:
+ local_inputs = [t.clone().requires_grad_() for t in inputs]
+ local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
- def forward(self, premise, hypothesis):
- prem_embed = self.embed(premise)
- hypo_embed = self.embed(hypothesis)
- if self.config.fix_emb:
- prem_embed = prem_embed.detach()
- hypo_embed = hypo_embed.detach()
- if self.config.projection:
- prem_embed = self.relu(self.projection(prem_embed))
- hypo_embed = self.relu(self.projection(hypo_embed))
- premise = self.encoder(prem_embed)
- hypothesis = self.encoder(hypo_embed)
- scores = self.out(torch.cat([premise, hypothesis], 1))
- return scores
+ # Verifies outputs
+ fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False, optimize=True)
+ outputs = fn(*local_inputs)
+ fusion_outputs = fusion(*local_fusion_inputs)
+ outputs_half = [t.half() for t in outputs]
+ self.assertEqual(outputs_half, fusion_outputs)
- class Config:
- n_embed = 100
- d_embed = 100
- d_proj = 300
- dp_ratio = 0.0 # For deterministic testing TODO: change by fixing seed in checkTrace?
- d_hidden = 300
- birnn = True
- d_out = 300
- fix_emb = True
- projection = True
- n_layers = 2
- n_cells = 4 # 2 * n_layers because birnn = True
+ # Verifies gradients
+ for output, fusion_output in zip(outputs_half, fusion_outputs):
+ grads = torch.autograd.grad(
+ output.float().sum(), local_inputs, allow_unused=True, retain_graph=True)
+ fusion_grads = torch.autograd.grad(
+ fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True)
+ grads_half = [t.half() for t in grads]
+ self.assertEqual(grads_half, fusion_grads)
- premise = torch.LongTensor(48, 128).random_(0, 100).to(device)
- hypothesis = torch.LongTensor(24, 128).random_(0, 100).to(device)
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_checks_cat_inputs(self):
+ # We shouldn't treat cat nodes as broadcasting. All their inputs
+ # need to be checked for having the same map size, before we can
+ # run the kernel.
+ @torch.jit.script
+ def f(x, y):
+ return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0)
+
+ # NOTE: y is broadcastable to x, but output of f(x, y) should have
+ # shape 3x4, and not 4x4.
+ x = torch.randn(2, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(1, 4, dtype=torch.float, device='cuda')
- self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
- inputs_require_grads=False, export_import=check_export_import)
+ self.assertEqual(f(x, y).shape, (3, 4))
+ self.assertAllFused(f.graph_for(x, y))
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "No CUDA")
@skipIfRocm
- def test_snli(self):
- self._test_snli(self, device='cpu')
+ def test_chunk_cuda(self):
+ def fn(x):
+ a, b, c = x.chunk(3, 1)
+ return a * b + c
- @skipIfRocm
- @unittest.skipIf(not RUN_CUDA, "no CUDA")
- def test_snli_cuda(self):
- # XXX: export_import on CUDA modules doesn't work (#11480)
- self._test_snli(self, device='cuda', check_export_import=False)
+ inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
+
+ ge = self.checkScript(fn, inputs)
+ self.assertExpectedGraph(ge.graph_for(*inputs))
@staticmethod
- def _test_super_resolution(self, device, check_export_import=True):
- import torch.nn.init as init
+ def _test_chunk_correctness(self, device='cpu'):
+ def chunk_4_0(x):
+ x0, x1, x2, x3 = x.chunk(4, 0)
+ return x0 + x1 + x2 + x3
- class Net(nn.Module):
+ def chunk_4_1(x):
+ x0, x1, x2, x3 = x.chunk(4, 1)
+ return x0 + x1 + x2 + x3
- def __init__(self, upscale_factor):
- super(Net, self).__init__()
+ def chunk_4_last(x):
+ x0, x1, x2, x3 = x.chunk(4, 2)
+ return x0 + x1 + x2 + x3
- self.relu = nn.ReLU()
- self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
- self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
- self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
- self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
- self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
+ fns = [chunk_4_0, chunk_4_1, chunk_4_last]
+ tensors = [
+ # splitSize = 1
+ torch.randn(4, 4, 4, dtype=torch.float, device=device),
- def forward(self, x):
- x = self.relu(self.conv1(x))
- x = self.relu(self.conv2(x))
- x = self.relu(self.conv3(x))
- x = self.pixel_shuffle(self.conv4(x))
- return x
+ # contiguous case
+ torch.randn(12, 8, 16, dtype=torch.float, device=device),
- net = Net(upscale_factor=4).to(device)
- self.checkTrace(net, (torch.rand(5, 1, 64, 64, device=device),),
- export_import=check_export_import)
+ # non-contiguous case
+ torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2),
+ ]
- @skipIfRocm
- def test_super_resolution(self):
- self._test_super_resolution(self, device='cpu')
+ for tensor in tensors:
+ for fn in fns:
+ self.checkScript(fn, [tensor])
+ @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
@skipIfRocm
- @unittest.skipIf(not RUN_CUDA, 'no CUDA')
- def test_super_resolution_cuda(self):
- # XXX: export_import on CUDA modules doesn't work (#11480)
- self._test_super_resolution(self, device='cuda', check_export_import=False)
-
- @suppress_warnings
- def test_time_sequence_prediction(self):
- class Sequence(torch.jit.ScriptModule):
- def __init__(self):
- super(Sequence, self).__init__()
- self.lstm1 = nn.LSTMCell(1, 51)
- self.lstm2 = nn.LSTMCell(51, 51)
- self.linear = nn.Linear(51, 1)
-
- # TODO: could not pass tuple to a python Op and type annotations
- # is not descending to python signature, hence the wrapper
- # see https://github.com/pytorch/pytorch/issues/8778
- # and https://github.com/pytorch/pytorch/issues/8777
- def test_lstm1(self, input, hx, cx):
- # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
- return self.lstm1(input, (hx, cx))
-
- def test_lstm2(self, input, hx, cx):
- # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
- return self.lstm2(input, (hx, cx))
-
- # TODO: could not support tensor constructors in script
- # see https://github.com/pytorch/pytorch/issues/8814
- def test_tensor(self):
- return torch.tensor([], dtype=torch.double)
-
- @torch.jit.script_method
- def forward(self, input):
- # TODO: add future as input with default val
- # see https://github.com/pytorch/pytorch/issues/8724
- outputs = self.test_tensor()
- h_t = torch.zeros((3, 51), dtype=torch.double)
- c_t = torch.zeros((3, 51), dtype=torch.double)
- h_t2 = torch.zeros((3, 51), dtype=torch.double)
- c_t2 = torch.zeros((3, 51), dtype=torch.double)
+ @enable_cpu_fuser
+ def test_chunk_correctness(self):
+ return self._test_chunk_correctness(self, 'cpu')
- output = torch.zeros([3, 51])
- future = 2
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "No CUDA")
+ @skipIfRocm
+ def test_chunk_correctness_cuda(self):
+ return self._test_chunk_correctness(self, 'cuda')
- # TODO: chunk call should appear as the for loop iterable
- # We hard-code it to 4 for now.
- a, b, c, d = input.chunk(input.size(1), dim=1)
- for input_t in (a, b, c, d):
- h_t, c_t = self.test_lstm1(input_t, h_t, c_t)
- h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
- output = self.linear(h_t2)
- outputs = torch.cat((outputs, output), 1)
- for _ in range(future): # if we should predict the future
- h_t, c_t = self.test_lstm1(output, h_t, c_t)
- h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
- output = self.linear(h_t2)
- outputs = torch.cat((outputs, output), 1)
- return outputs
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_chunk_distributes_cuda(self):
+ def f(x, y):
+ z1, z2 = (x + y).chunk(2, dim=1)
+ return z1 * z2
- # TODO: toggle export_import once above issues are fixed
- self.checkTrace(Sequence(), (torch.rand(3, 4),),
- export_import=False)
+ x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.float, device='cuda')
- @staticmethod
- def _test_vae(self, device, check_export_import=True):
- class VAE(nn.Module):
- def __init__(self):
- super(VAE, self).__init__()
+ ge = self.checkTrace(f, (x, y))
+ self.assertExpectedGraph(ge.graph_for(x, y))
- self.fc1 = nn.Linear(784, 400)
- self.fc21 = nn.Linear(400, 20)
- self.fc22 = nn.Linear(400, 20)
- self.fc3 = nn.Linear(20, 400)
- self.fc4 = nn.Linear(400, 784)
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_chunk_motion_deduplicates_inputs(self):
+ def func1(x):
+ z = x * x
+ z0, z1 = z.chunk(2)
+ return z0 * z1
- def encode(self, x):
- h1 = F.relu(self.fc1(x))
- return self.fc21(h1), self.fc22(h1)
+ def func2(x):
+ z = x * x * x
+ z0, z1 = z.chunk(2)
+ return z0 * z1
- def reparameterize(self, mu, logvar):
- if self.training:
- std = torch.exp(0.5 * logvar)
- eps = torch.randn_like(std)
- return eps.mul(std).add_(mu)
- else:
- return mu
+ inputs = [
+ torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float),
+ ]
+ for func in [func1, func2]:
+ module = self.checkScript(func, inputs)
+ forward_graph = module.graph_for(*inputs)
+ self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
+ fusion_group = list(forward_graph.nodes())[-1]
+ self.assertEqual(len(list(fusion_group.inputs())), 1)
- def decode(self, z):
- h3 = F.relu(self.fc3(z))
- return torch.sigmoid(self.fc4(h3))
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "No CUDA")
+ @skipIfRocm
+ def test_chunk_multiple_cuda(self):
+ # The arguments are intentionally used out of order as a test to see
+ # if the fusion compiler adds extra args in the correct order
+ def fn(s, x, y, z):
+ z1, z2 = z.chunk(2, 2)
+ x1, x2, x3 = x.chunk(3, 1)
+ y1, y2 = y.chunk(2, 0)
+ return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
- def forward(self, x):
- mu, logvar = self.encode(x.view(-1, 784))
- z = self.reparameterize(mu, logvar)
- return self.decode(z), mu, logvar
+ inputs = [
+ torch.randn(5, 2, 3, dtype=torch.float, device='cuda'),
+ torch.randn(5, 6, 3, dtype=torch.float, device='cuda'),
+ torch.randn(10, 2, 3, dtype=torch.float, device='cuda'),
+ torch.randn(5, 2, 6, dtype=torch.float, device='cuda'),
+ ]
- # eval() is present because randn_like makes this nondeterministic
- self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
- export_import=check_export_import)
+ ge = self.checkScript(fn, inputs)
+ self.assertExpectedGraph(ge.graph_for(*inputs))
- def test_vae(self):
- self._test_vae(self, device='cpu')
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_clamp(self):
+ def func2(a, b):
+ return torch.clamp(a + b, min=0, max=2)
- @unittest.skipIf(not RUN_CUDA, "no CUDA")
- def test_vae_cuda(self):
- # XXX: export_import on CUDA modules doesn't work (#11480)
- self._test_vae(self, device='cuda', check_export_import=False)
+ def funcInf(a, b):
+ return torch.clamp(a + b, min=0, max=float('inf'))
+ a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
+ b = torch.randn(4, 4, dtype=torch.float, device='cuda')
-# Smoke tests for export methods
-class TestPytorchExportModes(JitTestCase):
- class MyModel(nn.Module):
- def __init__(self):
- super(TestPytorchExportModes.MyModel, self).__init__()
+ funcs = (func2, funcInf)
+ for f in funcs:
+ s = self.checkScript(f, (a, b))
+ self.assertAllFused(s.graph_for(a, b), except_for={'aten::size'})
- def forward(self, x):
- return x.transpose(0, 1)
+ c = s(a, b)
+ c.sum().backward()
+ graph = backward_graph(s)
+ self.assertAllFused(graph, except_for={'prim::SumToSize'})
- def test_protobuf(self):
- torch_model = TestPytorchExportModes.MyModel()
- fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
- f = io.BytesIO()
- torch.onnx._export(torch_model, (fake_input), f, verbose=False,
- export_type=torch.onnx.ExportTypes.PROTOBUF_FILE)
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_comparison_eq_ne(self):
+ def f(x, y):
+ mask = (x == 0).type_as(x)
+ z = x * mask + y
+ mask = (x != 0).type_as(x)
+ z = z * mask + y
+ return z
- def test_zipfile(self):
- torch_model = TestPytorchExportModes.MyModel()
- fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
- f = io.BytesIO()
- torch.onnx._export(torch_model, (fake_input), f, verbose=False,
- export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE)
+ x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.float, device='cuda')
- def test_compressed_zipfile(self):
- torch_model = TestPytorchExportModes.MyModel()
- fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
- f = io.BytesIO()
- torch.onnx._export(torch_model, (fake_input), f, verbose=False,
- export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE)
+ ge = self.checkTrace(f, (x, y))
+ self.assertAllFused(ge.graph_for(x, y))
- def test_directory(self):
- torch_model = TestPytorchExportModes.MyModel()
- fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
- d = tempfile.mkdtemp()
- torch.onnx._export(torch_model, (fake_input), d, verbose=False,
- export_type=torch.onnx.ExportTypes.DIRECTORY)
- shutil.rmtree(d)
+ @staticmethod
+ def fn_test_comparison_gt_lt(x, y):
+ mask = (x > 0).type_as(x)
+ z = x * mask + y
+ mask = (x < 0).type_as(x)
+ z = z * mask + y
+ return z
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@skipIfRocm
- @skipIfNoLapack
- def test_aten_fallback(self):
- class ModelWithAtenNotONNXOp(nn.Module):
- def forward(self, x, y):
- abcd = x + y
- defg = torch.qr(abcd)
- return defg
+ def test_comparison_gt_lt_cuda(self):
+ x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.float, device='cuda')
- x = torch.rand(3, 4)
- y = torch.rand(3, 4)
- f = io.BytesIO()
- exported = torch.onnx.export_to_pretty_string(
- ModelWithAtenNotONNXOp(), (x, y), f,
- operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
- self.assertExpected(exported)
+ ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
+ self.assertAllFused(ge.graph_for(x, y))
- # torch.fmod is using to test ONNX_ATEN.
- # If you plan to remove fmod from aten, or found this test failed.
- # please contact @Rui.
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@skipIfRocm
- def test_onnx_aten(self):
- class ModelWithAtenFmod(nn.Module):
- def forward(self, x, y):
- return torch.fmod(x, y)
+ def test_comparison_ge_le_cuda(self):
+ def f(x, y):
+ mask = (x >= 0).type_as(x)
+ z = x * mask + y
+ mask = (x <= 0).type_as(x)
+ z = z * mask + y
+ return z
- f = io.BytesIO()
- x = torch.randn(3, 4, dtype=torch.float32)
- y = torch.randn(3, 4, dtype=torch.float32)
- exported = torch.onnx.export_to_pretty_string(
- ModelWithAtenFmod(), (x, y), f,
- operator_export_type=OperatorExportTypes.ONNX_ATEN)
- self.assertExpected(exported)
+ x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ ge = self.checkTrace(f, (x, y))
+ self.assertAllFused(ge.graph_for(x, y))
-# known to be failing in tracer
-EXCLUDE_TRACED = {
- 'test_split_dim',
- 'test_split_dim_neg0',
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_concat_cuda(self):
+ hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
+ cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
- # The following fail due to #12024.
- # A prim::ListConstruct is involved and the indices get traced as DynamicType,
- # which always require_grad. This causes a crash in autodiff.
- 'test___getitem___adv_index',
- 'test___getitem___adv_index_beg',
- 'test___getitem___adv_index_comb',
- 'test___getitem___adv_index_dup',
- 'test___getitem___adv_index_sub',
- 'test___getitem___adv_index_sub_2',
- 'test___getitem___adv_index_sub_3',
- 'test___getitem___adv_index_var',
-}
+ def foo(hx, cx):
+ return torch.cat((hx + cx, hx * cx))
-EXCLUDE_TYPE_CHECK = {
- # slogdet tests use itemgetter to select its only differentiable output,
- # but this happens outside of the graph we handle, so there are fewer
- # reference outputs than graph outputs.
- 'test_slogdet_1x1_neg_det',
- 'test_slogdet_1x1_pos_det',
- 'test_slogdet_distinct_singular_values',
- 'test_slogdet_neg_det',
- 'test_slogdet_pos_det',
- 'test_slogdet_symmetric',
- 'test_slogdet_symmetric_pd',
-}
+ ge = self.checkTrace(foo, (hx, cx))
+ self.assertExpectedGraph(ge.graph_for(hx, cx))
-# known to be failing in script
-EXCLUDE_SCRIPT = {
- 'test_norm_fro',
- 'test_norm_fro_default',
- 'test_norm_nuc',
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_concat_invariant_cuda(self):
+ # Invariant: the output of prim::FusedConcat may
+ # not be an input to any node inside the FusionGroup.
+ def fn(x, y, z):
+ x1 = x + y
+ y1 = x - y
+ w = torch.cat([x1, y1])
+ return w + z
- # aten op has additional cudnn argument
- 'test_nn_unfold',
+ x = torch.randn(2, 2, dtype=torch.float, device='cuda')
+ y = torch.randn(2, 2, dtype=torch.float, device='cuda')
+ z = torch.randn(4, 2, dtype=torch.float, device='cuda')
+ ge = self.checkTrace(fn, (x, y, z))
+ self.assertExpectedGraph(ge.graph_for(x, y, z))
- # flaky test - TODO fix
- 'test_nn_ctc_loss',
+ @staticmethod
+ def fn_test_exp(x, y):
+ return (x + .5 * y).exp()
- # unknown builtin op
- 'test_nn_fold',
-}
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_exp_cuda(self):
+ x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.float, device='cuda')
-EXCLUDE_PYTHON_PRINT = {
- # no support for BroadcastingList in python printer
- 'test_nn_max_unpool1d',
- 'test_nn_max_unpool2d',
- 'test_nn_max_unpool3d',
- 'test_nn_max_pool1d',
- 'test_nn_max_pool2d',
- 'test_nn_max_pool3d',
- 'test_nn_max_pool1d_with_indices',
-}
+ ge = self.checkTrace(self.fn_test_exp, (x, y))
+ self.assertAllFused(ge.graph_for(x, y))
-EXCLUDE_SCRIPT_MODULES = {
- 'test_nn_AdaptiveAvgPool2d_tuple_none',
- 'test_nn_AdaptiveAvgPool3d_tuple_none',
- 'test_nn_AdaptiveMaxPool2d_tuple_none',
- 'test_nn_AdaptiveMaxPool3d_tuple_none',
-}
+ # TODO: This test doesn't offer anything valuable, maybe we should delete it
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
+ @skipIfRocm
+ def test_last_device_cuda(self):
+ device = 'cuda:' + str(1)
+ x = torch.tensor([0.4], dtype=torch.float, device=device)
+ y = torch.tensor([0.7], dtype=torch.float, device=device)
-DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
- 'test_nn_avg_pool2d',
- 'test_nn_log_softmax',
- 'test_nn_threshold',
- 'test_nn_nll_loss',
-}
+ def doit(x, y):
+ return torch.sigmoid(torch.tanh(x * (x + y) + x))
+ ge = self.checkTrace(doit, (x, y))
+ self.assertExpectedGraph(ge.graph_for(x, y))
-# make a new function where all non-tensor arguments in 'args' have been partially
-# applied, and all tensor arguments remain.
-# used to trace functions when some arguments are not tensors
-def partial_apply_nontensors(fn, args, **kwargs):
- source = ['t' if isinstance(arg, torch.Tensor) else 's' for arg in args]
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_lstm_cuda(self):
+ inputs = get_lstm_inputs('cuda', training=True)
+ module = self.checkScript(LSTMCellS, inputs)
+ forward_graph = module.graph_for(*inputs)
+ self.assertGraphContainsExactly(
+ forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+ self.assertExpectedGraph(forward_graph, subname='forward')
- def new_fn(*tensors_):
- tensors = iter(tensors_)
- return fn(*(args[i] if s == 's' else next(tensors) for i, s in enumerate(source)), **kwargs)
+ hy, cy = module(*inputs)
+ (hy + cy).sum().backward()
+ self.assertExpectedGraph(backward_graph(module), subname='backward')
- return new_fn, [arg for arg in args if isinstance(arg, torch.Tensor)]
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_lstm_concat_cuda(self):
+ inputs = get_lstm_inputs('cuda')
+ ge = self.checkTrace(LSTMCellC, inputs)
+ self.assertExpectedGraph(ge.graph_for(*inputs))
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_lstm_gates_permutations_cuda(self):
+ # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
+ # Test that any permutation of this will still result in one FusionGroup.
+ choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh']
+ template = dedent('''
+ def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
+ gates = {} + {} + {} + {}
+ ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
+ return ingate * forgetgate * cellgate * outgate
+ ''')
+ for permutation in itertools.permutations(choices, len(choices)):
+ code = template.format(*permutation)
+ scope = {}
+ exec(code, globals(), scope)
+ cu = torch.jit.CompilationUnit(code)
-# create a trace function from input fn
-#
-# disable_autodiff_subgraph_inlining:
-# Don't inline autodiff subgraphs so we can test autodiff
-def create_traced_fn(self, fn,
- disable_autodiff_subgraph_inlining=False):
- def traced_fn(*inputs, **kwargs):
- fn_tensors, inputs_tensors = partial_apply_nontensors(fn, inputs, **kwargs)
- traced = torch.jit.trace(fn_tensors, inputs_tensors)
- self.assertExportImport(traced.graph, inputs_tensors)
- if disable_autodiff_subgraph_inlining:
- traced.debug_disable_autodiff_subgraph_inlining()
- output = traced(*inputs_tensors)
- traced_fn.last_graph = traced.graph_for(*inputs_tensors)
- return output
- return traced_fn
+ inputs = get_lstm_inputs('cuda', training=False)
+ self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
+ forward_graph = cu.cell.graph_for(*inputs)
+ self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
-script_template = '''
-def the_method({}):
- return {}
-'''
+ # TODO: Fuser doesn't work at all when inputs require grad. Fix that
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_lstm_traced_cuda(self):
+ inputs = get_lstm_inputs('cuda')
+ ge = self.checkTrace(LSTMCellF, inputs)
+ self.assertExpectedGraph(ge.graph_for(*inputs))
+
+ @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+ @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
+ @enable_cpu_fuser
+ def test_lstm_traced_cpu(self):
+ inputs = get_lstm_inputs('cpu')
+ try:
+ ge = self.checkTrace(LSTMCellF, inputs)
+ self.assertExpectedGraph(ge.graph_for(*inputs))
+ except RuntimeError as e:
+ if 'Failed to compile' in e.args[0]:
+ warnings.warn('CPU fuser test has failed! This is not a hard failure, '
+ 'because the kernels sometimes trigger bugs in compilers '
+ '(most notably GCC 7.2).')
+ raise unittest.SkipTest('Failed to compile')
+ else:
+ raise
-script_method_template = '''
-def forward({}):
- return {}
-'''
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_milstm_cuda(self):
+ inputs = get_milstm_inputs('cuda', training=True)
+ module = self.checkScript(MiLSTMCell, inputs)
+ forward_graph = module.graph_for(*inputs)
+ self.assertGraphContainsExactly(
+ forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+ self.assertExpectedGraph(forward_graph, subname='forward')
+ hy, cy = module(*inputs)
+ (hy + cy).sum().backward()
+ self.assertExpectedGraph(backward_graph(module), subname='backward')
-def get_constant(x):
- if x == inf:
- return 'float(\'inf\')' if PY2 else 'math.inf'
- if x == -inf:
- return 'float(\'-inf\')' if PY2 else '-math.inf'
- return x
+ # TODO: At some point we supported fusion of torch.rand_like but not anymore
+ @unittest.expectedFailure
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_rand_cuda(self):
+ class M(torch.jit.ScriptModule):
+ __constants__ = ['d']
+ def __init__(self):
+ self.d = torch.device('cuda')
-def get_script_args(args):
- formals = []
- tensors = []
- actuals = []
- for arg in args:
- if isinstance(arg, torch.Tensor):
- name = 'i{}'.format(len(formals))
- formals.append(name)
- actuals.append(name)
- tensors.append(arg)
- elif isinstance(arg, str):
- actuals.append("'{}'".format(arg))
- else:
- actuals.append(str(get_constant(arg)))
- return (formals, tensors, actuals)
+ @torch.jit.script_method
+ def create(self, x):
+ return x * x + x + torch.rand_like(x)
+ x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
+ m = M()
+ out1 = m.create(x)
+ out2 = m.create(x)
+ self.assertNotEqual(out1, out2)
+ self.assertTrue(torch.all(out1 >= 0))
+ self.assertTrue(torch.all(out1 < 1))
+ self.assertTrue(torch.all(out2 >= 0))
+ self.assertTrue(torch.all(out2 < 1))
+ self.assertAllFused(m.create.graph_for(x))
-# create a script function from (name, func_type, output_process_fn),
-# returns a function takes in (args, kwargs) and runs the compiled function and
-# then applies the post process fn to the outputs
-def create_script_fn(self, method_name, func_type, output_process_fn,
- disable_autodiff_subgraph_inlining=False):
- def script_fn(*args, **kwargs):
- formals, tensors, actuals = get_script_args(args)
- kwargs_str = ''
- for k, v in kwargs.items():
- kwargs_str += ', ' + k + '=' + str(v)
- if func_type == 'functional':
- call = 'torch.{}({}{})'.format(method_name, ', '.join(actuals), kwargs_str)
- elif func_type == 'method':
- call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str)
- elif func_type == 'nn_functional':
- call = 'torch.nn.functional.{}({}{})'.format(method_name, ', '.join(actuals), kwargs_str)
- else:
- raise 'Unsupported function type'
+ @staticmethod
+ def fn_test_relu(x, y):
+ return F.relu(x + .5 * y)
- script = script_template.format(', '.join(formals), call)
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_relu_cuda(self):
+ x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.float, device='cuda')
- CU = torch.jit.CompilationUnit(script)
- if disable_autodiff_subgraph_inlining:
- CU.the_method.debug_disable_autodiff_subgraph_inlining()
- self.assertExportImport(CU.the_method.graph, tensors)
- output = output_process_fn(CU.the_method(*tensors))
- script_fn.last_graph = CU.the_method.graph_for(*tensors)
- return output
- return script_fn
+ ge = self.checkTrace(self.fn_test_relu, (x, y))
+ @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+ @enable_cpu_fuser
+ def test_scalar(self):
+ def fn(x, y):
+ return 2 * x + y
-def check_alias_annotation(method_name, args, kwargs):
- formals, tensors, actuals = get_script_args(args)
- kwargs_str = ''
- for k, v in kwargs.items():
- kwargs_str += ', ' + k + '=' + str(v)
- call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str)
- script = script_template.format(', '.join(formals), call)
- CU = torch.jit.CompilationUnit(script)
- torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), method_name)
+ x = torch.tensor(0.1, dtype=torch.float, device='cpu')
+ y = torch.tensor(1, dtype=torch.float, device='cpu')
+ ge = self.checkScript(fn, (x, y))
+ self.assertExpectedGraph(ge.graph_for(x, y))
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ def test_small_constant_cuda(self):
+ def fn_test_small_constant(x, y):
+ return (1e-8 * x + 5e-9 * y) * 1e8
+ x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.float, device='cuda')
-def check_output_types(self, func, ref_outputs, args, kwargs):
- graph = getattr(func, 'last_graph', None)
- if not isinstance(ref_outputs, tuple):
- ref_outputs = (ref_outputs,)
- types = [o.type() for o in graph.outputs()]
- self.assertEqual(len(types), len(ref_outputs))
- for i, (t, ref_out) in enumerate(zip(types, ref_outputs)):
- if isinstance(ref_out, list):
- assert len(ref_out) > 0
- elem = ref_out[0]
- assert isinstance(elem, torch.Tensor)
- self.assertTrue(t.isSubtypeOf(torch._C.ListType.ofTensors()))
- else:
- ref_type = torch._C.Type.inferFrom(ref_out)
- self.assertTrue(ref_type.isSubtypeOf(t))
+ ge = self.checkTrace(fn_test_small_constant, (x, y))
+ self.assertAllFused(ge.graph_for(x, y))
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @skipIfRocm
+ def test_tensor_scalar_ops_cuda(self):
+ def should_fuse(x):
+ z = 3.
+ y = x + z
+ return x * y
-def check_against_reference(self, func, reference_func, args, kwargs=None,
- allow_unused=True, check_types=True, no_grad=False):
- kwargs = kwargs if kwargs else {}
+ # XXX: right now we only support fusing scalars if
+ # they're constant (#9940)
+ def should_not_fuse(x, z):
+ y = x + int(z)
+ return x * y
- def allSum(vs):
- if isinstance(vs, torch.Tensor):
- vs = (vs,)
- return sum([(i + 1) * v.sum()
- for i, v in enumerate(vs)
- if v is not None and v.dtype.is_floating_point])
+ inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
+ ge = self.checkScript(should_fuse, inputs)
+ self.assertExpectedGraph(ge.graph_for(*inputs), subname='1')
- def clone_inputs(requires_grad):
inputs = [
- arg.detach().clone().requires_grad_(requires_grad and arg.requires_grad)
- if isinstance(arg, torch.Tensor) else arg for arg in args
+ torch.randn(2, 2, dtype=torch.float, device='cuda'),
+ torch.tensor(3., dtype=torch.float, device='cuda'),
]
- return inputs, [input for input in inputs if isinstance(input, torch.Tensor) and input.requires_grad]
-
- nograd_inputs, nograd_tensors = clone_inputs(False)
- recording_inputs, recording_tensors = clone_inputs(True)
+ ge = self.checkScript(should_not_fuse, inputs)
+ self.assertExpectedGraph(ge.graph_for(*inputs), subname='2')
- # test no gradients case
- outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
- outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
- self.assertEqual(outputs, outputs_test)
+ @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+ @enable_cpu_fuser
+ def test_where_and_typing(self):
+ def f(x, y):
+ mask = x > y
+ res = torch.where(mask, x, y)
+ return mask, res
- if check_types:
- check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
+ script_f = torch.jit.script(f)
- if no_grad:
- # skip grad tests
- return
+ x = torch.randn(4, 4, dtype=torch.double)
+ y = torch.randn(4, 4, dtype=torch.double)
- # test single grad case
- outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
- grads = torch.autograd.grad(allSum(outputs), recording_tensors,
- allow_unused=allow_unused)
+ result1, result2 = script_f(x, y)
+ expected1, expected2 = f(x, y)
+ self.assertEqual(result1, expected1)
+ self.assertEqual(result2, expected2)
+ self.assertAllFused(script_f.graph_for(x, y))
- outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
- grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
- allow_unused=allow_unused)
- self.assertEqual(outputs, outputs_test)
- self.assertEqual(grads, grads_test)
+ # TODO: This test seems dead
+ @unittest.skipIf(not IS_WINDOWS, "Testing Fuse skipped on windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ def test_windows(self):
+ def scaleshift(x, scale, shift):
+ return x * scale + shift
- # test the grad grad case
- if self._testMethodName in nn_functional_single_grad:
- return
+ graph = torch.jit.script(scaleshift).graph
- outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
- l1 = allSum(outputs)
- grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
- allow_unused=allow_unused)
- l2 = (allSum(grads) * l1)
- grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
+ inputs = [
+ torch.randn(4, 4, dtype=torch.float, device='cuda'),
+ torch.randn(4, dtype=torch.float, device='cuda'),
+ torch.randn(4, dtype=torch.float, device='cuda'),
+ ]
- recording_inputs, recording_tensors = clone_inputs(True)
+ ge = self.checkTrace(scaleshift, inputs)
+ fuse_graph = ge.graph_for(*inputs)
- outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
- l1_test = allSum(outputs_test)
- grads_test = torch.autograd.grad(
- l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
- l2_test = (allSum(grads_test) * l1_test)
- grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
+ def run_graph(graph, inputs):
+ m = torch.jit.ScriptModule()
+ m._create_method_from_graph("forward", graph)
+ return m(*inputs)
- self.assertEqual(outputs, outputs_test)
- self.assertEqual(grads, grads_test)
- for g2, g2_test in zip(grads2, grads2_test):
- if g2 is None and g2_test is None:
- continue
- self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4))
+ self.assertEqual(run_graph(graph, inputs), run_graph(fuse_graph, inputs))
# NB: torch.jit.script, when used as a function, uses the current scope