Reuse KernelSpec for FusionGroups with equivalent graphs (#14541)
authorRichard Zou <zou3519@gmail.com>
Thu, 13 Dec 2018 15:51:08 +0000 (07:51 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 13 Dec 2018 15:54:35 +0000 (07:54 -0800)
Summary:
Before this PR, loop unrolling + the graph fuser was creating multiple
FusionGroups with the same bodies (with different variable names) for
JIT LSTMs. Each FusionGroup got registered to a separate fusion key;
each key resulted in a different compilation for the same
specializations.

This PR makes it so that when registering FusionGroups with the fusion
compiler, the compiler first checks the KernelSpec cache to see if the
FusionGroup's graph exists already. If it does, then return the
corresponding KernelSpec's key to share compiled kernels.

In addition, graphs in the KernelSpec cache are canonicalized before
being cached. I added a flag to the canonicalize pass to remove unique
names of values.

This shortens the compile time for a JIT LSTM (seq_len of 100, loop
unroll factor of 8) from 5.3s to 2.3s. Most of this compile time is
running the graph fuser and/or fusion compiler; while this PR
makes it so that there is only one unique kernel in the forward pass,
there are a lot of different kernels (6) in the backward pass
(after loop unrolling) that should be investigated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14541

Differential Revision: D13324487

Pulled By: zou3519

fbshipit-source-id: b841d82ed35a959b5cfc72db033bf5a7b42cc4fb

test/cpp/jit/gtest.cpp
test/cpp/jit/tests.h
test/test_jit.py
torch/csrc/jit/fuser/compiler.cpp
torch/csrc/jit/fuser/compiler.h
torch/csrc/jit/fuser/kernel_cache.cpp
torch/csrc/jit/fuser/kernel_cache.h
torch/csrc/jit/init.cpp
torch/csrc/jit/passes/canonicalize.cpp
torch/csrc/jit/passes/canonicalize.h

index fb032db..6715588 100644 (file)
@@ -23,6 +23,7 @@ JIT_TEST(DynamicDAG)
 JIT_TEST(FromQualString)
 JIT_TEST(InternedStrings)
 JIT_TEST(IValue)
+JIT_TEST(RegisterFusionCachesKernel)
 JIT_TEST(SchemaParser)
 JIT_TEST(TopologicalIndex)
 JIT_TEST(TopologicalMove)
index db1caf2..5782dc0 100644 (file)
@@ -45,6 +45,7 @@
 #include "torch/csrc/jit/passes/constant_propagation.h"
 #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
 #include "torch/csrc/jit/passes/dead_code_elimination.h"
+#include "torch/csrc/jit/passes/graph_fuser.h"
 #include "torch/csrc/jit/passes/lower_grad_of.h"
 #include "torch/csrc/jit/passes/lower_tuples.h"
 #include "torch/csrc/jit/passes/requires_grad_analysis.h"
@@ -878,6 +879,50 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) {
   out << "\n";
 }
 
+void testRegisterFusionCachesKernel(std::ostream& out = std::cout) {
+  // Build up a fake graph with a FusionGroup
+  auto createGraphWithNames = [](std::string cname, std::string dname) {
+    auto graph = std::make_shared<Graph>();
+    at::ScalarType s = at::ScalarType::Float;
+    auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
+    auto a = SymbolicVariable::asNewInput(*graph, type);
+    auto b = SymbolicVariable::asNewInput(*graph, type);
+    auto c = a * b;
+    auto d = c * a;
+    c.value()->setUniqueName(cname);
+    d.value()->setUniqueName(dname);
+    graph->registerOutput(d.value());
+    FuseGraph(graph);
+    return graph;
+  };
+
+  auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) {
+    const auto& nodes = graph->nodes();
+    auto maybe_fusion_group = std::find_if(
+        nodes.begin(), nodes.end(),
+        [](const Node* node) { return node->kind() == prim::FusionGroup; });
+    JIT_ASSERTM(
+        maybe_fusion_group != nodes.end(),
+        "testRegisterFusionCachesKernel: could not create FusionGroup");
+    return *maybe_fusion_group;
+  };
+
+  // Create two alpha-equivalent fusion groups
+  auto graph1 = createGraphWithNames("c1", "d1");
+  auto fg1 = getFusionGroup(graph1);
+
+  auto graph2 = createGraphWithNames("c2", "d2");
+  auto fg2 = getFusionGroup(graph2);
+
+  // Register both with the fusion compiler.
+  auto expected_key = registerFusion(fg1);
+  auto second_key = registerFusion(fg2);
+
+  // Because the graphs are alpha-equivalent, they should return the same key
+  // and therefore share a KernelSpec to share kernels for specializations
+  ASSERT_EQ(second_key, expected_key);
+}
+
 void testCreateAutodiffSubgraphs(std::ostream& out = std::cout) {
   auto graph = build_lstm();
   CreateAutodiffSubgraphs(graph, /*threshold=*/2);
index 0000e93..80801c3 100644 (file)
@@ -9627,31 +9627,6 @@ class TestFuser(JitTestCase):
         ge = self.checkTrace(scaleshift, inputs)
         self.assertExpectedGraph(ge.graph_for(*inputs))
 
-    @staticmethod
-    def _test_cast_Float(self, device):
-        def f(x, y):
-            z = x.float()
-            return z + y
-
-        inputs = [
-            torch.randn(4, 4, dtype=torch.double, device=device),
-            torch.randn(4, 4, dtype=torch.float, device=device),
-        ]
-
-        ge = self.checkScript(f, inputs)
-        self.assertAllFused(ge.graph_for(*inputs))
-
-    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
-    @enable_cpu_fuser
-    def test_cast_Float(self):
-        return self._test_cast_Float(self, 'cpu')
-
-    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
-    @unittest.skipIf(not RUN_CUDA, "No CUDA")
-    @skipIfRocm
-    def test_cast_Float_cuda(self):
-        return self._test_cast_Float(self, 'cuda')
-
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
     @unittest.skipIf(not RUN_CUDA_HALF, "no half support")
@@ -9942,6 +9917,60 @@ class TestFuser(JitTestCase):
         ge = self.checkTrace(self.fn_test_exp, (x, y))
         self.assertAllFused(ge.graph_for(x, y))
 
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
+    @skipIfRocm
+    @enable_cpu_fuser
+    def test_fusion_reuse_multi_gpu(self):
+        def fn(x, y):
+            return x * y * x * y
+
+        inputs_cpu = [
+            torch.randn(4, 4, dtype=torch.float),
+            torch.randn(4, 4, dtype=torch.float),
+        ]
+        inputs_cuda0 = [x.cuda(0) for x in inputs_cpu]
+        inputs_cuda1 = [y.cuda(1) for y in inputs_cpu]
+
+        # Should not crash; these should compile different kernels.
+        ge = self.checkScript(fn, inputs_cpu)
+        self.assertAllFused(ge.graph_for(*inputs_cpu))
+        ge(*inputs_cuda0)
+        ge(*inputs_cuda1)
+
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
+    @skipIfRocm
+    @enable_cpu_fuser
+    def test_kernel_cache_multi_gpu(self):
+        def not_fusible(x):
+            return x
+
+        def fn(x, y, z):
+            x_out = x * x * x * x * x  # fusion: lambda x. x * x * x * x * x
+            y_out = y * y * y * y * y
+            z_out = z * z * z * z * z
+            return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out)
+
+        inputs = [
+            torch.randn(4, 4, dtype=torch.float),
+            torch.randn(4, 4, dtype=torch.float, device='cuda:0'),
+            torch.randn(4, 4, dtype=torch.float, device='cuda:1'),
+        ]
+
+        prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
+
+        # There are 3 FusionGroups. Because they have the same graph, they
+        # should reuse the same KernelSpec in the KernelSpec cache.
+        ge = self.checkScript(fn, inputs)
+        self.assertGraphContainsExactly(
+            ge.graph_for(*inputs), 'prim::FusionGroup', 3, True)
+        new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
+        # XXX: This assumes that the same kernel isn't already used by another test
+        self.assertEqual(new_cache_size - prev_cache_size, 1)
+
     # TODO: This test doesn't offer anything valuable, maybe we should delete it
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
index 6615310..dadf9cc 100644 (file)
@@ -5,11 +5,13 @@
 #include <torch/csrc/jit/type.h>
 #include <torch/csrc/jit/code_template.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/passes/canonicalize.h>
 #include <torch/csrc/jit/passes/shape_analysis.h>
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/fuser/kernel_cache.h>
 #include <torch/csrc/jit/fuser/codegen.h>
 #include <torch/csrc/jit/fuser/tensor_desc.h>
