From: Bram Wasti Date: Fri, 12 Apr 2019 21:53:17 +0000 (-0700) Subject: Add pass registration mechanism (#18587) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~247 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b1539412db2e9610afcbe62270b4d92885431430;p=platform%2Fupstream%2Fpytorch.git Add pass registration mechanism (#18587) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18587 ghimport-source-id: 80d753f7046a2a719e0c076684f44fa2059a0921 Differential Revision: D14901227 Pulled By: bwasti fbshipit-source-id: 56511d0313419b63945a36b80e9ea51abdef2bd4 --- diff --git a/test/cpp/jit/test.cpp b/test/cpp/jit/test.cpp index 90b38f3..a15cf00 100644 --- a/test/cpp/jit/test.cpp +++ b/test/cpp/jit/test.cpp @@ -47,6 +47,7 @@ namespace jit { _(FromQualString) \ _(InternedStrings) \ _(IValue) \ + _(PassManagement) \ _(Proto) \ _(RegisterFusionCachesKernel) \ _(SchemaParser) \ diff --git a/test/cpp/jit/test_misc.h b/test/cpp/jit/test_misc.h index ea20266..f5a637c 100644 --- a/test/cpp/jit/test_misc.h +++ b/test/cpp/jit/test_misc.h @@ -16,6 +16,7 @@ #include "torch/csrc/jit/fuser/interface.h" #include "torch/csrc/jit/import.h" #include "torch/csrc/jit/interpreter.h" +#include "torch/csrc/jit/pass_manager.h" #include "torch/csrc/jit/passes/alias_analysis.h" #include "torch/csrc/jit/passes/common_subexpression_elimination.h" #include "torch/csrc/jit/passes/constant_propagation.h" @@ -721,6 +722,32 @@ void testModuleDefine() { AT_ASSERT(result.toTensor().item() == 6) } +static int testPassValue = 0; +void fakePass(std::shared_ptr& g) { + testPassValue++; + return; +} + +RegisterPass p(fakePass); + +void testPassManagement() { + std::shared_ptr graph = std::make_shared(); + script::parseIR( + R"IR( +graph(%a): + return (%a))IR", + &*graph); + + std::vector stack = {IValue(torch::randn({22}, at::kCPU))}; + auto run = [&](std::shared_ptr& graph, std::vector stack) { + GraphExecutor executor(graph); + executor.run(stack); + return stack; + }; + run(graph, stack); + AT_ASSERT(testPassValue); +} + } // namespace test } // namespace jit } // namespace torch diff --git a/tools/build_variables.py b/tools/build_variables.py index ff27ce3..eeaab35 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -55,6 +55,7 @@ libtorch_sources = [ "torch/csrc/jit/constants.cpp", "torch/csrc/jit/node_hashing.cpp", "torch/csrc/jit/export.cpp", + "torch/csrc/jit/pass_manager.cpp", "torch/csrc/jit/pickler.cpp", "torch/csrc/jit/graph_executor.cpp", "torch/csrc/jit/import.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 4b2281b..0f6b84c 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -125,6 +125,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/attributes.cpp ${TORCH_SRC_DIR}/csrc/jit/argument_spec.cpp ${TORCH_SRC_DIR}/csrc/jit/export.cpp + ${TORCH_SRC_DIR}/csrc/jit/pass_manager.cpp ${TORCH_SRC_DIR}/csrc/jit/pickler.cpp ${TORCH_SRC_DIR}/csrc/jit/generated/register_aten_ops_0.cpp ${TORCH_SRC_DIR}/csrc/jit/generated/register_aten_ops_1.cpp diff --git a/torch/csrc/jit/backends/bleh b/torch/csrc/jit/backends/bleh new file mode 100644 index 0000000..e69de29 diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 31ef09b..966246c 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -642,6 +643,9 @@ struct GraphExecutorImpl { } void runNondiffOptimization(std::shared_ptr& graph) { + for (const auto& pass : getCustomPasses()) { + pass(graph); + } FuseGraph(graph); } diff --git a/torch/csrc/jit/pass_manager.cpp b/torch/csrc/jit/pass_manager.cpp new file mode 100644 index 0000000..91a7730 --- /dev/null +++ b/torch/csrc/jit/pass_manager.cpp @@ -0,0 +1,16 @@ +#include + +namespace torch { +namespace jit { + +std::vector& getCustomPasses() { + static std::vector passes; + return passes; +} + +RegisterPass::RegisterPass(Pass p) { + getCustomPasses().emplace_back(std::move(p)); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/pass_manager.h b/torch/csrc/jit/pass_manager.h new file mode 100644 index 0000000..47bb1f7 --- /dev/null +++ b/torch/csrc/jit/pass_manager.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +/* `getCustomPasses()` returns a vector of passes that will be executed after + * differentiation but before any fusion. This is the de-facto location + * for compiler backends to insert passes. + * + * Static registration of a pass can be done by creating a global + * `RegisterPass r(Pass)` variable in a compilation unit. + * + * pass_manager.h uses a Meyer's singleton + * to store a vector of `Pass`es, which modify the IR graph in place. + */ + +namespace torch { +namespace jit { + +// A pass modifies a Graph in place. +using Pass = std::function&)>; + +std::vector& getCustomPasses(); + +struct RegisterPass { + RegisterPass(Pass p); +}; + +} // namespace jit +} // namespace torch