Re-land Parsing file check (#18570)
authoreellison <elias_ellison@brown.edu>
Fri, 29 Mar 2019 22:35:37 +0000 (15:35 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 22:46:59 +0000 (15:46 -0700)
Summary:
The last time I tried to land it there was a merge race with the docs coverage test lol. Re-landing with the fix.

Re-land of https://github.com/pytorch/pytorch/pull/18304
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18570

Differential Revision: D14668859

Pulled By: eellison

fbshipit-source-id: 3825a35ddc6179a0d433d70d22b5c1a96c20b21a

test/cpp/jit/test_irparser.h
test/test_docs_coverage.py
test/test_jit.py
torch/csrc/jit/init.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/testing/file_check.cpp
torch/csrc/jit/testing/file_check.h

index c8c3eea..af8d28a 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/irparser.h>
+#include <torch/csrc/jit/testing/file_check.h>
 #include "test/cpp/jit/test_base.h"
 
 #include <sstream>
@@ -211,6 +212,19 @@ graph(%0 : Tensor,
     }
     AT_ASSERT(error_thrown);
   }
+
+  {
+    auto graph = std::make_shared<Graph>();
+    const std::string& text =
+        R"IR(
+    graph(%a):
+    # CHECK: return
+      return (%a))IR";
+
+    script::parseIR(text, &*graph);
+    graph->inputs()[0]->type()->expect<TensorType>();
+    torch::jit::testing::FileCheck().run(text, *graph);
+  }
 }
 } // namespace jit
 } // namespace torch
index 3b565c3..02bad78 100644 (file)
@@ -37,6 +37,7 @@ class TestDocCoverage(unittest.TestCase):
             # below are some jit functions
             'wait', 'fork', 'parse_type_comment', 'import_ir_module',
             'import_ir_module_from_buffer', 'merge_type_from_type_comment',
+            'parse_ir',
 
             # below are symbols mistakely binded to torch.*, but should
             # go to torch.nn.functional.* instead
index 6229f0c..7d0f1c9 100644 (file)
@@ -43,7 +43,7 @@ from common_methods_invocations import create_input, unpack_variables, \
     exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
 from torch.testing import FileCheck
 from torch._C import TensorType, TupleType, FloatType, IntType, \
-    ListType, StringType, DictType
+    ListType, StringType, DictType, parse_ir
 from copy import deepcopy
 import random
 from typing import List, Dict, Optional, Tuple
@@ -5540,6 +5540,14 @@ a")
         m2.sub2.a.data.zero_()
         self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
 
+    def test_irparser(self):
+        graph_str = """graph(%0 : Double(5, 5)):
+          # CHECK: aten::relu
+          %1 : Double(5, 5) = aten::relu(%0)
+          return (%1)
+        """
+        FileCheck().run(graph_str, parse_ir(graph_str))
+
     def test_filecheck(self):
         def test_check():
             file = "232"
@@ -5632,6 +5640,59 @@ a")
             with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
                 fb.run("22 1 22")
 
+    def test_filecheck_parse(self):
+        def test_check():
+            file = """
+                # CHECK: 2
+                # CHECK: 3
+                # CHECK: 2
+                232
+                """
+            FileCheck().run(checks_file=file, test_file=file)
+            file = """
+                # CHECK: 232
+                232
+                """
+            FileCheck().run(file, "232")
+            with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'):
+                FileCheck().run(file, "22")
+            with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
+                FileCheck().run("# CHECK: 22", "23")
+        test_check()
+
+        def test_check_count():
+            file = "22222"
+            FileCheck().run("# CHECK-COUNT-5: 2", file)
+            FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file)
+            FileCheck().run("# CHECK-COUNT-2: 22", file)
+            FileCheck().run("# CHECK-COUNT-1: 222", file)
+
+            with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
+                FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file)
+        test_check_count()
+
+        def test_check_same():
+            file = "22\n33"
+            FileCheck().run("# CHECK-SAME: 22", file)
+
+            with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
+                FileCheck().run("# CHECK-SAME: 33", file)
+
+            file = "22  1  3"
+
+            FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file)
+            FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file)
+        test_check_same()
+
+        def test_bad_input():
+            with self.assertRaisesRegex(RuntimeError, "Check for bad input"):
+                FileCheck().run("", "1")
+
+            with self.assertRaisesRegex(RuntimeError, "Could not parse check"):
+                FileCheck().run("# CHECK1", "")
+
+        test_bad_input()
+
     def test_script_module_call_noscript(self):
         class M(torch.jit.ScriptModule):
             def __init__(self):
index 489c438..528dc43 100644 (file)
@@ -7,6 +7,7 @@
 #include <torch/csrc/jit/fuser/kernel_cache.h>
 #include <torch/csrc/jit/graph_executor.h>
 #include <torch/csrc/jit/import.h>
+#include <torch/csrc/jit/irparser.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/canonicalize.h>
 #include <torch/csrc/jit/passes/canonicalize_ops.h>
