Batch of Expect Files removal (#17414)
authoreellison <elias_ellison@brown.edu>
Sat, 23 Feb 2019 01:54:09 +0000 (17:54 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 23 Feb 2019 02:11:51 +0000 (18:11 -0800)
Summary:
Batch of removing expect files, and some tests that no longer test anything.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17414

Differential Revision: D14196342

Pulled By: eellison

fbshipit-source-id: 75c45649d1dd1ce39958fb02f5b7a2622c1d1d01

13 files changed:
test/expect/TestJit.test_cse.expect [deleted file]
test/expect/TestJit.test_inplace_transplant.expect [deleted file]
test/expect/TestJit.test_trace_tuple.expect [deleted file]
test/expect/TestJit.test_warnings.expect [deleted file]
test/expect/TestScript.test_erase_number_types.expect [deleted file]
test/expect/TestScript.test_logical_short_circuit.expect [deleted file]
test/expect/TestScript.test_math_numbers-float.expect [deleted file]
test/expect/TestScript.test_math_numbers-int.expect [deleted file]
test/expect/TestScript.test_math_schema.expect [deleted file]
test/test_jit.py
torch/csrc/jit/script/init.cpp
torch/csrc/jit/testing/file_check.cpp
torch/csrc/jit/testing/file_check.h

diff --git a/test/expect/TestJit.test_cse.expect b/test/expect/TestJit.test_cse.expect
deleted file mode 100644 (file)
index b4b94c6..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-graph(%0 : Double(2),
-      %1 : Double(2)):
-  %2 : int = prim::Constant[value=1]()
-  %3 : Double(2) = aten::add(%0, %1, %2)
-  %4 : Double(2) = aten::mul(%3, %3)
-  %5 : Double(2) = aten::mul(%4, %3)
-  %6 : Double(2) = aten::tanh(%5)
-  %7 : Double(2) = aten::add(%6, %6, %2)
-  %8 : Double(2) = aten::add(%5, %7, %2)
-  return (%8)
diff --git a/test/expect/TestJit.test_inplace_transplant.expect b/test/expect/TestJit.test_inplace_transplant.expect
deleted file mode 100644 (file)
index dff048b..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-graph(%0 : Double(1)):
-  %1 : Double(1) = aten::clone(%0)
-  %2 : Long() = prim::Constant[value={2}]()
-  %3 : int = prim::Constant[value=1]()
-  %4 : Double(1) = aten::add_(%1, %2, %3)
-  %5 : Long() = prim::Constant[value={3}]()
-  %6 : int = prim::Constant[value=1]()
-  %7 : Double(1) = aten::add_(%4, %5, %6)
-  return (%7)
diff --git a/test/expect/TestJit.test_trace_tuple.expect b/test/expect/TestJit.test_trace_tuple.expect
deleted file mode 100644 (file)
index d63b7fb..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-graph(%x : Double(2, 2),
-      %1 : (Double(2, 2), Double(2, 2))):
-  %2 : Double(2, 2), %3 : Double(2, 2) = prim::TupleUnpack(%1)
-  %4 : Double(2, 2) = aten::mul(%x, %3)
-  %5 : Double(2, 2) = aten::mul(%x, %2)
-  %6 : (Double(2, 2), Double(2, 2)) = prim::TupleConstruct(%4, %5)
-  %7 : (Double(2, 2), (Double(2, 2), Double(2, 2))) = prim::TupleConstruct(%x, %6)
-  return (%7)
diff --git a/test/expect/TestJit.test_warnings.expect b/test/expect/TestJit.test_warnings.expect
deleted file mode 100644 (file)
index 4401dfe..0000000
+++ /dev/null
@@ -1,12 +0,0 @@
-graph(%x : Tensor):
-  %1 : string = prim::Constant[value="x is less than 2"]()
-  %2 : int = prim::Constant[value=2]()
-  %3 : Tensor = aten::lt(%x, %2)
-  %4 : bool = prim::Bool(%3)
-   = prim::If(%4)
-    block0():
-       = aten::warn(%1, %2)
-      -> ()
-    block1():
-      -> ()
-  return (%x)
diff --git a/test/expect/TestScript.test_erase_number_types.expect b/test/expect/TestScript.test_erase_number_types.expect
deleted file mode 100644 (file)
index b04fae4..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-graph(%a : Tensor):
-  %1 : Long() = prim::Constant[value={7}]()
-  %2 : Long() = prim::Constant[value={1}]()
-  %3 : Long() = prim::Constant[value={3}]()
-  %4 : Long() = aten::add(%1, %2)
-  %b : Long() = aten::add(%4, %3)
-  %c.1 : Tensor = aten::add(%a, %b, %2)
-  %c : Tensor = aten::add(%c.1, %b, %2)
-  return (%c)
diff --git a/test/expect/TestScript.test_logical_short_circuit.expect b/test/expect/TestScript.test_logical_short_circuit.expect
deleted file mode 100644 (file)
index 8a3081d..0000000
+++ /dev/null
@@ -1,30 +0,0 @@
-graph(%t : Tensor):
-  %1 : bool = prim::Constant[value=1]()
-  %2 : bool = prim::Constant[value=0]()
-  %c1.1 : int = prim::Constant[value=1]()
-  %4 : int = prim::Constant[value=0]()
-  %5 : bool = prim::If(%2)
-    block0():
-      %6 : Tensor = aten::select(%t, %4, %c1.1)
-      %7 : bool = prim::Bool(%6)
-      -> (%7)
-    block1():
-      -> (%2)
-  %8 : bool = prim::If(%5)
-    block0():
-      -> (%5)
-    block1():
-      %9 : bool = prim::If(%1)
-        block0():
-          -> (%1)
-        block1():
-          %10 : Tensor = aten::select(%t, %4, %c1.1)
-          %11 : bool = prim::Bool(%10)
-          -> (%11)
-      -> (%9)
-  %c1 : int = prim::If(%8)
-    block0():
-      -> (%4)
-    block1():
-      -> (%c1.1)
-  return (%c1)
diff --git a/test/expect/TestScript.test_math_numbers-float.expect b/test/expect/TestScript.test_math_numbers-float.expect
deleted file mode 100644 (file)
index b3b13a7..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Tensor):
-  %1 : float = prim::Constant[value=1.1]()
-  %2 : float = prim::Constant[value=3.1]()
-  %3 : float = aten::add(%1, %2)
-  return (%3)
diff --git a/test/expect/TestScript.test_math_numbers-int.expect b/test/expect/TestScript.test_math_numbers-int.expect
deleted file mode 100644 (file)
index 0c84227..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Tensor):
-  %1 : int = prim::Constant[value=7]()
-  %2 : int = prim::Constant[value=8]()
-  %3 : int = aten::add(%1, %2)
-  return (%3)
diff --git a/test/expect/TestScript.test_math_schema.expect b/test/expect/TestScript.test_math_schema.expect
deleted file mode 100644 (file)
index 2d56aeb..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-graph(%x : Tensor,
-      %y : Tensor):
-  %2 : int = prim::Constant[value=1]()
-  %3 : Tensor = aten::add(%x, %y, %2)
-  return (%3)
index 9a425af..8ad6e2c 100644 (file)
@@ -39,6 +39,7 @@ import copy
 from common_methods_invocations import method_tests as autograd_method_tests
 from common_methods_invocations import create_input, unpack_variables, \
     exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
