From: eellison Date: Fri, 29 Mar 2019 22:35:37 +0000 (-0700) Subject: Re-land Parsing file check (#18570) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~546 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=393731ab24a875bf6fc937fbe51179a449e27117;p=platform%2Fupstream%2Fpytorch.git Re-land Parsing file check (#18570) Summary: The last time I tried to land it there was a merge race with the docs coverage test lol. Re-landing with the fix. Re-land of https://github.com/pytorch/pytorch/pull/18304 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18570 Differential Revision: D14668859 Pulled By: eellison fbshipit-source-id: 3825a35ddc6179a0d433d70d22b5c1a96c20b21a --- diff --git a/test/cpp/jit/test_irparser.h b/test/cpp/jit/test_irparser.h index c8c3eea..af8d28a 100644 --- a/test/cpp/jit/test_irparser.h +++ b/test/cpp/jit/test_irparser.h @@ -2,6 +2,7 @@ #include #include +#include #include "test/cpp/jit/test_base.h" #include @@ -211,6 +212,19 @@ graph(%0 : Tensor, } AT_ASSERT(error_thrown); } + + { + auto graph = std::make_shared(); + const std::string& text = + R"IR( + graph(%a): + # CHECK: return + return (%a))IR"; + + script::parseIR(text, &*graph); + graph->inputs()[0]->type()->expect(); + torch::jit::testing::FileCheck().run(text, *graph); + } } } // namespace jit } // namespace torch diff --git a/test/test_docs_coverage.py b/test/test_docs_coverage.py index 3b565c3..02bad78 100644 --- a/test/test_docs_coverage.py +++ b/test/test_docs_coverage.py @@ -37,6 +37,7 @@ class TestDocCoverage(unittest.TestCase): # below are some jit functions 'wait', 'fork', 'parse_type_comment', 'import_ir_module', 'import_ir_module_from_buffer', 'merge_type_from_type_comment', + 'parse_ir', # below are symbols mistakely binded to torch.*, but should # go to torch.nn.functional.* instead diff --git a/test/test_jit.py b/test/test_jit.py index 6229f0c..7d0f1c9 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -43,7 +43,7 @@ from common_methods_invocations import create_input, unpack_variables, \ exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL from torch.testing import FileCheck from torch._C import TensorType, TupleType, FloatType, IntType, \ - ListType, StringType, DictType + ListType, StringType, DictType, parse_ir from copy import deepcopy import random from typing import List, Dict, Optional, Tuple @@ -5540,6 +5540,14 @@ a") m2.sub2.a.data.zero_() self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2))) + def test_irparser(self): + graph_str = """graph(%0 : Double(5, 5)): + # CHECK: aten::relu + %1 : Double(5, 5) = aten::relu(%0) + return (%1) + """ + FileCheck().run(graph_str, parse_ir(graph_str)) + def test_filecheck(self): def test_check(): file = "232" @@ -5632,6 +5640,59 @@ a") with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): fb.run("22 1 22") + def test_filecheck_parse(self): + def test_check(): + file = """ + # CHECK: 2 + # CHECK: 3 + # CHECK: 2 + 232 + """ + FileCheck().run(checks_file=file, test_file=file) + file = """ + # CHECK: 232 + 232 + """ + FileCheck().run(file, "232") + with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'): + FileCheck().run(file, "22") + with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): + FileCheck().run("# CHECK: 22", "23") + test_check() + + def test_check_count(): + file = "22222" + FileCheck().run("# CHECK-COUNT-5: 2", file) + FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file) + FileCheck().run("# CHECK-COUNT-2: 22", file) + FileCheck().run("# CHECK-COUNT-1: 222", file) + + with self.assertRaisesRegex(RuntimeError, 'Expected to not find'): + FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file) + test_check_count() + + def test_check_same(): + file = "22\n33" + FileCheck().run("# CHECK-SAME: 22", file) + + with self.assertRaisesRegex(RuntimeError, "Expected to not find"): + FileCheck().run("# CHECK-SAME: 33", file) + + file = "22 1 3" + + FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file) + FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file) + test_check_same() + + def test_bad_input(): + with self.assertRaisesRegex(RuntimeError, "Check for bad input"): + FileCheck().run("", "1") + + with self.assertRaisesRegex(RuntimeError, "Could not parse check"): + FileCheck().run("# CHECK1", "") + + test_bad_input() + def test_script_module_call_noscript(self): class M(torch.jit.ScriptModule): def __init__(self): diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 489c438..528dc43 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -75,7 +76,6 @@ bool loadPythonClasses() { return true; } - } // anonymous namespace #if defined(_WIN32) @@ -375,6 +375,12 @@ void initJITBindings(PyObject* module) { }, py::arg("qualified_name")); + m.def("parse_ir", [](const std::string& input) { + auto graph = std::make_shared(); + script::parseIR(input, &*graph); + return graph; + }); + py::class_(m, "FunctionSchema") .def_property_readonly( "name", [](FunctionSchema& self) { return self.name(); }) @@ -490,6 +496,5 @@ void initJITBindings(PyObject* module) { script::initTreeViewBindings(module); script::initJitScriptBindings(module); } - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index f4d1c89..deebed3 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -1098,9 +1099,24 @@ void initJitScriptBindings(PyObject* module) { [](testing::FileCheck& f, const std::string& str) { return f.run(str); }) - .def("run", [](testing::FileCheck& f, const Graph& g) { - return f.run(g); - }); + .def( + "run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); }) + .def( + "run", + [](testing::FileCheck& f, + const std::string& input, + const std::string& output) { return f.run(input, output); }, + "Run", + py::arg("checks_file"), + py::arg("test_file")) + .def( + "run", + [](testing::FileCheck& f, const std::string& input, const Graph& g) { + return f.run(input, g); + }, + "Run", + py::arg("checks_file"), + py::arg("graph")); } } // namespace script } // namespace jit diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index 3af8c5a..741502b 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -79,10 +79,11 @@ std::ostream& operator<<(std::ostream& out, const Check& c) { }; namespace { + size_t assertFind( const SourceRange& search_range, const std::string& sub, - const Check& check) { + std::function extra_msg = nullptr) { auto pos = search_range.file_ptr()->find(sub, search_range.start()); if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) { auto found_range = @@ -92,13 +93,24 @@ size_t assertFind( printQuotedString(ss, sub); ss << " but did not find it\n"; found_range.highlight(ss); - ss << "From " << check << "\n"; + if (extra_msg) { + extra_msg(ss); + } throw std::runtime_error(ss.str()); } return pos; } size_t assertFind( + const SourceRange& search_range, + const std::string& sub, + const Check& check) { + return assertFind(search_range, sub, [&](std::ostream& out) { + out << "From " << check << "\n"; + }); +} + +size_t assertFind( const std::shared_ptr& file, const std::string& sub, size_t start, @@ -123,6 +135,7 @@ void assertNotFind( throw std::runtime_error(ss.str()); } } + } // namespace struct FileCheckImpl { @@ -130,35 +143,143 @@ struct FileCheckImpl { TORCH_API void run(const std::string& test_file) { has_run = true; + + if (groups.size() == 0 || groups[0].size() == 0) { + throw std::runtime_error( + "No checks have been added to this instance of" + "Filecheck! Check for bad input."); + } + doChecks(std::make_shared(test_file)); } - TORCH_API void addCheck( - CheckType type, - const std::string& s, - c10::optional count = c10::nullopt) { - Check check(type, s, std::move(count)); + TORCH_API void run( + const std::string& checks_file, + const std::string& test_file) { + auto checks_ptr = std::make_shared(checks_file); + parseStrings(checks_ptr); + run(test_file); + } + TORCH_API void addCheck(Check check) { // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group - if (groups.size() == 0 || (type != CHECK_NOT && type != CHECK_DAG)) { + if (groups.size() == 0 || + (check.type_ != CHECK_NOT && check.type_ != CHECK_DAG)) { groups.push_back({check}); } else { auto& last_group = groups.back(); - if (last_group.at(0).type_ == type) { + if (last_group.at(0).type_ == check.type_) { last_group.push_back(check); } else { groups.push_back({check}); } } - has_run = false; } + TORCH_API void addCheck( + CheckType type, + const std::string& s, + c10::optional count = c10::nullopt) { + addCheck(Check(type, s, std::move(count))); + } + bool has_run = false; friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc); private: + bool parseSingleCheck( + const std::shared_ptr& checks_file, + size_t* start) { + const static std::vector> check_pairs = { + {CHECK, ": "}, + {CHECK_NEXT, "-NEXT: "}, + {CHECK_SAME, "-SAME: "}, + {CHECK_NOT, "-NOT: "}, + {CHECK_DAG, "-DAG: "}, + {CHECK_COUNT, "-COUNT-"}, // needs special parsing + }; + + for (const auto& check_pair : check_pairs) { + const std::string& check_suffix = check_pair.second; + auto suffix_pos = checks_file->find(check_suffix, *start); + if (suffix_pos != *start) { + continue; + } + size_t end_check_string = suffix_pos + check_suffix.size(); + CheckType type = check_pair.first; + c10::optional count = c10::nullopt; + auto end_line = checks_file->find("\n", end_check_string); + bool exactly = false; + if (type == CHECK_COUNT) { + const std::string exact = "EXACTLY-"; + if (checks_file->find(exact, end_check_string) == end_check_string) { + exactly = true; + end_check_string += exact.size(); + } + size_t end = assertFind( + SourceRange(checks_file, end_check_string, end_line), ":"); + count = std::stoll( + checks_file->substr(end_check_string, end - end_check_string)); + end_check_string = end + 2; // add ':' and the space + } + auto check = Check( + type, + checks_file->substr(end_check_string, end_line - end_check_string), + count); + addCheck(check); + if (exactly) { + addCheck(CHECK_NOT, check.search_str_); + } + *start = end_line; + return true; + } + return false; + } + + size_t findNextStart( + const std::shared_ptr& checks_file, + size_t prev_end) { + size_t start = checks_file->find("#", prev_end); + if (start == std::string::npos) { + return start; + } + start += 1; + static constexpr size_t max_whitespace = 6; + size_t i = 0; + while (start + i < checks_file->size() && i < max_whitespace) { + auto c = checks_file->at(start + i); + if (c != ' ' && c != '\t') { + break; + } + i++; + } + static const std::string check = "CHECK"; + if (checks_file->substr(start + i, check.size()) == check) { + return start + i + check.size(); + } else { + return findNextStart(checks_file, start + i + 1); + } + } + + void parseStrings(const std::shared_ptr& checks_file) { + size_t start = 0; + start = findNextStart(checks_file, 0); + while (start != std::string::npos) { + bool found_match = parseSingleCheck(checks_file, &start); + if (!found_match) { + std::ostringstream ss; + ss << "Could not parse check at:\n"; + SourceRange(checks_file, start, start + 1).highlight(ss); + ss << "Check for bad input."; + has_run = true; + throw std::runtime_error(ss.str()); + } + start = findNextStart(checks_file, start); + } + } + void doCheckNot( const std::vector& nots, const std::shared_ptr& file, @@ -277,7 +398,6 @@ struct FileCheckImpl { } std::vector checks; - std::shared_ptr check_file; std::vector> groups; }; @@ -309,6 +429,20 @@ void FileCheck::run(const Graph& graph) { fcImpl->run(graph_str.str()); }; +void FileCheck::run( + const std::string& input_checks_string, + const std::string& test_string) { + fcImpl->run(input_checks_string, test_string); +} + +void FileCheck::run( + const std::string& input_checks_string, + const Graph& graph) { + std::stringstream graph_str; + graph_str << graph; + fcImpl->run(input_checks_string, graph_str.str()); +} + FileCheck* FileCheck::check(const std::string& str) { fcImpl->addCheck(CHECK, str); return this; diff --git a/torch/csrc/jit/testing/file_check.h b/torch/csrc/jit/testing/file_check.h index cf80575..d7a7819 100644 --- a/torch/csrc/jit/testing/file_check.h +++ b/torch/csrc/jit/testing/file_check.h @@ -23,6 +23,14 @@ struct FileCheck { // Run FileCheck against dump of graph IR TORCH_API void run(const Graph& graph); + // Parsing input checks string and run against test string / dump of graph IR + TORCH_API void run( + const std::string& input_checks_string, + const std::string& test_string); + TORCH_API void run( + const std::string& input_checks_string, + const Graph& graph); + // Checks that the string occurs, starting at the end of the most recent match TORCH_API FileCheck* check(const std::string& str);