Register CPU/CUDA fuser dynamically (#15887)
authorZachary DeVito <zdevito@fb.com>
Fri, 11 Jan 2019 18:45:40 +0000 (10:45 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 18:50:35 +0000 (10:50 -0800)
Summary:
This avoids a bunch of conditional compilation logic
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15887

Reviewed By: eellison

Differential Revision: D13613239

Pulled By: zdevito

fbshipit-source-id: a18fc69676b3ef19b4469ab58d8714d1f6efccbb

33 files changed:
CONTRIBUTING.md
test/cpp/jit/tests.h
test/test_jit.py
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/jit/fuser/arg_spec.h
torch/csrc/jit/fuser/codegen.cpp
torch/csrc/jit/fuser/codegen.h
torch/csrc/jit/fuser/compiler.cpp
torch/csrc/jit/fuser/compiler.h
torch/csrc/jit/fuser/config.h.in [deleted file]
torch/csrc/jit/fuser/cpu/dynamic_library.h
torch/csrc/jit/fuser/cpu/dynamic_library_unix.cpp [new file with mode: 0644]
torch/csrc/jit/fuser/cpu/dynamic_library_win.cpp [new file with mode: 0644]
torch/csrc/jit/fuser/cpu/fused_kernel.cpp
torch/csrc/jit/fuser/cpu/fused_kernel.h
torch/csrc/jit/fuser/cpu/resource_strings.h
torch/csrc/jit/fuser/cpu/temp_file.h
torch/csrc/jit/fuser/cuda/fused_kernel.cpp
torch/csrc/jit/fuser/cuda/fused_kernel.h
torch/csrc/jit/fuser/cuda/resource_strings.h
torch/csrc/jit/fuser/executor.cpp
torch/csrc/jit/fuser/executor.h
torch/csrc/jit/fuser/fallback.h
torch/csrc/jit/fuser/fused_kernel.h
torch/csrc/jit/fuser/interface.cpp
torch/csrc/jit/fuser/kernel_cache.h
torch/csrc/jit/fuser/kernel_spec.h
torch/csrc/jit/fuser/partition_desc.h
torch/csrc/jit/fuser/tensor_desc.h
torch/csrc/jit/fuser/tensor_info.h
torch/csrc/jit/init.cpp
torch/csrc/jit/passes/graph_fuser.cpp

index ec73437..9946970 100644 (file)
@@ -434,6 +434,11 @@ static_assert(std::is_same(A*, decltype(A::singleton()))::value, "hmm");
   are too large. Splitting such files into separate files helps.
   (Example: `THTensorMath`, `THTensorMoreMath`, `THTensorEvenMoreMath`.)
 
+* MSVC's preprocessor (but not the standard compiler) has a bug
+  where it incorrectly tokenizes raw string literals, ending when it sees a `"`.
+  This causes preprocessor tokens inside the literal like an`#endif`  to be incorrectly
+  treated as preprocessor directives. See https://godbolt.org/z/eVTIJq as an example.
+
 ### Running Clang-Tidy
 
 [Clang-Tidy](https://clang.llvm.org/extra/clang-tidy/index.html) is a C++
index a0c71fd..4ae3eb4 100644 (file)
@@ -941,7 +941,9 @@ void testRegisterFusionCachesKernel(std::ostream& out = std::cout) {
     c.value()->setUniqueName(cname);
     d.value()->setUniqueName(dname);
     graph->registerOutput(d.value());
+    torch::jit::overrideCanFuseOnCPU(true);
     FuseGraph(graph);
+    torch::jit::overrideCanFuseOnCPU(false);
     return graph;
   };
 
index c6cbe30..8e7918b 100644 (file)
@@ -10711,6 +10711,7 @@ class TestFuser(JitTestCase):
 
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
     def test_small_constant_cuda(self):
         def fn_test_small_constant(x, y):
             return (1e-8 * x + 5e-9 * y) * 1e8
index 08a2218..b6511c3 100644 (file)
@@ -100,6 +100,14 @@ torch_sources_no_python_default = [
     "torch/csrc/jit/c10_ops/layer_norm.cpp",
     "torch/csrc/utils/tensor_flatten.cpp",
     "torch/csrc/utils/variadic.cpp",
+    "torch/csrc/jit/fuser/kernel_cache.cpp",
+    "torch/csrc/jit/fuser/compiler.cpp",
+    "torch/csrc/jit/fuser/executor.cpp",
+    "torch/csrc/jit/fuser/codegen.cpp",
+    "torch/csrc/jit/fuser/fallback.cpp",
+    "torch/csrc/jit/fuser/cpu/fused_kernel.cpp",
+    "torch/csrc/jit/fuser/cpu/dynamic_library_unix.cpp",
+    "torch/csrc/jit/fuser/interface.cpp",
 ]
 
 
@@ -129,13 +137,10 @@ def torch_vars():
 
     r["torch_sources_no_python"] = (
         torch_sources_no_python_default
-        + ["torch/csrc/cuda/comm.cpp", "torch/csrc/cuda/nccl.cpp"]
-        + native.glob(["torch/csrc/jit/fuser/**/*.cpp"])
+        + ["torch/csrc/cuda/comm.cpp", "torch/csrc/cuda/nccl.cpp", "torch/csrc/jit/fuser/cuda/fused_kernel.cpp"]
     )
 
-    r["torch_sources_no_python_cpu"] = torch_sources_no_python_default + native.glob(
-        ["torch/csrc/jit/fuser/**/*.cpp"], exclude=["torch/csrc/jit/fuser/cuda/*.cpp"]
-    )
+    r["torch_sources_no_python_cpu"] = torch_sources_no_python_default
 
     r["torch_csrc_flags"] = {
         "compiler_flags": [
@@ -172,8 +177,6 @@ def torch_vars():
             "-Icaffe2/torch/csrc",
             "-Icaffe2/torch/csrc/nn",
             "-Icaffe2/torch/lib",
-            "-DUSE_CPU_FUSER_FBCODE=1",
-            "-DUSE_CUDA_FUSER_FBCODE=1",
         ],
     }
 
@@ -185,7 +188,5 @@ def torch_vars():
         "-Icaffe2/torch/csrc",
         "-Icaffe2/torch/csrc/nn",
         "-Icaffe2/torch/lib",
-        "-DUSE_CPU_FUSER_FBCODE=1",
-        "-DUSE_CUDA_FUSER_FBCODE=0",
     ]
     return r
index beaaa72..309fcd4 100644 (file)
@@ -203,36 +203,29 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/c10_ops/layer_norm.cpp
   ${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
   ${TORCH_SRC_DIR}/csrc/utils/variadic.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/fuser/kernel_cache.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/fuser/compiler.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/fuser/executor.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/fuser/codegen.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/fuser/fallback.cpp
   ${TORCH_ROOT}/test/cpp/jit/no-gtest.cpp
   )
 
-SET(USE_CPU_FUSER 0)
-if (NOT WIN32)
-  SET(USE_CPU_FUSER 1)
-
+if (WIN32)
   list(APPEND TORCH_SRCS
-    ${TORCH_SRC_DIR}/csrc/jit/fuser/kernel_cache.cpp
-    ${TORCH_SRC_DIR}/csrc/jit/fuser/compiler.cpp
-    ${TORCH_SRC_DIR}/csrc/jit/fuser/executor.cpp
-    ${TORCH_SRC_DIR}/csrc/jit/fuser/codegen.cpp
-    ${TORCH_SRC_DIR}/csrc/jit/fuser/fallback.cpp
-    ${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp
+    ${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/dynamic_library_win.cpp
   )
-endif()
-
-SET(USE_CUDA_FUSER 0)
-if (USE_CUDA AND NOT USE_ROCM AND NOT WIN32)
-  SET(USE_CUDA_FUSER 1)
-
+else ()
   list(APPEND TORCH_SRCS
-    ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp
+    ${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/dynamic_library_unix.cpp
+    ${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp
   )
-
-endif()
-
-CONFIGURE_FILE(
-    ${TORCH_SRC_DIR}/csrc/jit/fuser/config.h.in
-    ${CMAKE_CURRENT_SOURCE_DIR}/csrc/jit/fuser/config.h)
+  if (USE_CUDA AND NOT USE_ROCM)
+    list(APPEND TORCH_SRCS
+      ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp
+    )
+  endif()
+endif ()
 
 if (NOT NO_API)
   list(APPEND TORCH_SRCS
index d099395..c7c7e23 100644 (file)
@@ -1,7 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER || USE_CPU_FUSER
-
 #include <ATen/ATen.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/fuser/tensor_desc.h>
@@ -60,5 +57,3 @@ struct TORCH_API ArgSpec {
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
index 907f2d1..b62589c 100644 (file)
@@ -4,18 +4,12 @@
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/code_template.h>
 #include <torch/csrc/jit/fuser/compiler.h>
-#include <torch/csrc/jit/fuser/config.h>
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/fuser/tensor_info.h>
 #include <torch/csrc/jit/ir.h>
 
-#if USE_CUDA_FUSER
-#include <torch/csrc/jit/fuser/cuda/resource_strings.h>
-#endif
-
-#if USE_CPU_FUSER
 #include <torch/csrc/jit/fuser/cpu/resource_strings.h>
-#endif
+#include <torch/csrc/jit/fuser/cuda/resource_strings.h>
 
 #include <cmath>
 #include <cstdint>
@@ -314,7 +308,6 @@ std::string generateKernel(
   for (const auto& input : inputs) {
     emitFormal(input.first, input.second);
   }
-  
 
   // Writes output parameters
   for (const auto& output : outputs) {
@@ -387,9 +380,8 @@ std::string generateKernel(
     }
   }
 
-// Includes headers
-// Note: CUDA kernels support halfs and random generation, CPU kernels do not
-#if USE_CUDA_FUSER
+  // Includes headers
+  // Note: CUDA kernels support halfs and random generation, CPU kernels do not
   if (has_half_tensor) {
     env.s("HalfHeader", cuda::half_support_literal);
   } else {
@@ -405,7 +397,6 @@ std::string generateKernel(
     env.s("RandParam", "");
     env.s("RandInit", "");
   }
-#endif // USE_CUDA_FUSER
 
   // Insantiates the CUDA or CPU-specific templates
   env.s("tensorOffsets", tensorOffsets.str());
@@ -414,19 +405,11 @@ std::string generateKernel(
   env.v("argument_loads", argument_loads);
   std::string code_string;
   if (use_cuda) {
-#if USE_CUDA_FUSER
     env.s("type_declarations", cuda::type_declarations_template.format(env));
     code_string = cuda::cuda_compilation_unit_template.format(env);
-#else
-    throw std::runtime_error("CUDA Fusion requested but not supported.");
-#endif // USE_CUDA_FUSER
   } else {
-#if USE_CPU_FUSER
     env.s("type_declarations", cpu::type_declarations_template.format(env));
     code_string = cpu::cpu_compilation_unit_template.format(env);
-#else
-    throw std::runtime_error("CPU Fusion requested but not supported");
-#endif // USE_CPU_FUSER
   }
 
   if (debugFuser()) {
index 21ce507..1135cfc 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/fuser/arg_spec.h>
@@ -29,5 +27,3 @@ TORCH_API std::string generateKernel(
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
index 12cafc0..12e0bce 100644 (file)
 #include <torch/csrc/jit/type.h>
 #include "torch/csrc/jit/fuser/interface.h"
 
-#if USE_CUDA_FUSER
-#include <torch/csrc/jit/fuser/cuda/fused_kernel.h>
-#endif // USE_CUDA_FUSER
-
-#if USE_CPU_FUSER
-#include <torch/csrc/jit/fuser/cpu/fused_kernel.h>
-#endif // USE_CUDA_FUSER
-
 #include <atomic>
 #include <iostream>
 #include <memory>
@@ -35,6 +27,32 @@ namespace torch {
 namespace jit {
 namespace fuser {
 
+std::mutex fusion_backends_lock_;
+static std::unordered_map<at::Device::Type, FusedKernelConstructor>&
+getFusionBackends() {
+  static std::unordered_map<at::Device::Type, FusedKernelConstructor>
+      fusion_backends;
+  return fusion_backends;
+}
+
+void registerFusionBackend(
+    at::Device::Type backend_type,
+    FusedKernelConstructor ctor) {
+  std::lock_guard<std::mutex> guard(fusion_backends_lock_);
+  getFusionBackends()[backend_type] = std::move(ctor);
+}
+
+bool hasFusionBackend(at::Device::Type backend_type) {
+  std::lock_guard<std::mutex> guard(fusion_backends_lock_);
+  return getFusionBackends().count(backend_type);
+}
+
+const FusedKernelConstructor& getConstructor(at::Device::Type backend_type) {
+  std::lock_guard<std::mutex> guard(fusion_backends_lock_);
+  return getFusionBackends().at(backend_type);
+}
+
+
 // Counter for number of kernels compiled, used for debugging and
 // creating arbitrary kernel names.
 static std::atomic<size_t> next_kernel_id{0};
@@ -232,38 +250,19 @@ std::shared_ptr<FusedKernel> compileKernel(
 
   const std::string name = "kernel_" + std::to_string(next_kernel_id++);
   const bool use_cuda = device.is_cuda();
-  std::string code = generateKernel(name, *graph, flat_inputs, flat_outputs, use_cuda);
-  std::shared_ptr<FusedKernel> fused_kernel;
-  if (use_cuda) {
-#if USE_CUDA_FUSER
-    fused_kernel = std::make_shared<cuda::FusedKernelCUDA>(
-        device.index(),
-        name,
-        code,
-        input_desc,
-        output_desc,
-        chunk_desc,
-        concat_desc,
-        spec.hasRandom());
-#else
-    throw std::runtime_error("CUDA Fusion is not supported on this build.");
-#endif // USE_CUDA_FUSER
-  } else {
-#if USE_CPU_FUSER
-    fused_kernel = std::make_shared<cpu::FusedKernelCPU>(
-        name,
-        code,
-        input_desc,
-        output_desc,
-        chunk_desc,
-        concat_desc,
-        spec.hasRandom());
-#else
-    throw std::runtime_error("CPU Fusion is not supported on this build.");
-#endif // USE_CPU_FUSER
-  }
-
-  return fused_kernel;
+  std::string code =
+      generateKernel(name, *graph, flat_inputs, flat_outputs, use_cuda);
+  const FusedKernelConstructor& kernel_ctor =
+      getConstructor(use_cuda ? at::DeviceType::CUDA : at::DeviceType::CPU);
+  return kernel_ctor(
+      device.index(),
+      name,
+      code,
+      input_desc,
+      output_desc,
+      chunk_desc,
+      concat_desc,
+      spec.hasRandom());
 }
 
 } // namespace fuser
index 38e1ef1..a16f0d9 100644 (file)
@@ -1,10 +1,7 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/fuser/arg_spec.h>
-#include <torch/csrc/jit/fuser/config.h>
 #include <torch/csrc/jit/fuser/fused_kernel.h>
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/fuser/kernel_spec.h>
@@ -36,8 +33,28 @@ TORCH_API size_t nCompiledKernels();
 
 TORCH_API int debugFuser();
 
+using FusedKernelConstructor = std::function<std::shared_ptr<FusedKernel>(
+    int16_t device,
+    std::string name,
+    std::string code,
+    std::vector<TensorDesc> input_desc,
+    std::vector<TensorDesc> output_desc,
+    std::vector<PartitionDesc> chunk_desc,
+    std::vector<PartitionDesc> concat_desc,
+    bool has_random)>;
+
+TORCH_API void registerFusionBackend(
+    at::Device::Type backend_type,
+    FusedKernelConstructor ctor);
+TORCH_API bool hasFusionBackend(at::Device::Type backend_type);
+struct TORCH_API RegisterFusionBackend {
+  RegisterFusionBackend(
+      at::Device::Type backend_type,
+      FusedKernelConstructor ctor) {
+    registerFusionBackend(backend_type, std::move(ctor));
+  }
+};
+
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
diff --git a/torch/csrc/jit/fuser/config.h.in b/torch/csrc/jit/fuser/config.h.in
deleted file mode 100644 (file)
index 02306ed..0000000
+++ /dev/null
@@ -1,6 +0,0 @@
-#pragma once
-
-// clang-format off
-#define USE_CPU_FUSER @USE_CPU_FUSER@
-#define USE_CUDA_FUSER @USE_CUDA_FUSER@
-// clang-format on
index 25f8e39..0d380ee 100644 (file)
@@ -1,43 +1,21 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CPU_FUSER
 
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/utils/disallow_copy.h>
 
-#include <dlfcn.h>
-
 namespace torch {
 namespace jit {
 namespace fuser {
 namespace cpu {
 
-static void* checkDL(void* x) {
-  if (!x) {
-    AT_ERROR("error in dlopen or dlsym: ", dlerror());
-  }
-
-  return x;
-}
-
 struct DynamicLibrary {
   TH_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
 
-  DynamicLibrary(const char* name) {
-    // NOLINTNEXTLINE(hicpp-signed-bitwise)
-    handle = checkDL(dlopen(name, RTLD_LOCAL | RTLD_NOW));
-  }
+  DynamicLibrary(const char* name);
 
-  void* sym(const char* name) {
-    JIT_ASSERT(handle);
-    return checkDL(dlsym(handle, name));
-  }
+  void* sym(const char* name);
 
-  ~DynamicLibrary() {
-    if (!handle)
-      return;
-    dlclose(handle);
-  }
+  ~DynamicLibrary();
 
  private:
   void* handle = nullptr;
@@ -47,5 +25,3 @@ struct DynamicLibrary {
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CPU_FUSER
diff --git a/torch/csrc/jit/fuser/cpu/dynamic_library_unix.cpp b/torch/csrc/jit/fuser/cpu/dynamic_library_unix.cpp
new file mode 100644 (file)
index 0000000..1289a70
--- /dev/null
@@ -0,0 +1,39 @@
+
+#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
+#include <torch/csrc/utils/disallow_copy.h>
+
+#include <dlfcn.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cpu {
+
+static void* checkDL(void* x) {
+  if (!x) {
+    AT_ERROR("error in dlopen or dlsym: ", dlerror());
+  }
+
+  return x;
+}
+DynamicLibrary::DynamicLibrary(const char* name) {
+  // NOLINTNEXTLINE(hicpp-signed-bitwise)
+  handle = checkDL(dlopen(name, RTLD_LOCAL | RTLD_NOW));
+}
+
+void* DynamicLibrary::sym(const char* name) {
+  JIT_ASSERT(handle);
+  return checkDL(dlsym(handle, name));
+}
+
+DynamicLibrary::~DynamicLibrary() {
+  if (!handle)
+    return;
+  dlclose(handle);
+}
+
+} // namespace cpu
+} // namespace fuser
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/fuser/cpu/dynamic_library_win.cpp b/torch/csrc/jit/fuser/cpu/dynamic_library_win.cpp
new file mode 100644 (file)
index 0000000..4393a61
--- /dev/null
@@ -0,0 +1,24 @@
+#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
+#include <torch/csrc/utils/disallow_copy.h>
+
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cpu {
+
+DynamicLibrary::DynamicLibrary(const char* name) {
+  // NOLINTNEXTLINE(hicpp-signed-bitwise)
+  AT_ERROR("NYI: DynamicLibrary on Windows");
+}
+
+void* DynamicLibrary::sym(const char* name) {
+  AT_ERROR("NYI: DynamicLibrary on Windows");
+}
+
+DynamicLibrary::~DynamicLibrary() {}
+
+} // namespace cpu
+} // namespace fuser
+} // namespace jit
+} // namespace torch
index dbe954a..6cb070b 100644 (file)
@@ -1,5 +1,4 @@
 #include <torch/csrc/jit/fuser/cpu/fused_kernel.h>
-
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/code_template.h>
 #include <torch/csrc/jit/fuser/compiler.h>
@@ -131,6 +130,26 @@ FusedKernelCPU::FusedKernelCPU(
 #pragma GCC diagnostic pop
 }
 
+static std::shared_ptr<FusedKernel> createFusionKernel(
+    int16_t device,
+    std::string name,
+    std::string code,
+    std::vector<TensorDesc> input_desc,
+    std::vector<TensorDesc> output_desc,
+    std::vector<PartitionDesc> chunk_desc,
+    std::vector<PartitionDesc> concat_desc,
+    bool has_random) {
+  return std::make_shared<FusedKernelCPU>(
+      std::move(name),
+      std::move(code),
+      std::move(input_desc),
+      std::move(output_desc),
+      std::move(chunk_desc),
+      std::move(concat_desc),
+      has_random);
+}
+
+RegisterFusionBackend reg(at::DeviceType::CPU, createFusionKernel);
 } // namespace cpu
 } // namespace fuser
 } // namespace jit
index 272c837..3116f3c 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CPU_FUSER
 
 #include <ATen/ATen.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
@@ -46,5 +44,3 @@ struct TORCH_API FusedKernelCPU : public ::torch::jit::fuser::FusedKernel {
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CPU_FUSER
index 8d9e13a..1314459 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CPU_FUSER
 
 #include <torch/csrc/jit/code_template.h>
 
@@ -67,5 +65,3 @@ void ${kernelName}(IndexType totalElements, void ** args) {
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CPU_FUSER
index b889974..fd782a1 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CPU_FUSER
 
 #include <ATen/ATen.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
@@ -68,5 +66,3 @@ struct TempFile {
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CPU_FUSER
index 3f6ba60..bd287d7 100644 (file)
@@ -1,4 +1,5 @@
 #include <torch/csrc/jit/fuser/cuda/fused_kernel.h>
+#include <torch/csrc/jit/fuser/compiler.h>
 
 #include <ATen/cuda/CUDAContext.h>
 #include <THC/THC.h>
@@ -198,6 +199,28 @@ void FusedKernelCUDA::launch_raw(
   at::cuda::set_device(prior_device);
 }
 
+static std::shared_ptr<FusedKernel> createFusionKernel(
+    int16_t device,
+    std::string name,
+    std::string code,
+    std::vector<TensorDesc> input_desc,
+    std::vector<TensorDesc> output_desc,
+    std::vector<PartitionDesc> chunk_desc,
+    std::vector<PartitionDesc> concat_desc,
+    bool has_random) {
+  return std::make_shared<FusedKernelCUDA>(
+      device,
+      std::move(name),
+      std::move(code),
+      std::move(input_desc),
+      std::move(output_desc),
+      std::move(chunk_desc),
+      std::move(concat_desc),
+      has_random);
+}
+
+RegisterFusionBackend reg(at::DeviceType::CUDA, createFusionKernel);
+
 } // namespace cuda
 } // namespace fuser
 } // namespace jit
index 233c001..c14a5ff 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER
 
 #include <ATen/ATen.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
@@ -60,5 +58,3 @@ struct TORCH_API FusedKernelCUDA : public ::torch::jit::fuser::FusedKernel {
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CUDA_FUSER
index ce56b81..27f4894 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/code_template.h>
@@ -191,6 +189,15 @@ constexpr auto half_support_literal = R"(
       asm("{  cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
       return val;
     }
+)"
+// MSVC's preprocesor (but not the standard compiler) has a bug
+// where it incorrectly tokenizes raw string literals, ending when it sees a "
+// this causes the #endif in this string literal to be treated as a preprocessor
+// token which, in turn, cause sccache on windows CI to fail.
+// See https://godbolt.org/z/eVTIJq as an example.
+// This workaround uses string-pasting to separate the " and the #endif into different
+// strings
+R"(
   #endif /* defined(__CUDACC__) */
 #endif /* defined(__cplusplus) */
 #undef __HALF_TO_US
@@ -203,5 +210,3 @@ typedef __half half;
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CUDA_FUSER
index b8f43bc..446bd6f 100644 (file)
@@ -4,7 +4,6 @@
 #include <ATen/ExpandUtils.h>
 #include <c10/util/Optional.h>
 #include <torch/csrc/jit/fuser/compiler.h>
-#include <torch/csrc/jit/fuser/config.h>
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/fuser/kernel_cache.h>
 #include <torch/csrc/jit/fuser/kernel_spec.h>
index 9af2cd9..852e0bd 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/stack.h>
@@ -18,5 +16,3 @@ TORCH_API bool runFusion(const int64_t key, Stack& stack);
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
index ab55218..8eeccce 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <torch/csrc/jit/stack.h>
 
@@ -15,5 +13,3 @@ void runFallback(int64_t key, Stack& stack);
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
index 39a590c..46be16a 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <ATen/ATen.h>
 #include <torch/csrc/jit/fuser/partition_desc.h>
@@ -95,5 +93,3 @@ struct FusedKernel {
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
index 4e63c6f..9e2509a 100644 (file)
@@ -1,11 +1,8 @@
 #include <torch/csrc/jit/fuser/interface.h>
 
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER || USE_CPU_FUSER
 #include <torch/csrc/jit/fuser/compiler.h>
 #include <torch/csrc/jit/fuser/executor.h>
 #include <torch/csrc/jit/fuser/fallback.h>
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <stdexcept>
 
@@ -20,37 +17,22 @@ bool cpu_fuser_enabled = false;
 } // namespace detail
 
 int64_t registerFusion(const Node* fusion_group) {
-#if USE_CUDA_FUSER || USE_CPU_FUSER
   return fuser::registerFusion(fusion_group);
-#else
-  throw std::runtime_error("Fusion not supported for this build.");
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
 }
 
 void runFusion(const int64_t key, Stack& stack) {
-#if USE_CUDA_FUSER || USE_CPU_FUSER
   const auto result = fuser::runFusion(key, stack);
   if (!result)
     fuser::runFallback(key, stack);
-#else
-  throw std::runtime_error("Fusion not supported for this build.");
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
 }
 
 bool canFuseOnCPU() {
-#if USE_CPU_FUSER
-  return detail::cpu_fuser_enabled;
-#endif // USE_CPU_FUSER
-
-  return false;
+  return fuser::hasFusionBackend(at::DeviceType::CPU) &&
+      detail::cpu_fuser_enabled;
 }
 
 bool canFuseOnGPU() {
-#if USE_CUDA_FUSER
-  return true;
-#endif // USE_CUDA_FUSER
-
-  return false;
+  return fuser::hasFusionBackend(at::DeviceType::CUDA);
 }
 
 void overrideCanFuseOnCPU(bool value) {
@@ -62,7 +44,6 @@ void overrideCanFuseOnCPU(bool value) {
 std::vector<at::Tensor> debugLaunchGraph(
     Graph& graph,
     at::ArrayRef<at::Tensor> inputs) {
-#if USE_CUDA_FUSER || USE_CPU_FUSER
   // Creates a fusion group node
   auto wrapper_graph = std::make_shared<Graph>();
   Node* fusion_group =
@@ -80,17 +61,10 @@ std::vector<at::Tensor> debugLaunchGraph(
   const auto key = fuser::registerFusion(fusion_group);
   fuser::runFusion(key, stack);
   return fmap(stack, [](const IValue& iv) { return iv.toTensor(); });
-#else
-  throw std::runtime_error("Fusion not supported for this build.");
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
 }
 
 size_t nCompiledKernels() {
-#if USE_CUDA_FUSER || USE_CPU_FUSER
   return fuser::nCompiledKernels();
-#else
-  return 0;
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
 }
 
 } // namespace jit
index 792591c..56f81ad 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <c10/util/Optional.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
@@ -36,5 +34,3 @@ TORCH_API int64_t debugNumCachedKernelSpecs();
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
index bb5a873..3e2d995 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <ATen/ATen.h>
 #include <c10/util/Optional.h>
@@ -58,23 +56,20 @@ struct TORCH_API KernelSpec {
   // Note: assumes the spec is a single block
   // Note: This is the appropriate place to generalize if you want to add other
   //  passes to upfront compilation that walk the graph.
-  KernelSpec(
-    const int64_t _key, 
-    const std::shared_ptr<Graph>& _graph)
-  : key_{_key},
-    graph_{_graph},
-    code_{_graph},
-    nInputs_{_graph->inputs().size()},
-    inputBroadcastGroups_{},
-    inputChunks_{},
-    has_random_{false},
-    kernels_{} {
-    
+  KernelSpec(const int64_t _key, const std::shared_ptr<Graph>& _graph)
+      : key_{_key},
+        graph_{_graph},
+        code_{_graph},
+        nInputs_{_graph->inputs().size()},
+        inputBroadcastGroups_{},
+        inputChunks_{},
+        has_random_{false},
+        kernels_{} {
     for (const auto& n : graph_->nodes()) {
       if (n->kind() == aten::rand_like) {
         has_random_ = true;
         break;
-      } 
+      }
     }
   }
 
@@ -142,5 +137,3 @@ struct TORCH_API KernelSpec {
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CPU_FUSER || USE_CUDA_FUSER
index 16408d5..5c57d3e 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/assertions.h>
@@ -62,5 +60,3 @@ struct TORCH_API PartitionDesc {
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
index fb02867..908c189 100644 (file)
@@ -1,6 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <ATen/ATen.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
@@ -99,5 +97,3 @@ inline std::ostream& operator<<(std::ostream& out, const TensorDesc& d) {
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
index 161cf3b..157629b 100644 (file)
@@ -1,7 +1,4 @@
 #pragma once
-#include <torch/csrc/jit/fuser/config.h>
-#if USE_CUDA_FUSER || USE_CPU_FUSER
-
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <cstdint>
@@ -29,5 +26,3 @@ struct TORCH_API TensorInfo {
 } // namespace fuser
 } // namespace jit
 } // namespace torch
-
-#endif // USE_CUDA_FUSER || USE_CPU_FUSER
index f9b4e88..620ad9c 100644 (file)
@@ -92,15 +92,13 @@ void initJITBindings(PyObject* module) {
 
   py::register_exception<JITException>(m, "JITException");
 
-  py::class_<python::IODescriptor>(
+  py::class_<python::IODescriptor> iodescriptor(
       m, "IODescriptor"); // NOLINT(bugprone-unused-raii)
 
   m.def("_jit_init", loadPythonClasses)
-#if USE_CUDA_FUSER || USE_CPU_FUSER
       .def(
           "_jit_debug_fuser_num_cached_kernel_specs",
           torch::jit::fuser::debugNumCachedKernelSpecs)
-#endif
       .def("_jit_pass_onnx", ToONNX)
       .def("_jit_pass_lower_all_tuples", LowerAllTuples)
       .def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
index 39ae5bc..fd7e3e1 100644 (file)
 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
 #include <torch/csrc/jit/script/compiler.h>
 #include <torch/csrc/jit/symbolic_variable.h>
-#include <unordered_map>
 
-#ifdef USE_CUDA
-#include <cuda.h> // for CUDA_VERSION
-#endif
+#include <unordered_map>
 
 namespace torch {
 namespace jit {
@@ -1207,19 +1204,16 @@ void PeepholeOptimizeShapeExpressions(Block* block) {
 } // anonymous namespace
 
 void FuseGraph(std::shared_ptr<Graph>& graph) {
-// NYI on Windows
-#ifndef _WIN32
-
-  GraphFuser(graph->block(), graph).run();
-  // After FuseGraph some common subexpressions may come back
-  EliminateCommonSubexpression(graph);
-  // We might have emitted a fair amount of useless shape propagating code, so
-  // remove it
-  EliminateDeadCode(graph);
-  // Improve the quality of shape propagation code that was left
-  PeepholeOptimizeShapeExpressions(graph->block());
-
-#endif
+  if (canFuseOnCPU() || canFuseOnGPU()) {
+    GraphFuser(graph->block(), graph).run();
+    // After FuseGraph some common subexpressions may come back
+    EliminateCommonSubexpression(graph);
+    // We might have emitted a fair amount of useless shape propagating code, so
+    // remove it
+    EliminateDeadCode(graph);
+    // Improve the quality of shape propagation code that was left
+    PeepholeOptimizeShapeExpressions(graph->block());
+  }
 }
 
 } // namespace jit