+from torch.testing import FileCheck
 from copy import deepcopy
 import random
 from typing import List, Dict, Optional
@@ -949,7 +950,11 @@ class TestJit(JitTestCase):
 
         trace, _ = torch.jit.get_trace_graph(fn, (x, y))
         self.run_pass('cse', trace)
-        self.assertExpectedGraph(trace)
+        do_exactly = True
+        FileCheck().check_count("add", 1).check_count("mul", 2, do_exactly) \
+            .check_count("tanh", 1, do_exactly).check_count("add", 2, do_exactly).check_next("return")  \
+            .run(str(trace))
+
         self.assertExportImport(trace, (x, y))
 
     def test_recursive_cse(self):
@@ -964,16 +969,7 @@ class TestJit(JitTestCase):
 
         graph = torch.jit.script(fn).graph
         self.run_pass('cse', graph)
-        self.assertExpectedGraph(graph)
-
-    def test_scalar(self):
-        # NB: must not require grad; if it requires grad, it's always a Tensor
-        x = torch.tensor(2.)
-        y = torch.tensor(3.)
-
-        def fn(x, y):
-            return x - y
-        trace, _ = torch.jit.get_trace_graph(fn, (x, y))
+        FileCheck().check("block").check_not("aten::add").check_not("aten::gt").run(str(graph))
 
     def test_shape_analysis_broadcast(self):
         def broadcast(a, b):
@@ -1031,7 +1027,9 @@ class TestJit(JitTestCase):
             return y
 
         trace, _ = torch.jit.get_trace_graph(fn, (x,))
-        self.assertExpectedGraph(trace)
+        FileCheck().check_count("aten::clone", 1, exactly=True) \
+            .check_count("aten::add_", 2, exactly=True) \
+            .check_next("return").run(str(trace))
         self.assertExportImport(trace, (x,))
 
     def test_inplace_flags(self):
