JIT_TEST(FromQualString)
JIT_TEST(InternedStrings)
JIT_TEST(IValue)
+JIT_TEST(RegisterFusionCachesKernel)
JIT_TEST(SchemaParser)
JIT_TEST(TopologicalIndex)
JIT_TEST(TopologicalMove)
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
+#include "torch/csrc/jit/passes/graph_fuser.h"
#include "torch/csrc/jit/passes/lower_grad_of.h"
#include "torch/csrc/jit/passes/lower_tuples.h"
#include "torch/csrc/jit/passes/requires_grad_analysis.h"
out << "\n";
}
+void testRegisterFusionCachesKernel(std::ostream& out = std::cout) {
+ // Build up a fake graph with a FusionGroup
+ auto createGraphWithNames = [](std::string cname, std::string dname) {
+ auto graph = std::make_shared<Graph>();
+ at::ScalarType s = at::ScalarType::Float;
+ auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
+ auto a = SymbolicVariable::asNewInput(*graph, type);
+ auto b = SymbolicVariable::asNewInput(*graph, type);
+ auto c = a * b;
+ auto d = c * a;
+ c.value()->setUniqueName(cname);
+ d.value()->setUniqueName(dname);
+ graph->registerOutput(d.value());
+ FuseGraph(graph);
+ return graph;
+ };
+
+ auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) {
+ const auto& nodes = graph->nodes();
+ auto maybe_fusion_group = std::find_if(
+ nodes.begin(), nodes.end(),
+ [](const Node* node) { return node->kind() == prim::FusionGroup; });
+ JIT_ASSERTM(
+ maybe_fusion_group != nodes.end(),
+ "testRegisterFusionCachesKernel: could not create FusionGroup");
+ return *maybe_fusion_group;
+ };
+
+ // Create two alpha-equivalent fusion groups
+ auto graph1 = createGraphWithNames("c1", "d1");
+ auto fg1 = getFusionGroup(graph1);
+
+ auto graph2 = createGraphWithNames("c2", "d2");
+ auto fg2 = getFusionGroup(graph2);
+
+ // Register both with the fusion compiler.
+ auto expected_key = registerFusion(fg1);
+ auto second_key = registerFusion(fg2);
+
+ // Because the graphs are alpha-equivalent, they should return the same key
+ // and therefore share a KernelSpec to share kernels for specializations
+ ASSERT_EQ(second_key, expected_key);
+}
+
void testCreateAutodiffSubgraphs(std::ostream& out = std::cout) {
auto graph = build_lstm();
CreateAutodiffSubgraphs(graph, /*threshold=*/2);
ge = self.checkTrace(scaleshift, inputs)
self.assertExpectedGraph(ge.graph_for(*inputs))
- @staticmethod
- def _test_cast_Float(self, device):
- def f(x, y):
- z = x.float()
- return z + y
-
- inputs = [
- torch.randn(4, 4, dtype=torch.double, device=device),
- torch.randn(4, 4, dtype=torch.float, device=device),
- ]
-
- ge = self.checkScript(f, inputs)
- self.assertAllFused(ge.graph_for(*inputs))
-
- @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
- @enable_cpu_fuser
- def test_cast_Float(self):
- return self._test_cast_Float(self, 'cpu')
-
- @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
- @unittest.skipIf(not RUN_CUDA, "No CUDA")
- @skipIfRocm
- def test_cast_Float_cuda(self):
- return self._test_cast_Float(self, 'cuda')
-
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
ge = self.checkTrace(self.fn_test_exp, (x, y))
self.assertAllFused(ge.graph_for(x, y))
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
+ @skipIfRocm
+ @enable_cpu_fuser
+ def test_fusion_reuse_multi_gpu(self):
+ def fn(x, y):
+ return x * y * x * y
+
+ inputs_cpu = [
+ torch.randn(4, 4, dtype=torch.float),
+ torch.randn(4, 4, dtype=torch.float),
+ ]
+ inputs_cuda0 = [x.cuda(0) for x in inputs_cpu]
+ inputs_cuda1 = [y.cuda(1) for y in inputs_cpu]
+
+ # Should not crash; these should compile different kernels.
+ ge = self.checkScript(fn, inputs_cpu)
+ self.assertAllFused(ge.graph_for(*inputs_cpu))
+ ge(*inputs_cuda0)
+ ge(*inputs_cuda1)
+
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
+ @skipIfRocm
+ @enable_cpu_fuser
+ def test_kernel_cache_multi_gpu(self):
+ def not_fusible(x):
+ return x
+
+ def fn(x, y, z):
+ x_out = x * x * x * x * x # fusion: lambda x. x * x * x * x * x
+ y_out = y * y * y * y * y
+ z_out = z * z * z * z * z
+ return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out)
+
+ inputs = [
+ torch.randn(4, 4, dtype=torch.float),
+ torch.randn(4, 4, dtype=torch.float, device='cuda:0'),
+ torch.randn(4, 4, dtype=torch.float, device='cuda:1'),
+ ]
+
+ prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
+
+ # There are 3 FusionGroups. Because they have the same graph, they
+ # should reuse the same KernelSpec in the KernelSpec cache.
+ ge = self.checkScript(fn, inputs)
+ self.assertGraphContainsExactly(
+ ge.graph_for(*inputs), 'prim::FusionGroup', 3, True)
+ new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
+ # XXX: This assumes that the same kernel isn't already used by another test
+ self.assertEqual(new_cache_size - prev_cache_size, 1)
+
# TODO: This test doesn't offer anything valuable, maybe we should delete it
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
#include <torch/csrc/jit/type.h>
#include <torch/csrc/jit/code_template.h>
#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/fuser/kernel_cache.h>
#include <torch/csrc/jit/fuser/codegen.h>
#include <torch/csrc/jit/fuser/tensor_desc.h>
+#include "torch/csrc/jit/fuser/interface.h"
#if USE_CUDA_FUSER
#include <torch/csrc/jit/fuser/cuda/fused_kernel.h>
}
int64_t registerFusion(const Node* fusion_group) {
- // Creates and stores the FusionSpec
- auto graph = fusion_group->g(attr::Subgraph)->copy();
- EraseShapeInformation(graph);
- const auto key = store(graph);
+ auto graph = normalizeGraphForCache(fusion_group->g(attr::Subgraph));
- if (canFuseOnCPU() || canFuseOnGPU()) {
- const auto maybe_spec = retrieve(key);
- JIT_ASSERT(maybe_spec);
- upfrontCompilation(**maybe_spec);
+ // Don't re-register the fusion if we can use a pre-existing one
+ const auto maybe_spec = lookupGraph(graph);
+ if (maybe_spec) {
+ return (*maybe_spec)->key();
}
+ // Unconditionally create and register the fusion
+ // This is necessary to support our global disable fusions flag: if someone
+ // runs some code under no-fusions mode and then runs some code with fusions
+ // enabled, the second time around the returned spec from the cache should
+ // be a valid spec (must have had upfrontCompilation run on it).
+ const auto key = store(graph);
+ const auto maybe_retrieved_spec = retrieve(key);
+ JIT_ASSERT(maybe_retrieved_spec);
+ upfrontCompilation(**maybe_retrieved_spec);
+
return key;
}
namespace torch { namespace jit { namespace fuser {
-// Performs device-independent "upfront" compilation of the given fusion_group
+// Performs device-independent "upfront" compilation of the given fusion_group,
+// if it has not been registered already.
// Returns a key that can be used to run the fusion later
TORCH_API int64_t registerFusion(const Node* fusion_group);
#include <torch/csrc/jit/fuser/kernel_cache.h>
+#include <torch/csrc/jit/passes/canonicalize.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
#include <unordered_map>
#include <mutex>
// occurs. This is a critical property for thread-safety.
std::mutex mutex_;
int64_t kernel_counter{0};
+
+ // Map of fusion key to KernelSpec
std::unordered_map<int64_t, KernelSpec> specMap_;
+
+ // Map of pretty-printed graph string to fusion key
+ // Used to check if a graph has already been cached in specMap_
+ std::unordered_map<std::string, int64_t> graphToKey_;
};
static KernelCacheImpl& getKernelCache() {
return cache;
}
+int64_t debugNumCachedKernelSpecs() {
+ auto& cache = getKernelCache();
+ std::lock_guard<std::mutex> guard{cache.mutex_};
+ return cache.specMap_.size();
+}
+
+std::shared_ptr<Graph> normalizeGraphForCache(const std::shared_ptr<Graph>& graph) {
+ auto result = Canonicalize(graph, /*keep_unique_names=*/false);
+ EraseShapeInformation(result);
+ return result;
+}
+
// TODO: lookup by historic string key to start, then issue key
// as appropriate for faster lookup in the future
+// precondition: graph has been normalized via normalizeGraphForCache
int64_t store(std::shared_ptr<Graph> graph) {
auto& cache = getKernelCache();
+ std::string repr = graph->toString();
+
std::lock_guard<std::mutex> guard{cache.mutex_};
const auto key = cache.kernel_counter++;
cache.specMap_.emplace(
std::piecewise_construct
, std::forward_as_tuple(key)
, std::forward_as_tuple(key, graph));
+ cache.graphToKey_.emplace(std::make_pair(std::move(repr), key));
return key;
}
+// XXX: Does not grab mutex
+static at::optional<KernelSpec*> nolock_retrieve(
+ KernelCacheImpl& cache, const int64_t key) {
+ auto it = cache.specMap_.find(key);
+ if (it == cache.specMap_.end()) return at::nullopt;
+ return &(it->second);
+}
+
at::optional<KernelSpec*> retrieve(const int64_t key) {
auto& cache = getKernelCache();
std::lock_guard<std::mutex> guard{cache.mutex_};
- auto it = cache.specMap_.find(key);
- if (it == cache.specMap_.end()) return nullptr;
- return &(it->second);
+ return nolock_retrieve(cache, key);
+}
+
+// precondition: graph has been normalized via normalizeGraphForCache
+at::optional<KernelSpec*> lookupGraph(std::shared_ptr<Graph> graph) {
+ auto& cache = getKernelCache();
+ std::string repr = graph->toString();
+
+ std::lock_guard<std::mutex> guard{cache.mutex_};
+ auto it = cache.graphToKey_.find(repr);
+ if (it == cache.graphToKey_.end()) return at::nullopt;
+ return nolock_retrieve(cache, it->second);
}
} // namespace fuser
} // namespace jit
-} // namespace torch
\ No newline at end of file
+} // namespace torch
// A thread-safe cache interface.
+// Normalizes the graph by canonicalizing and erasing shape information
+TORCH_API std::shared_ptr<Graph> normalizeGraphForCache(const std::shared_ptr<Graph>& graph);
+
// Stores the given graph, returning the key used to access it
TORCH_API int64_t store(std::shared_ptr<Graph> graph);
+// Given a graph, find a KernelSpec based on it
+TORCH_API at::optional<KernelSpec*> lookupGraph(std::shared_ptr<Graph> graph);
+
// Returns the graph corresponding to the given key (if it exists)
TORCH_API at::optional<KernelSpec*> retrieve(const int64_t key);
+// Returns the size of the fusion key -> KernelSpec cache.
+// Only used for testing.
+TORCH_API int64_t debugNumCachedKernelSpecs();
+
} // namespace fuser
} // namespace jit
} // namespace torch
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/import.h>
#include <torch/csrc/jit/argument_spec.h>
+#include <torch/csrc/jit/fuser/kernel_cache.h>
#include <torch/csrc/jit/passes/remove_expands.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/passes/onnx.h>
py::class_<python::IODescriptor>(m, "IODescriptor"); // NOLINT(bugprone-unused-raii)
m.def("_jit_init", loadPythonClasses)
+#if USE_CUDA_FUSER || USE_CPU_FUSER
+ .def("_jit_debug_fuser_num_cached_kernel_specs",
+ torch::jit::fuser::debugNumCachedKernelSpecs)
+#endif
.def("_jit_pass_onnx", ToONNX)
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
.def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
// Canonicalize a graph, renumbering it so that all structurally equivalent
// graphs have same numbers.
-std::shared_ptr<Graph> Canonicalize(const std::shared_ptr<Graph>& graph) {
+// keep_unique_names: If false, canonicalizes unique names by removing them
+// and replacing them with normal value names.
+// Otherwise, ignores values with unique names.
+std::shared_ptr<Graph> Canonicalize(
+ const std::shared_ptr<Graph>& graph, bool keep_unique_names) {
auto r = std::make_shared<Graph>(graph->current_scope());
std::unordered_map<Value*, Value*> rn_env;
auto rn_fn = [&](Value* v) { return rn_env.at(v); };
for (auto* input : graph->inputs()) {
auto* r_input = r->addInput();
r_input->copyMetadata(input);
+ if (!keep_unique_names) r_input->setUniqueName("");
rn_env[input] = r_input;
}
for (auto* node : graph->nodes()) {
auto* r_node = r->createClone(node, rn_fn);
+ if (!keep_unique_names) {
+ for (auto* output : r_node->outputs()) {
+ output->setUniqueName("");
+ }
+ }
r->appendNode(r_node);
auto outputs = node->outputs();
auto r_outputs = r_node->outputs();
rn_env[outputs.at(i)] = r_outputs.at(i);
}
if (node->hasAttribute(attr::Subgraph)) {
- r_node->g_(attr::Subgraph, Canonicalize(node->g(attr::Subgraph)));
+ r_node->g_(attr::Subgraph, Canonicalize(node->g(attr::Subgraph), keep_unique_names));
}
}
for (auto* output : graph->outputs()) {
namespace torch { namespace jit {
-TORCH_API std::shared_ptr<Graph> Canonicalize(const std::shared_ptr<Graph>& graph);
+TORCH_API std::shared_ptr<Graph> Canonicalize(
+ const std::shared_ptr<Graph>& graph, bool keep_unique_names=true);
}}