+#include "torch/csrc/jit/fuser/interface.h"
 
 #if USE_CUDA_FUSER
   #include <torch/csrc/jit/fuser/cuda/fused_kernel.h>
@@ -132,17 +134,24 @@ static void upfrontCompilation(KernelSpec& spec) {
 }
 
 int64_t registerFusion(const Node* fusion_group) {
-  // Creates and stores the FusionSpec
-  auto graph = fusion_group->g(attr::Subgraph)->copy();
-  EraseShapeInformation(graph);
-  const auto key = store(graph);
+  auto graph = normalizeGraphForCache(fusion_group->g(attr::Subgraph));
 
-  if (canFuseOnCPU() || canFuseOnGPU()) {
-    const auto maybe_spec = retrieve(key);
-    JIT_ASSERT(maybe_spec);
-    upfrontCompilation(**maybe_spec);
+  // Don't re-register the fusion if we can use a pre-existing one
+  const auto maybe_spec = lookupGraph(graph);
+  if (maybe_spec) {
+    return (*maybe_spec)->key();
   }
 
+  // Unconditionally create and register the fusion
+  // This is necessary to support our global disable fusions flag: if someone
+  // runs some code under no-fusions mode and then runs some code with fusions
+  // enabled, the second time around the returned spec from the cache should
+  // be a valid spec (must have had upfrontCompilation run on it).
+  const auto key = store(graph);
+  const auto maybe_retrieved_spec = retrieve(key);
+  JIT_ASSERT(maybe_retrieved_spec);
+  upfrontCompilation(**maybe_retrieved_spec);
+
   return key;
 }
 
index c25e6db..2a7f6f0 100644 (file)
@@ -16,7 +16,8 @@
 
 namespace torch { namespace jit { namespace fuser {
 
-// Performs device-independent "upfront" compilation of the given fusion_group
+// Performs device-independent "upfront" compilation of the given fusion_group,
+// if it has not been registered already.
 // Returns a key that can be used to run the fusion later
 TORCH_API int64_t registerFusion(const Node* fusion_group);
 
index cd93447..3c52ad8 100644 (file)
@@ -1,4 +1,6 @@
 #include <torch/csrc/jit/fuser/kernel_cache.h>
+#include <torch/csrc/jit/passes/canonicalize.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
 
 #include <unordered_map>
 #include <mutex>
@@ -11,7 +13,13 @@ struct KernelCacheImpl {
   // occurs. This is a critical property for thread-safety. 
   std::mutex mutex_;
   int64_t kernel_counter{0};
+
+  // Map of fusion key to KernelSpec
   std::unordered_map<int64_t, KernelSpec> specMap_;
+
+  // Map of pretty-printed graph string to fusion key
+  // Used to check if a graph has already been cached in specMap_
+  std::unordered_map<std::string, int64_t> graphToKey_;
 };
 
 static KernelCacheImpl& getKernelCache() {
@@ -19,27 +27,60 @@ static KernelCacheImpl& getKernelCache() {
   return cache;
 }
 
+int64_t debugNumCachedKernelSpecs() {
+  auto& cache = getKernelCache();
+  std::lock_guard<std::mutex> guard{cache.mutex_};
+  return cache.specMap_.size();
+}
+
+std::shared_ptr<Graph> normalizeGraphForCache(const std::shared_ptr<Graph>& graph) {
+  auto result = Canonicalize(graph, /*keep_unique_names=*/false);
+  EraseShapeInformation(result);
+  return result;
+}
+
 // TODO: lookup by historic string key to start, then issue key
 // as appropriate for faster lookup in the future
+// precondition: graph has been normalized via normalizeGraphForCache
 int64_t store(std::shared_ptr<Graph> graph) {
   auto& cache = getKernelCache();
+  std::string repr = graph->toString();
+
   std::lock_guard<std::mutex> guard{cache.mutex_};
   const auto key = cache.kernel_counter++;
   cache.specMap_.emplace(
     std::piecewise_construct
   , std::forward_as_tuple(key)
   , std::forward_as_tuple(key, graph));
+  cache.graphToKey_.emplace(std::make_pair(std::move(repr), key));
   return key;
 }
 
+// XXX: Does not grab mutex
+static at::optional<KernelSpec*> nolock_retrieve(
+    KernelCacheImpl& cache, const int64_t key) {
+  auto it = cache.specMap_.find(key);
+  if (it == cache.specMap_.end()) return at::nullopt;
+  return &(it->second);
+}
+
 at::optional<KernelSpec*> retrieve(const int64_t key) { 
   auto& cache = getKernelCache();
   std::lock_guard<std::mutex> guard{cache.mutex_};
-  auto it = cache.specMap_.find(key);
-  if (it == cache.specMap_.end()) return nullptr;
-  return &(it->second);
+  return nolock_retrieve(cache, key);
+}
+
+// precondition: graph has been normalized via normalizeGraphForCache
+at::optional<KernelSpec*> lookupGraph(std::shared_ptr<Graph> graph) {
+  auto& cache = getKernelCache();
+  std::string repr = graph->toString();
+
+  std::lock_guard<std::mutex> guard{cache.mutex_};
+  auto it = cache.graphToKey_.find(repr);
+  if (it == cache.graphToKey_.end()) return at::nullopt;
+  return nolock_retrieve(cache, it->second);
 }
 
 } // namespace fuser
 } // namespace jit
-} // namespace torch
\ No newline at end of file
+} // namespace torch
index 88c8eca..63b4710 100644 (file)
@@ -14,12 +14,22 @@ namespace torch { namespace jit { namespace fuser {
 
 // A thread-safe cache interface.
 
+// Normalizes the graph by canonicalizing and erasing shape information
+TORCH_API std::shared_ptr<Graph> normalizeGraphForCache(const std::shared_ptr<Graph>& graph);
+
 // Stores the given graph, returning the key used to access it
 TORCH_API int64_t store(std::shared_ptr<Graph> graph);
 
+// Given a graph, find a KernelSpec based on it
+TORCH_API at::optional<KernelSpec*> lookupGraph(std::shared_ptr<Graph> graph);
+
 // Returns the graph corresponding to the given key (if it exists)
 TORCH_API at::optional<KernelSpec*> retrieve(const int64_t key);
 
+// Returns the size of the fusion key -> KernelSpec cache.
+// Only used for testing.
+TORCH_API int64_t debugNumCachedKernelSpecs();
+
 } // namespace fuser
 } // namespace jit
 } // namespace torch
index 2ddd772..e5b05d0 100644 (file)
@@ -8,6 +8,7 @@
 #include <torch/csrc/jit/export.h>
 #include <torch/csrc/jit/import.h>
 #include <torch/csrc/jit/argument_spec.h>
+#include <torch/csrc/jit/fuser/kernel_cache.h>
 #include <torch/csrc/jit/passes/remove_expands.h>
 #include <torch/csrc/jit/passes/graph_fuser.h>
 #include <torch/csrc/jit/passes/onnx.h>
@@ -93,6 +94,10 @@ void initJITBindings(PyObject *module) {
   py::class_<python::IODescriptor>(m, "IODescriptor"); // NOLINT(bugprone-unused-raii)
 
   m.def("_jit_init", loadPythonClasses)
+#if USE_CUDA_FUSER || USE_CPU_FUSER
+   .def("_jit_debug_fuser_num_cached_kernel_specs",
+       torch::jit::fuser::debugNumCachedKernelSpecs)
+#endif
    .def("_jit_pass_onnx", ToONNX)
    .def("_jit_pass_lower_all_tuples", LowerAllTuples)
    .def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
index 7f60025..80971e6 100644 (file)
@@ -4,17 +4,27 @@ namespace torch { namespace jit {
 
 // Canonicalize a graph, renumbering it so that all structurally equivalent
 // graphs have same numbers.
-std::shared_ptr<Graph> Canonicalize(const std::shared_ptr<Graph>& graph) {
+// keep_unique_names: If false, canonicalizes unique names by removing them
+//   and replacing them with normal value names.
+//   Otherwise, ignores values with unique names.
+std::shared_ptr<Graph> Canonicalize(
+    const std::shared_ptr<Graph>& graph, bool keep_unique_names) {
   auto r = std::make_shared<Graph>(graph->current_scope());
   std::unordered_map<Value*, Value*> rn_env;
   auto rn_fn = [&](Value* v) { return rn_env.at(v); };
   for (auto* input : graph->inputs()) {
     auto* r_input = r->addInput();
     r_input->copyMetadata(input);
+    if (!keep_unique_names) r_input->setUniqueName("");
     rn_env[input] = r_input;
   }
   for (auto* node : graph->nodes()) {
     auto* r_node = r->createClone(node, rn_fn);
+    if (!keep_unique_names) {
+      for (auto* output : r_node->outputs()) {
+        output->setUniqueName("");
+      }
+    }
     r->appendNode(r_node);
     auto outputs = node->outputs();
     auto r_outputs = r_node->outputs();
@@ -22,7 +32,7 @@ std::shared_ptr<Graph> Canonicalize(const std::shared_ptr<Graph>& graph) {
       rn_env[outputs.at(i)] = r_outputs.at(i);
     }
     if (node->hasAttribute(attr::Subgraph)) {
-      r_node->g_(attr::Subgraph, Canonicalize(node->g(attr::Subgraph)));
+      r_node->g_(attr::Subgraph, Canonicalize(node->g(attr::Subgraph), keep_unique_names));
     }
   }
   for (auto* output : graph->outputs()) {
index 4c85bfd..0d1e1fa 100644 (file)
@@ -4,6 +4,7 @@
 
 namespace torch { namespace jit {
 
-TORCH_API std::shared_ptr<Graph> Canonicalize(const std::shared_ptr<Graph>& graph);
+TORCH_API std::shared_ptr<Graph> Canonicalize(
+    const std::shared_ptr<Graph>& graph, bool keep_unique_names=true);
 
 }}