@@ -1178,7 +1176,9 @@ class TestJit(JitTestCase):
         x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
         traced_fn = torch.jit.trace(fn, (x, y))
         self.assertEqual(traced_fn(x, y), fn(x, y))
-        self.assertExpectedGraph(traced_fn.graph)
+        # should be a tuple nested within another tuple
+        FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next("return") \
+            .run(str(traced_fn.graph))
         self.assertExportImport(traced_fn.graph, (x, y))
 
     def test_trace_random(self):
@@ -2242,7 +2242,7 @@ class TestJit(JitTestCase):
                 warnings.warn("x is less than 2")
             return x
 
-        self.assertExpectedGraph(fn.graph)
+        FileCheck().check("aten::warn").run(str(fn.graph))
 
     def test_no_erroneous_warnings(self):
         import warnings
@@ -4673,38 +4673,6 @@ a")
         inputs = self._make_scalar_vars([-1234, 4321], torch.int64)
         self.checkScript(func, inputs, optimize=True)
 
-    def test_math_schema(self):
-        # This should use the add(Tensor, Tensor) schema.
-        # Also tests to see if alpha={1} is lifted correctly.
-        def fn(x, y):
-            return x + y
-
-        graph = torch.jit.script(fn).graph
-        self.assertExpectedGraph(graph)
-
-    def test_math_tensor_number(self):
-        # Test that 7 is casted to tensor, then casted to the
-        # correct type, and finally added to x.
-        def fn(x):
-            return x + 7
-
-        graph = torch.jit.script(fn).graph
-        self.assertExpectedGraph(graph)
-
-    def test_math_numbers(self):
-        # Test that the numbers are casted to tensor,
-        # added, and then casted back.
-        def fn1(x):
-            return 7 + 8
-
-        def fn2(x):
-            return 1.1 + 3.1
-
-        graph1 = torch.jit.script(fn1).graph
-        self.assertExpectedGraph(graph1, subname="int")
-        graph2 = torch.jit.script(fn2).graph
-        self.assertExpectedGraph(graph2, subname="float")
-
     def test_math_ops(self):
 
         def test_floor():
@@ -4883,6 +4851,16 @@ a")
                 c1 = 0
             return c1
 
+        self.assertEqual(0, testNoThrows(torch.randn(0)))
+        ifs = testNoThrows.graph.findAllNodes("prim::If", recurse=False)
+
+        # three ifs at the top level, and the second one has a nested if for
+        # the or (True or bool(t[1])) expression
+        self.assertTrue(len(ifs) == 3)
+        self.assertTrue(ifs[0].findNode("prim::If") is None)
+        self.assertTrue(ifs[1].findNode("prim::If").findNode("prim::If") is None)
+        self.assertTrue(ifs[2].findNode("prim::If") is None)
+
         @torch.jit.script
         def throwsOr(t):
             c0 = False or bool(t[1])
@@ -4894,8 +4872,6 @@ a")
             print(c0)
 
         t = torch.randn(0)
-        self.assertEqual(0, testNoThrows(torch.randn(0)))
-        self.assertExpectedGraph(testNoThrows.graph)
         with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
             throwsOr(t)
         with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
@@ -5663,19 +5639,6 @@ a")
         self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
 
     def test_filecheck(self):
-        from torch.testing import FileCheck
-
-        # def test_accidental_not_used():
-        #     def unused():
-        #         a = FileCheck()
-        #
-        #     with self.capture_stdout() as captured:
-        #         a = FileCheck()
-        #         del a
-        #     self.assertTrue("You have not run this instance of FileCheck"
-        #                     in captured[0])
-        #
-        # test_accidental_not_used()
         def test_check():
             file = "232"
             FileCheck().check("2").check("3").check("2").run(file)
@@ -5694,6 +5657,9 @@ a")
             FileCheck().check_count("22", 2).run(file)
             FileCheck().check_count("222", 1).run(file)
 
+            with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
+                FileCheck().check_count("2", 4, exactly=True).run(file)
+
             with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
                 FileCheck().check_count("22", 3).run(file)
 
@@ -5704,7 +5670,7 @@ a")
 
         def test_check_same():
             file = "22\n33"
-            FileCheck().check_same("22").run(file)
+            FileCheck().check_same("22").run(file)
 
             with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
                 FileCheck().check_same("33").run(file)
@@ -7878,9 +7844,11 @@ a")
             return c
 
         graph = torch.jit.script(func).graph
+        FileCheck().check("int = prim::Constant").check("aten::add_").run(str(graph))
         self.run_pass('remove_inplace_ops', graph)
         self.run_pass('erase_number_types', graph)
