Revert D14652372: [pytorch][PR] Add parsing to file check
authorElias Ellison <eellison@fb.com>
Thu, 28 Mar 2019 07:09:36 +0000 (00:09 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 28 Mar 2019 07:12:47 +0000 (00:12 -0700)
Differential Revision:
D14652372

Original commit changeset: 7430b9d1dc2b

fbshipit-source-id: fa3d0f68515fe53447746469844d2db20c1292e0

test/cpp/jit/test_irparser.h
test/test_jit.py
torch/csrc/jit/init.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/testing/file_check.cpp
torch/csrc/jit/testing/file_check.h

index af8d28a..c8c3eea 100644 (file)
@@ -2,7 +2,6 @@
 
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/irparser.h>
-#include <torch/csrc/jit/testing/file_check.h>
 #include "test/cpp/jit/test_base.h"
 
 #include <sstream>
@@ -212,19 +211,6 @@ graph(%0 : Tensor,
     }
     AT_ASSERT(error_thrown);
   }
-
-  {
-    auto graph = std::make_shared<Graph>();
-    const std::string& text =
-        R"IR(
-    graph(%a):
-    # CHECK: return
-      return (%a))IR";
-
-    script::parseIR(text, &*graph);
-    graph->inputs()[0]->type()->expect<TensorType>();
-    torch::jit::testing::FileCheck().run(text, *graph);
-  }
 }
 } // namespace jit
 } // namespace torch
index 028e926..b1b5c98 100644 (file)
@@ -43,7 +43,7 @@ from common_methods_invocations import create_input, unpack_variables, \
     exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
 from torch.testing import FileCheck
 from torch._C import TensorType, TupleType, FloatType, IntType, \
-    ListType, StringType, DictType, parse_ir
+    ListType, StringType, DictType
 from copy import deepcopy
 import random
 from typing import List, Dict, Optional, Tuple
@@ -6044,14 +6044,6 @@ a")
         m2.sub2.a.data.zero_()
         self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
 
-    def test_irparser(self):
-        graph_str = """graph(%0 : Double(5, 5)):
-          # CHECK: aten::relu
-          %1 : Double(5, 5) = aten::relu(%0)
-          return (%1)
-        """
-        FileCheck().run(graph_str, parse_ir(graph_str))
-
     def test_filecheck(self):
         def test_check():
             file = "232"
@@ -6144,59 +6136,6 @@ a")
             with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
                 fb.run("22 1 22")
 
-    def test_filecheck_parse(self):
-        def test_check():
-            file = """
-                # CHECK: 2
-                # CHECK: 3
-                # CHECK: 2
-                232
-                """
-            FileCheck().run(checks_file=file, test_file=file)
-            file = """
-                # CHECK: 232
-                232
-                """
-            FileCheck().run(file, "232")
-            with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'):
-                FileCheck().run(file, "22")
-            with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
-                FileCheck().run("# CHECK: 22", "23")
-        test_check()
-
-        def test_check_count():
-            file = "22222"
-            FileCheck().run("# CHECK-COUNT-5: 2", file)
-            FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file)
-            FileCheck().run("# CHECK-COUNT-2: 22", file)
-            FileCheck().run("# CHECK-COUNT-1: 222", file)
-
-            with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
-                FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file)
-        test_check_count()
-
-        def test_check_same():
-            file = "22\n33"
-            FileCheck().run("# CHECK-SAME: 22", file)
-
-            with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
-                FileCheck().run("# CHECK-SAME: 33", file)
-
-            file = "22  1  3"
-
-            FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file)
-            FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file)
-        test_check_same()
-
-        def test_bad_input():
-            with self.assertRaisesRegex(RuntimeError, "Check for bad input"):
-                FileCheck().run("", "1")
-
-            with self.assertRaisesRegex(RuntimeError, "Could not parse check"):
-                FileCheck().run("# CHECK1", "")
-
-        test_bad_input()
-
     def test_script_module_call_noscript(self):
         class M(torch.jit.ScriptModule):
             def __init__(self):
index a74edc3..f46dd9f 100644 (file)
@@ -8,7 +8,6 @@
 #include <torch/csrc/jit/fuser/kernel_cache.h>
 #include <torch/csrc/jit/graph_executor.h>
 #include <torch/csrc/jit/import.h>
-#include <torch/csrc/jit/irparser.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/canonicalize.h>
 #include <torch/csrc/jit/passes/canonicalize_ops.h>
@@ -78,6 +77,7 @@ bool loadPythonClasses() {
 
   return true;
 }
+
 } // anonymous namespace
 
 #if defined(_WIN32)
@@ -369,12 +369,6 @@ void initJITBindings(PyObject* module) {
       },
       py::arg("qualified_name"));
 
-  m.def("parse_ir", [](const std::string& input) {
-    auto graph = std::make_shared<Graph>();
-    script::parseIR(input, &*graph);
-    return graph;
-  });
-
   py::class_<FunctionSchema>(m, "FunctionSchema")
       .def_property_readonly(
           "name", [](FunctionSchema& self) { return self.name(); })
