self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
+ @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
+ @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+ @enable_cpu_fuser
+ def test_batchnorm_fuser_cpu(self):
+ code = '''
+ graph(%3 : Tensor,
+ %7 : Tensor,
+ %12 : Float(*, *),
+ %13 : Tensor,
+ %25 : Tensor):
+ %23 : int = prim::Constant[value=1]()
+ %22 : float = prim::Constant[value=1e-05]()
+ %26 : Tensor = aten::sqrt(%25)
+ %24 : Tensor = aten::add(%26, %22, %23)
+ %20 : Tensor = aten::reciprocal(%24)
+ %norm_invstd : Tensor = aten::mul(%20, %23)
+ %15 : Tensor = aten::sub(%12, %13, %23)
+ %11 : Tensor = aten::mul(%15, %norm_invstd)
+ %8 : Tensor = aten::mul(%11, %7)
+ %5 : Tensor = aten::add(%8, %3, %23)
+ %1 : Float(*, *) = aten::relu(%5)
+ return (%1)
+ '''
+
+ graph = parse_ir(code)
+ inputs = 5 * [torch.rand(26, 2048, dtype=torch.float)]
+ code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs)
+ FileCheck().check('sqrtf').run(code)
+
def test_fuser_multiple_blocks(self):
cu = torch.jit.CompilationUnit('''
def test_fuser_multiple_blocks(this, that, theother, meme):
fusion.launch_raw(numel, arguments);
}
-bool runFusion(const int64_t key, Stack& stack) {
+bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
// Short-circuits if fusion isn't enabled
if (!canFuseOnCPU() && !canFuseOnGPU())
return false;
maybe_kernel = spec.findKernel(arg_spec);
AT_ASSERT(maybe_kernel);
+ if (code_out) {
+ *code_out = maybe_kernel.value()->code();
+ }
+
// Launches fusion
std::vector<at::Tensor> raw_outputs;
launchFusion(*(*maybe_kernel), device, inputs, all_inputs, raw_outputs);
#pragma once
-#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/core/stack.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/fuser/fused_kernel.h>
+#include <torch/csrc/jit/fuser/kernel_spec.h>
#include <cstdint>
// Runs the fusion associated with the key (see registerFusion() in interface.h)
// on the inputs taken from the given Stack.
-TORCH_API bool runFusion(const int64_t key, Stack& stack);
+TORCH_API bool runFusion(
+ const int64_t key,
+ Stack& stack,
+ std::string* code_out = nullptr);
} // namespace fuser
} // namespace jit
#include <torch/csrc/jit/fuser/compiler.h>
#include <torch/csrc/jit/fuser/executor.h>
#include <torch/csrc/jit/fuser/fallback.h>
+#include <torch/csrc/jit/fuser/kernel_cache.h>
#include <stdexcept>
return fmap(stack, [](const IValue& iv) { return iv.toTensor(); });
}
+std::string debugGetFusedKernelCode(
+ Graph& graph,
+ at::ArrayRef<at::Tensor> inputs) {
+ // Creates a fusion group node
+ auto wrapper_graph = std::make_shared<Graph>();
+ Node* fusion_group =
+ wrapper_graph->insertNode(wrapper_graph->createFusionGroup());
+ fusion_group->g_(attr::Subgraph, graph.copy());
+ for (size_t i = 0; i < graph.inputs().size(); ++i) {
+ fusion_group->addInput(wrapper_graph->addInput());
+ }
+ for (size_t i = 0; i < graph.outputs().size(); ++i) {
+ wrapper_graph->registerOutput(fusion_group->addOutput());
+ }
+
+ // Creates the stack, registers and runs the fusion
+ Stack stack = fmap<IValue>(inputs);
+ const auto key = fuser::registerFusion(fusion_group);
+
+ std::string code;
+ if (!fuser::runFusion(key, stack, &code)) {
+ throw std::runtime_error("Could not run fusion for graph");
+ }
+
+ return code;
+}
+
size_t nCompiledKernels() {
return fuser::nCompiledKernels();
}
Graph& graph,
at::ArrayRef<at::Tensor> inputs);
+// Treats the given graph as a fusion group and returns the generated code.
+TORCH_API std::string debugGetFusedKernelCode(
+ Graph& graph,
+ at::ArrayRef<at::Tensor> inputs);
+
TORCH_API size_t nCompiledKernels();
} // namespace jit
const std::string& unqualified_op_name) {
auto stack = toStack(args);
checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
+ })
+ .def(
+ "_jit_fuser_get_fused_kernel_code",
+ [](Graph& g, std::vector<at::Tensor> inps) {
+ return debugGetFusedKernelCode(g, inps);
});
// NOLINTNEXTLINE(bugprone-unused-raii)