exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
from torch.testing import FileCheck
from torch._C import TensorType, TupleType, FloatType, IntType, \
- ListType, StringType, DictType
+ ListType, StringType, DictType, parse_ir
from copy import deepcopy
import random
from typing import List, Dict, Optional, Tuple
m2.sub2.a.data.zero_()
self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
+ def test_irparser(self):
+ graph_str = """graph(%0 : Double(5, 5)):
+ # CHECK: aten::relu
+ %1 : Double(5, 5) = aten::relu(%0)
+ return (%1)
+ """
+ FileCheck().run(graph_str, parse_ir(graph_str))
+
def test_filecheck(self):
def test_check():
file = "232"
with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
fb.run("22 1 22")
+ def test_filecheck_parse(self):
+ def test_check():
+ file = """
+ # CHECK: 2
+ # CHECK: 3
+ # CHECK: 2
+ 232
+ """
+ FileCheck().run(checks_file=file, test_file=file)
+ file = """
+ # CHECK: 232
+ 232
+ """
+ FileCheck().run(file, "232")
+ with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'):
+ FileCheck().run(file, "22")
+ with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
+ FileCheck().run("# CHECK: 22", "23")
+ test_check()
+
+ def test_check_count():
+ file = "22222"
+ FileCheck().run("# CHECK-COUNT-5: 2", file)
+ FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file)
+ FileCheck().run("# CHECK-COUNT-2: 22", file)
+ FileCheck().run("# CHECK-COUNT-1: 222", file)
+
+ with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
+ FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file)
+ test_check_count()
+
+ def test_check_same():
+ file = "22\n33"
+ FileCheck().run("# CHECK-SAME: 22", file)
+
+ with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
+ FileCheck().run("# CHECK-SAME: 33", file)
+
+ file = "22 1 3"
+
+ FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file)
+ FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file)
+ test_check_same()
+
+ def test_bad_input():
+ with self.assertRaisesRegex(RuntimeError, "Check for bad input"):
+ FileCheck().run("", "1")
+
+ with self.assertRaisesRegex(RuntimeError, "Could not parse check"):
+ FileCheck().run("# CHECK1", "")
+
+ test_bad_input()
+
def test_script_module_call_noscript(self):
class M(torch.jit.ScriptModule):
def __init__(self):
};
namespace {
+
size_t assertFind(
const SourceRange& search_range,
const std::string& sub,
- const Check& check) {
+ std::function<void(std::ostream& out)> extra_msg = nullptr) {
auto pos = search_range.file_ptr()->find(sub, search_range.start());
if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
auto found_range =
printQuotedString(ss, sub);
ss << " but did not find it\n";
found_range.highlight(ss);
- ss << "From " << check << "\n";
+ if (extra_msg) {
+ extra_msg(ss);
+ }
throw std::runtime_error(ss.str());
}
return pos;
}
size_t assertFind(
+ const SourceRange& search_range,
+ const std::string& sub,
+ const Check& check) {
+ return assertFind(search_range, sub, [&](std::ostream& out) {
+ out << "From " << check << "\n";
+ });
+}
+
+size_t assertFind(
const std::shared_ptr<std::string>& file,
const std::string& sub,
size_t start,
throw std::runtime_error(ss.str());
}
}
+
} // namespace
struct FileCheckImpl {
TORCH_API void run(const std::string& test_file) {
has_run = true;
+
+ if (groups.size() == 0 || groups[0].size() == 0) {
+ throw std::runtime_error(
+ "No checks have been added to this instance of"
+ "Filecheck! Check for bad input.");
+ }
+
doChecks(std::make_shared<std::string>(test_file));
}
- TORCH_API void addCheck(
- CheckType type,
- const std::string& s,
- c10::optional<size_t> count = c10::nullopt) {
- Check check(type, s, std::move(count));
+ TORCH_API void run(
+ const std::string& checks_file,
+ const std::string& test_file) {
+ auto checks_ptr = std::make_shared<std::string>(checks_file);
+ parseStrings(checks_ptr);
+ run(test_file);
+ }
+ TORCH_API void addCheck(Check check) {
// consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group
- if (groups.size() == 0 || (type != CHECK_NOT && type != CHECK_DAG)) {
+ if (groups.size() == 0 ||
+ (check.type_ != CHECK_NOT && check.type_ != CHECK_DAG)) {
groups.push_back({check});
} else {
auto& last_group = groups.back();
- if (last_group.at(0).type_ == type) {
+ if (last_group.at(0).type_ == check.type_) {
last_group.push_back(check);
} else {
groups.push_back({check});
}
}
-
has_run = false;
}
+ TORCH_API void addCheck(
+ CheckType type,
+ const std::string& s,
+ c10::optional<size_t> count = c10::nullopt) {
+ addCheck(Check(type, s, std::move(count)));
+ }
+
bool has_run = false;
friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
private:
+ bool parseSingleCheck(
+ const std::shared_ptr<std::string>& checks_file,
+ size_t* start) {
+ const static std::vector<std::pair<CheckType, std::string>> check_pairs = {
+ {CHECK, ": "},
+ {CHECK_NEXT, "-NEXT: "},
+ {CHECK_SAME, "-SAME: "},
+ {CHECK_NOT, "-NOT: "},
+ {CHECK_DAG, "-DAG: "},
+ {CHECK_COUNT, "-COUNT-"}, // needs special parsing
+ };
+
+ for (const auto& check_pair : check_pairs) {
+ const std::string& check_suffix = check_pair.second;
+ auto suffix_pos = checks_file->find(check_suffix, *start);
+ if (suffix_pos != *start) {
+ continue;
+ }
+ size_t end_check_string = suffix_pos + check_suffix.size();
+ CheckType type = check_pair.first;
+ c10::optional<size_t> count = c10::nullopt;
+ auto end_line = checks_file->find("\n", end_check_string);
+ bool exactly = false;
+ if (type == CHECK_COUNT) {
+ const std::string exact = "EXACTLY-";
+ if (checks_file->find(exact, end_check_string) == end_check_string) {
+ exactly = true;
+ end_check_string += exact.size();
+ }
+ size_t end = assertFind(
+ SourceRange(checks_file, end_check_string, end_line), ":");
+ count = std::stoll(
+ checks_file->substr(end_check_string, end - end_check_string));
+ end_check_string = end + 2; // add ':' and the space
+ }
+ auto check = Check(
+ type,
+ checks_file->substr(end_check_string, end_line - end_check_string),
+ count);
+ addCheck(check);
+ if (exactly) {
+ addCheck(CHECK_NOT, check.search_str_);
+ }
+ *start = end_line;
+ return true;
+ }
+ return false;
+ }
+
+ size_t findNextStart(
+ const std::shared_ptr<std::string>& checks_file,
+ size_t prev_end) {
+ size_t start = checks_file->find("#", prev_end);
+ if (start == std::string::npos) {
+ return start;
+ }
+ start += 1;
+ static constexpr size_t max_whitespace = 6;
+ size_t i = 0;
+ while (start + i < checks_file->size() && i < max_whitespace) {
+ auto c = checks_file->at(start + i);
+ if (c != ' ' && c != '\t') {
+ break;
+ }
+ i++;
+ }
+ static const std::string check = "CHECK";
+ if (checks_file->substr(start + i, check.size()) == check) {
+ return start + i + check.size();
+ } else {
+ return findNextStart(checks_file, start + i + 1);
+ }
+ }
+
+ void parseStrings(const std::shared_ptr<std::string>& checks_file) {
+ size_t start = 0;
+ start = findNextStart(checks_file, 0);
+ while (start != std::string::npos) {
+ bool found_match = parseSingleCheck(checks_file, &start);
+ if (!found_match) {
+ std::ostringstream ss;
+ ss << "Could not parse check at:\n";
+ SourceRange(checks_file, start, start + 1).highlight(ss);
+ ss << "Check for bad input.";
+ has_run = true;
+ throw std::runtime_error(ss.str());
+ }
+ start = findNextStart(checks_file, start);
+ }
+ }
+
void doCheckNot(
const std::vector<Check>& nots,
const std::shared_ptr<std::string>& file,
}
std::vector<Check> checks;
- std::shared_ptr<std::string> check_file;
std::vector<std::vector<Check>> groups;
};
fcImpl->run(graph_str.str());
};
+void FileCheck::run(
+ const std::string& input_checks_string,
+ const std::string& test_string) {
+ fcImpl->run(input_checks_string, test_string);
+}
+
+void FileCheck::run(
+ const std::string& input_checks_string,
+ const Graph& graph) {
+ std::stringstream graph_str;
+ graph_str << graph;
+ fcImpl->run(input_checks_string, graph_str.str());
+}
+
FileCheck* FileCheck::check(const std::string& str) {
fcImpl->addCheck(CHECK, str);
return this;