From: Richard Zou Date: Thu, 13 Dec 2018 15:51:08 +0000 (-0800) Subject: Reuse KernelSpec for FusionGroups with equivalent graphs (#14541) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2273 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b14d6d730a7fac0d6d31be600878d1a7ca4af217;p=platform%2Fupstream%2Fpytorch.git Reuse KernelSpec for FusionGroups with equivalent graphs (#14541) Summary: Before this PR, loop unrolling + the graph fuser was creating multiple FusionGroups with the same bodies (with different variable names) for JIT LSTMs. Each FusionGroup got registered to a separate fusion key; each key resulted in a different compilation for the same specializations. This PR makes it so that when registering FusionGroups with the fusion compiler, the compiler first checks the KernelSpec cache to see if the FusionGroup's graph exists already. If it does, then return the corresponding KernelSpec's key to share compiled kernels. In addition, graphs in the KernelSpec cache are canonicalized before being cached. I added a flag to the canonicalize pass to remove unique names of values. This shortens the compile time for a JIT LSTM (seq_len of 100, loop unroll factor of 8) from 5.3s to 2.3s. Most of this compile time is running the graph fuser and/or fusion compiler; while this PR makes it so that there is only one unique kernel in the forward pass, there are a lot of different kernels (6) in the backward pass (after loop unrolling) that should be investigated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14541 Differential Revision: D13324487 Pulled By: zou3519 fbshipit-source-id: b841d82ed35a959b5cfc72db033bf5a7b42cc4fb --- diff --git a/test/cpp/jit/gtest.cpp b/test/cpp/jit/gtest.cpp index fb032db..6715588 100644 --- a/test/cpp/jit/gtest.cpp +++ b/test/cpp/jit/gtest.cpp @@ -23,6 +23,7 @@ JIT_TEST(DynamicDAG) JIT_TEST(FromQualString) JIT_TEST(InternedStrings) JIT_TEST(IValue) +JIT_TEST(RegisterFusionCachesKernel) JIT_TEST(SchemaParser) JIT_TEST(TopologicalIndex) JIT_TEST(TopologicalMove) diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index db1caf2..5782dc0 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -45,6 +45,7 @@ #include "torch/csrc/jit/passes/constant_propagation.h" #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/passes/graph_fuser.h" #include "torch/csrc/jit/passes/lower_grad_of.h" #include "torch/csrc/jit/passes/lower_tuples.h" #include "torch/csrc/jit/passes/requires_grad_analysis.h" @@ -878,6 +879,50 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) { out << "\n"; } +void testRegisterFusionCachesKernel(std::ostream& out = std::cout) { + // Build up a fake graph with a FusionGroup + auto createGraphWithNames = [](std::string cname, std::string dname) { + auto graph = std::make_shared(); + at::ScalarType s = at::ScalarType::Float; + auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1}); + auto a = SymbolicVariable::asNewInput(*graph, type); + auto b = SymbolicVariable::asNewInput(*graph, type); + auto c = a * b; + auto d = c * a; + c.value()->setUniqueName(cname); + d.value()->setUniqueName(dname); + graph->registerOutput(d.value()); + FuseGraph(graph); + return graph; + }; + + auto getFusionGroup = [](const std::shared_ptr& graph) { + const auto& nodes = graph->nodes(); + auto maybe_fusion_group = std::find_if( + nodes.begin(), nodes.end(), + [](const Node* node) { return node->kind() == prim::FusionGroup; }); + JIT_ASSERTM( + maybe_fusion_group != nodes.end(), + "testRegisterFusionCachesKernel: could not create FusionGroup"); + return *maybe_fusion_group; + }; + + // Create two alpha-equivalent fusion groups + auto graph1 = createGraphWithNames("c1", "d1"); + auto fg1 = getFusionGroup(graph1); + + auto graph2 = createGraphWithNames("c2", "d2"); + auto fg2 = getFusionGroup(graph2); + + // Register both with the fusion compiler. + auto expected_key = registerFusion(fg1); + auto second_key = registerFusion(fg2); + + // Because the graphs are alpha-equivalent, they should return the same key + // and therefore share a KernelSpec to share kernels for specializations + ASSERT_EQ(second_key, expected_key); +} + void testCreateAutodiffSubgraphs(std::ostream& out = std::cout) { auto graph = build_lstm(); CreateAutodiffSubgraphs(graph, /*threshold=*/2); diff --git a/test/test_jit.py b/test/test_jit.py index 0000e93..80801c3 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9627,31 +9627,6 @@ class TestFuser(JitTestCase): ge = self.checkTrace(scaleshift, inputs) self.assertExpectedGraph(ge.graph_for(*inputs)) - @staticmethod - def _test_cast_Float(self, device): - def f(x, y): - z = x.float() - return z + y - - inputs = [ - torch.randn(4, 4, dtype=torch.double, device=device), - torch.randn(4, 4, dtype=torch.float, device=device), - ] - - ge = self.checkScript(f, inputs) - self.assertAllFused(ge.graph_for(*inputs)) - - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle") - @enable_cpu_fuser - def test_cast_Float(self): - return self._test_cast_Float(self, 'cpu') - - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle") - @unittest.skipIf(not RUN_CUDA, "No CUDA") - @skipIfRocm - def test_cast_Float_cuda(self): - return self._test_cast_Float(self, 'cuda') - @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_HALF, "no half support") @@ -9942,6 +9917,60 @@ class TestFuser(JitTestCase): ge = self.checkTrace(self.fn_test_exp, (x, y)) self.assertAllFused(ge.graph_for(x, y)) + @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") + @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") + @skipIfRocm + @enable_cpu_fuser + def test_fusion_reuse_multi_gpu(self): + def fn(x, y): + return x * y * x * y + + inputs_cpu = [ + torch.randn(4, 4, dtype=torch.float), + torch.randn(4, 4, dtype=torch.float), + ] + inputs_cuda0 = [x.cuda(0) for x in inputs_cpu] + inputs_cuda1 = [y.cuda(1) for y in inputs_cpu] + + # Should not crash; these should compile different kernels. + ge = self.checkScript(fn, inputs_cpu) + self.assertAllFused(ge.graph_for(*inputs_cpu)) + ge(*inputs_cuda0) + ge(*inputs_cuda1) + + @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") + @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") + @skipIfRocm + @enable_cpu_fuser + def test_kernel_cache_multi_gpu(self): + def not_fusible(x): + return x + + def fn(x, y, z): + x_out = x * x * x * x * x # fusion: lambda x. x * x * x * x * x + y_out = y * y * y * y * y + z_out = z * z * z * z * z + return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out) + + inputs = [ + torch.randn(4, 4, dtype=torch.float), + torch.randn(4, 4, dtype=torch.float, device='cuda:0'), + torch.randn(4, 4, dtype=torch.float, device='cuda:1'), + ] + + prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() + + # There are 3 FusionGroups. Because they have the same graph, they + # should reuse the same KernelSpec in the KernelSpec cache. + ge = self.checkScript(fn, inputs) + self.assertGraphContainsExactly( + ge.graph_for(*inputs), 'prim::FusionGroup', 3, True) + new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs() + # XXX: This assumes that the same kernel isn't already used by another test + self.assertEqual(new_cache_size - prev_cache_size, 1) + # TODO: This test doesn't offer anything valuable, maybe we should delete it @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") diff --git a/torch/csrc/jit/fuser/compiler.cpp b/torch/csrc/jit/fuser/compiler.cpp index 6615310..dadf9cc 100644 --- a/torch/csrc/jit/fuser/compiler.cpp +++ b/torch/csrc/jit/fuser/compiler.cpp @@ -5,11 +5,13 @@ #include #include #include +#include #include #include #include #include #include +#include "torch/csrc/jit/fuser/interface.h" #if USE_CUDA_FUSER #include @@ -132,17 +134,24 @@ static void upfrontCompilation(KernelSpec& spec) { } int64_t registerFusion(const Node* fusion_group) { - // Creates and stores the FusionSpec - auto graph = fusion_group->g(attr::Subgraph)->copy(); - EraseShapeInformation(graph); - const auto key = store(graph); + auto graph = normalizeGraphForCache(fusion_group->g(attr::Subgraph)); - if (canFuseOnCPU() || canFuseOnGPU()) { - const auto maybe_spec = retrieve(key); - JIT_ASSERT(maybe_spec); - upfrontCompilation(**maybe_spec); + // Don't re-register the fusion if we can use a pre-existing one + const auto maybe_spec = lookupGraph(graph); + if (maybe_spec) { + return (*maybe_spec)->key(); } + // Unconditionally create and register the fusion + // This is necessary to support our global disable fusions flag: if someone + // runs some code under no-fusions mode and then runs some code with fusions + // enabled, the second time around the returned spec from the cache should + // be a valid spec (must have had upfrontCompilation run on it). + const auto key = store(graph); + const auto maybe_retrieved_spec = retrieve(key); + JIT_ASSERT(maybe_retrieved_spec); + upfrontCompilation(**maybe_retrieved_spec); + return key; } diff --git a/torch/csrc/jit/fuser/compiler.h b/torch/csrc/jit/fuser/compiler.h index c25e6db..2a7f6f0 100644 --- a/torch/csrc/jit/fuser/compiler.h +++ b/torch/csrc/jit/fuser/compiler.h @@ -16,7 +16,8 @@ namespace torch { namespace jit { namespace fuser { -// Performs device-independent "upfront" compilation of the given fusion_group +// Performs device-independent "upfront" compilation of the given fusion_group, +// if it has not been registered already. // Returns a key that can be used to run the fusion later TORCH_API int64_t registerFusion(const Node* fusion_group); diff --git a/torch/csrc/jit/fuser/kernel_cache.cpp b/torch/csrc/jit/fuser/kernel_cache.cpp index cd93447..3c52ad8 100644 --- a/torch/csrc/jit/fuser/kernel_cache.cpp +++ b/torch/csrc/jit/fuser/kernel_cache.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include #include @@ -11,7 +13,13 @@ struct KernelCacheImpl { // occurs. This is a critical property for thread-safety. std::mutex mutex_; int64_t kernel_counter{0}; + + // Map of fusion key to KernelSpec std::unordered_map specMap_; + + // Map of pretty-printed graph string to fusion key + // Used to check if a graph has already been cached in specMap_ + std::unordered_map graphToKey_; }; static KernelCacheImpl& getKernelCache() { @@ -19,27 +27,60 @@ static KernelCacheImpl& getKernelCache() { return cache; } +int64_t debugNumCachedKernelSpecs() { + auto& cache = getKernelCache(); + std::lock_guard guard{cache.mutex_}; + return cache.specMap_.size(); +} + +std::shared_ptr normalizeGraphForCache(const std::shared_ptr& graph) { + auto result = Canonicalize(graph, /*keep_unique_names=*/false); + EraseShapeInformation(result); + return result; +} + // TODO: lookup by historic string key to start, then issue key // as appropriate for faster lookup in the future +// precondition: graph has been normalized via normalizeGraphForCache int64_t store(std::shared_ptr graph) { auto& cache = getKernelCache(); + std::string repr = graph->toString(); + std::lock_guard guard{cache.mutex_}; const auto key = cache.kernel_counter++; cache.specMap_.emplace( std::piecewise_construct , std::forward_as_tuple(key) , std::forward_as_tuple(key, graph)); + cache.graphToKey_.emplace(std::make_pair(std::move(repr), key)); return key; } +// XXX: Does not grab mutex +static at::optional nolock_retrieve( + KernelCacheImpl& cache, const int64_t key) { + auto it = cache.specMap_.find(key); + if (it == cache.specMap_.end()) return at::nullopt; + return &(it->second); +} + at::optional retrieve(const int64_t key) { auto& cache = getKernelCache(); std::lock_guard guard{cache.mutex_}; - auto it = cache.specMap_.find(key); - if (it == cache.specMap_.end()) return nullptr; - return &(it->second); + return nolock_retrieve(cache, key); +} + +// precondition: graph has been normalized via normalizeGraphForCache +at::optional lookupGraph(std::shared_ptr graph) { + auto& cache = getKernelCache(); + std::string repr = graph->toString(); + + std::lock_guard guard{cache.mutex_}; + auto it = cache.graphToKey_.find(repr); + if (it == cache.graphToKey_.end()) return at::nullopt; + return nolock_retrieve(cache, it->second); } } // namespace fuser } // namespace jit -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/torch/csrc/jit/fuser/kernel_cache.h b/torch/csrc/jit/fuser/kernel_cache.h index 88c8eca..63b4710 100644 --- a/torch/csrc/jit/fuser/kernel_cache.h +++ b/torch/csrc/jit/fuser/kernel_cache.h @@ -14,12 +14,22 @@ namespace torch { namespace jit { namespace fuser { // A thread-safe cache interface. +// Normalizes the graph by canonicalizing and erasing shape information +TORCH_API std::shared_ptr normalizeGraphForCache(const std::shared_ptr& graph); + // Stores the given graph, returning the key used to access it TORCH_API int64_t store(std::shared_ptr graph); +// Given a graph, find a KernelSpec based on it +TORCH_API at::optional lookupGraph(std::shared_ptr graph); + // Returns the graph corresponding to the given key (if it exists) TORCH_API at::optional retrieve(const int64_t key); +// Returns the size of the fusion key -> KernelSpec cache. +// Only used for testing. +TORCH_API int64_t debugNumCachedKernelSpecs(); + } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 2ddd772..e5b05d0 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -93,6 +94,10 @@ void initJITBindings(PyObject *module) { py::class_(m, "IODescriptor"); // NOLINT(bugprone-unused-raii) m.def("_jit_init", loadPythonClasses) +#if USE_CUDA_FUSER || USE_CPU_FUSER + .def("_jit_debug_fuser_num_cached_kernel_specs", + torch::jit::fuser::debugNumCachedKernelSpecs) +#endif .def("_jit_pass_onnx", ToONNX) .def("_jit_pass_lower_all_tuples", LowerAllTuples) .def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX) diff --git a/torch/csrc/jit/passes/canonicalize.cpp b/torch/csrc/jit/passes/canonicalize.cpp index 7f60025..80971e6 100644 --- a/torch/csrc/jit/passes/canonicalize.cpp +++ b/torch/csrc/jit/passes/canonicalize.cpp @@ -4,17 +4,27 @@ namespace torch { namespace jit { // Canonicalize a graph, renumbering it so that all structurally equivalent // graphs have same numbers. -std::shared_ptr Canonicalize(const std::shared_ptr& graph) { +// keep_unique_names: If false, canonicalizes unique names by removing them +// and replacing them with normal value names. +// Otherwise, ignores values with unique names. +std::shared_ptr Canonicalize( + const std::shared_ptr& graph, bool keep_unique_names) { auto r = std::make_shared(graph->current_scope()); std::unordered_map rn_env; auto rn_fn = [&](Value* v) { return rn_env.at(v); }; for (auto* input : graph->inputs()) { auto* r_input = r->addInput(); r_input->copyMetadata(input); + if (!keep_unique_names) r_input->setUniqueName(""); rn_env[input] = r_input; } for (auto* node : graph->nodes()) { auto* r_node = r->createClone(node, rn_fn); + if (!keep_unique_names) { + for (auto* output : r_node->outputs()) { + output->setUniqueName(""); + } + } r->appendNode(r_node); auto outputs = node->outputs(); auto r_outputs = r_node->outputs(); @@ -22,7 +32,7 @@ std::shared_ptr Canonicalize(const std::shared_ptr& graph) { rn_env[outputs.at(i)] = r_outputs.at(i); } if (node->hasAttribute(attr::Subgraph)) { - r_node->g_(attr::Subgraph, Canonicalize(node->g(attr::Subgraph))); + r_node->g_(attr::Subgraph, Canonicalize(node->g(attr::Subgraph), keep_unique_names)); } } for (auto* output : graph->outputs()) { diff --git a/torch/csrc/jit/passes/canonicalize.h b/torch/csrc/jit/passes/canonicalize.h index 4c85bfd..0d1e1fa 100644 --- a/torch/csrc/jit/passes/canonicalize.h +++ b/torch/csrc/jit/passes/canonicalize.h @@ -4,6 +4,7 @@ namespace torch { namespace jit { -TORCH_API std::shared_ptr Canonicalize(const std::shared_ptr& graph); +TORCH_API std::shared_ptr Canonicalize( + const std::shared_ptr& graph, bool keep_unique_names=true); }}