from common_methods_invocations import method_tests as autograd_method_tests
from common_methods_invocations import create_input, unpack_variables, \
exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
+from torch.testing import FileCheck
from copy import deepcopy
import random
from typing import List, Dict, Optional
trace, _ = torch.jit.get_trace_graph(fn, (x, y))
self.run_pass('cse', trace)
- self.assertExpectedGraph(trace)
+ do_exactly = True
+ FileCheck().check_count("add", 1).check_count("mul", 2, do_exactly) \
+ .check_count("tanh", 1, do_exactly).check_count("add", 2, do_exactly).check_next("return") \
+ .run(str(trace))
+
self.assertExportImport(trace, (x, y))
def test_recursive_cse(self):
graph = torch.jit.script(fn).graph
self.run_pass('cse', graph)
- self.assertExpectedGraph(graph)
-
- def test_scalar(self):
- # NB: must not require grad; if it requires grad, it's always a Tensor
- x = torch.tensor(2.)
- y = torch.tensor(3.)
-
- def fn(x, y):
- return x - y
- trace, _ = torch.jit.get_trace_graph(fn, (x, y))
+ FileCheck().check("block").check_not("aten::add").check_not("aten::gt").run(str(graph))
def test_shape_analysis_broadcast(self):
def broadcast(a, b):
return y
trace, _ = torch.jit.get_trace_graph(fn, (x,))
- self.assertExpectedGraph(trace)
+ FileCheck().check_count("aten::clone", 1, exactly=True) \
+ .check_count("aten::add_", 2, exactly=True) \
+ .check_next("return").run(str(trace))
self.assertExportImport(trace, (x,))
def test_inplace_flags(self):
x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
traced_fn = torch.jit.trace(fn, (x, y))
self.assertEqual(traced_fn(x, y), fn(x, y))
- self.assertExpectedGraph(traced_fn.graph)
+ # should be a tuple nested within another tuple
+ FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next("return") \
+ .run(str(traced_fn.graph))
self.assertExportImport(traced_fn.graph, (x, y))
def test_trace_random(self):
warnings.warn("x is less than 2")
return x
- self.assertExpectedGraph(fn.graph)
+ FileCheck().check("aten::warn").run(str(fn.graph))
def test_no_erroneous_warnings(self):
import warnings
inputs = self._make_scalar_vars([-1234, 4321], torch.int64)
self.checkScript(func, inputs, optimize=True)
- def test_math_schema(self):
- # This should use the add(Tensor, Tensor) schema.
- # Also tests to see if alpha={1} is lifted correctly.
- def fn(x, y):
- return x + y
-
- graph = torch.jit.script(fn).graph
- self.assertExpectedGraph(graph)
-
- def test_math_tensor_number(self):
- # Test that 7 is casted to tensor, then casted to the
- # correct type, and finally added to x.
- def fn(x):
- return x + 7
-
- graph = torch.jit.script(fn).graph
- self.assertExpectedGraph(graph)
-
- def test_math_numbers(self):
- # Test that the numbers are casted to tensor,
- # added, and then casted back.
- def fn1(x):
- return 7 + 8
-
- def fn2(x):
- return 1.1 + 3.1
-
- graph1 = torch.jit.script(fn1).graph
- self.assertExpectedGraph(graph1, subname="int")
- graph2 = torch.jit.script(fn2).graph
- self.assertExpectedGraph(graph2, subname="float")
-
def test_math_ops(self):
def test_floor():
c1 = 0
return c1
+ self.assertEqual(0, testNoThrows(torch.randn(0)))
+ ifs = testNoThrows.graph.findAllNodes("prim::If", recurse=False)
+
+ # three ifs at the top level, and the second one has a nested if for
+ # the or (True or bool(t[1])) expression
+ self.assertTrue(len(ifs) == 3)
+ self.assertTrue(ifs[0].findNode("prim::If") is None)
+ self.assertTrue(ifs[1].findNode("prim::If").findNode("prim::If") is None)
+ self.assertTrue(ifs[2].findNode("prim::If") is None)
+
@torch.jit.script
def throwsOr(t):
c0 = False or bool(t[1])
print(c0)
t = torch.randn(0)
- self.assertEqual(0, testNoThrows(torch.randn(0)))
- self.assertExpectedGraph(testNoThrows.graph)
with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
throwsOr(t)
with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
def test_filecheck(self):
- from torch.testing import FileCheck
-
- # def test_accidental_not_used():
- # def unused():
- # a = FileCheck()
- #
- # with self.capture_stdout() as captured:
- # a = FileCheck()
- # del a
- # self.assertTrue("You have not run this instance of FileCheck"
- # in captured[0])
- #
- # test_accidental_not_used()
def test_check():
file = "232"
FileCheck().check("2").check("3").check("2").run(file)
FileCheck().check_count("22", 2).run(file)
FileCheck().check_count("222", 1).run(file)
+ with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
+ FileCheck().check_count("2", 4, exactly=True).run(file)
+
with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
FileCheck().check_count("22", 3).run(file)
def test_check_same():
file = "22\n33"
- # FileCheck().check_same("22").run(file)
+ FileCheck().check_same("22").run(file)
with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
FileCheck().check_same("33").run(file)
return c
graph = torch.jit.script(func).graph
+ FileCheck().check("int = prim::Constant").check("aten::add_").run(str(graph))
self.run_pass('remove_inplace_ops', graph)
self.run_pass('erase_number_types', graph)
- self.assertExpectedGraph(graph)
+ self.run_pass('dce', graph)
+ FileCheck().check_not("int = prim::Constant").check_not("aten::add_").run(str(graph))
def test_mm_batching(self):
lstm_cell = torch.jit.script(LSTMCellS)