Lightweight String check Utility (#16858)
authorElias Ellison <eellison@fb.com>
Tue, 19 Feb 2019 20:25:30 +0000 (12:25 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 19 Feb 2019 20:31:57 +0000 (12:31 -0800)
Summary:
light weight implementation of LLVM filecheck utility. Currently only handles string matching - regexes & saving a regex to a variable name can be added as needed.

Current intended usage is through FileCheckBuilder python handle, and is shown in the tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16858

Differential Revision: D14096244

Pulled By: eellison

fbshipit-source-id: c7c8d1457691c105e6ccbb3c1a378d96baac2569

setup.py
test/test_jit.py
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/jit/script/init.cpp
torch/csrc/jit/testing/file_check.cpp [new file with mode: 0644]
torch/csrc/jit/testing/file_check.h [new file with mode: 0644]
torch/testing/__init__.py

index d21567b..3318bbb 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -800,6 +800,7 @@ if __name__ == '__main__':
                 'include/torch/csrc/jit/generated/*.h',
                 'include/torch/csrc/jit/passes/*.h',
                 'include/torch/csrc/jit/script/*.h',
+                'include/torch/csrc/jit/testing/*.h',
                 'include/torch/csrc/onnx/*.h',
                 'include/torch/csrc/utils/*.h',
                 'include/pybind11/*.h',
index 143ba0c..e7f0235 100644 (file)
@@ -5537,6 +5537,108 @@ a")
         m2.sub2.a.data.zero_()
         self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
 
+    def test_filecheck(self):
+        from torch.testing import FileCheck
+
+        # def test_accidental_not_used():
+        #     def unused():
+        #         a = FileCheck()
+        #
+        #     with self.capture_stdout() as captured:
+        #         a = FileCheck()
+        #         del a
+        #     self.assertTrue("You have not run this instance of FileCheck"
+        #                     in captured[0])
+        #
+        # test_accidental_not_used()
+        def test_check():
+            file = "232"
+            FileCheck().check("2").check("3").check("2").run(file)
+            FileCheck().check("232").run(file)
+
+            with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
+                FileCheck().check("22").run(file)
+            with self.assertRaisesRegex(RuntimeError, "CHECK: 3"):
+                FileCheck().check("3").check("3").run(file)
+
+        test_check()
+
+        def test_check_count():
+            file = "22222"
+            FileCheck().check_count("2", 5).run(file)
+            FileCheck().check_count("22", 2).run(file)
+            FileCheck().check_count("222", 1).run(file)
+
+            with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
+                FileCheck().check_count("22", 3).run(file)
+
+            with self.assertRaisesRegex(RuntimeError, "CHECK-COUNT-6: 2"):
+                FileCheck().check_count("2", 6).run(file)
+
+        test_check_count()
+
+        def test_check_same():
+            file = "22\n33"
+            # FileCheck().check_same("22").run(file)
+
+            with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
+                FileCheck().check_same("33").run(file)
+
+            file = "22  1  3"
+
+            FileCheck().check("2").check_same("3").run(file)
+            FileCheck().check_count("2", 2).check_same("3").run(file)
+
+        test_check_same()
+
+        def test_check_next():
+            file = "\n1\n2\n3"
+            FileCheck().check("1").check_next("2").check_next("3").run(file)
+            FileCheck().check_next("1").check_next("2").check_next("3").run(file)
+
+            with self.assertRaisesRegex(RuntimeError, "Expected to find"):
+                FileCheck().check("1").check_next("2").run("12")
+
+            with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
+                FileCheck().check("1").check_next("2").run("1\n\n2")
+
+        test_check_next()
+
+        def test_check_dag():
+            fc = FileCheck().check_dag("1").check_dag("2").check_not("2")
+            fc.run("12")
+            fc.run("21")
+
+            fc = FileCheck()
+            fc.check_not("3").check_dag("1").check_dag("2").check_not("3")
+            fc.run("1 3 2")
+            fc.run("2 3 1")
+
+            fc = FileCheck().check_dag("1").check_dag("2").check("3")
+            with self.assertRaisesRegex(RuntimeError, 'Expected to find "3" but did not find it'):
+                fc.run("1 3 2")
+
+        test_check_dag()
+
+        def test_check_not():
+            FileCheck().check_not("2").check("1").run("12")
+            FileCheck().check("2").check_not("2").run("12")
+
+            with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
+                FileCheck().check_not("2").check("1").run("21")
+
+            with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
+                FileCheck().check("2").check_not("1").run("21")
+
+            # checks with distinct range matchings
+            fb = FileCheck().check_count("2", 2).check_count("2", 2).check_not("2")
+            with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
+                fb.run("22 2 22")
+
+            fb = FileCheck().check_count("2", 2).check_not("1").check_count("2", 2)
+            with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
+                fb.run("22 1 22")
+
     def test_script_module_call_noscript(self):
         class M(torch.jit.ScriptModule):
             def __init__(self):
index 9323d55..4273e1b 100644 (file)
@@ -98,6 +98,7 @@ libtorch_sources = [
     "torch/csrc/jit/script/sugared_value.cpp",
     "torch/csrc/jit/script/schema_matching.cpp",
     "torch/csrc/jit/script/parser.cpp",
+    "torch/csrc/jit/testing/file_check.cpp",
     "torch/csrc/jit/import_method.cpp",
     "torch/csrc/jit/hooks_for_testing.cpp",
     "torch/csrc/jit/script/builtin_functions.cpp",
index 4e402b9..3faf02c 100644 (file)
@@ -171,6 +171,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
   ${TORCH_SRC_DIR}/csrc/jit/scope.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/type_parser.cpp
index ac44724..ca61850 100644 (file)
@@ -8,6 +8,7 @@
 #include <torch/csrc/jit/script/schema_matching.h>
 #include <torch/csrc/jit/script/sugared_value.h>
 #include <torch/csrc/jit/script/module.h>
+#include <torch/csrc/jit/testing/file_check.h>
 
 #include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/hooks_for_testing.h>
@@ -955,6 +956,17 @@ void initJitScriptBindings(PyObject* module) {
       });
   m.def("_jit_import_methods", import_methods);
   m.def("_jit_set_emit_module_hook", setEmitModuleHook);
+
+  py::class_<testing::FileCheck>(m, "FileCheck")
+      .def(py::init<>())
+      .def("check", &testing::FileCheck::check)
+      .def("check_not", &testing::FileCheck::check_not)
+      .def("check_same", &testing::FileCheck::check_same)
+      .def("check_next", &testing::FileCheck::check_next)
+      .def("check_count", &testing::FileCheck::check_count)
+      .def("check_dag", &testing::FileCheck::check_dag)
+      .def("check_count", &testing::FileCheck::check_count)
+      .def("run", &testing::FileCheck::run);
 }
 
 } // namespace script
diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp
new file mode 100644 (file)
index 0000000..71ac016
--- /dev/null
@@ -0,0 +1,326 @@
+//==-- llvm/Support/FileCheck.h ---------------------------*- C++ -*-==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+// API modified from llvm::FileCheck
+
+#include <c10/util/Exception.h>
+#include <c10/util/Optional.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/source_range.h>
+#include <algorithm>
+#include <iostream>
+#include <sstream>
+#include <string>
+
+#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/testing/file_check.h>
+
+namespace torch {
+namespace jit {
+
+void printQuotedString(std::ostream& stmt, const std::string& str);
+
+namespace testing {
+
+enum CheckType {
+  CHECK,
+  CHECK_NEXT,
+  CHECK_SAME,
+  CHECK_NOT,
+  CHECK_COUNT,
+  CHECK_DAG,
+};
+
+struct Check {
+  Check(
+      CheckType type,
+      std::string str,
+      c10::optional<size_t> count = c10::nullopt)
+      : type_(type), search_str_(std::move(str)) {
+    count_ = std::move(count);
+  };
+
+  CheckType type_;
+  c10::optional<size_t> count_;
+  const std::string search_str_;
+
+  friend std::ostream& operator<<(std::ostream& out, const Check& c);
+};
+
+std::ostream& operator<<(std::ostream& out, const Check& c) {
+  switch (c.type_) {
+    case CHECK:
+      out << "CHECK";
+      break;
+    case CHECK_NEXT:
+      out << "CHECK-NEXT";
+      break;
+    case CHECK_SAME:
+      out << "CHECK-SAME";
+      break;
+    case CHECK_NOT:
+      out << "CHECK-NOT";
+      break;
+    case CHECK_DAG:
+      out << "CHECK-DAG";
+      break;
+    case CHECK_COUNT:
+      out << "CHECK-COUNT-" << *c.count_;
+      break;
+  }
+  out << ": " << c.search_str_;
+  return out;
+};
+
+namespace {
+size_t assertFind(
+    const SourceRange& search_range,
+    const std::string& sub,
+    const Check& check) {
+  auto pos = search_range.file_ptr()->find(sub, search_range.start());
+  if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
+    auto found_range = SourceRange(search_range.file_ptr(),search_range.start(),
+        sub.size());
+    std::stringstream ss;
+    ss << "Expected to find ";
+    printQuotedString(ss, sub);
+    ss << " but did not find it\n";
+    found_range.highlight(ss);
+    ss << "From " << check << "\n";
+    throw std::runtime_error(ss.str());
+  }
+  return pos;
+}
+
+size_t assertFind(
+    const std::shared_ptr<std::string>& file,
+    const std::string& sub,
+    size_t start,
+    const Check& check) {
+  return assertFind(SourceRange(file, start, file->size()), sub, check);
+}
+
+void assertNotFind(
+    const SourceRange& search_range,
+    const std::string& sub,
+    const Check& check) {
+  auto pos = search_range.file_ptr()->find(sub, search_range.start());
+  if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) {
+    auto found_range = SourceRange(search_range.file_ptr(), pos, sub.size() + pos);
+    std::stringstream ss;
+    ss << "Expected to not find ";
+    printQuotedString(ss, sub);
+    ss << " but found it\n";
+    found_range.highlight(ss);
+    ss << "From " << check << "\n";
+    throw std::runtime_error(ss.str());
+  }
+}
+} // namespace
+
+struct FileCheckImpl {
+  TORCH_API explicit FileCheckImpl() = default;
+
+  TORCH_API void run(const std::string& test_file) {
+    has_run = true;
+    doChecks(std::make_shared<std::string>(test_file));
+  }
+
+  TORCH_API void addCheck(
+      CheckType type,
+      const std::string& s,
+      c10::optional<size_t> count = c10::nullopt) {
+    Check check(type, s, std::move(count));
+
+    // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group
+    if (groups.size() == 0 || (type != CHECK_NOT && type != CHECK_DAG)) {
+      groups.push_back({check});
+    } else {
+      auto& last_group = groups.back();
+      if (last_group.at(0).type_ == type) {
+        last_group.push_back(check);
+      } else {
+        groups.push_back({check});
+      }
+    }
+
+    has_run = false;
+  }
+
+  bool has_run = false;
+
+ private:
+  void doCheckNot(
+      const std::vector<Check>& nots,
+      const std::shared_ptr<std::string>& file,
+      const SourceRange& prev,
+      const SourceRange& next) {
+    auto start = prev.end(); // inclusive
+    auto end = next.start(); // exclusive
+    if (end < start) {
+      return;
+    }
+    for (const auto& check : nots) {
+      AT_ASSERT(check.type_ == CHECK_NOT);
+      assertNotFind(SourceRange(file, start, end), check.search_str_, check);
+    }
+  }
+
+  SourceRange matchDagGroup(
+      const std::vector<Check>& group,
+      const std::shared_ptr<std::string>& test_file,
+      const SourceRange& prev) {
+    size_t group_beg = std::string::npos;
+    size_t group_end = 0;
+
+    AT_ASSERT(groups.size() != 0);
+    for (const auto& check : group) {
+      AT_ASSERT(check.type_ == group[0].type_);
+      auto pos = assertFind(test_file, check.search_str_, prev.end(), check);
+      group_beg = std::min(pos, group_beg);
+      group_end = std::max(pos + check.search_str_.size(), group_end);
+    }
+
+    return SourceRange(test_file, group_beg, group_end);
+  }
+
+  SourceRange matchGroup(
+      const std::vector<Check>& group,
+      const std::shared_ptr<std::string>& test_file,
+      const SourceRange& prev) {
+    AT_ASSERT(group.size() != 0);
+    CheckType type = group[0].type_;
+
+    if (type == CHECK_DAG) {
+      return matchDagGroup(group, test_file, prev);
+    }
+    AT_ASSERT(type != CHECK_NOT);
+    AT_ASSERT(group.size() == 1);
+
+    const auto& check = group[0];
+    size_t start_range = prev.end();
+    size_t end_range = start_range;
+
+    switch (check.type_) {
+      case CHECK: {
+        start_range =
+            assertFind(test_file, check.search_str_, start_range, check);
+        end_range = start_range + check.search_str_.size();
+      } break;
+      case CHECK_SAME: {
+        auto pos = assertFind(test_file, check.search_str_, start_range, check);
+        assertNotFind(SourceRange(test_file, prev.end(), pos), "\n", check);
+        start_range = pos;
+        end_range = pos + check.search_str_.size();
+      } break;
+      case CHECK_NEXT: {
+        auto line_end = assertFind(test_file, "\n", start_range, check);
+        auto pos =
+            assertFind(test_file, check.search_str_, line_end + 1, check);
+        assertNotFind(SourceRange(test_file, line_end + 1, pos), "\n", check);
+        start_range = pos;
+        end_range = pos + check.search_str_.size();
+      } break;
+      case CHECK_COUNT: {
+        auto group_start_range = std::string::npos;
+        AT_ASSERT(check.count_ && *check.count_ != 0);
+        for (size_t i = 0; i < *check.count_; ++i) {
+          start_range =
+              assertFind(test_file, check.search_str_, start_range, check);
+          group_start_range = std::min(start_range, group_start_range);
+          end_range = start_range + check.search_str_.size();
+          start_range = end_range;
+        }
+        start_range = group_start_range;
+      } break;
+      case CHECK_DAG: {
+        AT_ERROR();
+      } break;
+      case CHECK_NOT: {
+        AT_ERROR();
+      } break;
+    }
+    return SourceRange(test_file, start_range, end_range);
+  }
+
+  void doChecks(const std::shared_ptr<std::string>& test_file) {
+    SourceRange prev(test_file, 0, 0);
+    for (size_t i = 0; i < groups.size(); i++) {
+      const auto& curr_group = groups[i];
+      CheckType type = curr_group.at(0).type_;
+      if (type != CHECK_NOT) {
+        prev = matchGroup(curr_group, test_file, prev);
+      } else {
+        if (i + 1 < groups.size()) {
+          const auto& next_group = groups[i + 1];
+          AT_ASSERT(next_group.at(0).type_ != CHECK_NOT);
+          SourceRange after_not = matchGroup(next_group, test_file, prev);
+          doCheckNot(curr_group, test_file, prev, after_not);
+          prev = after_not;
+          ++i; // already checked the group after
+        } else {
+          SourceRange end_of_file(
+              test_file, test_file->size() + 1, test_file->size() + 1);
+          doCheckNot(curr_group, test_file, prev, end_of_file);
+        }
+      }
+    }
+  }
+
+  std::vector<Check> checks;
+  std::shared_ptr<std::string> check_file;
+  std::vector<std::vector<Check>> groups;
+};
+
+FileCheck::FileCheck() : fcImpl(new FileCheckImpl()){};
+
+FileCheck::~FileCheck() {
+  if (!fcImpl->has_run) {
+    std::cout << "You have not run this instance of FileCheck!\n";
+  }
+  fcImpl.reset();
+};
+
+void FileCheck::run(const std::string& test_file) {
+  fcImpl->run(test_file);
+};
+
+FileCheck* FileCheck::check(const std::string& str) {
+  fcImpl->addCheck(CHECK, str);
+  return this;
+}
+
+FileCheck* FileCheck::check_not(const std::string& str) {
+  fcImpl->addCheck(CHECK_NOT, str);
+  return this;
+}
+
+FileCheck* FileCheck::check_same(const std::string& str) {
+  fcImpl->addCheck(CHECK_SAME, str);
+  return this;
+}
+
+FileCheck* FileCheck::check_next(const std::string& str) {
+  fcImpl->addCheck(CHECK_NEXT, str);
+  return this;
+}
+
+FileCheck* FileCheck::check_count(const std::string& str, size_t count) {
+  fcImpl->addCheck(CHECK_COUNT, str, count);
+  return this;
+}
+
+FileCheck* FileCheck::check_dag(const std::string& str) {
+  fcImpl->addCheck(CHECK_DAG, str);
+  return this;
+}
+
+} // namespace testing
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/testing/file_check.h b/torch/csrc/jit/testing/file_check.h
new file mode 100644 (file)
index 0000000..fe4d46a
--- /dev/null
@@ -0,0 +1,52 @@
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/testing/file_check.h>
+
+namespace torch {
+namespace jit {
+namespace testing {
+
+struct FileCheckImpl;
+
+struct FileCheck {
+ public:
+  TORCH_API explicit FileCheck();
+  TORCH_API ~FileCheck();
+
+  // Run FileCheck against test string
+  TORCH_API void run(const std::string& test_string);
+
+  // Checks that the string occurs, starting at the end of the most recent match
+  TORCH_API FileCheck* check(const std::string& str);
+
+  // Checks that the string does not occur between the previous match and next
+  // match Consecutive check_nots test against the same previous match and next
+  // match
+  TORCH_API FileCheck* check_not(const std::string& str);
+
+  // Checks that the string occurs on the same line as the previous match
+  TORCH_API FileCheck* check_same(const std::string& str);
+
+  // Checks that the string occurs on the line immediately following the
+  // previous match
+  TORCH_API FileCheck* check_next(const std::string& str);
+
+  // Checks that the string occurs count number of times
+  TORCH_API FileCheck* check_count(const std::string& str, size_t count);
+
+  // A series of consecutive check_dags get turned into a group of checks
+  // which can appear in any order relative to each other.
+  TORCH_API FileCheck* check_dag(const std::string& str);
+
+  // reset checks
+  TORCH_API void reset();
+
+ private:
+  bool has_run = false;
+  std::unique_ptr<FileCheckImpl> fcImpl;
+};
+
+} // namespace testing
+} // namespace jit
+} // namespace torch
index 250a602..eb2c920 100644 (file)
@@ -5,6 +5,8 @@ The testing package contains testing-specific utilities.
 import torch
 import random
 
+FileCheck = torch._C.FileCheck
+
 __all__ = [
     'assert_allclose', 'make_non_contiguous', 'rand_like', 'randn_like'
 ]