@@ -492,5 +486,6 @@ void initJITBindings(PyObject* module) {
   initBatchTensorBindings(module);
   initRegisterBatchOpsBindings(module);
 }
+
 } // namespace jit
 } // namespace torch
index 8d1fc66..0692763 100644 (file)
@@ -13,7 +13,6 @@
 #include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/hooks_for_testing.h>
 #include <torch/csrc/jit/import_source.h>
-#include <torch/csrc/jit/irparser.h>
 #include <torch/csrc/jit/passes/python_print.h>
 #include <torch/csrc/jit/passes/to_batch.h>
 #include <torch/csrc/jit/pybind_utils.h>
@@ -1094,24 +1093,9 @@ void initJitScriptBindings(PyObject* module) {
           [](testing::FileCheck& f, const std::string& str) {
             return f.run(str);
           })
-      .def(
-          "run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); })
-      .def(
-          "run",
-          [](testing::FileCheck& f,
-             const std::string& input,
-             const std::string& output) { return f.run(input, output); },
-          "Run",
-          py::arg("checks_file"),
-          py::arg("test_file"))
-      .def(
-          "run",
-          [](testing::FileCheck& f, const std::string& input, const Graph& g) {
-            return f.run(input, g);
-          },
-          "Run",
-          py::arg("checks_file"),
-          py::arg("graph"));
+      .def("run", [](testing::FileCheck& f, const Graph& g) {
+        return f.run(g);
+      });
 }
 } // namespace script
 } // namespace jit
