m2.sub2.a.data.zero_()
self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
+ def test_filecheck(self):
+ from torch.testing import FileCheck
+
+ # def test_accidental_not_used():
+ # def unused():
+ # a = FileCheck()
+ #
+ # with self.capture_stdout() as captured:
+ # a = FileCheck()
+ # del a
+ # self.assertTrue("You have not run this instance of FileCheck"
+ # in captured[0])
+ #
+ # test_accidental_not_used()
+ def test_check():
+ file = "232"
+ FileCheck().check("2").check("3").check("2").run(file)
+ FileCheck().check("232").run(file)
+
+ with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
+ FileCheck().check("22").run(file)
+ with self.assertRaisesRegex(RuntimeError, "CHECK: 3"):
+ FileCheck().check("3").check("3").run(file)
+
+ test_check()
+
+ def test_check_count():
+ file = "22222"
+ FileCheck().check_count("2", 5).run(file)
+ FileCheck().check_count("22", 2).run(file)
+ FileCheck().check_count("222", 1).run(file)
+
+ with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
+ FileCheck().check_count("22", 3).run(file)
+
+ with self.assertRaisesRegex(RuntimeError, "CHECK-COUNT-6: 2"):
+ FileCheck().check_count("2", 6).run(file)
+
+ test_check_count()
+
+ def test_check_same():
+ file = "22\n33"
+ # FileCheck().check_same("22").run(file)
+
+ with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
+ FileCheck().check_same("33").run(file)
+
+ file = "22 1 3"
+
+ FileCheck().check("2").check_same("3").run(file)
+ FileCheck().check_count("2", 2).check_same("3").run(file)
+
+ test_check_same()
+
+ def test_check_next():
+ file = "\n1\n2\n3"
+ FileCheck().check("1").check_next("2").check_next("3").run(file)
+ FileCheck().check_next("1").check_next("2").check_next("3").run(file)
+
+ with self.assertRaisesRegex(RuntimeError, "Expected to find"):
+ FileCheck().check("1").check_next("2").run("12")
+
+ with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
+ FileCheck().check("1").check_next("2").run("1\n\n2")
+
+ test_check_next()
+
+ def test_check_dag():
+ fc = FileCheck().check_dag("1").check_dag("2").check_not("2")
+ fc.run("12")
+ fc.run("21")
+
+ fc = FileCheck()
+ fc.check_not("3").check_dag("1").check_dag("2").check_not("3")
+ fc.run("1 3 2")
+ fc.run("2 3 1")
+
+ fc = FileCheck().check_dag("1").check_dag("2").check("3")
+ with self.assertRaisesRegex(RuntimeError, 'Expected to find "3" but did not find it'):
+ fc.run("1 3 2")
+
+ test_check_dag()
+
+ def test_check_not():
+ FileCheck().check_not("2").check("1").run("12")
+ FileCheck().check("2").check_not("2").run("12")
+
+ with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
+ FileCheck().check_not("2").check("1").run("21")
+
+ with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
+ FileCheck().check("2").check_not("1").run("21")
+
+ # checks with distinct range matchings
+ fb = FileCheck().check_count("2", 2).check_count("2", 2).check_not("2")
+ with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
+ fb.run("22 2 22")
+
+ fb = FileCheck().check_count("2", 2).check_not("1").check_count("2", 2)
+ with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
+ fb.run("22 1 22")
+
def test_script_module_call_noscript(self):
class M(torch.jit.ScriptModule):
def __init__(self):
--- /dev/null
+//==-- llvm/Support/FileCheck.h ---------------------------*- C++ -*-==//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+// API modified from llvm::FileCheck
+
+#include <c10/util/Exception.h>
+#include <c10/util/Optional.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/source_range.h>
+#include <algorithm>
+#include <iostream>
+#include <sstream>
+#include <string>
+
+#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/testing/file_check.h>
+
+namespace torch {
+namespace jit {
+
+void printQuotedString(std::ostream& stmt, const std::string& str);
+
+namespace testing {
+
+enum CheckType {
+ CHECK,
+ CHECK_NEXT,
+ CHECK_SAME,
+ CHECK_NOT,
+ CHECK_COUNT,
+ CHECK_DAG,
+};
+
+struct Check {
+ Check(
+ CheckType type,
+ std::string str,
+ c10::optional<size_t> count = c10::nullopt)
+ : type_(type), search_str_(std::move(str)) {
+ count_ = std::move(count);
+ };
+
+ CheckType type_;
+ c10::optional<size_t> count_;
+ const std::string search_str_;
+
+ friend std::ostream& operator<<(std::ostream& out, const Check& c);
+};
+
+std::ostream& operator<<(std::ostream& out, const Check& c) {
+ switch (c.type_) {
+ case CHECK:
+ out << "CHECK";
+ break;
+ case CHECK_NEXT:
+ out << "CHECK-NEXT";
+ break;
+ case CHECK_SAME:
+ out << "CHECK-SAME";
+ break;
+ case CHECK_NOT:
+ out << "CHECK-NOT";
+ break;
+ case CHECK_DAG:
+ out << "CHECK-DAG";
+ break;
+ case CHECK_COUNT:
+ out << "CHECK-COUNT-" << *c.count_;
+ break;
+ }
+ out << ": " << c.search_str_;
+ return out;
+};
+
+namespace {
+size_t assertFind(
+ const SourceRange& search_range,
+ const std::string& sub,
+ const Check& check) {
+ auto pos = search_range.file_ptr()->find(sub, search_range.start());
+ if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
+ auto found_range = SourceRange(search_range.file_ptr(),search_range.start(),
+ sub.size());
+ std::stringstream ss;
+ ss << "Expected to find ";
+ printQuotedString(ss, sub);
+ ss << " but did not find it\n";
+ found_range.highlight(ss);
+ ss << "From " << check << "\n";
+ throw std::runtime_error(ss.str());
+ }
+ return pos;
+}
+
+size_t assertFind(
+ const std::shared_ptr<std::string>& file,
+ const std::string& sub,
+ size_t start,
+ const Check& check) {
+ return assertFind(SourceRange(file, start, file->size()), sub, check);
+}
+
+void assertNotFind(
+ const SourceRange& search_range,
+ const std::string& sub,
+ const Check& check) {
+ auto pos = search_range.file_ptr()->find(sub, search_range.start());
+ if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) {
+ auto found_range = SourceRange(search_range.file_ptr(), pos, sub.size() + pos);
+ std::stringstream ss;
+ ss << "Expected to not find ";
+ printQuotedString(ss, sub);
+ ss << " but found it\n";
+ found_range.highlight(ss);
+ ss << "From " << check << "\n";
+ throw std::runtime_error(ss.str());
+ }
+}
+} // namespace
+
+struct FileCheckImpl {
+ TORCH_API explicit FileCheckImpl() = default;
+
+ TORCH_API void run(const std::string& test_file) {
+ has_run = true;
+ doChecks(std::make_shared<std::string>(test_file));
+ }
+
+ TORCH_API void addCheck(
+ CheckType type,
+ const std::string& s,
+ c10::optional<size_t> count = c10::nullopt) {
+ Check check(type, s, std::move(count));
+
+ // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group
+ if (groups.size() == 0 || (type != CHECK_NOT && type != CHECK_DAG)) {
+ groups.push_back({check});
+ } else {
+ auto& last_group = groups.back();
+ if (last_group.at(0).type_ == type) {
+ last_group.push_back(check);
+ } else {
+ groups.push_back({check});
+ }
+ }
+
+ has_run = false;
+ }
+
+ bool has_run = false;
+
+ private:
+ void doCheckNot(
+ const std::vector<Check>& nots,
+ const std::shared_ptr<std::string>& file,
+ const SourceRange& prev,
+ const SourceRange& next) {
+ auto start = prev.end(); // inclusive
+ auto end = next.start(); // exclusive
+ if (end < start) {
+ return;
+ }
+ for (const auto& check : nots) {
+ AT_ASSERT(check.type_ == CHECK_NOT);
+ assertNotFind(SourceRange(file, start, end), check.search_str_, check);
+ }
+ }
+
+ SourceRange matchDagGroup(
+ const std::vector<Check>& group,
+ const std::shared_ptr<std::string>& test_file,
+ const SourceRange& prev) {
+ size_t group_beg = std::string::npos;
+ size_t group_end = 0;
+
+ AT_ASSERT(groups.size() != 0);
+ for (const auto& check : group) {
+ AT_ASSERT(check.type_ == group[0].type_);
+ auto pos = assertFind(test_file, check.search_str_, prev.end(), check);
+ group_beg = std::min(pos, group_beg);
+ group_end = std::max(pos + check.search_str_.size(), group_end);
+ }
+
+ return SourceRange(test_file, group_beg, group_end);
+ }
+
+ SourceRange matchGroup(
+ const std::vector<Check>& group,
+ const std::shared_ptr<std::string>& test_file,
+ const SourceRange& prev) {
+ AT_ASSERT(group.size() != 0);
+ CheckType type = group[0].type_;
+
+ if (type == CHECK_DAG) {
+ return matchDagGroup(group, test_file, prev);
+ }
+ AT_ASSERT(type != CHECK_NOT);
+ AT_ASSERT(group.size() == 1);
+
+ const auto& check = group[0];
+ size_t start_range = prev.end();
+ size_t end_range = start_range;
+
+ switch (check.type_) {
+ case CHECK: {
+ start_range =
+ assertFind(test_file, check.search_str_, start_range, check);
+ end_range = start_range + check.search_str_.size();
+ } break;
+ case CHECK_SAME: {
+ auto pos = assertFind(test_file, check.search_str_, start_range, check);
+ assertNotFind(SourceRange(test_file, prev.end(), pos), "\n", check);
+ start_range = pos;
+ end_range = pos + check.search_str_.size();
+ } break;
+ case CHECK_NEXT: {
+ auto line_end = assertFind(test_file, "\n", start_range, check);
+ auto pos =
+ assertFind(test_file, check.search_str_, line_end + 1, check);
+ assertNotFind(SourceRange(test_file, line_end + 1, pos), "\n", check);
+ start_range = pos;
+ end_range = pos + check.search_str_.size();
+ } break;
+ case CHECK_COUNT: {
+ auto group_start_range = std::string::npos;
+ AT_ASSERT(check.count_ && *check.count_ != 0);
+ for (size_t i = 0; i < *check.count_; ++i) {
+ start_range =
+ assertFind(test_file, check.search_str_, start_range, check);
+ group_start_range = std::min(start_range, group_start_range);
+ end_range = start_range + check.search_str_.size();
+ start_range = end_range;
+ }
+ start_range = group_start_range;
+ } break;
+ case CHECK_DAG: {
+ AT_ERROR();
+ } break;
+ case CHECK_NOT: {
+ AT_ERROR();
+ } break;
+ }
+ return SourceRange(test_file, start_range, end_range);
+ }
+
+ void doChecks(const std::shared_ptr<std::string>& test_file) {
+ SourceRange prev(test_file, 0, 0);
+ for (size_t i = 0; i < groups.size(); i++) {
+ const auto& curr_group = groups[i];
+ CheckType type = curr_group.at(0).type_;
+ if (type != CHECK_NOT) {
+ prev = matchGroup(curr_group, test_file, prev);
+ } else {
+ if (i + 1 < groups.size()) {
+ const auto& next_group = groups[i + 1];
+ AT_ASSERT(next_group.at(0).type_ != CHECK_NOT);
+ SourceRange after_not = matchGroup(next_group, test_file, prev);
+ doCheckNot(curr_group, test_file, prev, after_not);
+ prev = after_not;
+ ++i; // already checked the group after
+ } else {
+ SourceRange end_of_file(
+ test_file, test_file->size() + 1, test_file->size() + 1);
+ doCheckNot(curr_group, test_file, prev, end_of_file);
+ }
+ }
+ }
+ }
+
+ std::vector<Check> checks;
+ std::shared_ptr<std::string> check_file;
+ std::vector<std::vector<Check>> groups;
+};
+
+FileCheck::FileCheck() : fcImpl(new FileCheckImpl()){};
+
+FileCheck::~FileCheck() {
+ if (!fcImpl->has_run) {
+ std::cout << "You have not run this instance of FileCheck!\n";
+ }
+ fcImpl.reset();
+};
+
+void FileCheck::run(const std::string& test_file) {
+ fcImpl->run(test_file);
+};
+
+FileCheck* FileCheck::check(const std::string& str) {
+ fcImpl->addCheck(CHECK, str);
+ return this;
+}
+
+FileCheck* FileCheck::check_not(const std::string& str) {
+ fcImpl->addCheck(CHECK_NOT, str);
+ return this;
+}
+
+FileCheck* FileCheck::check_same(const std::string& str) {
+ fcImpl->addCheck(CHECK_SAME, str);
+ return this;
+}
+
+FileCheck* FileCheck::check_next(const std::string& str) {
+ fcImpl->addCheck(CHECK_NEXT, str);
+ return this;
+}
+
+FileCheck* FileCheck::check_count(const std::string& str, size_t count) {
+ fcImpl->addCheck(CHECK_COUNT, str, count);
+ return this;
+}
+
+FileCheck* FileCheck::check_dag(const std::string& str) {
+ fcImpl->addCheck(CHECK_DAG, str);
+ return this;
+}
+
+} // namespace testing
+} // namespace jit
+} // namespace torch