exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
from torch.testing import FileCheck
from torch._C import TensorType, TupleType, FloatType, IntType, \
- ListType, StringType, DictType, parse_ir
+ ListType, StringType, DictType
from copy import deepcopy
import random
from typing import List, Dict, Optional, Tuple
m2.sub2.a.data.zero_()
self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
- def test_irparser(self):
- graph_str = """graph(%0 : Double(5, 5)):
- # CHECK: aten::relu
- %1 : Double(5, 5) = aten::relu(%0)
- return (%1)
- """
- FileCheck().run(graph_str, parse_ir(graph_str))
-
def test_filecheck(self):
def test_check():
file = "232"
with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
fb.run("22 1 22")
- def test_filecheck_parse(self):
- def test_check():
- file = """
- # CHECK: 2
- # CHECK: 3
- # CHECK: 2
- 232
- """
- FileCheck().run(checks_file=file, test_file=file)
- file = """
- # CHECK: 232
- 232
- """
- FileCheck().run(file, "232")
- with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'):
- FileCheck().run(file, "22")
- with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
- FileCheck().run("# CHECK: 22", "23")
- test_check()
-
- def test_check_count():
- file = "22222"
- FileCheck().run("# CHECK-COUNT-5: 2", file)
- FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file)
- FileCheck().run("# CHECK-COUNT-2: 22", file)
- FileCheck().run("# CHECK-COUNT-1: 222", file)
-
- with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
- FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file)
- test_check_count()
-
- def test_check_same():
- file = "22\n33"
- FileCheck().run("# CHECK-SAME: 22", file)
-
- with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
- FileCheck().run("# CHECK-SAME: 33", file)
-
- file = "22 1 3"
-
- FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file)
- FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file)
- test_check_same()
-
- def test_bad_input():
- with self.assertRaisesRegex(RuntimeError, "Check for bad input"):
- FileCheck().run("", "1")
-
- with self.assertRaisesRegex(RuntimeError, "Could not parse check"):
- FileCheck().run("# CHECK1", "")
-
- test_bad_input()
-
def test_script_module_call_noscript(self):
class M(torch.jit.ScriptModule):
def __init__(self):
};
namespace {
-
size_t assertFind(
const SourceRange& search_range,
const std::string& sub,
- std::function<void(std::ostream& out)> extra_msg = nullptr) {
+ const Check& check) {
auto pos = search_range.file_ptr()->find(sub, search_range.start());
if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
auto found_range =
printQuotedString(ss, sub);
ss << " but did not find it\n";
found_range.highlight(ss);
- if (extra_msg) {
- extra_msg(ss);
- }
+ ss << "From " << check << "\n";
throw std::runtime_error(ss.str());
}
return pos;
}
size_t assertFind(
- const SourceRange& search_range,
- const std::string& sub,
- const Check& check) {
- return assertFind(search_range, sub, [&](std::ostream& out) {
- out << "From " << check << "\n";
- });
-}
-
-size_t assertFind(
const std::shared_ptr<std::string>& file,
const std::string& sub,
size_t start,
throw std::runtime_error(ss.str());
}
}
-
} // namespace
struct FileCheckImpl {
TORCH_API void run(const std::string& test_file) {
has_run = true;
-
- if (groups.size() == 0 || groups[0].size() == 0) {
- throw std::runtime_error(
- "No checks have been added to this instance of"
- "Filecheck! Check for bad input.");
- }
-
doChecks(std::make_shared<std::string>(test_file));
}
- TORCH_API void run(
- const std::string& checks_file,
- const std::string& test_file) {
- auto checks_ptr = std::make_shared<std::string>(checks_file);
- parseStrings(checks_ptr);
- run(test_file);
- }
+ TORCH_API void addCheck(
+ CheckType type,
+ const std::string& s,
+ c10::optional<size_t> count = c10::nullopt) {
+ Check check(type, s, std::move(count));
- TORCH_API void addCheck(Check check) {
// consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group
- if (groups.size() == 0 ||
- (check.type_ != CHECK_NOT && check.type_ != CHECK_DAG)) {
+ if (groups.size() == 0 || (type != CHECK_NOT && type != CHECK_DAG)) {
groups.push_back({check});
} else {
auto& last_group = groups.back();
- if (last_group.at(0).type_ == check.type_) {
+ if (last_group.at(0).type_ == type) {
last_group.push_back(check);
} else {
groups.push_back({check});
}
}
- has_run = false;
- }
- TORCH_API void addCheck(
- CheckType type,
- const std::string& s,
- c10::optional<size_t> count = c10::nullopt) {
- addCheck(Check(type, s, std::move(count)));
+ has_run = false;
}
bool has_run = false;
friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
private:
- bool parseSingleCheck(
- const std::shared_ptr<std::string>& checks_file,
- size_t* start) {
- const static std::vector<std::pair<CheckType, std::string>> check_pairs = {
- {CHECK, ": "},
- {CHECK_NEXT, "-NEXT: "},
- {CHECK_SAME, "-SAME: "},
- {CHECK_NOT, "-NOT: "},
- {CHECK_DAG, "-DAG: "},
- {CHECK_COUNT, "-COUNT-"}, // needs special parsing
- };
-
- for (const auto& check_pair : check_pairs) {
- const std::string& check_suffix = check_pair.second;
- auto suffix_pos = checks_file->find(check_suffix, *start);
- if (suffix_pos != *start) {
- continue;
- }
- size_t end_check_string = suffix_pos + check_suffix.size();
- CheckType type = check_pair.first;
- c10::optional<size_t> count = c10::nullopt;
- auto end_line = checks_file->find("\n", end_check_string);
- bool exactly = false;
- if (type == CHECK_COUNT) {
- const std::string exact = "EXACTLY-";
- if (checks_file->find(exact, end_check_string) == end_check_string) {
- exactly = true;
- end_check_string += exact.size();
- }
- size_t end = assertFind(
- SourceRange(checks_file, end_check_string, end_line), ":");
- count = std::stoll(
- checks_file->substr(end_check_string, end - end_check_string));
- end_check_string = end + 2; // add ':' and the space
- }
- auto check = Check(
- type,
- checks_file->substr(end_check_string, end_line - end_check_string),
- count);
- addCheck(check);
- if (exactly) {
- addCheck(CHECK_NOT, check.search_str_);
- }
- *start = end_line;
- return true;
- }
- return false;
- }
-
- size_t findNextStart(
- const std::shared_ptr<std::string>& checks_file,
- size_t prev_end) {
- size_t start = checks_file->find("#", prev_end);
- if (start == std::string::npos) {
- return start;
- }
- start += 1;
- static constexpr size_t max_whitespace = 6;
- size_t i = 0;
- while (start + i < checks_file->size() && i < max_whitespace) {
- auto c = checks_file->at(start + i);
- if (c != ' ' && c != '\t') {
- break;
- }
- i++;
- }
- static const std::string check = "CHECK";
- if (checks_file->substr(start + i, check.size()) == check) {
- return start + i + check.size();
- } else {
- return findNextStart(checks_file, start + i + 1);
- }
- }
-
- void parseStrings(const std::shared_ptr<std::string>& checks_file) {
- size_t start = 0;
- start = findNextStart(checks_file, 0);
- while (start != std::string::npos) {
- bool found_match = parseSingleCheck(checks_file, &start);
- if (!found_match) {
- std::ostringstream ss;
- ss << "Could not parse check at:\n";
- SourceRange(checks_file, start, start + 1).highlight(ss);
- ss << "Check for bad input.";
- has_run = true;
- throw std::runtime_error(ss.str());
- }
- start = findNextStart(checks_file, start);
- }
- }
-
void doCheckNot(
const std::vector<Check>& nots,
const std::shared_ptr<std::string>& file,
}
std::vector<Check> checks;
+ std::shared_ptr<std::string> check_file;
std::vector<std::vector<Check>> groups;
};
fcImpl->run(graph_str.str());
};
-void FileCheck::run(
- const std::string& input_checks_string,
- const std::string& test_string) {
- fcImpl->run(input_checks_string, test_string);
-}
-
-void FileCheck::run(
- const std::string& input_checks_string,
- const Graph& graph) {
- std::stringstream graph_str;
- graph_str << graph;
- fcImpl->run(input_checks_string, graph_str.str());
-}
-
FileCheck* FileCheck::check(const std::string& str) {
fcImpl->addCheck(CHECK, str);
return this;