From 5667af3880af35663e5eee55b657e18e437f52e6 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 18 Dec 2018 16:13:39 -0800 Subject: [PATCH] Minor cleanup for TestFuser tests (#15134) Summary: Changelog: - change some expect tests that didn't have to be expect tests, instead use self.assertAllFused - Some of the fuser tests weren't using self.assertAllFused. - Minor test renames cc apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/15134 Differential Revision: D13507481 Pulled By: zou3519 fbshipit-source-id: dd0788530a60bb5ed2f42b961fae3db2b4404b64 --- test/expect/TestFuser.test_last_device_cuda.expect | 16 ------------ .../TestFuser.test_tensor_scalar_ops_cuda-1.expect | 11 -------- .../TestFuser.test_tensor_scalar_ops_cuda-2.expect | 8 ------ test/test_jit.py | 30 ++++++++-------------- 4 files changed, 11 insertions(+), 54 deletions(-) delete mode 100644 test/expect/TestFuser.test_last_device_cuda.expect delete mode 100644 test/expect/TestFuser.test_tensor_scalar_ops_cuda-1.expect delete mode 100644 test/expect/TestFuser.test_tensor_scalar_ops_cuda-2.expect diff --git a/test/expect/TestFuser.test_last_device_cuda.expect b/test/expect/TestFuser.test_last_device_cuda.expect deleted file mode 100644 index b2ef06b..0000000 --- a/test/expect/TestFuser.test_last_device_cuda.expect +++ /dev/null @@ -1,16 +0,0 @@ -graph(%x : Float(*) - %y : Float(*)) { - %2 : Float(*) = prim::FusionGroup_0(%x, %y) - return (%2); -} -with prim::FusionGroup_0 = graph(%0 : Float(*) - %1 : Float(*)) { - %2 : int = prim::Constant[value=1]() - %3 : Float(*) = aten::add(%0, %1, %2) - %4 : Float(*) = aten::mul(%0, %3) - %5 : int = prim::Constant[value=1]() - %6 : Float(*) = aten::add(%4, %0, %5) - %7 : Float(*) = aten::tanh(%6) - %8 : Float(*) = aten::sigmoid(%7) - return (%8); -} diff --git a/test/expect/TestFuser.test_tensor_scalar_ops_cuda-1.expect b/test/expect/TestFuser.test_tensor_scalar_ops_cuda-1.expect deleted file mode 100644 index 60ccea5..0000000 --- a/test/expect/TestFuser.test_tensor_scalar_ops_cuda-1.expect +++ /dev/null @@ -1,11 +0,0 @@ -graph(%x : Float(*, *)) { - %1 : Float(*, *) = prim::FusionGroup_0(%x) - return (%1); -} -with prim::FusionGroup_0 = graph(%0 : Float(*, *)) { - %z : float = prim::Constant[value=3]() - %2 : int = prim::Constant[value=1]() - %y : Float(*, *) = aten::add(%0, %z, %2) - %4 : Float(*, *) = aten::mul(%0, %y) - return (%4); -} diff --git a/test/expect/TestFuser.test_tensor_scalar_ops_cuda-2.expect b/test/expect/TestFuser.test_tensor_scalar_ops_cuda-2.expect deleted file mode 100644 index 0748792..0000000 --- a/test/expect/TestFuser.test_tensor_scalar_ops_cuda-2.expect +++ /dev/null @@ -1,8 +0,0 @@ -graph(%x : Float(*, *) - %z : Float()) { - %2 : int = prim::Constant[value=1]() - %3 : int = prim::Int(%z) - %y : Float(*, *) = aten::add(%x, %3, %2) - %5 : Float(*, *) = aten::mul(%x, %y) - return (%5); -} diff --git a/test/test_jit.py b/test/test_jit.py index 8d88e8c..f8d946a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -10070,11 +10070,10 @@ class TestFuser(JitTestCase): # XXX: This assumes that the same kernel isn't already used by another test self.assertEqual(new_cache_size - prev_cache_size, 1) - # TODO: This test doesn't offer anything valuable, maybe we should delete it @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") @skipIfRocm - def test_last_device_cuda(self): + def test_nonzero_device_cuda(self): device = 'cuda:' + str(1) x = torch.tensor([0.4], dtype=torch.float, device=device) y = torch.tensor([0.7], dtype=torch.float, device=device) @@ -10083,7 +10082,7 @@ class TestFuser(JitTestCase): return torch.sigmoid(torch.tanh(x * (x + y) + x)) ge = self.checkTrace(doit, (x, y)) - self.assertExpectedGraph(ge.graph_for(x, y)) + self.assertAllFused(ge.graph_for(x, y)) @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @@ -10212,6 +10211,7 @@ class TestFuser(JitTestCase): y = torch.randn(4, 4, dtype=torch.float, device='cuda') ge = self.checkTrace(self.fn_test_relu, (x, y)) + self.assertAllFused(ge.graph_for(x, y)) @staticmethod def fn_test_erf(x): @@ -10266,14 +10266,15 @@ class TestFuser(JitTestCase): inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')] ge = self.checkScript(should_fuse, inputs) - self.assertExpectedGraph(ge.graph_for(*inputs), subname='1') + self.assertAllFused(ge.graph_for(*inputs)) inputs = [ torch.randn(2, 2, dtype=torch.float, device='cuda'), torch.tensor(3., dtype=torch.float, device='cuda'), ] ge = self.checkScript(should_not_fuse, inputs) - self.assertExpectedGraph(ge.graph_for(*inputs), subname='2') + self.assertGraphContainsExactly( + ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True) @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle") @enable_cpu_fuser @@ -10294,30 +10295,21 @@ class TestFuser(JitTestCase): self.assertEqual(result2, expected2) self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'}) - # TODO: This test seems dead - @unittest.skipIf(not IS_WINDOWS, "Testing Fuse skipped on windows") + @unittest.skipIf(not IS_WINDOWS, "Test that the fuser is disabled on Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - def test_windows(self): + def test_windows_cuda(self): def scaleshift(x, scale, shift): return x * scale + shift - graph = torch.jit.script(scaleshift).graph - inputs = [ torch.randn(4, 4, dtype=torch.float, device='cuda'), torch.randn(4, dtype=torch.float, device='cuda'), torch.randn(4, dtype=torch.float, device='cuda'), ] - ge = self.checkTrace(scaleshift, inputs) - fuse_graph = ge.graph_for(*inputs) - - def run_graph(graph, inputs): - m = torch.jit.ScriptModule() - m._create_method_from_graph("forward", graph) - return m(*inputs) - - self.assertEqual(run_graph(graph, inputs), run_graph(fuse_graph, inputs)) + ge = self.checkScript(scaleshift, inputs) + self.assertGraphContainsExactly( + ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True) # NB: torch.jit.script, when used as a function, uses the current scope -- 2.7.4