From ffc7158bf2f97916305217e4203ef846c00161ce Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 28 Mar 2019 00:09:36 -0700 Subject: [PATCH] Revert D14652372: [pytorch][PR] Add parsing to file check Differential Revision: D14652372 Original commit changeset: 7430b9d1dc2b fbshipit-source-id: fa3d0f68515fe53447746469844d2db20c1292e0 --- test/cpp/jit/test_irparser.h | 14 --- test/test_jit.py | 63 +------------ torch/csrc/jit/init.cpp | 9 +- torch/csrc/jit/script/init.cpp | 22 +---- torch/csrc/jit/testing/file_check.cpp | 167 +++------------------------------- torch/csrc/jit/testing/file_check.h | 8 -- 6 files changed, 17 insertions(+), 266 deletions(-) diff --git a/test/cpp/jit/test_irparser.h b/test/cpp/jit/test_irparser.h index af8d28a..c8c3eea 100644 --- a/test/cpp/jit/test_irparser.h +++ b/test/cpp/jit/test_irparser.h @@ -2,7 +2,6 @@ #include #include -#include #include "test/cpp/jit/test_base.h" #include @@ -212,19 +211,6 @@ graph(%0 : Tensor, } AT_ASSERT(error_thrown); } - - { - auto graph = std::make_shared(); - const std::string& text = - R"IR( - graph(%a): - # CHECK: return - return (%a))IR"; - - script::parseIR(text, &*graph); - graph->inputs()[0]->type()->expect(); - torch::jit::testing::FileCheck().run(text, *graph); - } } } // namespace jit } // namespace torch diff --git a/test/test_jit.py b/test/test_jit.py index 028e926..b1b5c98 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -43,7 +43,7 @@ from common_methods_invocations import create_input, unpack_variables, \ exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL from torch.testing import FileCheck from torch._C import TensorType, TupleType, FloatType, IntType, \ - ListType, StringType, DictType, parse_ir + ListType, StringType, DictType from copy import deepcopy import random from typing import List, Dict, Optional, Tuple @@ -6044,14 +6044,6 @@ a") m2.sub2.a.data.zero_() self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2))) - def test_irparser(self): - graph_str = """graph(%0 : Double(5, 5)): - # CHECK: aten::relu - %1 : Double(5, 5) = aten::relu(%0) - return (%1) - """ - FileCheck().run(graph_str, parse_ir(graph_str)) - def test_filecheck(self): def test_check(): file = "232" @@ -6144,59 +6136,6 @@ a") with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): fb.run("22 1 22") - def test_filecheck_parse(self): - def test_check(): - file = """ - # CHECK: 2 - # CHECK: 3 - # CHECK: 2 - 232 - """ - FileCheck().run(checks_file=file, test_file=file) - file = """ - # CHECK: 232 - 232 - """ - FileCheck().run(file, "232") - with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'): - FileCheck().run(file, "22") - with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): - FileCheck().run("# CHECK: 22", "23") - test_check() - - def test_check_count(): - file = "22222" - FileCheck().run("# CHECK-COUNT-5: 2", file) - FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file) - FileCheck().run("# CHECK-COUNT-2: 22", file) - FileCheck().run("# CHECK-COUNT-1: 222", file) - - with self.assertRaisesRegex(RuntimeError, 'Expected to not find'): - FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file) - test_check_count() - - def test_check_same(): - file = "22\n33" - FileCheck().run("# CHECK-SAME: 22", file) - - with self.assertRaisesRegex(RuntimeError, "Expected to not find"): - FileCheck().run("# CHECK-SAME: 33", file) - - file = "22 1 3" - - FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file) - FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file) - test_check_same() - - def test_bad_input(): - with self.assertRaisesRegex(RuntimeError, "Check for bad input"): - FileCheck().run("", "1") - - with self.assertRaisesRegex(RuntimeError, "Could not parse check"): - FileCheck().run("# CHECK1", "") - - test_bad_input() - def test_script_module_call_noscript(self): class M(torch.jit.ScriptModule): def __init__(self): diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index a74edc3..f46dd9f 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include @@ -78,6 +77,7 @@ bool loadPythonClasses() { return true; } + } // anonymous namespace #if defined(_WIN32) @@ -369,12 +369,6 @@ void initJITBindings(PyObject* module) { }, py::arg("qualified_name")); - m.def("parse_ir", [](const std::string& input) { - auto graph = std::make_shared(); - script::parseIR(input, &*graph); - return graph; - }); - py::class_(m, "FunctionSchema") .def_property_readonly( "name", [](FunctionSchema& self) { return self.name(); }) @@ -492,5 +486,6 @@ void initJITBindings(PyObject* module) { initBatchTensorBindings(module); initRegisterBatchOpsBindings(module); } + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 8d1fc66..0692763 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -13,7 +13,6 @@ #include #include #include -#include #include #include #include @@ -1094,24 +1093,9 @@ void initJitScriptBindings(PyObject* module) { [](testing::FileCheck& f, const std::string& str) { return f.run(str); }) - .def( - "run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); }) - .def( - "run", - [](testing::FileCheck& f, - const std::string& input, - const std::string& output) { return f.run(input, output); }, - "Run", - py::arg("checks_file"), - py::arg("test_file")) - .def( - "run", - [](testing::FileCheck& f, const std::string& input, const Graph& g) { - return f.run(input, g); - }, - "Run", - py::arg("checks_file"), - py::arg("graph")); + .def("run", [](testing::FileCheck& f, const Graph& g) { + return f.run(g); + }); } } // namespace script } // namespace jit diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index bcc9e90..3af8c5a 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -79,11 +79,10 @@ std::ostream& operator<<(std::ostream& out, const Check& c) { }; namespace { - size_t assertFind( const SourceRange& search_range, const std::string& sub, - std::function extra_msg = nullptr) { + const Check& check) { auto pos = search_range.file_ptr()->find(sub, search_range.start()); if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) { auto found_range = @@ -93,24 +92,13 @@ size_t assertFind( printQuotedString(ss, sub); ss << " but did not find it\n"; found_range.highlight(ss); - if (extra_msg) { - extra_msg(ss); - } + ss << "From " << check << "\n"; throw std::runtime_error(ss.str()); } return pos; } size_t assertFind( - const SourceRange& search_range, - const std::string& sub, - const Check& check) { - return assertFind(search_range, sub, [&](std::ostream& out) { - out << "From " << check << "\n"; - }); -} - -size_t assertFind( const std::shared_ptr& file, const std::string& sub, size_t start, @@ -135,18 +123,6 @@ void assertNotFind( throw std::runtime_error(ss.str()); } } - -size_t substringCount( - const std::shared_ptr& file, - const std::string& sub) { - size_t occurances = 0; - std::string::size_type pos = 0; - while ((pos = file->find(sub, pos)) != std::string::npos) { - ++occurances; - pos += sub.length(); - } - return occurances; -} } // namespace struct FileCheckImpl { @@ -154,45 +130,28 @@ struct FileCheckImpl { TORCH_API void run(const std::string& test_file) { has_run = true; - - if (groups.size() == 0 || groups[0].size() == 0) { - throw std::runtime_error( - "No checks have been added to this instance of" - "Filecheck! Check for bad input."); - } - doChecks(std::make_shared(test_file)); } - TORCH_API void run( - const std::string& checks_file, - const std::string& test_file) { - auto checks_ptr = std::make_shared(checks_file); - parseStrings(checks_ptr); - run(test_file); - } + TORCH_API void addCheck( + CheckType type, + const std::string& s, + c10::optional count = c10::nullopt) { + Check check(type, s, std::move(count)); - TORCH_API void addCheck(Check check) { // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group - if (groups.size() == 0 || - (check.type_ != CHECK_NOT && check.type_ != CHECK_DAG)) { + if (groups.size() == 0 || (type != CHECK_NOT && type != CHECK_DAG)) { groups.push_back({check}); } else { auto& last_group = groups.back(); - if (last_group.at(0).type_ == check.type_) { + if (last_group.at(0).type_ == type) { last_group.push_back(check); } else { groups.push_back({check}); } } - has_run = false; - } - TORCH_API void addCheck( - CheckType type, - const std::string& s, - c10::optional count = c10::nullopt) { - addCheck(Check(type, s, std::move(count))); + has_run = false; } bool has_run = false; @@ -200,97 +159,6 @@ struct FileCheckImpl { friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc); private: - bool parseSingleCheck( - const std::shared_ptr& checks_file, - size_t* start) { - const static std::vector> check_pairs = { - {CHECK, ": "}, - {CHECK_NEXT, "-NEXT: "}, - {CHECK_SAME, "-SAME: "}, - {CHECK_NOT, "-NOT: "}, - {CHECK_DAG, "-DAG: "}, - {CHECK_COUNT, "-COUNT-"}, // needs special parsing - }; - - for (const auto& check_pair : check_pairs) { - const std::string& check_suffix = check_pair.second; - auto suffix_pos = checks_file->find(check_suffix, *start); - if (suffix_pos != *start) { - continue; - } - size_t end_check_string = suffix_pos + check_suffix.size(); - CheckType type = check_pair.first; - c10::optional count = c10::nullopt; - auto end_line = checks_file->find("\n", end_check_string); - bool exactly = false; - if (type == CHECK_COUNT) { - const std::string exact = "EXACTLY-"; - if (checks_file->find(exact, end_check_string) == end_check_string) { - exactly = true; - end_check_string += exact.size(); - } - size_t end = assertFind( - SourceRange(checks_file, end_check_string, end_line), ":"); - count = std::stoll( - checks_file->substr(end_check_string, end - end_check_string)); - end_check_string = end + 2; // add ':' and the space - } - auto check = Check( - type, - checks_file->substr(end_check_string, end_line - end_check_string), - count); - addCheck(check); - if (exactly) { - addCheck(CHECK_NOT, check.search_str_); - } - *start = end_line; - return true; - } - return false; - } - - size_t findNextStart( - const std::shared_ptr& checks_file, - size_t prev_end) { - size_t start = checks_file->find("#", prev_end); - if (start == std::string::npos) { - return start; - } - start += 1; - static constexpr size_t max_whitespace = 6; - size_t i = 0; - while (start + i < checks_file->size() && i < max_whitespace) { - auto c = checks_file->at(start + i); - if (c != ' ' && c != '\t') { - break; - } - i++; - } - static const std::string check = "CHECK"; - if (checks_file->substr(start + i, check.size()) == check) { - return start + i + check.size(); - } else { - return findNextStart(checks_file, start + i + 1); - } - } - - void parseStrings(const std::shared_ptr& checks_file) { - size_t start = 0; - start = findNextStart(checks_file, 0); - while (start != std::string::npos) { - bool found_match = parseSingleCheck(checks_file, &start); - if (!found_match) { - std::ostringstream ss; - ss << "Could not parse check at:\n"; - SourceRange(checks_file, start, start + 1).highlight(ss); - ss << "Check for bad input."; - has_run = true; - throw std::runtime_error(ss.str()); - } - start = findNextStart(checks_file, start); - } - } - void doCheckNot( const std::vector& nots, const std::shared_ptr& file, @@ -409,6 +277,7 @@ struct FileCheckImpl { } std::vector checks; + std::shared_ptr check_file; std::vector> groups; }; @@ -440,20 +309,6 @@ void FileCheck::run(const Graph& graph) { fcImpl->run(graph_str.str()); }; -void FileCheck::run( - const std::string& input_checks_string, - const std::string& test_string) { - fcImpl->run(input_checks_string, test_string); -} - -void FileCheck::run( - const std::string& input_checks_string, - const Graph& graph) { - std::stringstream graph_str; - graph_str << graph; - fcImpl->run(input_checks_string, graph_str.str()); -} - FileCheck* FileCheck::check(const std::string& str) { fcImpl->addCheck(CHECK, str); return this; diff --git a/torch/csrc/jit/testing/file_check.h b/torch/csrc/jit/testing/file_check.h index d7a7819..cf80575 100644 --- a/torch/csrc/jit/testing/file_check.h +++ b/torch/csrc/jit/testing/file_check.h @@ -23,14 +23,6 @@ struct FileCheck { // Run FileCheck against dump of graph IR TORCH_API void run(const Graph& graph); - // Parsing input checks string and run against test string / dump of graph IR - TORCH_API void run( - const std::string& input_checks_string, - const std::string& test_string); - TORCH_API void run( - const std::string& input_checks_string, - const Graph& graph); - // Checks that the string occurs, starting at the end of the most recent match TORCH_API FileCheck* check(const std::string& str); -- 2.7.4