From: Elias Ellison Date: Tue, 19 Feb 2019 20:25:30 +0000 (-0800) Subject: Lightweight String check Utility (#16858) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1212 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=89df22e57b9be274174ec2660e8707f0e0a1fa4b;p=platform%2Fupstream%2Fpytorch.git Lightweight String check Utility (#16858) Summary: light weight implementation of LLVM filecheck utility. Currently only handles string matching - regexes & saving a regex to a variable name can be added as needed. Current intended usage is through FileCheckBuilder python handle, and is shown in the tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16858 Differential Revision: D14096244 Pulled By: eellison fbshipit-source-id: c7c8d1457691c105e6ccbb3c1a378d96baac2569 --- diff --git a/setup.py b/setup.py index d21567b..3318bbb 100644 --- a/setup.py +++ b/setup.py @@ -800,6 +800,7 @@ if __name__ == '__main__': 'include/torch/csrc/jit/generated/*.h', 'include/torch/csrc/jit/passes/*.h', 'include/torch/csrc/jit/script/*.h', + 'include/torch/csrc/jit/testing/*.h', 'include/torch/csrc/onnx/*.h', 'include/torch/csrc/utils/*.h', 'include/pybind11/*.h', diff --git a/test/test_jit.py b/test/test_jit.py index 143ba0c..e7f0235 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5537,6 +5537,108 @@ a") m2.sub2.a.data.zero_() self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2))) + def test_filecheck(self): + from torch.testing import FileCheck + + # def test_accidental_not_used(): + # def unused(): + # a = FileCheck() + # + # with self.capture_stdout() as captured: + # a = FileCheck() + # del a + # self.assertTrue("You have not run this instance of FileCheck" + # in captured[0]) + # + # test_accidental_not_used() + def test_check(): + file = "232" + FileCheck().check("2").check("3").check("2").run(file) + FileCheck().check("232").run(file) + + with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): + FileCheck().check("22").run(file) + with self.assertRaisesRegex(RuntimeError, "CHECK: 3"): + FileCheck().check("3").check("3").run(file) + + test_check() + + def test_check_count(): + file = "22222" + FileCheck().check_count("2", 5).run(file) + FileCheck().check_count("22", 2).run(file) + FileCheck().check_count("222", 1).run(file) + + with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): + FileCheck().check_count("22", 3).run(file) + + with self.assertRaisesRegex(RuntimeError, "CHECK-COUNT-6: 2"): + FileCheck().check_count("2", 6).run(file) + + test_check_count() + + def test_check_same(): + file = "22\n33" + # FileCheck().check_same("22").run(file) + + with self.assertRaisesRegex(RuntimeError, "Expected to not find"): + FileCheck().check_same("33").run(file) + + file = "22 1 3" + + FileCheck().check("2").check_same("3").run(file) + FileCheck().check_count("2", 2).check_same("3").run(file) + + test_check_same() + + def test_check_next(): + file = "\n1\n2\n3" + FileCheck().check("1").check_next("2").check_next("3").run(file) + FileCheck().check_next("1").check_next("2").check_next("3").run(file) + + with self.assertRaisesRegex(RuntimeError, "Expected to find"): + FileCheck().check("1").check_next("2").run("12") + + with self.assertRaisesRegex(RuntimeError, "Expected to not find"): + FileCheck().check("1").check_next("2").run("1\n\n2") + + test_check_next() + + def test_check_dag(): + fc = FileCheck().check_dag("1").check_dag("2").check_not("2") + fc.run("12") + fc.run("21") + + fc = FileCheck() + fc.check_not("3").check_dag("1").check_dag("2").check_not("3") + fc.run("1 3 2") + fc.run("2 3 1") + + fc = FileCheck().check_dag("1").check_dag("2").check("3") + with self.assertRaisesRegex(RuntimeError, 'Expected to find "3" but did not find it'): + fc.run("1 3 2") + + test_check_dag() + + def test_check_not(): + FileCheck().check_not("2").check("1").run("12") + FileCheck().check("2").check_not("2").run("12") + + with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'): + FileCheck().check_not("2").check("1").run("21") + + with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): + FileCheck().check("2").check_not("1").run("21") + + # checks with distinct range matchings + fb = FileCheck().check_count("2", 2).check_count("2", 2).check_not("2") + with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'): + fb.run("22 2 22") + + fb = FileCheck().check_count("2", 2).check_not("1").check_count("2", 2) + with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): + fb.run("22 1 22") + def test_script_module_call_noscript(self): class M(torch.jit.ScriptModule): def __init__(self): diff --git a/tools/build_variables.py b/tools/build_variables.py index 9323d55..4273e1b 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -98,6 +98,7 @@ libtorch_sources = [ "torch/csrc/jit/script/sugared_value.cpp", "torch/csrc/jit/script/schema_matching.cpp", "torch/csrc/jit/script/parser.cpp", + "torch/csrc/jit/testing/file_check.cpp", "torch/csrc/jit/import_method.cpp", "torch/csrc/jit/hooks_for_testing.cpp", "torch/csrc/jit/script/builtin_functions.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 4e402b9..3faf02c 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -171,6 +171,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp ${TORCH_SRC_DIR}/csrc/jit/scope.cpp ${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp + ${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp ${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp ${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp ${TORCH_SRC_DIR}/csrc/jit/script/type_parser.cpp diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index ac44724..ca61850 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -955,6 +956,17 @@ void initJitScriptBindings(PyObject* module) { }); m.def("_jit_import_methods", import_methods); m.def("_jit_set_emit_module_hook", setEmitModuleHook); + + py::class_(m, "FileCheck") + .def(py::init<>()) + .def("check", &testing::FileCheck::check) + .def("check_not", &testing::FileCheck::check_not) + .def("check_same", &testing::FileCheck::check_same) + .def("check_next", &testing::FileCheck::check_next) + .def("check_count", &testing::FileCheck::check_count) + .def("check_dag", &testing::FileCheck::check_dag) + .def("check_count", &testing::FileCheck::check_count) + .def("run", &testing::FileCheck::run); } } // namespace script diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp new file mode 100644 index 0000000..71ac016 --- /dev/null +++ b/torch/csrc/jit/testing/file_check.cpp @@ -0,0 +1,326 @@ +//==-- llvm/Support/FileCheck.h ---------------------------*- C++ -*-==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +// API modified from llvm::FileCheck + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace jit { + +void printQuotedString(std::ostream& stmt, const std::string& str); + +namespace testing { + +enum CheckType { + CHECK, + CHECK_NEXT, + CHECK_SAME, + CHECK_NOT, + CHECK_COUNT, + CHECK_DAG, +}; + +struct Check { + Check( + CheckType type, + std::string str, + c10::optional count = c10::nullopt) + : type_(type), search_str_(std::move(str)) { + count_ = std::move(count); + }; + + CheckType type_; + c10::optional count_; + const std::string search_str_; + + friend std::ostream& operator<<(std::ostream& out, const Check& c); +}; + +std::ostream& operator<<(std::ostream& out, const Check& c) { + switch (c.type_) { + case CHECK: + out << "CHECK"; + break; + case CHECK_NEXT: + out << "CHECK-NEXT"; + break; + case CHECK_SAME: + out << "CHECK-SAME"; + break; + case CHECK_NOT: + out << "CHECK-NOT"; + break; + case CHECK_DAG: + out << "CHECK-DAG"; + break; + case CHECK_COUNT: + out << "CHECK-COUNT-" << *c.count_; + break; + } + out << ": " << c.search_str_; + return out; +}; + +namespace { +size_t assertFind( + const SourceRange& search_range, + const std::string& sub, + const Check& check) { + auto pos = search_range.file_ptr()->find(sub, search_range.start()); + if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) { + auto found_range = SourceRange(search_range.file_ptr(),search_range.start(), + sub.size()); + std::stringstream ss; + ss << "Expected to find "; + printQuotedString(ss, sub); + ss << " but did not find it\n"; + found_range.highlight(ss); + ss << "From " << check << "\n"; + throw std::runtime_error(ss.str()); + } + return pos; +} + +size_t assertFind( + const std::shared_ptr& file, + const std::string& sub, + size_t start, + const Check& check) { + return assertFind(SourceRange(file, start, file->size()), sub, check); +} + +void assertNotFind( + const SourceRange& search_range, + const std::string& sub, + const Check& check) { + auto pos = search_range.file_ptr()->find(sub, search_range.start()); + if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) { + auto found_range = SourceRange(search_range.file_ptr(), pos, sub.size() + pos); + std::stringstream ss; + ss << "Expected to not find "; + printQuotedString(ss, sub); + ss << " but found it\n"; + found_range.highlight(ss); + ss << "From " << check << "\n"; + throw std::runtime_error(ss.str()); + } +} +} // namespace + +struct FileCheckImpl { + TORCH_API explicit FileCheckImpl() = default; + + TORCH_API void run(const std::string& test_file) { + has_run = true; + doChecks(std::make_shared(test_file)); + } + + TORCH_API void addCheck( + CheckType type, + const std::string& s, + c10::optional count = c10::nullopt) { + Check check(type, s, std::move(count)); + + // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group + if (groups.size() == 0 || (type != CHECK_NOT && type != CHECK_DAG)) { + groups.push_back({check}); + } else { + auto& last_group = groups.back(); + if (last_group.at(0).type_ == type) { + last_group.push_back(check); + } else { + groups.push_back({check}); + } + } + + has_run = false; + } + + bool has_run = false; + + private: + void doCheckNot( + const std::vector& nots, + const std::shared_ptr& file, + const SourceRange& prev, + const SourceRange& next) { + auto start = prev.end(); // inclusive + auto end = next.start(); // exclusive + if (end < start) { + return; + } + for (const auto& check : nots) { + AT_ASSERT(check.type_ == CHECK_NOT); + assertNotFind(SourceRange(file, start, end), check.search_str_, check); + } + } + + SourceRange matchDagGroup( + const std::vector& group, + const std::shared_ptr& test_file, + const SourceRange& prev) { + size_t group_beg = std::string::npos; + size_t group_end = 0; + + AT_ASSERT(groups.size() != 0); + for (const auto& check : group) { + AT_ASSERT(check.type_ == group[0].type_); + auto pos = assertFind(test_file, check.search_str_, prev.end(), check); + group_beg = std::min(pos, group_beg); + group_end = std::max(pos + check.search_str_.size(), group_end); + } + + return SourceRange(test_file, group_beg, group_end); + } + + SourceRange matchGroup( + const std::vector& group, + const std::shared_ptr& test_file, + const SourceRange& prev) { + AT_ASSERT(group.size() != 0); + CheckType type = group[0].type_; + + if (type == CHECK_DAG) { + return matchDagGroup(group, test_file, prev); + } + AT_ASSERT(type != CHECK_NOT); + AT_ASSERT(group.size() == 1); + + const auto& check = group[0]; + size_t start_range = prev.end(); + size_t end_range = start_range; + + switch (check.type_) { + case CHECK: { + start_range = + assertFind(test_file, check.search_str_, start_range, check); + end_range = start_range + check.search_str_.size(); + } break; + case CHECK_SAME: { + auto pos = assertFind(test_file, check.search_str_, start_range, check); + assertNotFind(SourceRange(test_file, prev.end(), pos), "\n", check); + start_range = pos; + end_range = pos + check.search_str_.size(); + } break; + case CHECK_NEXT: { + auto line_end = assertFind(test_file, "\n", start_range, check); + auto pos = + assertFind(test_file, check.search_str_, line_end + 1, check); + assertNotFind(SourceRange(test_file, line_end + 1, pos), "\n", check); + start_range = pos; + end_range = pos + check.search_str_.size(); + } break; + case CHECK_COUNT: { + auto group_start_range = std::string::npos; + AT_ASSERT(check.count_ && *check.count_ != 0); + for (size_t i = 0; i < *check.count_; ++i) { + start_range = + assertFind(test_file, check.search_str_, start_range, check); + group_start_range = std::min(start_range, group_start_range); + end_range = start_range + check.search_str_.size(); + start_range = end_range; + } + start_range = group_start_range; + } break; + case CHECK_DAG: { + AT_ERROR(); + } break; + case CHECK_NOT: { + AT_ERROR(); + } break; + } + return SourceRange(test_file, start_range, end_range); + } + + void doChecks(const std::shared_ptr& test_file) { + SourceRange prev(test_file, 0, 0); + for (size_t i = 0; i < groups.size(); i++) { + const auto& curr_group = groups[i]; + CheckType type = curr_group.at(0).type_; + if (type != CHECK_NOT) { + prev = matchGroup(curr_group, test_file, prev); + } else { + if (i + 1 < groups.size()) { + const auto& next_group = groups[i + 1]; + AT_ASSERT(next_group.at(0).type_ != CHECK_NOT); + SourceRange after_not = matchGroup(next_group, test_file, prev); + doCheckNot(curr_group, test_file, prev, after_not); + prev = after_not; + ++i; // already checked the group after + } else { + SourceRange end_of_file( + test_file, test_file->size() + 1, test_file->size() + 1); + doCheckNot(curr_group, test_file, prev, end_of_file); + } + } + } + } + + std::vector checks; + std::shared_ptr check_file; + std::vector> groups; +}; + +FileCheck::FileCheck() : fcImpl(new FileCheckImpl()){}; + +FileCheck::~FileCheck() { + if (!fcImpl->has_run) { + std::cout << "You have not run this instance of FileCheck!\n"; + } + fcImpl.reset(); +}; + +void FileCheck::run(const std::string& test_file) { + fcImpl->run(test_file); +}; + +FileCheck* FileCheck::check(const std::string& str) { + fcImpl->addCheck(CHECK, str); + return this; +} + +FileCheck* FileCheck::check_not(const std::string& str) { + fcImpl->addCheck(CHECK_NOT, str); + return this; +} + +FileCheck* FileCheck::check_same(const std::string& str) { + fcImpl->addCheck(CHECK_SAME, str); + return this; +} + +FileCheck* FileCheck::check_next(const std::string& str) { + fcImpl->addCheck(CHECK_NEXT, str); + return this; +} + +FileCheck* FileCheck::check_count(const std::string& str, size_t count) { + fcImpl->addCheck(CHECK_COUNT, str, count); + return this; +} + +FileCheck* FileCheck::check_dag(const std::string& str) { + fcImpl->addCheck(CHECK_DAG, str); + return this; +} + +} // namespace testing +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/testing/file_check.h b/torch/csrc/jit/testing/file_check.h new file mode 100644 index 0000000..fe4d46a --- /dev/null +++ b/torch/csrc/jit/testing/file_check.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { +namespace testing { + +struct FileCheckImpl; + +struct FileCheck { + public: + TORCH_API explicit FileCheck(); + TORCH_API ~FileCheck(); + + // Run FileCheck against test string + TORCH_API void run(const std::string& test_string); + + // Checks that the string occurs, starting at the end of the most recent match + TORCH_API FileCheck* check(const std::string& str); + + // Checks that the string does not occur between the previous match and next + // match Consecutive check_nots test against the same previous match and next + // match + TORCH_API FileCheck* check_not(const std::string& str); + + // Checks that the string occurs on the same line as the previous match + TORCH_API FileCheck* check_same(const std::string& str); + + // Checks that the string occurs on the line immediately following the + // previous match + TORCH_API FileCheck* check_next(const std::string& str); + + // Checks that the string occurs count number of times + TORCH_API FileCheck* check_count(const std::string& str, size_t count); + + // A series of consecutive check_dags get turned into a group of checks + // which can appear in any order relative to each other. + TORCH_API FileCheck* check_dag(const std::string& str); + + // reset checks + TORCH_API void reset(); + + private: + bool has_run = false; + std::unique_ptr fcImpl; +}; + +} // namespace testing +} // namespace jit +} // namespace torch diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index 250a602..eb2c920 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -5,6 +5,8 @@ The testing package contains testing-specific utilities. import torch import random +FileCheck = torch._C.FileCheck + __all__ = [ 'assert_allclose', 'make_non_contiguous', 'rand_like', 'randn_like' ]