Add pass registration mechanism (#18587)
authorBram Wasti <bwasti@fb.com>
Fri, 12 Apr 2019 21:53:17 +0000 (14:53 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 12 Apr 2019 22:32:00 +0000 (15:32 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18587
ghimport-source-id: 80d753f7046a2a719e0c076684f44fa2059a0921

Differential Revision: D14901227

Pulled By: bwasti

fbshipit-source-id: 56511d0313419b63945a36b80e9ea51abdef2bd4

test/cpp/jit/test.cpp
test/cpp/jit/test_misc.h
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/jit/backends/bleh [new file with mode: 0644]
torch/csrc/jit/graph_executor.cpp
torch/csrc/jit/pass_manager.cpp [new file with mode: 0644]
torch/csrc/jit/pass_manager.h [new file with mode: 0644]

index 90b38f3..a15cf00 100644 (file)
@@ -47,6 +47,7 @@ namespace jit {
   _(FromQualString)                \
   _(InternedStrings)               \
   _(IValue)                        \
+  _(PassManagement)                \
   _(Proto)                         \
   _(RegisterFusionCachesKernel)    \
   _(SchemaParser)                  \
index ea20266..f5a637c 100644 (file)
@@ -16,6 +16,7 @@
 #include "torch/csrc/jit/fuser/interface.h"
 #include "torch/csrc/jit/import.h"
 #include "torch/csrc/jit/interpreter.h"
+#include "torch/csrc/jit/pass_manager.h"
 #include "torch/csrc/jit/passes/alias_analysis.h"
 #include "torch/csrc/jit/passes/common_subexpression_elimination.h"
 #include "torch/csrc/jit/passes/constant_propagation.h"
@@ -721,6 +722,32 @@ void testModuleDefine() {
   AT_ASSERT(result.toTensor().item<float>() == 6)
 }
 
+static int testPassValue = 0;
+void fakePass(std::shared_ptr<Graph>& g) {
+  testPassValue++;
+  return;
+}
+
+RegisterPass p(fakePass);
+
+void testPassManagement() {
+  std::shared_ptr<Graph> graph = std::make_shared<Graph>();
+  script::parseIR(
+      R"IR(
+graph(%a):
+  return (%a))IR",
+      &*graph);
+
+  std::vector<IValue> stack = {IValue(torch::randn({22}, at::kCPU))};
+  auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) {
+    GraphExecutor executor(graph);
+    executor.run(stack);
+    return stack;
+  };
+  run(graph, stack);
+  AT_ASSERT(testPassValue);
+}
+
 } // namespace test
 } // namespace jit
 } // namespace torch
index ff27ce3..eeaab35 100644 (file)
@@ -55,6 +55,7 @@ libtorch_sources = [
     "torch/csrc/jit/constants.cpp",
     "torch/csrc/jit/node_hashing.cpp",
     "torch/csrc/jit/export.cpp",
+    "torch/csrc/jit/pass_manager.cpp",
     "torch/csrc/jit/pickler.cpp",
     "torch/csrc/jit/graph_executor.cpp",
     "torch/csrc/jit/import.cpp",
index 4b2281b..0f6b84c 100644 (file)
@@ -125,6 +125,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/attributes.cpp
   ${TORCH_SRC_DIR}/csrc/jit/argument_spec.cpp
   ${TORCH_SRC_DIR}/csrc/jit/export.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/pass_manager.cpp
   ${TORCH_SRC_DIR}/csrc/jit/pickler.cpp
   ${TORCH_SRC_DIR}/csrc/jit/generated/register_aten_ops_0.cpp
   ${TORCH_SRC_DIR}/csrc/jit/generated/register_aten_ops_1.cpp
diff --git a/torch/csrc/jit/backends/bleh b/torch/csrc/jit/backends/bleh
new file mode 100644 (file)
index 0000000..e69de29
index 31ef09b..966246c 100644 (file)
@@ -8,6 +8,7 @@
 #include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/jit/interpreter.h>
 #include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/pass_manager.h>
 #include <torch/csrc/jit/passes/batch_mm.h>
 #include <torch/csrc/jit/passes/canonicalize_ops.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
@@ -642,6 +643,9 @@ struct GraphExecutorImpl {
   }
 
   void runNondiffOptimization(std::shared_ptr<Graph>& graph) {
+    for (const auto& pass : getCustomPasses()) {
+      pass(graph);
+    }
     FuseGraph(graph);
   }
 
diff --git a/torch/csrc/jit/pass_manager.cpp b/torch/csrc/jit/pass_manager.cpp
new file mode 100644 (file)
index 0000000..91a7730
--- /dev/null
@@ -0,0 +1,16 @@
+#include <torch/csrc/jit/pass_manager.h>
+
+namespace torch {
+namespace jit {
+
+std::vector<Pass>& getCustomPasses() {
+  static std::vector<Pass> passes;
+  return passes;
+}
+
+RegisterPass::RegisterPass(Pass p) {
+  getCustomPasses().emplace_back(std::move(p));
+}
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/pass_manager.h b/torch/csrc/jit/pass_manager.h
new file mode 100644 (file)
index 0000000..47bb1f7
--- /dev/null
@@ -0,0 +1,29 @@
+#pragma once
+
+#include <torch/csrc/jit/ir.h>
+
+/* `getCustomPasses()` returns a vector of passes that will be executed after
+ * differentiation but before any fusion.  This is the de-facto location
+ * for compiler backends to insert passes.
+ *
+ * Static registration of a pass can be done by creating a global
+ * `RegisterPass r(Pass)` variable in a compilation unit.
+ *
+ * pass_manager.h uses a Meyer's singleton
+ * to store a vector of `Pass`es, which modify the IR graph in place.
+ */
+
+namespace torch {
+namespace jit {
+
+// A pass modifies a Graph in place.
+using Pass = std::function<void(std::shared_ptr<Graph>&)>;
+
+std::vector<Pass>& getCustomPasses();
+
+struct RegisterPass {
+  RegisterPass(Pass p);
+};
+
+} // namespace jit
+} // namespace torch