-        self.assertExpectedGraph(graph)
+        self.run_pass('dce', graph)
+        FileCheck().check_not("int = prim::Constant").check_not("aten::add_").run(str(graph))
 
     def test_mm_batching(self):
         lstm_cell = torch.jit.script(LSTMCellS)
index 372b49d..978810d 100644 (file)
@@ -488,7 +488,8 @@ std::shared_ptr<SugaredValue> toSugaredValue(
     } else if (py::isinstance<py::float_>(obj)) {
       return toSimple(g.insertConstant(py::cast<double>(obj), nullptr, loc));
     } else if (py::isinstance<py::str>(obj)) {
-      return toSimple(g.insertConstant(py::cast<std::string>(obj), nullptr, loc));
+      return toSimple(
+          g.insertConstant(py::cast<std::string>(obj), nullptr, loc));
     } else if (obj.is(py::none())) {
       return toSimple(g.insertConstant(IValue(), nullptr, loc));
     } else if (THPDevice_Check(obj.ptr())) {
@@ -605,7 +606,6 @@ Resolver pythonResolver(const ResolutionCallback& rcb) {
     return toSugaredValue(obj, m, loc);
   };
 }
-
 } // namespace
 
 FunctionSchema getSchemaWithNameAndDefaults(
@@ -1005,9 +1005,18 @@ void initJitScriptBindings(PyObject* module) {
       .def("check_count", &testing::FileCheck::check_count)
       .def("check_dag", &testing::FileCheck::check_dag)
       .def("check_count", &testing::FileCheck::check_count)
+      .def(
+          "check_count",
+          [](testing::FileCheck& f,
+             const std::string& str,
+             size_t count,
+             bool exactly) { return f.check_count(str, count, exactly); },
+          "Check Count",
+          py::arg("str"),
+          py::arg("count"),
+          py::arg("exactly") = false)
       .def("run", &testing::FileCheck::run);
 }
-
 } // namespace script
 } // namespace jit
 } // namespace torch
index 71ac016..8025e5f 100644 (file)
@@ -85,8 +85,8 @@ size_t assertFind(
     const Check& check) {
   auto pos = search_range.file_ptr()->find(sub, search_range.start());
   if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
-    auto found_range = SourceRange(search_range.file_ptr(),search_range.start(),
-        sub.size());
+    auto found_range =
+        SourceRange(search_range.file_ptr(), search_range.start(), sub.size());
     std::stringstream ss;
     ss << "Expected to find ";
     printQuotedString(ss, sub);
@@ -112,7 +112,8 @@ void assertNotFind(
     const Check& check) {
   auto pos = search_range.file_ptr()->find(sub, search_range.start());
   if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) {
-    auto found_range = SourceRange(search_range.file_ptr(), pos, sub.size() + pos);
+    auto found_range =
+        SourceRange(search_range.file_ptr(), pos, sub.size() + pos);
     std::stringstream ss;
     ss << "Expected to not find ";
     printQuotedString(ss, sub);
@@ -311,8 +312,14 @@ FileCheck* FileCheck::check_next(const std::string& str) {
   return this;
 }
 
-FileCheck* FileCheck::check_count(const std::string& str, size_t count) {
+FileCheck* FileCheck::check_count(
+    const std::string& str,
+    size_t count,
+    bool exactly) {
   fcImpl->addCheck(CHECK_COUNT, str, count);
+  if (exactly) {
+    fcImpl->addCheck(CHECK_NOT, str);
+  }
   return this;
 }
 
@@ -320,7 +327,6 @@ FileCheck* FileCheck::check_dag(const std::string& str) {
   fcImpl->addCheck(CHECK_DAG, str);
   return this;
 }
-
 } // namespace testing
 } // namespace jit
 } // namespace torch
index fe4d46a..3043125 100644 (file)
@@ -32,8 +32,12 @@ struct FileCheck {
   // previous match
   TORCH_API FileCheck* check_next(const std::string& str);
 
-  // Checks that the string occurs count number of times
-  TORCH_API FileCheck* check_count(const std::string& str, size_t count);
+  // Checks that the string occurs count number of times. If exactly is true,
+  // checks that there are exactly count many matches
+  TORCH_API FileCheck* check_count(
+      const std::string& str,
+      size_t count,
+      bool exactly = false);
 
   // A series of consecutive check_dags get turned into a group of checks
   // which can appear in any order relative to each other.
@@ -46,7 +50,6 @@ struct FileCheck {
   bool has_run = false;
   std::unique_ptr<FileCheckImpl> fcImpl;
 };
-
 } // namespace testing
 } // namespace jit
 } // namespace torch