Code string API for fuser testing (#18884)
authorJames Reed <jamesreed@fb.com>
Sat, 6 Apr 2019 00:10:13 +0000 (17:10 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 6 Apr 2019 00:13:17 +0000 (17:13 -0700)
Summary:
This adds a C++ function `debugGetFusedKernelCode` as well as a Python binding `_jit_fuser_get_fused_kernel_code` that will, given a FusionGroup graph and a set of specified inputs, return the compiled kernel source code. We can then check the contents of this source code for verification of the fuser codegen backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18884

Differential Revision: D14795508

Pulled By: jamesr66a

fbshipit-source-id: 8f6e9dd13ebbb517737d893b0b5f5e9aa06af124

test/test_jit.py
torch/csrc/jit/fuser/executor.cpp
torch/csrc/jit/fuser/executor.h
torch/csrc/jit/fuser/interface.cpp
torch/csrc/jit/fuser/interface.h
torch/csrc/jit/init.cpp

index 166366c..c2a347a 100644 (file)
@@ -4355,6 +4355,35 @@ a")
 
         self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
 
+    @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
+    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+    @enable_cpu_fuser
+    def test_batchnorm_fuser_cpu(self):
+        code = '''
+            graph(%3 : Tensor,
+                  %7 : Tensor,
+                  %12 : Float(*, *),
+                  %13 : Tensor,
+                  %25 : Tensor):
+                %23 : int = prim::Constant[value=1]()
+                %22 : float = prim::Constant[value=1e-05]()
+                %26 : Tensor = aten::sqrt(%25)
+                %24 : Tensor = aten::add(%26, %22, %23)
+                %20 : Tensor = aten::reciprocal(%24)
+                %norm_invstd : Tensor = aten::mul(%20, %23)
+                %15 : Tensor = aten::sub(%12, %13, %23)
+                %11 : Tensor = aten::mul(%15, %norm_invstd)
+                %8 : Tensor = aten::mul(%11, %7)
+                %5 : Tensor = aten::add(%8, %3, %23)
+                %1 : Float(*, *) = aten::relu(%5)
+                return (%1)
+        '''
+
+        graph = parse_ir(code)
+        inputs = 5 * [torch.rand(26, 2048, dtype=torch.float)]
+        code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs)
+        FileCheck().check('sqrtf').run(code)
+
     def test_fuser_multiple_blocks(self):
         cu = torch.jit.CompilationUnit('''
         def test_fuser_multiple_blocks(this, that, theother, meme):
index 2003366..51b154f 100644 (file)
@@ -317,7 +317,7 @@ void launchFusion(
   fusion.launch_raw(numel, arguments);
 }
 
-bool runFusion(const int64_t key, Stack& stack) {
+bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
   // Short-circuits if fusion isn't enabled
   if (!canFuseOnCPU() && !canFuseOnGPU())
     return false;
@@ -373,6 +373,10 @@ bool runFusion(const int64_t key, Stack& stack) {
   maybe_kernel = spec.findKernel(arg_spec);
   AT_ASSERT(maybe_kernel);
 
+  if (code_out) {
+    *code_out = maybe_kernel.value()->code();
+  }
+
   // Launches fusion
   std::vector<at::Tensor> raw_outputs;
   launchFusion(*(*maybe_kernel), device, inputs, all_inputs, raw_outputs);
index 63f5c70..20a1f14 100644 (file)
@@ -1,7 +1,9 @@
 #pragma once
 
-#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <ATen/core/stack.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/fuser/fused_kernel.h>
+#include <torch/csrc/jit/fuser/kernel_spec.h>
 
 #include <cstdint>
 
@@ -11,7 +13,10 @@ namespace fuser {
 
 // Runs the fusion associated with the key (see registerFusion() in interface.h)
 // on the inputs taken from the given Stack.
-TORCH_API bool runFusion(const int64_t key, Stack& stack);
+TORCH_API bool runFusion(
+    const int64_t key,
+    Stack& stack,
+    std::string* code_out = nullptr);
 
 } // namespace fuser
 } // namespace jit
index 9e2509a..dcf64e9 100644 (file)
@@ -3,6 +3,7 @@
 #include <torch/csrc/jit/fuser/compiler.h>
 #include <torch/csrc/jit/fuser/executor.h>
 #include <torch/csrc/jit/fuser/fallback.h>
+#include <torch/csrc/jit/fuser/kernel_cache.h>
 
 #include <stdexcept>
 
@@ -63,6 +64,33 @@ std::vector<at::Tensor> debugLaunchGraph(
   return fmap(stack, [](const IValue& iv) { return iv.toTensor(); });
 }
 
+std::string debugGetFusedKernelCode(
+    Graph& graph,
+    at::ArrayRef<at::Tensor> inputs) {
+  // Creates a fusion group node
+  auto wrapper_graph = std::make_shared<Graph>();
+  Node* fusion_group =
+      wrapper_graph->insertNode(wrapper_graph->createFusionGroup());
+  fusion_group->g_(attr::Subgraph, graph.copy());
+  for (size_t i = 0; i < graph.inputs().size(); ++i) {
+    fusion_group->addInput(wrapper_graph->addInput());
+  }
+  for (size_t i = 0; i < graph.outputs().size(); ++i) {
+    wrapper_graph->registerOutput(fusion_group->addOutput());
+  }
+
+  // Creates the stack, registers and runs the fusion
+  Stack stack = fmap<IValue>(inputs);
+  const auto key = fuser::registerFusion(fusion_group);
+
+  std::string code;
+  if (!fuser::runFusion(key, stack, &code)) {
+    throw std::runtime_error("Could not run fusion for graph");
+  }
+
+  return code;
+}
+
 size_t nCompiledKernels() {
   return fuser::nCompiledKernels();
 }
index 8988a24..d34dbc0 100644 (file)
@@ -39,6 +39,11 @@ TORCH_API std::vector<at::Tensor> debugLaunchGraph(
     Graph& graph,
     at::ArrayRef<at::Tensor> inputs);
 
+// Treats the given graph as a fusion group and returns the generated code.
+TORCH_API std::string debugGetFusedKernelCode(
+    Graph& graph,
+    at::ArrayRef<at::Tensor> inputs);
+
 TORCH_API size_t nCompiledKernels();
 
 } // namespace jit
index 922873f..731a5ad 100644 (file)
@@ -226,6 +226,11 @@ void initJITBindings(PyObject* module) {
              const std::string& unqualified_op_name) {
             auto stack = toStack(args);
             checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
+          })
+      .def(
+          "_jit_fuser_get_fused_kernel_code",
+          [](Graph& g, std::vector<at::Tensor> inps) {
+            return debugGetFusedKernelCode(g, inps);
           });
 
   // NOLINTNEXTLINE(bugprone-unused-raii)