index bcc9e90..3af8c5a 100644 (file)
@@ -79,11 +79,10 @@ std::ostream& operator<<(std::ostream& out, const Check& c) {
 };
 
 namespace {
-
 size_t assertFind(
     const SourceRange& search_range,
     const std::string& sub,
-    std::function<void(std::ostream& out)> extra_msg = nullptr) {
+    const Check& check) {
   auto pos = search_range.file_ptr()->find(sub, search_range.start());
   if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
     auto found_range =
@@ -93,24 +92,13 @@ size_t assertFind(
     printQuotedString(ss, sub);
     ss << " but did not find it\n";
     found_range.highlight(ss);
-    if (extra_msg) {
-      extra_msg(ss);
-    }
+    ss << "From " << check << "\n";
     throw std::runtime_error(ss.str());
   }
   return pos;
 }
 
 size_t assertFind(
-    const SourceRange& search_range,
-    const std::string& sub,
-    const Check& check) {
-  return assertFind(search_range, sub, [&](std::ostream& out) {
-    out << "From " << check << "\n";
-  });
-}
-
-size_t assertFind(
     const std::shared_ptr<std::string>& file,
     const std::string& sub,
     size_t start,
@@ -135,18 +123,6 @@ void assertNotFind(
     throw std::runtime_error(ss.str());
   }
 }
-
-size_t substringCount(
-    const std::shared_ptr<std::string>& file,
-    const std::string& sub) {
-  size_t occurances = 0;
-  std::string::size_type pos = 0;
-  while ((pos = file->find(sub, pos)) != std::string::npos) {
-    ++occurances;
-    pos += sub.length();
-  }
-  return occurances;
-}
 } // namespace
 
 struct FileCheckImpl {
@@ -154,45 +130,28 @@ struct FileCheckImpl {
 
   TORCH_API void run(const std::string& test_file) {
     has_run = true;
-
-    if (groups.size() == 0 || groups[0].size() == 0) {
-      throw std::runtime_error(
-          "No checks have been added to this instance of"
-          "Filecheck! Check for bad input.");
-    }
-
     doChecks(std::make_shared<std::string>(test_file));
   }
 
-  TORCH_API void run(
-      const std::string& checks_file,
-      const std::string& test_file) {
-    auto checks_ptr = std::make_shared<std::string>(checks_file);
-    parseStrings(checks_ptr);
-    run(test_file);
-  }
+  TORCH_API void addCheck(
+      CheckType type,
+      const std::string& s,
+      c10::optional<size_t> count = c10::nullopt) {
+    Check check(type, s, std::move(count));
 
-  TORCH_API void addCheck(Check check) {
     // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group
-    if (groups.size() == 0 ||
-        (check.type_ != CHECK_NOT && check.type_ != CHECK_DAG)) {
+    if (groups.size() == 0 || (type != CHECK_NOT && type != CHECK_DAG)) {
       groups.push_back({check});
     } else {
       auto& last_group = groups.back();
-      if (last_group.at(0).type_ == check.type_) {
+      if (last_group.at(0).type_ == type) {
         last_group.push_back(check);
       } else {
         groups.push_back({check});
       }
     }
-    has_run = false;
-  }
 
-  TORCH_API void addCheck(
-      CheckType type,
-      const std::string& s,
-      c10::optional<size_t> count = c10::nullopt) {
-    addCheck(Check(type, s, std::move(count)));
+    has_run = false;
   }
 
   bool has_run = false;
@@ -200,97 +159,6 @@ struct FileCheckImpl {
   friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
 
  private:
-  bool parseSingleCheck(
-      const std::shared_ptr<std::string>& checks_file,
-      size_t* start) {
-    const static std::vector<std::pair<CheckType, std::string>> check_pairs = {
-        {CHECK, ": "},
-        {CHECK_NEXT, "-NEXT: "},
-        {CHECK_SAME, "-SAME: "},
-        {CHECK_NOT, "-NOT: "},
-        {CHECK_DAG, "-DAG: "},
-        {CHECK_COUNT, "-COUNT-"}, // needs special parsing
-    };
-
-    for (const auto& check_pair : check_pairs) {
-      const std::string& check_suffix = check_pair.second;
-      auto suffix_pos = checks_file->find(check_suffix, *start);
-      if (suffix_pos != *start) {
-        continue;
-      }
-      size_t end_check_string = suffix_pos + check_suffix.size();
-      CheckType type = check_pair.first;
-      c10::optional<size_t> count = c10::nullopt;
-      auto end_line = checks_file->find("\n", end_check_string);
-      bool exactly = false;
-      if (type == CHECK_COUNT) {
-        const std::string exact = "EXACTLY-";
-        if (checks_file->find(exact, end_check_string) == end_check_string) {
-          exactly = true;
-          end_check_string += exact.size();
-        }
-        size_t end = assertFind(
-            SourceRange(checks_file, end_check_string, end_line), ":");
-        count = std::stoll(
-            checks_file->substr(end_check_string, end - end_check_string));
-        end_check_string = end + 2; // add ':' and the space
-      }
-      auto check = Check(
-          type,
-          checks_file->substr(end_check_string, end_line - end_check_string),
-          count);
-      addCheck(check);
-      if (exactly) {
-        addCheck(CHECK_NOT, check.search_str_);
-      }
-      *start = end_line;
-      return true;
-    }
-    return false;
-  }
-
-  size_t findNextStart(
-      const std::shared_ptr<std::string>& checks_file,
-      size_t prev_end) {
-    size_t start = checks_file->find("#", prev_end);
-    if (start == std::string::npos) {
-      return start;
-    }
-    start += 1;
-    static constexpr size_t max_whitespace = 6;
-    size_t i = 0;
-    while (start + i < checks_file->size() && i < max_whitespace) {
-      auto c = checks_file->at(start + i);
-      if (c != ' ' && c != '\t') {
-        break;
-      }
-      i++;
-    }
-    static const std::string check = "CHECK";
-    if (checks_file->substr(start + i, check.size()) == check) {
-      return start + i + check.size();
-    } else {
-      return findNextStart(checks_file, start + i + 1);
-    }
-  }
-
-  void parseStrings(const std::shared_ptr<std::string>& checks_file) {
-    size_t start = 0;
-    start = findNextStart(checks_file, 0);
-    while (start != std::string::npos) {
-      bool found_match = parseSingleCheck(checks_file, &start);
-      if (!found_match) {
-        std::ostringstream ss;
-        ss << "Could not parse check at:\n";
-        SourceRange(checks_file, start, start + 1).highlight(ss);
-        ss << "Check for bad input.";
-        has_run = true;
-        throw std::runtime_error(ss.str());
-      }
-      start = findNextStart(checks_file, start);
-    }
-  }
-
   void doCheckNot(
       const std::vector<Check>& nots,
       const std::shared_ptr<std::string>& file,
@@ -409,6 +277,7 @@ struct FileCheckImpl {
   }
 
   std::vector<Check> checks;
+  std::shared_ptr<std::string> check_file;
   std::vector<std::vector<Check>> groups;
 };
 
@@ -440,20 +309,6 @@ void FileCheck::run(const Graph& graph) {
   fcImpl->run(graph_str.str());
 };
 
-void FileCheck::run(
-    const std::string& input_checks_string,
-    const std::string& test_string) {
-  fcImpl->run(input_checks_string, test_string);
-}
-
-void FileCheck::run(
-    const std::string& input_checks_string,
-    const Graph& graph) {
-  std::stringstream graph_str;
-  graph_str << graph;
-  fcImpl->run(input_checks_string, graph_str.str());
-}
-
 FileCheck* FileCheck::check(const std::string& str) {
   fcImpl->addCheck(CHECK, str);
   return this;
index d7a7819..cf80575 100644 (file)
@@ -23,14 +23,6 @@ struct FileCheck {
   // Run FileCheck against dump of graph IR
   TORCH_API void run(const Graph& graph);
 
-  // Parsing input checks string and run against test string / dump of graph IR
-  TORCH_API void run(
-      const std::string& input_checks_string,
-      const std::string& test_string);
-  TORCH_API void run(
-      const std::string& input_checks_string,
-      const Graph& graph);
-
   // Checks that the string occurs, starting at the end of the most recent match
   TORCH_API FileCheck* check(const std::string& str);