@@ -75,7 +76,6 @@ bool loadPythonClasses() {
 
   return true;
 }
-
 } // anonymous namespace
 
 #if defined(_WIN32)
@@ -375,6 +375,12 @@ void initJITBindings(PyObject* module) {
       },
       py::arg("qualified_name"));
 
+  m.def("parse_ir", [](const std::string& input) {
+    auto graph = std::make_shared<Graph>();
+    script::parseIR(input, &*graph);
+    return graph;
+  });
+
   py::class_<FunctionSchema>(m, "FunctionSchema")
       .def_property_readonly(
           "name", [](FunctionSchema& self) { return self.name(); })
@@ -490,6 +496,5 @@ void initJITBindings(PyObject* module) {
   script::initTreeViewBindings(module);
   script::initJitScriptBindings(module);
 }
-
 } // namespace jit
 } // namespace torch
index f4d1c89..deebed3 100644 (file)
@@ -13,6 +13,7 @@
 #include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/hooks_for_testing.h>
 #include <torch/csrc/jit/import_source.h>
+#include <torch/csrc/jit/irparser.h>
 #include <torch/csrc/jit/passes/python_print.h>
 #include <torch/csrc/jit/pybind_utils.h>
 #include <torch/csrc/jit/python_tracer.h>
@@ -1098,9 +1099,24 @@ void initJitScriptBindings(PyObject* module) {
           [](testing::FileCheck& f, const std::string& str) {
             return f.run(str);
           })
-      .def("run", [](testing::FileCheck& f, const Graph& g) {
-        return f.run(g);
-      });
+      .def(
+          "run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); })
+      .def(
+          "run",
+          [](testing::FileCheck& f,
+             const std::string& input,
+             const std::string& output) { return f.run(input, output); },
+          "Run",
+          py::arg("checks_file"),
+          py::arg("test_file"))
+      .def(
+          "run",
+          [](testing::FileCheck& f, const std::string& input, const Graph& g) {
+            return f.run(input, g);
+          },
+          "Run",
+          py::arg("checks_file"),
+          py::arg("graph"));
 }
 } // namespace script
 } // namespace jit
