From 0daafe02098825842fbe5d1682e88e63ae6868c1 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 27 Mar 2019 18:11:45 -0700 Subject: [PATCH] Add parsing to file check (#18304) Summary: This allows you to embed checks in IR, making the test more readable. E.g. ``` graph_str = 'graph(%0 : Double(5, 5)): # CHECK: aten::relu %1 : Double(5, 5) = aten::relu(%0) return (%1)' FileCheck().run(graph_str, parseIR(graph_str)) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/18304 Differential Revision: D14652372 Pulled By: eellison fbshipit-source-id: 7430b9d1dc2b7584704375aac02d7392ecec76a0 --- test/cpp/jit/test_irparser.h | 14 +++ test/test_jit.py | 63 ++++++++++++- torch/csrc/jit/init.cpp | 9 +- torch/csrc/jit/script/init.cpp | 22 ++++- torch/csrc/jit/testing/file_check.cpp | 167 +++++++++++++++++++++++++++++++--- torch/csrc/jit/testing/file_check.h | 8 ++ 6 files changed, 266 insertions(+), 17 deletions(-) diff --git a/test/cpp/jit/test_irparser.h b/test/cpp/jit/test_irparser.h index c8c3eea..af8d28a 100644 --- a/test/cpp/jit/test_irparser.h +++ b/test/cpp/jit/test_irparser.h @@ -2,6 +2,7 @@ #include #include +#include #include "test/cpp/jit/test_base.h" #include @@ -211,6 +212,19 @@ graph(%0 : Tensor, } AT_ASSERT(error_thrown); } + + { + auto graph = std::make_shared(); + const std::string& text = + R"IR( + graph(%a): + # CHECK: return + return (%a))IR"; + + script::parseIR(text, &*graph); + graph->inputs()[0]->type()->expect(); + torch::jit::testing::FileCheck().run(text, *graph); + } } } // namespace jit } // namespace torch diff --git a/test/test_jit.py b/test/test_jit.py index 694926b..c0dc575 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -43,7 +43,7 @@ from common_methods_invocations import create_input, unpack_variables, \ exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL from torch.testing import FileCheck from torch._C import TensorType, TupleType, FloatType, IntType, \ - ListType, StringType, DictType + ListType, StringType, DictType, parse_ir from copy import deepcopy import random from typing import List, Dict, Optional, Tuple @@ -6042,6 +6042,14 @@ a") m2.sub2.a.data.zero_() self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2))) + def test_irparser(self): + graph_str = """graph(%0 : Double(5, 5)): + # CHECK: aten::relu + %1 : Double(5, 5) = aten::relu(%0) + return (%1) + """ + FileCheck().run(graph_str, parse_ir(graph_str)) + def test_filecheck(self): def test_check(): file = "232" @@ -6134,6 +6142,59 @@ a") with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): fb.run("22 1 22") + def test_filecheck_parse(self): + def test_check(): + file = """ + # CHECK: 2 + # CHECK: 3 + # CHECK: 2 + 232 + """ + FileCheck().run(checks_file=file, test_file=file) + file = """ + # CHECK: 232 + 232 + """ + FileCheck().run(file, "232") + with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'): + FileCheck().run(file, "22") + with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): + FileCheck().run("# CHECK: 22", "23") + test_check() + + def test_check_count(): + file = "22222" + FileCheck().run("# CHECK-COUNT-5: 2", file) + FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file) + FileCheck().run("# CHECK-COUNT-2: 22", file) + FileCheck().run("# CHECK-COUNT-1: 222", file) + + with self.assertRaisesRegex(RuntimeError, 'Expected to not find'): + FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file) + test_check_count() + + def test_check_same(): + file = "22\n33" + FileCheck().run("# CHECK-SAME: 22", file) + + with self.assertRaisesRegex(RuntimeError, "Expected to not find"): + FileCheck().run("# CHECK-SAME: 33", file) + + file = "22 1 3" + + FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file) + FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file) + test_check_same() + + def test_bad_input(): + with self.assertRaisesRegex(RuntimeError, "Check for bad input"): + FileCheck().run("", "1") + + with self.assertRaisesRegex(RuntimeError, "Could not parse check"): + FileCheck().run("# CHECK1", "") + + test_bad_input() + def test_script_module_call_noscript(self): class M(torch.jit.ScriptModule): def __init__(self): diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index f46dd9f..a74edc3 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -77,7 +78,6 @@ bool loadPythonClasses() { return true; } - } // anonymous namespace #if defined(_WIN32) @@ -369,6 +369,12 @@ void initJITBindings(PyObject* module) { }, py::arg("qualified_name")); + m.def("parse_ir", [](const std::string& input) { + auto graph = std::make_shared(); + script::parseIR(input, &*graph); + return graph; + }); + py::class_(m, "FunctionSchema") .def_property_readonly( "name", [](FunctionSchema& self) { return self.name(); }) @@ -486,6 +492,5 @@ void initJITBindings(PyObject* module) { initBatchTensorBindings(module); initRegisterBatchOpsBindings(module); } - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 0692763..8d1fc66 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -1093,9 +1094,24 @@ void initJitScriptBindings(PyObject* module) { [](testing::FileCheck& f, const std::string& str) { return f.run(str); }) - .def("run", [](testing::FileCheck& f, const Graph& g) { - return f.run(g); - }); + .def( + "run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); }) + .def( + "run", + [](testing::FileCheck& f, + const std::string& input, + const std::string& output) { return f.run(input, output); }, + "Run", + py::arg("checks_file"), + py::arg("test_file")) + .def( + "run", + [](testing::FileCheck& f, const std::string& input, const Graph& g) { + return f.run(input, g); + }, + "Run", + py::arg("checks_file"), + py::arg("graph")); } } // namespace script } // namespace jit diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index 3af8c5a..bcc9e90 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -79,10 +79,11 @@ std::ostream& operator<<(std::ostream& out, const Check& c) { }; namespace { + size_t assertFind( const SourceRange& search_range, const std::string& sub, - const Check& check) { + std::function extra_msg = nullptr) { auto pos = search_range.file_ptr()->find(sub, search_range.start()); if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) { auto found_range = @@ -92,13 +93,24 @@ size_t assertFind( printQuotedString(ss, sub); ss << " but did not find it\n"; found_range.highlight(ss); - ss << "From " << check << "\n"; + if (extra_msg) { + extra_msg(ss); + } throw std::runtime_error(ss.str()); } return pos; } size_t assertFind( + const SourceRange& search_range, + const std::string& sub, + const Check& check) { + return assertFind(search_range, sub, [&](std::ostream& out) { + out << "From " << check << "\n"; + }); +} + +size_t assertFind( const std::shared_ptr& file, const std::string& sub, size_t start, @@ -123,6 +135,18 @@ void assertNotFind( throw std::runtime_error(ss.str()); } } + +size_t substringCount( + const std::shared_ptr& file, + const std::string& sub) { + size_t occurances = 0; + std::string::size_type pos = 0; + while ((pos = file->find(sub, pos)) != std::string::npos) { + ++occurances; + pos += sub.length(); + } + return occurances; +} } // namespace struct FileCheckImpl { @@ -130,35 +154,143 @@ struct FileCheckImpl { TORCH_API void run(const std::string& test_file) { has_run = true; + + if (groups.size() == 0 || groups[0].size() == 0) { + throw std::runtime_error( + "No checks have been added to this instance of" + "Filecheck! Check for bad input."); + } + doChecks(std::make_shared(test_file)); } - TORCH_API void addCheck( - CheckType type, - const std::string& s, - c10::optional count = c10::nullopt) { - Check check(type, s, std::move(count)); + TORCH_API void run( + const std::string& checks_file, + const std::string& test_file) { + auto checks_ptr = std::make_shared(checks_file); + parseStrings(checks_ptr); + run(test_file); + } + TORCH_API void addCheck(Check check) { // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group - if (groups.size() == 0 || (type != CHECK_NOT && type != CHECK_DAG)) { + if (groups.size() == 0 || + (check.type_ != CHECK_NOT && check.type_ != CHECK_DAG)) { groups.push_back({check}); } else { auto& last_group = groups.back(); - if (last_group.at(0).type_ == type) { + if (last_group.at(0).type_ == check.type_) { last_group.push_back(check); } else { groups.push_back({check}); } } - has_run = false; } + TORCH_API void addCheck( + CheckType type, + const std::string& s, + c10::optional count = c10::nullopt) { + addCheck(Check(type, s, std::move(count))); + } + bool has_run = false; friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc); private: + bool parseSingleCheck( + const std::shared_ptr& checks_file, + size_t* start) { + const static std::vector> check_pairs = { + {CHECK, ": "}, + {CHECK_NEXT, "-NEXT: "}, + {CHECK_SAME, "-SAME: "}, + {CHECK_NOT, "-NOT: "}, + {CHECK_DAG, "-DAG: "}, + {CHECK_COUNT, "-COUNT-"}, // needs special parsing + }; + + for (const auto& check_pair : check_pairs) { + const std::string& check_suffix = check_pair.second; + auto suffix_pos = checks_file->find(check_suffix, *start); + if (suffix_pos != *start) { + continue; + } + size_t end_check_string = suffix_pos + check_suffix.size(); + CheckType type = check_pair.first; + c10::optional count = c10::nullopt; + auto end_line = checks_file->find("\n", end_check_string); + bool exactly = false; + if (type == CHECK_COUNT) { + const std::string exact = "EXACTLY-"; + if (checks_file->find(exact, end_check_string) == end_check_string) { + exactly = true; + end_check_string += exact.size(); + } + size_t end = assertFind( + SourceRange(checks_file, end_check_string, end_line), ":"); + count = std::stoll( + checks_file->substr(end_check_string, end - end_check_string)); + end_check_string = end + 2; // add ':' and the space + } + auto check = Check( + type, + checks_file->substr(end_check_string, end_line - end_check_string), + count); + addCheck(check); + if (exactly) { + addCheck(CHECK_NOT, check.search_str_); + } + *start = end_line; + return true; + } + return false; + } + + size_t findNextStart( + const std::shared_ptr& checks_file, + size_t prev_end) { + size_t start = checks_file->find("#", prev_end); + if (start == std::string::npos) { + return start; + } + start += 1; + static constexpr size_t max_whitespace = 6; + size_t i = 0; + while (start + i < checks_file->size() && i < max_whitespace) { + auto c = checks_file->at(start + i); + if (c != ' ' && c != '\t') { + break; + } + i++; + } + static const std::string check = "CHECK"; + if (checks_file->substr(start + i, check.size()) == check) { + return start + i + check.size(); + } else { + return findNextStart(checks_file, start + i + 1); + } + } + + void parseStrings(const std::shared_ptr& checks_file) { + size_t start = 0; + start = findNextStart(checks_file, 0); + while (start != std::string::npos) { + bool found_match = parseSingleCheck(checks_file, &start); + if (!found_match) { + std::ostringstream ss; + ss << "Could not parse check at:\n"; + SourceRange(checks_file, start, start + 1).highlight(ss); + ss << "Check for bad input."; + has_run = true; + throw std::runtime_error(ss.str()); + } + start = findNextStart(checks_file, start); + } + } + void doCheckNot( const std::vector& nots, const std::shared_ptr& file, @@ -277,7 +409,6 @@ struct FileCheckImpl { } std::vector checks; - std::shared_ptr check_file; std::vector> groups; }; @@ -309,6 +440,20 @@ void FileCheck::run(const Graph& graph) { fcImpl->run(graph_str.str()); }; +void FileCheck::run( + const std::string& input_checks_string, + const std::string& test_string) { + fcImpl->run(input_checks_string, test_string); +} + +void FileCheck::run( + const std::string& input_checks_string, + const Graph& graph) { + std::stringstream graph_str; + graph_str << graph; + fcImpl->run(input_checks_string, graph_str.str()); +} + FileCheck* FileCheck::check(const std::string& str) { fcImpl->addCheck(CHECK, str); return this; diff --git a/torch/csrc/jit/testing/file_check.h b/torch/csrc/jit/testing/file_check.h index cf80575..d7a7819 100644 --- a/torch/csrc/jit/testing/file_check.h +++ b/torch/csrc/jit/testing/file_check.h @@ -23,6 +23,14 @@ struct FileCheck { // Run FileCheck against dump of graph IR TORCH_API void run(const Graph& graph); + // Parsing input checks string and run against test string / dump of graph IR + TORCH_API void run( + const std::string& input_checks_string, + const std::string& test_string); + TORCH_API void run( + const std::string& input_checks_string, + const Graph& graph); + // Checks that the string occurs, starting at the end of the most recent match TORCH_API FileCheck* check(const std::string& str); -- 2.7.4