Split off fuser tests in test_jit.py to their own test case (#15072)
authorRichard Zou <zou3519@gmail.com>
Tue, 11 Dec 2018 22:50:33 +0000 (14:50 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 11 Dec 2018 22:55:06 +0000 (14:55 -0800)
Summary:
This PR creates TestFuser inside test_jit.py to be a home for graph fuser
specific tests.

This was a useful exercise because now that all the fuser tests are in
one place, I can spot redundant and bitrotting tests for cleanup in a
future PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15072

Differential Revision: D13421458

Pulled By: zou3519

fbshipit-source-id: 80b1a7712feff75a0c186d1664601c4edbbca694

18 files changed:
test/expect/TestFuser.test_broadcast_cuda.expect [moved from test/expect/TestJit.test_broadcast_fusion_cuda.expect with 100% similarity]
test/expect/TestFuser.test_chunk_cuda.expect [moved from test/expect/TestScript.test_chunk_fusion_cuda.expect with 100% similarity]
test/expect/TestFuser.test_chunk_distributes_cuda.expect [moved from test/expect/TestJit.test_fusion_distribute_cuda.expect with 100% similarity]
test/expect/TestFuser.test_chunk_multiple_cuda.expect [moved from test/expect/TestScript.test_chunk_multiple_fusion_cuda.expect with 100% similarity]
test/expect/TestFuser.test_concat_cuda.expect [moved from test/expect/TestJit.test_concat_fusion_cuda.expect with 100% similarity]
test/expect/TestFuser.test_concat_invariant_cuda.expect [moved from test/expect/TestJit.test_concat_fusion_invariant_cuda.expect with 100% similarity]
test/expect/TestFuser.test_last_device_cuda.expect [moved from test/expect/TestJit.test_fuse_last_device_cuda.expect with 100% similarity]
test/expect/TestFuser.test_lstm_concat_cuda.expect [moved from test/expect/TestJit.test_lstm_fusion_concat_cuda.expect with 100% similarity]
test/expect/TestFuser.test_lstm_cuda-backward.expect [moved from test/expect/TestScript.test_lstm_fusion_cuda-backward.expect with 100% similarity]
test/expect/TestFuser.test_lstm_cuda-forward.expect [moved from test/expect/TestScript.test_lstm_fusion_cuda-forward.expect with 100% similarity]
test/expect/TestFuser.test_lstm_traced_cpu.expect [moved from test/expect/TestJit.test_lstm_fusion_cpu.expect with 100% similarity]
test/expect/TestFuser.test_lstm_traced_cuda.expect [moved from test/expect/TestJit.test_lstm_fusion_cuda.expect with 100% similarity]
test/expect/TestFuser.test_milstm_cuda-backward.expect [moved from test/expect/TestScript.test_milstm_fusion_cuda-backward.expect with 100% similarity]
test/expect/TestFuser.test_milstm_cuda-forward.expect [moved from test/expect/TestScript.test_milstm_fusion_cuda-forward.expect with 100% similarity]
test/expect/TestFuser.test_scalar.expect [moved from test/expect/TestScript.test_scalar_fusion.expect with 100% similarity]
test/expect/TestFuser.test_tensor_scalar_ops_cuda-1.expect [moved from test/expect/TestScript.test_tensor_scalar_fusion_cuda-1.expect with 100% similarity]
test/expect/TestFuser.test_tensor_scalar_ops_cuda-2.expect [moved from test/expect/TestScript.test_tensor_scalar_fusion_cuda-2.expect with 100% similarity]
test/test_jit.py

index 42979e3..30f319e 100644 (file)
@@ -347,6 +347,52 @@ class JitTestCase(TestCase):
             trace.set_graph(graph)
         return graph
 
+    def checkScript(self,
+                    script,
+                    inputs,
+                    optimize=True,
+                    outputs=None,
+                    name='func',
+                    capture_output=False,
+                    frames_up=1,
+                    check_expected=False):
+        if isinstance(script, str):
+            cu = torch.jit.CompilationUnit(script, optimize, _frames_up=frames_up)
+            ge = getattr(cu, name)
+        else:
+            if capture_output:
+                with self.capture_stdout() as captured:
+                    outputs = script(*inputs)
+            else:
+                outputs = script(*inputs)
+            # Check the string frontend first
+            source = textwrap.dedent(inspect.getsource(script))
+            self.checkScript(
+                source,
+                inputs,
+                optimize,
+                outputs,
+                script.__name__,
+                capture_output,
+                frames_up=2,
+                check_expected=check_expected)
+            # Continue checking the Python frontend
+            ge = torch.jit.script(script, optimize, _frames_up=1)
+
+        if capture_output:
+            with self.capture_stdout() as captured:
+                outputs_ge = ge(*inputs)
+            if not WINDOWS:
+                self.assertExpected(captured[0], subname='stdout')
+        else:
+            outputs_ge = ge(*inputs)
+        self.assertEqual(outputs, outputs_ge)
+
+        if check_expected:
+            self.assertExpectedGraph(ge.graph)
+
+        return ge
+
     def checkTrace(self, func, reference_tensors, input_tensors=None,
                    optimize=True, drop=None, allow_unused=False, verbose=False,
                    inputs_require_grads=True, check_tolerance=1e-5, export_import=True):
@@ -433,14 +479,6 @@ class JitTestCase(TestCase):
 
         return ge
 
-    def assertAllFused(self, graph, except_for=()):
-        if [n.kind() for n in graph.nodes()] == ['prim::DifferentiableGraph']:
-            graph = next(graph.nodes()).g('Subgraph')
-        allowed_nodes = {'prim::Constant', 'prim::FusionGroup'} | set(except_for)
-        self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
-                        'got {}'.format(graph))
-        self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
-
     def assertExportImport(self, trace, inputs):
         graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
         m = torch.jit.ScriptModule()
@@ -760,2041 +798,1685 @@ class TestJit(JitTestCase):
         self.assertExportImport(trace, (t,) + tuple(model.parameters()))
         self.assertExpectedONNXGraph(trace)
 
-    @unittest.skipIf(not IS_WINDOWS, "Testing Fuse skipped on windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    def test_windows_fuse(self):
-        def scaleshift(x, scale, shift):
-            return x * scale + shift
+    def test_canonicalize_tensor_iterator(self):
+        x = torch.randn(4, 4)
 
-        graph = torch.jit.script(scaleshift).graph
+        def f(x):
+            x = x + 2
+            x = x - 4
+            x = x * 6
+            x = x / 8
+            return x
 
-        inputs = [
-            torch.randn(4, 4, dtype=torch.float, device='cuda'),
-            torch.randn(4, dtype=torch.float, device='cuda'),
-            torch.randn(4, dtype=torch.float, device='cuda'),
-        ]
+        traced = torch.jit.trace(f, (x,))
+        f(x)
+        graph = traced.graph_for(x)
+        # There should be 4 int constants for the right sides of operators, plus two
+        # for alpha arguments for add and sub
+        self.assertTrue(str(traced.graph_for(x)).count(': int = prim::Constant'), 6)
 
-        ge = self.checkTrace(scaleshift, inputs)
-        fuse_graph = ge.graph_for(*inputs)
+    # TODO: adapt this test to check that GraphExecutor treats them differently
+    @unittest.skip("Need to be adjusted to Graph Executor")
+    def test_arg_configurations(self):
+        """Different arg configurations should trigger different traces"""
+        x = Variable(torch.FloatTensor(4, 4).uniform_())
+        x_double = Variable(x.data.double())
+        x_grad = Variable(x.data.clone(), requires_grad=True)
+        y = Variable(torch.randn(4))
 
-        def run_graph(graph, inputs):
-            m = torch.jit.ScriptModule()
-            m._create_method_from_graph("forward", graph)
-            return m(*inputs)
+        configurations = [
+            (x,),
+            (x_double,),
+            (x_grad,),
+            (y,),
+            ([x, x],),
+            ([x, y],),
+        ]
+        if torch.cuda.is_available():
+            x_cuda = Variable(x.data.cuda())
+            configurations += [
+                (x_cuda,),
+                ([x, x_cuda],),
+                ([x_cuda, x],),
+                ([[x_cuda, x]],),
+            ]
+            if torch.cuda.device_count() > 1:
+                x_cuda_1 = Variable(x.data.cuda(1))
+                configurations += [
+                    (x_cuda_1,),
+                    ([x_cuda, x_cuda_1],),
+                ]
 
-        self.assertEqual(run_graph(graph, inputs), run_graph(fuse_graph, inputs))
+        @torch.jit.compile(nderivs=0)
+        def fn(*args):
+            in_vars, _ = torch._C._jit_flatten(args)
+            return in_vars[0] + 1
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_broadcast_fusion_cuda(self):
-        def scaleshift(x, scale, shift):
-            return x * scale + shift
+        for i, config in enumerate(configurations):
+            self.assertFalse(fn.has_trace_for(*config))
+            fn(*config)
+            self.assertTrue(fn.has_trace_for(*config))
+            for unk_config in configurations[i + 1:]:
+                self.assertFalse(fn.has_trace_for(*unk_config))
+        self.assertEqual(fn.hits, 0)
 
-        inputs = [
-            torch.randn(4, 4, dtype=torch.float, device='cuda'),
-            torch.randn(4, dtype=torch.float, device='cuda'),
-            torch.randn(4, dtype=torch.float, device='cuda'),
-        ]
-        ge = self.checkTrace(scaleshift, inputs)
-        self.assertExpectedGraph(ge.graph_for(*inputs))
+    def test_cse(self):
+        x = torch.tensor([0.4, 0.3], requires_grad=True)
+        y = torch.tensor([0.7, 0.5], requires_grad=True)
 
-    # TODO: Fuser doesn't work at all when inputs require grad. Fix that
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_lstm_fusion_cuda(self):
-        inputs = get_lstm_inputs('cuda')
-        ge = self.checkTrace(LSTMCellF, inputs)
-        self.assertExpectedGraph(ge.graph_for(*inputs))
+        def fn(x, y):
+            w = (x + y) * (x + y) * (x + y)
+            t = torch.tanh(w) + torch.tanh(w)
+            z = (x + y) * (x + y) * (x + y) + t
+            return z
 
-    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
-    @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
-    @enable_cpu_fuser
-    def test_lstm_fusion_cpu(self):
-        inputs = get_lstm_inputs('cpu')
-        try:
-            ge = self.checkTrace(LSTMCellF, inputs)
-            self.assertExpectedGraph(ge.graph_for(*inputs))
-        except RuntimeError as e:
-            if 'Failed to compile' in e.args[0]:
-                warnings.warn('CPU fuser test has failed! This is not a hard failure, '
-                              'because the kernels sometimes trigger bugs in compilers '
-                              '(most notably GCC 7.2).')
-                raise unittest.SkipTest('Failed to compile')
-            else:
-                raise
+        trace, _ = torch.jit.get_trace_graph(fn, (x, y))
+        self.run_pass('cse', trace)
+        self.assertExpectedGraph(trace)
+        self.assertExportImport(trace, (x, y))
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_lstm_fusion_concat_cuda(self):
-        inputs = get_lstm_inputs('cuda')
-        ge = self.checkTrace(LSTMCellC, inputs)
-        self.assertExpectedGraph(ge.graph_for(*inputs))
+    def test_recursive_cse(self):
+        x = torch.tensor([0.1])
+        y = torch.tensor([0.2])
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_concat_fusion_cuda(self):
-        hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
-        cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
+        def fn(x, y):
+            z = x
+            if bool(x + y > x):
+                z = x + y
+            return z
 
-        def foo(hx, cx):
-            return torch.cat((hx + cx, hx * cx))
+        graph = torch.jit.script(fn).graph
+        self.run_pass('cse', graph)
+        self.assertExpectedGraph(graph)
 
-        ge = self.checkTrace(foo, (hx, cx))
-        self.assertExpectedGraph(ge.graph_for(hx, cx))
+    def test_scalar(self):
+        # NB: must not require grad; if it requires grad, it's always a Tensor
+        x = torch.tensor(2.)
+        y = torch.tensor(3.)
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_concat_fusion_invariant_cuda(self):
-        # Invariant: the output of prim::FusedConcat may
-        # not be an input to any node inside the FusionGroup.
-        def fn(x, y, z):
-            x1 = x + y
-            y1 = x - y
-            w = torch.cat([x1, y1])
-            return w + z
+        def fn(x, y):
+            return x - y
+        trace, _ = torch.jit.get_trace_graph(fn, (x, y))
 
-        x = torch.randn(2, 2, dtype=torch.float, device='cuda')
-        y = torch.randn(2, 2, dtype=torch.float, device='cuda')
-        z = torch.randn(4, 2, dtype=torch.float, device='cuda')
-        ge = self.checkTrace(fn, (x, y, z))
-        self.assertExpectedGraph(ge.graph_for(x, y, z))
+    def test_shape_analysis_broadcast(self):
+        def broadcast(a, b):
+            return a + b
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_fusion_distribute_cuda(self):
-        def f(x, y):
-            z1, z2 = (x + y).chunk(2, dim=1)
-            return z1 * z2
+        x = torch.randn(3, 1, 5, requires_grad=True)
+        y = torch.randn(4, 1, 8, 5, requires_grad=True)
 
-        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
-        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        graph = torch.jit.script(broadcast).graph
+        torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False)
+        self.assertExpectedGraph(graph)
 
-        ge = self.checkTrace(f, (x, y))
-        self.assertExpectedGraph(ge.graph_for(x, y))
+    # TODO: update verify to work with GraphExecutors
+    @unittest.skip("verify needs to be updated to work with GraphExecutors")
+    def test_verify(self):
+        x = torch.tensor([0.4], requires_grad=True)
+        y = torch.tensor([0.7], requires_grad=True)
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_fusion_rand_cuda(self):
-        class M(torch.jit.ScriptModule):
-            __constants__ = ['d']
+        @torch.jit.compile
+        def f(x, y):
+            z = torch.sigmoid(x * (x + y))
+            w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
+            return z, w
 
-            def __init__(self):
-                self.d = torch.device('cuda')
+        torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
 
-            @torch.jit.script_method
-            def create(self, x):
-                return x * x + x + torch.rand_like(x)
+    @suppress_warnings
+    def test_constant(self):
+        x = torch.randn(2, 2, requires_grad=True)
 
-        x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
-        m = M()
-        out1 = m.create(x)
-        out2 = m.create(x)
-        self.assertNotEqual(out1, out2)
-        self.assertTrue(torch.all(out1 >= 0))
-        self.assertTrue(torch.all(out1 < 1))
-        self.assertTrue(torch.all(out2 >= 0))
-        self.assertTrue(torch.all(out2 < 1))
+        def f(x):
+            return x.matmul(torch.diag(torch.tensor([2., 2.])))
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_fusion_arg_configurations_cuda(self):
-        # A smoke test to make sure we won't use the same kernel for contiguous
-        # and non-contiguous arguments.
-        # TODO: add optionally enabled debug counters to the fuser to verify
-        #       that we really can tell the difference between configurations
-        def f(x, y):
-            z1, z2 = (x + y).chunk(2, dim=1)
-            return z1 * z2
+        self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
 
-        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
-        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
-        traced_f = torch.jit.trace(f, (x, y,))
-        self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
+    def test_legacy_fail(self):
+        class MyLegacyFn(Function):
+            def forward(self, x):
+                return x
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_fusion_checks_cat_inputs(self):
-        # We shouldn't treat cat nodes as broadcasting. All their inputs
-        # need to be checked for having the same map size, before we can
-        # run the kernel.
-        @torch.jit.script
-        def f(x, y):
-            return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0)
+            def backward(self, grad_output):
+                return grad_output
 
-        # NOTE: y is broadcastable to x, but output of f(x, y) should have
-        # shape 3x4, and not 4x4.
-        x = torch.randn(2, 4, dtype=torch.float, device='cuda')
-        y = torch.randn(1, 4, dtype=torch.float, device='cuda')
+        x = torch.tensor([0.], requires_grad=True)
+        with self.assertRaisesRegex(RuntimeError, "MyLegacyFn"):
+            torch.jit.get_trace_graph(lambda x: MyLegacyFn()(x), (x,))
 
-        self.assertEqual(f(x, y).shape, (3, 4))
-        self.assertAllFused(f.graph_for(x, y))
+    def test_inplace_transplant(self):
+        x = torch.tensor([0.], requires_grad=True)
 
-    @staticmethod
-    def fn_test_comparison_gt_lt(x, y):
-        mask = (x > 0).type_as(x)
-        z = x * mask + y
-        mask = (x < 0).type_as(x)
-        z = z * mask + y
-        return z
+        def fn(x):
+            y = x.clone()
+            y.add_(2)
+            y.add_(3)
+            return y
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_comparison_gt_lt_cuda(self):
-        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
-        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        trace, _ = torch.jit.get_trace_graph(fn, (x,))
+        self.assertExpectedGraph(trace)
+        self.assertExportImport(trace, (x,))
 
-        ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
-        self.assertAllFused(ge.graph_for(x, y))
+    def test_inplace_flags(self):
+        class InplaceFn(Function):
+            @staticmethod
+            def forward(ctx, x):
+                ctx.mark_dirty(x)
+                return x.add_(1)
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_comparison_ge_le_cuda(self):
-        def f(x, y):
-            mask = (x >= 0).type_as(x)
-            z = x * mask + y
-            mask = (x <= 0).type_as(x)
-            z = z * mask + y
-            return z
+            @staticmethod
+            def backward(ctx, go):
+                return go
 
-        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
-        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        class RegularFn(Function):
+            @staticmethod
+            def forward(ctx, x):
+                return x.add(1)
 
-        ge = self.checkTrace(f, (x, y))
-        self.assertAllFused(ge.graph_for(x, y))
+            @staticmethod
+            def backward(ctx, go):
+                return go
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_comparison_eq_ne(self):
-        def f(x, y):
-            mask = (x == 0).type_as(x)
-            z = x * mask + y
-            mask = (x != 0).type_as(x)
-            z = z * mask + y
-            return z
+        x = torch.tensor([0.], requires_grad=True)
 
-        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
-        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        def fn(x):
+            y = RegularFn.apply(x)
+            y = InplaceFn.apply(y)
+            y = InplaceFn.apply(y)
+            y = RegularFn.apply(y)
+            return y
 
-        ge = self.checkTrace(f, (x, y))
-        self.assertAllFused(ge.graph_for(x, y))
+        trace, _ = torch.jit.get_trace_graph(fn, (x,), _force_outplace=True)
+        self.run_pass('dce', trace)
+        ops = [n for n in trace.graph().nodes()]
+        for op in ops:
+            self.assertTrue(op.hasAttribute('inplace'))
+        inplace_flags = [False, True, True, False]
+        for op, is_inplace in zip(ops, inplace_flags):
+            self.assertEqual(op.i('inplace'), is_inplace)
 
-    @staticmethod
-    def fn_test_relu(x, y):
-        return F.relu(x + .5 * y)
+    def test_inplace_check(self):
+        class MyInplaceFn(Function):
+            @staticmethod
+            def forward(self, x):
+                x.add_(1)
+                self.mark_dirty(x)
+                return x
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_relu_cuda(self):
-        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
-        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+            @staticmethod
+            def backward(self, grad):
+                return grad
 
-        ge = self.checkTrace(self.fn_test_relu, (x, y))
+        def fn(x):
+            return MyInplaceFn.apply(x)
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    def test_small_constant_cuda(self):
-        def fn_test_small_constant(x, y):
-            return (1e-8 * x + 5e-9 * y) * 1e8
-        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
-        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        x = torch.randn(5, 5)
+        ge = torch._C.GraphExecutor(fn, (x,), lambda var: '', _force_outplace=True)
+        with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
+            ge(x)
 
-        ge = self.checkTrace(fn_test_small_constant, (x, y))
+    def do_trace_size(self, requires_grad):
+        def fn(x):
+            return x.view(x.shape[1] * 2, x.size(0), 2)
 
-    @staticmethod
-    def fn_test_exp(x, y):
-        return (x + .5 * y).exp()
+        x = torch.randn(5, 2, 4, requires_grad=requires_grad)
+        y = torch.randn(4, 8, 4, requires_grad=requires_grad)
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_exp_cuda(self):
-        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
-        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        # Check that it behaves as expected
+        traced_fn = torch.jit.trace(fn, x)
+        self.assertEqual(traced_fn(y), fn(y))
+        self.assertEqual(traced_fn(x), fn(x))
 
-        ge = self.checkTrace(self.fn_test_exp, (x, y))
+        # Check that the trace looks ok
+        trace, _ = torch.jit.get_trace_graph(fn, (x,))
+        self.assertExpectedGraph(trace)
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @unittest.skipIf(not RUN_CUDA_HALF, "no half support")
-    def test_cuda_half(self):
-        x = torch.randn(4, 4, dtype=torch.half, device='cuda')
-        y = torch.randn(4, 4, dtype=torch.half, device='cuda')
+    def test_trace_size(self):
+        self.do_trace_size(False)
 
-        funcs = [
-            self.fn_test_comparison_gt_lt,
-            self.fn_test_relu,
-            self.fn_test_exp
-        ]
+    # test the different graph_executor path that happens when
+    # gradients are required and sizes are involved
+    def test_trace_size_with_grad(self):
+        self.do_trace_size(True)
 
-        # Note: Non fused inputs must be float to prevent loss of precision
-        inputs = (x.float(), y.float())
-        fusion_inputs = (x, y)
-        for fn in funcs:
-            local_inputs = [t.clone().requires_grad_() for t in inputs]
-            local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
+    def test_trace_casts(self):
+        casts = [
+            lambda x: x.byte(),
+            lambda x: x.float(),
+            lambda x: x.cpu(),
+            lambda x: x.to(device='cpu'),
+            lambda x: x.to(dtype=torch.int64),
+            lambda x: x.to(device='cpu', dtype=torch.float),
+            lambda x: x.to(x)
+        ]
 
-            # Verifies outputs
-            fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False, optimize=True)
-            outputs = fn(*local_inputs)
-            fusion_outputs = fusion(*local_fusion_inputs)
-            outputs_half = [t.half() for t in outputs]
-            self.assertEqual(outputs_half, fusion_outputs)
+        def assertContainsCast(trace):
+            self.assertEqual(sum(n.kind() == 'aten::to' for n in trace.graph.nodes()), 1)
 
-            # Verifies gradients
-            for output, fusion_output in zip(outputs_half, fusion_outputs):
-                grads = torch.autograd.grad(
-                    output.float().sum(), local_inputs, allow_unused=True, retain_graph=True)
-                fusion_grads = torch.autograd.grad(
-                    fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True)
-                grads_half = [t.half() for t in grads]
-                self.assertEqual(grads_half, fusion_grads)
+        for cast in casts:
+            trace = torch.jit.trace(cast, torch.randn(2, 2))
+            assertContainsCast(trace)
+            x = torch.randn(2, 2)
+            self.assertEqual(trace(x), cast(x))
 
-    def test_canonicalize_tensor_iterator(self):
-        x = torch.randn(4, 4)
+        def to_tensor(x, y):
+            return x.to(y)
 
-        def f(x):
-            x = x + 2
-            x = x - 4
-            x = x * 6
-            x = x / 8
-            return x
+        to_tensor_trace = torch.jit.trace(to_tensor, (torch.randn(2, 2), torch.randn(1, 8)))
+        assertContainsCast(to_tensor_trace)
+        x, y = torch.randn(2, 2), torch.randn(1, 10)
+        self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
 
-        traced = torch.jit.trace(f, (x,))
-        f(x)
-        graph = traced.graph_for(x)
-        # There should be 4 int constants for the right sides of operators, plus two
-        # for alpha arguments for add and sub
-        self.assertTrue(str(traced.graph_for(x)).count(': int = prim::Constant'), 6)
+    def test_trace_warn(self):
+        def fn(x):
+            int(x)  # Warning 1.
+            y = x * 1
+            if y:   # Warning 2.
+                pass
+            q = [x, x * 4]
+            z = q[y]  # Warning 3.
+            float(z)  # Warning 4.
+            z.tolist()  # Warning 5.
+            z.numpy()  # Warning 6.
+            for elem in torch.ones(4, 4):  # Warning 7.
+                pass
+            return z + 4
 
-    # TODO: adapt this test to check that GraphExecutor treats them differently
-    @unittest.skip("Need to be adjusted to Graph Executor")
-    def test_arg_configurations(self):
-        """Different arg configurations should trigger different traces"""
-        x = Variable(torch.FloatTensor(4, 4).uniform_())
-        x_double = Variable(x.data.double())
-        x_grad = Variable(x.data.clone(), requires_grad=True)
-        y = Variable(torch.randn(4))
+        with warnings.catch_warnings(record=True) as warns:
+            traced_fn = torch.jit.trace(fn, torch.tensor([1]))
+        warns = [str(w.message) for w in warns]
+        self.assertEqual(len(warns), 7)
+        self.assertIn('a Python integer', warns[0])
+        self.assertIn('a Python boolean', warns[1])
+        self.assertIn('a Python index', warns[2])
+        self.assertIn('a Python float', warns[3])
+        self.assertIn('a Python list', warns[4])
+        self.assertIn('a NumPy array', warns[5])
+        self.assertIn('Iterating over', warns[6])
 
-        configurations = [
-            (x,),
-            (x_double,),
-            (x_grad,),
-            (y,),
-            ([x, x],),
-            ([x, y],),
-        ]
-        if torch.cuda.is_available():
-            x_cuda = Variable(x.data.cuda())
-            configurations += [
-                (x_cuda,),
-                ([x, x_cuda],),
-                ([x_cuda, x],),
-                ([[x_cuda, x]],),
-            ]
-            if torch.cuda.device_count() > 1:
-                x_cuda_1 = Variable(x.data.cuda(1))
-                configurations += [
-                    (x_cuda_1,),
-                    ([x_cuda, x_cuda_1],),
-                ]
+    def test_trace_tuple(self):
+        def fn(x, y):
+            return x, (x * y[1], x * y[0])
 
-        @torch.jit.compile(nderivs=0)
-        def fn(*args):
-            in_vars, _ = torch._C._jit_flatten(args)
-            return in_vars[0] + 1
+        x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
+        traced_fn = torch.jit.trace(fn, (x, y))
+        self.assertEqual(traced_fn(x, y), fn(x, y))
+        self.assertExpectedGraph(traced_fn.graph)
+        self.assertExportImport(traced_fn.graph, (x, y))
 
-        for i, config in enumerate(configurations):
-            self.assertFalse(fn.has_trace_for(*config))
-            fn(*config)
-            self.assertTrue(fn.has_trace_for(*config))
-            for unk_config in configurations[i + 1:]:
-                self.assertFalse(fn.has_trace_for(*unk_config))
-        self.assertEqual(fn.hits, 0)
+    def test_trace_random(self):
+        def f(mean, std):
+            return torch.normal(mean, std)
 
-    def test_cse(self):
-        x = torch.tensor([0.4, 0.3], requires_grad=True)
-        y = torch.tensor([0.7, 0.5], requires_grad=True)
+        traced = torch.jit.trace(f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False)
+        mean, std = torch.zeros(5, 5), torch.ones(5, 5)
+        with torch.random.fork_rng(devices=[]):
+            output = f(mean, std)
+        traced_output = traced(mean, std)
+        self.assertEqual(output, traced_output)
 
-        def fn(x, y):
-            w = (x + y) * (x + y) * (x + y)
-            t = torch.tanh(w) + torch.tanh(w)
-            z = (x + y) * (x + y) * (x + y) + t
-            return z
+    def test_trace_tensor_factory(self):
+        def run(**kwargs):
+            inputs_require_grads = kwargs.pop('inputs_require_grads', True)
 
-        trace, _ = torch.jit.get_trace_graph(fn, (x, y))
-        self.run_pass('cse', trace)
-        self.assertExpectedGraph(trace)
-        self.assertExportImport(trace, (x, y))
+            def fn(x):
+                return x + torch.ones(2, 3, **kwargs)
 
-    def test_recursive_cse(self):
-        x = torch.tensor([0.1])
-        y = torch.tensor([0.2])
+            input_kwargs = kwargs.copy()
+            if 'out' in input_kwargs:
+                del input_kwargs['out']
+            input = torch.ones(2, 3, **input_kwargs)
+            self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
+            # check we recorded 'ones' and did not just record a constant
+            tfn = torch.jit.trace(fn, input)
+            self.assertTrue("ones" in str(tfn.graph))
+        run()
+        run(dtype=torch.int, inputs_require_grads=False)
+        run(out=torch.tensor([]))
+        if RUN_CUDA:
+            run(device="cuda:0")
+        if RUN_CUDA_MULTI_GPU:
+            run(device="cuda:1")
 
-        def fn(x, y):
-            z = x
-            if bool(x + y > x):
-                z = x + y
-            return z
+    # TODO: implement
+    @unittest.expectedFailure
+    def test_output_unflatten(self):
+        """Check that outputs of traced functions retain the original structure and nesting"""
+        def fn(x):
+            return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
 
-        graph = torch.jit.script(fn).graph
-        self.run_pass('cse', graph)
-        self.assertExpectedGraph(graph)
+        self.checkTrace(fn, (torch.randn(2, 2),))
 
-    def test_scalar(self):
-        # NB: must not require grad; if it requires grad, it's always a Tensor
-        x = torch.tensor(2.)
-        y = torch.tensor(3.)
+    # TODO: implement
+    @unittest.expectedFailure
+    def test_input_flatten(self):
+        """Check that inputs to traced functions are flattened"""
 
-        def fn(x, y):
-            return x - y
-        trace, _ = torch.jit.get_trace_graph(fn, (x, y))
+        def fn(x, t):
+            y, z = t
+            return x * y * z
 
-    def test_shape_analysis_broadcast(self):
-        def broadcast(a, b):
-            return a + b
+        inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
+        self.checkTrace(fn, inputs)
 
-        x = torch.randn(3, 1, 5, requires_grad=True)
-        y = torch.randn(4, 1, 8, 5, requires_grad=True)
+    # TODO: adapt to a GraphExecutor test
+    @unittest.skip("Need to instrument GraphExecutors a bit more")
+    def test_flags(self):
+        x, y = torch.randn(2, 2)
+        y = Variable(torch.randn(2, 2))
 
-        graph = torch.jit.script(broadcast).graph
-        torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False)
-        self.assertExpectedGraph(graph)
+        @torch.jit.compile
+        def fn(x, y):
+            return (x * x + y * y + x * y).sum()
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
-    @skipIfRocm
-    def test_fuse_last_device_cuda(self):
-        device = 'cuda:' + str(1)
-        x = torch.tensor([0.4], dtype=torch.float, device=device)
-        y = torch.tensor([0.7], dtype=torch.float, device=device)
+        grads = {}
+        for rx, ry in product((True, False), repeat=2):
+            x.requires_grad = rx
+            y.requires_grad = ry
 
-        def doit(x, y):
-            return torch.sigmoid(torch.tanh(x * (x + y) + x))
+            self.assertFalse(fn.has_trace_for(x, y))
+            out = fn(x, y)
 
-        ge = self.checkTrace(doit, (x, y))
-        self.assertExpectedGraph(ge.graph_for(x, y))
+            self.assertFalse(fn.has_trace_for(x, y))
+            for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
+                if not compute:
+                    continue
+                grad_v, = torch.autograd.grad(out, v, retain_graph=True)
+                expected_grad = grads.setdefault(name, grad_v)
+                self.assertEqual(grad_v, expected_grad)
+            self.assertEqual(fn.has_trace_for(x, y), rx or ry)
 
-    # TODO: update verify to work with GraphExecutors
-    @unittest.skip("verify needs to be updated to work with GraphExecutors")
-    def test_verify(self):
+    def test_python_ir(self):
         x = torch.tensor([0.4], requires_grad=True)
         y = torch.tensor([0.7], requires_grad=True)
 
-        @torch.jit.compile
-        def f(x, y):
-            z = torch.sigmoid(x * (x + y))
-            w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
-            return z, w
+        def doit(x, y):
+            return torch.sigmoid(torch.tanh(x * (x + y)))
 
-        torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
+        trace, _ = torch.jit.get_trace_graph(doit, (x, y))
+        self.run_pass('dce', trace)
+        self.run_pass('canonicalize', trace)
+        g = trace.graph()
+        g2 = torch._C.Graph()
+        g_to_g2 = {}
+        for node in g.inputs():
+            g_to_g2[node] = g2.addInput()
+        for node in g.nodes():
+            n_ = g2.createClone(node, lambda x: g_to_g2[x])
+            g2.appendNode(n_)
+            for o, no in zip(node.outputs(), n_.outputs()):
+                g_to_g2[o] = no
 
-    @suppress_warnings
-    def test_constant(self):
-        x = torch.randn(2, 2, requires_grad=True)
+        for node in g.outputs():
+            g2.registerOutput(g_to_g2[node])
 
-        def f(x):
-            return x.matmul(torch.diag(torch.tensor([2., 2.])))
+        t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
+        self.assertEqual(t_node.attributeNames(), ["a"])
+        g2.appendNode(t_node)
+        self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a")))
+        self.assertExpected(str(g2))
 
-        self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
+    @skipIfRocm
+    def test_cpp_cuda(self):
+        # rather than rebuild assertExpected in cpp,
+        # just glob all the cpp outputs into one file for now
+        self.assertExpected(torch._C._jit_run_cpp_tests())
 
-    def test_legacy_fail(self):
-        class MyLegacyFn(Function):
-            def forward(self, x):
-                return x
+    def test_batchnorm(self):
+        x = torch.ones(2, 2, 2, 2)
+        trace, _ = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x, _force_outplace=True)
+        self.assertExpectedGraph(trace)
 
-            def backward(self, grad_output):
-                return grad_output
+    def test_dropout(self):
+        x = torch.ones(2, 2)
+        trace, _ = torch.jit.get_trace_graph(nn.Dropout(0.6), x)
+        self.assertExpectedGraph(trace)
 
-        x = torch.tensor([0.], requires_grad=True)
-        with self.assertRaisesRegex(RuntimeError, "MyLegacyFn"):
-            torch.jit.get_trace_graph(lambda x: MyLegacyFn()(x), (x,))
+    def test_conv(self):
+        x = torch.ones(20, 16, 50, 40)
+        trace, _ = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x)
+        self.assertExpectedGraph(trace)
 
-    def test_inplace_transplant(self):
-        x = torch.tensor([0.], requires_grad=True)
+    def test_repeated_input(self):
+        def fn(a, b):
+            return a + b
 
-        def fn(x):
-            y = x.clone()
-            y.add_(2)
-            y.add_(3)
-            return y
+        ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
+        self.assertExpectedGraph(ge.graph)
 
-        trace, _ = torch.jit.get_trace_graph(fn, (x,))
-        self.assertExpectedGraph(trace)
-        self.assertExportImport(trace, (x,))
+    def test_repeated_output(self):
+        def fn(a, b):
+            z = a + b
+            return z, z
 
-    def test_inplace_flags(self):
-        class InplaceFn(Function):
-            @staticmethod
-            def forward(ctx, x):
-                ctx.mark_dirty(x)
-                return x.add_(1)
+        ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
+        self.assertExpectedGraph(ge.graph)
 
-            @staticmethod
-            def backward(ctx, go):
-                return go
+    @skipIfNoTorchVision
+    def test_alexnet(self):
+        x = torch.ones(1, 3, 224, 224)
+        trace, _ = torch.jit.get_trace_graph(torchvision.models.AlexNet(), x)
+        self.run_pass('cse', trace)
+        self.assertExpectedGraph(trace)
 
-        class RegularFn(Function):
-            @staticmethod
-            def forward(ctx, x):
-                return x.add(1)
+    # Inplace copies don't work with tracer yet.
+    # This is actually somewhat important to support correctly
+    # as all backwards functions of views are implemented
+    # as a zero filled tensor with a gradient fill on the
+    # viewed portion.
+    def test_inplace_copy(self):
+        x = torch.randn(4, 4, requires_grad=True)
 
-            @staticmethod
-            def backward(ctx, go):
-                return go
+        def f(x):
+            out = Variable(torch.zeros(x.size()))
+            out.copy_(x)
+            return out
 
-        x = torch.tensor([0.], requires_grad=True)
+        trace, z = torch.jit.get_trace_graph(f, (x, ))
+        self.run_pass('dce', trace)
+        self.assertExpectedGraph(trace)
+        self.assertExportImport(trace, (x,))
 
-        def fn(x):
-            y = RegularFn.apply(x)
-            y = InplaceFn.apply(y)
-            y = InplaceFn.apply(y)
-            y = RegularFn.apply(y)
-            return y
+    def test_shared_param(self):
 
-        trace, _ = torch.jit.get_trace_graph(fn, (x,), _force_outplace=True)
-        self.run_pass('dce', trace)
-        ops = [n for n in trace.graph().nodes()]
-        for op in ops:
-            self.assertTrue(op.hasAttribute('inplace'))
-        inplace_flags = [False, True, True, False]
-        for op, is_inplace in zip(ops, inplace_flags):
-            self.assertEqual(op.i('inplace'), is_inplace)
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super(MyModule, self).__init__()
+                self.b = self.a = nn.Parameter(torch.randn(2, 2))
 
-    def test_inplace_check(self):
-        class MyInplaceFn(Function):
-            @staticmethod
             def forward(self, x):
-                x.add_(1)
-                self.mark_dirty(x)
-                return x
+                return x * self.a + self.b
 
-            @staticmethod
-            def backward(self, grad):
-                return grad
+        m = MyModule()
+        trace, _ = torch.jit.get_trace_graph(m, (torch.randn(2, 2),))
+        self.assertEqual(len(list(trace.graph().inputs())), 2)
+        self.assertExpectedGraph(trace)
 
-        def fn(x):
-            return MyInplaceFn.apply(x)
+    def test_nested_inplace(self):
+        x = torch.randn(2, 2)
+        trace, _ = torch.jit.get_trace_graph(
+            lambda x: F.threshold(x, 0, 0, inplace=True), (x, ))
+        self.assertExpectedGraph(trace)
+        self.assertExportImport(trace, (x,))
 
-        x = torch.randn(5, 5)
-        ge = torch._C.GraphExecutor(fn, (x,), lambda var: '', _force_outplace=True)
-        with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
-            ge(x)
+    def run_ge_tests(self, optimize, use_cuda):
+        def rand(*args):
+            t = torch.rand(*args).float()
+            if use_cuda:
+                t = t.cuda()
+            return t
+        self.checkTrace(lambda a, b: a * b + b,
+                        [rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
+                        optimize=optimize)
+        # trivial identity
+        self.checkTrace(lambda a, b: (
+            b, a), [rand(1), rand(1)], optimize=optimize)
 
-    def do_trace_size(self, requires_grad):
-        def fn(x):
-            return x.view(x.shape[1] * 2, x.size(0), 2)
+        def foo(a):
+            t = a * a
+            return t * t, 4 * t
+        self.checkTrace(foo, [rand(1)], optimize=optimize)
+        # unused input
+        self.checkTrace(
+            lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize,
+            allow_unused=True)
+        # test outputs that do not get used in grad
+        self.checkTrace(foo, [rand(1)], drop=1, optimize=optimize)
+        # test autograd fallback
+        self.checkTrace(lambda a, b: a * b /
+                        (a - 2 * b) + b, [rand(1), rand(1)],
+                        optimize=optimize)
 
-        x = torch.randn(5, 2, 4, requires_grad=requires_grad)
-        y = torch.randn(4, 8, 4, requires_grad=requires_grad)
+    def test_ge_unoptimized(self):
+        self.run_ge_tests(False, False)
 
-        # Check that it behaves as expected
-        traced_fn = torch.jit.trace(fn, x)
-        self.assertEqual(traced_fn(y), fn(y))
-        self.assertEqual(traced_fn(x), fn(x))
+    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+    @enable_cpu_fuser
+    def test_ge_optimized(self):
+        self.run_ge_tests(True, False)
 
-        # Check that the trace looks ok
-        trace, _ = torch.jit.get_trace_graph(fn, (x,))
-        self.assertExpectedGraph(trace)
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+    @skipIfRocm
+    def test_ge_cuda(self):
+        self.run_ge_tests(True, True)
 
-    def test_trace_size(self):
-        self.do_trace_size(False)
+    # more manual test of graph executor that can be used as a scratchpad
+    def test_ge(self):
+        def foo(a, b):
+            return a * b / (a - b) + b
+        V = Variable
+        a, b = V(torch.rand(1)), V(torch.rand(1))
+        ge = torch._C.GraphExecutor(foo, (a, b), lambda var: '')
+        a, b = V(torch.rand(1), requires_grad=True), V(
+            torch.rand(1), requires_grad=True)
+        r, = ge(a, b)
+        da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
 
-    # test the different graph_executor path that happens when
-    # gradients are required and sizes are involved
-    def test_trace_size_with_grad(self):
-        self.do_trace_size(True)
+        l2 = (da * db + db * db)
+        g2result = torch.autograd.grad(l2, [da, db])
 
-    def test_trace_casts(self):
-        casts = [
-            lambda x: x.byte(),
-            lambda x: x.float(),
-            lambda x: x.cpu(),
-            lambda x: x.to(device='cpu'),
-            lambda x: x.to(dtype=torch.int64),
-            lambda x: x.to(device='cpu', dtype=torch.float),
-            lambda x: x.to(x)
-        ]
+        r = foo(a, b)
+        da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
+        self.assertEqual(da, da2)
+        self.assertEqual(db, db2)
+        l3 = (da2 * db2 + db2 * db2)
+        g2result2 = torch.autograd.grad(l3, [da2, db2])
+        self.assertEqual(g2result, g2result2)
 
-        def assertContainsCast(trace):
-            self.assertEqual(sum(n.kind() == 'aten::to' for n in trace.graph.nodes()), 1)
+    def test_trace_annotation(self):
+        @_trace(torch.rand(1))
+        def foo(a):
+            return a + a + a
 
-        for cast in casts:
-            trace = torch.jit.trace(cast, torch.randn(2, 2))
-            assertContainsCast(trace)
-            x = torch.randn(2, 2)
-            self.assertEqual(trace(x), cast(x))
+        x = torch.randn(5, 5)
+        self.assertEqual(foo(x), x + x + x)
 
-        def to_tensor(x, y):
-            return x.to(y)
+    def test_trace_script(self):
+        @torch.jit.script
+        def func1(x):
+            # type: (Tuple[Tensor, Tensor]) -> Tensor
+            return x[0] + x[1]
 
-        to_tensor_trace = torch.jit.trace(to_tensor, (torch.randn(2, 2), torch.randn(1, 8)))
-        assertContainsCast(to_tensor_trace)
-        x, y = torch.randn(2, 2), torch.randn(1, 10)
-        self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
+        @torch.jit.script
+        def func2(x):
+            # type: (List[Tensor]) -> Tensor
+            return x[0] + x[1]
 
-    def test_trace_warn(self):
-        def fn(x):
-            int(x)  # Warning 1.
-            y = x * 1
-            if y:   # Warning 2.
-                pass
-            q = [x, x * 4]
-            z = q[y]  # Warning 3.
-            float(z)  # Warning 4.
-            z.tolist()  # Warning 5.
-            z.numpy()  # Warning 6.
-            for elem in torch.ones(4, 4):  # Warning 7.
-                pass
-            return z + 4
+        a = torch.randn(5)
+        b = torch.randn(5)
 
-        with warnings.catch_warnings(record=True) as warns:
-            traced_fn = torch.jit.trace(fn, torch.tensor([1]))
-        warns = [str(w.message) for w in warns]
-        self.assertEqual(len(warns), 7)
-        self.assertIn('a Python integer', warns[0])
-        self.assertIn('a Python boolean', warns[1])
-        self.assertIn('a Python index', warns[2])
-        self.assertIn('a Python float', warns[3])
-        self.assertIn('a Python list', warns[4])
-        self.assertIn('a NumPy array', warns[5])
-        self.assertIn('Iterating over', warns[6])
+        expected = func1((a, b))
+        traced = torch.jit.trace(func1, ((a, b),))
+        result = traced((a, b))
+        self.assertEqual(expected, result)
 
-    def test_trace_tuple(self):
-        def fn(x, y):
-            return x, (x * y[1], x * y[0])
+        expected = func2((a, b))
+        traced = torch.jit.trace(func2, ((a, b),))
+        result = traced((a, b))
+        self.assertEqual(expected, result)
 
-        x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
-        traced_fn = torch.jit.trace(fn, (x, y))
-        self.assertEqual(traced_fn(x, y), fn(x, y))
-        self.assertExpectedGraph(traced_fn.graph)
-        self.assertExportImport(traced_fn.graph, (x, y))
+    def test_einsum(self):
+        def outer(x, y):
+            return torch.einsum('i,j->ij', (x, y))
 
-    def test_trace_random(self):
-        def f(mean, std):
-            return torch.normal(mean, std)
+        traced = torch.jit.trace(outer, (torch.randn(4), torch.randn(5)))
+        script = torch.jit.script(outer)
+        fns = [traced, script]
+        x, y = torch.randn(10), torch.randn(2)
+        for fn in [traced, script]:
+            self.assertGraphContains(fn.graph, kind='aten::einsum')
+            self.assertEqual(fn(x, y), outer(x, y))
 
-        traced = torch.jit.trace(f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False)
-        mean, std = torch.zeros(5, 5), torch.ones(5, 5)
-        with torch.random.fork_rng(devices=[]):
-            output = f(mean, std)
-        traced_output = traced(mean, std)
-        self.assertEqual(output, traced_output)
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "calls .cuda()")
+    @skipIfRocm
+    def test_traced_module_cuda(self):
+        class Model(nn.Module):
+            def __init__(self, num_features, num_layers):
+                super(Model, self).__init__()
+                self.num_layers = num_layers
+                layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()]
+                          for _ in range(num_layers)]
+                self.submodule = nn.Sequential(*chain(*layers))
 
-    def test_trace_tensor_factory(self):
-        def run(**kwargs):
-            inputs_require_grads = kwargs.pop('inputs_require_grads', True)
+            def forward(self, x):
+                for i in range(self.num_layers):
+                    x = self.submodule[i](x) + x
+                return x
 
-            def fn(x):
-                return x + torch.ones(2, 3, **kwargs)
+        model = Model(5, 3)
+        x = torch.randn(2, 5)
+        traced_model = torch.jit.trace(model, x)
 
-            input_kwargs = kwargs.copy()
-            if 'out' in input_kwargs:
-                del input_kwargs['out']
-            input = torch.ones(2, 3, **input_kwargs)
-            self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
-            # check we recorded 'ones' and did not just record a constant
-            tfn = torch.jit.trace(fn, input)
-            self.assertTrue("ones" in str(tfn.graph))
-        run()
-        run(dtype=torch.int, inputs_require_grads=False)
-        run(out=torch.tensor([]))
-        if RUN_CUDA:
-            run(device="cuda:0")
-        if RUN_CUDA_MULTI_GPU:
-            run(device="cuda:1")
+        # We're missing some attributes these modules had initially. Make sure we can
+        # still get the __repr__()
+        model.__repr__()
 
-    # TODO: implement
-    @unittest.expectedFailure
-    def test_output_unflatten(self):
-        """Check that outputs of traced functions retain the original structure and nesting"""
-        def fn(x):
-            return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
+        # XXX: indexing sequentials is broken
+        linear_submodule = next(iter(traced_model.submodule._modules.values()))
 
-        self.checkTrace(fn, (torch.randn(2, 2),))
+        # All attributes that aren't parameters should raise
+        with self.assertRaises(AttributeError):
+            linear_submodule.in_features
+        linear_submodule.weight
+        with self.assertRaises(RuntimeError):
+            traced_model.asdf = 4
+        linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape))
+        with self.assertRaises(RuntimeError):
+            del linear_submodule.weight
 
-    # TODO: implement
-    @unittest.expectedFailure
-    def test_input_flatten(self):
-        """Check that inputs to traced functions are flattened"""
+        # Submodules can't be called
+        with self.assertRaises(RuntimeError):
+            linear_submodule(x)
 
-        def fn(x, t):
-            y, z = t
-            return x * y * z
+        # Type casts
+        linear_submodule.cuda()
+        traced_model.float().cuda()
+        cuda_out = traced_model(x.float().cuda())
+        traced_model.cpu()
+        cpu_out = traced_model(x.float())
+        self.assertEqual(cpu_out, cuda_out)
+        traced_model.double()
 
-        inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
-        self.checkTrace(fn, inputs)
+        # state_dict + load_state_dict
+        state = {k: v.clone() for k, v in traced_model.state_dict().items()}
+        new_state = {k: v.clone().fill_(1) for k, v in state.items()}
+        out = traced_model(x)
+        traced_model.load_state_dict(new_state)
+        out_ones = traced_model(x)
+        traced_model.load_state_dict(state)
+        out_state = traced_model(x)
+        self.assertEqual(out, out_state)
+        self.assertNotEqual(out, out_ones)
 
-    # TODO: adapt to a GraphExecutor test
-    @unittest.skip("Need to instrument GraphExecutors a bit more")
-    def test_flags(self):
-        x, y = torch.randn(2, 2)
-        y = Variable(torch.randn(2, 2))
+    def test_python_function(self):
+        class MyFn(Function):
+            @staticmethod
+            def forward(ctx, x):
+                return x + 1
 
-        @torch.jit.compile
-        def fn(x, y):
-            return (x * x + y * y + x * y).sum()
+            @staticmethod
+            def backward(ctx, grad_output):
+                return grad_output
 
-        grads = {}
-        for rx, ry in product((True, False), repeat=2):
-            x.requires_grad = rx
-            y.requires_grad = ry
+        @_trace(torch.zeros(2))
+        def fn(x):
+            return MyFn.apply(x + 2) + 3
 
-            self.assertFalse(fn.has_trace_for(x, y))
-            out = fn(x, y)
+        x = torch.tensor([1., 2., 3.])
+        y = torch.randn(2, 2, requires_grad=True)
+        fn(x)
+        fn(y)
 
-            self.assertFalse(fn.has_trace_for(x, y))
-            for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
-                if not compute:
-                    continue
-                grad_v, = torch.autograd.grad(out, v, retain_graph=True)
-                expected_grad = grads.setdefault(name, grad_v)
-                self.assertEqual(grad_v, expected_grad)
-            self.assertEqual(fn.has_trace_for(x, y), rx or ry)
+    def test_python_function_tup(self):
+        class MyFn(Function):
+            @staticmethod
+            def forward(ctx, x):
+                return x + 1, x - 1
 
-    def test_python_ir(self):
-        x = torch.tensor([0.4], requires_grad=True)
-        y = torch.tensor([0.7], requires_grad=True)
+            @staticmethod
+            def backward(ctx, grad_output):
+                return grad_output, grad_output
 
-        def doit(x, y):
-            return torch.sigmoid(torch.tanh(x * (x + y)))
+        @_trace(torch.zeros(2))
+        def fn(x):
+            a, b = MyFn.apply(x + 2)
+            return a + b + 3
+        x = torch.tensor([1., 2., 3.])
+        y = torch.randn(2, 2, requires_grad=True)
+        fn(x)
+        fn(y)
 
-        trace, _ = torch.jit.get_trace_graph(doit, (x, y))
-        self.run_pass('dce', trace)
-        self.run_pass('canonicalize', trace)
-        g = trace.graph()
-        g2 = torch._C.Graph()
-        g_to_g2 = {}
-        for node in g.inputs():
-            g_to_g2[node] = g2.addInput()
-        for node in g.nodes():
-            n_ = g2.createClone(node, lambda x: g_to_g2[x])
-            g2.appendNode(n_)
-            for o, no in zip(node.outputs(), n_.outputs()):
-                g_to_g2[o] = no
+    def test_decompose_addmm(self):
+        @torch.jit.script
+        def addmm(mat, mat1, mat2, alpha, beta):
+            a = mat.addmm(mat1, mat2)
+            b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
+            c = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
+            d = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
 
-        for node in g.outputs():
-            g2.registerOutput(g_to_g2[node])
+            return a + b + c + d
 
-        t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
-        self.assertEqual(t_node.attributeNames(), ["a"])
-        g2.appendNode(t_node)
-        self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a")))
-        self.assertExpected(str(g2))
+        mat = torch.randn(2, 2)
+        mat1 = torch.randn(2, 4)
+        mat2 = torch.randn(4, 2)
+        alpha = torch.FloatTensor([123.0])
+        beta = torch.FloatTensor([321.0])
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
-    @skipIfRocm
-    def test_cpp_cuda(self):
-        # rather than rebuild assertExpected in cpp,
-        # just glob all the cpp outputs into one file for now
-        self.assertExpected(torch._C._jit_run_cpp_tests())
+        out_ref = addmm(mat, mat1, mat2, alpha, beta)
+        self.run_pass('canonicalize_ops', addmm.graph)
+        out_test = addmm(mat, mat1, mat2, alpha, beta)
+        self.assertEqual(out_ref, out_test)
+        self.assertExpected(canonical(addmm.graph))
 
-    def test_batchnorm(self):
-        x = torch.ones(2, 2, 2, 2)
-        trace, _ = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x, _force_outplace=True)
-        self.assertExpectedGraph(trace)
+    def test_index_put(self):
+        ten = torch.zeros(3, 3)
+        mask = torch.Tensor([[True, True, True],
+                             [True, False, False],
+                             [True, True, False]]).byte()
 
-    def test_dropout(self):
-        x = torch.ones(2, 2)
-        trace, _ = torch.jit.get_trace_graph(nn.Dropout(0.6), x)
-        self.assertExpectedGraph(trace)
+        def test_fn(ten, mask):
+            ten[mask] = torch.ones(6)
+            return ten
 
-    def test_conv(self):
-        x = torch.ones(20, 16, 50, 40)
-        trace, _ = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x)
-        self.assertExpectedGraph(trace)
+        traced_test_fn = torch.jit.trace(test_fn, (ten, mask))
 
-    def test_repeated_input(self):
-        def fn(a, b):
-            return a + b
+        ten = torch.rand(3, 3)
+        self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
 
-        ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
-        self.assertExpectedGraph(ge.graph)
-
-    def test_repeated_output(self):
-        def fn(a, b):
-            z = a + b
-            return z, z
+    def test_sparse_tensors_error(self):
+        def get_sparse():
+            return torch.sparse.FloatTensor(2, 3)
 
-        ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
-        self.assertExpectedGraph(ge.graph)
+        @torch.jit.script
+        def sparse(input):
+            output = get_sparse()
+            return output, input
 
-    @skipIfNoTorchVision
-    def test_alexnet(self):
-        x = torch.ones(1, 3, 224, 224)
-        trace, _ = torch.jit.get_trace_graph(torchvision.models.AlexNet(), x)
-        self.run_pass('cse', trace)
-        self.assertExpectedGraph(trace)
+        with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
+            sparse(get_sparse())
 
-    # Inplace copies don't work with tracer yet.
-    # This is actually somewhat important to support correctly
-    # as all backwards functions of views are implemented
-    # as a zero filled tensor with a gradient fill on the
-    # viewed portion.
-    def test_inplace_copy(self):
-        x = torch.randn(4, 4, requires_grad=True)
+        with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
+            sparse(torch.tensor([1]))
 
-        def f(x):
-            out = Variable(torch.zeros(x.size()))
-            out.copy_(x)
-            return out
+    def test_tuple_specialization(self):
+        @torch.jit.script
+        def f(t):
+            # type: (Tuple[Tensor, Tensor]) -> Tensor
+            x, y = t
+            return x + y
 
-        trace, z = torch.jit.get_trace_graph(f, (x, ))
-        self.run_pass('dce', trace)
-        self.assertExpectedGraph(trace)
-        self.assertExportImport(trace, (x,))
+        t = torch.randn(2, 2), torch.randn(2, 2)
+        f(t)
+        graph = f.graph_for(t)
+        input_types = list(next(graph.inputs()).type().elements())
+        for t in input_types:
+            self.assertEqual(t.kind(), 'TensorType')
 
-    def test_shared_param(self):
+    def test_constant_prop_simple(self):
+        @torch.jit.script
+        def constant_prop(input_tensor):
+            a = 2 * 3
+            b = a + 2
+            return b + input_tensor
 
-        class MyModule(torch.nn.Module):
-            def __init__(self):
-                super(MyModule, self).__init__()
-                self.b = self.a = nn.Parameter(torch.randn(2, 2))
+        x = torch.tensor(2)
+        out_ref = constant_prop(x)
+        self.run_pass('constant_propagation', constant_prop.graph)
+        out_test = constant_prop(torch.tensor(2))
+        self.assertEqual(out_ref, out_test)
+        self.assertExpected(canonical(constant_prop.graph))
 
-            def forward(self, x):
-                return x * self.a + self.b
+    def test_constant_prop_nested(self):
+        @torch.jit.script
+        def constant_prop(a):
+            b = 2 + 1
+            if bool(a < 2):
+                c = b + 2
+            else:
+                c = b - 2
+            return c
+        out_ref = constant_prop(torch.tensor(2))
+        self.run_pass('constant_propagation', constant_prop.graph)
+        out_test = constant_prop(torch.tensor(2))
+        self.assertEqual(out_ref, out_test)
+        self.assertExpected(canonical(constant_prop.graph))
 
-        m = MyModule()
-        trace, _ = torch.jit.get_trace_graph(m, (torch.randn(2, 2),))
-        self.assertEqual(len(list(trace.graph().inputs())), 2)
-        self.assertExpectedGraph(trace)
+    def test_constant_prop_print(self):
+        @torch.jit.script
+        def constant_prop(input_tensor):
+            a = 2 * 3
+            print(a)
+            b = a + 2
+            return b + input_tensor
 
-    def test_nested_inplace(self):
-        x = torch.randn(2, 2)
-        trace, _ = torch.jit.get_trace_graph(
-            lambda x: F.threshold(x, 0, 0, inplace=True), (x, ))
-        self.assertExpectedGraph(trace)
-        self.assertExportImport(trace, (x,))
+        self.run_pass('constant_propagation', constant_prop.graph)
+        self.assertExpected(canonical(constant_prop.graph))
 
-    def run_ge_tests(self, optimize, use_cuda):
-        def rand(*args):
-            t = torch.rand(*args).float()
-            if use_cuda:
-                t = t.cuda()
-            return t
-        self.checkTrace(lambda a, b: a * b + b,
-                        [rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
-                        optimize=optimize)
-        # trivial identity
-        self.checkTrace(lambda a, b: (
-            b, a), [rand(1), rand(1)], optimize=optimize)
+    def test_constant_prop_rand(self):
+        @torch.jit.script
+        def constant_prop():
+            a = torch.randn([3])
+            b = a + 2
+            return b
 
-        def foo(a):
-            t = a * a
-            return t * t, 4 * t
-        self.checkTrace(foo, [rand(1)], optimize=optimize)
-        # unused input
-        self.checkTrace(
-            lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize,
-            allow_unused=True)
-        # test outputs that do not get used in grad
-        self.checkTrace(foo, [rand(1)], drop=1, optimize=optimize)
-        # test autograd fallback
-        self.checkTrace(lambda a, b: a * b /
-                        (a - 2 * b) + b, [rand(1), rand(1)],
-                        optimize=optimize)
+        self.run_pass('constant_propagation', constant_prop.graph)
+        self.assertExpected(canonical(constant_prop.graph))
 
-    def test_ge_unoptimized(self):
-        self.run_ge_tests(False, False)
+    def test_trace_records_names(self):
+        def foo(bar, baz):
+            baz = bar + 3
+            quick_brown_fox = torch.neg(baz)
+            for i in range(20):
+                yeet = quick_brown_fox - 3.14
+            return yeet
 
-    def _test_fused_abs(self, device='cpu'):
+        traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3)))
+        graph_str = str(traced.graph)
+        assert 'bar' in graph_str
+        assert 'baz' in graph_str
+        assert 'quick_brown_fox' in graph_str
 
+    def test_constant_prop_if_constant(self):
         @torch.jit.script
-        def func(x):
-            return x.abs() * 2
-
-        a = torch.randn(5, device=device)
-        self.assertEqual(func(a), a.abs() * 2)
-        self.assertAllFused(func.graph_for(a))
+        def constant_prop(a, b):
+            c0 = 1
+            c1 = 1
+            c2 = 1
+            if bool(a):  # -> c0, c1
+                if bool(b):  # -> c0
+                    if True:  # -> c0
+                        c0 = c0 + 1
+                        if False:
+                            c1 = c1 + 1
+                            c2 = c2 + 1
+            else:  # -> c0, c1
+                c1 = c1 + 1
 
-    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
-    @enable_cpu_fuser
-    def test_fused_abs_cpu(self):
-        self._test_fused_abs()
+            if True:  # inlined
+                c0 = c0 + 1  # dynamic
+                c2 = c2 + 4  # set to 5
+            return a + c0 + c1 + c2
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @skipIfRocm
-    def test_fused_abs_cuda(self):
-        self._test_fused_abs(device="cuda")
+        self.run_pass('constant_propagation', constant_prop.graph)
+        self.assertExpected(canonical(constant_prop.graph))
 
-    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
-    @enable_cpu_fuser
-    def test_ge_optimized(self):
-        self.run_ge_tests(True, False)
+    def test_constant_prop_loop_constant(self):
+        @torch.jit.script
+        def constant_prop():
+            b = 0
+            while True:
+                b = 1
+            while False:
+                b = 2
+            return b
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    @skipIfRocm
-    def test_ge_cuda(self):
-        self.run_ge_tests(True, True)
+        self.run_pass('constant_propagation', constant_prop.graph)
+        self.assertExpected(canonical(constant_prop.graph))
 
-    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
-    @enable_cpu_fuser
-    def test_fused_where_and_typing(self):
-        def f(x, y):
-            mask = x > y
-            res = torch.where(mask, x, y)
-            return mask, res
+    def test_trace_detach(self):
+        def foo(x, w):
+            return torch.matmul(x, w).detach()
 
-        script_f = torch.jit.script(f)
+        traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
 
-        x = torch.randn(4, 4, dtype=torch.double)
-        y = torch.randn(4, 4, dtype=torch.double)
+        self.assertExpectedGraph(traced.graph)
+        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
+        traced_result = traced(x, w)
+        self.assertEqual(foo(x, w), traced_result)
+        self.assertFalse(traced_result.requires_grad)
+        self.assertIsNone(traced_result.grad_fn)
 
-        result1, result2 = script_f(x, y)
-        expected1, expected2 = f(x, y)
-        self.assertEqual(result1, expected1)
-        self.assertEqual(result2, expected2)
-        self.assertAllFused(script_f.graph_for(x, y))
+    def test_trace_detach_inplace(self):
+        def foo(x, w):
+            y = torch.matmul(x, w)
+            y.detach_()
+            return y
 
-    # more manual test of graph executor that can be used as a scratchpad
-    def test_ge(self):
-        def foo(a, b):
-            return a * b / (a - b) + b
-        V = Variable
-        a, b = V(torch.rand(1)), V(torch.rand(1))
-        ge = torch._C.GraphExecutor(foo, (a, b), lambda var: '')
-        a, b = V(torch.rand(1), requires_grad=True), V(
-            torch.rand(1), requires_grad=True)
-        r, = ge(a, b)
-        da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
+        traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
 
-        l2 = (da * db + db * db)
-        g2result = torch.autograd.grad(l2, [da, db])
+        self.assertExpectedGraph(traced.graph)
+        x, w = torch.rand(3, 4), torch.rand(4, 5)
+        traced_result = traced(x, w)
+        self.assertEqual(foo(x, w), traced_result)
+        self.assertFalse(traced_result.requires_grad)
+        self.assertIsNone(traced_result.grad_fn)
 
-        r = foo(a, b)
-        da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
-        self.assertEqual(da, da2)
-        self.assertEqual(db, db2)
-        l3 = (da2 * db2 + db2 * db2)
-        g2result2 = torch.autograd.grad(l3, [da2, db2])
-        self.assertEqual(g2result, g2result2)
+    def test_trace_detach_onnx_erase(self):
+        class Mod(torch.nn.Module):
+            def forward(self, x, w):
+                return torch.matmul(x, w).detach()
 
-    def test_trace_annotation(self):
-        @_trace(torch.rand(1))
-        def foo(a):
-            return a + a + a
+        f = io.BytesIO()
+        self.assertExpected(torch.onnx.export_to_pretty_string(
+            Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f))
 
-        x = torch.randn(5, 5)
-        self.assertEqual(foo(x), x + x + x)
+    def test_trace_slice_full_dim(self):
+        def foo(x):
+            return x[0:5, 0] + 1.0
 
-    def test_trace_script(self):
-        @torch.jit.script
-        def func1(x):
-            # type: (Tuple[Tensor, Tensor]) -> Tensor
-            return x[0] + x[1]
+        traced = torch.jit.trace(foo, (torch.rand(5, 4),))
+        test_x = torch.rand(6, 3)
+        self.assertEqual(foo(test_x), traced(test_x))
 
-        @torch.jit.script
-        def func2(x):
-            # type: (List[Tensor]) -> Tensor
-            return x[0] + x[1]
+    def test_export_expand_aten_fallback(self):
+        class ExpandTest(torch.jit.ScriptModule):
+            @torch.jit.script_method
+            def forward(self, x):
+                y = x
+                for i in range(5):
+                    y = x.expand([3, 4, i])
+                return y
 
-        a = torch.randn(5)
-        b = torch.randn(5)
+        mod = ExpandTest()
+        example_outs = mod(torch.rand(3, 4, 1))
+        f = io.BytesIO()
+        with self.assertRaisesRegex(RuntimeError, 'Could not export a broadcasted operation'):
+            torch.onnx.export_to_pretty_string(mod, (torch.rand(3, 4, 1),), f, verbose=False,
+                                               example_outputs=example_outs)
 
-        expected = func1((a, b))
-        traced = torch.jit.trace(func1, ((a, b),))
-        result = traced((a, b))
-        self.assertEqual(expected, result)
+        self.assertExpected(
+            torch.onnx.export_to_pretty_string(mod, (torch.rand(3, 4, 1),), f, verbose=False,
+                                               example_outputs=example_outs,
+                                               operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK))
 
-        expected = func2((a, b))
-        traced = torch.jit.trace(func2, ((a, b),))
-        result = traced((a, b))
-        self.assertEqual(expected, result)
+    def test_export_dropout(self):
+        test = torch.nn.Dropout()
+        test.eval()
 
-    def test_einsum(self):
-        def outer(x, y):
-            return torch.einsum('i,j->ij', (x, y))
+        traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False)
+        imported = self.getExportImportCopy(traced)
+        x = torch.randn(3, 4)
+        self.assertEqual(traced(x), imported(x))
 
-        traced = torch.jit.trace(outer, (torch.randn(4), torch.randn(5)))
-        script = torch.jit.script(outer)
-        fns = [traced, script]
-        x, y = torch.randn(10), torch.randn(2)
-        for fn in [traced, script]:
-            self.assertGraphContains(fn.graph, kind='aten::einsum')
-            self.assertEqual(fn(x, y), outer(x, y))
+    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+    def test_cuda_export_restore(self):
+        class Sub(torch.jit.ScriptModule):
+            def __init__(self):
+                super(Sub, self).__init__()
+                self.weight = nn.Parameter(torch.randn(3, 4))
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "calls .cuda()")
-    @skipIfRocm
-    def test_traced_module_cuda(self):
-        class Model(nn.Module):
-            def __init__(self, num_features, num_layers):
-                super(Model, self).__init__()
-                self.num_layers = num_layers
-                layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()]
-                          for _ in range(num_layers)]
-                self.submodule = nn.Sequential(*chain(*layers))
+            @torch.jit.script_method
+            def forward(self, thing):
+                return self.weight + thing
 
-            def forward(self, x):
-                for i in range(self.num_layers):
-                    x = self.submodule[i](x) + x
-                return x
+        class M(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M, self).__init__()
+                self.mod = Sub()
 
-        model = Model(5, 3)
-        x = torch.randn(2, 5)
-        traced_model = torch.jit.trace(model, x)
+            @torch.jit.script_method
+            def forward(self, v):
+                return self.mod(v)
+        m = M()
+        m.cuda()
+        m2 = self.getExportImportCopy(m)
+        m2.cuda()
+        input = torch.rand(3, 4).cuda()
+        self.assertEqual(m(input), m2(input))
 
-        # We're missing some attributes these modules had initially. Make sure we can
-        # still get the __repr__()
-        model.__repr__()
+    def test_export_batchnorm(self):
+        for mode in ['eval', 'train']:
+            for clazz in [
+                    torch.nn.BatchNorm1d(100),
+                    torch.nn.BatchNorm1d(100, affine=False),
+                    torch.nn.BatchNorm2d(100),
+                    torch.nn.BatchNorm2d(100, affine=False)]:
+                getattr(clazz, mode)()
 
-        # XXX: indexing sequentials is broken
-        linear_submodule = next(iter(traced_model.submodule._modules.values()))
+                input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
+                    torch.randn(20, 100, 35, 45)
 
-        # All attributes that aren't parameters should raise
-        with self.assertRaises(AttributeError):
-            linear_submodule.in_features
-        linear_submodule.weight
-        with self.assertRaises(RuntimeError):
-            traced_model.asdf = 4
-        linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape))
-        with self.assertRaises(RuntimeError):
-            del linear_submodule.weight
+                traced = torch.jit.trace(clazz, (input,))
+                imported = self.getExportImportCopy(traced)
+                x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
+                    torch.randn(20, 100, 35, 45)
+                self.assertEqual(traced(x), imported(x))
 
-        # Submodules can't be called
-        with self.assertRaises(RuntimeError):
-            linear_submodule(x)
+    def test_export_rnn(self):
+        for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]:
+            class RNNTest(torch.nn.Module):
+                def __init__(self):
+                    super(RNNTest, self).__init__()
+                    self.rnn = clazz
 
-        # Type casts
-        linear_submodule.cuda()
-        traced_model.float().cuda()
-        cuda_out = traced_model(x.float().cuda())
-        traced_model.cpu()
-        cpu_out = traced_model(x.float())
-        self.assertEqual(cpu_out, cuda_out)
-        traced_model.double()
+                def forward(self, x, lengths, h0):
+                    packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
+                    out, h = self.rnn(packed, h0)
+                    padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
+                    return padded_outs
 
-        # state_dict + load_state_dict
-        state = {k: v.clone() for k, v in traced_model.state_dict().items()}
-        new_state = {k: v.clone().fill_(1) for k, v in state.items()}
-        out = traced_model(x)
-        traced_model.load_state_dict(new_state)
-        out_ones = traced_model(x)
-        traced_model.load_state_dict(state)
-        out_state = traced_model(x)
-        self.assertEqual(out, out_state)
-        self.assertNotEqual(out, out_ones)
+            test = RNNTest()
 
-    def test_python_function(self):
-        class MyFn(Function):
-            @staticmethod
-            def forward(ctx, x):
-                return x + 1
+            traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20)))
+            imported = self.getExportImportCopy(traced)
+            # NB: We make sure to pass in a batch with a different max sequence
+            # length to ensure that the argument stashing for pad_packed works
+            # properly.
+            x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20)
+            self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0))
 
-            @staticmethod
-            def backward(ctx, grad_output):
-                return grad_output
+    def test_export_lstm(self):
+        class LSTMTest(torch.nn.Module):
+            def __init__(self):
+                super(LSTMTest, self).__init__()
+                self.rnn = nn.LSTM(10, 20, 2)
 
-        @_trace(torch.zeros(2))
-        def fn(x):
-            return MyFn.apply(x + 2) + 3
+            def forward(self, x, lengths, hiddens):
+                h0, c0 = hiddens
+                packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
+                out, (h, c) = self.rnn(packed, (h0, c0))
+                padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
+                return padded_outs
 
-        x = torch.tensor([1., 2., 3.])
-        y = torch.randn(2, 2, requires_grad=True)
-        fn(x)
-        fn(y)
+        test = LSTMTest()
 
-    def test_python_function_tup(self):
-        class MyFn(Function):
-            @staticmethod
-            def forward(ctx, x):
-                return x + 1, x - 1
+        traced = torch.jit.trace(test, (torch.randn(5, 3, 10),
+                                        torch.LongTensor([3, 2, 1]),
+                                        (torch.randn(2, 3, 20), torch.randn(2, 3, 20))))
+        imported = self.getExportImportCopy(traced)
+        x, lengths, h0, c0 = \
+            torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
+        self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
 
-            @staticmethod
-            def backward(ctx, grad_output):
-                return grad_output, grad_output
+    def test_trace_variable_instantiation(self):
+        def random_foo(x):
+            return Variable(Variable(x) + 1.0)
 
-        @_trace(torch.zeros(2))
-        def fn(x):
-            a, b = MyFn.apply(x + 2)
-            return a + b + 3
-        x = torch.tensor([1., 2., 3.])
-        y = torch.randn(2, 2, requires_grad=True)
-        fn(x)
-        fn(y)
+        random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
 
-    def test_decompose_addmm(self):
-        @torch.jit.script
-        def addmm(mat, mat1, mat2, alpha, beta):
-            a = mat.addmm(mat1, mat2)
-            b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
-            c = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
-            d = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
+        x = torch.rand(5, 6)
+        self.assertEqual(random_foo(x), random_foo_traced(x))
 
-            return a + b + c + d
+    def test_trace_slice_expr_complete_type(self):
+        def random_foo(x):
+            return x + 1.0
 
-        mat = torch.randn(2, 2)
-        mat1 = torch.randn(2, 4)
-        mat2 = torch.randn(4, 2)
-        alpha = torch.FloatTensor([123.0])
-        beta = torch.FloatTensor([321.0])
+        random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
 
-        out_ref = addmm(mat, mat1, mat2, alpha, beta)
-        self.run_pass('canonicalize_ops', addmm.graph)
-        out_test = addmm(mat, mat1, mat2, alpha, beta)
-        self.assertEqual(out_ref, out_test)
-        self.assertExpected(canonical(addmm.graph))
+        @torch.jit.script
+        def random_bar(x):
+            return random_foo_traced(x)[0:1]
 
-    def test_index_put(self):
-        ten = torch.zeros(3, 3)
-        mask = torch.Tensor([[True, True, True],
-                             [True, False, False],
-                             [True, True, False]]).byte()
+        x = torch.rand(3, 4)
+        self.assertEqual(random_bar(x), (x + 1)[0:1])
 
-        def test_fn(ten, mask):
-            ten[mask] = torch.ones(6)
-            return ten
+    def test_export_tensoroption_to(self):
+        def foo(x):
+            return x.new_tensor(x[0]).cpu() + x
 
-        traced_test_fn = torch.jit.trace(test_fn, (ten, mask))
+        traced = torch.jit.trace(foo, (torch.rand([2])))
+        example_outputs = traced(torch.rand([2]))
 
-        ten = torch.rand(3, 3)
-        self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
+        f = io.BytesIO()
+        self.assertExpected(torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
+                                                                example_outputs=example_outputs))
 
-    def test_sparse_tensors_error(self):
-        def get_sparse():
-            return torch.sparse.FloatTensor(2, 3)
+    def test_pretty_printer(self):
+        @torch.jit.script
+        def if_test(a, b):
+            # FIXME: use 0 instead of a.
+            # c = 0
+            c = a
+            if bool(a < b):
+                c = b
+            else:
+                c = a
+            return c
 
         @torch.jit.script
-        def sparse(input):
-            output = get_sparse()
-            return output, input
+        def if_one(a, b):
+            c = b
+            if bool(a < b):
+                c = a
+            return c
 
-        with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
-            sparse(get_sparse())
+        @torch.jit.script
+        def while_test(a, i):
+            while bool(i < 3):
+                a *= a
+                i += 1
+            return a
 
-        with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
-            sparse(torch.tensor([1]))
+        @torch.jit.script
+        def while_if_test(a, b):
+            c = 0
+            while bool(a < 10):
+                a = a + 1
+                b = b + 1
+                if bool(a > b):
+                    c = 2
+                else:
+                    c = 3
+            return a + 1 + c
 
-    def test_tuple_specialization(self):
         @torch.jit.script
-        def f(t):
-            # type: (Tuple[Tensor, Tensor]) -> Tensor
-            x, y = t
-            return x + y
+        def loop_use_test(y):
+            x = y + 1
+            z = x + 5
+            while bool(y < 8):
+                y += 1
+                z = x
+            return x, z
 
-        t = torch.randn(2, 2), torch.randn(2, 2)
-        f(t)
-        graph = f.graph_for(t)
-        input_types = list(next(graph.inputs()).type().elements())
-        for t in input_types:
-            self.assertEqual(t.kind(), 'TensorType')
+        def python_fn(x):
+            return x + 10
 
-    def test_constant_prop_simple(self):
         @torch.jit.script
-        def constant_prop(input_tensor):
-            a = 2 * 3
-            b = a + 2
-            return b + input_tensor
+        def python_op_name_test(y):
+            return python_fn(y)
 
-        x = torch.tensor(2)
-        out_ref = constant_prop(x)
-        self.run_pass('constant_propagation', constant_prop.graph)
-        out_test = constant_prop(torch.tensor(2))
-        self.assertEqual(out_ref, out_test)
-        self.assertExpected(canonical(constant_prop.graph))
+        @torch.jit.script
+        def empty_int_list_test(y):
+            x = torch.jit.annotate(List[int], [])
+            return x[0]
 
-    def test_constant_prop_nested(self):
         @torch.jit.script
-        def constant_prop(a):
-            b = 2 + 1
-            if bool(a < 2):
-                c = b + 2
-            else:
-                c = b - 2
-            return c
-        out_ref = constant_prop(torch.tensor(2))
-        self.run_pass('constant_propagation', constant_prop.graph)
-        out_test = constant_prop(torch.tensor(2))
-        self.assertEqual(out_ref, out_test)
-        self.assertExpected(canonical(constant_prop.graph))
+        def empty_float_list_test(y):
+            return [1.0, 2.0, 3.0]
 
-    def test_constant_prop_print(self):
         @torch.jit.script
-        def constant_prop(input_tensor):
-            a = 2 * 3
-            print(a)
-            b = a + 2
-            return b + input_tensor
+        def print_weird_test(y):
+            print("hi\016")
 
-        self.run_pass('constant_propagation', constant_prop.graph)
-        self.assertExpected(canonical(constant_prop.graph))
+        self.assertExpected(if_test.graph.pretty_print(), "if_test")
+        self.assertExpected(if_one.graph.pretty_print(), "if_one")
+        self.assertExpected(while_test.graph.pretty_print(), "while_test")
+        self.assertExpected(while_if_test.graph.pretty_print(), "while_if_test")
+        self.assertExpected(loop_use_test.graph.pretty_print(), "loop_use_test")
+        self.assertExpected(python_op_name_test.graph.pretty_print(), "python_op_name_test")
+        self.assertExpected(empty_int_list_test.graph.pretty_print(), "empty_int_list_test")
+        self.assertExpected(empty_float_list_test.graph.pretty_print(), "empty_float_list_test")
+        self.assertExpected(print_weird_test.graph.pretty_print(), "print_weird_test")
 
-    def test_constant_prop_rand(self):
-        @torch.jit.script
-        def constant_prop():
-            a = torch.randn([3])
-            b = a + 2
-            return b
+    def test_cu_escaped_number(self):
+        cu = torch.jit.CompilationUnit('''
+            def foo(a):
+                print("hi\016")
+        ''')
+        self.assertExpected(cu.foo.graph.pretty_print())
 
-        self.run_pass('constant_propagation', constant_prop.graph)
-        self.assertExpected(canonical(constant_prop.graph))
+    def test_import_method(self):
+        @torch.jit.script
+        def foo(x, y):
+            return 2 * x + y
 
-    def test_trace_records_names(self):
-        def foo(bar, baz):
-            baz = bar + 3
-            quick_brown_fox = torch.neg(baz)
-            for i in range(20):
-                yeet = quick_brown_fox - 3.14
-            return yeet
+        r, _ = foo._python_print()
+        mod = torch.jit.ScriptModule()
+        torch._C._jit_import_methods(mod, "op_version_set = 0\n{}".format(r), [])
+        self.assertExpected(mod.graph.pretty_print())
 
-        traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3)))
-        graph_str = str(traced.graph)
-        assert 'bar' in graph_str
-        assert 'baz' in graph_str
-        assert 'quick_brown_fox' in graph_str
+    def test_function_default_values(self):
+        outer_var = torch.tensor(20)
+        outer_var2 = torch.tensor(30)
+        a = torch.tensor(0.5)
+        b = torch.tensor(10)
 
-    def test_constant_prop_if_constant(self):
         @torch.jit.script
-        def constant_prop(a, b):
-            c0 = 1
-            c1 = 1
-            c2 = 1
-            if bool(a):  # -> c0, c1
-                if bool(b):  # -> c0
-                    if True:  # -> c0
-                        c0 = c0 + 1
-                        if False:
-                            c1 = c1 + 1
-                            c2 = c2 + 1
-            else:  # -> c0, c1
-                c1 = c1 + 1
+        def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
+            return x + a + b + c
 
-            if True:  # inlined
-                c0 = c0 + 1  # dynamic
-                c2 = c2 + 4  # set to 5
-            return a + c0 + c1 + c2
+        self.assertExpectedGraph(simple_fn.graph, "simple")
+        self.assertEqual(
+            simple_fn(torch.ones(1)),
+            torch.ones(1) + 0.5 + 10 + (20 + 30))
+        self.assertEqual(
+            simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)),
+            torch.ones(1) + 1 + 3 + 4)
 
-        self.run_pass('constant_propagation', constant_prop.graph)
-        self.assertExpected(canonical(constant_prop.graph))
+        outer_c = torch.tensor(9)
+        outer_flag = torch.tensor(False)
 
-    def test_constant_prop_loop_constant(self):
         @torch.jit.script
-        def constant_prop():
-            b = 0
-            while True:
-                b = 1
-            while False:
-                b = 2
-            return b
-
-        self.run_pass('constant_propagation', constant_prop.graph)
-        self.assertExpected(canonical(constant_prop.graph))
+        def bool_fn(x, a=outer_c, flag=outer_flag):
+            if bool(flag):
+                result = x
+            else:
+                result = x + a
+            return result
 
-    def test_trace_detach(self):
-        def foo(x, w):
-            return torch.matmul(x, w).detach()
+        self.assertExpectedGraph(bool_fn.graph, "bool")
+        self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
+        self.assertEqual(
+            bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)),
+            torch.ones(1))
 
-        traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
+        @torch.jit.script
+        def none_fn(x=None):
+            # type: (Optional[int]) -> Optional[int]
+            return x
 
-        self.assertExpectedGraph(traced.graph)
-        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
-        traced_result = traced(x, w)
-        self.assertEqual(foo(x, w), traced_result)
-        self.assertFalse(traced_result.requires_grad)
-        self.assertIsNone(traced_result.grad_fn)
+        self.assertExpectedGraph(none_fn.graph, "none")
+        self.assertEqual(none_fn(), None)
+        self.assertEqual(none_fn(1), 1)
 
-    def test_trace_detach_inplace(self):
-        def foo(x, w):
-            y = torch.matmul(x, w)
-            y.detach_()
-            return y
-
-        traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
+        @torch.jit.script
+        def hints(x, a=0.5, b=10):
+            # type: (Tensor, float, int) -> Tensor
+            return x + a + b
 
-        self.assertExpectedGraph(traced.graph)
-        x, w = torch.rand(3, 4), torch.rand(4, 5)
-        traced_result = traced(x, w)
-        self.assertEqual(foo(x, w), traced_result)
-        self.assertFalse(traced_result.requires_grad)
-        self.assertIsNone(traced_result.grad_fn)
+        self.assertExpectedGraph(hints.graph, "type_hints")
+        self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
 
-    def test_trace_detach_onnx_erase(self):
-        class Mod(torch.nn.Module):
-            def forward(self, x, w):
-                return torch.matmul(x, w).detach()
+        with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
 
-        f = io.BytesIO()
-        self.assertExpected(torch.onnx.export_to_pretty_string(
-            Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f))
+            @torch.jit.script
+            def hints_bad_types(x, a=10, b=0.5):
+                # type: (Tensor, float, int) -> Tensor
+                return x + a + b
 
-    def test_trace_slice_full_dim(self):
-        def foo(x):
-            return x[0:5, 0] + 1.0
+    def test_module_default_values(self):
+        four = torch.tensor(4)
 
-        traced = torch.jit.trace(foo, (torch.rand(5, 4),))
-        test_x = torch.rand(6, 3)
-        self.assertEqual(foo(test_x), traced(test_x))
+        class Test(torch.jit.ScriptModule):
+            def __init__(self):
+                super(Test, self).__init__()
 
-    def test_export_expand_aten_fallback(self):
-        class ExpandTest(torch.jit.ScriptModule):
             @torch.jit.script_method
-            def forward(self, x):
-                y = x
-                for i in range(5):
-                    y = x.expand([3, 4, i])
-                return y
+            def forward(self, input, other=four):
+                return input + other
 
-        mod = ExpandTest()
-        example_outs = mod(torch.rand(3, 4, 1))
-        f = io.BytesIO()
-        with self.assertRaisesRegex(RuntimeError, 'Could not export a broadcasted operation'):
-            torch.onnx.export_to_pretty_string(mod, (torch.rand(3, 4, 1),), f, verbose=False,
-                                               example_outputs=example_outs)
+        t = Test()
+        self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
 
-        self.assertExpected(
-            torch.onnx.export_to_pretty_string(mod, (torch.rand(3, 4, 1),), f, verbose=False,
-                                               example_outputs=example_outs,
-                                               operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK))
+    def test_warnings(self):
+        import warnings
 
-    def test_export_dropout(self):
-        test = torch.nn.Dropout()
-        test.eval()
+        @torch.jit.script
+        def fn(x):
+            if bool(x < 2):
+                warnings.warn("x is less than 2")
+            return x
 
-        traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False)
-        imported = self.getExportImportCopy(traced)
-        x = torch.randn(3, 4)
-        self.assertEqual(traced(x), imported(x))
+        self.assertExpectedGraph(fn.graph)
 
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    def test_cuda_export_restore(self):
-        class Sub(torch.jit.ScriptModule):
-            def __init__(self):
-                super(Sub, self).__init__()
-                self.weight = nn.Parameter(torch.randn(3, 4))
 
-            @torch.jit.script_method
-            def forward(self, thing):
-                return self.weight + thing
+class TestBatched(TestCase):
+    # generate random examples and create an batchtensor with them
+    def rand_batch(self, *dims):
+        dims = [dim for dim in dims if dim != ()]
+        xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:]),
+                         requires_grad=True) for i in range(dims[0])]
+        xb = BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]).byte())
+        return xs, xb
 
-        class M(torch.jit.ScriptModule):
-            def __init__(self):
-                super(M, self).__init__()
-                self.mod = Sub()
+    def test_create_batchtensor(self):
+        # create from tensorlist
+        xs, batch = self.rand_batch(4, (True, 3), (False, 2), (True, 5))
+        self.assertEqual(xs, batch.examples())
+        # create from data, mask, dims
+        batch2 = BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims())
+        self.assertEqual(xs, batch2.examples())
+        # expand a tensor to a batchtensor given batch_size
+        xs = torch.rand(3, 4, 5)
+        batch3 = BatchTensor(xs, 2)
+        xs = xs.unsqueeze(0)
+        self.assertEqual([xs, xs], batch3.examples())
 
-            @torch.jit.script_method
-            def forward(self, v):
-                return self.mod(v)
-        m = M()
-        m.cuda()
-        m2 = self.getExportImportCopy(m)
-        m2.cuda()
-        input = torch.rand(3, 4).cuda()
-        self.assertEqual(m(input), m2(input))
+    def test_batch_elementwise_unary(self):
+        @torch.jit.batch(batch_size=4)
+        def tanh(a):
+            return torch.tanh(a)
 
-    def test_export_batchnorm(self):
-        for mode in ['eval', 'train']:
-            for clazz in [
-                    torch.nn.BatchNorm1d(100),
-                    torch.nn.BatchNorm1d(100, affine=False),
-                    torch.nn.BatchNorm2d(100),
-                    torch.nn.BatchNorm2d(100, affine=False)]:
-                getattr(clazz, mode)()
+        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
+        res_batch = tanh(batch)
+        res = [torch.tanh(xs[j]) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-                input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
-                    torch.randn(20, 100, 35, 45)
+    def test_batch_elementwise_binary(self):
+        @torch.jit.batch(batch_size=4)
+        def add(a, b):
+            return a + b
 
-                traced = torch.jit.trace(clazz, (input,))
-                imported = self.getExportImportCopy(traced)
-                x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
-                    torch.randn(20, 100, 35, 45)
-                self.assertEqual(traced(x), imported(x))
+        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
+        xs2, batch2 = xs, batch
+        res_batch = add(batch, batch2)
+        res = [torch.add(xs[j], xs2[j]) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-    def test_export_rnn(self):
-        for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]:
-            class RNNTest(torch.nn.Module):
-                def __init__(self):
-                    super(RNNTest, self).__init__()
-                    self.rnn = clazz
+        # test broadcast
+        xs, batch = self.rand_batch(4, (False, 3), (False, 2))
+        b = torch.rand(3, 2)
+        res_batch = add(batch, b)
+        res = [torch.add(xs[j], b) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-                def forward(self, x, lengths, h0):
-                    packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
-                    out, h = self.rnn(packed, h0)
-                    padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
-                    return padded_outs
+    def test_batch_mm(self):
+        @torch.jit.batch(batch_size=4)
+        def mm(a, b):
+            return torch.mm(a, b)
 
-            test = RNNTest()
+        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
+        xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
+        res_batch = mm(batch, batch2)
+        res = [torch.mm(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-            traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20)))
-            imported = self.getExportImportCopy(traced)
-            # NB: We make sure to pass in a batch with a different max sequence
-            # length to ensure that the argument stashing for pad_packed works
-            # properly.
-            x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20)
-            self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0))
+        # test broadcast
+        b = torch.rand(2, 4)
+        res_batch = mm(batch, b)
+        res = [torch.mm(xs[j].squeeze(0), b).unsqueeze(0) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-    def test_export_lstm(self):
-        class LSTMTest(torch.nn.Module):
-            def __init__(self):
-                super(LSTMTest, self).__init__()
-                self.rnn = nn.LSTM(10, 20, 2)
+    def test_batch_matmul(self):
+        @torch.jit.batch(batch_size=4)
+        def matmul(a, b):
+            return torch.matmul(a, b)
 
-            def forward(self, x, lengths, hiddens):
-                h0, c0 = hiddens
-                packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
-                out, (h, c) = self.rnn(packed, (h0, c0))
-                padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
-                return padded_outs
+        def matmul_test(xs, batch, xs2, batch2):
+            ys = [torch.matmul(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
+            ybs = matmul(batch, batch2)
+            self.assertEqual(ys, ybs.examples())
 
-        test = LSTMTest()
+        # 1 dimension * 1 dimension
+        xs, batch = self.rand_batch(4, (False, 2))
+        xs2, batch2 = self.rand_batch(4, (False, 2))
+        matmul_test(xs, batch, xs2, batch2)
+        # 1 dimension * 2 dimension
+        xs, batch = self.rand_batch(4, (False, 2))
+        xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
+        matmul_test(xs, batch, xs2, batch2)
+        # 2 dimension * 1 dimensions
+        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
+        xs2, batch2 = self.rand_batch(4, (False, 2))
+        matmul_test(xs, batch, xs2, batch2)
+        # 2 dimension * 2 dimension
+        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
+        xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
+        matmul_test(xs, batch, xs2, batch2)
 
-        traced = torch.jit.trace(test, (torch.randn(5, 3, 10),
-                                        torch.LongTensor([3, 2, 1]),
-                                        (torch.randn(2, 3, 20), torch.randn(2, 3, 20))))
-        imported = self.getExportImportCopy(traced)
-        x, lengths, h0, c0 = \
-            torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
-        self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
+    def test_batch_select(self):
+        @torch.jit.batch(batch_size=4)
+        def select(x):
+            return torch.select(x, 1, 0)
 
-    def test_trace_variable_instantiation(self):
-        def random_foo(x):
-            return Variable(Variable(x) + 1.0)
+        xs, batch = self.rand_batch(4, (True, 3), (True, 2))
+        res_batch = select(batch)
+        res = [torch.select(xs[j], 1, 0) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-        random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
+        xs, batch = self.rand_batch(4, (False, 3), (True, 2))
+        res_batch = select(batch)
+        res = [torch.select(xs[j], 1, 0) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-        x = torch.rand(5, 6)
-        self.assertEqual(random_foo(x), random_foo_traced(x))
+    def test_batch_index_select(self):
+        @torch.jit.batch(batch_size=4)
+        def index_select(x, ind):
+            return x.index_select(1, ind)
 
-    def test_trace_slice_expr_complete_type(self):
-        def random_foo(x):
-            return x + 1.0
+        xs, batch = self.rand_batch(4, (False, 5), (True, 2))
+        ind = [torch.randint(0, 4, (1,), dtype=torch.long) for i in range(4)]
+        ind_batch = BatchTensor(ind, torch.tensor([]).byte())
+        res_batch = index_select(batch, ind_batch)
+        res = [torch.index_select(xs[j], 1, ind[j]) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-        random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
+    def test_batch_where(self):
+        @torch.jit.batch(batch_size=4)
+        def where(c, a, b):
+            return torch.where(c, a, b)
 
-        @torch.jit.script
-        def random_bar(x):
-            return random_foo_traced(x)[0:1]
+        xs, batch = self.rand_batch(4, (False, 3), (False, 2))
+        xs2, batch2 = self.rand_batch(4, (False, 3), (False, 2))
 
-        x = torch.rand(3, 4)
-        self.assertEqual(random_bar(x), (x + 1)[0:1])
+        dims = [4, (False, 3), (False, 2)]
+        xs_cond = [torch.rand(1, 3, 2).byte() for i in range(dims[0])]
+        batch_cond = BatchTensor(xs_cond, torch.tensor([b for b, d in dims[1:]]))
 
-    def test_export_tensoroption_to(self):
-        def foo(x):
-            return x.new_tensor(x[0]).cpu() + x
+        res_batch = where(batch_cond, batch, batch2)
+        res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-        traced = torch.jit.trace(foo, (torch.rand([2])))
-        example_outputs = traced(torch.rand([2]))
+    def test_batch_argmax(self):
+        @torch.jit.batch(batch_size=4)
+        def argmax(a):
+            return torch.argmax(a, 1)
 
-        f = io.BytesIO()
-        self.assertExpected(torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
-                                                                example_outputs=example_outputs))
+        xs, batch = self.rand_batch(4, (True, 5), (True, 6))
+        res_batch = argmax(batch)
+        res = [torch.argmax(xs[j], 1) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-    def test_pretty_printer(self):
-        @torch.jit.script
-        def if_test(a, b):
-            # FIXME: use 0 instead of a.
-            # c = 0
-            c = a
-            if bool(a < b):
-                c = b
-            else:
-                c = a
-            return c
+        @torch.jit.batch(batch_size=4)
+        def argmax(a):
+            return torch.argmax(a, 1, False)
 
-        @torch.jit.script
-        def if_one(a, b):
-            c = b
-            if bool(a < b):
-                c = a
-            return c
+        res_batch = argmax(batch)
+        res = [torch.argmax(xs[j], 1, False) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-        @torch.jit.script
-        def while_test(a, i):
-            while bool(i < 3):
-                a *= a
-                i += 1
-            return a
+    def test_batch_topk(self):
+        @torch.jit.batch(batch_size=4)
+        def topk(a):
+            return torch.topk(a, 3, 1)
 
-        @torch.jit.script
-        def while_if_test(a, b):
-            c = 0
-            while bool(a < 10):
-                a = a + 1
-                b = b + 1
-                if bool(a > b):
-                    c = 2
-                else:
-                    c = 3
-            return a + 1 + c
+        xs, batch = self.rand_batch(4, (False, 5), (True, 6))
 
-        @torch.jit.script
-        def loop_use_test(y):
-            x = y + 1
-            z = x + 5
-            while bool(y < 8):
-                y += 1
-                z = x
-            return x, z
+        # along static dim
+        res_batch = topk(batch)
+        res = [torch.topk(xs[j], 3, 1)[0] for j in range(4)]
+        res_idx = [torch.topk(xs[j], 3, 1)[1] for j in range(4)]
+        self.assertEqual(res, res_batch[0].examples())
+        self.assertEqual(res_idx, res_batch[1].examples())
 
-        def python_fn(x):
-            return x + 10
+        @torch.jit.batch(batch_size=4)
+        def topk(a):
+            return torch.topk(a, 1, 2)
 
-        @torch.jit.script
-        def python_op_name_test(y):
-            return python_fn(y)
+        # along dynamic dim
+        res_batch = topk(batch)
+        res = [torch.topk(xs[j], 1, 2)[0] for j in range(4)]
+        res_idx = [torch.topk(xs[j], 1, 2)[1] for j in range(4)]
+        self.assertEqual(res, res_batch[0].examples())
+        self.assertEqual(res_idx, res_batch[1].examples())
 
-        @torch.jit.script
-        def empty_int_list_test(y):
-            x = torch.jit.annotate(List[int], [])
-            return x[0]
+    def test_batch_softmax(self):
+        @torch.jit.batch(batch_size=4)
+        def softmax(a):
+            return torch.softmax(a, 1)
 
-        @torch.jit.script
-        def empty_float_list_test(y):
-            return [1.0, 2.0, 3.0]
+        xs, batch = self.rand_batch(4, (False, 5), (True, 6))
 
-        @torch.jit.script
-        def print_weird_test(y):
-            print("hi\016")
+        # along static dim
+        res_batch = softmax(batch)
+        res = [torch.softmax(xs[j], 1) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-        self.assertExpected(if_test.graph.pretty_print(), "if_test")
-        self.assertExpected(if_one.graph.pretty_print(), "if_one")
-        self.assertExpected(while_test.graph.pretty_print(), "while_test")
-        self.assertExpected(while_if_test.graph.pretty_print(), "while_if_test")
-        self.assertExpected(loop_use_test.graph.pretty_print(), "loop_use_test")
-        self.assertExpected(python_op_name_test.graph.pretty_print(), "python_op_name_test")
-        self.assertExpected(empty_int_list_test.graph.pretty_print(), "empty_int_list_test")
-        self.assertExpected(empty_float_list_test.graph.pretty_print(), "empty_float_list_test")
-        self.assertExpected(print_weird_test.graph.pretty_print(), "print_weird_test")
+        @torch.jit.batch(batch_size=4)
+        def softmax(a):
+            return torch.softmax(a, 2)
 
-    def test_cu_escaped_number(self):
-        cu = torch.jit.CompilationUnit('''
-            def foo(a):
-                print("hi\016")
-        ''')
-        self.assertExpected(cu.foo.graph.pretty_print())
+        # along dynamic dim
+        res_batch = softmax(batch)
+        res = [torch.softmax(xs[j], 2) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-    def test_import_method(self):
-        @torch.jit.script
-        def foo(x, y):
-            return 2 * x + y
+    def test_batch_view(self):
+        @torch.jit.batch(batch_size=4)
+        def view(a):
+            return a.view([4, -1, 3])
 
-        r, _ = foo._python_print()
-        mod = torch.jit.ScriptModule()
-        torch._C._jit_import_methods(mod, "op_version_set = 0\n{}".format(r), [])
-        self.assertExpected(mod.graph.pretty_print())
+        xs, batch = self.rand_batch(4, (True, 5), (False, 3))
+        res_batch = view(batch)
+        res = [xs[j].view([1, -1, 3]) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-    def test_function_default_values(self):
-        outer_var = torch.tensor(20)
-        outer_var2 = torch.tensor(30)
-        a = torch.tensor(0.5)
-        b = torch.tensor(10)
+    def test_batch_cat(self):
+        @torch.jit.batch(batch_size=4)
+        def cat2(a, b):
+            return torch.cat([a, b], 2)
 
-        @torch.jit.script
-        def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
-            return x + a + b + c
+        xs, batch = self.rand_batch(4, (True, 5), (False, 3))
+        xs2, batch2 = xs, batch
+        res_batch = cat2(batch, batch2)
+        res = [torch.cat([xs[j], xs2[j]], 2) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-        self.assertExpectedGraph(simple_fn.graph, "simple")
-        self.assertEqual(
-            simple_fn(torch.ones(1)),
-            torch.ones(1) + 0.5 + 10 + (20 + 30))
-        self.assertEqual(
-            simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)),
-            torch.ones(1) + 1 + 3 + 4)
+    def test_batch_sum(self):
+        @torch.jit.batch(batch_size=4)
+        def batch_sum(a):
+            return a.sum()
 
-        outer_c = torch.tensor(9)
-        outer_flag = torch.tensor(False)
+        xs, batch = self.rand_batch(4, (True, 5), (False, 3))
+        res_batch = batch_sum(batch)
+        res = [xs[j].sum().unsqueeze(0) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-        @torch.jit.script
-        def bool_fn(x, a=outer_c, flag=outer_flag):
-            if bool(flag):
-                result = x
+    def test_if_else(self):
+        def single_if(a, b):
+            if bool(a > b):
+                a = a + b
             else:
-                result = x + a
-            return result
-
-        self.assertExpectedGraph(bool_fn.graph, "bool")
-        self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
-        self.assertEqual(
-            bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)),
-            torch.ones(1))
+                a = a - b
+            return a
 
-        @torch.jit.script
-        def none_fn(x=None):
-            # type: (Optional[int]) -> Optional[int]
-            return x
+        batch_if = torch.jit.batch(batch_size=4)(single_if)
 
-        self.assertExpectedGraph(none_fn.graph, "none")
-        self.assertEqual(none_fn(), None)
-        self.assertEqual(none_fn(1), 1)
+        a, batch_a = self.rand_batch(4, ())
+        b, batch_b = self.rand_batch(4, ())
+        res_batch = batch_if(batch_a, batch_b)
+        res = [single_if(a[j], b[j]) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-        @torch.jit.script
-        def hints(x, a=0.5, b=10):
-            # type: (Tensor, float, int) -> Tensor
-            return x + a + b
+        script_if = torch.jit.script(single_if)
+        graph = torch.to_batch_graph(script_if.graph)
+        self.assertExpected(canonical(graph))
 
-        self.assertExpectedGraph(hints.graph, "type_hints")
-        self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
+    def test_if_else_with_scalar(self):
+        def single_if(a, b):
+            if bool(a > 0.1):
+                a = a + b
+            else:
+                a = a - b
+            return a
 
-        with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
+        batch_if = torch.jit.batch(batch_size=4)(single_if)
 
-            @torch.jit.script
-            def hints_bad_types(x, a=10, b=0.5):
-                # type: (Tensor, float, int) -> Tensor
-                return x + a + b
+        a, batch_a = self.rand_batch(4, ())
+        b, batch_b = self.rand_batch(4, ())
+        res_batch = batch_if(batch_a, batch_b)
+        res = [single_if(a[j], b[j]) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-    def test_module_default_values(self):
-        four = torch.tensor(4)
+        script_if = torch.jit.script(single_if)
+        graph = torch.to_batch_graph(script_if.graph)
+        self.assertExpected(canonical(graph))
 
-        class Test(torch.jit.ScriptModule):
-            def __init__(self):
-                super(Test, self).__init__()
+    def test_if_noelse(self):
+        def single_if(a, b):
+            if bool(a > b):
+                a = a + b
+            return a
 
-            @torch.jit.script_method
-            def forward(self, input, other=four):
-                return input + other
+        batch_if = torch.jit.batch(batch_size=4)(single_if)
 
-        t = Test()
-        self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
+        a, batch_a = self.rand_batch(4, ())
+        b, batch_b = self.rand_batch(4, ())
+        res_batch = batch_if(batch_a, batch_b)
+        res = [single_if(a[j], b[j]) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-    def test_warnings(self):
-        import warnings
+        script_if = torch.jit.script(single_if)
+        graph = torch.to_batch_graph(script_if.graph)
+        self.assertExpected(canonical(graph))
 
-        @torch.jit.script
-        def fn(x):
-            if bool(x < 2):
-                warnings.warn("x is less than 2")
-            return x
+    def test_if_noelse_with_scalar(self):
+        def single_if(a, b):
+            if bool(a > 0.1):
+                a = a + b
+            return a
 
-        self.assertExpectedGraph(fn.graph)
+        batch_if = torch.jit.batch(batch_size=4)(single_if)
 
+        a, batch_a = self.rand_batch(4, ())
+        b, batch_b = self.rand_batch(4, ())
+        res_batch = batch_if(batch_a, batch_b)
+        res = [single_if(a[j], b[j]) for j in range(4)]
+        self.assertEqual(res, res_batch.examples())
 
-class TestBatched(TestCase):
-    # generate random examples and create an batchtensor with them
-    def rand_batch(self, *dims):
-        dims = [dim for dim in dims if dim != ()]
-        xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:]),
-                         requires_grad=True) for i in range(dims[0])]
-        xb = BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]).byte())
-        return xs, xb
+        script_if = torch.jit.script(single_if)
+        graph = torch.to_batch_graph(script_if.graph)
+        self.assertExpected(canonical(graph))
 
-    def test_create_batchtensor(self):
-        # create from tensorlist
-        xs, batch = self.rand_batch(4, (True, 3), (False, 2), (True, 5))
-        self.assertEqual(xs, batch.examples())
-        # create from data, mask, dims
-        batch2 = BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims())
-        self.assertEqual(xs, batch2.examples())
-        # expand a tensor to a batchtensor given batch_size
-        xs = torch.rand(3, 4, 5)
-        batch3 = BatchTensor(xs, 2)
-        xs = xs.unsqueeze(0)
-        self.assertEqual([xs, xs], batch3.examples())
+    def test_while(self):
+        def single_while(a, b):
+            while bool(a > b):
+                a = a - b
+            return a
 
-    def test_batch_elementwise_unary(self):
-        @torch.jit.batch(batch_size=4)
-        def tanh(a):
-            return torch.tanh(a)
+        batch_while = torch.jit.batch(batch_size=4)(single_while)
 
-        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
-        res_batch = tanh(batch)
-        res = [torch.tanh(xs[j]) for j in range(4)]
+        a, batch_a = self.rand_batch(4, ())
+        b = [torch.abs(torch.rand(1)) for i in range(4)]
+        batch_b = BatchTensor(b, torch.tensor([]).byte())
+        res_batch = batch_while(batch_a, batch_b)
+        res = [single_while(a[j], b[j]) for j in range(4)]
         self.assertEqual(res, res_batch.examples())
 
-    def test_batch_elementwise_binary(self):
-        @torch.jit.batch(batch_size=4)
-        def add(a, b):
-            return a + b
-
-        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
-        xs2, batch2 = xs, batch
-        res_batch = add(batch, batch2)
-        res = [torch.add(xs[j], xs2[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
+        script_while = torch.jit.script(single_while)
+        graph = torch.to_batch_graph(script_while.graph)
+        self.assertExpected(canonical(graph))
 
-        # test broadcast
-        xs, batch = self.rand_batch(4, (False, 3), (False, 2))
-        b = torch.rand(3, 2)
-        res_batch = add(batch, b)
-        res = [torch.add(xs[j], b) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
+    def test_for(self):
+        def single_for(x, y):
+            for _ in range(10):
+                x = x + y
+            return x
 
-    def test_batch_mm(self):
-        @torch.jit.batch(batch_size=4)
-        def mm(a, b):
-            return torch.mm(a, b)
+        batch_for = torch.jit.batch(batch_size=4)(single_for)
 
-        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
-        xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
-        res_batch = mm(batch, batch2)
-        res = [torch.mm(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
+        a, batch_a = self.rand_batch(4, ())
+        b, batch_b = self.rand_batch(4, ())
+        res_batch = batch_for(batch_a, batch_b)
+        res = [single_for(a[j], b[j]) for j in range(4)]
         self.assertEqual(res, res_batch.examples())
 
-        # test broadcast
-        b = torch.rand(2, 4)
-        res_batch = mm(batch, b)
-        res = [torch.mm(xs[j].squeeze(0), b).unsqueeze(0) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
+        script_for = torch.jit.script(single_for)
+        graph = torch.to_batch_graph(script_for.graph)
+        self.assertExpected(canonical(graph))
 
-    def test_batch_matmul(self):
-        @torch.jit.batch(batch_size=4)
-        def matmul(a, b):
-            return torch.matmul(a, b)
+    def test_lstm(self):
+        def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
+            for i in range(x_all.size(1)):
+                x = x_all.select(1, i)
+                i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
+                f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
+                o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
+                # activations
+                i_t = torch.sigmoid(i_t)
+                f_t = torch.sigmoid(f_t)
+                o_t = torch.sigmoid(o_t)
+                # cell computations
+                c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
+                c_t = torch.tanh(c_t)
+                c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
+                h_t = torch.mul(o_t, torch.tanh(c_t))
+                h = h_t
+                c = c_t
+            return h
 
-        def matmul_test(xs, batch, xs2, batch2):
-            ys = [torch.matmul(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
-            ybs = matmul(batch, batch2)
-            self.assertEqual(ys, ybs.examples())
+        LSTM_batch = torch.jit.batch(batch_size=4)(LSTM)
 
-        # 1 dimension * 1 dimension
-        xs, batch = self.rand_batch(4, (False, 2))
-        xs2, batch2 = self.rand_batch(4, (False, 2))
-        matmul_test(xs, batch, xs2, batch2)
-        # 1 dimension * 2 dimension
-        xs, batch = self.rand_batch(4, (False, 2))
-        xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
-        matmul_test(xs, batch, xs2, batch2)
-        # 2 dimension * 1 dimensions
-        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
-        xs2, batch2 = self.rand_batch(4, (False, 2))
-        matmul_test(xs, batch, xs2, batch2)
-        # 2 dimension * 2 dimension
-        xs, batch = self.rand_batch(4, (True, 3), (False, 2))
-        xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
-        matmul_test(xs, batch, xs2, batch2)
+        batch_size, input_size, hidden_size = 4, 3, 2
+        xs, batch = self.rand_batch(batch_size, (True, 4), (False, input_size))
+        hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
+        cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
 
-    def test_batch_select(self):
-        @torch.jit.batch(batch_size=4)
-        def select(x):
-            return torch.select(x, 1, 0)
+        # input to hidden weights
+        w_xi = torch.rand(input_size, hidden_size)
+        w_xf = torch.rand(input_size, hidden_size)
+        w_xo = torch.rand(input_size, hidden_size)
+        w_xc = torch.rand(input_size, hidden_size)
+        # hidden to hidden weights
+        w_hi = torch.rand(hidden_size, hidden_size)
+        w_hf = torch.rand(hidden_size, hidden_size)
+        w_ho = torch.rand(hidden_size, hidden_size)
+        w_hc = torch.rand(hidden_size, hidden_size)
+        # bias terms
+        b_i = torch.rand(hidden_size)
+        b_f = torch.rand(hidden_size)
+        b_o = torch.rand(hidden_size)
+        b_c = torch.rand(hidden_size)
 
-        xs, batch = self.rand_batch(4, (True, 3), (True, 2))
-        res_batch = select(batch)
-        res = [torch.select(xs[j], 1, 0) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
+        ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
+                   w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) for j in range(batch_size)]
+        ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
+                         w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
+        self.assertEqual(ys, ybs.examples())
 
-        xs, batch = self.rand_batch(4, (False, 3), (True, 2))
-        res_batch = select(batch)
-        res = [torch.select(xs[j], 1, 0) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
+    def test_greedy_search(self):
+        def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
+                   b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
+            iter_count = torch.zeros_like(iter_num)
+            while bool(iter_count < iter_num):
+                iter_count = iter_count + 1
+                # LSTM Cell
+                i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
+                f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
+                o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
+                # activations
+                i_t = torch.sigmoid(i_t)
+                f_t = torch.sigmoid(f_t)
+                o_t = torch.sigmoid(o_t)
+                # cell computations
+                c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
+                c_t = torch.tanh(c_t)
+                c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
+                h_t = torch.mul(o_t, torch.tanh(c_t))
+                h = h_t
+                c = c_t
+                # calculate feature with max probability
+                s_t = torch.matmul(h_t, w_hs) + b_s
+                p_t = torch.softmax(s_t, 1)
+                i_t = torch.argmax(p_t, 1)
+                x = embed.index_select(1, i_t).squeeze(1)
+            return h
 
-    def test_batch_index_select(self):
-        @torch.jit.batch(batch_size=4)
-        def index_select(x, ind):
-            return x.index_select(1, ind)
-
-        xs, batch = self.rand_batch(4, (False, 5), (True, 2))
-        ind = [torch.randint(0, 4, (1,), dtype=torch.long) for i in range(4)]
-        ind_batch = BatchTensor(ind, torch.tensor([]).byte())
-        res_batch = index_select(batch, ind_batch)
-        res = [torch.index_select(xs[j], 1, ind[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_where(self):
-        @torch.jit.batch(batch_size=4)
-        def where(c, a, b):
-            return torch.where(c, a, b)
-
-        xs, batch = self.rand_batch(4, (False, 3), (False, 2))
-        xs2, batch2 = self.rand_batch(4, (False, 3), (False, 2))
-
-        dims = [4, (False, 3), (False, 2)]
-        xs_cond = [torch.rand(1, 3, 2).byte() for i in range(dims[0])]
-        batch_cond = BatchTensor(xs_cond, torch.tensor([b for b, d in dims[1:]]))
-
-        res_batch = where(batch_cond, batch, batch2)
-        res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_argmax(self):
-        @torch.jit.batch(batch_size=4)
-        def argmax(a):
-            return torch.argmax(a, 1)
-
-        xs, batch = self.rand_batch(4, (True, 5), (True, 6))
-        res_batch = argmax(batch)
-        res = [torch.argmax(xs[j], 1) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        @torch.jit.batch(batch_size=4)
-        def argmax(a):
-            return torch.argmax(a, 1, False)
-
-        res_batch = argmax(batch)
-        res = [torch.argmax(xs[j], 1, False) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_topk(self):
-        @torch.jit.batch(batch_size=4)
-        def topk(a):
-            return torch.topk(a, 3, 1)
-
-        xs, batch = self.rand_batch(4, (False, 5), (True, 6))
-
-        # along static dim
-        res_batch = topk(batch)
-        res = [torch.topk(xs[j], 3, 1)[0] for j in range(4)]
-        res_idx = [torch.topk(xs[j], 3, 1)[1] for j in range(4)]
-        self.assertEqual(res, res_batch[0].examples())
-        self.assertEqual(res_idx, res_batch[1].examples())
-
-        @torch.jit.batch(batch_size=4)
-        def topk(a):
-            return torch.topk(a, 1, 2)
-
-        # along dynamic dim
-        res_batch = topk(batch)
-        res = [torch.topk(xs[j], 1, 2)[0] for j in range(4)]
-        res_idx = [torch.topk(xs[j], 1, 2)[1] for j in range(4)]
-        self.assertEqual(res, res_batch[0].examples())
-        self.assertEqual(res_idx, res_batch[1].examples())
-
-    def test_batch_softmax(self):
-        @torch.jit.batch(batch_size=4)
-        def softmax(a):
-            return torch.softmax(a, 1)
-
-        xs, batch = self.rand_batch(4, (False, 5), (True, 6))
-
-        # along static dim
-        res_batch = softmax(batch)
-        res = [torch.softmax(xs[j], 1) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        @torch.jit.batch(batch_size=4)
-        def softmax(a):
-            return torch.softmax(a, 2)
-
-        # along dynamic dim
-        res_batch = softmax(batch)
-        res = [torch.softmax(xs[j], 2) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_view(self):
-        @torch.jit.batch(batch_size=4)
-        def view(a):
-            return a.view([4, -1, 3])
-
-        xs, batch = self.rand_batch(4, (True, 5), (False, 3))
-        res_batch = view(batch)
-        res = [xs[j].view([1, -1, 3]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_cat(self):
-        @torch.jit.batch(batch_size=4)
-        def cat2(a, b):
-            return torch.cat([a, b], 2)
-
-        xs, batch = self.rand_batch(4, (True, 5), (False, 3))
-        xs2, batch2 = xs, batch
-        res_batch = cat2(batch, batch2)
-        res = [torch.cat([xs[j], xs2[j]], 2) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_batch_sum(self):
-        @torch.jit.batch(batch_size=4)
-        def batch_sum(a):
-            return a.sum()
-
-        xs, batch = self.rand_batch(4, (True, 5), (False, 3))
-        res_batch = batch_sum(batch)
-        res = [xs[j].sum().unsqueeze(0) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-    def test_if_else(self):
-        def single_if(a, b):
-            if bool(a > b):
-                a = a + b
-            else:
-                a = a - b
-            return a
-
-        batch_if = torch.jit.batch(batch_size=4)(single_if)
-
-        a, batch_a = self.rand_batch(4, ())
-        b, batch_b = self.rand_batch(4, ())
-        res_batch = batch_if(batch_a, batch_b)
-        res = [single_if(a[j], b[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        script_if = torch.jit.script(single_if)
-        graph = torch.to_batch_graph(script_if.graph)
-        self.assertExpected(canonical(graph))
-
-    def test_if_else_with_scalar(self):
-        def single_if(a, b):
-            if bool(a > 0.1):
-                a = a + b
-            else:
-                a = a - b
-            return a
-
-        batch_if = torch.jit.batch(batch_size=4)(single_if)
-
-        a, batch_a = self.rand_batch(4, ())
-        b, batch_b = self.rand_batch(4, ())
-        res_batch = batch_if(batch_a, batch_b)
-        res = [single_if(a[j], b[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        script_if = torch.jit.script(single_if)
-        graph = torch.to_batch_graph(script_if.graph)
-        self.assertExpected(canonical(graph))
-
-    def test_if_noelse(self):
-        def single_if(a, b):
-            if bool(a > b):
-                a = a + b
-            return a
-
-        batch_if = torch.jit.batch(batch_size=4)(single_if)
-
-        a, batch_a = self.rand_batch(4, ())
-        b, batch_b = self.rand_batch(4, ())
-        res_batch = batch_if(batch_a, batch_b)
-        res = [single_if(a[j], b[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        script_if = torch.jit.script(single_if)
-        graph = torch.to_batch_graph(script_if.graph)
-        self.assertExpected(canonical(graph))
-
-    def test_if_noelse_with_scalar(self):
-        def single_if(a, b):
-            if bool(a > 0.1):
-                a = a + b
-            return a
-
-        batch_if = torch.jit.batch(batch_size=4)(single_if)
-
-        a, batch_a = self.rand_batch(4, ())
-        b, batch_b = self.rand_batch(4, ())
-        res_batch = batch_if(batch_a, batch_b)
-        res = [single_if(a[j], b[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        script_if = torch.jit.script(single_if)
-        graph = torch.to_batch_graph(script_if.graph)
-        self.assertExpected(canonical(graph))
-
-    def test_while(self):
-        def single_while(a, b):
-            while bool(a > b):
-                a = a - b
-            return a
-
-        batch_while = torch.jit.batch(batch_size=4)(single_while)
-
-        a, batch_a = self.rand_batch(4, ())
-        b = [torch.abs(torch.rand(1)) for i in range(4)]
-        batch_b = BatchTensor(b, torch.tensor([]).byte())
-        res_batch = batch_while(batch_a, batch_b)
-        res = [single_while(a[j], b[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        script_while = torch.jit.script(single_while)
-        graph = torch.to_batch_graph(script_while.graph)
-        self.assertExpected(canonical(graph))
-
-    def test_for(self):
-        def single_for(x, y):
-            for _ in range(10):
-                x = x + y
-            return x
-
-        batch_for = torch.jit.batch(batch_size=4)(single_for)
-
-        a, batch_a = self.rand_batch(4, ())
-        b, batch_b = self.rand_batch(4, ())
-        res_batch = batch_for(batch_a, batch_b)
-        res = [single_for(a[j], b[j]) for j in range(4)]
-        self.assertEqual(res, res_batch.examples())
-
-        script_for = torch.jit.script(single_for)
-        graph = torch.to_batch_graph(script_for.graph)
-        self.assertExpected(canonical(graph))
-
-    def test_lstm(self):
-        def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
-            for i in range(x_all.size(1)):
-                x = x_all.select(1, i)
-                i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
-                f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
-                o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
-                # activations
-                i_t = torch.sigmoid(i_t)
-                f_t = torch.sigmoid(f_t)
-                o_t = torch.sigmoid(o_t)
-                # cell computations
-                c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
-                c_t = torch.tanh(c_t)
-                c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
-                h_t = torch.mul(o_t, torch.tanh(c_t))
-                h = h_t
-                c = c_t
-            return h
-
-        LSTM_batch = torch.jit.batch(batch_size=4)(LSTM)
-
-        batch_size, input_size, hidden_size = 4, 3, 2
-        xs, batch = self.rand_batch(batch_size, (True, 4), (False, input_size))
-        hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
-        cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
-
-        # input to hidden weights
-        w_xi = torch.rand(input_size, hidden_size)
-        w_xf = torch.rand(input_size, hidden_size)
-        w_xo = torch.rand(input_size, hidden_size)
-        w_xc = torch.rand(input_size, hidden_size)
-        # hidden to hidden weights
-        w_hi = torch.rand(hidden_size, hidden_size)
-        w_hf = torch.rand(hidden_size, hidden_size)
-        w_ho = torch.rand(hidden_size, hidden_size)
-        w_hc = torch.rand(hidden_size, hidden_size)
-        # bias terms
-        b_i = torch.rand(hidden_size)
-        b_f = torch.rand(hidden_size)
-        b_o = torch.rand(hidden_size)
-        b_c = torch.rand(hidden_size)
-
-        ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
-                   w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) for j in range(batch_size)]
-        ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
-                         w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
-        self.assertEqual(ys, ybs.examples())
-
-    def test_greedy_search(self):
-        def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
-                   b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
-            iter_count = torch.zeros_like(iter_num)
-            while bool(iter_count < iter_num):
-                iter_count = iter_count + 1
-                # LSTM Cell
-                i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
-                f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
-                o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
-                # activations
-                i_t = torch.sigmoid(i_t)
-                f_t = torch.sigmoid(f_t)
-                o_t = torch.sigmoid(o_t)
-                # cell computations
-                c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
-                c_t = torch.tanh(c_t)
-                c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
-                h_t = torch.mul(o_t, torch.tanh(c_t))
-                h = h_t
-                c = c_t
-                # calculate feature with max probability
-                s_t = torch.matmul(h_t, w_hs) + b_s
-                p_t = torch.softmax(s_t, 1)
-                i_t = torch.argmax(p_t, 1)
-                x = embed.index_select(1, i_t).squeeze(1)
-            return h
-
-        greedy_batch = torch.jit.batch(batch_size=4)(greedy)
+        greedy_batch = torch.jit.batch(batch_size=4)(greedy)
 
         batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
         xs, batch = self.rand_batch(batch_size, (False, input_size))
@@ -2979,52 +2661,6 @@ class TestScript(JitTestCase):
             ge = torch.jit.script(script, optimize)
             ge(*inputs)
 
-    def checkScript(self,
-                    script,
-                    inputs,
-                    optimize=True,
-                    outputs=None,
-                    name='func',
-                    capture_output=False,
-                    frames_up=1,
-                    check_expected=False):
-        if isinstance(script, str):
-            cu = torch.jit.CompilationUnit(script, optimize, _frames_up=frames_up)
-            ge = getattr(cu, name)
-        else:
-            if capture_output:
-                with self.capture_stdout() as captured:
-                    outputs = script(*inputs)
-            else:
-                outputs = script(*inputs)
-            # Check the string frontend first
-            source = textwrap.dedent(inspect.getsource(script))
-            self.checkScript(
-                source,
-                inputs,
-                optimize,
-                outputs,
-                script.__name__,
-                capture_output,
-                frames_up=2,
-                check_expected=check_expected)
-            # Continue checking the Python frontend
-            ge = torch.jit.script(script, optimize, _frames_up=1)
-
-        if capture_output:
-            with self.capture_stdout() as captured:
-                outputs_ge = ge(*inputs)
-            if not WINDOWS:
-                self.assertExpected(captured[0], subname='stdout')
-        else:
-            outputs_ge = ge(*inputs)
-        self.assertEqual(outputs, outputs_ge)
-
-        if check_expected:
-            self.assertExpectedGraph(ge.graph)
-
-        return ge
-
     def test_training_param(self):
         class What(torch.jit.ScriptModule):
             @torch.jit.script_method
@@ -3288,29 +2924,6 @@ a")
         b = torch.rand(1, requires_grad=True)
         self.checkScript(func, (a, b), optimize=True)
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_clamp_fusion(self):
-        def func2(a, b):
-            return torch.clamp(a + b, min=0, max=2)
-
-        def funcInf(a, b):
-            return torch.clamp(a + b, min=0, max=float('inf'))
-
-        a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
-        b = torch.randn(4, 4, dtype=torch.float, device='cuda')
-
-        funcs = (func2, funcInf)
-        for f in funcs:
-            s = self.checkScript(f, (a, b))
-            self.assertAllFused(s.graph_for(a, b), except_for={'aten::size'})
-
-            c = s(a, b)
-            c.sum().backward()
-            graph = backward_graph(s)
-            self.assertAllFused(graph, except_for={'prim::SumToSize'})
-
     def test_mul(self):
         def func(a, b):
             return a * b
@@ -3733,89 +3346,6 @@ a")
             func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
         self.assertExpected(canonical(func2.graph), subname='2')
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "No CUDA")
-    @skipIfRocm
-    def test_chunk_fusion_cuda(self):
-        def fn(x):
-            a, b, c = x.chunk(3, 1)
-            return a * b + c
-
-        inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
-
-        self.checkScript(fn, inputs)
-
-        fn_script = torch.jit.script(fn)
-        _ = fn_script(*inputs)
-        self.assertExpectedGraph(fn_script.graph_for(*inputs))
-
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "No CUDA")
-    @skipIfRocm
-    def test_chunk_multiple_fusion_cuda(self):
-        # The arguments are intentionally used out of order as a test to see
-        # if the fusion compiler adds extra args in the correct order
-        def fn(s, x, y, z):
-            z1, z2 = z.chunk(2, 2)
-            x1, x2, x3 = x.chunk(3, 1)
-            y1, y2 = y.chunk(2, 0)
-            return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
-
-        inputs = [
-            torch.randn(5, 2, 3, dtype=torch.float, device='cuda'),
-            torch.randn(5, 6, 3, dtype=torch.float, device='cuda'),
-            torch.randn(10, 2, 3, dtype=torch.float, device='cuda'),
-            torch.randn(5, 2, 6, dtype=torch.float, device='cuda'),
-        ]
-
-        self.checkScript(fn, inputs)
-
-        fn_script = torch.jit.script(fn)
-        _ = fn_script(*inputs)
-        self.assertExpectedGraph(fn_script.graph_for(*inputs))
-
-    @staticmethod
-    def _test_chunk_fusion_correctness(self, device='cpu'):
-        def chunk_4_0(x):
-            x0, x1, x2, x3 = x.chunk(4, 0)
-            return x0 + x1 + x2 + x3
-
-        def chunk_4_1(x):
-            x0, x1, x2, x3 = x.chunk(4, 1)
-            return x0 + x1 + x2 + x3
-
-        def chunk_4_last(x):
-            x0, x1, x2, x3 = x.chunk(4, 2)
-            return x0 + x1 + x2 + x3
-
-        fns = [chunk_4_0, chunk_4_1, chunk_4_last]
-        tensors = [
-            # splitSize = 1
-            torch.randn(4, 4, 4, dtype=torch.float, device=device),
-
-            # contiguous case
-            torch.randn(12, 8, 16, dtype=torch.float, device=device),
-
-            # non-contiguous case
-            torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2),
-        ]
-
-        for tensor in tensors:
-            for fn in fns:
-                self.checkScript(fn, [tensor])
-
-    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
-    @skipIfRocm
-    @enable_cpu_fuser
-    def test_chunk_fusion_correctness(self):
-        return self._test_chunk_fusion_correctness(self, 'cpu')
-
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "No CUDA")
-    @skipIfRocm
-    def test_chunk_fusion_correctness_cuda(self):
-        return self._test_chunk_fusion_correctness(self, 'cuda')
-
     def test_cat(self):
         @torch.jit.script
         def func(x):
@@ -3950,32 +3480,6 @@ a")
 
         self.checkScript(func2, ())
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_tensor_scalar_fusion_cuda(self):
-        def should_fuse(x):
-            z = 3.
-            y = x + z
-            return x * y
-
-        # XXX: right now we only support fusing scalars if
-        # they're constant (#9940)
-        def should_not_fuse(x, z):
-            y = x + int(z)
-            return x * y
-
-        inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
-        ge = self.checkScript(should_fuse, inputs)
-        self.assertExpectedGraph(ge.graph_for(*inputs), subname='1')
-
-        inputs = [
-            torch.randn(2, 2, dtype=torch.float, device='cuda'),
-            torch.tensor(3., dtype=torch.float, device='cuda'),
-        ]
-        ge = self.checkScript(should_not_fuse, inputs)
-        self.assertExpectedGraph(ge.graph_for(*inputs), subname='2')
-
     def test_list_ops(self):
         def test_equality():
             a = [1, 2, 3]
@@ -4184,145 +3688,56 @@ a")
         # NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
         self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs)
 
-    def test_view_shape_prop(self):
-        cu = torch.jit.CompilationUnit('''
-        def test_view_shape_prop(a):
-            return a.view(size=[-1])
-        ''')
-        inputs = [torch.zeros(10, 10)]
-        outputs = torch.zeros(100)
-
-        real_outs = cu.test_view_shape_prop(*inputs)
-        self.assertEqual(real_outs, outputs)
-
-    def test_view_listconstruct_shape_prop(self):
-        def fn(x):
-            B = x.size(0)
-            C = x.size(1)
-            T = x.size(2)
-            return x.view(T, B, C)
-
-        x = torch.randn(3, 1, 5, requires_grad=True)
-        graph = torch.jit.script(fn).graph
-        torch._C._jit_pass_shape_analysis(graph, (x,), False)
-        self.assertTrue(next(graph.outputs()).type().kind() != 'DynamicType')
-
-    def test_integral_shape_inference(self):
-        cu = torch.jit.CompilationUnit('''
-        def test_integral_shape_inference(a):
-            return a / a
-        ''')
-        inputs = [torch.ones(10, 10).type(torch.LongTensor)]
-        outputs = torch.ones(10, 10)
-
-        self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
-
-    def test_fuser_multiple_blocks(self):
-        cu = torch.jit.CompilationUnit('''
-        def test_fuser_multiple_blocks(this, that, theother, meme):
-            i = 0
-            while i < 20:
-                this = torch.cat([this, meme], dim=0)
-                that = torch.cat([that, meme], dim=0)
-                theother = torch.cat([theother, meme], dim=0)
-                i = i + 1
-            return this, that, theother
-        ''')
-
-        inputs = [torch.ones(0, 10, 10)] * 3
-        inputs += [torch.ones(1, 10, 10)]
-        outputs = [torch.ones(20, 10, 10)] * 3
-
-        self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
-
-    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
-    @enable_cpu_fuser
-    def test_scalar_fusion(self):
-        def fn(x, y):
-            return 2 * x + y
-
-        x = torch.tensor(0.1, dtype=torch.float, device='cpu')
-        y = torch.tensor(1, dtype=torch.float, device='cpu')
-        ge = self.checkScript(fn, (x, y))
-        self.assertExpectedGraph(ge.graph_for(x, y))
-
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_fusion_chunk_motion_deduplicates_inputs(self):
-        def func1(x):
-            z = x * x
-            z0, z1 = z.chunk(2)
-            return z0 * z1
-
-        def func2(x):
-            z = x * x * x
-            z0, z1 = z.chunk(2)
-            return z0 * z1
-
-        inputs = [
-            torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float),
-        ]
-        for func in [func1, func2]:
-            module = self.checkScript(func, inputs)
-            forward_graph = module.graph_for(*inputs)
-            self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
-            fusion_group = list(forward_graph.nodes())[-1]
-            self.assertEqual(len(list(fusion_group.inputs())), 1)
-
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_lstm_gates_permutations_fusion_cuda(self):
-        # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
-        # Test that any permutation of this will still result in one FusionGroup.
-        choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh']
-        template = dedent('''
-        def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
-            gates = {} + {} + {} + {}
-            ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
-            return ingate * forgetgate * cellgate * outgate
+    def test_view_shape_prop(self):
+        cu = torch.jit.CompilationUnit('''
+        def test_view_shape_prop(a):
+            return a.view(size=[-1])
         ''')
-        for permutation in itertools.permutations(choices, len(choices)):
-            code = template.format(*permutation)
-            scope = {}
-            exec(code, globals(), scope)
-            cu = torch.jit.CompilationUnit(code)
+        inputs = [torch.zeros(10, 10)]
+        outputs = torch.zeros(100)
 
-            inputs = get_lstm_inputs('cuda', training=False)
-            self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
-            forward_graph = cu.cell.graph_for(*inputs)
-            self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
+        real_outs = cu.test_view_shape_prop(*inputs)
+        self.assertEqual(real_outs, outputs)
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_lstm_fusion_cuda(self):
-        inputs = get_lstm_inputs('cuda', training=True)
-        module = self.checkScript(LSTMCellS, inputs)
-        forward_graph = module.graph_for(*inputs)
-        self.assertGraphContainsExactly(
-            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
-        self.assertExpectedGraph(forward_graph, subname='forward')
+    def test_view_listconstruct_shape_prop(self):
+        def fn(x):
+            B = x.size(0)
+            C = x.size(1)
+            T = x.size(2)
+            return x.view(T, B, C)
 
-        hy, cy = module(*inputs)
-        (hy + cy).sum().backward()
-        self.assertExpectedGraph(backward_graph(module), subname='backward')
+        x = torch.randn(3, 1, 5, requires_grad=True)
+        graph = torch.jit.script(fn).graph
+        torch._C._jit_pass_shape_analysis(graph, (x,), False)
+        self.assertTrue(next(graph.outputs()).type().kind() != 'DynamicType')
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    @skipIfRocm
-    def test_milstm_fusion_cuda(self):
-        inputs = get_milstm_inputs('cuda', training=True)
-        module = self.checkScript(MiLSTMCell, inputs)
-        forward_graph = module.graph_for(*inputs)
-        self.assertGraphContainsExactly(
-            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
-        self.assertExpectedGraph(forward_graph, subname='forward')
+    def test_integral_shape_inference(self):
+        cu = torch.jit.CompilationUnit('''
+        def test_integral_shape_inference(a):
+            return a / a
+        ''')
+        inputs = [torch.ones(10, 10).type(torch.LongTensor)]
+        outputs = torch.ones(10, 10)
 
-        hy, cy = module(*inputs)
-        (hy + cy).sum().backward()
-        self.assertExpectedGraph(backward_graph(module), subname='backward')
+        self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
+
+    def test_fuser_multiple_blocks(self):
+        cu = torch.jit.CompilationUnit('''
+        def test_fuser_multiple_blocks(this, that, theother, meme):
+            i = 0
+            while i < 20:
+                this = torch.cat([this, meme], dim=0)
+                that = torch.cat([that, meme], dim=0)
+                theother = torch.cat([theother, meme], dim=0)
+                i = i + 1
+            return this, that, theother
+        ''')
+
+        inputs = [torch.ones(0, 10, 10)] * 3
+        inputs += [torch.ones(1, 10, 10)]
+        outputs = [torch.ones(20, 10, 10)] * 3
+
+        self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
 
     def test_dropout_script(self):
 
@@ -5352,25 +4767,205 @@ a")
         m.eval()
         self.assertEqual(m(), 1)
 
-    def test_script_module_for(self):
+    def test_script_module_for(self):
+        class M(torch.jit.ScriptModule):
+            __constants__ = ['b']
+
+            def __init__(self):
+                super(M, self).__init__(False)
+                self.b = [1, 2, 3, 4]
+
+            @torch.jit.script_method
+            def forward(self):
+                sum = 0
+                for i in self.b:
+                    sum += i
+                return sum
+
+        m = M()
+        self.assertEqual(m(), 10)
+
+    def test_script_module_for2(self):
+        class Sub(torch.jit.ScriptModule):
+            def __init__(self):
+                super(Sub, self).__init__(False)
+                self.weight = nn.Parameter(torch.randn(2))
+
+            @torch.jit.script_method
+            def forward(self, thing):
+                return self.weight + thing
+
+        class M(torch.jit.ScriptModule):
+            __constants__ = ['mods']
+
+            def __init__(self):
+                super(M, self).__init__(False)
+                self.mods = nn.ModuleList([Sub() for i in range(10)])
+
+            @torch.jit.script_method
+            def forward(self, v):
+                for m in self.mods:
+                    v = m(v)
+                return v
+
+        i = torch.Tensor(2)
+        m = M()
+        o = m(i)
+        v = i
+        for sub in m.mods:
+            v = sub(v)
+        self.assertEqual(o, v)
+
+    def test_script_module_const_submodule_fail(self):
+        class Sub(torch.jit.ScriptModule):
+            def __init__(self):
+                super(Sub, self).__init__(False)
+                self.weight = nn.Parameter(torch.randn(2))
+
+            @torch.jit.script_method
+            def forward(self, thing):
+                return self.weight + thing
+
+        class M(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M, self).__init__(False)
+                self.mods = [Sub() for _ in range(10)]
+
+            @torch.jit.script_method
+            def forward(self):
+                for _ in self.mods:
+                    print(1)
+                return 4
+
+        with self.assertRaisesRegex(RuntimeError, "did you forget to add it __constants__"):
+            M()
+
+    class DerivedStateModule(torch.jit.ScriptModule):
+        def __init__(self):
+            super(TestScript.DerivedStateModule, self).__init__()
+            self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
+            self.register_buffer('derived', torch.neg(self.param).detach())
+
+            # This is a flag so we can test that the pack method was called
+            self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long))
+            # This is a flag so we can test that the unpack method was called
+            self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long))
+
+        @torch.jit.script_method
+        def _pack(self):
+            self.pack_called.set_(torch.ones(1, dtype=torch.long))
+            self.derived.set_(torch.rand(1, dtype=torch.float).detach())
+
+        @torch.jit.script_method
+        def _unpack(self):
+            self.unpack_called.set_(torch.ones(1, dtype=torch.long))
+            self.derived.set_(torch.neg(self.param).detach())
+
+        @torch.jit.script_method
+        def forward(self, x):
+            return x + self.derived
+
+    def test_pack_unpack_state(self):
+        sm = TestScript.DerivedStateModule()
+        x = torch.rand(3, 4, dtype=torch.float)
+        torch.testing.assert_allclose(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
+
+        # Test save path
+        self.assertFalse(sm.pack_called.item())
+        self.assertFalse(sm.unpack_called.item())
+        sm.apply(lambda s: s._pack())
+        imported = self.getExportImportCopy(sm)
+        sm.apply(lambda s: s._unpack())
+        imported.apply(lambda s: s._unpack())
+        # ensure pack was called before serialization
+        self.assertTrue(sm.pack_called.item())
+        # ensure unpack was called after serialization so as to leave the module in an initialized state
+        self.assertTrue(sm.unpack_called.item())
+
+        torch.testing.assert_allclose(sm.derived, torch.neg(sm.param))
+
+        # Test load paths
+        self.assertTrue(imported.unpack_called.item())
+        torch.testing.assert_allclose(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
+
+    def test_pack_unpack_nested(self):
+        class SubSubMod(torch.jit.ScriptModule):
+            def __init__(self):
+                super(SubSubMod, self).__init__()
+                self.register_buffer('buf', torch.ones(3, 4) * 3)
+
+            @torch.jit.script_method
+            def _pack(self):
+                self.buf.set_(torch.zeros(1, dtype=torch.double))
+
+            @torch.jit.script_method
+            def _unpack(self):
+                self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 3)
+
+            @torch.jit.script_method
+            def forward(self, x):
+                return x + self.buf
+
+        class SubMod(torch.jit.ScriptModule):
+            def __init__(self):
+                super(SubMod, self).__init__()
+                self.register_buffer('buf', torch.ones(3, 4) * 2)
+                self.ssm = SubSubMod()
+
+            @torch.jit.script_method
+            def _pack(self):
+                self.buf.set_(torch.zeros(1, dtype=torch.double))
+
+            @torch.jit.script_method
+            def _unpack(self):
+                self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 2)
+
+            @torch.jit.script_method
+            def forward(self, x):
+                return self.ssm(x + self.buf)
+
+        class Mod(torch.jit.ScriptModule):
+            def __init__(self):
+                super(Mod, self).__init__()
+                self.submod = SubMod()
+                self.register_buffer('buf', torch.ones(3, 4) * 1)
+
+            @torch.jit.script_method
+            def _pack(self):
+                self.buf.set_(torch.zeros(1, dtype=torch.double))
+
+            @torch.jit.script_method
+            def _unpack(self):
+                self.buf.set_(torch.ones(3, 4, dtype=torch.double))
+
+            @torch.jit.script_method
+            def forward(self, x):
+                return self.submod(x + self.buf)
+
+        m = Mod()
+        torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
+        m.apply(lambda s: s._pack())
+        torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.zeros(3, 4))
+        m.apply(lambda s: s._unpack())
+        torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
+
+    def test_script_module_not_tuple(self):
         class M(torch.jit.ScriptModule):
-            __constants__ = ['b']
+            __constants__ = ['mods']
 
             def __init__(self):
                 super(M, self).__init__(False)
-                self.b = [1, 2, 3, 4]
+                self.mods = 1
 
             @torch.jit.script_method
-            def forward(self):
-                sum = 0
-                for i in self.b:
-                    sum += i
-                return sum
-
-        m = M()
-        self.assertEqual(m(), 10)
+            def forward(self, v):
+                for m in self.mods:
+                    print(m)
+                return v
+        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
+            M()
 
-    def test_script_module_for2(self):
+    def test_script_sequential_for(self):
         class Sub(torch.jit.ScriptModule):
             def __init__(self):
                 super(Sub, self).__init__(False)
@@ -5385,7 +4980,7 @@ a")
 
             def __init__(self):
                 super(M, self).__init__(False)
-                self.mods = nn.ModuleList([Sub() for i in range(10)])
+                self.mods = nn.Sequential(Sub(), Sub(), Sub())
 
             @torch.jit.script_method
             def forward(self, v):
@@ -5393,6 +4988,10 @@ a")
                     v = m(v)
                 return v
 
+            @torch.jit.script_method
+            def forward2(self, v):
+                return self.mods(v)
+
         i = torch.Tensor(2)
         m = M()
         o = m(i)
@@ -5401,7 +5000,10 @@ a")
             v = sub(v)
         self.assertEqual(o, v)
 
-    def test_script_module_const_submodule_fail(self):
+        o2 = m.forward2(i)
+        self.assertEqual(o2, v)
+
+    def test_script_sequential_multi_output_fail(self):
         class Sub(torch.jit.ScriptModule):
             def __init__(self):
                 super(Sub, self).__init__(False)
             def forward(self, thing):
                 return self.weight + thing
 
+        class ReturnMulti(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ReturnMulti, self).__init__(False)
+
+            @torch.jit.script_method
+            def forward(self, x):
+                return x, x, x
+
+        class HaveSequential(torch.jit.ScriptModule):
+            __constants__ = ['someseq']
+
+            def __init__(self):
+                super(HaveSequential, self).__init__(False)
+                self.someseq = nn.Sequential(
+                    Sub(),
+                    ReturnMulti(),
+                    Sub()
+                )
+
+            @torch.jit.script_method
+            def forward(self, x):
+                return self.someseq(x)
+
+        with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"):
+            hs = HaveSequential()
+            i = torch.Tensor(2)
+            hs(i)
+
+    def test_constant_as_attr(self):
         class M(torch.jit.ScriptModule):
+            __constants__ = ['dim']
+
             def __init__(self):
                 super(M, self).__init__(False)
-                self.mods = [Sub() for _ in range(10)]
+                self.dim = 1
 
             @torch.jit.script_method
-            def forward(self):
-                for _ in self.mods:
-                    print(1)
-                return 4
+            def forward(self, v):
+                return torch.cat([v, v, v], dim=self.dim)
+        v = torch.zeros(1, 1)
+        self.assertEqual(torch.cat([v, v, v], dim=1), M()(v))
 
-        with self.assertRaisesRegex(RuntimeError, "did you forget to add it __constants__"):
-            M()
+    class StarTestSumStarred(torch.nn.Module):
+        def __init__(self):
+            super(TestScript.StarTestSumStarred, self).__init__()
 
-    class DerivedStateModule(torch.jit.ScriptModule):
+        def forward(self, *inputs):
+            output = inputs[0]
+            for i in range(1, len(inputs)):
+                output += inputs[i]
+            return output
+
+    class StarTestReturnThree(torch.nn.Module):
         def __init__(self):
-            super(TestScript.DerivedStateModule, self).__init__()
-            self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
-            self.register_buffer('derived', torch.neg(self.param).detach())
+            super(TestScript.StarTestReturnThree, self).__init__()
 
-            # This is a flag so we can test that the pack method was called
-            self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long))
-            # This is a flag so we can test that the unpack method was called
-            self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long))
+        def forward(self, rep):
+            return rep, rep, rep
 
-        @torch.jit.script_method
-        def _pack(self):
-            self.pack_called.set_(torch.ones(1, dtype=torch.long))
-            self.derived.set_(torch.rand(1, dtype=torch.float).detach())
+    def test_script_star_expr(self):
 
-        @torch.jit.script_method
-        def _unpack(self):
-            self.unpack_called.set_(torch.ones(1, dtype=torch.long))
-            self.derived.set_(torch.neg(self.param).detach())
+        class M2(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M2, self).__init__(True)
+                self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
+                                         (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
+                self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
 
-        @torch.jit.script_method
-        def forward(self, x):
-            return x + self.derived
+            @torch.jit.script_method
+            def forward(self, rep):
+                tup = self.g(rep)
+                return self.m(*tup)
 
-    def test_pack_unpack_state(self):
-        sm = TestScript.DerivedStateModule()
-        x = torch.rand(3, 4, dtype=torch.float)
-        torch.testing.assert_allclose(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
+        m = M2()
+        self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
 
-        # Test save path
-        self.assertFalse(sm.pack_called.item())
-        self.assertFalse(sm.unpack_called.item())
-        sm.apply(lambda s: s._pack())
-        imported = self.getExportImportCopy(sm)
-        sm.apply(lambda s: s._unpack())
-        imported.apply(lambda s: s._unpack())
-        # ensure pack was called before serialization
-        self.assertTrue(sm.pack_called.item())
-        # ensure unpack was called after serialization so as to leave the module in an initialized state
-        self.assertTrue(sm.unpack_called.item())
+    def test_script_star_expr_string(self):
+        class M2(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M2, self).__init__(True)
+                self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
+                                         (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
+                self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
 
-        torch.testing.assert_allclose(sm.derived, torch.neg(sm.param))
+                self.define('''
+            def forward(self, rep):
+                tup = self.g(rep)
+                return self.m(*tup)
+                ''')
 
-        # Test load paths
-        self.assertTrue(imported.unpack_called.item())
-        torch.testing.assert_allclose(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
+        m = M2()
+        self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
 
-    def test_pack_unpack_nested(self):
-        class SubSubMod(torch.jit.ScriptModule):
+    class StarTestSumAndReturnThree(torch.nn.Module):
+        def __init__(self):
+            super(TestScript.StarTestSumAndReturnThree, self).__init__()
+
+        def forward(self, *inputs):
+            output = inputs[0]
+            for i in range(1, len(inputs)):
+                output += inputs[i]
+            return output, output, output
+
+    def test_script_star_assign(self):
+        class M2(torch.jit.ScriptModule):
             def __init__(self):
-                super(SubSubMod, self).__init__()
-                self.register_buffer('buf', torch.ones(3, 4) * 3)
+                super(M2, self).__init__(True)
+                self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3))
+                self.define('''
+            def forward(self, rep):
+                head, *tail = self.g(rep)
+                return head
+                ''')
 
-            @torch.jit.script_method
-            def _pack(self):
-                self.buf.set_(torch.zeros(1, dtype=torch.double))
+        m = M2()
+        self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
 
-            @torch.jit.script_method
-            def _unpack(self):
-                self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 3)
+    def test_script_module_star_assign2(self):
+        class M2(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M2, self).__init__(True)
+                self.g = torch.jit.trace(
+                    TestScript.StarTestSumAndReturnThree(),
+                    (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
+                    _force_outplace=True)
+                self.define('''
+            def forward(self, rep):
+                *head, tail = self.g(rep, rep, rep)
+                return tail
+                ''')
+
+        m = M2()
+        self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
+
+    def test_script_module_star_assign2_inplace(self):
+        class M2(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M2, self).__init__(True)
+                self.g = torch.jit.trace(
+                    TestScript.StarTestSumAndReturnThree(),
+                    (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
+                    _force_outplace=False)
+                self.define('''
+            def forward(self, rep):
+                *head, tail = self.g(rep, rep, rep)
+                return tail
+                ''')
+
+        m = M2()
+        # since forward() makes three aliases to the input `rep` before passing
+        # it to StarTestSumAndReturnThree(), in-place behavior will be different
+        # than the above out of place.
+        self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3))
+
+    def test_script_module_star_assign_fail_pythonop(self):
+
+        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
+            class M2(torch.jit.ScriptModule):
+                def __init__(self):
+                    super(M2, self).__init__(True)
+
+                    def myfunc():
+                        return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
+
+                    self.define('''
+                def forward(self, rep):
+                    a, *b = myfunc()
+                    return a
+                    ''')
+
+            m = M2()
+            m(torch.zeros(4, 3))
+
+    def test_script_module_star_assign_fail_builtin(self):
+        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
+            class M2(torch.jit.ScriptModule):
+                def __init__(self):
+                    super(M2, self).__init__(True)
+
+                    self.define('''
+                def forward(self, rep):
+                    a, *b = torch.neg(rep)
+                    return a
+                    ''')
+
+            m = M2()
+            m(torch.zeros(4, 3))
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return x + self.buf
+    def test_pack_padded_pad_packed_trace(self):
+        from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
+        T, B, C = 3, 5, 7
 
-        class SubMod(torch.jit.ScriptModule):
+        class PadPackedWrapper(torch.nn.Module):
             def __init__(self):
-                super(SubMod, self).__init__()
-                self.register_buffer('buf', torch.ones(3, 4) * 2)
-                self.ssm = SubSubMod()
+                super(PadPackedWrapper, self).__init__()
 
-            @torch.jit.script_method
-            def _pack(self):
-                self.buf.set_(torch.zeros(1, dtype=torch.double))
+            def forward(self, x, seq_lens):
+                x = pack_padded_sequence(x, seq_lens)
+                x, _ = pad_packed_sequence(x)
+                return x
 
-            @torch.jit.script_method
-            def _unpack(self):
-                self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 2)
+        x = np.ones((T, B, C))
+        seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
+        # set padding value so we can test equivalence
+        for b in range(B):
+            if seq_lens[b] < T:
+                x[seq_lens[b]:, b, :] = 0
+        seq_lens = torch.from_numpy(seq_lens)
+        x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return self.ssm(x + self.buf)
+        m = PadPackedWrapper()
+        m_traced = torch.jit.trace(m, (x, seq_lens,))
 
-        class Mod(torch.jit.ScriptModule):
-            def __init__(self):
-                super(Mod, self).__init__()
-                self.submod = SubMod()
-                self.register_buffer('buf', torch.ones(3, 4) * 1)
+        y = m(x, seq_lens)
+        loss = torch.sum(y)
+        loss.backward()
+        grad = x.grad.clone()
+        x.grad.zero_()
 
-            @torch.jit.script_method
-            def _pack(self):
-                self.buf.set_(torch.zeros(1, dtype=torch.double))
+        y_traced = m_traced(x, seq_lens)
+        loss_traced = torch.sum(y_traced)
+        loss_traced.backward()
+        grad_traced = x.grad.clone()
 
-            @torch.jit.script_method
-            def _unpack(self):
-                self.buf.set_(torch.ones(3, 4, dtype=torch.double))
+        self.assertEqual(y_traced, x)
+        self.assertEqual(y_traced, y)
+        self.assertEqual(grad, grad_traced)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return self.submod(x + self.buf)
+        f = io.BytesIO()
+        torch.onnx._export(m, (x, seq_lens), f, verbose=False)
 
-        m = Mod()
-        torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
-        m.apply(lambda s: s._pack())
-        torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.zeros(3, 4))
-        m.apply(lambda s: s._unpack())
-        torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
+    def test_script_outputs(self):
+        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
+            @torch.jit.script
+            def foo(a):
+                c, d = a + a
+                return c + d
 
-    def test_script_module_not_tuple(self):
-        class M(torch.jit.ScriptModule):
-            __constants__ = ['mods']
+        @torch.jit.script
+        def return3():
+            return 1, 2, 3
 
-            def __init__(self):
-                super(M, self).__init__(False)
-                self.mods = 1
+        with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
+            @torch.jit.script
+            def bind2():
+                a, b = return3()
+                print(a)
+                print(b)
 
-            @torch.jit.script_method
-            def forward(self, v):
-                for m in self.mods:
-                    print(m)
-                return v
-        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
-            M()
+    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+    def test_script_get_device_cuda(self):
+        @torch.jit.script
+        def foo(a):
+            return a.get_device()
 
-    def test_script_sequential_for(self):
-        class Sub(torch.jit.ScriptModule):
-            def __init__(self):
-                super(Sub, self).__init__(False)
-                self.weight = nn.Parameter(torch.randn(2))
+        v = torch.randn(1, device='cuda')
+        self.assertEqual(foo(v), 0)
 
-            @torch.jit.script_method
-            def forward(self, thing):
-                return self.weight + thing
+    def test_script_chunk(self):
+        @torch.jit.script
+        def foo(a):
+            b, c = torch.chunk(a, dim=0, chunks=2)
+            return b
+        v = torch.rand(10, 3)
+        self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
 
-        class M(torch.jit.ScriptModule):
-            __constants__ = ['mods']
+    def test_rnn_trace_override(self):
+        from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
+        num_layers = 3
+        T, B, C = 11, 5, 7
 
-            def __init__(self):
-                super(M, self).__init__(False)
-                self.mods = nn.Sequential(Sub(), Sub(), Sub())
+        class RNNTraceWrapper(torch.nn.Module):
+            def __init__(self, cell_type):
+                super(RNNTraceWrapper, self).__init__()
+                if cell_type == 'RNN':
+                    self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers)
+                elif cell_type == 'LSTM':
+                    self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers)
+                elif cell_type == 'GRU':
+                    self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers)
 
-            @torch.jit.script_method
-            def forward(self, v):
-                for m in self.mods:
-                    v = m(v)
-                return v
+            def forward(self, x, seq_lens):
+                x = pack_padded_sequence(x, seq_lens)
+                x, _ = self.rnn(x)
+                x, _ = pad_packed_sequence(x)
+                return x
 
-            @torch.jit.script_method
-            def forward2(self, v):
-                return self.mods(v)
+        for cell_type in ['RNN', 'LSTM', 'GRU']:
+            x = torch.ones(T, B, C, requires_grad=True)
+            seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
 
-        i = torch.Tensor(2)
-        m = M()
-        o = m(i)
-        v = i
-        for sub in m.mods:
-            v = sub(v)
-        self.assertEqual(o, v)
+            m = RNNTraceWrapper(cell_type)
+            m_traced = torch.jit.trace(m, (x, seq_lens,))
 
-        o2 = m.forward2(i)
-        self.assertEqual(o2, v)
+            y = m(x, seq_lens)
+            loss = torch.sum(y)
+            loss.backward()
+            grad = x.grad.clone()
+            x.grad.zero_()
 
-    def test_script_sequential_multi_output_fail(self):
-        class Sub(torch.jit.ScriptModule):
-            def __init__(self):
-                super(Sub, self).__init__(False)
-                self.weight = nn.Parameter(torch.randn(2))
+            y_traced = m_traced(x, seq_lens)
+            loss_traced = torch.sum(y_traced)
+            loss_traced.backward()
+            grad_traced = x.grad.clone()
 
-            @torch.jit.script_method
-            def forward(self, thing):
-                return self.weight + thing
+            self.assertEqual(y_traced, y)
+            self.assertEqual(grad, grad_traced)
 
-        class ReturnMulti(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ReturnMulti, self).__init__(False)
+            f = io.BytesIO()
+            torch.onnx._export(m, (x, seq_lens), f, verbose=False)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return x, x, x
+    def test_python_call_non_tensor(self):
+        def foo(a, b, c):
+            # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor]
+            d, e = c
+            return b + e, a + d
 
-        class HaveSequential(torch.jit.ScriptModule):
-            __constants__ = ['someseq']
+        @torch.jit.script
+        def bar():
+            x = torch.ones(3, 4)
+            a, b = foo(x, 3, (x, 3))
+            return a, b
 
-            def __init__(self):
-                super(HaveSequential, self).__init__(False)
-                self.someseq = nn.Sequential(
-                    Sub(),
-                    ReturnMulti(),
-                    Sub()
-                )
+        self.assertEqual((6, torch.ones(3, 4) + 1), bar())
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return self.someseq(x)
+    def test_python_call_non_tensor_wrong(self):
+        with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
+            def foo():
+                # type: () -> Tensor
+                return ((3, 4),)
 
-        with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"):
-            hs = HaveSequential()
-            i = torch.Tensor(2)
-            hs(i)
+            @torch.jit.script
+            def bar():
+                return foo()
 
-    def test_constant_as_attr(self):
-        class M(torch.jit.ScriptModule):
-            __constants__ = ['dim']
+            bar()
 
-            def __init__(self):
-                super(M, self).__init__(False)
-                self.dim = 1
+    def test_tuples(self):
+        def foo(i):
+            a = (i + 4, i * 2)
+            c = a
+            # some nonsense with if-statements and loops to check
+            # that tuple lowering doesn't fail
+            if True:
+                c = (i * 9, i + 1)
+            t0, t1 = c
+            while False:
+                t0, t1 = c
+                c = (t1, t0)
+            x = (1,)
+            y = 1,
+            return t0, x, y
 
-            @torch.jit.script_method
-            def forward(self, v):
-                return torch.cat([v, v, v], dim=self.dim)
-        v = torch.zeros(1, 1)
-        self.assertEqual(torch.cat([v, v, v], dim=1), M()(v))
+        v = torch.rand(10, 3)
+        self.checkScript(foo, (v,))
 
-    class StarTestSumStarred(torch.nn.Module):
-        def __init__(self):
-            super(TestScript.StarTestSumStarred, self).__init__()
+        with self.assertRaisesRegex(RuntimeError, r"variable 'a' previously has type \(Tensor, Tensor\)"):
+            @torch.jit.script
+            def mixtypes(x):
+                a = (x, x)
+                if True:
+                    a = 4
 
-        def forward(self, *inputs):
-            output = inputs[0]
-            for i in range(1, len(inputs)):
-                output += inputs[i]
-            return output
+    def test_if_tuple_sizes(self):
+        with self.assertRaisesRegex(RuntimeError, "Type mismatch"):
+            @torch.jit.script
+            def diff_tuple_sizes(x):
+                if False:
+                    c0 = ((x, x), (x, x, x))
+                else:
+                    c0 = ((x, x, x), (x, x))
+                return c0
 
-    class StarTestReturnThree(torch.nn.Module):
-        def __init__(self):
-            super(TestScript.StarTestReturnThree, self).__init__()
+    def test_if_different_type(self):
+        with self.assertRaisesRegex(RuntimeError, "Type mismatch: c0 is set to type int "
+                                    "in the true branch and type float in the false branch:"):
+            @torch.jit.script
+            def diff_type_used():
+                if False:
+                    c0 = 1
+                else:
+                    c0 = 1.0
+                return c0
 
-        def forward(self, rep):
-            return rep, rep, rep
+        with self.assertRaisesRegex(RuntimeError, "variable 'c0' previously has type float"):
+            @torch.jit.script
+            def diff_existing_type(x):
+                c0 = 1.0
+                if False:
+                    c0 = 1
+                    print(x)
+                return x
 
-    def test_script_star_expr(self):
+        @torch.jit.script
+        def diff_type_unused():
+            if True:
+                c0 = 1
+                print(c0)
+            else:
+                c0 = 1.0
+                print(c0)
+            return 1
 
-        class M2(torch.jit.ScriptModule):
-            def __init__(self):
-                super(M2, self).__init__(True)
-                self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
-                                         (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
-                self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
+    def test_if_list(self):
+        # testing that different length lists don't throw error
+        @torch.jit.script
+        def test_list(x):
+            if True:
+                c = [x, x]
+            else:
+                c = [x, x, x]
+            return torch.cat(c)
 
-            @torch.jit.script_method
-            def forward(self, rep):
-                tup = self.g(rep)
-                return self.m(*tup)
+        b = torch.zeros(2, 4)
+        test_list.graph.propagate_shapes((b,), False)
+        self.assertExpected(canonical(test_list.graph))
 
-        m = M2()
-        self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
+    def test_if_supertype(self):
+        @torch.jit.script
+        def tensor_unifying(x, y, z):
 
-    def test_script_star_expr_string(self):
-        class M2(torch.jit.ScriptModule):
-            def __init__(self):
-                super(M2, self).__init__(True)
-                self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
-                                         (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
-                self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
+            # testing dynamic is appropriately set for y and z
+            if True:
+                x, y, z = x, y, z
+            else:
+                x, y, z = x, x, y
 
-                self.define('''
-            def forward(self, rep):
-                tup = self.g(rep)
-                return self.m(*tup)
-                ''')
+            return x, y, z
 
-        m = M2()
-        self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
+        a = torch.zeros(2, 2, dtype=torch.float)
+        b = torch.zeros(2, 4, dtype=torch.long)
+        c = torch.zeros(2, 4, dtype=torch.float)
 
-    class StarTestSumAndReturnThree(torch.nn.Module):
-        def __init__(self):
-            super(TestScript.StarTestSumAndReturnThree, self).__init__()
+        tensor_unifying.graph.propagate_shapes((a, b, c), False)
+        self.assertExpected(canonical(tensor_unifying.graph))
 
-        def forward(self, *inputs):
-            output = inputs[0]
-            for i in range(1, len(inputs)):
-                output += inputs[i]
-            return output, output, output
+    def test_type_annotations_repeated_list(self):
+        @torch.jit.script
+        def float_fn(x, y):
+            # type: (float, BroadcastingList3[float]) -> List[float]
+            return y
+        self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0]))
+        self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0)))
 
-    def test_script_star_assign(self):
-        class M2(torch.jit.ScriptModule):
-            def __init__(self):
-                super(M2, self).__init__(True)
-                self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3))
-                self.define('''
-            def forward(self, rep):
-                head, *tail = self.g(rep)
-                return head
-                ''')
+        @torch.jit.script
+        def float_fn_call():
+            print(float_fn(1.0, 1.0))
+            print(float_fn(1.0, (1.0, 1.0, 1.0)))
 
-        m = M2()
-        self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
+        @torch.jit.script
+        def int_fn(x):
+            # type: (BroadcastingList3[int]) -> List[int]
+            return x
+        self.assertEqual(int_fn(1), int_fn([1, 1, 1]))
+        self.assertEqual(int_fn(1), int_fn((1, 1, 1)))
 
-    def test_script_module_star_assign2(self):
-        class M2(torch.jit.ScriptModule):
-            def __init__(self):
-                super(M2, self).__init__(True)
-                self.g = torch.jit.trace(
-                    TestScript.StarTestSumAndReturnThree(),
-                    (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
-                    _force_outplace=True)
-                self.define('''
-            def forward(self, rep):
-                *head, tail = self.g(rep, rep, rep)
-                return tail
-                ''')
+        @torch.jit.script
+        def int_fn_call():
+            print(int_fn(1))
+            print(int_fn((1, 1, 1)))
 
-        m = M2()
-        self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
+        with self.assertRaisesRegex(RuntimeError, "expected number"):
+            @torch.jit.script
+            def fn(x):
+                # type: (BroadcastingListx[int]) -> List[int]
+                return x
 
-    def test_script_module_star_assign2_inplace(self):
-        class M2(torch.jit.ScriptModule):
-            def __init__(self):
-                super(M2, self).__init__(True)
-                self.g = torch.jit.trace(
-                    TestScript.StarTestSumAndReturnThree(),
-                    (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
-                    _force_outplace=False)
-                self.define('''
-            def forward(self, rep):
-                *head, tail = self.g(rep, rep, rep)
-                return tail
-                ''')
+        with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
+            @torch.jit.script
+            def nested(x, y):
+                # type: (int, Tuple[int, int[2]]) -> List[int]
+                return x
 
-        m = M2()
-        # since forward() makes three aliases to the input `rep` before passing
-        # it to StarTestSumAndReturnThree(), in-place behavior will be different
-        # than the above out of place.
-        self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3))
+    def test_ntuple_builtins(self):
+        from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
 
-    def test_script_module_star_assign_fail_pythonop(self):
+        def test_ints():
+            return _single(1), _pair(2), _triple(3), _quadruple(4)
 
-        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
-            class M2(torch.jit.ScriptModule):
-                def __init__(self):
-                    super(M2, self).__init__(True)
+        def test_floats():
+            return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1)
 
-                    def myfunc():
-                        return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
+        self.checkScript(test_ints, ())
+        self.checkScript(test_floats, ())
 
-                    self.define('''
-                def forward(self, rep):
-                    a, *b = myfunc()
-                    return a
-                    ''')
+    def test_embedding_renorm_grad_error(self):
+        # Testing that the builtin call to embedding_renorm_ correctly throws
+        # Error when .backward() is called on its input
 
-            m = M2()
-            m(torch.zeros(4, 3))
+        def embedding_norm(input, embedding_matrix, max_norm):
+            F.embedding(input, embedding_matrix, max_norm=0.01)
 
-    def test_script_module_star_assign_fail_builtin(self):
-        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
-            class M2(torch.jit.ScriptModule):
-                def __init__(self):
-                    super(M2, self).__init__(True)
+        @torch.jit.script
+        def embedding_norm_script(input, embedding_matrix, max_norm):
+            # type: (Tensor, Tensor, float)
+            F.embedding(input, embedding_matrix, max_norm=0.01)
 
-                    self.define('''
-                def forward(self, rep):
-                    a, *b = torch.neg(rep)
-                    return a
-                    ''')
+        for fun in [embedding_norm, embedding_norm_script]:
+            input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
+            embedding_matrix = torch.randn(10, 3)
 
-            m = M2()
-            m(torch.zeros(4, 3))
+            var1 = torch.randn(10, 3, requires_grad=True)
+            var2 = var1.detach().requires_grad_()
+            output1 = var1 * embedding_matrix
+            output2 = var2 * embedding_matrix
 
-    def test_pack_padded_pad_packed_trace(self):
-        from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
-        T, B, C = 3, 5, 7
+            output1.sum().backward()
 
-        class PadPackedWrapper(torch.nn.Module):
-            def __init__(self):
-                super(PadPackedWrapper, self).__init__()
+            ignore = F.embedding(input, embedding_matrix, max_norm=0.01)
+            with self.assertRaisesRegex(RuntimeError, "modified"):
+                output2.sum().backward()
+
+    def test_type_annotations(self):
+        def fn(x, y):
+            # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
+            return x, x * 2, x * 3
 
-            def forward(self, x, seq_lens):
-                x = pack_padded_sequence(x, seq_lens)
-                x, _ = pad_packed_sequence(x)
-                return x
+        with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
+            @torch.jit.script
+            def script_fn(x):
+                x, y, z, w = fn(x, x)
 
-        x = np.ones((T, B, C))
-        seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
-        # set padding value so we can test equivalence
-        for b in range(B):
-            if seq_lens[b] < T:
-                x[seq_lens[b]:, b, :] = 0
-        seq_lens = torch.from_numpy(seq_lens)
-        x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True)
+        with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
+            @torch.jit.script
+            def script_fn2(x):
+                x, y = fn(x, x)
 
-        m = PadPackedWrapper()
-        m_traced = torch.jit.trace(m, (x, seq_lens,))
+        def fn_unpack(x):
+            y, z, w = fn(x, x)
+            return y
 
-        y = m(x, seq_lens)
-        loss = torch.sum(y)
-        loss.backward()
-        grad = x.grad.clone()
-        x.grad.zero_()
+        def fn_index(x):
+            q = fn(x, x)
+            return x
 
-        y_traced = m_traced(x, seq_lens)
-        loss_traced = torch.sum(y_traced)
-        loss_traced.backward()
-        grad_traced = x.grad.clone()
+        def fn_string(str, strpair):
+            # type: (str, Tuple[str, str]) -> Tuple[str, int, str, str]
+            str1, str2 = strpair
+            return str, 2, str1, str2
 
-        self.assertEqual(y_traced, x)
-        self.assertEqual(y_traced, y)
-        self.assertEqual(grad, grad_traced)
+        x = torch.ones(2, 2)
+        self.checkScript(fn_unpack, (x,), optimize=True)
+        self.checkScript(fn_index, (x,), optimize=True)
+        self.checkScript(fn_string, ("1", ("3", "4")), optimize=True)
 
-        f = io.BytesIO()
-        torch.onnx._export(m, (x, seq_lens), f, verbose=False)
+    def test_type_annotations_varargs(self):
+        def fn_varargs(x, *args):
+            return args[0] if args else x
 
-    def test_script_outputs(self):
-        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
-            @torch.jit.script
-            def foo(a):
-                c, d = a + a
-                return c + d
+        def fn1(x, y, z):
+            return fn_varargs(x)
 
-        @torch.jit.script
-        def return3():
-            return 1, 2, 3
+        def fn2(x, y, z):
+            return fn_varargs(x, y)
 
-        with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
-            @torch.jit.script
-            def bind2():
-                a, b = return3()
-                print(a)
-                print(b)
+        def fn3(x, y, z):
+            return fn_varargs(x, y, z)
 
-    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
-    def test_script_get_device_cuda(self):
-        @torch.jit.script
-        def foo(a):
-            return a.get_device()
+        x, y, z = [torch.randn(2, 2) for _ in range(3)]
+        self.checkScript(fn1, (x, y, z), optimize=True)
+        self.checkScript(fn2, (x, y, z), optimize=True)
+        self.checkScript(fn3, (x, y, z), optimize=True)
 
-        v = torch.randn(1, device='cuda')
-        self.assertEqual(foo(v), 0)
+    @unittest.skipIf(not PY35, "Python 3.5 needed")
+    def test_type_annotation_py3(self):
+        import importlib.util
 
-    def test_script_chunk(self):
-        @torch.jit.script
-        def foo(a):
-            b, c = torch.chunk(a, dim=0, chunks=2)
-            return b
-        v = torch.rand(10, 3)
-        self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
+        code = dedent("""
+        import torch
+        from torch import Tensor
+        from typing import Tuple
 
-    def test_rnn_trace_override(self):
-        from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
-        num_layers = 3
-        T, B, C = 11, 5, 7
+        def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]:
+            return (x, y + z, z)
+        """)
 
-        class RNNTraceWrapper(torch.nn.Module):
-            def __init__(self, cell_type):
-                super(RNNTraceWrapper, self).__init__()
-                if cell_type == 'RNN':
-                    self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers)
-                elif cell_type == 'LSTM':
-                    self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers)
-                elif cell_type == 'GRU':
-                    self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers)
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            script_path = os.path.join(tmp_dir, 'script.py')
+            with open(script_path, 'w') as f:
+                f.write(code)
+            fn = get_fn('test_type_annotation_py3', script_path)
 
-            def forward(self, x, seq_lens):
-                x = pack_padded_sequence(x, seq_lens)
-                x, _ = self.rnn(x)
-                x, _ = pad_packed_sequence(x)
-                return x
+            with self.assertRaisesRegex(RuntimeError, r"expected a value of type Tensor for argument"
+                                                      r" '0' but found \(Tensor, Tensor\)"):
+                @torch.jit.script
+                def bad_fn(x):
+                    x, y = fn((x, x), x, x)
+                    return y
 
-        for cell_type in ['RNN', 'LSTM', 'GRU']:
-            x = torch.ones(T, B, C, requires_grad=True)
-            seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
+            with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
+                @torch.jit.script
+                def bad_fn2(x):
+                    x, y = fn(x, x, x)
+                    return y
 
-            m = RNNTraceWrapper(cell_type)
-            m_traced = torch.jit.trace(m, (x, seq_lens,))
+            with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
+                @torch.jit.script
+                def bad_fn3(x):
+                    x, y, z, w = fn(x, x, x)
+                    return y
 
-            y = m(x, seq_lens)
-            loss = torch.sum(y)
-            loss.backward()
-            grad = x.grad.clone()
-            x.grad.zero_()
+            def good_fn(x):
+                y, z, w = fn(x, x, x)
+                return y, z, w
 
-            y_traced = m_traced(x, seq_lens)
-            loss_traced = torch.sum(y_traced)
-            loss_traced.backward()
-            grad_traced = x.grad.clone()
+            self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
 
-            self.assertEqual(y_traced, y)
-            self.assertEqual(grad, grad_traced)
+    def test_type_annotation_module(self):
+        class BaseModule(torch.jit.ScriptModule):
+            def foo(self, x):
+                # type: (Tensor) -> Tensor
+                return x + 1
 
-            f = io.BytesIO()
-            torch.onnx._export(m, (x, seq_lens), f, verbose=False)
+            def bar(self, x, y):
+                # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
+                return x + y, y
 
-    def test_python_call_non_tensor(self):
-        def foo(a, b, c):
-            # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor]
-            d, e = c
-            return b + e, a + d
+            def baz(self, x, y):
+                return x
 
-        @torch.jit.script
-        def bar():
-            x = torch.ones(3, 4)
-            a, b = foo(x, 3, (x, 3))
-            return a, b
+        class ModuleTooMany(BaseModule):
+            @torch.jit.script_method
+            def method(self, x):
+                return self.foo(x, x)
 
-        self.assertEqual((6, torch.ones(3, 4) + 1), bar())
+        class ModuleTooFew(BaseModule):
+            @torch.jit.script_method
+            def method(self, x):
+                return self.bar(x)
 
-    def test_python_call_non_tensor_wrong(self):
-        with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
-            def foo():
-                # type: () -> Tensor
-                return ((3, 4),)
+        class ModuleTooManyAssign(BaseModule):
+            @torch.jit.script_method
+            def method(self, x):
+                y, z, w = self.bar(x, x)
+                return x
 
-            @torch.jit.script
-            def bar():
-                return foo()
+        class ModuleDefault(BaseModule):
+            @torch.jit.script_method
+            def method(self, x):
+                y = self.baz(x)
+                return x
 
-            bar()
+        with self.assertRaisesRegex(RuntimeError, "expected at most 1 arguments but found 2"):
+            ModuleTooMany()
+        with self.assertRaisesRegex(RuntimeError, "argument 1 not provided"):
+            ModuleTooFew()
+        with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"):
+            ModuleTooManyAssign()
+        with self.assertRaisesRegex(RuntimeError, "argument 1 not provided."):
+            ModuleDefault()
 
-    def test_tuples(self):
-        def foo(i):
-            a = (i + 4, i * 2)
-            c = a
-            # some nonsense with if-statements and loops to check
-            # that tuple lowering doesn't fail
-            if True:
-                c = (i * 9, i + 1)
-            t0, t1 = c
-            while False:
-                t0, t1 = c
-                c = (t1, t0)
-            x = (1,)
-            y = 1,
-            return t0, x, y
+    def test_script_define_order(self):
+        class M(torch.jit.ScriptModule):
+            def __init__(self):
+                pass
 
-        v = torch.rand(10, 3)
-        self.checkScript(foo, (v,))
+            @torch.jit.script_method
+            def call_foo(self, input):
+                return self.foo(input)
 
-        with self.assertRaisesRegex(RuntimeError, r"variable 'a' previously has type \(Tensor, Tensor\)"):
-            @torch.jit.script
-            def mixtypes(x):
-                a = (x, x)
-                if True:
-                    a = 4
+            @torch.jit.script_method
+            def foo(self, input):
+                return input + 1
+        m = M()
+        self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
 
-    def test_if_tuple_sizes(self):
-        with self.assertRaisesRegex(RuntimeError, "Type mismatch"):
-            @torch.jit.script
-            def diff_tuple_sizes(x):
-                if False:
-                    c0 = ((x, x), (x, x, x))
-                else:
-                    c0 = ((x, x, x), (x, x))
-                return c0
+    def test_script_define_order_recursive_fail(self):
+        class M(torch.jit.ScriptModule):
+            def __init__(self):
+                pass
 
-    def test_if_different_type(self):
-        with self.assertRaisesRegex(RuntimeError, "Type mismatch: c0 is set to type int "
-                                    "in the true branch and type float in the false branch:"):
-            @torch.jit.script
-            def diff_type_used():
-                if False:
-                    c0 = 1
-                else:
-                    c0 = 1.0
-                return c0
+            @torch.jit.script_method
+            def call_foo(self, input):
+                return self.foo(input)
 
-        with self.assertRaisesRegex(RuntimeError, "variable 'c0' previously has type float"):
-            @torch.jit.script
-            def diff_existing_type(x):
-                c0 = 1.0
-                if False:
-                    c0 = 1
-                    print(x)
-                return x
+            @torch.jit.script_method
+            def foo(self, input):
+                self.call_foo(input)
 
-        @torch.jit.script
-        def diff_type_unused():
-            if True:
-                c0 = 1
-                print(c0)
-            else:
-                c0 = 1.0
-                print(c0)
-            return 1
+        with self.assertRaisesRegex(RuntimeError, 'called recursively involving'):
+            M()
 
-    def test_if_list(self):
-        # testing that different length lists don't throw error
-        @torch.jit.script
-        def test_list(x):
-            if True:
-                c = [x, x]
-            else:
-                c = [x, x, x]
-            return torch.cat(c)
+    def test_script_kwargs_fn_call(self):
+        class M(torch.jit.ScriptModule):
+            def __init__(self):
+                pass
 
-        b = torch.zeros(2, 4)
-        test_list.graph.propagate_shapes((b,), False)
-        self.assertExpected(canonical(test_list.graph))
+            @torch.jit.script_method
+            def call_foo(self, input):
+                return self.foo(input=input, bar=1)
 
-    def test_if_supertype(self):
+            @torch.jit.script_method
+            def foo(self, bar, input):
+                # type: (int, Tensor) -> Tensor
+                return input + bar
+        m = M()
+        self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
+
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    def test_trace_of_script(self):
         @torch.jit.script
-        def tensor_unifying(x, y, z):
+        def foo(a, c):
+            b = 0.0
+            if bool(a == 0.0):
+                b = 1.0
+            return b + c
 
-            # testing dynamic is appropriately set for y and z
-            if True:
-                x, y, z = x, y, z
-            else:
-                x, y, z = x, x, y
+        a = torch.ones(1, dtype=torch.float)
 
-            return x, y, z
+        @_trace(torch.zeros(1, dtype=torch.float))
+        def use(b):
+            return foo(b - 1.0, a) + 1.0
 
-        a = torch.zeros(2, 2, dtype=torch.float)
-        b = torch.zeros(2, 4, dtype=torch.long)
-        c = torch.zeros(2, 4, dtype=torch.float)
+        # test we propagated shapes through the function
+        self.assertTrue("Dynamic" not in str(use.graph))
 
-        tensor_unifying.graph.propagate_shapes((a, b, c), False)
-        self.assertExpected(canonical(tensor_unifying.graph))
+        self.assertEqual(3, use(torch.ones(1, dtype=torch.float)))
+        self.assertEqual(2, use(torch.zeros(1, dtype=torch.float)))
 
-    def test_type_annotations_repeated_list(self):
+    def test_if_define(self):
         @torch.jit.script
-        def float_fn(x, y):
-            # type: (float, BroadcastingList3[float]) -> List[float]
-            return y
-        self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0]))
-        self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0)))
+        def foo(a):
+            if bool(a == 0):
+                b = 1
+            else:
+                b = 0
+            return b + 1
 
         @torch.jit.script
-        def float_fn_call():
-            print(float_fn(1.0, 1.0))
-            print(float_fn(1.0, (1.0, 1.0, 1.0)))
+        def foo2(a):
+            b = 0
+            if bool(a == 0):
+                b = 1
+            return b + 1
 
         @torch.jit.script
-        def int_fn(x):
-            # type: (BroadcastingList3[int]) -> List[int]
-            return x
-        self.assertEqual(int_fn(1), int_fn([1, 1, 1]))
-        self.assertEqual(int_fn(1), int_fn((1, 1, 1)))
+        def foo3(a):
+            b = 1
+            if bool(a == 0):
+                c = 4
+            else:
+                b = 0
+            return b + 1
 
-        @torch.jit.script
-        def int_fn_call():
-            print(int_fn(1))
-            print(int_fn((1, 1, 1)))
+        a = torch.ones(1, dtype=torch.long)
+        b = torch.zeros(1, dtype=torch.long)
+        self.assertEqual(1, foo(a))
+        self.assertEqual(2, foo(b))
+        self.assertEqual(1, foo2(a))
+        self.assertEqual(2, foo2(b))
+        self.assertEqual(1, foo3(a))
+        self.assertEqual(2, foo3(b))
 
-        with self.assertRaisesRegex(RuntimeError, "expected number"):
-            @torch.jit.script
-            def fn(x):
-                # type: (BroadcastingListx[int]) -> List[int]
-                return x
+    def test_script_module_export_submodule(self):
+        class M1(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M1, self).__init__(False)
+                self.weight = nn.Parameter(torch.randn(2))
 
-        with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
-            @torch.jit.script
-            def nested(x, y):
-                # type: (int, Tuple[int, int[2]]) -> List[int]
-                return x
+            @torch.jit.script_method
+            def forward(self, thing):
+                return self.weight + thing
 
-    def test_ntuple_builtins(self):
-        from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
+        class M2(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M2, self).__init__(False)
+                # test submodule
+                self.sub = M1()
+                self.weight = nn.Parameter(torch.randn(2, 3))
+                self.bias = nn.Parameter(torch.randn(2))
+                self.define("""
+                    def hi(self, a):
+                        return self.weight.mm(a)
+                """)
 
-        def test_ints():
-            return _single(1), _pair(2), _triple(3), _quadruple(4)
+            @torch.jit.script_method
+            def doit(self, input):
+                return self.weight.mm(input)
 
-        def test_floats():
-            return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1)
+            @torch.jit.script_method
+            def doit2(self, input):
+                return self.weight.mm(input)
 
-        self.checkScript(test_ints, ())
-        self.checkScript(test_floats, ())
+            @torch.jit.script_method
+            def doit3(self, input):
+                return input + torch.ones([1], dtype=torch.double)
 
-    def test_embedding_renorm_grad_error(self):
-        # Testing that the builtin call to embedding_renorm_ correctly throws
-        # Error when .backward() is called on its input
+            @torch.jit.script_method
+            def forward(self, input):
+                a = self.doit(input)
+                b = self.doit2(input)
+                c = self.hi(input)
+                return a + b + self.bias + c
 
-        def embedding_norm(input, embedding_matrix, max_norm):
-            F.embedding(input, embedding_matrix, max_norm=0.01)
+        m_orig = M2()
+        m_import = self.getExportImportCopy(m_orig)
 
-        @torch.jit.script
-        def embedding_norm_script(input, embedding_matrix, max_norm):
-            # type: (Tensor, Tensor, float)
-            F.embedding(input, embedding_matrix, max_norm=0.01)
+        input = torch.randn(3, 2)
+        self.assertEqual(m_orig.doit(input), m_import.doit(input))
+        self.assertEqual(m_orig.hi(input), m_import.hi(input))
+        self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
+        self.assertEqual(m_orig.forward(input), m_import.forward(input))
 
-        for fun in [embedding_norm, embedding_norm_script]:
-            input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
-            embedding_matrix = torch.randn(10, 3)
+    @skipIfNoTorchVision
+    def test_script_module_trace_resnet18(self):
+        x = torch.ones(1, 3, 224, 224)
+        m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224))
+        m_import = self.getExportImportCopy(m_orig)
 
-            var1 = torch.randn(10, 3, requires_grad=True)
-            var2 = var1.detach().requires_grad_()
-            output1 = var1 * embedding_matrix
-            output2 = var2 * embedding_matrix
+        input = torch.randn(1, 3, 224, 224, requires_grad=True)
+        output_orig = m_orig(input)
+        output_orig.sum().backward()
+        grad_orig = input.grad.clone()
+        input.grad.zero_()
 
-            output1.sum().backward()
+        output_import = m_import(input)
+        output_import.sum().backward()
+        grad_import = input.grad.clone()
 
-            ignore = F.embedding(input, embedding_matrix, max_norm=0.01)
-            with self.assertRaisesRegex(RuntimeError, "modified"):
-                output2.sum().backward()
+        self.assertEqual(output_orig, output_import)
+        self.assertEqual(grad_orig, grad_import)
 
-    def test_type_annotations(self):
-        def fn(x, y):
-            # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
-            return x, x * 2, x * 3
+    @skipIfNoTorchVision
+    def test_script_module_script_resnet(self):
+        def conv1x1(in_planes, out_planes, stride=1):
+            """1x1 convolution"""
+            return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
 
-        with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
-            @torch.jit.script
-            def script_fn(x):
-                x, y, z, w = fn(x, x)
+        def conv3x3(in_planes, out_planes, stride=1):
+            """3x3 convolution with padding"""
+            return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                             padding=1, bias=False)
 
-        with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
-            @torch.jit.script
-            def script_fn2(x):
-                x, y = fn(x, x)
+        class BasicBlock(torch.jit.ScriptModule):
+            expansion = 1
+            __constants__ = ['downsample']
 
-        def fn_unpack(x):
-            y, z, w = fn(x, x)
-            return y
+            def __init__(self, inplanes, planes, stride=1, downsample=None):
+                super(BasicBlock, self).__init__()
+                self.conv1 = conv3x3(inplanes, planes, stride)
+                self.bn1 = nn.BatchNorm2d(planes)
+                self.relu = nn.ReLU(inplace=True)
+                self.conv2 = conv3x3(planes, planes)
+                self.bn2 = nn.BatchNorm2d(planes)
+                self.downsample = downsample
+                self.stride = stride
 
-        def fn_index(x):
-            q = fn(x, x)
-            return x
+            @torch.jit.script_method
+            def forward(self, x):
+                residual = x
 
-        def fn_string(str, strpair):
-            # type: (str, Tuple[str, str]) -> Tuple[str, int, str, str]
-            str1, str2 = strpair
-            return str, 2, str1, str2
+                out = self.conv1(x)
+                out = self.bn1(out)
+                out = self.relu(out)
 
-        x = torch.ones(2, 2)
-        self.checkScript(fn_unpack, (x,), optimize=True)
-        self.checkScript(fn_index, (x,), optimize=True)
-        self.checkScript(fn_string, ("1", ("3", "4")), optimize=True)
+                out = self.conv2(out)
+                out = self.bn2(out)
 
-    def test_type_annotations_varargs(self):
-        def fn_varargs(x, *args):
-            return args[0] if args else x
+                if self.downsample is not None:
+                    residual = self.downsample(x)
 
-        def fn1(x, y, z):
-            return fn_varargs(x)
+                out += residual
+                out = self.relu(out)
 
-        def fn2(x, y, z):
-            return fn_varargs(x, y)
+                return out
 
-        def fn3(x, y, z):
-            return fn_varargs(x, y, z)
+        class ResNet(torch.jit.ScriptModule):
+            __constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
 
-        x, y, z = [torch.randn(2, 2) for _ in range(3)]
-        self.checkScript(fn1, (x, y, z), optimize=True)
-        self.checkScript(fn2, (x, y, z), optimize=True)
-        self.checkScript(fn3, (x, y, z), optimize=True)
+            def __init__(self, block, layers, num_classes=1000):
+                super(ResNet, self).__init__()
+                self.inplanes = 64
+                self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+                                       bias=False)
+                self.bn1 = nn.BatchNorm2d(64)
+                self.relu = nn.ReLU(inplace=True)
+                self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+                self.layer1 = self._make_layer(block, 64, layers[0])
+                self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+                self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+                self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+                self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+                self.fc = nn.Linear(512 * block.expansion, num_classes)
 
-    @unittest.skipIf(not PY35, "Python 3.5 needed")
-    def test_type_annotation_py3(self):
-        import importlib.util
+                for m in self.modules():
+                    if isinstance(m, nn.Conv2d):
+                        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                    elif isinstance(m, nn.BatchNorm2d):
+                        nn.init.constant_(m.weight, 1)
+                        nn.init.constant_(m.bias, 0)
 
-        code = dedent("""
-        import torch
-        from torch import Tensor
-        from typing import Tuple
+            def _make_layer(self, block, planes, blocks, stride=1):
+                downsample = None
+                if stride != 1 or self.inplanes != planes * block.expansion:
+                    downsample = nn.Sequential(
+                        conv1x1(self.inplanes, planes * block.expansion, stride),
+                        nn.BatchNorm2d(planes * block.expansion),
+                    )
 
-        def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]:
-            return (x, y + z, z)
-        """)
+                layers = []
+                layers.append(block(self.inplanes, planes, stride, downsample))
+                self.inplanes = planes * block.expansion
+                for _ in range(1, blocks):
+                    layers.append(block(self.inplanes, planes))
 
-        with tempfile.TemporaryDirectory() as tmp_dir:
-            script_path = os.path.join(tmp_dir, 'script.py')
-            with open(script_path, 'w') as f:
-                f.write(code)
-            fn = get_fn('test_type_annotation_py3', script_path)
+                return nn.Sequential(*layers)
 
-            with self.assertRaisesRegex(RuntimeError, r"expected a value of type Tensor for argument"
-                                                      r" '0' but found \(Tensor, Tensor\)"):
-                @torch.jit.script
-                def bad_fn(x):
-                    x, y = fn((x, x), x, x)
-                    return y
+            @torch.jit.script_method
+            def forward(self, x):
+                x = self.conv1(x)
+                x = self.bn1(x)
+                x = self.relu(x)
+                x = self.maxpool(x)
 
-            with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
-                @torch.jit.script
-                def bad_fn2(x):
-                    x, y = fn(x, x, x)
-                    return y
+                x = self.layer1(x)
+                x = self.layer2(x)
+                x = self.layer3(x)
+                x = self.layer4(x)
 
-            with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
-                @torch.jit.script
-                def bad_fn3(x):
-                    x, y, z, w = fn(x, x, x)
-                    return y
+                x = self.avgpool(x)
+                x = x.view(x.size(0), -1)
+                x = self.fc(x)
 
-            def good_fn(x):
-                y, z, w = fn(x, x, x)
-                return y, z, w
+                return x
 
-            self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
+        resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
 
-    def test_type_annotation_module(self):
-        class BaseModule(torch.jit.ScriptModule):
-            def foo(self, x):
-                # type: (Tensor) -> Tensor
-                return x + 1
+        resnet18_imported = self.getExportImportCopy(resnet18)
 
-            def bar(self, x, y):
-                # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
-                return x + y, y
+        input = torch.randn(1, 3, 224, 224, requires_grad=True)
+        output_orig = resnet18(input)
+        output_orig.sum().backward()
+        grad_orig = input.grad.clone()
+        input.grad.zero_()
 
-            def baz(self, x, y):
-                return x
+        output_import = resnet18_imported(input)
+        output_import.sum().backward()
+        grad_import = input.grad.clone()
 
-        class ModuleTooMany(BaseModule):
-            @torch.jit.script_method
-            def method(self, x):
-                return self.foo(x, x)
+        self.assertEqual(output_orig, output_import)
+        self.assertEqual(grad_orig, grad_import)
 
-        class ModuleTooFew(BaseModule):
-            @torch.jit.script_method
-            def method(self, x):
-                return self.bar(x)
+    def test_script_module_export_tensor_type(self):
+        class M(torch.jit.ScriptModule):
 
-        class ModuleTooManyAssign(BaseModule):
-            @torch.jit.script_method
-            def method(self, x):
-                y, z, w = self.bar(x, x)
-                return x
+            def __init__(self, type):
+                super(M, self).__init__(False)
+                self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_())
 
-        class ModuleDefault(BaseModule):
             @torch.jit.script_method
-            def method(self, x):
-                y = self.baz(x)
-                return x
+            def foo(self):
+                return self.param
 
-        with self.assertRaisesRegex(RuntimeError, "expected at most 1 arguments but found 2"):
-            ModuleTooMany()
-        with self.assertRaisesRegex(RuntimeError, "argument 1 not provided"):
-            ModuleTooFew()
-        with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"):
-            ModuleTooManyAssign()
-        with self.assertRaisesRegex(RuntimeError, "argument 1 not provided."):
-            ModuleDefault()
+        for type in [torch.float, torch.double]:
+            m_orig = M(type)
+            m_import = self.getExportImportCopy(m_orig)
+            # check to make sure the storage wasn't resized
+            self.assertTrue(m_orig.param.storage().size() == 25)
+            self.assertEqual(m_orig.foo(), m_import.foo())
+            self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
 
-    def test_script_define_order(self):
+    @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA")
+    def test_script_module_export_tensor_cuda(self):
         class M(torch.jit.ScriptModule):
+
             def __init__(self):
-                pass
+                super(M, self).__init__(False)
+                self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_())
 
             @torch.jit.script_method
-            def call_foo(self, input):
-                return self.foo(input)
+            def foo(self):
+                return self.param
 
-            @torch.jit.script_method
-            def foo(self, input):
-                return input + 1
-        m = M()
-        self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
+        m_orig = M()
+        m_import = self.getExportImportCopy(m_orig)
+        # check to make sure the storage wasn't resized
+        self.assertTrue(m_orig.param.storage().size() == 25)
+        self.assertTrue(m_import.foo().device == torch.device('cuda:0'))
+        self.assertEqual(m_orig.foo(), m_import.foo())
+        self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
 
-    def test_script_define_order_recursive_fail(self):
+    def test_script_module_export_blocks(self):
         class M(torch.jit.ScriptModule):
-            def __init__(self):
-                pass
+            def __init__(self, n, m):
+                super(M, self).__init__()
+                self.weight = torch.nn.Parameter(torch.rand(n, m))
 
             @torch.jit.script_method
-            def call_foo(self, input):
-                return self.foo(input)
+            def forward(self, input):
+                if bool(input.sum() > 0):
+                    output = self.weight.mv(input)
+                else:
+                    output = self.weight + input
+                return output
+
+        m_orig = M(200, 200)
+        m_import = self.getExportImportCopy(m_orig)
+
+        t = torch.rand(200)
+        self.assertEqual(m_orig(t), m_import(t))
+
+    def test_script_module_export_shared_storage(self):
+        class M(torch.jit.ScriptModule):
+
+            def __init__(self):
+                super(M, self).__init__(False)
+                self.param1 = torch.nn.Parameter(torch.rand(5, 5))
+                self.param2 = torch.nn.Parameter(self.param1[3])
+                self.param3 = torch.nn.Parameter(torch.rand(5, 5))
+                self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6])
 
             @torch.jit.script_method
-            def foo(self, input):
-                self.call_foo(input)
+            def foo(self):
+                return self.param1 + self.param2 + self.param3 + self.param4
 
-        with self.assertRaisesRegex(RuntimeError, 'called recursively involving'):
-            M()
+        m_orig = M()
+        m_import = self.getExportImportCopy(m_orig)
 
-    def test_script_kwargs_fn_call(self):
-        class M(torch.jit.ScriptModule):
+        self.assertEqual(m_orig.foo(), m_import.foo())
+        self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
+        self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
+
+    def test_onnx_export_script_module(self):
+        class ModuleToExport(torch.jit.ScriptModule):
             def __init__(self):
-                pass
+                super(ModuleToExport, self).__init__()
 
             @torch.jit.script_method
-            def call_foo(self, input):
-                return self.foo(input=input, bar=1)
+            def forward(self, x):
+                y = x - x
+                return x + x
 
-            @torch.jit.script_method
-            def foo(self, bar, input):
-                # type: (int, Tensor) -> Tensor
-                return input + bar
-        m = M()
-        self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
+        mte = ModuleToExport()
+        outputs = mte(torch.zeros(1, 2, 3))
+        self.assertExpected(torch.onnx.export_to_pretty_string(
+            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+            example_outputs=outputs))
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    def test_trace_of_script(self):
-        @torch.jit.script
-        def foo(a, c):
-            b = 0.0
-            if bool(a == 0.0):
-                b = 1.0
-            return b + c
+    def test_onnx_export_script_python_fail(self):
+        class ModuleToInline(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ModuleToInline, self).__init__()
 
-        a = torch.ones(1, dtype=torch.float)
+            def forward(self, x):
+                return torch.neg(x)
 
-        @_trace(torch.zeros(1, dtype=torch.float))
-        def use(b):
-            return foo(b - 1.0, a) + 1.0
+        class ModuleToExport(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ModuleToExport, self).__init__()
+                self.mod = ModuleToInline()
 
-        # test we propagated shapes through the function
-        self.assertTrue("Dynamic" not in str(use.graph))
+            @torch.jit.script_method
+            def forward(self, x):
+                y = self.mod(x)
+                return y + y
 
-        self.assertEqual(3, use(torch.ones(1, dtype=torch.float)))
-        self.assertEqual(2, use(torch.zeros(1, dtype=torch.float)))
+        mte = ModuleToExport()
+        outputs = mte(torch.zeros(1, 2, 3))
+        f = io.BytesIO()
+        with self.assertRaisesRegex(RuntimeError, "Couldn't export Python operator"):
+            torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False,
+                               example_outputs=outputs)
 
-    def test_if_define(self):
-        @torch.jit.script
-        def foo(a):
-            if bool(a == 0):
-                b = 1
-            else:
-                b = 0
-            return b + 1
+    def test_onnx_export_script_inline_trace(self):
+        class ModuleToInline(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ModuleToInline, self).__init__()
 
-        @torch.jit.script
-        def foo2(a):
-            b = 0
-            if bool(a == 0):
-                b = 1
-            return b + 1
+            def forward(self, x):
+                return torch.neg(x)
 
-        @torch.jit.script
-        def foo3(a):
-            b = 1
-            if bool(a == 0):
-                c = 4
-            else:
-                b = 0
-            return b + 1
+        class ModuleToExport(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ModuleToExport, self).__init__()
+                self.mod = torch.jit.trace(ModuleToInline(), torch.zeros(1, 2, 3))
 
-        a = torch.ones(1, dtype=torch.long)
-        b = torch.zeros(1, dtype=torch.long)
-        self.assertEqual(1, foo(a))
-        self.assertEqual(2, foo(b))
-        self.assertEqual(1, foo2(a))
-        self.assertEqual(2, foo2(b))
-        self.assertEqual(1, foo3(a))
-        self.assertEqual(2, foo3(b))
+            @torch.jit.script_method
+            def forward(self, x):
+                y = self.mod(x)
+                return y + y
 
-    def test_script_module_export_submodule(self):
-        class M1(torch.jit.ScriptModule):
+        mte = ModuleToExport()
+        outputs = mte(torch.zeros(1, 2, 3))
+        self.assertExpected(torch.onnx.export_to_pretty_string(
+            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+            example_outputs=outputs))
+
+    def test_onnx_export_script_inline_script(self):
+        class ModuleToInline(torch.jit.ScriptModule):
             def __init__(self):
-                super(M1, self).__init__(False)
-                self.weight = nn.Parameter(torch.randn(2))
+                super(ModuleToInline, self).__init__()
 
             @torch.jit.script_method
-            def forward(self, thing):
-                return self.weight + thing
+            def forward(self, x):
+                return torch.neg(x)
 
-        class M2(torch.jit.ScriptModule):
+        class ModuleToExport(torch.jit.ScriptModule):
             def __init__(self):
-                super(M2, self).__init__(False)
-                # test submodule
-                self.sub = M1()
-                self.weight = nn.Parameter(torch.randn(2, 3))
-                self.bias = nn.Parameter(torch.randn(2))
-                self.define("""
-                    def hi(self, a):
-                        return self.weight.mm(a)
-                """)
+                super(ModuleToExport, self).__init__()
+                self.mod = ModuleToInline()
 
             @torch.jit.script_method
-            def doit(self, input):
-                return self.weight.mm(input)
+            def forward(self, x):
+                y = self.mod(x)
+                return y + y
 
-            @torch.jit.script_method
-            def doit2(self, input):
-                return self.weight.mm(input)
+        mte = ModuleToExport()
+        outputs = mte(torch.zeros(1, 2, 3))
+        self.assertExpected(torch.onnx.export_to_pretty_string(
+            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+            example_outputs=outputs))
 
-            @torch.jit.script_method
-            def doit3(self, input):
-                return input + torch.ones([1], dtype=torch.double)
+    def test_onnx_export_script_module_loop(self):
+        class ModuleToExport(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ModuleToExport, self).__init__()
 
             @torch.jit.script_method
-            def forward(self, input):
-                a = self.doit(input)
-                b = self.doit2(input)
-                c = self.hi(input)
-                return a + b + self.bias + c
+            def forward(self, x):
+                # test if we support end to end onnx export on loop and
+                # nested loops with and without loop index
+                for _ in range(5):
+                    for i in range(3):
+                        x = x + i
+                return x
 
-        m_orig = M2()
-        m_import = self.getExportImportCopy(m_orig)
+        mte = ModuleToExport()
+        outputs = mte(torch.zeros(1, 2, 3))
+        self.assertExpected(torch.onnx.export_to_pretty_string(
+            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+            example_outputs=outputs))
 
-        input = torch.randn(3, 2)
-        self.assertEqual(m_orig.doit(input), m_import.doit(input))
-        self.assertEqual(m_orig.hi(input), m_import.hi(input))
-        self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
-        self.assertEqual(m_orig.forward(input), m_import.forward(input))
+    def test_onnx_export_script_truediv(self):
+        class ModuleToExport(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ModuleToExport, self).__init__()
 
-    @skipIfNoTorchVision
-    def test_script_module_trace_resnet18(self):
-        x = torch.ones(1, 3, 224, 224)
-        m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224))
-        m_import = self.getExportImportCopy(m_orig)
+            @torch.jit.script_method
+            def forward(self, x):
+                z = x.size(0) / 2
+                return x + z
 
-        input = torch.randn(1, 3, 224, 224, requires_grad=True)
-        output_orig = m_orig(input)
-        output_orig.sum().backward()
-        grad_orig = input.grad.clone()
-        input.grad.zero_()
+        mte = ModuleToExport()
+        outputs = mte(torch.zeros(1, 2, 3))
+        self.assertExpected(torch.onnx.export_to_pretty_string(
+            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+            example_outputs=outputs))
 
-        output_import = m_import(input)
-        output_import.sum().backward()
-        grad_import = input.grad.clone()
+    def test_onnx_raw_export_script_truediv(self):
+        class ModuleToExport(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ModuleToExport, self).__init__()
 
-        self.assertEqual(output_orig, output_import)
-        self.assertEqual(grad_orig, grad_import)
+            @torch.jit.script_method
+            def forward(self, x):
+                z = x.size(0) / 2
+                return x + z
 
-    @skipIfNoTorchVision
-    def test_script_module_script_resnet(self):
-        def conv1x1(in_planes, out_planes, stride=1):
-            """1x1 convolution"""
-            return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+        mte = ModuleToExport()
+        outputs = mte(torch.zeros(1, 2, 3))
+        self.assertExpected(torch.onnx.export_to_pretty_string(
+            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+            example_outputs=outputs, export_raw_ir=True))
 
-        def conv3x3(in_planes, out_planes, stride=1):
-            """3x3 convolution with padding"""
-            return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
-                             padding=1, bias=False)
+    def test_onnx_export_script_non_alpha_add_sub(self):
+        class ModuleToExport(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ModuleToExport, self).__init__()
 
-        class BasicBlock(torch.jit.ScriptModule):
-            expansion = 1
-            __constants__ = ['downsample']
+            @torch.jit.script_method
+            def forward(self, x):
+                bs = x.size(0) + 1
+                return bs - 1
+
+        mte = ModuleToExport()
+        outputs = torch.LongTensor([mte(torch.rand(3, 4))])
+        self.assertExpected(torch.onnx.export_to_pretty_string(
+            mte, (torch.rand(3, 4),), None, verbose=False,
+            example_outputs=outputs))
 
-            def __init__(self, inplanes, planes, stride=1, downsample=None):
-                super(BasicBlock, self).__init__()
-                self.conv1 = conv3x3(inplanes, planes, stride)
-                self.bn1 = nn.BatchNorm2d(planes)
-                self.relu = nn.ReLU(inplace=True)
-                self.conv2 = conv3x3(planes, planes)
-                self.bn2 = nn.BatchNorm2d(planes)
-                self.downsample = downsample
-                self.stride = stride
+    def test_onnx_export_script_module_if(self):
+        class ModuleToExport(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ModuleToExport, self).__init__()
 
             @torch.jit.script_method
             def forward(self, x):
-                residual = x
+                if bool(torch.sum(x) > 0):
+                    x = torch.neg(x)
+                return x
 
-                out = self.conv1(x)
-                out = self.bn1(out)
-                out = self.relu(out)
+        mte = ModuleToExport()
+        outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
+        self.assertExpected(torch.onnx.export_to_pretty_string(
+            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
+            example_outputs=outputs))
 
-                out = self.conv2(out)
-                out = self.bn2(out)
+    def test_onnx_export_script_inline_params(self):
+        class ModuleToInline(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ModuleToInline, self).__init__()
+                self.m = torch.nn.Parameter(torch.ones(3, 3))
+                self.unused = torch.nn.Parameter(torch.ones(1, 2, 3))
 
-                if self.downsample is not None:
-                    residual = self.downsample(x)
+            @torch.jit.script_method
+            def forward(self, x):
+                return torch.mm(x, self.m)
 
-                out += residual
-                out = self.relu(out)
+        class ModuleToExport(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ModuleToExport, self).__init__()
+                self.mod = ModuleToInline()
+                self.param = torch.nn.Parameter(torch.ones(3, 4))
 
-                return out
+            @torch.jit.script_method
+            def forward(self, x):
+                y = self.mod(x)
+                return torch.mm(y, self.param)
 
-        class ResNet(torch.jit.ScriptModule):
-            __constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
+        mte = ModuleToExport()
+        result = mte(torch.zeros(2, 3))
+        reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
+        self.assertEqual(result, reference)
+        self.assertExpected(torch.onnx.export_to_pretty_string(
+            mte, (torch.ones(2, 3),), None, verbose=False,
+            example_outputs=result, propagate=True))
 
-            def __init__(self, block, layers, num_classes=1000):
-                super(ResNet, self).__init__()
-                self.inplanes = 64
-                self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
-                                       bias=False)
-                self.bn1 = nn.BatchNorm2d(64)
-                self.relu = nn.ReLU(inplace=True)
-                self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
-                self.layer1 = self._make_layer(block, 64, layers[0])
-                self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
-                self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
-                self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
-                self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
-                self.fc = nn.Linear(512 * block.expansion, num_classes)
+    def test_trace_with_size(self):
+        @_trace(torch.zeros(1, 1))
+        def foo(x):
+            return x + 1
 
-                for m in self.modules():
-                    if isinstance(m, nn.Conv2d):
-                        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
-                    elif isinstance(m, nn.BatchNorm2d):
-                        nn.init.constant_(m.weight, 1)
-                        nn.init.constant_(m.bias, 0)
+        @torch.jit.script
+        def bar(x):
+            y = int(foo(x))
+            if True:
+                y = 7
+            return y + 1
 
-            def _make_layer(self, block, planes, blocks, stride=1):
-                downsample = None
-                if stride != 1 or self.inplanes != planes * block.expansion:
-                    downsample = nn.Sequential(
-                        conv1x1(self.inplanes, planes * block.expansion, stride),
-                        nn.BatchNorm2d(planes * block.expansion),
-                    )
+        self.assertEqual(8, bar(torch.ones(1, 1)))
 
-                layers = []
-                layers.append(block(self.inplanes, planes, stride, downsample))
-                self.inplanes = planes * block.expansion
-                for _ in range(1, blocks):
-                    layers.append(block(self.inplanes, planes))
+    def test_tracing_slicing(self):
+        @_trace(torch.zeros(10))
+        def foo_trace(x):
+            return x[-5:-3]
 
-                return nn.Sequential(*layers)
+        @torch.jit.script
+        def foo_script(x):
+            return x[-5:-3]
 
-            @torch.jit.script_method
-            def forward(self, x):
-                x = self.conv1(x)
-                x = self.bn1(x)
-                x = self.relu(x)
-                x = self.maxpool(x)
+        def foo(x):
+            return x[-5:-3]
 
-                x = self.layer1(x)
-                x = self.layer2(x)
-                x = self.layer3(x)
-                x = self.layer4(x)
+        a = torch.arange(0, 8)
+        b = torch.arange(0, 20)
+        self.assertEqual(foo_trace(a), foo_script(a))
+        self.assertEqual(foo_trace(a), foo(a))
+        self.assertNotEqual(foo_trace(a), foo_trace(b))
 
-                x = self.avgpool(x)
-                x = x.view(x.size(0), -1)
-                x = self.fc(x)
+    def test_tracing_indexing(self):
+        @_trace(torch.zeros(10))
+        def foo_trace(x):
+            return x[-2]
 
-                return x
+        @torch.jit.script
+        def foo_script(x):
+            return x[-2]
 
-        resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
+        def foo(x):
+            return x[-2]
 
-        resnet18_imported = self.getExportImportCopy(resnet18)
+        a = torch.arange(0, 8)
+        b = torch.arange(0, 20)
+        self.assertEqual(foo_script(a), foo_trace(a))
+        self.assertEqual(foo_trace(a), foo(a))
+        self.assertNotEqual(foo_trace(a), foo_trace(b))
 
-        input = torch.randn(1, 3, 224, 224, requires_grad=True)
-        output_orig = resnet18(input)
-        output_orig.sum().backward()
-        grad_orig = input.grad.clone()
-        input.grad.zero_()
+    def test_index_select_shape_prop(self):
 
-        output_import = resnet18_imported(input)
-        output_import.sum().backward()
-        grad_import = input.grad.clone()
+        @torch.jit.script
+        def foo(x, y):
+            return torch.index_select(x, index=y, dim=1)
 
-        self.assertEqual(output_orig, output_import)
-        self.assertEqual(grad_orig, grad_import)
+        a = torch.zeros(2, 2)
+        b = torch.zeros(4, dtype=torch.long)
+        torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False)
+        self.assertExpected(canonical(foo.graph))
 
-    def test_script_module_export_tensor_type(self):
-        class M(torch.jit.ScriptModule):
+    def test_onnx_export_speculate(self):
 
-            def __init__(self, type):
-                super(M, self).__init__(False)
-                self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_())
+        class Foo(torch.jit.ScriptModule):
+            def __init__(self, m):
+                super(Foo, self).__init__()
+                self.m = m
 
             @torch.jit.script_method
-            def foo(self):
-                return self.param
+            def forward(self, x):
+                x += x
+                # because we are testing if we emit `if` statement correctly
+                # we cannot use `True` as the condition. Constant prop
+                # would remove the `if` statements.
+                c = torch.sum(x) > 4
+                if bool(c):
+                    if bool(c):
+                        y = self.m(x)
+                    else:
+                        y = self.m(x)
+                else:
+                    y = self.m(x)
+                return y
 
-        for type in [torch.float, torch.double]:
-            m_orig = M(type)
-            m_import = self.getExportImportCopy(m_orig)
-            # check to make sure the storage wasn't resized
-            self.assertTrue(m_orig.param.storage().size() == 25)
-            self.assertEqual(m_orig.foo(), m_import.foo())
-            self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
+        linear = torch.jit.trace(nn.Linear(10, 20).float(), torch.zeros(1, 10, dtype=torch.float))
 
-    @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA")
-    def test_script_module_export_tensor_cuda(self):
-        class M(torch.jit.ScriptModule):
+        @torch.jit.script
+        def transpose(x):
+            return x.t()
 
-            def __init__(self):
-                super(M, self).__init__(False)
-                self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_())
+        f1 = Foo(transpose)
+        outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
+        f2 = Foo(linear)
+        outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))
 
-            @torch.jit.script_method
-            def foo(self):
-                return self.param
+        onnx_ish = torch.onnx.export_to_pretty_string(
+            f1,
+            (torch.ones(1, 10, dtype=torch.float), ),
+            None, verbose=False, example_outputs=outputs_f1)
+        self.assertExpected(onnx_ish, subname='f1')
+        onnx_ish = torch.onnx.export_to_pretty_string(
+            f2,
+            (torch.ones(1, 10, dtype=torch.float), ),
+            None, verbose=False, example_outputs=outputs_f2)
+        self.assertExpected(onnx_ish, subname='f2')
+
+    def test_onnx_export_shape_reshape(self):
+        class Foo(torch.nn.Module):
+            def forward(self, x):
+                import torch.onnx.operators
+                x = x.repeat(5, 1, 1)
+                shape = torch.onnx.operators.shape_as_tensor(x)
+                reshaped = torch.onnx.operators.reshape_from_tensor_shape(x, shape)
+                return reshaped
+
+        foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
+        outputs = foo(torch.zeros(1, 2, 3))
+        f = io.BytesIO()
+        s = torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f,
+                                               example_outputs=outputs)
+        self.assertExpected(s)
+
+    def test_shape_analysis_loop(self):
+        def foo(a, b, x):
+            c = a
+            # on the first iteration of the loop it appears that
+            # c should have a expand to the size of b
+            # but on the second+ iterations, there is no broadcast and the
+            # sizes are different.
+            # previously this would cause the compiler to (1) enter an infinite
+            # loop trying to compute the shape, and (2) insert invalid
+            # broadcasts.
+            # this test ensure we don't regress on these issues
+            for _ in range(2):
+                a = c + b
+                c = x
+                b = x
+            return a
 
-        m_orig = M()
-        m_import = self.getExportImportCopy(m_orig)
-        # check to make sure the storage wasn't resized
-        self.assertTrue(m_orig.param.storage().size() == 25)
-        self.assertTrue(m_import.foo().device == torch.device('cuda:0'))
-        self.assertEqual(m_orig.foo(), m_import.foo())
-        self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
+        self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=False)
 
-    def test_script_module_export_blocks(self):
-        class M(torch.jit.ScriptModule):
-            def __init__(self, n, m):
-                super(M, self).__init__()
-                self.weight = torch.nn.Parameter(torch.rand(n, m))
+    def test_intlist_args(self):
+        def func_1(x):
+            return torch.nn.functional.adaptive_avg_pool1d(x, 1)
 
-            @torch.jit.script_method
-            def forward(self, input):
-                if bool(input.sum() > 0):
-                    output = self.weight.mv(input)
-                else:
-                    output = self.weight + input
-                return output
+        def func_2(x):
+            return torch.nn.functional.adaptive_avg_pool1d(x, output_size=1)
 
-        m_orig = M(200, 200)
-        m_import = self.getExportImportCopy(m_orig)
+        def func_3(x):
+            return torch.nn.functional.adaptive_avg_pool1d(x, output_size=[1])
 
-        t = torch.rand(200)
-        self.assertEqual(m_orig(t), m_import(t))
+        x = torch.randn(8, 8, 8)
+        self.checkScript(func_1, [x], optimize=True)
+        self.checkScript(func_2, [x], optimize=True)
+        self.checkScript(func_3, [x], optimize=True)
 
-    def test_script_module_export_shared_storage(self):
-        class M(torch.jit.ScriptModule):
+    def test_wrong_implicit_expand(self):
 
-            def __init__(self):
-                super(M, self).__init__(False)
-                self.param1 = torch.nn.Parameter(torch.rand(5, 5))
-                self.param2 = torch.nn.Parameter(self.param1[3])
-                self.param3 = torch.nn.Parameter(torch.rand(5, 5))
-                self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6])
+        @_trace(torch.zeros(3), torch.zeros(1))
+        def foo(a, b):
+            return a + b
 
-            @torch.jit.script_method
-            def foo(self):
-                return self.param1 + self.param2 + self.param3 + self.param4
+        a = torch.rand(4)
+        b = torch.rand(4)
+        self.assertEqual(a + b, foo(a, b))
 
-        m_orig = M()
-        m_import = self.getExportImportCopy(m_orig)
+    def test_builtin_args_fails(self):
 
-        self.assertEqual(m_orig.foo(), m_import.foo())
-        self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
-        self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
+        with self.assertRaisesRegex(RuntimeError, 'expected at most'):
+            @torch.jit.script
+            def f0(a):
+                torch.sum(a, a, a, a)
 
-    def test_onnx_export_script_module(self):
-        class ModuleToExport(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToExport, self).__init__()
+        with self.assertRaisesRegex(RuntimeError, 'argument self not provided'):
+            @torch.jit.script
+            def f1(a):
+                torch.sum(foo=4)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                y = x - x
-                return x + x
+        with self.assertRaisesRegex(RuntimeError, 'specified twice'):
+            @torch.jit.script
+            def f2(a):
+                torch.sum(a, self=a)
 
-        mte = ModuleToExport()
-        outputs = mte(torch.zeros(1, 2, 3))
-        self.assertExpected(torch.onnx.export_to_pretty_string(
-            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
-            example_outputs=outputs))
+        with self.assertRaisesRegex(RuntimeError, 'not provided'):
+            @torch.jit.script
+            def f3(a):
+                torch.sum(dim=4)
 
-    def test_onnx_export_script_python_fail(self):
-        class ModuleToInline(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToInline, self).__init__()
+        with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but found Tensor'):
+            @torch.jit.script
+            def f4(a):
+                torch.cat(a)
 
-            def forward(self, x):
-                return torch.neg(x)
+        with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but found int\[\]'):
+            @torch.jit.script
+            def f5(a):
+                torch.cat([3])
 
-        class ModuleToExport(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToExport, self).__init__()
-                self.mod = ModuleToInline()
+        with self.assertRaisesRegex(RuntimeError, 'Lists must contain only a single type'):
+            @torch.jit.script
+            def f6(a):
+                a.expand(size=[3, [4]])
 
-            @torch.jit.script_method
-            def forward(self, x):
-                y = self.mod(x)
-                return y + y
+        with self.assertRaisesRegex(RuntimeError, 'xpected a value of type Tensor for argument \'self\''):
+            @torch.jit.script
+            def f7(a):
+                torch.sum([4])
 
-        mte = ModuleToExport()
-        outputs = mte(torch.zeros(1, 2, 3))
-        f = io.BytesIO()
-        with self.assertRaisesRegex(RuntimeError, "Couldn't export Python operator"):
-            torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False,
-                               example_outputs=outputs)
+    def test_builtin_args(self):
 
-    def test_onnx_export_script_inline_trace(self):
-        class ModuleToInline(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToInline, self).__init__()
+        def t0(a):
+            # default arg dim
+            return torch.cat([a, a])
 
-            def forward(self, x):
-                return torch.neg(x)
+        self.checkScript(t0, (torch.zeros(1, 1),))
 
-        class ModuleToExport(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToExport, self).__init__()
-                self.mod = torch.jit.trace(ModuleToInline(), torch.zeros(1, 2, 3))
+        def t1(a):
+            # keywords out of order
+            return torch.cat(dim=1, tensors=[a, a])
 
-            @torch.jit.script_method
-            def forward(self, x):
-                y = self.mod(x)
-                return y + y
+        self.checkScript(t1, (torch.zeros(1, 1, 2),))
 
-        mte = ModuleToExport()
-        outputs = mte(torch.zeros(1, 2, 3))
-        self.assertExpected(torch.onnx.export_to_pretty_string(
-            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
-            example_outputs=outputs))
+        def t2(a):
+            # mix const/non-const attributes
+            if True:
+                b = 1
+            else:
+                b = 0
+            return torch.sum(a, dim=b, keepdim=False)
 
-    def test_onnx_export_script_inline_script(self):
-        class ModuleToInline(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToInline, self).__init__()
+        self.checkScript(t2, (torch.zeros(1, 1, 2),))
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return torch.neg(x)
+    def test_parser_type_annotations(self):
+        cu = torch.jit.CompilationUnit('''
+            def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
+                return x, x
+        ''')
 
-        class ModuleToExport(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToExport, self).__init__()
-                self.mod = ModuleToInline()
+        self.assertExpected(cu.__getattr__('foo').pretty_print_schema())
 
-            @torch.jit.script_method
-            def forward(self, x):
-                y = self.mod(x)
-                return y + y
+    def test_parser_type_annotations_comment(self):
+        cu = torch.jit.CompilationUnit('''
+            def foo(x, y):
+                # type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]
+                return x, x
+        ''')
 
-        mte = ModuleToExport()
-        outputs = mte(torch.zeros(1, 2, 3))
-        self.assertExpected(torch.onnx.export_to_pretty_string(
-            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
-            example_outputs=outputs))
+        self.assertExpected(cu.__getattr__('foo').pretty_print_schema())
 
-    def test_onnx_export_script_module_loop(self):
-        class ModuleToExport(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToExport, self).__init__()
+    def test_parser_type_annotations_unknown_type(self):
+        with self.assertRaisesRegex(RuntimeError, r'Unknown type name Foo'):
+            cu = torch.jit.CompilationUnit('''
+                def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
+                    return x, x
+            ''')
 
-            @torch.jit.script_method
-            def forward(self, x):
-                # test if we support end to end onnx export on loop and
-                # nested loops with and without loop index
-                for _ in range(5):
-                    for i in range(3):
-                        x = x + i
-                return x
+    def test_parser_type_annotations_subscript_non_ident(self):
+        with self.assertRaisesRegex(RuntimeError, r'Subscripted type must be a type identifier'):
+            cu = torch.jit.CompilationUnit('''
+                def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]:
+                    return x, x
+            ''')
 
-        mte = ModuleToExport()
-        outputs = mte(torch.zeros(1, 2, 3))
-        self.assertExpected(torch.onnx.export_to_pretty_string(
-            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
-            example_outputs=outputs))
+    def test_parser_type_annotations_subscript_tensor(self):
+        with self.assertRaisesRegex(RuntimeError, r'Unknown type constructor Tensor'):
+            cu = torch.jit.CompilationUnit('''
+                def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
+                    return x, x
+            ''')
 
-    def test_onnx_export_script_truediv(self):
-        class ModuleToExport(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToExport, self).__init__()
+    def test_parser_type_annotations_incompatible_expression(self):
+        with self.assertRaisesRegex(RuntimeError, r'Expression of type \+ cannot be used in a type expression'):
+            cu = torch.jit.CompilationUnit('''
+                def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]:
+                    return x, x
+            ''')
 
-            @torch.jit.script_method
-            def forward(self, x):
-                z = x.size(0) / 2
-                return x + z
+    def test_gather_dynamic_index(self):
+        def t(x):
+            gather1 = x[0]
+            idx = 0 + 1
+            gather2 = x[idx]
+            return gather1 + gather2
 
-        mte = ModuleToExport()
-        outputs = mte(torch.zeros(1, 2, 3))
-        self.assertExpected(torch.onnx.export_to_pretty_string(
-            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
-            example_outputs=outputs))
+        self.checkScript(t, (torch.zeros(3, 2, 3),))
 
-    def test_onnx_raw_export_script_truediv(self):
-        class ModuleToExport(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToExport, self).__init__()
+    def test_slice_dynamic_index(self):
+        def t(x):
+            slice1 = x[0:1]
+            zero = 0
+            one = zero + 1
+            slice2 = x[zero:one]
+            return slice1 + slice2
 
-            @torch.jit.script_method
-            def forward(self, x):
-                z = x.size(0) / 2
-                return x + z
+        self.checkScript(t, (torch.zeros(3, 2, 3),))
 
-        mte = ModuleToExport()
-        outputs = mte(torch.zeros(1, 2, 3))
-        self.assertExpected(torch.onnx.export_to_pretty_string(
-            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
-            example_outputs=outputs, export_raw_ir=True))
+    def test_addmm_grad(self):
+        """ This test checks several things:
+            1. An expand node was inserted before the addmm operating on the
+               bias term.
+            2. The fused form of addmm appears in the ultimate graph that's
+               executed.
+            3. A sum op was emitted for accumulating gradients along the 0th
+               (expanded) dimension of the bias term.
+            4. The correct symbolic representation for the backward pass of the
+               mm operator was emitted (x.t() -> mm)
 
-    def test_onnx_export_script_non_alpha_add_sub(self):
-        class ModuleToExport(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToExport, self).__init__()
+            TODO: we should actually check these conditions once we have a way
+            to dump the GraphExecutor state. Namely the processed forward graph
+            and the backward graph.
+        """
+        @torch.jit.script
+        def addmm_grad_test(b, x, w):
+            return torch.addmm(b, x, w)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                bs = x.size(0) + 1
-                return bs - 1
+        # Initialize param and input values
+        w_init = torch.rand(2, 5)
+        b_init = torch.rand(5)
+        x = torch.rand(3, 2)
 
-        mte = ModuleToExport()
-        outputs = torch.LongTensor([mte(torch.rand(3, 4))])
-        self.assertExpected(torch.onnx.export_to_pretty_string(
-            mte, (torch.rand(3, 4),), None, verbose=False,
-            example_outputs=outputs))
+        # Clone trainable params
+        b = b_init.clone()
+        b.requires_grad_()
+        w = w_init.clone()
+        w.requires_grad_()
 
-    def test_onnx_export_script_module_if(self):
-        class ModuleToExport(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToExport, self).__init__()
+        # Test symbolic differentiation
+        y = addmm_grad_test(b, x, w)
+        y.sum().backward()
 
-            @torch.jit.script_method
-            def forward(self, x):
-                if bool(torch.sum(x) > 0):
-                    x = torch.neg(x)
-                return x
+        # clone params for autograd reference
+        b_ref = b_init.clone()
+        b_ref.requires_grad_()
+        w_ref = w_init.clone()
+        w_ref.requires_grad_()
+        y_ref = torch.addmm(b_ref, x, w_ref)
+        y_ref.sum().backward()
 
-        mte = ModuleToExport()
-        outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
-        self.assertExpected(torch.onnx.export_to_pretty_string(
-            mte, (torch.zeros(1, 2, 3),), None, verbose=False,
-            example_outputs=outputs))
+        self.assertEqual(w.grad, w_ref.grad)
+        self.assertEqual(b.grad, b_ref.grad)
+
+    def test_zeros(self):
+        class M(torch.jit.ScriptModule):
+            __constants__ = ['d']
 
-    def test_onnx_export_script_inline_params(self):
-        class ModuleToInline(torch.jit.ScriptModule):
             def __init__(self):
-                super(ModuleToInline, self).__init__()
-                self.m = torch.nn.Parameter(torch.ones(3, 3))
-                self.unused = torch.nn.Parameter(torch.ones(1, 2, 3))
+                self.d = torch.device('cpu')
 
             @torch.jit.script_method
-            def forward(self, x):
-                return torch.mm(x, self.m)
+            def create(self):
+                return torch.zeros([1, 1, 2], dtype=torch.float, device=self.d, layout=torch.strided)
 
-        class ModuleToExport(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ModuleToExport, self).__init__()
-                self.mod = ModuleToInline()
-                self.param = torch.nn.Parameter(torch.ones(3, 4))
+        r = M().create()
+        self.assertEqual(r.dtype, torch.float)
+        self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                y = self.mod(x)
-                return torch.mm(y, self.param)
+    def test_vararg_zeros(self):
+        def foo():
+            return torch.zeros(3, 4, 5, dtype=torch.int)
 
-        mte = ModuleToExport()
-        result = mte(torch.zeros(2, 3))
-        reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
-        self.assertEqual(result, reference)
-        self.assertExpected(torch.onnx.export_to_pretty_string(
-            mte, (torch.ones(2, 3),), None, verbose=False,
-            example_outputs=result, propagate=True))
+        self.checkScript(foo, ())
 
-    def test_trace_with_size(self):
-        @_trace(torch.zeros(1, 1))
-        def foo(x):
-            return x + 1
+    def test_rand(self):
+        def test_rand():
+            a = torch.rand([3, 4])
+            return a + 1.0 - a
 
-        @torch.jit.script
-        def bar(x):
-            y = int(foo(x))
-            if True:
-                y = 7
-            return y + 1
+        self.checkScript(test_rand, ())
 
-        self.assertEqual(8, bar(torch.ones(1, 1)))
+    def test_erase_number_types(self):
+        def func(a):
+            b = 7 + 1 + 3
+            c = a + b
+            c += b
+            return c
 
-    def test_tracing_slicing(self):
-        @_trace(torch.zeros(10))
-        def foo_trace(x):
-            return x[-5:-3]
+        graph = torch.jit.script(func).graph
+        self.run_pass('remove_inplace_ops', graph)
+        self.run_pass('erase_number_types', graph)
+        self.assertExpectedGraph(graph)
 
-        @torch.jit.script
-        def foo_script(x):
-            return x[-5:-3]
+    def test_mm_batching(self):
+        lstm_cell = torch.jit.script(LSTMCellS)
 
-        def foo(x):
-            return x[-5:-3]
+        def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
+            for i in range(x.size(0)):
+                hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
+            return hx
 
-        a = torch.arange(0, 8)
-        b = torch.arange(0, 20)
-        self.assertEqual(foo_trace(a), foo_script(a))
-        self.assertEqual(foo_trace(a), foo(a))
-        self.assertNotEqual(foo_trace(a), foo_trace(b))
+        slstm = torch.jit.script(lstm)
 
-    def test_tracing_indexing(self):
-        @_trace(torch.zeros(10))
-        def foo_trace(x):
-            return x[-2]
+        inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
+        slstm(*inputs).sum().backward()
 
-        @torch.jit.script
-        def foo_script(x):
-            return x[-2]
+        fw_graph = slstm.graph_for(*inputs)
+        bw_graph = backward_graph(slstm, diff_graph_idx=0)
+        self.assertTrue('prim::MMBatchSide' in str(fw_graph))
+        self.assertTrue('prim::MMTreeReduce' in str(bw_graph))
 
-        def foo(x):
-            return x[-2]
+        sout = slstm(*inputs)
+        out = lstm(*inputs)
+        self.assertEqual(slstm(*inputs), lstm(*inputs))
+        self.assertEqual(torch.autograd.grad(slstm(*inputs).sum(), inputs),
+                         torch.autograd.grad(lstm(*inputs).sum(), inputs))
 
-        a = torch.arange(0, 8)
-        b = torch.arange(0, 20)
-        self.assertEqual(foo_script(a), foo_trace(a))
-        self.assertEqual(foo_trace(a), foo(a))
-        self.assertNotEqual(foo_trace(a), foo_trace(b))
+    def test_loop_unrolling(self):
+        def fn(x):
+            y = 0
+            for i in range(int(x)):
+                y += i
+            return y
 
-    def test_index_select_shape_prop(self):
+        graph = torch.jit.script(fn).graph
+        self.run_pass('loop_unrolling', graph)
+        self.assertExpectedGraph(graph)
+        self.checkScript(fn, (torch.tensor(10),))
 
-        @torch.jit.script
-        def foo(x, y):
-            return torch.index_select(x, index=y, dim=1)
+    def test_loop_unrolling_const(self):
+        def fn():
+            y = 0
+            for i in range(10):
+                y += 1
+            return y
 
-        a = torch.zeros(2, 2)
-        b = torch.zeros(4, dtype=torch.long)
-        torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False)
-        self.assertExpected(canonical(foo.graph))
+        def fn2():
+            y = 0
+            for i in range(10):
+                y += i
+            return y
 
-    def test_onnx_export_speculate(self):
+        def check(fn, name):
+            graph = torch.jit.script(fn).graph
+            self.run_pass('loop_unrolling', graph)
+            self.assertExpectedGraph(graph, subname=name)
+            self.checkScript(fn, ())
 
-        class Foo(torch.jit.ScriptModule):
-            def __init__(self, m):
-                super(Foo, self).__init__()
-                self.m = m
+        check(fn, 'add_const')
+        check(fn2, 'add_iter')
 
-            @torch.jit.script_method
-            def forward(self, x):
-                x += x
-                # because we are testing if we emit `if` statement correctly
-                # we cannot use `True` as the condition. Constant prop
-                # would remove the `if` statements.
-                c = torch.sum(x) > 4
-                if bool(c):
-                    if bool(c):
-                        y = self.m(x)
-                    else:
-                        y = self.m(x)
-                else:
-                    y = self.m(x)
-                return y
+    def test_loop_unrolling_nested(self):
+        def fn(x):
+            y = 0
+            for i in range(10):
+                for j in range(int(x)):
+                    y += j
+            return y
 
-        linear = torch.jit.trace(nn.Linear(10, 20).float(), torch.zeros(1, 10, dtype=torch.float))
+        graph = torch.jit.script(fn).graph
+        self.run_pass('loop_unrolling', graph)
+        self.assertExpectedGraph(graph)
+        self.checkScript(fn, (torch.tensor(10),))
 
-        @torch.jit.script
-        def transpose(x):
-            return x.t()
+    def test_loop_unroll_unused_counter(self):
+        def fn(x):
+            y = 0
+            for i in range(int(x)):
+                y += 1
+            return y
 
-        f1 = Foo(transpose)
-        outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
-        f2 = Foo(linear)
-        outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))
+        graph = torch.jit.script(fn).graph
+        self.run_pass('loop_unrolling', graph)
+        self.assertExpectedGraph(graph)
 
-        onnx_ish = torch.onnx.export_to_pretty_string(
-            f1,
-            (torch.ones(1, 10, dtype=torch.float), ),
-            None, verbose=False, example_outputs=outputs_f1)
-        self.assertExpected(onnx_ish, subname='f1')
-        onnx_ish = torch.onnx.export_to_pretty_string(
-            f2,
-            (torch.ones(1, 10, dtype=torch.float), ),
-            None, verbose=False, example_outputs=outputs_f2)
-        self.assertExpected(onnx_ish, subname='f2')
+    def test_loop_unroll_negative(self):
+        def fn(x):
+            y = 0
+            for i in range(int(x)):
+                y += 1
+            return y
 
-    def test_onnx_export_shape_reshape(self):
-        class Foo(torch.nn.Module):
-            def forward(self, x):
-                import torch.onnx.operators
-                x = x.repeat(5, 1, 1)
-                shape = torch.onnx.operators.shape_as_tensor(x)
-                reshaped = torch.onnx.operators.reshape_from_tensor_shape(x, shape)
-                return reshaped
+        self.checkScript(fn, (torch.tensor(-20),))
+        self.checkScript(fn, (torch.tensor(-2),))
+        self.checkScript(fn, (torch.tensor(-1),))
+        self.checkScript(fn, (torch.tensor(0),))
+        self.checkScript(fn, (torch.tensor(1),))
+        self.checkScript(fn, (torch.tensor(2),))
 
-        foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
-        outputs = foo(torch.zeros(1, 2, 3))
-        f = io.BytesIO()
-        s = torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f,
-                                               example_outputs=outputs)
-        self.assertExpected(s)
+    def test_where(self):
+        def fn(x, y):
+            return torch.where(x > 0.0, x, y)
 
-    def test_shape_analysis_loop(self):
-        def foo(a, b, x):
-            c = a
-            # on the first iteration of the loop it appears that
-            # c should have a expand to the size of b
-            # but on the second+ iterations, there is no broadcast and the
-            # sizes are different.
-            # previously this would cause the compiler to (1) enter an infinite
-            # loop trying to compute the shape, and (2) insert invalid
-            # broadcasts.
-            # this test ensure we don't regress on these issues
-            for _ in range(2):
-                a = c + b
-                c = x
-                b = x
-            return a
+        self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
 
-        self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=False)
+    def test_where_method(self):
+        def fn(x, y):
+            return x.where(x > 0.0, y)
 
-    def test_intlist_args(self):
-        def func_1(x):
-            return torch.nn.functional.adaptive_avg_pool1d(x, 1)
+        self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
 
-        def func_2(x):
-            return torch.nn.functional.adaptive_avg_pool1d(x, output_size=1)
+    def test_reassign_module_lhs(self):
+        with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\' because it has type value and self is'
+                                    ' not a first-class value.  Only reassignments to first-class values are allowed'):
+            class ReassignSelfLHS(torch.jit.ScriptModule):
+                @torch.jit.script_method
+                def forward(self, x):
+                    for i in range(20):
+                        self = x
+                    return self
 
-        def func_3(x):
-            return torch.nn.functional.adaptive_avg_pool1d(x, output_size=[1])
+            ReassignSelfLHS()
 
-        x = torch.randn(8, 8, 8)
-        self.checkScript(func_1, [x], optimize=True)
-        self.checkScript(func_2, [x], optimize=True)
-        self.checkScript(func_3, [x], optimize=True)
+    def test_reassign_module_rhs(self):
+        with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'x\' to a value of type module because x is not a'
+                                    ' first-class value.  Only reassignments to first-class values are allowed'):
+            class ReassignSelfRHS(torch.jit.ScriptModule):
+                @torch.jit.script_method
+                def forward(self, x):
+                    for i in range(20):
+                        x = self
+                    return self
 
-    def test_wrong_implicit_expand(self):
+            ReassignSelfRHS()
 
-        @_trace(torch.zeros(3), torch.zeros(1))
-        def foo(a, b):
-            return a + b
+    def test_unknown_builtin(self):
+        with self.assertRaisesRegex(RuntimeError, 'unknown builtin op'):
+            @torch.jit.script
+            def unknown_builtin(x):
+                return x.splork(3)
 
-        a = torch.rand(4)
-        b = torch.rand(4)
-        self.assertEqual(a + b, foo(a, b))
+    def test_return_tuple(self):
+        def return_tuple(x):
+            a = (x, x)
+            return a, x
+        self.checkScript(return_tuple, (torch.rand(4),))
 
-    def test_builtin_args_fails(self):
+    def test_method_no_self(self):
+        with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'):
+            class MethodNoSelf(torch.jit.ScriptModule):
+                @torch.jit.script_method
+                def forward():
+                    return torch.zeros(3, 4)
 
-        with self.assertRaisesRegex(RuntimeError, 'expected at most'):
-            @torch.jit.script
-            def f0(a):
-                torch.sum(a, a, a, a)
+            MethodNoSelf()
 
-        with self.assertRaisesRegex(RuntimeError, 'argument self not provided'):
+    def test_return_stmt_not_at_end(self):
+        with self.assertRaisesRegex(RuntimeError, 'return statements can appear only at the end of the function body'):
             @torch.jit.script
-            def f1(a):
-                torch.sum(foo=4)
+            def return_stmt_wrong(x):
+                if bool(x > 3):
+                    return 3
+                else:
+                    return x
 
-        with self.assertRaisesRegex(RuntimeError, 'specified twice'):
+    def test_for_range_no_arg(self):
+        with self.assertRaisesRegex(RuntimeError, r'range\(\) expects 1 argument but got 0'):
             @torch.jit.script
-            def f2(a):
-                torch.sum(a, self=a)
+            def range_no_arg(x):
+                for i in range():
+                    x += 1
+                return x
 
-        with self.assertRaisesRegex(RuntimeError, 'not provided'):
-            @torch.jit.script
-            def f3(a):
-                torch.sum(dim=4)
+    def test_list_iterables(self):
+        with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
+            cu = torch.jit.CompilationUnit('''
+            def list_iterables(x):
+                for i, j in [2, 3, 4], [5, 6, 7]:
+                    x += i
+                    x += j
+                return x
+            ''')
 
-        with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but found Tensor'):
-            @torch.jit.script
-            def f4(a):
-                torch.cat(a)
+    def test_for_tuple_unpack(self):
+        with self.assertRaisesRegex(RuntimeError, 'Iteration variable unpacking is not supported'):
+            cu = torch.jit.CompilationUnit('''
+            def for_tuple_unpack(x, y):
+                for i, j in [[3, 4], [5, 6], [7, 8]]:
+                    x += i
+                    y += j
+                return x, y
+            ''')
 
-        with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but found int\[\]'):
-            @torch.jit.script
-            def f5(a):
-                torch.cat([3])
+    def test_single_starred_lhs(self):
+        with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
+                                                  ' of another non-starred expression'):
+            cu = torch.jit.CompilationUnit('''
+            def single_starred_lhs(x):
+                a = (x, x, x)
+                *b, = a
+                return b
+            ''')
 
-        with self.assertRaisesRegex(RuntimeError, 'Lists must contain only a single type'):
+    def test_singleton_tuple_unpack(self):
+        def foo(a):
+            b, = (a,)
+            return b + 1
+        self.checkScript(foo, (torch.rand(3),))
+
+    def test_multi_reduction(self):
+        with self.assertRaisesRegex(
+                RuntimeError,
+                'augmented assignment can only have one LHS expression'):
+            cu = torch.jit.CompilationUnit('''
+            def multi_reduction(x):
+                a, b += x
+                return a, b
+            ''')
+
+    def test_invalid_call_arguments(self):
+        with self.assertRaisesRegex(RuntimeError, 'arguments for call are not valid'):
             @torch.jit.script
-            def f6(a):
-                a.expand(size=[3, [4]])
+            def invalid_call_arguments(x):
+                return torch.unsqueeze(3, 4, 5, 6, 7, 8)
+
+    def test_invalid_lhs_assignment(self):
+        with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
+            cu = torch.jit.CompilationUnit('''
+            def invalid_lhs_assignment(x):
+                x + 1 = x
+                return x
+            ''')
+
+    def test_multi_starred_expr_lhs(self):
+        with self.assertRaisesRegex(RuntimeError, 'Only one starred expression is allowed on the lhs'):
+            cu = torch.jit.CompilationUnit('''
+            def multi_starred_expr_lhs():
+                a, *b, *c = [1, 2, 3, 4, 5, 6]
+                return a
+            ''')
+
+    def test_pack_tuple_into_non_var(self):
+        with self.assertRaisesRegex(RuntimeError, 'Cannot pack a tuple into a non-variable'):
+            cu = torch.jit.CompilationUnit('''
+            def pack_tuple_into_non_var(x):
+                a, *1 = (3, 4, 5)
+                return x
+            ''')
+
+    def test_print_kwargs(self):
+        with self.assertRaisesRegex(RuntimeError, 'print doesn\'t accept any keyword arguments'):
+            cu = torch.jit.CompilationUnit('''
+            def print_kwargs(x):
+                print(x, flush=True)
+                return x
+            ''')
 
-        with self.assertRaisesRegex(RuntimeError, 'xpected a value of type Tensor for argument \'self\''):
+    def test_builtin_use_as_value(self):
+        with self.assertRaisesRegex(RuntimeError, 'builtin cannot be used as a value'):
             @torch.jit.script
-            def f7(a):
-                torch.sum([4])
-
-    def test_builtin_args(self):
+            def builtin_use_as_value(x):
+                return x.unsqueeze
 
-        def t0(a):
-            # default arg dim
-            return torch.cat([a, a])
+    def test_wrong_use_as_tuple(self):
+        with self.assertRaisesRegex(RuntimeError, 'cannot be used as a tuple'):
+            def test_fn():
+                return 3
 
-        self.checkScript(t0, (torch.zeros(1, 1),))
+            @torch.jit.script
+            def wrong_use_as_tuple(self):
+                a, b = test_fn
+                return a
 
-        def t1(a):
-            # keywords out of order
-            return torch.cat(dim=1, tensors=[a, a])
+    def test_wrong_attr_lookup(self):
+        with self.assertRaisesRegex(RuntimeError, 'attribute lookup is not defined on builtin'):
+            @torch.jit.script
+            def wrong_attr_lookup(self, x):
+                a = x.unsqueeze.myattr
+                return a
 
-        self.checkScript(t1, (torch.zeros(1, 1, 2),))
+    def test_wrong_use_as_callable(self):
+        with self.assertRaisesRegex(RuntimeError, 'cannot call a value'):
+            @torch.jit.script
+            def wrong_use_as_callable(x):
+                return x(3, 4, 5)
 
-        def t2(a):
-            # mix const/non-const attributes
-            if True:
-                b = 1
-            else:
-                b = 0
-            return torch.sum(a, dim=b, keepdim=False)
+    def test_python_val_doesnt_have_attr(self):
+        with self.assertRaisesRegex(RuntimeError, 'object has no attribute abcd'):
 
-        self.checkScript(t2, (torch.zeros(1, 1, 2),))
+            @torch.jit.script
+            def python_val_doesnt_have_attr():
+                # this has to be a module otherwise attr lookup would not be
+                # allowed in the first place
+                return shutil.abcd
 
-    def test_parser_type_annotations(self):
-        cu = torch.jit.CompilationUnit('''
-            def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
-                return x, x
-        ''')
+    def test_wrong_module_attr_lookup(self):
+        with self.assertRaisesRegex(RuntimeError, 'python value of type \'type\' cannot be used as a value:'):
+            import io
 
-        self.assertExpected(cu.__getattr__('foo').pretty_print_schema())
+            @torch.jit.script
+            def wrong_module_attr_lookup():
+                return io.BytesIO
 
-    def test_parser_type_annotations_comment(self):
-        cu = torch.jit.CompilationUnit('''
-            def foo(x, y):
-                # type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]
-                return x, x
-        ''')
+    def test_wrong_method_call_inputs(self):
+        with self.assertRaisesRegex(RuntimeError, 'argument y not provided'):
+            class SomeModule(torch.jit.ScriptModule):
 
-        self.assertExpected(cu.__getattr__('foo').pretty_print_schema())
+                @torch.jit.script_method
+                def foo(self, x, y):
+                    return x
 
-    def test_parser_type_annotations_unknown_type(self):
-        with self.assertRaisesRegex(RuntimeError, r'Unknown type name Foo'):
-            cu = torch.jit.CompilationUnit('''
-                def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
-                    return x, x
-            ''')
+                @torch.jit.script_method
+                def forward(self, x, y):
+                    return self.foo(x)
+            SomeModule()
 
-    def test_parser_type_annotations_subscript_non_ident(self):
-        with self.assertRaisesRegex(RuntimeError, r'Subscripted type must be a type identifier'):
+    def test_single_starred_expr_for_loop(self):
+        with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
             cu = torch.jit.CompilationUnit('''
-                def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]:
-                    return x, x
+            def test():
+                x = 0
+                for *a in [1, 2, 3]:
+                    x = x + 1
+                return x
             ''')
 
-    def test_parser_type_annotations_subscript_tensor(self):
-        with self.assertRaisesRegex(RuntimeError, r'Unknown type constructor Tensor'):
+    def test_duplicate(self):
+        with self.assertRaisesRegex(RuntimeError, 'Method \'test\' already defined'):
             cu = torch.jit.CompilationUnit('''
-                def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
-                    return x, x
-            ''')
+            def test():
+                return 1
 
-    def test_parser_type_annotations_incompatible_expression(self):
-        with self.assertRaisesRegex(RuntimeError, r'Expression of type \+ cannot be used in a type expression'):
-            cu = torch.jit.CompilationUnit('''
-                def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]:
-                    return x, x
+            def test():
+                return 2
             ''')
 
-    def test_gather_dynamic_index(self):
-        def t(x):
-            gather1 = x[0]
-            idx = 0 + 1
-            gather2 = x[idx]
-            return gather1 + gather2
-
-        self.checkScript(t, (torch.zeros(3, 2, 3),))
+    def test_call_ge(self):
+        with self.assertRaisesRegex(RuntimeError, 'expected at most 1 arguments but found 3'):
+            @_trace(torch.zeros(1, 2, 3))
+            def foo(x):
+                return x
 
-    def test_slice_dynamic_index(self):
-        def t(x):
-            slice1 = x[0:1]
-            zero = 0
-            one = zero + 1
-            slice2 = x[zero:one]
-            return slice1 + slice2
+            @torch.jit.script
+            def test_fn():
+                return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3))
 
-        self.checkScript(t, (torch.zeros(3, 2, 3),))
+    def test_wrong_return_type(self):
+        with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
+            def somefunc():
+                # type: () -> Tuple[Tuple[Tensor, Tensor]]
+                return torch.zeros(3, 4), torch.zeros(4, 5)
 
-    def test_addmm_grad(self):
-        """ This test checks several things:
-            1. An expand node was inserted before the addmm operating on the
-               bias term.
-            2. The fused form of addmm appears in the ultimate graph that's
-               executed.
-            3. A sum op was emitted for accumulating gradients along the 0th
-               (expanded) dimension of the bias term.
-            4. The correct symbolic representation for the backward pass of the
-               mm operator was emitted (x.t() -> mm)
+            @torch.jit.script
+            def wrong_return_type():
+                return somefunc()
+            wrong_return_type()
 
-            TODO: we should actually check these conditions once we have a way
-            to dump the GraphExecutor state. Namely the processed forward graph
-            and the backward graph.
-        """
-        @torch.jit.script
-        def addmm_grad_test(b, x, w):
-            return torch.addmm(b, x, w)
+    # Tests for calling between different front-end modes
+    def test_call_python_fn_from_tracing_fn(self):
+        def python_fn(x):
+            return torch.neg(x)
 
-        # Initialize param and input values
-        w_init = torch.rand(2, 5)
-        b_init = torch.rand(5)
-        x = torch.rand(3, 2)
+        @_trace(torch.rand(3, 4))
+        def traced_fn(x):
+            return python_fn(x) + 1
 
-        # Clone trainable params
-        b = b_init.clone()
-        b.requires_grad_()
-        w = w_init.clone()
-        w.requires_grad_()
+        # The neg op in the python function should be properly inlined to the
+        # graph
+        self.assertExpected(canonical(traced_fn.graph))
 
-        # Test symbolic differentiation
-        y = addmm_grad_test(b, x, w)
-        y.sum().backward()
+    def test_call_python_mod_from_tracing_fn(self):
+        class PythonMod(torch.nn.Module):
+            def __init__(self):
+                super(PythonMod, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(4, 3))
 
-        # clone params for autograd reference
-        b_ref = b_init.clone()
-        b_ref.requires_grad_()
-        w_ref = w_init.clone()
-        w_ref.requires_grad_()
-        y_ref = torch.addmm(b_ref, x, w_ref)
-        y_ref.sum().backward()
+            def forward(self, x):
+                return torch.mm(x, self.param)
 
-        self.assertEqual(w.grad, w_ref.grad)
-        self.assertEqual(b.grad, b_ref.grad)
+        pm = PythonMod()
 
-    def test_zeros(self):
-        class M(torch.jit.ScriptModule):
-            __constants__ = ['d']
+        @_trace(torch.rand(3, 4))
+        def traced_fn(x):
+            return pm(x) + 1.0
 
-            def __init__(self):
-                self.d = torch.device('cpu')
+        # Note: the parameter self.param from the Python module is inlined
+        # into the graph
+        self.assertExpected(canonical(traced_fn.graph))
 
-            @torch.jit.script_method
-            def create(self):
-                return torch.zeros([1, 1, 2], dtype=torch.float, device=self.d, layout=torch.strided)
+    def test_call_traced_fn_from_tracing_fn(self):
+        @_trace(torch.rand(3, 4))
+        def traced_fn1(x):
+            return torch.neg(x)
 
-        r = M().create()
-        self.assertEqual(r.dtype, torch.float)
-        self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r)
+        @_trace(torch.rand(3, 4))
+        def traced_fn(x):
+            return traced_fn1(x) + 1
 
-    def test_vararg_zeros(self):
-        def foo():
-            return torch.zeros(3, 4, 5, dtype=torch.int)
+        self.assertExpected(canonical(traced_fn.graph))
 
-        self.checkScript(foo, ())
+    def test_call_traced_mod_from_tracing_fn(self):
+        class TracedModule(torch.nn.Module):
+            def __init__(self):
+                super(TracedModule, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(4, 3))
 
-    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
-    def test_rand(self):
+            def forward(self, x):
+                return torch.mm(x, self.param)
 
-        def test_rand():
-            a = torch.rand([3, 4])
-            return a + 1.0 - a
+        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
 
-        self.checkScript(test_rand, ())
+        @_trace(torch.rand(3, 4))
+        def traced_fn(x):
+            return tm(x) + 1.0
 
-    def test_erase_number_types(self):
-        def func(a):
-            b = 7 + 1 + 3
-            c = a + b
-            c += b
-            return c
+        # Note: the parameter self.param from the Python module is inlined
+        # into the graph
+        self.assertExpected(canonical(traced_fn.graph))
 
-        graph = torch.jit.script(func).graph
-        self.run_pass('remove_inplace_ops', graph)
-        self.run_pass('erase_number_types', graph)
-        self.assertExpectedGraph(graph)
+    def test_call_script_fn_from_tracing_fn(self):
+        @torch.jit.script
+        def script_fn(x):
+            return torch.neg(x)
 
-    def test_mm_batching(self):
-        lstm_cell = torch.jit.script(LSTMCellS)
+        @_trace(torch.rand(3, 4))
+        def traced_fn(x):
+            return script_fn(x) + 1
 
-        def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
-            for i in range(x.size(0)):
-                hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
-            return hx
+        self.assertExpected(canonical(traced_fn.graph))
 
-        slstm = torch.jit.script(lstm)
+    def test_call_script_mod_from_tracing_fn(self):
+        class ScriptMod(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ScriptMod, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(4, 3))
 
-        inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
-        slstm(*inputs).sum().backward()
+            @torch.jit.script_method
+            def forward(self, x):
+                return torch.mm(x, self.param)
 
-        fw_graph = slstm.graph_for(*inputs)
-        bw_graph = backward_graph(slstm, diff_graph_idx=0)
-        self.assertTrue('prim::MMBatchSide' in str(fw_graph))
-        self.assertTrue('prim::MMTreeReduce' in str(bw_graph))
+        sm = ScriptMod()
 
-        sout = slstm(*inputs)
-        out = lstm(*inputs)
-        self.assertEqual(slstm(*inputs), lstm(*inputs))
-        self.assertEqual(torch.autograd.grad(slstm(*inputs).sum(), inputs),
-                         torch.autograd.grad(lstm(*inputs).sum(), inputs))
+        @_trace(torch.rand(3, 4))
+        def traced_fn(x):
+            return sm(x) + 1.0
 
-    def test_loop_unrolling(self):
-        def fn(x):
-            y = 0
-            for i in range(int(x)):
-                y += i
-            return y
+        self.assertExpected(canonical(traced_fn.graph))
 
-        graph = torch.jit.script(fn).graph
-        self.run_pass('loop_unrolling', graph)
-        self.assertExpectedGraph(graph)
-        self.checkScript(fn, (torch.tensor(10),))
+    def test_call_python_fn_from_traced_module(self):
+        def python_fn(x):
+            return torch.neg(x)
 
-    def test_loop_unrolling_const(self):
-        def fn():
-            y = 0
-            for i in range(10):
-                y += 1
-            return y
+        class TracedModule(torch.nn.Module):
+            def __init__(self):
+                super(TracedModule, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(4, 3))
 
-        def fn2():
-            y = 0
-            for i in range(10):
-                y += i
-            return y
+            def forward(self, x):
+                return torch.mm(python_fn(x), self.param)
 
-        def check(fn, name):
-            graph = torch.jit.script(fn).graph
-            self.run_pass('loop_unrolling', graph)
-            self.assertExpectedGraph(graph, subname=name)
-            self.checkScript(fn, ())
+        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
 
-        check(fn, 'add_const')
-        check(fn2, 'add_iter')
+        # Note: parameter self.param from the traced module should appear as
+        # an input to the graph and the neg op from the Python function should
+        # be properly inlined
+        self.assertExpected(canonical(tm.graph))
 
-    def test_loop_unrolling_nested(self):
-        def fn(x):
-            y = 0
-            for i in range(10):
-                for j in range(int(x)):
-                    y += j
-            return y
+    def test_call_python_mod_from_traced_module(self):
+        class PythonModule(torch.nn.Module):
+            def __init__(self):
+                super(PythonModule, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(5, 7))
 
-        graph = torch.jit.script(fn).graph
-        self.run_pass('loop_unrolling', graph)
-        self.assertExpectedGraph(graph)
-        self.checkScript(fn, (torch.tensor(10),))
+            def forward(self, x):
+                return torch.mm(x, self.param)
 
-    def test_loop_unroll_unused_counter(self):
-        def fn(x):
-            y = 0
-            for i in range(int(x)):
-                y += 1
-            return y
+        class TracedModule(torch.nn.Module):
+            def __init__(self):
+                super(TracedModule, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(4, 5))
+                self.mod = PythonModule()
 
-        graph = torch.jit.script(fn).graph
-        self.run_pass('loop_unrolling', graph)
-        self.assertExpectedGraph(graph)
+            def forward(self, x):
+                return self.mod(torch.mm(x, self.param)) + 1.0
 
-    def test_loop_unroll_negative(self):
-        def fn(x):
-            y = 0
-            for i in range(int(x)):
-                y += 1
-            return y
+        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
 
-        self.checkScript(fn, (torch.tensor(-20),))
-        self.checkScript(fn, (torch.tensor(-2),))
-        self.checkScript(fn, (torch.tensor(-1),))
-        self.checkScript(fn, (torch.tensor(0),))
-        self.checkScript(fn, (torch.tensor(1),))
-        self.checkScript(fn, (torch.tensor(2),))
+        # Note: the parameters from both modules should appear in the flattened
+        # inputs of the graph. All ops from both modules should be inlined.
+        self.assertExpected(canonical(tm.graph))
 
-    def test_where(self):
-        def fn(x, y):
-            return torch.where(x > 0.0, x, y)
+    def test_call_traced_fn_from_traced_module(self):
+        @_trace(torch.rand(3, 4))
+        def traced_fn(x):
+            return torch.neg(x)
 
-        self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
+        class TracedModule(torch.nn.Module):
+            def __init__(self):
+                super(TracedModule, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(4, 5))
 
-    def test_where_method(self):
-        def fn(x, y):
-            return x.where(x > 0.0, y)
+            def forward(self, x):
+                return traced_fn(torch.mm(x, self.param))
 
-        self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
+        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+        # Note: neg op from the traced function should be properly inlined
+        self.assertExpected(canonical(tm.graph))
 
-    def test_reassign_module_lhs(self):
-        with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\' because it has type value and self is'
-                                    ' not a first-class value.  Only reassignments to first-class values are allowed'):
-            class ReassignSelfLHS(torch.jit.ScriptModule):
-                @torch.jit.script_method
-                def forward(self, x):
-                    for i in range(20):
-                        self = x
-                    return self
+    def test_call_traced_module_from_traced_module(self):
+        class TracedModule1(torch.nn.Module):
+            def __init__(self):
+                super(TracedModule1, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(5, 7))
 
-            ReassignSelfLHS()
+            def forward(self, x):
+                return torch.mm(x, self.param)
 
-    def test_reassign_module_rhs(self):
-        with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'x\' to a value of type module because x is not a'
-                                    ' first-class value.  Only reassignments to first-class values are allowed'):
-            class ReassignSelfRHS(torch.jit.ScriptModule):
-                @torch.jit.script_method
-                def forward(self, x):
-                    for i in range(20):
-                        x = self
-                    return self
+        class TracedModule(torch.nn.Module):
+            def __init__(self):
+                super(TracedModule, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(4, 5))
+                self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5))
 
-            ReassignSelfRHS()
+            def forward(self, x):
+                return self.mod(torch.mm(x, self.param)) + 1.0
 
-    def test_unknown_builtin(self):
-        with self.assertRaisesRegex(RuntimeError, 'unknown builtin op'):
-            @torch.jit.script
-            def unknown_builtin(x):
-                return x.splork(3)
+        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
 
-    def test_return_tuple(self):
-        def return_tuple(x):
-            a = (x, x)
-            return a, x
-        self.checkScript(return_tuple, (torch.rand(4),))
+        # Note: the parameters from both modules should appear in the flattened
+        # inputs of the graph. All ops from both modules should be inlined.
+        self.assertExpected(canonical(tm.graph))
 
-    def test_method_no_self(self):
-        with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'):
-            class MethodNoSelf(torch.jit.ScriptModule):
-                @torch.jit.script_method
-                def forward():
-                    return torch.zeros(3, 4)
+    def test_call_script_fn_from_traced_module(self):
+        @torch.jit.script
+        def traced_fn(x):
+            return torch.neg(x)
 
-            MethodNoSelf()
+        class TracedModule(torch.nn.Module):
+            def __init__(self):
+                super(TracedModule, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(4, 5))
 
-    def test_return_stmt_not_at_end(self):
-        with self.assertRaisesRegex(RuntimeError, 'return statements can appear only at the end of the function body'):
-            @torch.jit.script
-            def return_stmt_wrong(x):
-                if bool(x > 3):
-                    return 3
-                else:
-                    return x
+            def forward(self, x):
+                return traced_fn(torch.mm(x, self.param))
 
-    def test_for_range_no_arg(self):
-        with self.assertRaisesRegex(RuntimeError, r'range\(\) expects 1 argument but got 0'):
-            @torch.jit.script
-            def range_no_arg(x):
-                for i in range():
-                    x += 1
-                return x
+        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+        # Note: neg op from the script function should be properly inlined
+        self.assertExpected(canonical(tm.graph))
 
-    def test_list_iterables(self):
-        with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
-            cu = torch.jit.CompilationUnit('''
-            def list_iterables(x):
-                for i, j in [2, 3, 4], [5, 6, 7]:
-                    x += i
-                    x += j
-                return x
-            ''')
+    def test_call_script_module_from_traced_module(self):
+        class ScriptMod(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ScriptMod, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(5, 7))
+
+            @torch.jit.script_method
+            def forward(self, x):
+                return torch.mm(x, self.param)
 
-    def test_for_tuple_unpack(self):
-        with self.assertRaisesRegex(RuntimeError, 'Iteration variable unpacking is not supported'):
-            cu = torch.jit.CompilationUnit('''
-            def for_tuple_unpack(x, y):
-                for i, j in [[3, 4], [5, 6], [7, 8]]:
-                    x += i
-                    y += j
-                return x, y
-            ''')
+        class TracedModule(torch.nn.Module):
+            def __init__(self):
+                super(TracedModule, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(4, 5))
+                self.mod = ScriptMod()
 
-    def test_single_starred_lhs(self):
-        with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
-                                                  ' of another non-starred expression'):
-            cu = torch.jit.CompilationUnit('''
-            def single_starred_lhs(x):
-                a = (x, x, x)
-                *b, = a
-                return b
-            ''')
+            def forward(self, x):
+                return self.mod(torch.mm(x, self.param)) + 1.0
 
-    def test_singleton_tuple_unpack(self):
-        def foo(a):
-            b, = (a,)
-            return b + 1
-        self.checkScript(foo, (torch.rand(3),))
+        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
 
-    def test_multi_reduction(self):
-        with self.assertRaisesRegex(
-                RuntimeError,
-                'augmented assignment can only have one LHS expression'):
-            cu = torch.jit.CompilationUnit('''
-            def multi_reduction(x):
-                a, b += x
-                return a, b
-            ''')
+        # Note: the parameters from both modules should appear in the flattened
+        # inputs of the graph. All ops from both modules should be inlined.
+        self.assertExpected(canonical(tm.graph))
 
-    def test_invalid_call_arguments(self):
-        with self.assertRaisesRegex(RuntimeError, 'arguments for call are not valid'):
-            @torch.jit.script
-            def invalid_call_arguments(x):
-                return torch.unsqueeze(3, 4, 5, 6, 7, 8)
+    def test_call_python_fn_from_script_fn(self):
+        def python_fn(x):
+            return torch.neg(x)
 
-    def test_invalid_lhs_assignment(self):
-        with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
-            cu = torch.jit.CompilationUnit('''
-            def invalid_lhs_assignment(x):
-                x + 1 = x
-                return x
-            ''')
+        @torch.jit.script
+        def script_fn(x):
+            return python_fn(x) + 1
 
-    def test_multi_starred_expr_lhs(self):
-        with self.assertRaisesRegex(RuntimeError, 'Only one starred expression is allowed on the lhs'):
-            cu = torch.jit.CompilationUnit('''
-            def multi_starred_expr_lhs():
-                a, *b, *c = [1, 2, 3, 4, 5, 6]
-                return a
-            ''')
+        # Note: the call to python_fn appears as `^python_fn()` and is called
+        # as a PythonOp in the interpreter
+        self.assertExpected(canonical(script_fn.graph))
 
-    def test_pack_tuple_into_non_var(self):
-        with self.assertRaisesRegex(RuntimeError, 'Cannot pack a tuple into a non-variable'):
-            cu = torch.jit.CompilationUnit('''
-            def pack_tuple_into_non_var(x):
-                a, *1 = (3, 4, 5)
-                return x
-            ''')
+    def test_call_python_mod_from_script_fn(self):
+        class PythonModule(torch.nn.Module):
+            def __init__(self):
+                super(PythonModule, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(5, 7))
 
-    def test_print_kwargs(self):
-        with self.assertRaisesRegex(RuntimeError, 'print doesn\'t accept any keyword arguments'):
-            cu = torch.jit.CompilationUnit('''
-            def print_kwargs(x):
-                print(x, flush=True)
-                return x
-            ''')
+            def forward(self, x):
+                return torch.mm(x, self.param)
 
-    def test_builtin_use_as_value(self):
-        with self.assertRaisesRegex(RuntimeError, 'builtin cannot be used as a value'):
-            @torch.jit.script
-            def builtin_use_as_value(x):
-                return x.unsqueeze
+        pm = PythonModule()
 
-    def test_wrong_use_as_tuple(self):
-        with self.assertRaisesRegex(RuntimeError, 'cannot be used as a tuple'):
-            def test_fn():
-                return 3
+        @torch.jit.script
+        def script_fn(x):
+            return pm(x) + 1
 
-            @torch.jit.script
-            def wrong_use_as_tuple(self):
-                a, b = test_fn
-                return a
+        # Note: call to pm(x) appears as ^<python_value>() in the trace.
+        # Parameters are NOT inlined.
+        self.assertExpected(str(script_fn.graph))
 
-    def test_wrong_attr_lookup(self):
-        with self.assertRaisesRegex(RuntimeError, 'attribute lookup is not defined on builtin'):
-            @torch.jit.script
-            def wrong_attr_lookup(self, x):
-                a = x.unsqueeze.myattr
-                return a
+    def test_call_traced_fn_from_script_fn(self):
+        @_trace(torch.rand(3, 4))
+        def traced_fn(x):
+            return torch.neg(x)
 
-    def test_wrong_use_as_callable(self):
-        with self.assertRaisesRegex(RuntimeError, 'cannot call a value'):
-            @torch.jit.script
-            def wrong_use_as_callable(x):
-                return x(3, 4, 5)
+        @torch.jit.script
+        def script_fn(x):
+            return traced_fn(x) + 1
 
-    def test_python_val_doesnt_have_attr(self):
-        with self.assertRaisesRegex(RuntimeError, 'object has no attribute abcd'):
+        # Note: the neg op from traced_fn should be properly inlined into the
+        # script function's graph
+        self.assertExpected(str(script_fn.graph))
 
-            @torch.jit.script
-            def python_val_doesnt_have_attr():
-                # this has to be a module otherwise attr lookup would not be
-                # allowed in the first place
-                return shutil.abcd
+    def test_call_traced_mod_from_script_fn(self):
+        class TracedModule(torch.nn.Module):
+            def __init__(self):
+                super(TracedModule, self).__init__()
 
-    def test_wrong_module_attr_lookup(self):
-        with self.assertRaisesRegex(RuntimeError, 'python value of type \'type\' cannot be used as a value:'):
-            import io
+            def forward(self, x):
+                return torch.mm(x, torch.zeros(4, 3))
 
-            @torch.jit.script
-            def wrong_module_attr_lookup():
-                return io.BytesIO
+        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
 
-    def test_wrong_method_call_inputs(self):
-        with self.assertRaisesRegex(RuntimeError, 'argument y not provided'):
-            class SomeModule(torch.jit.ScriptModule):
+        @torch.jit.script
+        def script_fn(x):
+            return tm(x) + 1
 
-                @torch.jit.script_method
-                def foo(self, x, y):
-                    return x
+        self.assertExpected(str(script_fn.graph))
 
-                @torch.jit.script_method
-                def forward(self, x, y):
-                    return self.foo(x)
-            SomeModule()
+    def test_call_script_fn_from_script_fn(self):
+        @torch.jit.script
+        def script_fn1(x):
+            return torch.neg(x)
 
-    def test_single_starred_expr_for_loop(self):
-        with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
-            cu = torch.jit.CompilationUnit('''
-            def test():
-                x = 0
-                for *a in [1, 2, 3]:
-                    x = x + 1
-                return x
-            ''')
+        @torch.jit.script
+        def script_fn(x):
+            return script_fn1(x) + 1
 
-    def test_duplicate(self):
-        with self.assertRaisesRegex(RuntimeError, 'Method \'test\' already defined'):
-            cu = torch.jit.CompilationUnit('''
-            def test():
-                return 1
+        # Note: the neg op from script_fn1 should be properly inlined into the
+        # graph of script_fn
+        self.assertExpected(canonical(script_fn.graph))
 
-            def test():
-                return 2
-            ''')
+    def test_call_script_mod_from_script_fn(self):
+        class ScriptMod(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ScriptMod, self).__init__()
 
-    def test_call_ge(self):
-        with self.assertRaisesRegex(RuntimeError, 'expected at most 1 arguments but found 3'):
-            @_trace(torch.zeros(1, 2, 3))
-            def foo(x):
-                return x
+            @torch.jit.script_method
+            def forward(self, x):
+                return torch.mm(x, torch.zeros([4, 3]))
 
-            @torch.jit.script
-            def test_fn():
-                return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3))
+        sm = ScriptMod()
 
-    def test_wrong_return_type(self):
-        with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
-            def somefunc():
-                # type: () -> Tuple[Tuple[Tensor, Tensor]]
-                return torch.zeros(3, 4), torch.zeros(4, 5)
+        @torch.jit.script
+        def script_fn(x):
+            return sm(x) + 1
 
-            @torch.jit.script
-            def wrong_return_type():
-                return somefunc()
-            wrong_return_type()
+        self.assertExpected(canonical(script_fn.graph))
 
-    # Tests for calling between different front-end modes
-    def test_call_python_fn_from_tracing_fn(self):
+    def test_call_python_fn_from_script_module(self):
         def python_fn(x):
             return torch.neg(x)
 
-        @_trace(torch.rand(3, 4))
-        def traced_fn(x):
-            return python_fn(x) + 1
+        class ScriptMod(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ScriptMod, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(4, 3))
 
-        # The neg op in the python function should be properly inlined to the
-        # graph
-        self.assertExpected(canonical(traced_fn.graph))
+            @torch.jit.script_method
+            def forward(self, x):
+                return python_fn(torch.mm(x, self.param))
 
-    def test_call_python_mod_from_tracing_fn(self):
+        sm = ScriptMod()
+        self.assertExpected(str(sm.__getattr__('forward').graph))
+
+    def test_call_python_mod_from_script_module(self):
         class PythonMod(torch.nn.Module):
             def __init__(self):
                 super(PythonMod, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(3, 5))
+
+            def forward(self, x):
+                return torch.mm(x, self.param)
+
+        class ScriptMod(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ScriptMod, self).__init__()
                 self.param = torch.nn.Parameter(torch.rand(4, 3))
+                self.pm = PythonMod()
 
+            @torch.jit.script_method
             def forward(self, x):
-                return torch.mm(x, self.param)
+                return self.pm(torch.mm(x, self.param))
 
-        pm = PythonMod()
+        sm = ScriptMod()
+        # Note: the call into PythonMod appears as ^<python_value>(). Parameters
+        # are NOT inlined
+        self.assertExpected(str(sm.graph))
 
-        @_trace(torch.rand(3, 4))
+    def test_call_tracing_fn_from_script_module(self):
+        @_trace(torch.rand(3, 3))
         def traced_fn(x):
-            return pm(x) + 1.0
-
-        # Note: the parameter self.param from the Python module is inlined
-        # into the graph
-        self.assertExpected(canonical(traced_fn.graph))
-
-    def test_call_traced_fn_from_tracing_fn(self):
-        @_trace(torch.rand(3, 4))
-        def traced_fn1(x):
             return torch.neg(x)
 
-        @_trace(torch.rand(3, 4))
-        def traced_fn(x):
-            return traced_fn1(x) + 1
+        class ScriptMod(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ScriptMod, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(4, 3))
 
-        self.assertExpected(canonical(traced_fn.graph))
+            @torch.jit.script_method
+            def forward(self, x):
+                return traced_fn(torch.mm(x, self.param))
 
-    def test_call_traced_mod_from_tracing_fn(self):
-        class TracedModule(torch.nn.Module):
+        sm = ScriptMod()
+        self.assertExpected(str(sm.__getattr__('forward').graph))
+
+    def test_call_tracing_mod_from_script_module(self):
+        class TracedMod(torch.nn.Module):
             def __init__(self):
-                super(TracedModule, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 3))
+                super(TracedMod, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(3, 5))
 
             def forward(self, x):
                 return torch.mm(x, self.param)
 
-        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+        class ScriptMod(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ScriptMod, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(4, 3))
+                self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3))
 
-        @_trace(torch.rand(3, 4))
-        def traced_fn(x):
-            return tm(x) + 1.0
+            @torch.jit.script_method
+            def forward(self, x):
+                return self.tm(torch.mm(x, self.param))
 
-        # Note: the parameter self.param from the Python module is inlined
-        # into the graph
-        self.assertExpected(canonical(traced_fn.graph))
+        sm = ScriptMod()
+        # Note: the parameters from both modules should appear in the flattened
+        # input list to the graph. The mm op from TracedMod should be properly
+        # inlined
+        self.assertExpected(str(sm.graph))
 
-    def test_call_script_fn_from_tracing_fn(self):
+    def test_call_script_fn_from_script_module(self):
         @torch.jit.script
         def script_fn(x):
             return torch.neg(x)
 
-        @_trace(torch.rand(3, 4))
-        def traced_fn(x):
-            return script_fn(x) + 1
-
-        self.assertExpected(canonical(traced_fn.graph))
-
-    def test_call_script_mod_from_tracing_fn(self):
         class ScriptMod(torch.jit.ScriptModule):
             def __init__(self):
                 super(ScriptMod, self).__init__()
 
             @torch.jit.script_method
             def forward(self, x):
-                return torch.mm(x, self.param)
+                return script_fn(torch.mm(x, self.param))
 
         sm = ScriptMod()
+        self.assertExpected(canonical(sm.__getattr__('forward').graph))
 
-        @_trace(torch.rand(3, 4))
-        def traced_fn(x):
-            return sm(x) + 1.0
-
-        self.assertExpected(canonical(traced_fn.graph))
+    def test_call_script_mod_from_script_module(self):
+        class ScriptMod1(torch.jit.ScriptModule):
+            def __init__(self):
+                super(ScriptMod1, self).__init__()
+                self.param = torch.nn.Parameter(torch.rand(3, 5))
 
-    def test_call_python_fn_from_traced_module(self):
-        def python_fn(x):
-            return torch.neg(x)
+            @torch.jit.script_method
+            def forward(self, x):
+                return torch.mm(x, self.param)
 
-        class TracedModule(torch.nn.Module):
+        class ScriptMod(torch.jit.ScriptModule):
             def __init__(self):
-                super(TracedModule, self).__init__()
+                super(ScriptMod, self).__init__()
                 self.param = torch.nn.Parameter(torch.rand(4, 3))
+                self.tm = ScriptMod1()
 
+            @torch.jit.script_method
             def forward(self, x):
-                return torch.mm(python_fn(x), self.param)
+                return self.tm(torch.mm(x, self.param))
 
-        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+        sm = ScriptMod()
+        # Note: the parameters from both modules should appear in the flattened
+        # input list to the graph. The mm op from ScriptMod1 should be properly
+        # inlined
+        self.assertExpected(canonical(sm.graph))
 
-        # Note: parameter self.param from the traced module should appear as
-        # an input to the graph and the neg op from the Python function should
-        # be properly inlined
-        self.assertExpected(canonical(tm.graph))
+    def test_module_with_params_called_fails(self):
+        with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with parameters. Stateful "
+                                                  "modules to be inlined must be submodules of the callee."):
+            class ScriptMod(torch.jit.ScriptModule):
+                def __init__(self):
+                    super(ScriptMod, self).__init__()
+                    self.param = torch.nn.Parameter(torch.rand(3, 3))
 
-    def test_call_python_mod_from_traced_module(self):
-        class PythonModule(torch.nn.Module):
-            def __init__(self):
-                super(PythonModule, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(5, 7))
+                @torch.jit.script_method
+                def forward(self, x):
+                    return torch.mm(x, self.param)
 
-            def forward(self, x):
-                return torch.mm(x, self.param)
+            sm = ScriptMod()
 
-        class TracedModule(torch.nn.Module):
-            def __init__(self):
-                super(TracedModule, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 5))
-                self.mod = PythonModule()
+            @torch.jit.script
+            def some_func(x):
+                return sm(x)
 
-            def forward(self, x):
-                return self.mod(torch.mm(x, self.param)) + 1.0
+    def test_index_put_trace_with_view(self):
+        @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
+        def test_index_put(target, indices, rhs):
+            target[indices] = rhs
+            return target
 
-        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+        self.assertExpectedGraph(test_index_put.graph)
 
-        # Note: the parameters from both modules should appear in the flattened
-        # inputs of the graph. All ops from both modules should be inlined.
-        self.assertExpected(canonical(tm.graph))
+    def test_index_put_trace_without_view(self):
+        @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
+        def test_index_put(target, indices, rhs):
+            target[indices] = rhs
+            return target
 
-    def test_call_traced_fn_from_traced_module(self):
-        @_trace(torch.rand(3, 4))
-        def traced_fn(x):
-            return torch.neg(x)
+        self.assertExpectedGraph(test_index_put.graph)
 
-        class TracedModule(torch.nn.Module):
-            def __init__(self):
-                super(TracedModule, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 5))
+    def test_tuple_indexing(self):
+        def tuple_index(a):
+            if bool(a):
+                b = (1, 2)
+            else:
+                b = (0, 2)
+            return b[-2], b[1]
+
+        self.checkScript(tuple_index, (torch.tensor([1]),))
+        self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True)
+        tuple_comp = torch.jit.script(tuple_index)
+        self.assertExpectedGraph(tuple_comp.graph)
+        self.run_pass('lower_all_tuples', tuple_comp.graph)
+        m = torch.jit.ScriptModule()
+        m._create_method_from_graph("forward", tuple_comp.graph)
+        self.assertEqual(m(torch.tensor(1)), (1, 2))
+
+        with self.assertRaisesRegex(RuntimeError, "tuple indices must be integer constants"):
+            @torch.jit.script
+            def test_non_constant_input(a):
+                if bool(a):
+                    b = 1
+                else:
+                    b = 0
+                c = (0, 1)
+                return c[b]
+
+        def test_indexing_float():
+            c = (1, 2)
+            return c[0.1]
+        self.checkScriptRaisesRegex(test_indexing_float, (), Exception,
+                                    "tuple indices must")
+
+        def test_indexing_out_of_bounds_pos():
+            c = (1, 2)
+            return c[2]
+
+        self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
+                                    "out of range")
+
+        def test_indexing_out_of_bounds_neg():
+            c = (1, 2)
+            return c[-3]
+
+        self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
+                                    "out of range")
+
+    def test_tuple_slicing(self):
+        def tuple_slice(a):
+            if bool(a):
+                b = (1, 2, 3, 4)
+            else:
+                b = (4, 3, 2, 1)
+            c = b[-4:4]
+            d = b[0:]
+            e = c[1:-1]
+            return e
+
+        self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
+        tuple_comp = torch.jit.script(tuple_slice)
+        self.assertExpectedGraph(tuple_comp.graph)
+        self.run_pass('lower_all_tuples', tuple_comp.graph)
+        self.assertTrue('Tuple' not in str(tuple_comp.graph))
+        m = torch.jit.ScriptModule()
+        m._create_method_from_graph("forward", tuple_comp.graph)
+        self.assertEqual(m(torch.tensor(1)), (2, 3))
+
+        @torch.jit.script
+        def test_indexing_end_out_of_bounds():
+            c = (1, 2)
+            return c[2:10]
+
+        # output is None in script and () in python
+        self.assertEqual(test_indexing_end_out_of_bounds(), None)
+
+    def test_unwrap_optional_builtin(self):
+        def test(x):
+            # type: (Optional[int]) -> int
+            x = torch.jit._unwrap_optional(x)
+            x = x + x
+            return x
+
+        self.checkScript(test, (3,))
+
+        with self.assertRaisesRegex(AssertionError, "Unwrapping null optional"):
+            test(None)
+
+        test_script = torch.jit.script(test)
+        with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
+            test_script(None)
+
+        @torch.jit.script
+        def test_test():
+            return torch.jit._unwrap_optional(1)
+
+        with self.assertRaisesRegex(RuntimeError, "cannot match an Optional\\[T\\] to None"):
+            @torch.jit.script
+            def test_no_type():
+                # type: () -> int
+                return torch.jit._unwrap_optional(None)
 
-            def forward(self, x):
-                return traced_fn(torch.mm(x, self.param))
+    def test_indexing_error(self):
+        with self.assertRaisesRegex(RuntimeError, "Indexing only supported on lists, tensors, and tuples"):
+            @torch.jit.script
+            def test_wrong_type():
+                a = 8
+                return a[0]
 
-        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
-        # Note: neg op from the traced function should be properly inlined
-        self.assertExpected(canonical(tm.graph))
+    def test_annotated_script_fn(self):
+        @torch.jit.script
+        def foo(x, y, z):
+            # type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
+            return x
 
-    def test_call_traced_module_from_traced_module(self):
-        class TracedModule1(torch.nn.Module):
-            def __init__(self):
-                super(TracedModule1, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(5, 7))
+        self.assertExpected(foo.__getattr__('forward').pretty_print_schema())
 
-            def forward(self, x):
-                return torch.mm(x, self.param)
+    def test_annotated_script_method(self):
+        class SM(torch.jit.ScriptModule):
+            @torch.jit.script_method
+            def forward(self, x, y):
+                # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor, Tensor]
+                return y, y, y
 
-        class TracedModule(torch.nn.Module):
-            def __init__(self):
-                super(TracedModule, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 5))
-                self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5))
+        sm = SM()
 
-            def forward(self, x):
-                return self.mod(torch.mm(x, self.param)) + 1.0
+        self.assertExpected(sm.__getattr__('forward').pretty_print_schema())
 
-        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+    def test_annotated_script_fn_return_mismatch(self):
+        with self.assertRaisesRegex(RuntimeError, r"Return value at position 0 was annotated as "
+                                                  r"having type \(Tensor, Tensor\) but is "
+                                                  r"actually of type Tensor"):
+            @torch.jit.script
+            def return_tup(x):
+                # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
+                return x, x
 
-        # Note: the parameters from both modules should appear in the flattened
-        # inputs of the graph. All ops from both modules should be inlined.
-        self.assertExpected(canonical(tm.graph))
+    def test_annotated_script_fn_arg_mismatch(self):
+        with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
+            @torch.jit.script
+            def tuple_arg(x):
+                # type: (Tuple[Tensor, Tensor]) -> Tensor
+                return x + 1
 
-    def test_call_script_fn_from_traced_module(self):
+    def test_script_non_tensor_args_outputs(self):
         @torch.jit.script
-        def traced_fn(x):
-            return torch.neg(x)
-
-        class TracedModule(torch.nn.Module):
-            def __init__(self):
-                super(TracedModule, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 5))
+        def fn(x, y):
+            # type: (Tensor, float) -> float
+            return float((x + y).sum())
 
-            def forward(self, x):
-                return traced_fn(torch.mm(x, self.param))
+        x = torch.ones(2, 2)
+        z = fn(x, 1)
+        self.assertIsInstance(z, float)
+        self.assertEqual(z, 8.)
 
-        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
-        # Note: neg op from the script function should be properly inlined
-        self.assertExpected(canonical(tm.graph))
+    @unittest.skip('https://github.com/pytorch/pytorch/issues/9595')
+    def test_inline_and_run_annotated_script_fn(self):
+        @torch.jit.script
+        def to_inline(x, y):
+            # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
+            return y
 
-    def test_call_script_module_from_traced_module(self):
-        class ScriptMod(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ScriptMod, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(5, 7))
+        @torch.jit.script
+        def some_func(x):
+            return to_inline((x, x), x)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return torch.mm(x, self.param)
+        x = torch.rand(3, 4)
+        self.assertEqual(some_func(x), x)
 
-        class TracedModule(torch.nn.Module):
-            def __init__(self):
-                super(TracedModule, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 5))
-                self.mod = ScriptMod()
+    def test_file_format_serialization(self):
+        import tempfile
+        filename = tempfile.mktemp()
+        writer = torch._C.PyTorchFileWriter(filename)
+        import os
+        import random
+        buffers = [os.urandom(size) for size in [random.randint(1, 100) for i in range(20)]]
+        offsets = []
+        for i, buf in enumerate(buffers):
+            writer.write_record(str(i), buf, len(buf))
+            offsets.append(i)
+        import pickle
+        serialized_offsets = pickle.dumps(offsets)
+        writer.write_record("meta", serialized_offsets, len(serialized_offsets))
+        writer.write_end_of_file()
 
-            def forward(self, x):
-                return self.mod(torch.mm(x, self.param)) + 1.0
+        reader = torch._C.PyTorchFileReader(filename)
+        serialized_offsets_read = reader.get_record("meta")
+        parsed_serialized_offsets = pickle.loads(serialized_offsets)
 
-        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+        for i, offset in enumerate(parsed_serialized_offsets):
+            data = reader.get_record(str(offset))
+            assert(data == buffers[i])
 
-        # Note: the parameters from both modules should appear in the flattened
-        # inputs of the graph. All ops from both modules should be inlined.
-        self.assertExpected(canonical(tm.graph))
+    # for each type, the input type annotation and corresponding return type annotation
+    def type_input_return_pairs(self):
+        return [
+            ('Tensor', 'Tensor'),
+            ('torch.Tensor', 'Tensor'),
+            ('str', 'str'),
+            ('int', 'int'),
+            ('bool', 'bool'),
+            ('BroadcastingList3[float]', 'List[float]'),
+            ('BroadcastingList2[int]', 'List[int]'),
+            ('List[int]', 'List[int]'),
+            ('Optional[int]', 'Optional[int]'),
+        ]
 
-    def test_call_python_fn_from_script_fn(self):
-        def python_fn(x):
-            return torch.neg(x)
+    # replacing code input & return type pair
+    def format_code(self, code, pair):
+        return code.format(input=pair[0], output=pair[1])
 
-        @torch.jit.script
-        def script_fn(x):
-            return python_fn(x) + 1
+    # ***** Type annotation tests ****
+    # Test combinations of:
+    # {String frontend, Python AST Frontend}
+    # {Python 3-style type annotations, MyPy-style type comments}
+    # {Script method, Script function}
 
-        # Note: the call to python_fn appears as `^python_fn()` and is called
-        # as a PythonOp in the interpreter
-        self.assertExpected(canonical(script_fn.graph))
+    #  String frontend , Python 3-style type annotations , Script function
+    def test_annot_string_py3_fn(self):
+        code = '''
+            def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
+                return x, x
+        '''
+        test_str = []
+        for pair in self.type_input_return_pairs():
+            cu = torch.jit.CompilationUnit(self.format_code(code, pair))
+            test_str.append(cu.__getattr__('foo').pretty_print_schema())
+        self.assertExpected("\n".join(test_str))
 
-    def test_call_python_mod_from_script_fn(self):
-        class PythonModule(torch.nn.Module):
+    #  String frontend , Python 3-style type annotations , Script method
+    def test_annot_string_py3_method(self):
+        class TestModule(torch.jit.ScriptModule):
             def __init__(self):
-                super(PythonModule, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(5, 7))
+                super(TestModule, self).__init__()
 
-            def forward(self, x):
-                return torch.mm(x, self.param)
+        code = '''
+            def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
+                return x, x
+        '''
+        test_str = []
+        for pair in self.type_input_return_pairs():
+            tm = TestModule()
+            tm.define(self.format_code(code, pair))
+            test_str.append(tm.__getattr__('foo').pretty_print_schema())
+        self.assertExpected("\n".join(test_str))
 
-        pm = PythonModule()
+    #  String frontend , MyPy-style type comments , Script function
+    def test_annot_string_mypy_fn(self):
+        code = '''
+            def foo(x, y):
+                # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
+                return x, x
+        '''
+        test_str = []
+        for pair in self.type_input_return_pairs():
+            cu = torch.jit.CompilationUnit(self.format_code(code, pair))
+            test_str.append(cu.__getattr__('foo').pretty_print_schema())
+        self.assertExpected("\n".join(test_str))
 
-        @torch.jit.script
-        def script_fn(x):
-            return pm(x) + 1
+    #  String frontend , MyPy-style type comments , Script method
+    def test_annot_string_mypy_method(self):
+        class TestModule(torch.jit.ScriptModule):
+            def __init__(self):
+                super(TestModule, self).__init__()
 
-        # Note: call to pm(x) appears as ^<python_value>() in the trace.
-        # Parameters are NOT inlined.
-        self.assertExpected(str(script_fn.graph))
+        code = '''
+        def foo(self, x, y):
+            # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
+            return x, x
+        '''
 
-    def test_call_traced_fn_from_script_fn(self):
-        @_trace(torch.rand(3, 4))
-        def traced_fn(x):
-            return torch.neg(x)
+        test_str = []
+        for pair in self.type_input_return_pairs():
+            tm = TestModule()
+            tm.define(self.format_code(code, pair))
+            test_str.append(tm.__getattr__('foo').pretty_print_schema())
+        self.assertExpected("\n".join(test_str))
 
-        @torch.jit.script
-        def script_fn(x):
-            return traced_fn(x) + 1
+    # Helper function to eval Python3 code without causing a syntax error for
+    # this file under py2
+    def _get_py3_code(self, code, fn_name):
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            script_path = os.path.join(tmp_dir, 'script.py')
+            with open(script_path, 'w') as f:
+                f.write(code)
+            import importlib.util
+            spec = importlib.util.spec_from_file_location(fn_name, script_path)
+            module = importlib.util.module_from_spec(spec)
+            spec.loader.exec_module(module)
+            fn = getattr(module, fn_name)
+            return fn
 
-        # Note: the neg op from traced_fn should be properly inlined into the
-        # script function's graph
-        self.assertExpected(str(script_fn.graph))
+    #  Python AST Frontend , Python 3-style type annotations , Script function
+    @unittest.skipIf(not PY35, "Python 3.5 needed")
+    def test_annot_ast_py3_fn(self):
+        code = dedent('''
+            from typing import Tuple, List, Optional
+            from torch import Tensor
+            from torch.jit.annotations import BroadcastingList2, BroadcastingList3
+            import torch
+            @torch.jit.script
+            def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
+                return x, x
+        ''')
+        test_str = []
+        for pair in self.type_input_return_pairs():
+            fn = self._get_py3_code(self.format_code(code, pair), 'foo')
+            test_str.append(fn.__getattr__('forward').pretty_print_schema())
+        self.assertExpected("\n".join(test_str))
 
-    def test_call_traced_mod_from_script_fn(self):
-        class TracedModule(torch.nn.Module):
-            def __init__(self):
-                super(TracedModule, self).__init__()
+    #  Python AST Frontend , Python 3-style type annotations , Script method
+    @unittest.skipIf(not PY35, "Python 3.5 needed")
+    def test_annot_ast_py3_method(self):
+        code = dedent('''
+            from typing import Tuple, List, Optional
+            from torch import Tensor
+            from torch.jit.annotations import BroadcastingList2, \\
+                BroadcastingList3
+            import torch
+            class FooModule(torch.jit.ScriptModule):
+                @torch.jit.script_method
+                def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
+                    return x, x
+            instance = FooModule()
+        ''')
 
-            def forward(self, x):
-                return torch.mm(x, torch.zeros(4, 3))
+        test_str = []
+        for pair in self.type_input_return_pairs():
+            fn = self._get_py3_code(self.format_code(code, pair), 'instance')
+            test_str.append(fn.__getattr__('foo').pretty_print_schema())
+        self.assertExpected("\n".join(test_str))
 
-        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
+    #  Python AST Frontend , MyPy-style type comments , Script function
+    @unittest.skipIf(not PY35, "Python 3.5 needed")
+    def test_annot_ast_mypy_fn(self):
+        code = dedent('''
+            import torch
+            @torch.jit.script
+            def foo(x, y):
+                # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
+                return x, x
+        ''')
 
-        @torch.jit.script
-        def script_fn(x):
-            return tm(x) + 1
+        test_str = []
+        for pair in self.type_input_return_pairs():
+            fn = self._get_py3_code(self.format_code(code, pair), 'foo')
+            test_str.append(fn.__getattr__('forward').pretty_print_schema())
+        self.assertExpected("\n".join(test_str))
 
-        self.assertExpected(str(script_fn.graph))
+    #  Python AST Frontend , MyPy-style type comments , Script method
+    @unittest.skipIf(not PY35, "Python 3.5 needed")
+    def test_annot_ast_mypy_method(self):
+        code = dedent('''
+            import torch
+            class FooModule(torch.jit.ScriptModule):
+                @torch.jit.script_method
+                def foo(self, x, y):
+                    # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
+                    return x, x
+            instance = FooModule()
+        ''')
 
-    def test_call_script_fn_from_script_fn(self):
-        @torch.jit.script
-        def script_fn1(x):
-            return torch.neg(x)
+        test_str = []
+        for pair in self.type_input_return_pairs():
+            fn = self._get_py3_code(self.format_code(code, pair), 'instance')
+            test_str.append(fn.__getattr__('foo').pretty_print_schema())
+        self.assertExpected("\n".join(test_str))
 
-        @torch.jit.script
-        def script_fn(x):
-            return script_fn1(x) + 1
+    def test_method_casts_script(self):
+        cast_types = [
+            'byte', 'char', 'double', 'float', 'int', 'long', 'short'
+        ]
 
-        # Note: the neg op from script_fn1 should be properly inlined into the
-        # graph of script_fn
-        self.assertExpected(canonical(script_fn.graph))
+        for cast_type in cast_types:
+            cu = torch.jit.CompilationUnit('''
+            def cast_to(x):
+                return x.{cast_type}()
+            '''.format(cast_type=cast_type))
 
-    def test_call_script_mod_from_script_fn(self):
-        class ScriptMod(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ScriptMod, self).__init__()
+            x = torch.rand(3, 4, 5) * 128
+            cu_result = cu.cast_to(x)
+            reference = getattr(x, cast_type)()
+            self.assertEqual(cu_result, reference)
 
-            @torch.jit.script_method
+    def test_listconstruct_erasure(self):
+        class FooMod(torch.nn.Module):
             def forward(self, x):
-                return torch.mm(x, torch.zeros([4, 3]))
+                mask = x < 0.0
+                return x[mask]
 
-        sm = ScriptMod()
+        import io
+        f = io.BytesIO()
+        self.assertExpected(torch.onnx.export_to_pretty_string(
+            FooMod(), (torch.rand(3, 4),), f,
+            operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK))
 
-        @torch.jit.script
-        def script_fn(x):
-            return sm(x) + 1
+    def test_trace_checker_arange_as_constant(self):
+        with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
+            @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)])
+            def foo(x):
+                y = torch.arange(0, x.shape[0]).double()
+                return x + y.unsqueeze(1)
 
-        self.assertExpected(canonical(script_fn.graph))
+    @suppress_warnings
+    def test_trace_checker_dot_data(self):
+        with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value '
+                                                                 r'across invocations'):
+            @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
+            def foo(x):
+                y = x.data
+                return x + y
 
-    def test_call_python_fn_from_script_module(self):
-        def python_fn(x):
-            return torch.neg(x)
+    @suppress_warnings
+    def test_trace_checker_control_flow(self):
+        def foo(x):
+            for _ in range(x.size(0)):
+                x = torch.neg(x)
+            return x
 
-        class ScriptMod(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ScriptMod, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 3))
+        with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
+            torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return python_fn(torch.mm(x, self.param))
+    @suppress_warnings
+    def test_trace_checker_memoization(self):
+        with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
+            def foo(x):
+                if not hasattr(foo, 'cache'):
+                    foo.cache = torch.neg(x)
+                return x + foo.cache
 
-        sm = ScriptMod()
-        self.assertExpected(str(sm.__getattr__('forward').graph))
+            traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
 
-    def test_call_python_mod_from_script_module(self):
-        class PythonMod(torch.nn.Module):
-            def __init__(self):
-                super(PythonMod, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(3, 5))
+    def checkTracerWarning(self, *args, **kwargs):
+        with warnings.catch_warnings(record=True) as warns:
+            torch.jit.trace(*args, **kwargs)
+        self.assertGreater(len(warns), 0)
+        for warn in warns:
+            self.assertIn("cause the trace to be incorrect", str(warn.message))
 
-            def forward(self, x):
-                return torch.mm(x, self.param)
+    def test_trace_checker_slice_lhs(self):
+        def foo(x):
+            for i in range(3):
+                x[i, :] = torch.zeros(4)
+            return x
 
-        class ScriptMod(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ScriptMod, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 3))
-                self.pm = PythonMod()
+        self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(3, 4)]),
+                              'Output nr 1. of the traced function does not match the '
+                              'corresponding output of the Python function')
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return self.pm(torch.mm(x, self.param))
+    def test_trace_checker_inplace_on_view(self):
+        def foo(x):
+            x.view(-1).add_(-x.view(-1))
+            return x
 
-        sm = ScriptMod()
-        # Note: the call into PythonMod appears as ^<python_value>(). Parameters
-        # are NOT inlined
-        self.assertExpected(str(sm.graph))
+        self.assertWarnsRegex(lambda: torch.jit.trace(foo,
+                                                      torch.rand(3, 4),
+                                                      check_inputs=[torch.rand(5, 6)],
+                                                      _force_outplace=True),
+                              'Output nr 1. of the traced function does not match the '
+                              'corresponding output of the Python function')
 
-    def test_call_tracing_fn_from_script_module(self):
-        @_trace(torch.rand(3, 3))
-        def traced_fn(x):
-            return torch.neg(x)
+    def test_lhs_index_fails(self):
+        def foo(x):
+            x[0, 1] = 4
+            return x
+        self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
 
-        class ScriptMod(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ScriptMod, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 3))
+    def test_lhs_index_trivial(self):
+        def foo(y, x):
+            y[...] = x
+            return y
+        self.checkTrace(foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return traced_fn(torch.mm(x, self.param))
+    def test_inplace_warn(self):
+        def foo(x):
+            x.view(-1).add_(-x.view(-1))
+            return x
+        self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
 
-        sm = ScriptMod()
-        self.assertExpected(str(sm.__getattr__('forward').graph))
+    @suppress_warnings
+    def test_trace_checker_dropout_train(self):
+        def foo(x):
+            return torch.dropout(x, p=0.5, train=True)
 
-    def test_call_tracing_mod_from_script_module(self):
-        class TracedMod(torch.nn.Module):
-            def __init__(self):
-                super(TracedMod, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(3, 5))
+        self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
+                              'Output nr 1. of the traced function does not match the '
+                              'corresponding output of the Python function')
+        self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
+                              'Trace had nondeterministic nodes')
 
-            def forward(self, x):
-                return torch.mm(x, self.param)
+    def test_trace_checker_dropout_notrain(self):
+        input = torch.rand(3, 4)
 
-        class ScriptMod(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ScriptMod, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 3))
-                self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3))
+        @_trace(input)
+        def foo(x):
+            return torch.dropout(x, p=0.5, train=False)
+
+        self.assertEqual(foo(input), input)
 
+    def test_export_dynamic_slice(self):
+        class DynamicSliceExportMod(torch.jit.ScriptModule):
             @torch.jit.script_method
             def forward(self, x):
-                return self.tm(torch.mm(x, self.param))
+                retval = x[0]
+                for i in range(x.size(1)):
+                    retval += torch.sum(x[0:i], dim=0)
+                return retval
 
-        sm = ScriptMod()
-        # Note: the parameters from both modules should appear in the flattened
-        # input list to the graph. The mm op from TracedMod should be properly
-        # inlined
-        self.assertExpected(str(sm.graph))
+        mod = DynamicSliceExportMod()
 
-    def test_call_script_fn_from_script_module(self):
-        @torch.jit.script
-        def script_fn(x):
-            return torch.neg(x)
+        input = torch.rand(3, 4, 5)
+        example_outs = mod(input)
 
-        class ScriptMod(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ScriptMod, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 3))
+        f = io.BytesIO()
+        exported = torch.onnx.export_to_pretty_string(
+            DynamicSliceExportMod(), (input,), f, example_outputs=example_outs)
+        self.assertExpected(exported)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return script_fn(torch.mm(x, self.param))
+    def test_string_frontend_elif(self):
+        code = '''
+            def elif_test(niter : int):
+                rv = 0
+                for i in range(niter):
+                    if i % 3 == 0 and i % 5 == 0:
+                        rv += 35
+                    elif i % 3 == 0:
+                        rv += 3
+                    elif i % 5 == 0:
+                        rv += 5
+                    else:
+                        rv += i
+                return rv
+        '''
 
-        sm = ScriptMod()
-        self.assertExpected(canonical(sm.__getattr__('forward').graph))
+        self.checkScript(code, (101,), name='elif_test', outputs=3028)
 
-    def test_call_script_mod_from_script_module(self):
-        class ScriptMod1(torch.jit.ScriptModule):
-            def __init__(self):
-                super(ScriptMod1, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(3, 5))
+    def test_addmm_fusion(self):
+        class AddmmWrapper(torch.nn.Module):
+            def forward(self, x, y, c):
+                return torch.mm(x, y) + c
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return torch.mm(x, self.param)
+        # Test addmm fusion is disabled for normal Jit
+        x, y, c = torch.rand(3, 4), torch.rand(4, 5), torch.rand(3, 5)
+        f = io.BytesIO()
+        pretty = torch.onnx.export_to_pretty_string(AddmmWrapper(), (x, y, c), f)
+        self.assertExpected(pretty, 'onnx')
 
-        class ScriptMod(torch.jit.ScriptModule):
+        jit_trace = torch.jit.trace(AddmmWrapper(), (x, y, c))
+        ge_graph = jit_trace.__getattr__('forward').graph_for(x, y, c)
+        self.assertExpectedGraph(ge_graph, 'jit')
+
+    def test_pyop_exception_message(self):
+        class Foo(torch.jit.ScriptModule):
             def __init__(self):
-                super(ScriptMod, self).__init__()
-                self.param = torch.nn.Parameter(torch.rand(4, 3))
-                self.tm = ScriptMod1()
+                super(Foo, self).__init__()
+                self.conv = nn.Conv2d(1, 10, kernel_size=5)
 
             @torch.jit.script_method
             def forward(self, x):
-                return self.tm(torch.mm(x, self.param))
-
-        sm = ScriptMod()
-        # Note: the parameters from both modules should appear in the flattened
-        # input list to the graph. The mm op from ScriptMod1 should be properly
-        # inlined
-        self.assertExpected(canonical(sm.graph))
-
-    def test_module_with_params_called_fails(self):
-        with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with parameters. Stateful "
-                                                  "modules to be inlined must be submodules of the callee."):
-            class ScriptMod(torch.jit.ScriptModule):
-                def __init__(self):
-                    super(ScriptMod, self).__init__()
-                    self.param = torch.nn.Parameter(torch.rand(3, 3))
-
-                @torch.jit.script_method
-                def forward(self, x):
-                    return torch.mm(x, self.param)
-
-            sm = ScriptMod()
+                return self.conv(x)
+        foo = Foo()
+        # testing that the correct error message propagates
+        with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
+            foo(torch.ones([123]))  # wrong size
 
-            @torch.jit.script
-            def some_func(x):
-                return sm(x)
+    def test_exceptions(self):
+        cu = torch.jit.CompilationUnit('''
+            def foo(cond):
+                if bool(cond):
+                    raise ValueError(3)
+                return 1
+        ''')
 
-    def test_index_put_trace_with_view(self):
-        @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
-        def test_index_put(target, indices, rhs):
-            target[indices] = rhs
-            return target
+        cu.foo(torch.tensor(0))
+        with self.assertRaisesRegex(torch.jit.Error, "Exception"):
+            cu.foo(torch.tensor(1))
 
-        self.assertExpectedGraph(test_index_put.graph)
+        @torch.jit.script
+        def foo(cond):
+            a = 3
+            if bool(cond):
+                raise ArbitraryError(a, "hi")
+                if False:
+                    raise ArbitraryError
+            return a
 
-    def test_index_put_trace_without_view(self):
-        @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
-        def test_index_put(target, indices, rhs):
-            target[indices] = rhs
-            return target
+        foo(torch.tensor(0))
+        # we don't currently validate the name of the exception
+        with self.assertRaisesRegex(torch.jit.Error, "Exception"):
+            foo(torch.tensor(1))
 
-        self.assertExpectedGraph(test_index_put.graph)
+        @torch.jit.script
+        def foo_except_used():
+            a = Exception()
+            print(a)
+            raise a
 
-    def test_tuple_indexing(self):
-        def tuple_index(a):
-            if bool(a):
-                b = (1, 2)
-            else:
-                b = (0, 2)
-            return b[-2], b[1]
+        # a not DCEd
+        with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"):
+            foo_except_used()
 
-        self.checkScript(tuple_index, (torch.tensor([1]),))
-        self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True)
-        tuple_comp = torch.jit.script(tuple_index)
-        self.assertExpectedGraph(tuple_comp.graph)
-        self.run_pass('lower_all_tuples', tuple_comp.graph)
-        m = torch.jit.ScriptModule()
-        m._create_method_from_graph("forward", tuple_comp.graph)
-        self.assertEqual(m(torch.tensor(1)), (1, 2))
+        # We don't validate the expr following raise
+        @torch.jit.script
+        def foo():
+            raise 3 + 4
 
-        with self.assertRaisesRegex(RuntimeError, "tuple indices must be integer constants"):
+        # no control flow analysis yet
+        with self.assertRaisesRegex(RuntimeError, "undefined value a"):
             @torch.jit.script
-            def test_non_constant_input(a):
-                if bool(a):
-                    b = 1
+            def foo():
+                if True:
+                    a = 1
                 else:
-                    b = 0
-                c = (0, 1)
-                return c[b]
-
-        def test_indexing_float():
-            c = (1, 2)
-            return c[0.1]
-        self.checkScriptRaisesRegex(test_indexing_float, (), Exception,
-                                    "tuple indices must")
+                    raise Exception("Hi")
+                return a
 
-        def test_indexing_out_of_bounds_pos():
-            c = (1, 2)
-            return c[2]
+    def test_assertions(self):
+        cu = torch.jit.CompilationUnit('''
+            def foo(cond):
+                assert bool(cond), "hi"
+                return 0
+        ''')
 
-        self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
-                                    "out of range")
+        cu.foo(torch.tensor(1))
+        with self.assertRaisesRegex(torch.jit.Error, "Exception"):
+            cu.foo(torch.tensor(0))
 
-        def test_indexing_out_of_bounds_neg():
-            c = (1, 2)
-            return c[-3]
+        @torch.jit.script
+        def foo(cond):
+            assert bool(cond), "hi"
 
-        self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
-                                    "out of range")
+        foo(torch.tensor(1))
+        # we don't currently validate the name of the exception
+        with self.assertRaisesRegex(torch.jit.Error, "Exception"):
+            foo(torch.tensor(0))
 
-    def test_tuple_slicing(self):
-        def tuple_slice(a):
-            if bool(a):
-                b = (1, 2, 3, 4)
-            else:
-                b = (4, 3, 2, 1)
-            c = b[-4:4]
-            d = b[0:]
-            e = c[1:-1]
-            return e
+    def test_weak_script_function(self):
+        outer_var = 10
+        outer_var2 = 11
 
-        self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
-        tuple_comp = torch.jit.script(tuple_slice)
-        self.assertExpectedGraph(tuple_comp.graph)
-        self.run_pass('lower_all_tuples', tuple_comp.graph)
-        self.assertTrue('Tuple' not in str(tuple_comp.graph))
-        m = torch.jit.ScriptModule()
-        m._create_method_from_graph("forward", tuple_comp.graph)
-        self.assertEqual(m(torch.tensor(1)), (2, 3))
+        def not_a_script_fn(x):
+            return x + 2
 
         @torch.jit.script
-        def test_indexing_end_out_of_bounds():
-            c = (1, 2)
-            return c[2:10]
-
-        # output is None in script and () in python
-        self.assertEqual(test_indexing_end_out_of_bounds(), None)
+        def even_more_inner(x):
+            return x + 1
 
-    def test_unwrap_optional_builtin(self):
-        def test(x):
-            # type: (Optional[int]) -> int
-            x = torch.jit._unwrap_optional(x)
-            x = x + x
-            return x
+        @torch.jit.script
+        def inner(x):
+            return not_a_script_fn(x) + x + even_more_inner(x)
 
-        self.checkScript(test, (3,))
+        @torch.jit.script
+        def strong_script_fn(x):
+            if bool(x.norm() > 2):
+                x = x + 3
+            return x + 4 + inner(x)
 
-        with self.assertRaisesRegex(AssertionError, "Unwrapping null optional"):
-            test(None)
+        @torch._jit_internal.weak_script
+        def weak_script_fn_inner(x):
+            return x + 6 + not_a_script_fn(x)
 
-        test_script = torch.jit.script(test)
-        with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
-            test_script(None)
+        @torch._jit_internal.weak_script
+        def weak_script_fn(x):
+            return x + 5 + weak_script_fn_inner(x) + weak_script_fn_inner(x)
 
-        @torch.jit.script
-        def test_test():
-            return torch.jit._unwrap_optional(1)
+        def fn(x):
+            x = not_a_script_fn(x)
+            x = strong_script_fn(x)
+            return weak_script_fn(x)
 
-        with self.assertRaisesRegex(RuntimeError, "cannot match an Optional\\[T\\] to None"):
-            @torch.jit.script
-            def test_no_type():
-                # type: () -> int
-                return torch.jit._unwrap_optional(None)
+        input = torch.randn(3, 4, 5)
+        self.checkScript(fn, (input,))
 
-    def test_indexing_error(self):
-        with self.assertRaisesRegex(RuntimeError, "Indexing only supported on lists, tensors, and tuples"):
-            @torch.jit.script
-            def test_wrong_type():
-                a = 8
-                return a[0]
+    def test_python_op_exception(self):
+        def python_op(x):
+            raise Exception("bad!")
 
-    def test_annotated_script_fn(self):
         @torch.jit.script
-        def foo(x, y, z):
-            # type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
-            return x
+        def fn(x):
+            return python_op(x)
 
-        self.assertExpected(foo.__getattr__('forward').pretty_print_schema())
+        with self.assertRaisesRegex(RuntimeError, "operation failed in interpreter"):
+            fn(torch.tensor(4))
 
-    def test_annotated_script_method(self):
-        class SM(torch.jit.ScriptModule):
-            @torch.jit.script_method
-            def forward(self, x, y):
-                # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor, Tensor]
-                return y, y, y
+    def test_trace_contiguous(self):
+        def foo(x):
+            return x[:, :, ::2].contiguous().view(12)
 
-        sm = SM()
+        x = torch.rand(2, 3, 4)
+        traced = torch.jit.trace(foo, (x,))
+        y = traced(x)
+        self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())
 
-        self.assertExpected(sm.__getattr__('forward').pretty_print_schema())
+    # This tests the logic in THPVariable_contiguous. There is short-circuiting
+    # code that prevents us from even getting to VariableType::contiguous, since
+    # it is an optimization that prevents us from acquiring the GIL for touching
+    # the device. We needed to add the tracing logic directly into the
+    # THPVariable_contiguous function only for the path where we are skipping
+    # dispatch into contiguous. We should see an aten::contiguous in this trace!
+    def test_trace_contiguous_short_circuit(self):
+        def foo(x):
+            return x.contiguous()
 
-    def test_annotated_script_fn_return_mismatch(self):
-        with self.assertRaisesRegex(RuntimeError, r"Return value at position 0 was annotated as "
-                                                  r"having type \(Tensor, Tensor\) but is "
-                                                  r"actually of type Tensor"):
-            @torch.jit.script
-            def return_tup(x):
-                # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
-                return x, x
+        x = torch.rand(2, 3, 4)
+        traced = torch.jit.trace(foo, (x,))
+        self.assertExpectedGraph(traced.graph)
 
-    def test_annotated_script_fn_arg_mismatch(self):
-        with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
-            @torch.jit.script
-            def tuple_arg(x):
-                # type: (Tuple[Tensor, Tensor]) -> Tensor
-                return x + 1
+    def test_weak_module(self):
 
-    def test_script_non_tensor_args_outputs(self):
-        @torch.jit.script
-        def fn(x, y):
-            # type: (Tensor, float) -> float
-            return float((x + y).sum())
+        @torch._jit_internal.weak_module
+        class Weak(torch.nn.Module):
+            __constants__ = ['number']
 
-        x = torch.ones(2, 2)
-        z = fn(x, 1)
-        self.assertIsInstance(z, float)
-        self.assertEqual(z, 8.)
+            def __init__(self):
+                super(Weak, self).__init__()
+                self.number = 199
 
-    @unittest.skip('https://github.com/pytorch/pytorch/issues/9595')
-    def test_inline_and_run_annotated_script_fn(self):
-        @torch.jit.script
-        def to_inline(x, y):
-            # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
-            return y
+            def python_op_in_weak_module(self, x):
+                return x + 123
 
-        @torch.jit.script
-        def some_func(x):
-            return to_inline((x, x), x)
+            @torch._jit_internal.weak_script_method
+            def forward(self, x):
+                return 55 + self.number + self.python_op_in_weak_module(x)
 
-        x = torch.rand(3, 4)
-        self.assertEqual(some_func(x), x)
+        class OtherStrong(torch.jit.ScriptModule):
+            __constants__ = ['number']
 
-    def test_file_format_serialization(self):
-        import tempfile
-        filename = tempfile.mktemp()
-        writer = torch._C.PyTorchFileWriter(filename)
-        import os
-        import random
-        buffers = [os.urandom(size) for size in [random.randint(1, 100) for i in range(20)]]
-        offsets = []
-        for i, buf in enumerate(buffers):
-            writer.write_record(str(i), buf, len(buf))
-            offsets.append(i)
-        import pickle
-        serialized_offsets = pickle.dumps(offsets)
-        writer.write_record("meta", serialized_offsets, len(serialized_offsets))
-        writer.write_end_of_file()
+            def __init__(self):
+                super(OtherStrong, self).__init__()
+                self.number = 357
 
-        reader = torch._C.PyTorchFileReader(filename)
-        serialized_offsets_read = reader.get_record("meta")
-        parsed_serialized_offsets = pickle.loads(serialized_offsets)
+            def python_op_in_strong_module(self, x):
+                return x + 456
 
-        for i, offset in enumerate(parsed_serialized_offsets):
-            data = reader.get_record(str(offset))
-            assert(data == buffers[i])
+            @torch.jit.script_method
+            def forward(self, x):
+                return x + self.number + self.python_op_in_strong_module(x)
 
-    # for each type, the input type annotation and corresponding return type annotation
-    def type_input_return_pairs(self):
-        return [
-            ('Tensor', 'Tensor'),
-            ('torch.Tensor', 'Tensor'),
-            ('str', 'str'),
-            ('int', 'int'),
-            ('bool', 'bool'),
-            ('BroadcastingList3[float]', 'List[float]'),
-            ('BroadcastingList2[int]', 'List[int]'),
-            ('List[int]', 'List[int]'),
-            ('Optional[int]', 'Optional[int]'),
-        ]
+        class Passthrough(torch.jit.ScriptModule):
+            def __init__(self):
+                super(Passthrough, self).__init__()
+                self.weak = Weak()
 
-    # replacing code input & return type pair
-    def format_code(self, code, pair):
-        return code.format(input=pair[0], output=pair[1])
+            @torch.jit.script_method
+            def forward(self, x):
+                return self.weak(x)
 
-    # ***** Type annotation tests ****
-    # Test combinations of:
-    # {String frontend, Python AST Frontend}
-    # {Python 3-style type annotations, MyPy-style type comments}
-    # {Script method, Script function}
+        weak_mod = Weak()
+        x = torch.ones(1)
+        expected_result = 55 + 199 + (x + 123)
 
-    #  String frontend , Python 3-style type annotations , Script function
-    def test_annot_string_py3_fn(self):
-        code = '''
-            def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
-                return x, x
-        '''
-        test_str = []
-        for pair in self.type_input_return_pairs():
-            cu = torch.jit.CompilationUnit(self.format_code(code, pair))
-            test_str.append(cu.__getattr__('foo').pretty_print_schema())
-        self.assertExpected("\n".join(test_str))
+        # Ensure weak mod is running without the JIT by passing the wrong type
+        # (i.e. not a tensor)
+        weak_mod(2)
 
-    #  String frontend , Python 3-style type annotations , Script method
-    def test_annot_string_py3_method(self):
-        class TestModule(torch.jit.ScriptModule):
+        python_result = weak_mod(x)
+        strong_mod = Passthrough()
+        script_result = strong_mod(x)
+
+        self.assertEqual(python_result, expected_result)
+        self.assertEqual(script_result, expected_result)
+
+        class Strong(torch.jit.ScriptModule):
             def __init__(self):
-                super(TestModule, self).__init__()
+                super(Strong, self).__init__()
+                self.weak = Weak()
+                self.strong = OtherStrong()
 
-        code = '''
-            def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
-                return x, x
-        '''
-        test_str = []
-        for pair in self.type_input_return_pairs():
-            tm = TestModule()
-            tm.define(self.format_code(code, pair))
-            test_str.append(tm.__getattr__('foo').pretty_print_schema())
-        self.assertExpected("\n".join(test_str))
+            @torch.jit.script_method
+            def forward(self, x):
+                y = 2 * x
+                return y + 1 + self.weak(y) + self.strong(y)
 
-    #  String frontend , MyPy-style type comments , Script function
-    def test_annot_string_mypy_fn(self):
-        code = '''
-            def foo(x, y):
-                # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
-                return x, x
-        '''
-        test_str = []
-        for pair in self.type_input_return_pairs():
-            cu = torch.jit.CompilationUnit(self.format_code(code, pair))
-            test_str.append(cu.__getattr__('foo').pretty_print_schema())
-        self.assertExpected("\n".join(test_str))
+        strong_mod = Strong()
+        strong_mod2 = Strong()
+        x = torch.ones(1)
+        expected_result = (x * 2) + 1 + (55 + 199 + x * 2 + 123) + (x * 2 + 357 + x * 2 + 456)
+        script_result = strong_mod(x)
+        script_result2 = strong_mod2(x)
+        self.assertEqual(script_result, expected_result)
+        self.assertEqual(script_result, script_result2)
 
-    #  String frontend , MyPy-style type comments , Script method
-    def test_annot_string_mypy_method(self):
-        class TestModule(torch.jit.ScriptModule):
-            def __init__(self):
-                super(TestModule, self).__init__()
+    def test_weak_module_parameters_and_buffers(self):
+        weights = torch.randn(10, 10)
+        bias = torch.randn(10)
+        weights2 = torch.randn(10, 10)
+        bias2 = torch.randn(10)
 
-        code = '''
-        def foo(self, x, y):
-            # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
-            return x, x
-        '''
+        @torch._jit_internal.weak_module
+        class TestLinear(torch.nn.Module):
+            def __init__(self, in_features, out_features):
+                super(TestLinear, self).__init__()
+                self.in_features = in_features
+                self.out_features = out_features
+                self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
+                self.bias = torch.nn.Parameter(torch.Tensor(out_features))
+                self.register_buffer('counter', torch.ones(out_features))
+                self.reset_parameters()
 
-        test_str = []
-        for pair in self.type_input_return_pairs():
-            tm = TestModule()
-            tm.define(self.format_code(code, pair))
-            test_str.append(tm.__getattr__('foo').pretty_print_schema())
-        self.assertExpected("\n".join(test_str))
+            def reset_parameters(self):
+                torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+                if self.bias is not None:
+                    fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
+                    bound = 1 / math.sqrt(fan_in)
+                    torch.nn.init.uniform_(self.bias, -bound, bound)
 
-    # Helper function to eval Python3 code without causing a syntax error for
-    # this file under py2
-    def _get_py3_code(self, code, fn_name):
-        with tempfile.TemporaryDirectory() as tmp_dir:
-            script_path = os.path.join(tmp_dir, 'script.py')
-            with open(script_path, 'w') as f:
-                f.write(code)
-            import importlib.util
-            spec = importlib.util.spec_from_file_location(fn_name, script_path)
-            module = importlib.util.module_from_spec(spec)
-            spec.loader.exec_module(module)
-            fn = getattr(module, fn_name)
-            return fn
+            @torch._jit_internal.weak_script_method
+            def forward(self, input):
+                return F.linear(input, self.weight, self.bias) + self.counter
 
-    #  Python AST Frontend , Python 3-style type annotations , Script function
-    @unittest.skipIf(not PY35, "Python 3.5 needed")
-    def test_annot_ast_py3_fn(self):
-        code = dedent('''
-            from typing import Tuple, List, Optional
-            from torch import Tensor
-            from torch.jit.annotations import BroadcastingList2, BroadcastingList3
-            import torch
-            @torch.jit.script
-            def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
-                return x, x
-        ''')
-        test_str = []
-        for pair in self.type_input_return_pairs():
-            fn = self._get_py3_code(self.format_code(code, pair), 'foo')
-            test_str.append(fn.__getattr__('forward').pretty_print_schema())
-        self.assertExpected("\n".join(test_str))
+        # Initialize a ScriptModule that uses the weak module above multiple times
+        class Strong(torch.jit.ScriptModule):
+            def __init__(self):
+                super(Strong, self).__init__()
+                self.fc1 = TestLinear(10, 10)
+                self.fc1.weight = torch.nn.Parameter(weights)
+                self.fc1.bias = torch.nn.Parameter(bias)
+                self.fc2 = TestLinear(10, 10)
+                self.fc2.weight = torch.nn.Parameter(weights2)
+                self.fc2.bias = torch.nn.Parameter(bias2)
 
-    #  Python AST Frontend , Python 3-style type annotations , Script method
-    @unittest.skipIf(not PY35, "Python 3.5 needed")
-    def test_annot_ast_py3_method(self):
-        code = dedent('''
-            from typing import Tuple, List, Optional
-            from torch import Tensor
-            from torch.jit.annotations import BroadcastingList2, \\
-                BroadcastingList3
-            import torch
-            class FooModule(torch.jit.ScriptModule):
-                @torch.jit.script_method
-                def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
-                    return x, x
-            instance = FooModule()
-        ''')
+            @torch.jit.script_method
+            def forward(self, x):
+                return x + self.fc1(x) + self.fc1(x) + self.fc2(x)
 
-        test_str = []
-        for pair in self.type_input_return_pairs():
-            fn = self._get_py3_code(self.format_code(code, pair), 'instance')
-            test_str.append(fn.__getattr__('foo').pretty_print_schema())
-        self.assertExpected("\n".join(test_str))
+        strong_mod = Strong()
 
-    #  Python AST Frontend , MyPy-style type comments , Script function
-    @unittest.skipIf(not PY35, "Python 3.5 needed")
-    def test_annot_ast_mypy_fn(self):
-        code = dedent('''
-            import torch
-            @torch.jit.script
-            def foo(x, y):
-                # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
-                return x, x
-        ''')
+        # Run same calculation as module
+        inp = torch.ones(10)
+        lin = torch.nn.Linear(10, 10)
+        lin.weight = torch.nn.Parameter(weights)
+        lin.bias = torch.nn.Parameter(bias)
+        lin2 = torch.nn.Linear(10, 10)
+        lin2.weight = torch.nn.Parameter(weights2)
+        lin2.bias = torch.nn.Parameter(bias2)
+        expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10)
 
-        test_str = []
-        for pair in self.type_input_return_pairs():
-            fn = self._get_py3_code(self.format_code(code, pair), 'foo')
-            test_str.append(fn.__getattr__('forward').pretty_print_schema())
-        self.assertExpected("\n".join(test_str))
+        self.assertEqual(strong_mod(inp), expected_result)
+        self.assertExportImportModule(strong_mod, (inp,))
 
-    #  Python AST Frontend , MyPy-style type comments , Script method
-    @unittest.skipIf(not PY35, "Python 3.5 needed")
-    def test_annot_ast_mypy_method(self):
-        code = dedent('''
-            import torch
-            class FooModule(torch.jit.ScriptModule):
-                @torch.jit.script_method
-                def foo(self, x, y):
-                    # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
-                    return x, x
-            instance = FooModule()
-        ''')
+    def test_weak_module_nested(self):
+        @torch._jit_internal.weak_module
+        class OtherWeak(torch.nn.Module):
+            __constants__ = ['constant']
 
-        test_str = []
-        for pair in self.type_input_return_pairs():
-            fn = self._get_py3_code(self.format_code(code, pair), 'instance')
-            test_str.append(fn.__getattr__('foo').pretty_print_schema())
-        self.assertExpected("\n".join(test_str))
+            def __init__(self, in_features, out_features):
+                super(OtherWeak, self).__init__()
+                self.in_features = in_features
+                self.out_features = out_features
+                self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
+                self.bias = torch.nn.Parameter(torch.ones(out_features))
+                self.constant = 3
 
-    def test_method_casts_script(self):
-        cast_types = [
-            'byte', 'char', 'double', 'float', 'int', 'long', 'short'
-        ]
+            @torch._jit_internal.weak_script_method
+            def forward(self, x):
+                return x * x + self.constant + F.linear(x, self.weight, self.bias)
 
-        for cast_type in cast_types:
-            cu = torch.jit.CompilationUnit('''
-            def cast_to(x):
-                return x.{cast_type}()
-            '''.format(cast_type=cast_type))
+        class OtherStrong(torch.jit.ScriptModule):
 
-            x = torch.rand(3, 4, 5) * 128
-            cu_result = cu.cast_to(x)
-            reference = getattr(x, cast_type)()
-            self.assertEqual(cu_result, reference)
+            def __init__(self):
+                super(OtherStrong, self).__init__()
 
-    def test_listconstruct_erasure(self):
-        class FooMod(torch.nn.Module):
+            @torch.jit.script_method
             def forward(self, x):
-                mask = x < 0.0
-                return x[mask]
+                return x + 27
 
-        import io
-        f = io.BytesIO()
-        self.assertExpected(torch.onnx.export_to_pretty_string(
-            FooMod(), (torch.rand(3, 4),), f,
-            operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK))
+        @torch._jit_internal.weak_module
+        class Weak(torch.nn.Module):
+            def __init__(self, in_features, out_features):
+                super(Weak, self).__init__()
+                self.in_features = in_features
+                self.out_features = out_features
+                self.weight = torch.nn.Parameter(2 * torch.ones(out_features, in_features))
+                self.bias = torch.nn.Parameter(2 * torch.ones(out_features))
+                self.weak_submodule = OtherWeak(10, 10)
+                self.strong_submodule = OtherStrong()
 
-    def test_trace_checker_arange_as_constant(self):
-        with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
-            @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)])
-            def foo(x):
-                y = torch.arange(0, x.shape[0]).double()
-                return x + y.unsqueeze(1)
+            @torch._jit_internal.weak_script_method
+            def forward(self, x):
+                return x + self.weak_submodule(x) + self.strong_submodule(x) \
+                    + F.linear(x, self.weight, self.bias)
 
-    @suppress_warnings
-    def test_trace_checker_dot_data(self):
-        with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value '
-                                                                 r'across invocations'):
-            @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
-            def foo(x):
-                y = x.data
-                return x + y
+        class Strong(torch.jit.ScriptModule):
+            __constants__ = ['constant']
 
-    @suppress_warnings
-    def test_trace_checker_control_flow(self):
-        def foo(x):
-            for _ in range(x.size(0)):
-                x = torch.neg(x)
-            return x
+            def __init__(self):
+                super(Strong, self).__init__()
+                self.weak = Weak(10, 10)
 
-        with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
-            torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])
+            @torch.jit.script_method
+            def forward(self, x):
+                return x + self.weak(x)
 
-    @suppress_warnings
-    def test_trace_checker_memoization(self):
-        with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
-            def foo(x):
-                if not hasattr(foo, 'cache'):
-                    foo.cache = torch.neg(x)
-                return x + foo.cache
+        strong_mod = Strong()
+        inp = torch.randn(10)
+        result = strong_mod(inp)
+        expected_result = inp + (inp + inp * inp + inp + 27) + 3 \
+            + F.linear(inp, torch.ones(10, 10), torch.ones(10)) \
+            + F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
+        self.assertEqual(result, expected_result)
 
-            traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
+    def test_weak_module_submodule(self):
+        @torch._jit_internal.weak_module
+        class Weak(torch.nn.Module):
+            def __init__(self):
+                super(Weak, self).__init__()
+                self.param = torch.nn.Parameter(100 * torch.ones(5))
 
-    def checkTracerWarning(self, *args, **kwargs):
-        with warnings.catch_warnings(record=True) as warns:
-            torch.jit.trace(*args, **kwargs)
-        self.assertGreater(len(warns), 0)
-        for warn in warns:
-            self.assertIn("cause the trace to be incorrect", str(warn.message))
+            @torch._jit_internal.weak_script_method
+            def forward(self, x):
+                return x + self.param
+
+        weak = Weak()
 
-    def test_trace_checker_slice_lhs(self):
-        def foo(x):
-            for i in range(3):
-                x[i, :] = torch.zeros(4)
-            return x
+        class OtherStrong(torch.jit.ScriptModule):
+            def __init__(self):
+                super(OtherStrong, self).__init__()
+                self.weak = weak
+                self.weak2 = weak
 
-        self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(3, 4)]),
-                              'Output nr 1. of the traced function does not match the '
-                              'corresponding output of the Python function')
+            @torch.jit.script_method
+            def forward(self, x):
+                return x + self.weak(x)
 
-    def test_trace_checker_inplace_on_view(self):
-        def foo(x):
-            x.view(-1).add_(-x.view(-1))
-            return x
+        class Strong(torch.jit.ScriptModule):
+            def __init__(self):
+                super(Strong, self).__init__()
+                self.weak = Weak()
 
-        self.assertWarnsRegex(lambda: torch.jit.trace(foo,
-                                                      torch.rand(3, 4),
-                                                      check_inputs=[torch.rand(5, 6)],
-                                                      _force_outplace=True),
-                              'Output nr 1. of the traced function does not match the '
-                              'corresponding output of the Python function')
+            @torch.jit.script_method
+            def forward(self, x):
+                return self.weak(x) + weak(x)
 
-    def test_lhs_index_fails(self):
-        def foo(x):
-            x[0, 1] = 4
-            return x
-        self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
+        other_strong_mod = OtherStrong()
 
-    def test_lhs_index_trivial(self):
-        def foo(y, x):
-            y[...] = x
-            return y
-        self.checkTrace(foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False)
+        self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)
 
-    def test_inplace_warn(self):
-        def foo(x):
-            x.view(-1).add_(-x.view(-1))
-            return x
-        self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
+        with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"):
+            strong_mod = Strong()
 
-    @suppress_warnings
-    def test_trace_checker_dropout_train(self):
-        def foo(x):
-            return torch.dropout(x, p=0.5, train=True)
+    def test_weak_module_copying(self):
+        class Submodule(torch.nn.Module):
+            def __init__(self):
+                super(Submodule, self).__init__()
 
-        self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
-                              'Output nr 1. of the traced function does not match the '
-                              'corresponding output of the Python function')
-        self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
-                              'Trace had nondeterministic nodes')
+            def forward(self, x):
+                return x + 100
 
-    def test_trace_checker_dropout_notrain(self):
-        input = torch.rand(3, 4)
+        @torch._jit_internal.weak_module
+        class Weak(torch.nn.Module):
+            def __init__(self, in_features, out_features):
+                super(Weak, self).__init__()
+                self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
+                self.bias = torch.nn.Parameter(torch.ones(out_features))
+                self.register_buffer("buffer", torch.ones(out_features))
+                self.submodule = Submodule()
 
-        @_trace(input)
-        def foo(x):
-            return torch.dropout(x, p=0.5, train=False)
+            @torch._jit_internal.weak_script_method
+            def forward(self, x):
+                return F.linear(x, self.weight, self.bias) \
+                    + self.buffer + self.submodule(x)
 
-        self.assertEqual(foo(input), input)
+        class Strong(torch.jit.ScriptModule):
+            def __init__(self, weak):
+                super(Strong, self).__init__()
+                self.weak = weak
 
-    def test_export_dynamic_slice(self):
-        class DynamicSliceExportMod(torch.jit.ScriptModule):
             @torch.jit.script_method
             def forward(self, x):
-                retval = x[0]
-                for i in range(x.size(1)):
-                    retval += torch.sum(x[0:i], dim=0)
-                return retval
+                return self.weak(x)
 
-        mod = DynamicSliceExportMod()
+        inp = torch.ones(5, 5) * 5
+        weak_mod = Weak(5, 5)
+        strong_mod = Strong(weak_mod)
 
-        input = torch.rand(3, 4, 5)
-        example_outs = mod(input)
+        self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule))
+        self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule))
 
-        f = io.BytesIO()
-        exported = torch.onnx.export_to_pretty_string(
-            DynamicSliceExportMod(), (input,), f, example_outputs=example_outs)
-        self.assertExpected(exported)
+        self.assertIs(strong_mod.weak.weight, weak_mod.weight)
+        self.assertIs(strong_mod.weak.buffer, weak_mod.buffer)
+        self.assertIs(strong_mod.weak.submodule, weak_mod.submodule)
 
-    def test_string_frontend_elif(self):
-        code = '''
-            def elif_test(niter : int):
-                rv = 0
-                for i in range(niter):
-                    if i % 3 == 0 and i % 5 == 0:
-                        rv += 35
-                    elif i % 3 == 0:
-                        rv += 3
-                    elif i % 5 == 0:
-                        rv += 5
-                    else:
-                        rv += i
-                return rv
-        '''
+        # Test lookup fallback
+        weak_mod.new_attribute = 10
+        self.assertIs(strong_mod.weak.new_attribute, weak_mod.new_attribute)
 
-        self.checkScript(code, (101,), name='elif_test', outputs=3028)
+        weak_mod.weight.data += torch.ones(5, 5) * 100
+        self.assertTrue(strong_mod(inp).allclose(weak_mod(inp)))
 
-    def test_addmm_fusion(self):
-        class AddmmWrapper(torch.nn.Module):
-            def forward(self, x, y, c):
-                return torch.mm(x, y) + c
+        # Re-assignment is not tracked
+        weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
+        self.assertFalse(strong_mod(inp).allclose(weak_mod(inp)))
 
-        # Test addmm fusion is disabled for normal Jit
-        x, y, c = torch.rand(3, 4), torch.rand(4, 5), torch.rand(3, 5)
-        f = io.BytesIO()
-        pretty = torch.onnx.export_to_pretty_string(AddmmWrapper(), (x, y, c), f)
-        self.assertExpected(pretty, 'onnx')
+    def test_backend_cudnn_enabled(self):
+        # Only test that this compiles
+        @torch.jit.script
+        def fn(x):
+            if torch.backends.cudnn.enabled:
+                x = x + 2
+            else:
+                x = x + 3
+            return x
 
-        jit_trace = torch.jit.trace(AddmmWrapper(), (x, y, c))
-        ge_graph = jit_trace.__getattr__('forward').graph_for(x, y, c)
-        self.assertExpectedGraph(ge_graph, 'jit')
+    def test_inplace_add(self):
 
-    def test_pyop_exception_message(self):
-        class Foo(torch.jit.ScriptModule):
-            def __init__(self):
-                super(Foo, self).__init__()
-                self.conv = nn.Conv2d(1, 10, kernel_size=5)
+        def foo(a, b):
+            c = a + b
+            c.add_(b)
+            return c
+        self.checkScript(foo, (torch.rand(3), torch.rand(3)))
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return self.conv(x)
-        foo = Foo()
-        # testing that the correct error message propagates
-        with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
-            foo(torch.ones([123]))  # wrong size
+    def test_add_out(self):
+        def foo(a, b):
+            c = a + b
+            e = 2 * a
+            torch.add(c, b, out=e)
+            return e
+        self.checkScript(foo, (torch.rand(3), torch.rand(3)))
 
-    def test_exceptions(self):
-        cu = torch.jit.CompilationUnit('''
-            def foo(cond):
-                if bool(cond):
-                    raise ValueError(3)
-                return 1
-        ''')
+    def test_augmented_assign(self):
+        def foo(a, b):
+            a += b
+            a -= b
+            a /= b
+            a *= b
+            return a, b
+        self.checkScript(foo, (torch.rand(3), torch.rand(3)), check_expected=True)
 
-        cu.foo(torch.tensor(0))
-        with self.assertRaisesRegex(torch.jit.Error, "Exception"):
-            cu.foo(torch.tensor(1))
+    def test_pass(self):
+        def foo(x):
+            # type: (bool) -> int
+            for _i in range(3):
+                pass
+            if x:
+                pass
+            else:
+                pass
+            return 3
 
-        @torch.jit.script
-        def foo(cond):
-            a = 3
-            if bool(cond):
-                raise ArbitraryError(a, "hi")
-                if False:
-                    raise ArbitraryError
-            return a
+        self.checkScript(foo, (True,))
 
-        foo(torch.tensor(0))
-        # we don't currently validate the name of the exception
-        with self.assertRaisesRegex(torch.jit.Error, "Exception"):
-            foo(torch.tensor(1))
+    def test_optional_conversion(self):
+        @torch.jit.script
+        def other_fn(x=None):
+            # type: (Optional[int]) -> int
+            return torch.jit._unwrap_optional(x)
 
         @torch.jit.script
-        def foo_except_used():
-            a = Exception()
-            print(a)
-            raise a
+        def fn(x):
+            # type: (int) -> int
+            return other_fn(x)
 
-        # a not DCEd
-        with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"):
-            foo_except_used()
+        self.assertEqual(fn(2), 2)
 
-        # We don't validate the expr following raise
         @torch.jit.script
-        def foo():
-            raise 3 + 4
-
-        # no control flow analysis yet
-        with self.assertRaisesRegex(RuntimeError, "undefined value a"):
-            @torch.jit.script
-            def foo():
-                if True:
-                    a = 1
-                else:
-                    raise Exception("Hi")
-                return a
+        def unify_to_optional(x):
+            # type: (bool) -> Optional[int]
+            if x:
+                a = None
+            else:
+                a = 2
+            return a
 
-    def test_assertions(self):
-        cu = torch.jit.CompilationUnit('''
-            def foo(cond):
-                assert bool(cond), "hi"
-                return 0
-        ''')
+        self.assertEqual(unify_to_optional(True), None)
+        self.assertEqual(unify_to_optional(False), 2)
 
-        cu.foo(torch.tensor(1))
-        with self.assertRaisesRegex(torch.jit.Error, "Exception"):
-            cu.foo(torch.tensor(0))
+        @torch.jit.script
+        def opt_list(x):
+            # type: (Optional[List[float]]) -> int
+            return 2
 
         @torch.jit.script
-        def foo(cond):
-            assert bool(cond), "hi"
+        def broadcast_opt_list(x):
+            # type: (Optional[BroadcastingList2[float]]) -> int
+            return 2
 
-        foo(torch.tensor(1))
-        # we don't currently validate the name of the exception
-        with self.assertRaisesRegex(torch.jit.Error, "Exception"):
-            foo(torch.tensor(0))
+        @torch.jit.script
+        def opt_list_tuple_caller(x):
+            # type: (Tuple[float, float]) -> int
+            return opt_list(x) + broadcast_opt_list(x)
 
-    def test_weak_script_function(self):
-        outer_var = 10
-        outer_var2 = 11
+        self.assertEqual(opt_list_tuple_caller((2., 3.)), 4)
 
-        def not_a_script_fn(x):
-            return x + 2
+    def test_lhs_indexing(self):
+        def foo(a, b):
+            a = a.clone()
+            a[0] = b
+            return a
+        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
 
-        @torch.jit.script
-        def even_more_inner(x):
-            return x + 1
+    def test_lhs_advanced_indexing_assignment(self):
+        def foo(x, y):
+            a = torch.exp(x)
+            b = x == 1
+            a[b] = y[b]
+            return a
+        self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
 
-        @torch.jit.script
-        def inner(x):
-            return not_a_script_fn(x) + x + even_more_inner(x)
+    def test_lhs_advanced_indexing_augmented_assignment(self):
+        def foo(x, y):
+            a = torch.exp(x)
+            b = x == 1
+            a[b] += y[b]
+            return a
+        self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
 
-        @torch.jit.script
-        def strong_script_fn(x):
-            if bool(x.norm() > 2):
-                x = x + 3
-            return x + 4 + inner(x)
+    def test_lhs_indexing_list(self):
+        def foo(a, b):
+            ls = [a]
+            ls[0] = b
+            return ls
+        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
 
-        @torch._jit_internal.weak_script
-        def weak_script_fn_inner(x):
-            return x + 6 + not_a_script_fn(x)
+    def test_inplace_copy_script(self):
+        def foo(x):
+            a = torch.rand(3, 4)
+            a.copy_(x)
+            return a
+        self.checkScript(foo, (torch.rand(3, 4),))
 
-        @torch._jit_internal.weak_script
-        def weak_script_fn(x):
-            return x + 5 + weak_script_fn_inner(x) + weak_script_fn_inner(x)
+    def test_lhs_indexing_increment(self):
+        def foo(a, b):
+            a[0] += b
+            return a
+        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
 
-        def fn(x):
-            x = not_a_script_fn(x)
-            x = strong_script_fn(x)
-            return weak_script_fn(x)
+    def test_lhs_indexing_increment_list(self):
+        def foo(a, b):
+            a = a.clone()
+            ls = [a, b]
+            ls[0] += b
+            return ls
+        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
 
-        input = torch.randn(3, 4, 5)
-        self.checkScript(fn, (input,))
+    def test_lhs_indexing_increment_list_prim(self):
+        def foo():
+            ls = [1, 2, 3]
+            ls[0] += 5
+            return ls
+        self.checkScript(foo, ())
 
-    def test_python_op_exception(self):
-        def python_op(x):
-            raise Exception("bad!")
+    def test_lhs_indexing_multi(self):
+        def foo(a, b):
+            a = a.clone()
+            foo, a[0], bar = (1, b, 3)
+            return foo, a, bar
+        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
 
-        @torch.jit.script
-        def fn(x):
-            return python_op(x)
+    def test_bool_dispatch(self):
+        with self.disableModuleHook():  # TODO: Python print broadcasting list
+            def kwarg_false(x):
+                # type: (Tensor) -> Tensor
+                return F.max_pool1d(x, 1, 1, return_indices=False)
+            self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
 
-        with self.assertRaisesRegex(RuntimeError, "operation failed in interpreter"):
-            fn(torch.tensor(4))
+            def kwarg_true(x):
+                # type: (Tensor) -> Tuple[Tensor, Tensor]
+                return F.max_pool1d(x, 1, 1, return_indices=True)
+            self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
 
-    def test_trace_contiguous(self):
-        def foo(x):
-            return x[:, :, ::2].contiguous().view(12)
+            def full_kwarg_false(x):
+                # type: (Tensor) -> Tensor
+                return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
+            self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
 
-        x = torch.rand(2, 3, 4)
-        traced = torch.jit.trace(foo, (x,))
-        y = traced(x)
-        self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())
+            def full_kwarg_true(x):
+                # type: (Tensor) -> Tuple[Tensor, Tensor]
+                return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
+            self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
 
-    # This tests the logic in THPVariable_contiguous. There is short-circuiting
-    # code that prevents us from even getting to VariableType::contiguous, since
-    # it is an optimization that prevents us from acquiring the GIL for touching
-    # the device. We needed to add the tracing logic directly into the
-    # THPVariable_contiguous function only for the path where we are skipping
-    # dispatch into contiguous. We should see an aten::contiguous in this trace!
-    def test_trace_contiguous_short_circuit(self):
-        def foo(x):
-            return x.contiguous()
+            def use_default(x):
+                # type: (Tensor) -> Tensor
+                return F.max_pool1d(x, 1, 1)
+            self.checkScript(use_default, (torch.randn(3, 3, 3),))
 
-        x = torch.rand(2, 3, 4)
-        traced = torch.jit.trace(foo, (x,))
-        self.assertExpectedGraph(traced.graph)
+            def arg_false(x):
+                # type: (Tensor) -> Tensor
+                return F.max_pool1d(x, 1, 1, 0, 1, False, False)
+            self.checkScript(arg_false, (torch.randn(3, 3, 3),))
 
-    def test_weak_module(self):
+            def arg_true(x):
+                # type: (Tensor) -> Tuple[Tensor, Tensor]
+                return F.max_pool1d(x, 1, 1, 0, 1, False, True)
+            self.checkScript(arg_true, (torch.randn(3, 3, 3),))
 
-        @torch._jit_internal.weak_module
-        class Weak(torch.nn.Module):
-            __constants__ = ['number']
+    def test_infer_size(self):
+        from torch._C import _infer_size
 
-            def __init__(self):
-                super(Weak, self).__init__()
-                self.number = 199
+        def fn(x, y):
+            # type: (Tensor, Tensor) -> List[int]
+            return _infer_size(x.size(), y.size())
 
-            def python_op_in_weak_module(self, x):
-                return x + 123
+        self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
 
-            @torch._jit_internal.weak_script_method
-            def forward(self, x):
-                return 55 + self.number + self.python_op_in_weak_module(x)
+    def test_mutable_dce(self):
+        @torch.jit.script
+        def foo():
+            a = torch.rand(2, 3)
+            a += torch.rand(2, 3)
+            b = torch.rand(2, 3)
+            b += torch.rand(2, 3)
+            # b should be cleaned up but not a
+            return a
 
-        class OtherStrong(torch.jit.ScriptModule):
-            __constants__ = ['number']
+        self.assertExpectedGraph(foo.graph)
 
-            def __init__(self):
-                super(OtherStrong, self).__init__()
-                self.number = 357
+    def test_mutable_dce_block(self):
+        @torch.jit.script
+        def foo():
+            a = torch.rand(2, 3)
+            a += torch.rand(2, 3)
+            b = torch.rand(2, 3)
+            if bool(a > torch.zeros(2, 3)):
+                b += torch.rand(2, 3)
+                a += torch.rand(2, 3)
+            # a should be cleaned up but not b
+            return b
 
-            def python_op_in_strong_module(self, x):
-                return x + 456
+        self.assertExpectedGraph(foo.graph)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return x + self.number + self.python_op_in_strong_module(x)
+    def test_mutable_dce_graph_input(self):
+        @torch.jit.script
+        def foo(a):
+            a += torch.rand(2, 3)
+            # shouldn't clean up `a` even though it's not used in the output
 
-        class Passthrough(torch.jit.ScriptModule):
-            def __init__(self):
-                super(Passthrough, self).__init__()
-                self.weak = Weak()
+        self.assertExpectedGraph(foo.graph)
+
+    def test_mutable_dce_list(self):
+        @torch.jit.script
+        def foo(a):
+            l = []
+            l.append(a)
+            c = l[0]
+            b = torch.rand(2, 3)
+            c += torch.rand(2, 3)
+            return b
+
+        self.assertExpectedGraph(foo.graph)
+
+    def test_mutable_dce_loop(self):
+        @torch.jit.script
+        def foo(a):
+            l = []
+            l.append(a)
+            i = 0
+            b = torch.rand(2, 3)
+            while i < 1:
+                dead = torch.rand(2, 3)
+                c = l[0]
+                c += torch.rand(2, 3)
+                i += 1
+            return b
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return self.weak(x)
+        self.assertExpectedGraph(foo.graph)
 
-        weak_mod = Weak()
-        x = torch.ones(1)
-        expected_result = 55 + 199 + (x + 123)
 
-        # Ensure weak mod is running without the JIT by passing the wrong type
-        # (i.e. not a tensor)
-        weak_mod(2)
+class MnistNet(nn.Module):
+    def __init__(self):
+        super(MnistNet, self).__init__()
+        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
+        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
+        self.conv2_drop = nn.Dropout2d()
+        self.fc1 = nn.Linear(320, 50)
+        self.fc2 = nn.Linear(50, 10)
 
-        python_result = weak_mod(x)
-        strong_mod = Passthrough()
-        script_result = strong_mod(x)
+    def forward(self, x):
+        x = F.relu(F.max_pool2d(self.conv1(x), 2))
+        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
+        x = x.view(-1, 320)
+        x = F.relu(self.fc1(x))
+        x = F.dropout(x, training=self.training)
+        x = self.fc2(x)
+        return F.log_softmax(x, dim=1)
 
-        self.assertEqual(python_result, expected_result)
-        self.assertEqual(script_result, expected_result)
 
-        class Strong(torch.jit.ScriptModule):
-            def __init__(self):
-                super(Strong, self).__init__()
-                self.weak = Weak()
-                self.strong = OtherStrong()
+class TestEndToEndHybridFrontendModels(JitTestCase):
+    @staticmethod
+    def _test_dcgan_models(self, device, check_export_import=True):
+        class DCGANGenerator(nn.Module):
+            def __init__(self, nz, ngf, nc):
+                super(DCGANGenerator, self).__init__()
+                self.main = nn.Sequential(
+                    # input is Z, going into a convolution
+                    nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
+                    nn.BatchNorm2d(ngf * 8),
+                    nn.ReLU(True),
+                    # state size. (ngf*8) x 4 x 4
+                    nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
+                    nn.BatchNorm2d(ngf * 4),
+                    nn.ReLU(True),
+                    # state size. (ngf*4) x 8 x 8
+                    nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
+                    nn.BatchNorm2d(ngf * 2),
+                    nn.ReLU(True),
+                    # state size. (ngf*2) x 16 x 16
+                    nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
+                    nn.BatchNorm2d(ngf),
+                    nn.ReLU(True),
+                    # state size. (ngf) x 32 x 32
+                    nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
+                    nn.Tanh()
+                    # state size. (nc) x 64 x 64
+                )
 
-            @torch.jit.script_method
-            def forward(self, x):
-                y = 2 * x
-                return y + 1 + self.weak(y) + self.strong(y)
+            def forward(self, input):
+                return self.main(input)
 
-        strong_mod = Strong()
-        strong_mod2 = Strong()
-        x = torch.ones(1)
-        expected_result = (x * 2) + 1 + (55 + 199 + x * 2 + 123) + (x * 2 + 357 + x * 2 + 456)
-        script_result = strong_mod(x)
-        script_result2 = strong_mod2(x)
-        self.assertEqual(script_result, expected_result)
-        self.assertEqual(script_result, script_result2)
+        class DCGANDiscriminator(nn.Module):
+            def __init__(self, nc, ndf):
+                super(DCGANDiscriminator, self).__init__()
+                self.main = nn.Sequential(
+                    # input is (nc) x 64 x 64
+                    nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
+                    nn.LeakyReLU(0.2, inplace=True),
+                    # state size. (ndf) x 32 x 32
+                    nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
+                    nn.BatchNorm2d(ndf * 2),
+                    nn.LeakyReLU(0.2, inplace=True),
+                    # state size. (ndf*2) x 16 x 16
+                    nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
+                    nn.BatchNorm2d(ndf * 4),
+                    nn.LeakyReLU(0.2, inplace=True),
+                    # state size. (ndf*4) x 8 x 8
+                    nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
+                    nn.BatchNorm2d(ndf * 8),
+                    nn.LeakyReLU(0.2, inplace=True),
+                    # state size. (ndf*8) x 4 x 4
+                    nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
+                    nn.Sigmoid()
+                )
 
-    def test_weak_module_parameters_and_buffers(self):
-        weights = torch.randn(10, 10)
-        bias = torch.randn(10)
-        weights2 = torch.randn(10, 10)
-        bias2 = torch.randn(10)
+            def forward(self, input):
+                return self.main(input).view(-1, 1).squeeze(1)
 
-        @torch._jit_internal.weak_module
-        class TestLinear(torch.nn.Module):
-            def __init__(self, in_features, out_features):
-                super(TestLinear, self).__init__()
-                self.in_features = in_features
-                self.out_features = out_features
-                self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
-                self.bias = torch.nn.Parameter(torch.Tensor(out_features))
-                self.register_buffer('counter', torch.ones(out_features))
-                self.reset_parameters()
+        bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
+        self.checkTrace(DCGANGenerator(nz, ngf, nc).to(device),
+                        (torch.rand(bs, nz, 1, 1, device=device),),
+                        export_import=check_export_import)
+        example_input = DCGANGenerator(nz, ngf, nc).to(device)(torch.rand(bs, nz, 1, 1, device=device))
+        self.checkTrace(DCGANDiscriminator(nc, ndf).to(device), (example_input,),
+                        export_import=check_export_import)
 
-            def reset_parameters(self):
-                torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
-                if self.bias is not None:
-                    fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
-                    bound = 1 / math.sqrt(fan_in)
-                    torch.nn.init.uniform_(self.bias, -bound, bound)
+    def test_dcgan_models(self):
+        self._test_dcgan_models(self, device='cpu')
 
-            @torch._jit_internal.weak_script_method
-            def forward(self, input):
-                return F.linear(input, self.weight, self.bias) + self.counter
+    @unittest.skipIf(not RUN_CUDA, "no CUDA")
+    @skipIfRocm
+    def test_dcgan_models_cuda(self):
+        # XXX: export_import on CUDA modules doesn't work (#11480)
+        self._test_dcgan_models(self, device='cuda', check_export_import=False)
 
-        # Initialize a ScriptModule that uses the weak module above multiple times
-        class Strong(torch.jit.ScriptModule):
+    @staticmethod
+    def _test_neural_style(self, device, check_export_import=True):
+        class TransformerNet(torch.nn.Module):
             def __init__(self):
-                super(Strong, self).__init__()
-                self.fc1 = TestLinear(10, 10)
-                self.fc1.weight = torch.nn.Parameter(weights)
-                self.fc1.bias = torch.nn.Parameter(bias)
-                self.fc2 = TestLinear(10, 10)
-                self.fc2.weight = torch.nn.Parameter(weights2)
-                self.fc2.bias = torch.nn.Parameter(bias2)
+                super(TransformerNet, self).__init__()
+                # Initial convolution layers
+                self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
+                self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
+                self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
+                self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
+                self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
+                self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
+                # Residual layers
+                self.res1 = ResidualBlock(128)
+                self.res2 = ResidualBlock(128)
+                self.res3 = ResidualBlock(128)
+                self.res4 = ResidualBlock(128)
+                self.res5 = ResidualBlock(128)
+                # Upsampling Layers
+                self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
+                self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
+                self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
+                self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
+                self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
+                # Non-linearities
+                self.relu = torch.nn.ReLU()
+
+            def forward(self, X):
+                y = self.relu(self.in1(self.conv1(X)))
+                y = self.relu(self.in2(self.conv2(y)))
+                y = self.relu(self.in3(self.conv3(y)))
+                y = self.res1(y)
+                y = self.res2(y)
+                y = self.res3(y)
+                y = self.res4(y)
+                y = self.res5(y)
+                y = self.relu(self.in4(self.deconv1(y)))
+                y = self.relu(self.in5(self.deconv2(y)))
+                y = self.deconv3(y)
+                return y
+
+        class ConvLayer(torch.nn.Module):
+            def __init__(self, in_channels, out_channels, kernel_size, stride):
+                super(ConvLayer, self).__init__()
+                reflection_padding = kernel_size // 2
+                self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
+                self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
 
-            @torch.jit.script_method
             def forward(self, x):
-                return x + self.fc1(x) + self.fc1(x) + self.fc2(x)
+                out = self.reflection_pad(x)
+                out = self.conv2d(out)
+                return out
 
-        strong_mod = Strong()
+        class ResidualBlock(torch.nn.Module):
+            """ResidualBlock
+            introduced in: https://arxiv.org/abs/1512.03385
+            recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
+            """
 
-        # Run same calculation as module
-        inp = torch.ones(10)
-        lin = torch.nn.Linear(10, 10)
-        lin.weight = torch.nn.Parameter(weights)
-        lin.bias = torch.nn.Parameter(bias)
-        lin2 = torch.nn.Linear(10, 10)
-        lin2.weight = torch.nn.Parameter(weights2)
-        lin2.bias = torch.nn.Parameter(bias2)
-        expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10)
+            def __init__(self, channels):
+                super(ResidualBlock, self).__init__()
+                self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
+                self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
+                self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
+                self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
+                self.relu = torch.nn.ReLU()
 
-        self.assertEqual(strong_mod(inp), expected_result)
-        self.assertExportImportModule(strong_mod, (inp,))
+            def forward(self, x):
+                residual = x
+                out = self.relu(self.in1(self.conv1(x)))
+                out = self.in2(self.conv2(out))
+                out = out + residual
+                return out
 
-    def test_weak_module_nested(self):
-        @torch._jit_internal.weak_module
-        class OtherWeak(torch.nn.Module):
-            __constants__ = ['constant']
+        class UpsampleConvLayer(torch.nn.Module):
+            """UpsampleConvLayer
+            Upsamples the input and then does a convolution. This method gives better results
+            compared to ConvTranspose2d.
+            ref: http://distill.pub/2016/deconv-checkerboard/
+            """
 
-            def __init__(self, in_features, out_features):
-                super(OtherWeak, self).__init__()
-                self.in_features = in_features
-                self.out_features = out_features
-                self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
-                self.bias = torch.nn.Parameter(torch.ones(out_features))
-                self.constant = 3
+            def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
+                super(UpsampleConvLayer, self).__init__()
+                self.upsample = upsample
+                if upsample:
+                    self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
+                reflection_padding = kernel_size // 2
+                self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
+                self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
 
-            @torch._jit_internal.weak_script_method
             def forward(self, x):
-                return x * x + self.constant + F.linear(x, self.weight, self.bias)
+                x_in = x
+                if self.upsample:
+                    x_in = self.upsample_layer(x_in)
+                out = self.reflection_pad(x_in)
+                out = self.conv2d(out)
+                return out
 
-        class OtherStrong(torch.jit.ScriptModule):
+        self.checkTrace(TransformerNet(), (torch.rand(5, 3, 64, 64),), export_import=check_export_import)
 
-            def __init__(self):
-                super(OtherStrong, self).__init__()
+    def test_neural_style(self):
+        self._test_neural_style(self, device='cpu')
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return x + 27
+    @unittest.skipIf(not RUN_CUDA, "no CUDA")
+    def test_neural_style_cuda(self):
+        # XXX: export_import on CUDA modules doesn't work (#11480)
+        self._test_neural_style(self, device='cuda', check_export_import=False)
 
-        @torch._jit_internal.weak_module
-        class Weak(torch.nn.Module):
-            def __init__(self, in_features, out_features):
-                super(Weak, self).__init__()
-                self.in_features = in_features
-                self.out_features = out_features
-                self.weight = torch.nn.Parameter(2 * torch.ones(out_features, in_features))
-                self.bias = torch.nn.Parameter(2 * torch.ones(out_features))
-                self.weak_submodule = OtherWeak(10, 10)
-                self.strong_submodule = OtherStrong()
+    @staticmethod
+    def _test_mnist(self, device, check_export_import=True):
+        # eval() is present because dropout makes this nondeterministic
+        self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
+                        export_import=check_export_import)
 
-            @torch._jit_internal.weak_script_method
-            def forward(self, x):
-                return x + self.weak_submodule(x) + self.strong_submodule(x) \
-                    + F.linear(x, self.weight, self.bias)
+    def test_mnist(self):
+        self._test_mnist(self, device='cpu')
 
-        class Strong(torch.jit.ScriptModule):
-            __constants__ = ['constant']
+    @unittest.skipIf(not RUN_CUDA, "no CUDA")
+    @skipIfRocm
+    def test_mnist_cuda(self):
+        # XXX: export_import on CUDA modules doesn't work (#11480)
+        self._test_mnist(self, device='cuda', check_export_import=False)
 
-            def __init__(self):
-                super(Strong, self).__init__()
-                self.weak = Weak(10, 10)
+    @unittest.skipIf(not RUN_CUDA, "no CUDA")
+    @skipIfRocm
+    def test_mnist_training_leaks_no_memory_cuda(self):
+        net = MnistNet().cuda()
+        # MnistNet uses dropout, don't check its trace
+        traced_net = torch.jit.trace(net, [torch.randn(5, 1, 28, 28, device='cuda')],
+                                     check_trace=False)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return x + self.weak(x)
+        def train(iters):
+            for _ in range(iters):
+                # Get some fake data
+                inp = torch.randn(5, 1, 28, 28, device='cuda')
+                out = traced_net(inp)
 
-        strong_mod = Strong()
-        inp = torch.randn(10)
-        result = strong_mod(inp)
-        expected_result = inp + (inp + inp * inp + inp + 27) + 3 \
-            + F.linear(inp, torch.ones(10, 10), torch.ones(10)) \
-            + F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
-        self.assertEqual(result, expected_result)
+                # Here's some fake loss
+                out.sum().backward()
 
-    def test_weak_module_submodule(self):
-        @torch._jit_internal.weak_module
-        class Weak(torch.nn.Module):
-            def __init__(self):
-                super(Weak, self).__init__()
-                self.param = torch.nn.Parameter(100 * torch.ones(5))
+                # Zero out grads
+                traced_net.zero_grad()
 
-            @torch._jit_internal.weak_script_method
-            def forward(self, x):
-                return x + self.param
+        # Set it up so the params have .grad fields so they are not reported as leaks
+        train(1)
 
-        weak = Weak()
+        with self.assertLeaksNoCudaTensors():
+            train(5)
 
-        class OtherStrong(torch.jit.ScriptModule):
+    @staticmethod
+    def _test_reinforcement_learning(self, device, test_export_import=True):
+        class Policy(nn.Module):
             def __init__(self):
-                super(OtherStrong, self).__init__()
-                self.weak = weak
-                self.weak2 = weak
+                super(Policy, self).__init__()
+                self.affine1 = nn.Linear(4, 128)
+                self.affine2 = nn.Linear(128, 2)
 
-            @torch.jit.script_method
             def forward(self, x):
-                return x + self.weak(x)
+                x = F.relu(self.affine1(x))
+                action_scores = self.affine2(x)
+                return F.softmax(action_scores, dim=1)
 
-        class Strong(torch.jit.ScriptModule):
-            def __init__(self):
-                super(Strong, self).__init__()
-                self.weak = Weak()
+        self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
+                        export_import=test_export_import)
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return self.weak(x) + weak(x)
+    def test_reinforcement_learning(self):
+        self._test_reinforcement_learning(self, device='cpu')
 
-        other_strong_mod = OtherStrong()
+    @unittest.skipIf(not RUN_CUDA, "no CUDA")
+    def test_reinforcement_learning_cuda(self):
+        # XXX: export_import on CUDA modules doesn't work (#11480)
+        self._test_reinforcement_learning(self, device='cuda', test_export_import=False)
 
-        self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)
+    @staticmethod
+    def _test_snli(self, device, check_export_import=True):
+        class Bottle(nn.Module):
 
-        with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"):
-            strong_mod = Strong()
+            def forward(self, input):
+                if len(input.size()) <= 2:
+                    return super(Bottle, self).forward(input)
+                size = input.size()[:2]
+                out = super(Bottle, self).forward(input.view(size[0] * size[1], -1))
+                return out.view(size[0], size[1], -1)
 
-    def test_weak_module_copying(self):
-        class Submodule(torch.nn.Module):
-            def __init__(self):
-                super(Submodule, self).__init__()
+        class Linear(Bottle, nn.Linear):
+            pass
 
-            def forward(self, x):
-                return x + 100
+        class Encoder(nn.Module):
 
-        @torch._jit_internal.weak_module
-        class Weak(torch.nn.Module):
-            def __init__(self, in_features, out_features):
-                super(Weak, self).__init__()
-                self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
-                self.bias = torch.nn.Parameter(torch.ones(out_features))
-                self.register_buffer("buffer", torch.ones(out_features))
-                self.submodule = Submodule()
+            def __init__(self, config):
+                super(Encoder, self).__init__()
+                self.config = config
+                input_size = config.d_proj if config.projection else config.d_embed
+                dropout = 0 if config.n_layers == 1 else config.dp_ratio
+                self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
+                                   num_layers=config.n_layers, dropout=dropout,
+                                   bidirectional=config.birnn)
 
-            @torch._jit_internal.weak_script_method
-            def forward(self, x):
-                return F.linear(x, self.weight, self.bias) \
-                    + self.buffer + self.submodule(x)
+            def forward(self, inputs):
+                batch_size = inputs.size()[1]
+                state_shape = self.config.n_cells, batch_size, self.config.d_hidden
+                h0 = c0 = inputs.new_zeros(state_shape)
+                outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
+                return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
 
-        class Strong(torch.jit.ScriptModule):
-            def __init__(self, weak):
-                super(Strong, self).__init__()
-                self.weak = weak
+        class SNLIClassifier(nn.Module):
 
-            @torch.jit.script_method
-            def forward(self, x):
-                return self.weak(x)
+            def __init__(self, config):
+                super(SNLIClassifier, self).__init__()
+                self.config = config
+                self.embed = nn.Embedding(config.n_embed, config.d_embed)
+                self.projection = Linear(config.d_embed, config.d_proj)
+                self.encoder = Encoder(config)
+                self.dropout = nn.Dropout(p=config.dp_ratio)
+                self.relu = nn.ReLU()
+                seq_in_size = 2 * config.d_hidden
+                if self.config.birnn:
+                    seq_in_size *= 2
+                lin_config = [seq_in_size] * 2
+                self.out = nn.Sequential(
+                    Linear(*lin_config),
+                    self.relu,
+                    self.dropout,
+                    Linear(*lin_config),
+                    self.relu,
+                    self.dropout,
+                    Linear(*lin_config),
+                    self.relu,
+                    self.dropout,
+                    Linear(seq_in_size, config.d_out))
 
-        inp = torch.ones(5, 5) * 5
-        weak_mod = Weak(5, 5)
-        strong_mod = Strong(weak_mod)
+            def forward(self, premise, hypothesis):
+                prem_embed = self.embed(premise)
+                hypo_embed = self.embed(hypothesis)
+                if self.config.fix_emb:
+                    prem_embed = prem_embed.detach()
+                    hypo_embed = hypo_embed.detach()
+                if self.config.projection:
+                    prem_embed = self.relu(self.projection(prem_embed))
+                    hypo_embed = self.relu(self.projection(hypo_embed))
+                premise = self.encoder(prem_embed)
+                hypothesis = self.encoder(hypo_embed)
+                scores = self.out(torch.cat([premise, hypothesis], 1))
+                return scores
 
-        self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule))
-        self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule))
+        class Config:
+            n_embed = 100
+            d_embed = 100
+            d_proj = 300
+            dp_ratio = 0.0  # For deterministic testing TODO: change by fixing seed in checkTrace?
+            d_hidden = 300
+            birnn = True
+            d_out = 300
+            fix_emb = True
+            projection = True
+            n_layers = 2
+            n_cells = 4  # 2 * n_layers because birnn = True
 
-        self.assertIs(strong_mod.weak.weight, weak_mod.weight)
-        self.assertIs(strong_mod.weak.buffer, weak_mod.buffer)
-        self.assertIs(strong_mod.weak.submodule, weak_mod.submodule)
+        premise = torch.LongTensor(48, 128).random_(0, 100).to(device)
+        hypothesis = torch.LongTensor(24, 128).random_(0, 100).to(device)
 
-        # Test lookup fallback
-        weak_mod.new_attribute = 10
-        self.assertIs(strong_mod.weak.new_attribute, weak_mod.new_attribute)
+        self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
+                        inputs_require_grads=False, export_import=check_export_import)
 
-        weak_mod.weight.data += torch.ones(5, 5) * 100
-        self.assertTrue(strong_mod(inp).allclose(weak_mod(inp)))
+    @skipIfRocm
+    def test_snli(self):
+        self._test_snli(self, device='cpu')
 
-        # Re-assignment is not tracked
-        weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
-        self.assertFalse(strong_mod(inp).allclose(weak_mod(inp)))
+    @skipIfRocm
+    @unittest.skipIf(not RUN_CUDA, "no CUDA")
+    def test_snli_cuda(self):
+        # XXX: export_import on CUDA modules doesn't work (#11480)
+        self._test_snli(self, device='cuda', check_export_import=False)
 
-    def test_backend_cudnn_enabled(self):
-        # Only test that this compiles
-        @torch.jit.script
-        def fn(x):
-            if torch.backends.cudnn.enabled:
-                x = x + 2
-            else:
-                x = x + 3
-            return x
+    @staticmethod
+    def _test_super_resolution(self, device, check_export_import=True):
+        import torch.nn.init as init
 
-    def test_inplace_add(self):
+        class Net(nn.Module):
 
-        def foo(a, b):
-            c = a + b
-            c.add_(b)
-            return c
-        self.checkScript(foo, (torch.rand(3), torch.rand(3)))
+            def __init__(self, upscale_factor):
+                super(Net, self).__init__()
 
-    def test_add_out(self):
-        def foo(a, b):
-            c = a + b
-            e = 2 * a
-            torch.add(c, b, out=e)
-            return e
-        self.checkScript(foo, (torch.rand(3), torch.rand(3)))
+                self.relu = nn.ReLU()
+                self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
+                self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
+                self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
+                self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
+                self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
 
-    def test_augmented_assign(self):
-        def foo(a, b):
-            a += b
-            a -= b
-            a /= b
-            a *= b
-            return a, b
-        self.checkScript(foo, (torch.rand(3), torch.rand(3)), check_expected=True)
+            def forward(self, x):
+                x = self.relu(self.conv1(x))
+                x = self.relu(self.conv2(x))
+                x = self.relu(self.conv3(x))
+                x = self.pixel_shuffle(self.conv4(x))
+                return x
 
-    def test_pass(self):
-        def foo(x):
-            # type: (bool) -> int
-            for _i in range(3):
-                pass
-            if x:
-                pass
-            else:
-                pass
-            return 3
+        net = Net(upscale_factor=4).to(device)
+        self.checkTrace(net, (torch.rand(5, 1, 64, 64, device=device),),
+                        export_import=check_export_import)
 
-        self.checkScript(foo, (True,))
+    @skipIfRocm
+    def test_super_resolution(self):
+        self._test_super_resolution(self, device='cpu')
 
-    def test_optional_conversion(self):
-        @torch.jit.script
-        def other_fn(x=None):
-            # type: (Optional[int]) -> int
-            return torch.jit._unwrap_optional(x)
+    @skipIfRocm
+    @unittest.skipIf(not RUN_CUDA, 'no CUDA')
+    def test_super_resolution_cuda(self):
+        # XXX: export_import on CUDA modules doesn't work (#11480)
+        self._test_super_resolution(self, device='cuda', check_export_import=False)
 
-        @torch.jit.script
-        def fn(x):
-            # type: (int) -> int
-            return other_fn(x)
+    @suppress_warnings
+    def test_time_sequence_prediction(self):
+        class Sequence(torch.jit.ScriptModule):
+            def __init__(self):
+                super(Sequence, self).__init__()
+                self.lstm1 = nn.LSTMCell(1, 51)
+                self.lstm2 = nn.LSTMCell(51, 51)
+                self.linear = nn.Linear(51, 1)
 
-        self.assertEqual(fn(2), 2)
+            # TODO: could not pass tuple to a python Op and type annotations
+            # is not descending to python signature, hence the wrapper
+            # see https://github.com/pytorch/pytorch/issues/8778
+            # and https://github.com/pytorch/pytorch/issues/8777
+            def test_lstm1(self, input, hx, cx):
+                # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
+                return self.lstm1(input, (hx, cx))
 
-        @torch.jit.script
-        def unify_to_optional(x):
-            # type: (bool) -> Optional[int]
-            if x:
-                a = None
-            else:
-                a = 2
-            return a
+            def test_lstm2(self, input, hx, cx):
+                # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
+                return self.lstm2(input, (hx, cx))
 
-        self.assertEqual(unify_to_optional(True), None)
-        self.assertEqual(unify_to_optional(False), 2)
+            # TODO: could not support tensor constructors in script
+            # see https://github.com/pytorch/pytorch/issues/8814
+            def test_tensor(self):
+                return torch.tensor([], dtype=torch.double)
 
-        @torch.jit.script
-        def opt_list(x):
-            # type: (Optional[List[float]]) -> int
-            return 2
+            @torch.jit.script_method
+            def forward(self, input):
+                # TODO: add future as input with default val
+                # see https://github.com/pytorch/pytorch/issues/8724
+                outputs = self.test_tensor()
+                h_t = torch.zeros((3, 51), dtype=torch.double)
+                c_t = torch.zeros((3, 51), dtype=torch.double)
+                h_t2 = torch.zeros((3, 51), dtype=torch.double)
+                c_t2 = torch.zeros((3, 51), dtype=torch.double)
 
-        @torch.jit.script
-        def broadcast_opt_list(x):
-            # type: (Optional[BroadcastingList2[float]]) -> int
-            return 2
+                output = torch.zeros([3, 51])
+                future = 2
 
-        @torch.jit.script
-        def opt_list_tuple_caller(x):
-            # type: (Tuple[float, float]) -> int
-            return opt_list(x) + broadcast_opt_list(x)
+                # TODO: chunk call should appear as the for loop iterable
+                # We hard-code it to 4 for now.
+                a, b, c, d = input.chunk(input.size(1), dim=1)
+                for input_t in (a, b, c, d):
+                    h_t, c_t = self.test_lstm1(input_t, h_t, c_t)
+                    h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
+                    output = self.linear(h_t2)
+                    outputs = torch.cat((outputs, output), 1)
+                for _ in range(future):  # if we should predict the future
+                    h_t, c_t = self.test_lstm1(output, h_t, c_t)
+                    h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
+                    output = self.linear(h_t2)
+                    outputs = torch.cat((outputs, output), 1)
+                return outputs
 
-        self.assertEqual(opt_list_tuple_caller((2., 3.)), 4)
+        # TODO: toggle export_import once above issues are fixed
+        self.checkTrace(Sequence(), (torch.rand(3, 4),),
+                        export_import=False)
 
-    def test_lhs_indexing(self):
-        def foo(a, b):
-            a = a.clone()
-            a[0] = b
-            return a
-        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
+    @staticmethod
+    def _test_vae(self, device, check_export_import=True):
+        class VAE(nn.Module):
+            def __init__(self):
+                super(VAE, self).__init__()
 
-    def test_lhs_advanced_indexing_assignment(self):
-        def foo(x, y):
-            a = torch.exp(x)
-            b = x == 1
-            a[b] = y[b]
-            return a
-        self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
+                self.fc1 = nn.Linear(784, 400)
+                self.fc21 = nn.Linear(400, 20)
+                self.fc22 = nn.Linear(400, 20)
+                self.fc3 = nn.Linear(20, 400)
+                self.fc4 = nn.Linear(400, 784)
 
-    def test_lhs_advanced_indexing_augmented_assignment(self):
-        def foo(x, y):
-            a = torch.exp(x)
-            b = x == 1
-            a[b] += y[b]
-            return a
-        self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
+            def encode(self, x):
+                h1 = F.relu(self.fc1(x))
+                return self.fc21(h1), self.fc22(h1)
 
-    def test_lhs_indexing_list(self):
-        def foo(a, b):
-            ls = [a]
-            ls[0] = b
-            return ls
-        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
+            def reparameterize(self, mu, logvar):
+                if self.training:
+                    std = torch.exp(0.5 * logvar)
+                    eps = torch.randn_like(std)
+                    return eps.mul(std).add_(mu)
+                else:
+                    return mu
 
-    def test_inplace_copy_script(self):
-        def foo(x):
-            a = torch.rand(3, 4)
-            a.copy_(x)
-            return a
-        self.checkScript(foo, (torch.rand(3, 4),))
+            def decode(self, z):
+                h3 = F.relu(self.fc3(z))
+                return torch.sigmoid(self.fc4(h3))
 
-    def test_lhs_indexing_increment(self):
-        def foo(a, b):
-            a[0] += b
-            return a
-        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
+            def forward(self, x):
+                mu, logvar = self.encode(x.view(-1, 784))
+                z = self.reparameterize(mu, logvar)
+                return self.decode(z), mu, logvar
 
-    def test_lhs_indexing_increment_list(self):
-        def foo(a, b):
-            a = a.clone()
-            ls = [a, b]
-            ls[0] += b
-            return ls
-        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
+        # eval() is present because randn_like makes this nondeterministic
+        self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
+                        export_import=check_export_import)
 
-    def test_lhs_indexing_increment_list_prim(self):
-        def foo():
-            ls = [1, 2, 3]
-            ls[0] += 5
-            return ls
-        self.checkScript(foo, ())
+    def test_vae(self):
+        self._test_vae(self, device='cpu')
 
-    def test_lhs_indexing_multi(self):
-        def foo(a, b):
-            a = a.clone()
-            foo, a[0], bar = (1, b, 3)
-            return foo, a, bar
-        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
+    @unittest.skipIf(not RUN_CUDA, "no CUDA")
+    def test_vae_cuda(self):
+        # XXX: export_import on CUDA modules doesn't work (#11480)
+        self._test_vae(self, device='cuda', check_export_import=False)
 
-    def test_bool_dispatch(self):
-        with self.disableModuleHook():  # TODO: Python print broadcasting list
-            def kwarg_false(x):
-                # type: (Tensor) -> Tensor
-                return F.max_pool1d(x, 1, 1, return_indices=False)
-            self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
 
-            def kwarg_true(x):
-                # type: (Tensor) -> Tuple[Tensor, Tensor]
-                return F.max_pool1d(x, 1, 1, return_indices=True)
-            self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
+# Smoke tests for export methods
+class TestPytorchExportModes(JitTestCase):
+    class MyModel(nn.Module):
+        def __init__(self):
+            super(TestPytorchExportModes.MyModel, self).__init__()
+
+        def forward(self, x):
+            return x.transpose(0, 1)
 
-            def full_kwarg_false(x):
-                # type: (Tensor) -> Tensor
-                return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
-            self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
+    def test_protobuf(self):
+        torch_model = TestPytorchExportModes.MyModel()
+        fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
+        f = io.BytesIO()
+        torch.onnx._export(torch_model, (fake_input), f, verbose=False,
+                           export_type=torch.onnx.ExportTypes.PROTOBUF_FILE)
 
-            def full_kwarg_true(x):
-                # type: (Tensor) -> Tuple[Tensor, Tensor]
-                return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
-            self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
+    def test_zipfile(self):
+        torch_model = TestPytorchExportModes.MyModel()
+        fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
+        f = io.BytesIO()
+        torch.onnx._export(torch_model, (fake_input), f, verbose=False,
+                           export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE)
 
-            def use_default(x):
-                # type: (Tensor) -> Tensor
-                return F.max_pool1d(x, 1, 1)
-            self.checkScript(use_default, (torch.randn(3, 3, 3),))
+    def test_compressed_zipfile(self):
+        torch_model = TestPytorchExportModes.MyModel()
+        fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
+        f = io.BytesIO()
+        torch.onnx._export(torch_model, (fake_input), f, verbose=False,
+                           export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE)
 
-            def arg_false(x):
-                # type: (Tensor) -> Tensor
-                return F.max_pool1d(x, 1, 1, 0, 1, False, False)
-            self.checkScript(arg_false, (torch.randn(3, 3, 3),))
+    def test_directory(self):
+        torch_model = TestPytorchExportModes.MyModel()
+        fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
+        d = tempfile.mkdtemp()
+        torch.onnx._export(torch_model, (fake_input), d, verbose=False,
+                           export_type=torch.onnx.ExportTypes.DIRECTORY)
+        shutil.rmtree(d)
 
-            def arg_true(x):
-                # type: (Tensor) -> Tuple[Tensor, Tensor]
-                return F.max_pool1d(x, 1, 1, 0, 1, False, True)
-            self.checkScript(arg_true, (torch.randn(3, 3, 3),))
+    @skipIfRocm
+    @skipIfNoLapack
+    def test_aten_fallback(self):
+        class ModelWithAtenNotONNXOp(nn.Module):
+            def forward(self, x, y):
+                abcd = x + y
+                defg = torch.qr(abcd)
+                return defg
 
-    def test_infer_size(self):
-        from torch._C import _infer_size
+        x = torch.rand(3, 4)
+        y = torch.rand(3, 4)
+        f = io.BytesIO()
+        exported = torch.onnx.export_to_pretty_string(
+            ModelWithAtenNotONNXOp(), (x, y), f,
+            operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
+        self.assertExpected(exported)
 
-        def fn(x, y):
-            # type: (Tensor, Tensor) -> List[int]
-            return _infer_size(x.size(), y.size())
+    # torch.fmod is using to test ONNX_ATEN.
+    # If you plan to remove fmod from aten, or found this test failed.
+    # please contact @Rui.
+    @skipIfRocm
+    def test_onnx_aten(self):
+        class ModelWithAtenFmod(nn.Module):
+            def forward(self, x, y):
+                return torch.fmod(x, y)
 
-        self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
+        f = io.BytesIO()
+        x = torch.randn(3, 4, dtype=torch.float32)
+        y = torch.randn(3, 4, dtype=torch.float32)
+        exported = torch.onnx.export_to_pretty_string(
+            ModelWithAtenFmod(), (x, y), f,
+            operator_export_type=OperatorExportTypes.ONNX_ATEN)
+        self.assertExpected(exported)
 
-    def test_mutable_dce(self):
-        @torch.jit.script
-        def foo():
-            a = torch.rand(2, 3)
-            a += torch.rand(2, 3)
-            b = torch.rand(2, 3)
-            b += torch.rand(2, 3)
-            # b should be cleaned up but not a
-            return a
 
-        self.assertExpectedGraph(foo.graph)
+# known to be failing in tracer
+EXCLUDE_TRACED = {
+    'test_split_dim',
+    'test_split_dim_neg0',
 
-    def test_mutable_dce_block(self):
-        @torch.jit.script
-        def foo():
-            a = torch.rand(2, 3)
-            a += torch.rand(2, 3)
-            b = torch.rand(2, 3)
-            if bool(a > torch.zeros(2, 3)):
-                b += torch.rand(2, 3)
-                a += torch.rand(2, 3)
-            # a should be cleaned up but not b
-            return b
+    # The following fail due to #12024.
+    # A prim::ListConstruct is involved and the indices get traced as DynamicType,
+    # which always require_grad. This causes a crash in autodiff.
+    'test___getitem___adv_index',
+    'test___getitem___adv_index_beg',
+    'test___getitem___adv_index_comb',
+    'test___getitem___adv_index_dup',
+    'test___getitem___adv_index_sub',
+    'test___getitem___adv_index_sub_2',
+    'test___getitem___adv_index_sub_3',
+    'test___getitem___adv_index_var',
+}
 
-        self.assertExpectedGraph(foo.graph)
+EXCLUDE_TYPE_CHECK = {
+    # slogdet tests use itemgetter to select its only differentiable output,
+    # but this happens outside of the graph we handle, so there are fewer
+    # reference outputs than graph outputs.
+    'test_slogdet_1x1_neg_det',
+    'test_slogdet_1x1_pos_det',
+    'test_slogdet_distinct_singular_values',
+    'test_slogdet_neg_det',
+    'test_slogdet_pos_det',
+    'test_slogdet_symmetric',
+    'test_slogdet_symmetric_pd',
+}
 
-    def test_mutable_dce_graph_input(self):
-        @torch.jit.script
-        def foo(a):
-            a += torch.rand(2, 3)
-            # shouldn't clean up `a` even though it's not used in the output
+# known to be failing in script
+EXCLUDE_SCRIPT = {
+    'test_norm_fro',
+    'test_norm_fro_default',
+    'test_norm_nuc',
 
-        self.assertExpectedGraph(foo.graph)
+    # aten op has additional cudnn argument
+    'test_nn_unfold',
 
-    def test_mutable_dce_list(self):
-        @torch.jit.script
-        def foo(a):
-            l = []
-            l.append(a)
-            c = l[0]
-            b = torch.rand(2, 3)
-            c += torch.rand(2, 3)
-            return b
+    # flaky test - TODO fix
+    'test_nn_ctc_loss',
 
-        self.assertExpectedGraph(foo.graph)
+    # unknown builtin op
+    'test_nn_fold',
+}
 
-    def test_mutable_dce_loop(self):
-        @torch.jit.script
-        def foo(a):
-            l = []
-            l.append(a)
-            i = 0
-            b = torch.rand(2, 3)
-            while i < 1:
-                dead = torch.rand(2, 3)
-                c = l[0]
-                c += torch.rand(2, 3)
-                i += 1
-            return b
+EXCLUDE_PYTHON_PRINT = {
+    # no support for BroadcastingList in python printer
+    'test_nn_max_unpool1d',
+    'test_nn_max_unpool2d',
+    'test_nn_max_unpool3d',
+    'test_nn_max_pool1d',
+    'test_nn_max_pool2d',
+    'test_nn_max_pool3d',
+    'test_nn_max_pool1d_with_indices',
+}
 
-        self.assertExpectedGraph(foo.graph)
+EXCLUDE_SCRIPT_MODULES = {
+    'test_nn_AdaptiveAvgPool2d_tuple_none',
+    'test_nn_AdaptiveAvgPool3d_tuple_none',
+    'test_nn_AdaptiveMaxPool2d_tuple_none',
+    'test_nn_AdaptiveMaxPool3d_tuple_none',
+}
 
+DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
+    'test_nn_avg_pool2d',
+    'test_nn_log_softmax',
+    'test_nn_threshold',
+    'test_nn_nll_loss',
+}
 
-class MnistNet(nn.Module):
-    def __init__(self):
-        super(MnistNet, self).__init__()
-        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
-        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
-        self.conv2_drop = nn.Dropout2d()
-        self.fc1 = nn.Linear(320, 50)
-        self.fc2 = nn.Linear(50, 10)
 
-    def forward(self, x):
-        x = F.relu(F.max_pool2d(self.conv1(x), 2))
-        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
-        x = x.view(-1, 320)
-        x = F.relu(self.fc1(x))
-        x = F.dropout(x, training=self.training)
-        x = self.fc2(x)
-        return F.log_softmax(x, dim=1)
+# make a new function where all non-tensor arguments in 'args' have been partially
+# applied, and all tensor arguments remain.
+# used to trace functions when some arguments are not tensors
+def partial_apply_nontensors(fn, args, **kwargs):
+    source = ['t' if isinstance(arg, torch.Tensor) else 's' for arg in args]
+
+    def new_fn(*tensors_):
+        tensors = iter(tensors_)
+        return fn(*(args[i] if s == 's' else next(tensors) for i, s in enumerate(source)), **kwargs)
 
+    return new_fn, [arg for arg in args if isinstance(arg, torch.Tensor)]
 
-class TestEndToEndHybridFrontendModels(JitTestCase):
-    @staticmethod
-    def _test_dcgan_models(self, device, check_export_import=True):
-        class DCGANGenerator(nn.Module):
-            def __init__(self, nz, ngf, nc):
-                super(DCGANGenerator, self).__init__()
-                self.main = nn.Sequential(
-                    # input is Z, going into a convolution
-                    nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
-                    nn.BatchNorm2d(ngf * 8),
-                    nn.ReLU(True),
-                    # state size. (ngf*8) x 4 x 4
-                    nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
-                    nn.BatchNorm2d(ngf * 4),
-                    nn.ReLU(True),
-                    # state size. (ngf*4) x 8 x 8
-                    nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
-                    nn.BatchNorm2d(ngf * 2),
-                    nn.ReLU(True),
-                    # state size. (ngf*2) x 16 x 16
-                    nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
-                    nn.BatchNorm2d(ngf),
-                    nn.ReLU(True),
-                    # state size. (ngf) x 32 x 32
-                    nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
-                    nn.Tanh()
-                    # state size. (nc) x 64 x 64
-                )
 
-            def forward(self, input):
-                return self.main(input)
+# create a trace function from input fn
+#
+# disable_autodiff_subgraph_inlining:
+#   Don't inline autodiff subgraphs so we can test autodiff
+def create_traced_fn(self, fn,
+                     disable_autodiff_subgraph_inlining=False):
+    def traced_fn(*inputs, **kwargs):
+        fn_tensors, inputs_tensors = partial_apply_nontensors(fn, inputs, **kwargs)
+        traced = torch.jit.trace(fn_tensors, inputs_tensors)
+        self.assertExportImport(traced.graph, inputs_tensors)
+        if disable_autodiff_subgraph_inlining:
+            traced.debug_disable_autodiff_subgraph_inlining()
+        output = traced(*inputs_tensors)
+        traced_fn.last_graph = traced.graph_for(*inputs_tensors)
+        return output
+    return traced_fn
 
-        class DCGANDiscriminator(nn.Module):
-            def __init__(self, nc, ndf):
-                super(DCGANDiscriminator, self).__init__()
-                self.main = nn.Sequential(
-                    # input is (nc) x 64 x 64
-                    nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
-                    nn.LeakyReLU(0.2, inplace=True),
-                    # state size. (ndf) x 32 x 32
-                    nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
-                    nn.BatchNorm2d(ndf * 2),
-                    nn.LeakyReLU(0.2, inplace=True),
-                    # state size. (ndf*2) x 16 x 16
-                    nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
-                    nn.BatchNorm2d(ndf * 4),
-                    nn.LeakyReLU(0.2, inplace=True),
-                    # state size. (ndf*4) x 8 x 8
-                    nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
-                    nn.BatchNorm2d(ndf * 8),
-                    nn.LeakyReLU(0.2, inplace=True),
-                    # state size. (ndf*8) x 4 x 4
-                    nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
-                    nn.Sigmoid()
-                )
+script_template = '''
+def the_method({}):
+    return {}
+'''
 
-            def forward(self, input):
-                return self.main(input).view(-1, 1).squeeze(1)
+script_method_template = '''
+def forward({}):
+    return {}
+'''
 
-        bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
-        self.checkTrace(DCGANGenerator(nz, ngf, nc).to(device),
-                        (torch.rand(bs, nz, 1, 1, device=device),),
-                        export_import=check_export_import)
-        example_input = DCGANGenerator(nz, ngf, nc).to(device)(torch.rand(bs, nz, 1, 1, device=device))
-        self.checkTrace(DCGANDiscriminator(nc, ndf).to(device), (example_input,),
-                        export_import=check_export_import)
 
-    def test_dcgan_models(self):
-        self._test_dcgan_models(self, device='cpu')
+def get_constant(x):
+    if x == inf:
+        return 'float(\'inf\')' if PY2 else 'math.inf'
+    if x == -inf:
+        return 'float(\'-inf\')' if PY2 else '-math.inf'
+    return x
 
-    @unittest.skipIf(not RUN_CUDA, "no CUDA")
-    @skipIfRocm
-    def test_dcgan_models_cuda(self):
-        # XXX: export_import on CUDA modules doesn't work (#11480)
-        self._test_dcgan_models(self, device='cuda', check_export_import=False)
 
-    @staticmethod
-    def _test_neural_style(self, device, check_export_import=True):
-        class TransformerNet(torch.nn.Module):
-            def __init__(self):
-                super(TransformerNet, self).__init__()
-                # Initial convolution layers
-                self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
-                self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
-                self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
-                self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
-                self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
-                self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
-                # Residual layers
-                self.res1 = ResidualBlock(128)
-                self.res2 = ResidualBlock(128)
-                self.res3 = ResidualBlock(128)
-                self.res4 = ResidualBlock(128)
-                self.res5 = ResidualBlock(128)
-                # Upsampling Layers
-                self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
-                self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
-                self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
-                self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
-                self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
-                # Non-linearities
-                self.relu = torch.nn.ReLU()
+def get_script_args(args):
+    formals = []
+    tensors = []
+    actuals = []
+    for arg in args:
+        if isinstance(arg, torch.Tensor):
+            name = 'i{}'.format(len(formals))
+            formals.append(name)
+            actuals.append(name)
+            tensors.append(arg)
+        elif isinstance(arg, str):
+            actuals.append("'{}'".format(arg))
+        else:
+            actuals.append(str(get_constant(arg)))
+    return (formals, tensors, actuals)
 
-            def forward(self, X):
-                y = self.relu(self.in1(self.conv1(X)))
-                y = self.relu(self.in2(self.conv2(y)))
-                y = self.relu(self.in3(self.conv3(y)))
-                y = self.res1(y)
-                y = self.res2(y)
-                y = self.res3(y)
-                y = self.res4(y)
-                y = self.res5(y)
-                y = self.relu(self.in4(self.deconv1(y)))
-                y = self.relu(self.in5(self.deconv2(y)))
-                y = self.deconv3(y)
-                return y
 
-        class ConvLayer(torch.nn.Module):
-            def __init__(self, in_channels, out_channels, kernel_size, stride):
-                super(ConvLayer, self).__init__()
-                reflection_padding = kernel_size // 2
-                self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
-                self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
+# create a script function from (name, func_type, output_process_fn),
+# returns a function takes in (args, kwargs) and runs the compiled function and
+# then applies the post process fn to the outputs
+def create_script_fn(self, method_name, func_type, output_process_fn,
+                     disable_autodiff_subgraph_inlining=False):
+    def script_fn(*args, **kwargs):
+        formals, tensors, actuals = get_script_args(args)
+        kwargs_str = ''
+        for k, v in kwargs.items():
+            kwargs_str += ', ' + k + '=' + str(v)
+        if func_type == 'functional':
+            call = 'torch.{}({}{})'.format(method_name, ', '.join(actuals), kwargs_str)
+        elif func_type == 'method':
+            call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str)
+        elif func_type == 'nn_functional':
+            call = 'torch.nn.functional.{}({}{})'.format(method_name, ', '.join(actuals), kwargs_str)
+        else:
+            raise 'Unsupported function type'
 
-            def forward(self, x):
-                out = self.reflection_pad(x)
-                out = self.conv2d(out)
-                return out
+        script = script_template.format(', '.join(formals), call)
 
-        class ResidualBlock(torch.nn.Module):
-            """ResidualBlock
-            introduced in: https://arxiv.org/abs/1512.03385
-            recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
-            """
+        CU = torch.jit.CompilationUnit(script)
+        if disable_autodiff_subgraph_inlining:
+            CU.the_method.debug_disable_autodiff_subgraph_inlining()
+        self.assertExportImport(CU.the_method.graph, tensors)
+        output = output_process_fn(CU.the_method(*tensors))
+        script_fn.last_graph = CU.the_method.graph_for(*tensors)
+        return output
+    return script_fn
 
-            def __init__(self, channels):
-                super(ResidualBlock, self).__init__()
-                self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
-                self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
-                self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
-                self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
-                self.relu = torch.nn.ReLU()
 
-            def forward(self, x):
-                residual = x
-                out = self.relu(self.in1(self.conv1(x)))
-                out = self.in2(self.conv2(out))
-                out = out + residual
-                return out
+def check_alias_annotation(method_name, args, kwargs):
+    formals, tensors, actuals = get_script_args(args)
+    kwargs_str = ''
+    for k, v in kwargs.items():
+        kwargs_str += ', ' + k + '=' + str(v)
+    call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str)
+    script = script_template.format(', '.join(formals), call)
+    CU = torch.jit.CompilationUnit(script)
+    torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), method_name)
 
-        class UpsampleConvLayer(torch.nn.Module):
-            """UpsampleConvLayer
-            Upsamples the input and then does a convolution. This method gives better results
-            compared to ConvTranspose2d.
-            ref: http://distill.pub/2016/deconv-checkerboard/
-            """
 
-            def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
-                super(UpsampleConvLayer, self).__init__()
-                self.upsample = upsample
-                if upsample:
-                    self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
-                reflection_padding = kernel_size // 2
-                self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
-                self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
+def check_output_types(self, func, ref_outputs, args, kwargs):
+    graph = getattr(func, 'last_graph', None)
+    if not isinstance(ref_outputs, tuple):
+        ref_outputs = (ref_outputs,)
+    types = [o.type() for o in graph.outputs()]
+    self.assertEqual(len(types), len(ref_outputs))
+    for i, (t, ref_out) in enumerate(zip(types, ref_outputs)):
+        if isinstance(ref_out, list):
+            assert len(ref_out) > 0
+            elem = ref_out[0]
+            assert isinstance(elem, torch.Tensor)
+            self.assertTrue(t.isSubtypeOf(torch._C.ListType.ofTensors()))
+        else:
+            ref_type = torch._C.Type.inferFrom(ref_out)
+            self.assertTrue(ref_type.isSubtypeOf(t))
 
-            def forward(self, x):
-                x_in = x
-                if self.upsample:
-                    x_in = self.upsample_layer(x_in)
-                out = self.reflection_pad(x_in)
-                out = self.conv2d(out)
-                return out
 
-        self.checkTrace(TransformerNet(), (torch.rand(5, 3, 64, 64),), export_import=check_export_import)
+def check_against_reference(self, func, reference_func, args, kwargs=None,
+                            allow_unused=True, check_types=True, no_grad=False):
+    kwargs = kwargs if kwargs else {}
 
-    def test_neural_style(self):
-        self._test_neural_style(self, device='cpu')
+    def allSum(vs):
+        if isinstance(vs, torch.Tensor):
+            vs = (vs,)
+        return sum([(i + 1) * v.sum()
+                    for i, v in enumerate(vs)
+                    if v is not None and v.dtype.is_floating_point])
 
-    @unittest.skipIf(not RUN_CUDA, "no CUDA")
-    def test_neural_style_cuda(self):
-        # XXX: export_import on CUDA modules doesn't work (#11480)
-        self._test_neural_style(self, device='cuda', check_export_import=False)
+    def clone_inputs(requires_grad):
+        inputs = [
+            arg.detach().clone().requires_grad_(requires_grad and arg.requires_grad)
+            if isinstance(arg, torch.Tensor) else arg for arg in args
+        ]
+        return inputs, [input for input in inputs if isinstance(input, torch.Tensor) and input.requires_grad]
+
+    nograd_inputs, nograd_tensors = clone_inputs(False)
+    recording_inputs, recording_tensors = clone_inputs(True)
+
+    # test no gradients case
+    outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
+    outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
+    self.assertEqual(outputs, outputs_test)
+
+    if check_types:
+        check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
 
-    @staticmethod
-    def _test_mnist(self, device, check_export_import=True):
-        # eval() is present because dropout makes this nondeterministic
-        self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
-                        export_import=check_export_import)
+    if no_grad:
+        # skip grad tests
+        return
 
-    def test_mnist(self):
-        self._test_mnist(self, device='cpu')
+    # test single grad case
+    outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
+    grads = torch.autograd.grad(allSum(outputs), recording_tensors,
+                                allow_unused=allow_unused)
 
-    @unittest.skipIf(not RUN_CUDA, "no CUDA")
-    @skipIfRocm
-    def test_mnist_cuda(self):
-        # XXX: export_import on CUDA modules doesn't work (#11480)
-        self._test_mnist(self, device='cuda', check_export_import=False)
+    outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
+    grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
+                                     allow_unused=allow_unused)
+    self.assertEqual(outputs, outputs_test)
+    self.assertEqual(grads, grads_test)
 
-    @unittest.skipIf(not RUN_CUDA, "no CUDA")
-    @skipIfRocm
-    def test_mnist_training_leaks_no_memory_cuda(self):
-        net = MnistNet().cuda()
-        # MnistNet uses dropout, don't check its trace
-        traced_net = torch.jit.trace(net, [torch.randn(5, 1, 28, 28, device='cuda')],
-                                     check_trace=False)
+    # test the grad grad case
+    if self._testMethodName in nn_functional_single_grad:
+        return
 
-        def train(iters):
-            for _ in range(iters):
-                # Get some fake data
-                inp = torch.randn(5, 1, 28, 28, device='cuda')
-                out = traced_net(inp)
+    outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
+    l1 = allSum(outputs)
+    grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
+                                allow_unused=allow_unused)
+    l2 = (allSum(grads) * l1)
+    grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
 
-                # Here's some fake loss
-                out.sum().backward()
+    recording_inputs, recording_tensors = clone_inputs(True)
 
-                # Zero out grads
-                traced_net.zero_grad()
+    outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
+    l1_test = allSum(outputs_test)
+    grads_test = torch.autograd.grad(
+        l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
+    l2_test = (allSum(grads_test) * l1_test)
+    grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
 
-        # Set it up so the params have .grad fields so they are not reported as leaks
-        train(1)
+    self.assertEqual(outputs, outputs_test)
+    self.assertEqual(grads, grads_test)
+    for g2, g2_test in zip(grads2, grads2_test):
+        if g2 is None and g2_test is None:
+            continue
+        self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4))
 
-        with self.assertLeaksNoCudaTensors():
-            train(5)
 
-    @staticmethod
-    def _test_reinforcement_learning(self, device, test_export_import=True):
-        class Policy(nn.Module):
-            def __init__(self):
-                super(Policy, self).__init__()
-                self.affine1 = nn.Linear(4, 128)
-                self.affine2 = nn.Linear(128, 2)
+class TestFuser(JitTestCase):
+    def assertAllFused(self, graph, except_for=()):
+        if [n.kind() for n in graph.nodes()] == ['prim::DifferentiableGraph']:
+            graph = next(graph.nodes()).g('Subgraph')
+        allowed_nodes = {'prim::Constant', 'prim::FusionGroup'} | set(except_for)
+        self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
+                        'got {}'.format(graph))
+        self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
 
-            def forward(self, x):
-                x = F.relu(self.affine1(x))
-                action_scores = self.affine2(x)
-                return F.softmax(action_scores, dim=1)
+    def _test_fused_abs(self, device='cpu'):
 
-        self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
-                        export_import=test_export_import)
+        @torch.jit.script
+        def func(x):
+            return x.abs() * 2
 
-    def test_reinforcement_learning(self):
-        self._test_reinforcement_learning(self, device='cpu')
+        a = torch.randn(5, device=device)
+        self.assertEqual(func(a), a.abs() * 2)
+        self.assertAllFused(func.graph_for(a))
 
-    @unittest.skipIf(not RUN_CUDA, "no CUDA")
-    def test_reinforcement_learning_cuda(self):
-        # XXX: export_import on CUDA modules doesn't work (#11480)
-        self._test_reinforcement_learning(self, device='cuda', test_export_import=False)
+    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+    @enable_cpu_fuser
+    def test_abs_cpu(self):
+        self._test_fused_abs()
 
-    @staticmethod
-    def _test_snli(self, device, check_export_import=True):
-        class Bottle(nn.Module):
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
+    @skipIfRocm
+    def test_abs_cuda(self):
+        self._test_fused_abs(device="cuda")
 
-            def forward(self, input):
-                if len(input.size()) <= 2:
-                    return super(Bottle, self).forward(input)
-                size = input.size()[:2]
-                out = super(Bottle, self).forward(input.view(size[0] * size[1], -1))
-                return out.view(size[0], size[1], -1)
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_arg_configurations_smoke_cuda(self):
+        # A smoke test to make sure we won't use the same kernel for contiguous
+        # and non-contiguous arguments.
+        # TODO: add optionally enabled debug counters to the fuser to verify
+        #       that we really can tell the difference between configurations
+        def f(x, y):
+            z1, z2 = (x + y).chunk(2, dim=1)
+            return z1 * z2
 
-        class Linear(Bottle, nn.Linear):
-            pass
+        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        traced_f = torch.jit.trace(f, (x, y,))
+        self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
 
-        class Encoder(nn.Module):
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_broadcast_cuda(self):
+        def scaleshift(x, scale, shift):
+            return x * scale + shift
 
-            def __init__(self, config):
-                super(Encoder, self).__init__()
-                self.config = config
-                input_size = config.d_proj if config.projection else config.d_embed
-                dropout = 0 if config.n_layers == 1 else config.dp_ratio
-                self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
-                                   num_layers=config.n_layers, dropout=dropout,
-                                   bidirectional=config.birnn)
+        inputs = [
+            torch.randn(4, 4, dtype=torch.float, device='cuda'),
+            torch.randn(4, dtype=torch.float, device='cuda'),
+            torch.randn(4, dtype=torch.float, device='cuda'),
+        ]
+        ge = self.checkTrace(scaleshift, inputs)
+        self.assertExpectedGraph(ge.graph_for(*inputs))
 
-            def forward(self, inputs):
-                batch_size = inputs.size()[1]
-                state_shape = self.config.n_cells, batch_size, self.config.d_hidden
-                h0 = c0 = inputs.new_zeros(state_shape)
-                outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
-                return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @unittest.skipIf(not RUN_CUDA_HALF, "no half support")
+    def test_cuda_half(self):
+        x = torch.randn(4, 4, dtype=torch.half, device='cuda')
+        y = torch.randn(4, 4, dtype=torch.half, device='cuda')
 
-        class SNLIClassifier(nn.Module):
+        funcs = [
+            self.fn_test_comparison_gt_lt,
+            self.fn_test_relu,
+            self.fn_test_exp
+        ]
 
-            def __init__(self, config):
-                super(SNLIClassifier, self).__init__()
-                self.config = config
-                self.embed = nn.Embedding(config.n_embed, config.d_embed)
-                self.projection = Linear(config.d_embed, config.d_proj)
-                self.encoder = Encoder(config)
-                self.dropout = nn.Dropout(p=config.dp_ratio)
-                self.relu = nn.ReLU()
-                seq_in_size = 2 * config.d_hidden
-                if self.config.birnn:
-                    seq_in_size *= 2
-                lin_config = [seq_in_size] * 2
-                self.out = nn.Sequential(
-                    Linear(*lin_config),
-                    self.relu,
-                    self.dropout,
-                    Linear(*lin_config),
-                    self.relu,
-                    self.dropout,
-                    Linear(*lin_config),
-                    self.relu,
-                    self.dropout,
-                    Linear(seq_in_size, config.d_out))
+        # Note: Non fused inputs must be float to prevent loss of precision
+        inputs = (x.float(), y.float())
+        fusion_inputs = (x, y)
+        for fn in funcs:
+            local_inputs = [t.clone().requires_grad_() for t in inputs]
+            local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
 
-            def forward(self, premise, hypothesis):
-                prem_embed = self.embed(premise)
-                hypo_embed = self.embed(hypothesis)
-                if self.config.fix_emb:
-                    prem_embed = prem_embed.detach()
-                    hypo_embed = hypo_embed.detach()
-                if self.config.projection:
-                    prem_embed = self.relu(self.projection(prem_embed))
-                    hypo_embed = self.relu(self.projection(hypo_embed))
-                premise = self.encoder(prem_embed)
-                hypothesis = self.encoder(hypo_embed)
-                scores = self.out(torch.cat([premise, hypothesis], 1))
-                return scores
+            # Verifies outputs
+            fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False, optimize=True)
+            outputs = fn(*local_inputs)
+            fusion_outputs = fusion(*local_fusion_inputs)
+            outputs_half = [t.half() for t in outputs]
+            self.assertEqual(outputs_half, fusion_outputs)
 
-        class Config:
-            n_embed = 100
-            d_embed = 100
-            d_proj = 300
-            dp_ratio = 0.0  # For deterministic testing TODO: change by fixing seed in checkTrace?
-            d_hidden = 300
-            birnn = True
-            d_out = 300
-            fix_emb = True
-            projection = True
-            n_layers = 2
-            n_cells = 4  # 2 * n_layers because birnn = True
+            # Verifies gradients
+            for output, fusion_output in zip(outputs_half, fusion_outputs):
+                grads = torch.autograd.grad(
+                    output.float().sum(), local_inputs, allow_unused=True, retain_graph=True)
+                fusion_grads = torch.autograd.grad(
+                    fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True)
+                grads_half = [t.half() for t in grads]
+                self.assertEqual(grads_half, fusion_grads)
 
-        premise = torch.LongTensor(48, 128).random_(0, 100).to(device)
-        hypothesis = torch.LongTensor(24, 128).random_(0, 100).to(device)
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_checks_cat_inputs(self):
+        # We shouldn't treat cat nodes as broadcasting. All their inputs
+        # need to be checked for having the same map size, before we can
+        # run the kernel.
+        @torch.jit.script
+        def f(x, y):
+            return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0)
+
+        # NOTE: y is broadcastable to x, but output of f(x, y) should have
+        # shape 3x4, and not 4x4.
+        x = torch.randn(2, 4, dtype=torch.float, device='cuda')
+        y = torch.randn(1, 4, dtype=torch.float, device='cuda')
 
-        self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
-                        inputs_require_grads=False, export_import=check_export_import)
+        self.assertEqual(f(x, y).shape, (3, 4))
+        self.assertAllFused(f.graph_for(x, y))
 
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "No CUDA")
     @skipIfRocm
-    def test_snli(self):
-        self._test_snli(self, device='cpu')
+    def test_chunk_cuda(self):
+        def fn(x):
+            a, b, c = x.chunk(3, 1)
+            return a * b + c
 
-    @skipIfRocm
-    @unittest.skipIf(not RUN_CUDA, "no CUDA")
-    def test_snli_cuda(self):
-        # XXX: export_import on CUDA modules doesn't work (#11480)
-        self._test_snli(self, device='cuda', check_export_import=False)
+        inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
+
+        ge = self.checkScript(fn, inputs)
+        self.assertExpectedGraph(ge.graph_for(*inputs))
 
     @staticmethod
-    def _test_super_resolution(self, device, check_export_import=True):
-        import torch.nn.init as init
+    def _test_chunk_correctness(self, device='cpu'):
+        def chunk_4_0(x):
+            x0, x1, x2, x3 = x.chunk(4, 0)
+            return x0 + x1 + x2 + x3
 
-        class Net(nn.Module):
+        def chunk_4_1(x):
+            x0, x1, x2, x3 = x.chunk(4, 1)
+            return x0 + x1 + x2 + x3
 
-            def __init__(self, upscale_factor):
-                super(Net, self).__init__()
+        def chunk_4_last(x):
+            x0, x1, x2, x3 = x.chunk(4, 2)
+            return x0 + x1 + x2 + x3
 
-                self.relu = nn.ReLU()
-                self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
-                self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
-                self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
-                self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
-                self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
+        fns = [chunk_4_0, chunk_4_1, chunk_4_last]
+        tensors = [
+            # splitSize = 1
+            torch.randn(4, 4, 4, dtype=torch.float, device=device),
 
-            def forward(self, x):
-                x = self.relu(self.conv1(x))
-                x = self.relu(self.conv2(x))
-                x = self.relu(self.conv3(x))
-                x = self.pixel_shuffle(self.conv4(x))
-                return x
+            # contiguous case
+            torch.randn(12, 8, 16, dtype=torch.float, device=device),
 
-        net = Net(upscale_factor=4).to(device)
-        self.checkTrace(net, (torch.rand(5, 1, 64, 64, device=device),),
-                        export_import=check_export_import)
+            # non-contiguous case
+            torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2),
+        ]
 
-    @skipIfRocm
-    def test_super_resolution(self):
-        self._test_super_resolution(self, device='cpu')
+        for tensor in tensors:
+            for fn in fns:
+                self.checkScript(fn, [tensor])
 
+    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
     @skipIfRocm
-    @unittest.skipIf(not RUN_CUDA, 'no CUDA')
-    def test_super_resolution_cuda(self):
-        # XXX: export_import on CUDA modules doesn't work (#11480)
-        self._test_super_resolution(self, device='cuda', check_export_import=False)
-
-    @suppress_warnings
-    def test_time_sequence_prediction(self):
-        class Sequence(torch.jit.ScriptModule):
-            def __init__(self):
-                super(Sequence, self).__init__()
-                self.lstm1 = nn.LSTMCell(1, 51)
-                self.lstm2 = nn.LSTMCell(51, 51)
-                self.linear = nn.Linear(51, 1)
-
-            # TODO: could not pass tuple to a python Op and type annotations
-            # is not descending to python signature, hence the wrapper
-            # see https://github.com/pytorch/pytorch/issues/8778
-            # and https://github.com/pytorch/pytorch/issues/8777
-            def test_lstm1(self, input, hx, cx):
-                # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
-                return self.lstm1(input, (hx, cx))
-
-            def test_lstm2(self, input, hx, cx):
-                # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
-                return self.lstm2(input, (hx, cx))
-
-            # TODO: could not support tensor constructors in script
-            # see https://github.com/pytorch/pytorch/issues/8814
-            def test_tensor(self):
-                return torch.tensor([], dtype=torch.double)
-
-            @torch.jit.script_method
-            def forward(self, input):
-                # TODO: add future as input with default val
-                # see https://github.com/pytorch/pytorch/issues/8724
-                outputs = self.test_tensor()
-                h_t = torch.zeros((3, 51), dtype=torch.double)
-                c_t = torch.zeros((3, 51), dtype=torch.double)
-                h_t2 = torch.zeros((3, 51), dtype=torch.double)
-                c_t2 = torch.zeros((3, 51), dtype=torch.double)
+    @enable_cpu_fuser
+    def test_chunk_correctness(self):
+        return self._test_chunk_correctness(self, 'cpu')
 
-                output = torch.zeros([3, 51])
-                future = 2
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "No CUDA")
+    @skipIfRocm
+    def test_chunk_correctness_cuda(self):
+        return self._test_chunk_correctness(self, 'cuda')
 
-                # TODO: chunk call should appear as the for loop iterable
-                # We hard-code it to 4 for now.
-                a, b, c, d = input.chunk(input.size(1), dim=1)
-                for input_t in (a, b, c, d):
-                    h_t, c_t = self.test_lstm1(input_t, h_t, c_t)
-                    h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
-                    output = self.linear(h_t2)
-                    outputs = torch.cat((outputs, output), 1)
-                for _ in range(future):  # if we should predict the future
-                    h_t, c_t = self.test_lstm1(output, h_t, c_t)
-                    h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
-                    output = self.linear(h_t2)
-                    outputs = torch.cat((outputs, output), 1)
-                return outputs
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_chunk_distributes_cuda(self):
+        def f(x, y):
+            z1, z2 = (x + y).chunk(2, dim=1)
+            return z1 * z2
 
-        # TODO: toggle export_import once above issues are fixed
-        self.checkTrace(Sequence(), (torch.rand(3, 4),),
-                        export_import=False)
+        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
 
-    @staticmethod
-    def _test_vae(self, device, check_export_import=True):
-        class VAE(nn.Module):
-            def __init__(self):
-                super(VAE, self).__init__()
+        ge = self.checkTrace(f, (x, y))
+        self.assertExpectedGraph(ge.graph_for(x, y))
 
-                self.fc1 = nn.Linear(784, 400)
-                self.fc21 = nn.Linear(400, 20)
-                self.fc22 = nn.Linear(400, 20)
-                self.fc3 = nn.Linear(20, 400)
-                self.fc4 = nn.Linear(400, 784)
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_chunk_motion_deduplicates_inputs(self):
+        def func1(x):
+            z = x * x
+            z0, z1 = z.chunk(2)
+            return z0 * z1
 
-            def encode(self, x):
-                h1 = F.relu(self.fc1(x))
-                return self.fc21(h1), self.fc22(h1)
+        def func2(x):
+            z = x * x * x
+            z0, z1 = z.chunk(2)
+            return z0 * z1
 
-            def reparameterize(self, mu, logvar):
-                if self.training:
-                    std = torch.exp(0.5 * logvar)
-                    eps = torch.randn_like(std)
-                    return eps.mul(std).add_(mu)
-                else:
-                    return mu
+        inputs = [
+            torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float),
+        ]
+        for func in [func1, func2]:
+            module = self.checkScript(func, inputs)
+            forward_graph = module.graph_for(*inputs)
+            self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
+            fusion_group = list(forward_graph.nodes())[-1]
+            self.assertEqual(len(list(fusion_group.inputs())), 1)
 
-            def decode(self, z):
-                h3 = F.relu(self.fc3(z))
-                return torch.sigmoid(self.fc4(h3))
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "No CUDA")
+    @skipIfRocm
+    def test_chunk_multiple_cuda(self):
+        # The arguments are intentionally used out of order as a test to see
+        # if the fusion compiler adds extra args in the correct order
+        def fn(s, x, y, z):
+            z1, z2 = z.chunk(2, 2)
+            x1, x2, x3 = x.chunk(3, 1)
+            y1, y2 = y.chunk(2, 0)
+            return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
 
-            def forward(self, x):
-                mu, logvar = self.encode(x.view(-1, 784))
-                z = self.reparameterize(mu, logvar)
-                return self.decode(z), mu, logvar
+        inputs = [
+            torch.randn(5, 2, 3, dtype=torch.float, device='cuda'),
+            torch.randn(5, 6, 3, dtype=torch.float, device='cuda'),
+            torch.randn(10, 2, 3, dtype=torch.float, device='cuda'),
+            torch.randn(5, 2, 6, dtype=torch.float, device='cuda'),
+        ]
 
-        # eval() is present because randn_like makes this nondeterministic
-        self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
-                        export_import=check_export_import)
+        ge = self.checkScript(fn, inputs)
+        self.assertExpectedGraph(ge.graph_for(*inputs))
 
-    def test_vae(self):
-        self._test_vae(self, device='cpu')
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_clamp(self):
+        def func2(a, b):
+            return torch.clamp(a + b, min=0, max=2)
 
-    @unittest.skipIf(not RUN_CUDA, "no CUDA")
-    def test_vae_cuda(self):
-        # XXX: export_import on CUDA modules doesn't work (#11480)
-        self._test_vae(self, device='cuda', check_export_import=False)
+        def funcInf(a, b):
+            return torch.clamp(a + b, min=0, max=float('inf'))
 
+        a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
+        b = torch.randn(4, 4, dtype=torch.float, device='cuda')
 
-# Smoke tests for export methods
-class TestPytorchExportModes(JitTestCase):
-    class MyModel(nn.Module):
-        def __init__(self):
-            super(TestPytorchExportModes.MyModel, self).__init__()
+        funcs = (func2, funcInf)
+        for f in funcs:
+            s = self.checkScript(f, (a, b))
+            self.assertAllFused(s.graph_for(a, b), except_for={'aten::size'})
 
-        def forward(self, x):
-            return x.transpose(0, 1)
+            c = s(a, b)
+            c.sum().backward()
+            graph = backward_graph(s)
+            self.assertAllFused(graph, except_for={'prim::SumToSize'})
 
-    def test_protobuf(self):
-        torch_model = TestPytorchExportModes.MyModel()
-        fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
-        f = io.BytesIO()
-        torch.onnx._export(torch_model, (fake_input), f, verbose=False,
-                           export_type=torch.onnx.ExportTypes.PROTOBUF_FILE)
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_comparison_eq_ne(self):
+        def f(x, y):
+            mask = (x == 0).type_as(x)
+            z = x * mask + y
+            mask = (x != 0).type_as(x)
+            z = z * mask + y
+            return z
 
-    def test_zipfile(self):
-        torch_model = TestPytorchExportModes.MyModel()
-        fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
-        f = io.BytesIO()
-        torch.onnx._export(torch_model, (fake_input), f, verbose=False,
-                           export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE)
+        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
 
-    def test_compressed_zipfile(self):
-        torch_model = TestPytorchExportModes.MyModel()
-        fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
-        f = io.BytesIO()
-        torch.onnx._export(torch_model, (fake_input), f, verbose=False,
-                           export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE)
+        ge = self.checkTrace(f, (x, y))
+        self.assertAllFused(ge.graph_for(x, y))
 
-    def test_directory(self):
-        torch_model = TestPytorchExportModes.MyModel()
-        fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
-        d = tempfile.mkdtemp()
-        torch.onnx._export(torch_model, (fake_input), d, verbose=False,
-                           export_type=torch.onnx.ExportTypes.DIRECTORY)
-        shutil.rmtree(d)
+    @staticmethod
+    def fn_test_comparison_gt_lt(x, y):
+        mask = (x > 0).type_as(x)
+        z = x * mask + y
+        mask = (x < 0).type_as(x)
+        z = z * mask + y
+        return z
 
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
     @skipIfRocm
-    @skipIfNoLapack
-    def test_aten_fallback(self):
-        class ModelWithAtenNotONNXOp(nn.Module):
-            def forward(self, x, y):
-                abcd = x + y
-                defg = torch.qr(abcd)
-                return defg
+    def test_comparison_gt_lt_cuda(self):
+        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
 
-        x = torch.rand(3, 4)
-        y = torch.rand(3, 4)
-        f = io.BytesIO()
-        exported = torch.onnx.export_to_pretty_string(
-            ModelWithAtenNotONNXOp(), (x, y), f,
-            operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
-        self.assertExpected(exported)
+        ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
+        self.assertAllFused(ge.graph_for(x, y))
 
-    # torch.fmod is using to test ONNX_ATEN.
-    # If you plan to remove fmod from aten, or found this test failed.
-    # please contact @Rui.
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
     @skipIfRocm
-    def test_onnx_aten(self):
-        class ModelWithAtenFmod(nn.Module):
-            def forward(self, x, y):
-                return torch.fmod(x, y)
+    def test_comparison_ge_le_cuda(self):
+        def f(x, y):
+            mask = (x >= 0).type_as(x)
+            z = x * mask + y
+            mask = (x <= 0).type_as(x)
+            z = z * mask + y
+            return z
 
-        f = io.BytesIO()
-        x = torch.randn(3, 4, dtype=torch.float32)
-        y = torch.randn(3, 4, dtype=torch.float32)
-        exported = torch.onnx.export_to_pretty_string(
-            ModelWithAtenFmod(), (x, y), f,
-            operator_export_type=OperatorExportTypes.ONNX_ATEN)
-        self.assertExpected(exported)
+        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
 
+        ge = self.checkTrace(f, (x, y))
+        self.assertAllFused(ge.graph_for(x, y))
 
-# known to be failing in tracer
-EXCLUDE_TRACED = {
-    'test_split_dim',
-    'test_split_dim_neg0',
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_concat_cuda(self):
+        hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
+        cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
 
-    # The following fail due to #12024.
-    # A prim::ListConstruct is involved and the indices get traced as DynamicType,
-    # which always require_grad. This causes a crash in autodiff.
-    'test___getitem___adv_index',
-    'test___getitem___adv_index_beg',
-    'test___getitem___adv_index_comb',
-    'test___getitem___adv_index_dup',
-    'test___getitem___adv_index_sub',
-    'test___getitem___adv_index_sub_2',
-    'test___getitem___adv_index_sub_3',
-    'test___getitem___adv_index_var',
-}
+        def foo(hx, cx):
+            return torch.cat((hx + cx, hx * cx))
 
-EXCLUDE_TYPE_CHECK = {
-    # slogdet tests use itemgetter to select its only differentiable output,
-    # but this happens outside of the graph we handle, so there are fewer
-    # reference outputs than graph outputs.
-    'test_slogdet_1x1_neg_det',
-    'test_slogdet_1x1_pos_det',
-    'test_slogdet_distinct_singular_values',
-    'test_slogdet_neg_det',
-    'test_slogdet_pos_det',
-    'test_slogdet_symmetric',
-    'test_slogdet_symmetric_pd',
-}
+        ge = self.checkTrace(foo, (hx, cx))
+        self.assertExpectedGraph(ge.graph_for(hx, cx))
 
-# known to be failing in script
-EXCLUDE_SCRIPT = {
-    'test_norm_fro',
-    'test_norm_fro_default',
-    'test_norm_nuc',
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_concat_invariant_cuda(self):
+        # Invariant: the output of prim::FusedConcat may
+        # not be an input to any node inside the FusionGroup.
+        def fn(x, y, z):
+            x1 = x + y
+            y1 = x - y
+            w = torch.cat([x1, y1])
+            return w + z
 
-    # aten op has additional cudnn argument
-    'test_nn_unfold',
+        x = torch.randn(2, 2, dtype=torch.float, device='cuda')
+        y = torch.randn(2, 2, dtype=torch.float, device='cuda')
+        z = torch.randn(4, 2, dtype=torch.float, device='cuda')
+        ge = self.checkTrace(fn, (x, y, z))
+        self.assertExpectedGraph(ge.graph_for(x, y, z))
 
-    # flaky test - TODO fix
-    'test_nn_ctc_loss',
+    @staticmethod
+    def fn_test_exp(x, y):
+        return (x + .5 * y).exp()
 
-    # unknown builtin op
-    'test_nn_fold',
-}
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_exp_cuda(self):
+        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
 
-EXCLUDE_PYTHON_PRINT = {
-    # no support for BroadcastingList in python printer
-    'test_nn_max_unpool1d',
-    'test_nn_max_unpool2d',
-    'test_nn_max_unpool3d',
-    'test_nn_max_pool1d',
-    'test_nn_max_pool2d',
-    'test_nn_max_pool3d',
-    'test_nn_max_pool1d_with_indices',
-}
+        ge = self.checkTrace(self.fn_test_exp, (x, y))
+        self.assertAllFused(ge.graph_for(x, y))
 
-EXCLUDE_SCRIPT_MODULES = {
-    'test_nn_AdaptiveAvgPool2d_tuple_none',
-    'test_nn_AdaptiveAvgPool3d_tuple_none',
-    'test_nn_AdaptiveMaxPool2d_tuple_none',
-    'test_nn_AdaptiveMaxPool3d_tuple_none',
-}
+    # TODO: This test doesn't offer anything valuable, maybe we should delete it
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
+    @skipIfRocm
+    def test_last_device_cuda(self):
+        device = 'cuda:' + str(1)
+        x = torch.tensor([0.4], dtype=torch.float, device=device)
+        y = torch.tensor([0.7], dtype=torch.float, device=device)
 
-DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
-    'test_nn_avg_pool2d',
-    'test_nn_log_softmax',
-    'test_nn_threshold',
-    'test_nn_nll_loss',
-}
+        def doit(x, y):
+            return torch.sigmoid(torch.tanh(x * (x + y) + x))
 
+        ge = self.checkTrace(doit, (x, y))
+        self.assertExpectedGraph(ge.graph_for(x, y))
 
-# make a new function where all non-tensor arguments in 'args' have been partially
-# applied, and all tensor arguments remain.
-# used to trace functions when some arguments are not tensors
-def partial_apply_nontensors(fn, args, **kwargs):
-    source = ['t' if isinstance(arg, torch.Tensor) else 's' for arg in args]
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_lstm_cuda(self):
+        inputs = get_lstm_inputs('cuda', training=True)
+        module = self.checkScript(LSTMCellS, inputs)
+        forward_graph = module.graph_for(*inputs)
+        self.assertGraphContainsExactly(
+            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+        self.assertExpectedGraph(forward_graph, subname='forward')
 
-    def new_fn(*tensors_):
-        tensors = iter(tensors_)
-        return fn(*(args[i] if s == 's' else next(tensors) for i, s in enumerate(source)), **kwargs)
+        hy, cy = module(*inputs)
+        (hy + cy).sum().backward()
+        self.assertExpectedGraph(backward_graph(module), subname='backward')
 
-    return new_fn, [arg for arg in args if isinstance(arg, torch.Tensor)]
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_lstm_concat_cuda(self):
+        inputs = get_lstm_inputs('cuda')
+        ge = self.checkTrace(LSTMCellC, inputs)
+        self.assertExpectedGraph(ge.graph_for(*inputs))
 
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_lstm_gates_permutations_cuda(self):
+        # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
+        # Test that any permutation of this will still result in one FusionGroup.
+        choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh']
+        template = dedent('''
+        def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
+            gates = {} + {} + {} + {}
+            ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
+            return ingate * forgetgate * cellgate * outgate
+        ''')
+        for permutation in itertools.permutations(choices, len(choices)):
+            code = template.format(*permutation)
+            scope = {}
+            exec(code, globals(), scope)
+            cu = torch.jit.CompilationUnit(code)
 
-# create a trace function from input fn
-#
-# disable_autodiff_subgraph_inlining:
-#   Don't inline autodiff subgraphs so we can test autodiff
-def create_traced_fn(self, fn,
-                     disable_autodiff_subgraph_inlining=False):
-    def traced_fn(*inputs, **kwargs):
-        fn_tensors, inputs_tensors = partial_apply_nontensors(fn, inputs, **kwargs)
-        traced = torch.jit.trace(fn_tensors, inputs_tensors)
-        self.assertExportImport(traced.graph, inputs_tensors)
-        if disable_autodiff_subgraph_inlining:
-            traced.debug_disable_autodiff_subgraph_inlining()
-        output = traced(*inputs_tensors)
-        traced_fn.last_graph = traced.graph_for(*inputs_tensors)
-        return output
-    return traced_fn
+            inputs = get_lstm_inputs('cuda', training=False)
+            self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
+            forward_graph = cu.cell.graph_for(*inputs)
+            self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
 
-script_template = '''
-def the_method({}):
-    return {}
-'''
+    # TODO: Fuser doesn't work at all when inputs require grad. Fix that
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_lstm_traced_cuda(self):
+        inputs = get_lstm_inputs('cuda')
+        ge = self.checkTrace(LSTMCellF, inputs)
+        self.assertExpectedGraph(ge.graph_for(*inputs))
+
+    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+    @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
+    @enable_cpu_fuser
+    def test_lstm_traced_cpu(self):
+        inputs = get_lstm_inputs('cpu')
+        try:
+            ge = self.checkTrace(LSTMCellF, inputs)
+            self.assertExpectedGraph(ge.graph_for(*inputs))
+        except RuntimeError as e:
+            if 'Failed to compile' in e.args[0]:
+                warnings.warn('CPU fuser test has failed! This is not a hard failure, '
+                              'because the kernels sometimes trigger bugs in compilers '
+                              '(most notably GCC 7.2).')
+                raise unittest.SkipTest('Failed to compile')
+            else:
+                raise
 
-script_method_template = '''
-def forward({}):
-    return {}
-'''
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_milstm_cuda(self):
+        inputs = get_milstm_inputs('cuda', training=True)
+        module = self.checkScript(MiLSTMCell, inputs)
+        forward_graph = module.graph_for(*inputs)
+        self.assertGraphContainsExactly(
+            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
+        self.assertExpectedGraph(forward_graph, subname='forward')
 
+        hy, cy = module(*inputs)
+        (hy + cy).sum().backward()
+        self.assertExpectedGraph(backward_graph(module), subname='backward')
 
-def get_constant(x):
-    if x == inf:
-        return 'float(\'inf\')' if PY2 else 'math.inf'
-    if x == -inf:
-        return 'float(\'-inf\')' if PY2 else '-math.inf'
-    return x
+    # TODO: At some point we supported fusion of torch.rand_like but not anymore
+    @unittest.expectedFailure
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_rand_cuda(self):
+        class M(torch.jit.ScriptModule):
+            __constants__ = ['d']
 
+            def __init__(self):
+                self.d = torch.device('cuda')
 
-def get_script_args(args):
-    formals = []
-    tensors = []
-    actuals = []
-    for arg in args:
-        if isinstance(arg, torch.Tensor):
-            name = 'i{}'.format(len(formals))
-            formals.append(name)
-            actuals.append(name)
-            tensors.append(arg)
-        elif isinstance(arg, str):
-            actuals.append("'{}'".format(arg))
-        else:
-            actuals.append(str(get_constant(arg)))
-    return (formals, tensors, actuals)
+            @torch.jit.script_method
+            def create(self, x):
+                return x * x + x + torch.rand_like(x)
 
+        x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
+        m = M()
+        out1 = m.create(x)
+        out2 = m.create(x)
+        self.assertNotEqual(out1, out2)
+        self.assertTrue(torch.all(out1 >= 0))
+        self.assertTrue(torch.all(out1 < 1))
+        self.assertTrue(torch.all(out2 >= 0))
+        self.assertTrue(torch.all(out2 < 1))
+        self.assertAllFused(m.create.graph_for(x))
 
-# create a script function from (name, func_type, output_process_fn),
-# returns a function takes in (args, kwargs) and runs the compiled function and
-# then applies the post process fn to the outputs
-def create_script_fn(self, method_name, func_type, output_process_fn,
-                     disable_autodiff_subgraph_inlining=False):
-    def script_fn(*args, **kwargs):
-        formals, tensors, actuals = get_script_args(args)
-        kwargs_str = ''
-        for k, v in kwargs.items():
-            kwargs_str += ', ' + k + '=' + str(v)
-        if func_type == 'functional':
-            call = 'torch.{}({}{})'.format(method_name, ', '.join(actuals), kwargs_str)
-        elif func_type == 'method':
-            call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str)
-        elif func_type == 'nn_functional':
-            call = 'torch.nn.functional.{}({}{})'.format(method_name, ', '.join(actuals), kwargs_str)
-        else:
-            raise 'Unsupported function type'
+    @staticmethod
+    def fn_test_relu(x, y):
+        return F.relu(x + .5 * y)
 
-        script = script_template.format(', '.join(formals), call)
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_relu_cuda(self):
+        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
 
-        CU = torch.jit.CompilationUnit(script)
-        if disable_autodiff_subgraph_inlining:
-            CU.the_method.debug_disable_autodiff_subgraph_inlining()
-        self.assertExportImport(CU.the_method.graph, tensors)
-        output = output_process_fn(CU.the_method(*tensors))
-        script_fn.last_graph = CU.the_method.graph_for(*tensors)
-        return output
-    return script_fn
+        ge = self.checkTrace(self.fn_test_relu, (x, y))
 
+    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+    @enable_cpu_fuser
+    def test_scalar(self):
+        def fn(x, y):
+            return 2 * x + y
 
-def check_alias_annotation(method_name, args, kwargs):
-    formals, tensors, actuals = get_script_args(args)
-    kwargs_str = ''
-    for k, v in kwargs.items():
-        kwargs_str += ', ' + k + '=' + str(v)
-    call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str)
-    script = script_template.format(', '.join(formals), call)
-    CU = torch.jit.CompilationUnit(script)
-    torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), method_name)
+        x = torch.tensor(0.1, dtype=torch.float, device='cpu')
+        y = torch.tensor(1, dtype=torch.float, device='cpu')
+        ge = self.checkScript(fn, (x, y))
+        self.assertExpectedGraph(ge.graph_for(x, y))
 
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    def test_small_constant_cuda(self):
+        def fn_test_small_constant(x, y):
+            return (1e-8 * x + 5e-9 * y) * 1e8
+        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        y = torch.randn(4, 4, dtype=torch.float, device='cuda')
 
-def check_output_types(self, func, ref_outputs, args, kwargs):
-    graph = getattr(func, 'last_graph', None)
-    if not isinstance(ref_outputs, tuple):
-        ref_outputs = (ref_outputs,)
-    types = [o.type() for o in graph.outputs()]
-    self.assertEqual(len(types), len(ref_outputs))
-    for i, (t, ref_out) in enumerate(zip(types, ref_outputs)):
-        if isinstance(ref_out, list):
-            assert len(ref_out) > 0
-            elem = ref_out[0]
-            assert isinstance(elem, torch.Tensor)
-            self.assertTrue(t.isSubtypeOf(torch._C.ListType.ofTensors()))
-        else:
-            ref_type = torch._C.Type.inferFrom(ref_out)
-            self.assertTrue(ref_type.isSubtypeOf(t))
+        ge = self.checkTrace(fn_test_small_constant, (x, y))
+        self.assertAllFused(ge.graph_for(x, y))
 
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_tensor_scalar_ops_cuda(self):
+        def should_fuse(x):
+            z = 3.
+            y = x + z
+            return x * y
 
-def check_against_reference(self, func, reference_func, args, kwargs=None,
-                            allow_unused=True, check_types=True, no_grad=False):
-    kwargs = kwargs if kwargs else {}
+        # XXX: right now we only support fusing scalars if
+        # they're constant (#9940)
+        def should_not_fuse(x, z):
+            y = x + int(z)
+            return x * y
 
-    def allSum(vs):
-        if isinstance(vs, torch.Tensor):
-            vs = (vs,)
-        return sum([(i + 1) * v.sum()
-                    for i, v in enumerate(vs)
-                    if v is not None and v.dtype.is_floating_point])
+        inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
+        ge = self.checkScript(should_fuse, inputs)
+        self.assertExpectedGraph(ge.graph_for(*inputs), subname='1')
 
-    def clone_inputs(requires_grad):
         inputs = [
-            arg.detach().clone().requires_grad_(requires_grad and arg.requires_grad)
-            if isinstance(arg, torch.Tensor) else arg for arg in args
+            torch.randn(2, 2, dtype=torch.float, device='cuda'),
+            torch.tensor(3., dtype=torch.float, device='cuda'),
         ]
-        return inputs, [input for input in inputs if isinstance(input, torch.Tensor) and input.requires_grad]
-
-    nograd_inputs, nograd_tensors = clone_inputs(False)
-    recording_inputs, recording_tensors = clone_inputs(True)
+        ge = self.checkScript(should_not_fuse, inputs)
+        self.assertExpectedGraph(ge.graph_for(*inputs), subname='2')
 
-    # test no gradients case
-    outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
-    outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
-    self.assertEqual(outputs, outputs_test)
+    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+    @enable_cpu_fuser
+    def test_where_and_typing(self):
+        def f(x, y):
+            mask = x > y
+            res = torch.where(mask, x, y)
+            return mask, res
 
-    if check_types:
-        check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
+        script_f = torch.jit.script(f)
 
-    if no_grad:
-        # skip grad tests
-        return
+        x = torch.randn(4, 4, dtype=torch.double)
+        y = torch.randn(4, 4, dtype=torch.double)
 
-    # test single grad case
-    outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
-    grads = torch.autograd.grad(allSum(outputs), recording_tensors,
-                                allow_unused=allow_unused)
+        result1, result2 = script_f(x, y)
+        expected1, expected2 = f(x, y)
+        self.assertEqual(result1, expected1)
+        self.assertEqual(result2, expected2)
+        self.assertAllFused(script_f.graph_for(x, y))
 
-    outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
-    grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
-                                     allow_unused=allow_unused)
-    self.assertEqual(outputs, outputs_test)
-    self.assertEqual(grads, grads_test)
+    # TODO: This test seems dead
+    @unittest.skipIf(not IS_WINDOWS, "Testing Fuse skipped on windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    def test_windows(self):
+        def scaleshift(x, scale, shift):
+            return x * scale + shift
 
-    # test the grad grad case
-    if self._testMethodName in nn_functional_single_grad:
-        return
+        graph = torch.jit.script(scaleshift).graph
 
-    outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
-    l1 = allSum(outputs)
-    grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
-                                allow_unused=allow_unused)
-    l2 = (allSum(grads) * l1)
-    grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
+        inputs = [
+            torch.randn(4, 4, dtype=torch.float, device='cuda'),
+            torch.randn(4, dtype=torch.float, device='cuda'),
+            torch.randn(4, dtype=torch.float, device='cuda'),
+        ]
 
-    recording_inputs, recording_tensors = clone_inputs(True)
+        ge = self.checkTrace(scaleshift, inputs)
+        fuse_graph = ge.graph_for(*inputs)
 
-    outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
-    l1_test = allSum(outputs_test)
-    grads_test = torch.autograd.grad(
-        l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
-    l2_test = (allSum(grads_test) * l1_test)
-    grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
+        def run_graph(graph, inputs):
+            m = torch.jit.ScriptModule()
+            m._create_method_from_graph("forward", graph)
+            return m(*inputs)
 
-    self.assertEqual(outputs, outputs_test)
-    self.assertEqual(grads, grads_test)
-    for g2, g2_test in zip(grads2, grads2_test):
-        if g2 is None and g2_test is None:
-            continue
-        self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4))
+        self.assertEqual(run_graph(graph, inputs), run_graph(fuse_graph, inputs))
 
 
 # NB: torch.jit.script, when used as a function, uses the current scope