# XXX: This assumes that the same kernel isn't already used by another test
self.assertEqual(new_cache_size - prev_cache_size, 1)
- # TODO: This test doesn't offer anything valuable, maybe we should delete it
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
@skipIfRocm
- def test_last_device_cuda(self):
+ def test_nonzero_device_cuda(self):
device = 'cuda:' + str(1)
x = torch.tensor([0.4], dtype=torch.float, device=device)
y = torch.tensor([0.7], dtype=torch.float, device=device)
return torch.sigmoid(torch.tanh(x * (x + y) + x))
ge = self.checkTrace(doit, (x, y))
- self.assertExpectedGraph(ge.graph_for(x, y))
+ self.assertAllFused(ge.graph_for(x, y))
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
ge = self.checkTrace(self.fn_test_relu, (x, y))
+ self.assertAllFused(ge.graph_for(x, y))
@staticmethod
def fn_test_erf(x):
inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
ge = self.checkScript(should_fuse, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs), subname='1')
+ self.assertAllFused(ge.graph_for(*inputs))
inputs = [
torch.randn(2, 2, dtype=torch.float, device='cuda'),
torch.tensor(3., dtype=torch.float, device='cuda'),
]
ge = self.checkScript(should_not_fuse, inputs)
- self.assertExpectedGraph(ge.graph_for(*inputs), subname='2')
+ self.assertGraphContainsExactly(
+ ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
@enable_cpu_fuser
self.assertEqual(result2, expected2)
self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
- # TODO: This test seems dead
- @unittest.skipIf(not IS_WINDOWS, "Testing Fuse skipped on windows")
+ @unittest.skipIf(not IS_WINDOWS, "Test that the fuser is disabled on Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
- def test_windows(self):
+ def test_windows_cuda(self):
def scaleshift(x, scale, shift):
return x * scale + shift
- graph = torch.jit.script(scaleshift).graph
-
inputs = [
torch.randn(4, 4, dtype=torch.float, device='cuda'),
torch.randn(4, dtype=torch.float, device='cuda'),
torch.randn(4, dtype=torch.float, device='cuda'),
]
- ge = self.checkTrace(scaleshift, inputs)
- fuse_graph = ge.graph_for(*inputs)
-
- def run_graph(graph, inputs):
- m = torch.jit.ScriptModule()
- m._create_method_from_graph("forward", graph)
- return m(*inputs)
-
- self.assertEqual(run_graph(graph, inputs), run_graph(fuse_graph, inputs))
+ ge = self.checkScript(scaleshift, inputs)
+ self.assertGraphContainsExactly(
+ ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
# NB: torch.jit.script, when used as a function, uses the current scope