index 3af8c5a..741502b 100644 (file)
@@ -79,10 +79,11 @@ std::ostream& operator<<(std::ostream& out, const Check& c) {
 };
 
 namespace {
+
 size_t assertFind(
     const SourceRange& search_range,
     const std::string& sub,
-    const Check& check) {
+    std::function<void(std::ostream& out)> extra_msg = nullptr) {
   auto pos = search_range.file_ptr()->find(sub, search_range.start());
   if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
     auto found_range =
@@ -92,13 +93,24 @@ size_t assertFind(
     printQuotedString(ss, sub);
     ss << " but did not find it\n";
     found_range.highlight(ss);
-    ss << "From " << check << "\n";
+    if (extra_msg) {
+      extra_msg(ss);
+    }
     throw std::runtime_error(ss.str());
   }
   return pos;
 }
 
 size_t assertFind(
+    const SourceRange& search_range,
+    const std::string& sub,
+    const Check& check) {
+  return assertFind(search_range, sub, [&](std::ostream& out) {
+    out << "From " << check << "\n";
+  });
+}
+
+size_t assertFind(
     const std::shared_ptr<std::string>& file,
     const std::string& sub,
     size_t start,
@@ -123,6 +135,7 @@ void assertNotFind(
     throw std::runtime_error(ss.str());
   }
 }
+
 } // namespace
 
 struct FileCheckImpl {
@@ -130,35 +143,143 @@ struct FileCheckImpl {
 
   TORCH_API void run(const std::string& test_file) {
     has_run = true;
+
+    if (groups.size() == 0 || groups[0].size() == 0) {
+      throw std::runtime_error(
+          "No checks have been added to this instance of"
+          "Filecheck! Check for bad input.");
+    }
+
     doChecks(std::make_shared<std::string>(test_file));
   }
 
-  TORCH_API void addCheck(
-      CheckType type,
-      const std::string& s,
-      c10::optional<size_t> count = c10::nullopt) {
-    Check check(type, s, std::move(count));
+  TORCH_API void run(
+      const std::string& checks_file,
+      const std::string& test_file) {
+    auto checks_ptr = std::make_shared<std::string>(checks_file);
+    parseStrings(checks_ptr);
+    run(test_file);
+  }
 
+  TORCH_API void addCheck(Check check) {
     // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group
-    if (groups.size() == 0 || (type != CHECK_NOT && type != CHECK_DAG)) {
+    if (groups.size() == 0 ||
+        (check.type_ != CHECK_NOT && check.type_ != CHECK_DAG)) {
       groups.push_back({check});
     } else {
       auto& last_group = groups.back();
-      if (last_group.at(0).type_ == type) {
+      if (last_group.at(0).type_ == check.type_) {
         last_group.push_back(check);
       } else {
         groups.push_back({check});
       }
     }
-
     has_run = false;
   }
 
+  TORCH_API void addCheck(
+      CheckType type,
+      const std::string& s,
+      c10::optional<size_t> count = c10::nullopt) {
+    addCheck(Check(type, s, std::move(count)));
+  }
+
   bool has_run = false;
 
   friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
 
  private:
+  bool parseSingleCheck(
+      const std::shared_ptr<std::string>& checks_file,
+      size_t* start) {
+    const static std::vector<std::pair<CheckType, std::string>> check_pairs = {
+        {CHECK, ": "},
+        {CHECK_NEXT, "-NEXT: "},
+        {CHECK_SAME, "-SAME: "},
+        {CHECK_NOT, "-NOT: "},
+        {CHECK_DAG, "-DAG: "},
+        {CHECK_COUNT, "-COUNT-"}, // needs special parsing
+    };
+
+    for (const auto& check_pair : check_pairs) {
+      const std::string& check_suffix = check_pair.second;
+      auto suffix_pos = checks_file->find(check_suffix, *start);
+      if (suffix_pos != *start) {
+        continue;
+      }
+      size_t end_check_string = suffix_pos + check_suffix.size();
+      CheckType type = check_pair.first;
+      c10::optional<size_t> count = c10::nullopt;
+      auto end_line = checks_file->find("\n", end_check_string);
+      bool exactly = false;
+      if (type == CHECK_COUNT) {
+        const std::string exact = "EXACTLY-";
+        if (checks_file->find(exact, end_check_string) == end_check_string) {
+          exactly = true;
+          end_check_string += exact.size();
+        }
+        size_t end = assertFind(
+            SourceRange(checks_file, end_check_string, end_line), ":");
+        count = std::stoll(
+            checks_file->substr(end_check_string, end - end_check_string));
+        end_check_string = end + 2; // add ':' and the space
+      }
+      auto check = Check(
+          type,
+          checks_file->substr(end_check_string, end_line - end_check_string),
+          count);
+      addCheck(check);
+      if (exactly) {
+        addCheck(CHECK_NOT, check.search_str_);
+      }
+      *start = end_line;
+      return true;
+    }
+    return false;
+  }
+
+  size_t findNextStart(
+      const std::shared_ptr<std::string>& checks_file,
+      size_t prev_end) {
+    size_t start = checks_file->find("#", prev_end);
+    if (start == std::string::npos) {
+      return start;
+    }
+    start += 1;
+    static constexpr size_t max_whitespace = 6;
+    size_t i = 0;
+    while (start + i < checks_file->size() && i < max_whitespace) {
+      auto c = checks_file->at(start + i);
+      if (c != ' ' && c != '\t') {
+        break;
+      }
+      i++;
+    }
+    static const std::string check = "CHECK";
+    if (checks_file->substr(start + i, check.size()) == check) {
+      return start + i + check.size();
+    } else {
+      return findNextStart(checks_file, start + i + 1);
+    }
+  }
+
+  void parseStrings(const std::shared_ptr<std::string>& checks_file) {
+    size_t start = 0;
+    start = findNextStart(checks_file, 0);
+    while (start != std::string::npos) {
+      bool found_match = parseSingleCheck(checks_file, &start);
+      if (!found_match) {
+        std::ostringstream ss;
+        ss << "Could not parse check at:\n";
+        SourceRange(checks_file, start, start + 1).highlight(ss);
+        ss << "Check for bad input.";
+        has_run = true;
+        throw std::runtime_error(ss.str());
+      }
+      start = findNextStart(checks_file, start);
+    }
+  }
+
   void doCheckNot(
       const std::vector<Check>& nots,
       const std::shared_ptr<std::string>& file,
@@ -277,7 +398,6 @@ struct FileCheckImpl {
   }
 
   std::vector<Check> checks;
-  std::shared_ptr<std::string> check_file;
   std::vector<std::vector<Check>> groups;
 };
 
@@ -309,6 +429,20 @@ void FileCheck::run(const Graph& graph) {
   fcImpl->run(graph_str.str());
 };
 
+void FileCheck::run(
+    const std::string& input_checks_string,
+    const std::string& test_string) {
+  fcImpl->run(input_checks_string, test_string);
+}
+
+void FileCheck::run(
+    const std::string& input_checks_string,
+    const Graph& graph) {
+  std::stringstream graph_str;
+  graph_str << graph;
+  fcImpl->run(input_checks_string, graph_str.str());
+}
+
 FileCheck* FileCheck::check(const std::string& str) {
   fcImpl->addCheck(CHECK, str);
   return this;
index cf80575..d7a7819 100644 (file)
@@ -23,6 +23,14 @@ struct FileCheck {
   // Run FileCheck against dump of graph IR
   TORCH_API void run(const Graph& graph);
 
+  // Parsing input checks string and run against test string / dump of graph IR
+  TORCH_API void run(
+      const std::string& input_checks_string,
+      const std::string& test_string);
+  TORCH_API void run(
+      const std::string& input_checks_string,
+      const Graph& graph);
+
   // Checks that the string occurs, starting at the end of the most recent match
   TORCH_API FileCheck* check(const std::string& str);