Minor cleanup for TestFuser tests (#15134)
authorRichard Zou <zou3519@gmail.com>
Wed, 19 Dec 2018 00:13:39 +0000 (16:13 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 00:33:59 +0000 (16:33 -0800)
Summary:
Changelog:
- change some expect tests that didn't have to be expect tests,
  instead use self.assertAllFused
- Some of the fuser tests weren't using self.assertAllFused.
- Minor test renames

cc apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15134

Differential Revision: D13507481

Pulled By: zou3519

fbshipit-source-id: dd0788530a60bb5ed2f42b961fae3db2b4404b64

test/expect/TestFuser.test_last_device_cuda.expect [deleted file]
test/expect/TestFuser.test_tensor_scalar_ops_cuda-1.expect [deleted file]
test/expect/TestFuser.test_tensor_scalar_ops_cuda-2.expect [deleted file]
test/test_jit.py

diff --git a/test/expect/TestFuser.test_last_device_cuda.expect b/test/expect/TestFuser.test_last_device_cuda.expect
deleted file mode 100644 (file)
index b2ef06b..0000000
+++ /dev/null
@@ -1,16 +0,0 @@
-graph(%x : Float(*)
-      %y : Float(*)) {
-  %2 : Float(*) = prim::FusionGroup_0(%x, %y)
-  return (%2);
-}
-with prim::FusionGroup_0 = graph(%0 : Float(*)
-      %1 : Float(*)) {
-  %2 : int = prim::Constant[value=1]()
-  %3 : Float(*) = aten::add(%0, %1, %2)
-  %4 : Float(*) = aten::mul(%0, %3)
-  %5 : int = prim::Constant[value=1]()
-  %6 : Float(*) = aten::add(%4, %0, %5)
-  %7 : Float(*) = aten::tanh(%6)
-  %8 : Float(*) = aten::sigmoid(%7)
-  return (%8);
-}
diff --git a/test/expect/TestFuser.test_tensor_scalar_ops_cuda-1.expect b/test/expect/TestFuser.test_tensor_scalar_ops_cuda-1.expect
deleted file mode 100644 (file)
index 60ccea5..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-graph(%x : Float(*, *)) {
-  %1 : Float(*, *) = prim::FusionGroup_0(%x)
-  return (%1);
-}
-with prim::FusionGroup_0 = graph(%0 : Float(*, *)) {
-  %z : float = prim::Constant[value=3]()
-  %2 : int = prim::Constant[value=1]()
-  %y : Float(*, *) = aten::add(%0, %z, %2)
-  %4 : Float(*, *) = aten::mul(%0, %y)
-  return (%4);
-}
diff --git a/test/expect/TestFuser.test_tensor_scalar_ops_cuda-2.expect b/test/expect/TestFuser.test_tensor_scalar_ops_cuda-2.expect
deleted file mode 100644 (file)
index 0748792..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-graph(%x : Float(*, *)
-      %z : Float()) {
-  %2 : int = prim::Constant[value=1]()
-  %3 : int = prim::Int(%z)
-  %y : Float(*, *) = aten::add(%x, %3, %2)
-  %5 : Float(*, *) = aten::mul(%x, %y)
-  return (%5);
-}
index 8d88e8c..f8d946a 100644 (file)
@@ -10070,11 +10070,10 @@ class TestFuser(JitTestCase):
         # XXX: This assumes that the same kernel isn't already used by another test
         self.assertEqual(new_cache_size - prev_cache_size, 1)
 
-    # TODO: This test doesn't offer anything valuable, maybe we should delete it
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
     @skipIfRocm
-    def test_last_device_cuda(self):
+    def test_nonzero_device_cuda(self):
         device = 'cuda:' + str(1)
         x = torch.tensor([0.4], dtype=torch.float, device=device)
         y = torch.tensor([0.7], dtype=torch.float, device=device)
@@ -10083,7 +10082,7 @@ class TestFuser(JitTestCase):
             return torch.sigmoid(torch.tanh(x * (x + y) + x))
 
         ge = self.checkTrace(doit, (x, y))
-        self.assertExpectedGraph(ge.graph_for(x, y))
+        self.assertAllFused(ge.graph_for(x, y))
 
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -10212,6 +10211,7 @@ class TestFuser(JitTestCase):
         y = torch.randn(4, 4, dtype=torch.float, device='cuda')
 
         ge = self.checkTrace(self.fn_test_relu, (x, y))
+        self.assertAllFused(ge.graph_for(x, y))
 
     @staticmethod
     def fn_test_erf(x):
@@ -10266,14 +10266,15 @@ class TestFuser(JitTestCase):
 
         inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
         ge = self.checkScript(should_fuse, inputs)
-        self.assertExpectedGraph(ge.graph_for(*inputs), subname='1')
+        self.assertAllFused(ge.graph_for(*inputs))
 
         inputs = [
             torch.randn(2, 2, dtype=torch.float, device='cuda'),
             torch.tensor(3., dtype=torch.float, device='cuda'),
         ]
         ge = self.checkScript(should_not_fuse, inputs)
-        self.assertExpectedGraph(ge.graph_for(*inputs), subname='2')
+        self.assertGraphContainsExactly(
+            ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
 
     @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
     @enable_cpu_fuser
@@ -10294,30 +10295,21 @@ class TestFuser(JitTestCase):
         self.assertEqual(result2, expected2)
         self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
 
-    # TODO: This test seems dead
-    @unittest.skipIf(not IS_WINDOWS, "Testing Fuse skipped on windows")
+    @unittest.skipIf(not IS_WINDOWS, "Test that the fuser is disabled on Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
-    def test_windows(self):
+    def test_windows_cuda(self):
         def scaleshift(x, scale, shift):
             return x * scale + shift
 
-        graph = torch.jit.script(scaleshift).graph
-
         inputs = [
             torch.randn(4, 4, dtype=torch.float, device='cuda'),
             torch.randn(4, dtype=torch.float, device='cuda'),
             torch.randn(4, dtype=torch.float, device='cuda'),
         ]
 
-        ge = self.checkTrace(scaleshift, inputs)
-        fuse_graph = ge.graph_for(*inputs)
-
-        def run_graph(graph, inputs):
-            m = torch.jit.ScriptModule()
-            m._create_method_from_graph("forward", graph)
-            return m(*inputs)
-
-        self.assertEqual(run_graph(graph, inputs), run_graph(fuse_graph, inputs))
+        ge = self.checkScript(scaleshift, inputs)
+        self.assertGraphContainsExactly(
+            ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
 
 
 # NB: torch.jit.script, when used as a function, uses the current scope