[torch/deploy] add torch.distributed to build (#63918)
authorMichael Suo <suo@fb.com>
Fri, 27 Aug 2021 03:54:54 +0000 (20:54 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 03:58:44 +0000 (20:58 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63918

Previously we were building with `USE_DISTRIBUTED` off, because c10d was built as a separately library for historical reasons. Since then, lw has merged the c10d build into libtorch, so this is fairly easy to turn on.

Differential Revision:
D30492442

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.intern.facebook.com/intern/diff/D30492442/)!
D30492442
D30492442

Test Plan: added a unit test

Reviewed By: wconstab

Pulled By: suo

fbshipit-source-id: 843b8fcf349a72a7f6fcbd1fcc8961268690fb8c

tools/build_variables.bzl
torch/CMakeLists.txt
torch/csrc/deploy/example/generate_examples.py
torch/csrc/deploy/test_deploy.cpp
torch/csrc/deploy/test_deploy_gpu.cpp
torch/csrc/distributed/c10d/frontend.cpp
torch/csrc/distributed/c10d/frontend.hpp
torch/csrc/distributed/c10d/frontend_cuda.cpp [new file with mode: 0644]
torch/csrc/distributed/c10d/frontend_cuda.hpp [new file with mode: 0644]
torch/csrc/distributed/c10d/init.cpp
torch/csrc/distributed/rpc/request_callback_impl.cpp

index 3f6225358ac978ce7026b153c23ce7fb1d022a44..650830b3143f0095d21ca908f068d23d07bd5ec8 100644 (file)
@@ -340,6 +340,7 @@ libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_tr
 
 # These files are the only ones that are supported on Windows.
 libtorch_distributed_base_sources = [
+    "torch/csrc/distributed/c10d/frontend.cpp",
     "torch/csrc/distributed/c10d/comm.cpp",
     "torch/csrc/distributed/c10d/default_comm_hooks.cpp",
     "torch/csrc/distributed/c10d/FileStore.cpp",
@@ -351,6 +352,7 @@ libtorch_distributed_base_sources = [
     "torch/csrc/distributed/c10d/ProcessGroupGloo.cpp",
     "torch/csrc/distributed/c10d/ProcessGroupMPI.cpp",
     "torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp",
+    "torch/csrc/distributed/c10d/quantization/quantization.cpp",
     "torch/csrc/distributed/c10d/reducer.cpp",
     "torch/csrc/distributed/c10d/sequence_num.cpp",
     "torch/csrc/distributed/c10d/Store.cpp",
@@ -548,6 +550,7 @@ libtorch_cuda_distributed_base_sources = [
 
 # These files are only supported on Linux (and others) but not on Windows.
 libtorch_cuda_distributed_extra_sources = [
+    "torch/csrc/distributed/c10d/frontend_cuda.cpp",
     "torch/csrc/distributed/c10d/NCCLUtils.cpp",
     "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
     "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
@@ -735,10 +738,8 @@ libtorch_python_core_sources = [
 ]
 
 libtorch_python_distributed_core_sources = [
-    "torch/csrc/distributed/c10d/frontend.cpp",
     "torch/csrc/distributed/c10d/init.cpp",
     "torch/csrc/distributed/c10d/python_comm_hook.cpp",
-    "torch/csrc/distributed/c10d/quantization/quantization.cpp",
 ]
 
 libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [
index 761605fadcce8cb66d8e552ef62af39d1c4ac305..7c086855612ca9359c2f94ad548761d639067331 100644 (file)
@@ -214,11 +214,78 @@ add_custom_command(
     WORKING_DIRECTORY
     "${TORCH_ROOT}"
 )
+if(USE_DISTRIBUTED)
+    if(WIN32)
+      append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS)
+    else()
+      append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS)
+    endif()
+    # Disable certain warnings for GCC-9.X
+    if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0.0))
+      set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type")
+      set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type")
+      set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type")
+    endif()
+    # NCCL is a private dependency of libtorch, but libtorch_python includes
+    # some private headers of libtorch, which in turn include NCCL. As a hacky
+    # alternative to making NCCL a public dependency of libtorch, we make it
+    # a private dependency of libtorch_python as well.
+    if(USE_NCCL)
+      list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl)
+    endif()
+    # Same for MPI.
+    if(USE_MPI)
+      list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${MPI_CXX_LIBRARIES})
+    endif()
+    list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D)
+
+endif()
+
+if(USE_NCCL AND NOT WIN32)
+    list(APPEND TORCH_PYTHON_SRCS
+      ${TORCH_SRC_DIR}/csrc/cuda/python_nccl.cpp)
+    list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_NCCL)
+endif()
+
 
 # WARNING- any TORCH_PYTHON_COMPILE_DEFINITIONS above this line
 #          affect both torch_python and DEPLOY interpreter.
 if(USE_DEPLOY)
   add_library(torch_python_obj OBJECT ${TORCH_PYTHON_SRCS})
+  if(USE_DISTRIBUTED)
+    # Set c10d-related compile definitions. For a "normal" build of
+    # libtorch_python, these are set on libtorch as PUBLIC so they are
+    # automatically propagated when libtorch_python links against libtorch. But
+    # since in the deploy build we are intentionally *not* linking against
+    # libtorch, we need to set them manually here.
+    list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_DISTRIBUTED)
+    if(USE_GLOO AND USE_C10D_GLOO)
+      list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_GLOO)
+    endif()
+    if(USE_NCCL AND USE_C10D_NCCL)
+        list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_NCCL)
+        # Put nccl headers on the include path. We are specifically only setting
+        # include dirs here instead of linking against __caffe2_nccl wholesale
+        # to ensure we aren't accidentally replicating the nccl lib.
+        target_include_directories(torch_python_obj PRIVATE $<TARGET_PROPERTY:__caffe2_nccl,INTERFACE_INCLUDE_DIRECTORIES>)
+    endif()
+    if(USE_MPI AND USE_C10D_MPI)
+      list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_MPI)
+    endif()
+
+    # Pass USE_RPC in order to reduce use of
+    # #if defined(USE_DISTRIBUTED) && !defined(_WIN32)
+    # need to be removed when RPC is supported
+    if(NOT WIN32)
+      target_compile_definitions(torch_cpu PUBLIC USE_RPC)
+    endif()
+    if(USE_TENSORPIPE)
+      list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_TENSORPIPE)
+    endif()
+
+    # Set c10d-related include directories as well.
+    target_include_directories(torch_python_obj PRIVATE $<BUILD_INTERFACE:${TORCH_SRC_DIR}/csrc/distributed>)
+  endif()
   target_compile_definitions(torch_python_obj PRIVATE "-DTHP_BUILD_MAIN_LIB -DUSE_DEPLOY")
 
   target_compile_definitions(torch_python_obj PRIVATE ${TORCH_PYTHON_COMPILE_DEFINITIONS})
@@ -268,38 +335,6 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
   set_source_files_properties(${TORCH_SRC_DIR}/csrc/utils/throughput_benchmark.cpp PROPERTIES COMPILE_FLAGS -Wno-attributes)
 endif()
 
-if(USE_DISTRIBUTED)
-    if(WIN32)
-      append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS)
-    else()
-      append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS)
-    endif()
-    # Disable certain warnings for GCC-9.X
-    if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0.0))
-      set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type")
-      set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type")
-      set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type")
-    endif()
-    # NCCL is a private dependency of libtorch, but libtorch_python includes
-    # some private headers of libtorch, which in turn include NCCL. As a hacky
-    # alternative to making NCCL a public dependency of libtorch, we make it
-    # a private dependency of libtorch_python as well.
-    if(USE_NCCL)
-      list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl)
-    endif()
-    # Same for MPI.
-    if(USE_MPI)
-      list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${MPI_CXX_LIBRARIES})
-    endif()
-    list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D)
-endif()
-
-if(USE_NCCL AND NOT WIN32)
-    list(APPEND TORCH_PYTHON_SRCS
-      ${TORCH_SRC_DIR}/csrc/cuda/python_nccl.cpp)
-    list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_NCCL)
-endif()
-
 add_library(torch_python SHARED ${TORCH_PYTHON_SRCS})
 if(HAVE_SOVERSION)
   set_target_properties(torch_python PROPERTIES
index 65f244373d954f4e1c71bb602d1fc484270a6a69..0f279d922157c44371e596a7a4bbca7d919df771 100644 (file)
@@ -79,3 +79,6 @@ if __name__ == "__main__":
         e.save_pickle("fn", "fn.pkl", load_library)
 
     generate_fx_example()
+
+    with PackageExporter(p / "uses_distributed") as e:
+        e.save_source_string("uses_distributed", "import torch.distributed; assert torch.distributed.is_available()")
index a004db1e0d232e161df520d9a2592c3279080206..53456cacca2add1fc9c2262fa39b4157367c4653 100644 (file)
@@ -366,3 +366,15 @@ TEST(TorchpyTest, SharedLibraryLoad) {
   }
 }
 #endif
+
+TEST(TorchpyTest, UsesDistributed) {
+  const auto model_filename = path(
+      "USES_DISTRIBUTED",
+      "torch/csrc/deploy/example/generated/uses_distributed");
+  torch::deploy::InterpreterManager m(1);
+  torch::deploy::Package p = m.load_package(model_filename);
+  {
+    auto I = p.acquire_session();
+    I.self.attr("import_module")({"uses_distributed"});
+  }
+}
index 8287d1683edca74eff44d8496e9ab6829232a6fa..4e990adcd9e899a36a195d132afc9358988b7a3b 100644 (file)
@@ -53,3 +53,15 @@ TEST(TorchDeployGPUTest, SimpleModel) {
 
   ASSERT_TRUE(ref_output.allclose(output, 1e-03, 1e-05));
 }
+
+TEST(TorchDeployGPUTest, UsesDistributed) {
+  const auto model_filename = path(
+      "USES_DISTRIBUTED",
+      "torch/csrc/deploy/example/generated/uses_distributed");
+  torch::deploy::InterpreterManager m(1);
+  torch::deploy::Package p = m.load_package(model_filename);
+  {
+    auto I = p.acquire_session();
+    I.self.attr("import_module")({"uses_distributed"});
+  }
+}
index b65cba79884af3844264f39794c81c4c39a6e386..e5b59f28982f678341dbaa34ec66e9ec5f8c417d 100644 (file)
@@ -3,10 +3,11 @@
 #include <ATen/core/Tensor.h>
 #include <ATen/Functions.h>
 #include <c10/util/Exception.h>
-#include <c10d/PrefixStore.hpp>
 #include <c10d/FileStore.hpp>
 #include <c10d/TCPStore.hpp>
 #include <c10d/Utils.hpp>
+#include <torch/csrc/distributed/c10d/quantization/quantization.h>
+#include <torch/library.h>
 
 #include <chrono>
 #include <sstream>
 #include <c10d/ProcessGroupGloo.hpp>
 #endif
 
-#ifdef USE_C10D_NCCL
-#include <c10d/ProcessGroupNCCL.hpp>
-#endif
-
 #ifdef USE_C10D_MPI
 #include <c10d/ProcessGroupMPI.hpp>
 #endif
@@ -29,6 +26,20 @@ namespace c10d {
 
 namespace {
 
+// Constant initialization, so it is guaranteed to be initialized before
+// static initialization calls which may invoke registerNCCLProcessGroupProvider
+const NCCLProcessGroupProvider stubProvider;
+constexpr const NCCLProcessGroupProvider* defaultStubProviderAddr =
+    &stubProvider;
+inline const NCCLProcessGroupProvider*& getNCCLProcessGroupProviderAddress() {
+  static const NCCLProcessGroupProvider* stubs_ = defaultStubProviderAddr;
+  return stubs_;
+}
+
+const NCCLProcessGroupProvider* GetNCCLProcessGroupProvider() {
+  return getNCCLProcessGroupProviderAddress();
+}
+
 void maybePreprocessComplexTensor(at::Tensor& tensor) {
   if(!tensor.is_complex()) {
     return;
@@ -63,6 +74,11 @@ void assertReduceOpSupportsComplexTensor(ReduceOp op) {
 
 }  // namespace anonymous
 
+void registerNCCLProcessGroupProvider(NCCLProcessGroupProvider* provider) {
+  getNCCLProcessGroupProviderAddress() = provider;
+}
+
+
 std::string Backend::get(const std::string& backend_type) {
   return backend_type;
 }
@@ -207,17 +223,7 @@ c10::intrusive_ptr<ProcessGroup> DistributedC10d::newProcessGroupHelper(
           "Attempting to create GLOO-based process group while GLOO is either not enabled or built");
 #endif // USE_C10D_GLOO
     } else if (backend == "nccl") {
-#ifdef USE_C10D_NCCL
-      auto options = ProcessGroupNCCL::Options::create();
-
-      options->is_high_priority_stream = false;
-      options->timeout = timeout;
-      pg = c10::make_intrusive<ProcessGroupNCCL>(
-          prefix_store, rank, world_size, options);
-#else
-      AT_ERROR(
-          "Attempting to create NCCL-based process group while NCCL is either not enabled or built");
-#endif // USE_C10D_NCCL
+      pg = GetNCCLProcessGroupProvider()->get(prefix_store, rank, world_size, timeout);
     } else {
       // TODO: discuss to figure out how to extend this to third party backends?
       AT_ERROR("Unsupported backend type: ", backend);
@@ -1008,7 +1014,7 @@ void initCustomClassBindings() {
           .def(
               "broadcast",
               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
-                 std::vector<at::Tensor> data) { return self->broadcast(data);
+                  std::vector<at::Tensor> data) { return self->broadcast(data);
           })
           */
           .def(
@@ -1045,14 +1051,14 @@ void initCustomClassBindings() {
           .def(
               "allreduce",
               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
-                 at::Tensor& tensor,
-                 c10::intrusive_ptr<::c10d::ReduceOp> op) {
+                  at::Tensor& tensor,
+                  c10::intrusive_ptr<::c10d::ReduceOp> op) {
                       ::c10d::AllreduceOptions opts;
                       opts.reduceOp = *op;
                       std::vector<at::Tensor> tensors = {tensor};
                       return self->allreduce(tensors, opts);
-                 }
-           )
+                  }
+            )
           */
           // TODO: make AllreduceCoalescedOptions compatible with TorchBind to
           // provide the full API in python.
@@ -1098,8 +1104,8 @@ void initCustomClassBindings() {
           .def(
               "allgather",
               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
-                 std::vector<at::Tensor> output,
-                 at::Tensor input) {
+                  std::vector<at::Tensor> output,
+                  at::Tensor input) {
                 std::vector<std::vector<at::Tensor>> outputs = {
                     std::move(output)};
                 std::vector<at::Tensor> inputs = {std::move(input)};
@@ -1121,8 +1127,8 @@ void initCustomClassBindings() {
           .def(
               "gather",
               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
-                 std::vector<std::vector<at::Tensor>> output_tensors,
-                 std::vector<at::Tensor> input_tensors) {
+                  std::vector<std::vector<at::Tensor>> output_tensors,
+                  std::vector<at::Tensor> input_tensors) {
                 ::c10d::GatherOptions opts;
                 return self->gather(output_tensors, input_tensors, opts);
               })
@@ -1145,8 +1151,8 @@ void initCustomClassBindings() {
           .def(
               "scatter",
               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
-                 std::vector<at::Tensor> outputTensors,
-                 std::vector<std::vector<at::Tensor>> inputTensors) {
+                  std::vector<at::Tensor> outputTensors,
+                  std::vector<std::vector<at::Tensor>> inputTensors) {
                 ::c10d::ScatterOptions opts;
                 self->scatter(outputTensors, inputTensors, opts);
               })
@@ -1169,8 +1175,8 @@ void initCustomClassBindings() {
           // TODO: Enable this method when TorchBind supports
           ReduceScatterOptions. .def( "reduce_scatter",
               [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
-                 std::vector<at::Tensor> outputTensors,
-                 std::vector<std::vector<at::Tensor>> inputTensors) {
+                  std::vector<at::Tensor> outputTensors,
+                  std::vector<std::vector<at::Tensor>> inputTensors) {
                 ::c10d::ReduceScatterOptions opts;
                 return self->reduce_scatter(outputTensors, inputTensors, opts);
               })
@@ -1241,95 +1247,6 @@ void initCustomClassBindings() {
                 return self->barrier(opts);
               });
 
-#ifdef USE_C10D_NCCL
-  // XXX: Ideally the Options of ProcessGroupNCCL should be
-  // bound using `def_readwrite` like in pybind11, but we
-  // didn't do that because: 1. no milisecond support yet
-  // 2. no def_readwrite or property support yet.
-  // TODO: make this binding the same as pybind11
-  static const auto ProcessGroupNCCLOptionsTorchBind =
-      torch::class_<::c10d::ProcessGroupNCCL::Options>(
-          "dist_c10d", "ProcessGroupNCCLOptions")
-          .def(torch::init([](int64_t timeout, bool isHighPriorityStream) {
-            auto opTimeout = std::chrono::milliseconds(timeout);
-            auto opts =
-                ::c10d::ProcessGroupNCCL::Options::create(isHighPriorityStream);
-            opts->timeout = opTimeout;
-            return opts;
-          }));
-
-  static const auto ProcessGroupNCCLTorchBind =
-      torch::class_<::c10d::ProcessGroupNCCL>("dist_c10d", "ProcessGroupNCCL")
-          .def_pickle(
-              [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
-                auto base_process_group =
-                    ::c10::static_intrusive_pointer_cast<::c10d::ProcessGroup>(self);
-                auto name =
-                    ::c10d::DistributedC10d::get()->getNameOfProcessGroup(self);
-                return std::vector<std::string>{name};
-              },
-              [](std::vector<std::string> state) {
-                TORCH_CHECK(
-                    state.size() == 1,
-                    "Expecting exactly 1 state when restoring ProcessGroupNCCL, got: ",
-                    state.size());
-                const auto& process_group_name = state.front();
-                auto base_process_group =
-                    ::c10d::DistributedC10d::get()->getProcessGroupByName(
-                        process_group_name);
-                TORCH_CHECK(
-                    base_process_group.defined(),
-                    "Needed process group not found, ",
-                    "please create a process group with name: ",
-                    process_group_name);
-                c10::intrusive_ptr<::c10d::ProcessGroupNCCL>
-                    process_group_nccl = ::c10::dynamic_intrusive_pointer_cast<
-                        ::c10d::ProcessGroupNCCL>(base_process_group);
-                TORCH_CHECK(
-                    process_group_nccl.defined(),
-                    "Process group ",
-                    process_group_name,
-                    " isn't configured for NCCL backend");
-                return process_group_nccl;
-              })
-          .def(torch::init(
-              [](const c10::intrusive_ptr<::c10d::Store>& store,
-                 int64_t rank,
-                 int64_t size,
-                 c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options> options,
-                 const std::string& name) {
-                auto pg = c10::make_intrusive<::c10d::ProcessGroupNCCL>(
-                    store, rank, size, options);
-                ::c10d::DistributedC10d::get()->registerProcessGroupName(
-                    pg, name);
-                return pg;
-              }))
-          .def(
-              "alltoall_base",
-              [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self,
-                 at::Tensor output,
-                 at::Tensor input,
-                 std::vector<int64_t> outputSplitSizes,
-                 std::vector<int64_t> inputSplitSizes) {
-                return self->alltoall_base(
-                    output,
-                    input,
-                    outputSplitSizes,
-                    inputSplitSizes,
-                    ::c10d::AllToAllOptions());
-              })
-          .def(
-              "size",
-              [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
-                return (int64_t)self->getSize();
-              })
-          .def(
-              "rank",
-              [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
-                return (int64_t)self->getRank();
-              });
-#endif
-
   static const auto DistributedC10dFrontendTorchBind =
       torch::class_<::c10d::DistributedC10d>("dist_c10d", "frontend")
           .def(torch::init([]() { return ::c10d::DistributedC10d::get(); }))
@@ -1344,4 +1261,12 @@ void initCustomClassBindings() {
               &::c10d::DistributedC10d::getNameOfProcessGroup);
 }
 
+TORCH_LIBRARY(q, m) {
+    m.def("_Bfloat16QuantizedToFloat(Tensor input) -> Tensor");
+    m.def("_FloatToBfloat16Quantized(Tensor input) -> Tensor");
+}
+TORCH_LIBRARY_IMPL(q, CPU, m) {
+    m.impl("_Bfloat16QuantizedToFloat", ::torch::distributed::c10d::quantization::_bfloat16_to_float_cpu);
+    m.impl("_FloatToBfloat16Quantized", ::torch::distributed::c10d::quantization::_float_to_bfloat16_cpu);
+}
 } // namespace c10d
index c90cc077b2823f77db49e08a5c88eb3dc235eaf9..b39d8b7a444bffbcee85faae094906b286d4517c 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <ATen/ATen.h>
 #include <c10/util/Optional.h>
+#include <c10d/PrefixStore.hpp>
 #include <c10d/ProcessGroup.hpp>
 #include <c10d/Store.hpp>
 #include <c10d/Types.hpp>
@@ -259,7 +260,26 @@ class TORCH_PYTHON_API DistributedC10d : public torch::CustomClassHolder {
   int64_t group_count_;
 };
 
-// Must be called to initialize Torchbind bindings for c10d.
-void initCustomClassBindings();
+// This class exists as a way to allow us to split NCCL-specific code into a
+// different file. frontend_cuda.cpp will, if USE_C10D_NCCL is defined,
+// override this NCCLProcessGroupProvider with one that will actually do
+// something.
+struct TORCH_API NCCLProcessGroupProvider {
+  virtual c10::intrusive_ptr<ProcessGroup> get(
+      c10::intrusive_ptr<PrefixStore> /*prefix_store*/,
+      int64_t /*rank*/,
+      int64_t /*world_size*/,
+      std::chrono::milliseconds /*timeout*/) const {
+    AT_ERROR(
+        "Attempting to create NCCL-based process group while NCCL is either not enabled or built");
+  }
+
+  virtual ~NCCLProcessGroupProvider() = default;
+};
+
+TORCH_API void registerNCCLProcessGroupProvider(
+    NCCLProcessGroupProvider* provider);
+
+TORCH_API void initCustomClassBindings();
 
 } // namespace c10d
diff --git a/torch/csrc/distributed/c10d/frontend_cuda.cpp b/torch/csrc/distributed/c10d/frontend_cuda.cpp
new file mode 100644 (file)
index 0000000..1b42f13
--- /dev/null
@@ -0,0 +1,136 @@
+#include <torch/csrc/distributed/c10d/frontend_cuda.hpp>
+
+#ifdef USE_C10D_NCCL
+
+#include <c10/util/Exception.h>
+#include <c10d/ProcessGroupNCCL.hpp>
+#include <torch/csrc/distributed/c10d/frontend.hpp>
+#include <torch/csrc/distributed/c10d/quantization/quantization_gpu.h>
+#include <torch/library.h>
+
+namespace c10d {
+
+void initCustomClassBindingsNccl() {
+  // XXX: Ideally the Options of ProcessGroupNCCL should be
+  // bound using `def_readwrite` like in pybind11, but we
+  // didn't do that because: 1. no milisecond support yet
+  // 2. no def_readwrite or property support yet.
+  // TODO: make this binding the same as pybind11
+  static const auto ProcessGroupNCCLOptionsTorchBind =
+      torch::class_<::c10d::ProcessGroupNCCL::Options>(
+          "dist_c10d", "ProcessGroupNCCLOptions")
+          .def(torch::init([](int64_t timeout, bool isHighPriorityStream) {
+            auto opTimeout = std::chrono::milliseconds(timeout);
+            auto opts =
+                ::c10d::ProcessGroupNCCL::Options::create(isHighPriorityStream);
+            opts->timeout = opTimeout;
+            return opts;
+          }));
+
+  static const auto ProcessGroupNCCLTorchBind =
+      torch::class_<::c10d::ProcessGroupNCCL>("dist_c10d", "ProcessGroupNCCL")
+          .def_pickle(
+              [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
+                auto base_process_group =
+                    ::c10::static_intrusive_pointer_cast<::c10d::ProcessGroup>(
+                        self);
+                auto name =
+                    ::c10d::DistributedC10d::get()->getNameOfProcessGroup(self);
+                return std::vector<std::string>{name};
+              },
+              [](std::vector<std::string> state) {
+                TORCH_CHECK(
+                    state.size() == 1,
+                    "Expecting exactly 1 state when restoring ProcessGroupNCCL, got: ",
+                    state.size());
+                const auto& process_group_name = state.front();
+                auto base_process_group =
+                    ::c10d::DistributedC10d::get()->getProcessGroupByName(
+                        process_group_name);
+                TORCH_CHECK(
+                    base_process_group.defined(),
+                    "Needed process group not found, ",
+                    "please create a process group with name: ",
+                    process_group_name);
+                c10::intrusive_ptr<::c10d::ProcessGroupNCCL>
+                    process_group_nccl = ::c10::dynamic_intrusive_pointer_cast<
+                        ::c10d::ProcessGroupNCCL>(base_process_group);
+                TORCH_CHECK(
+                    process_group_nccl.defined(),
+                    "Process group ",
+                    process_group_name,
+                    " isn't configured for NCCL backend");
+                return process_group_nccl;
+              })
+          .def(torch::init(
+              [](const c10::intrusive_ptr<::c10d::Store>& store,
+                 int64_t rank,
+                 int64_t size,
+                 c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options> options,
+                 const std::string& name) {
+                auto pg = c10::make_intrusive<::c10d::ProcessGroupNCCL>(
+                    store, rank, size, options);
+                ::c10d::DistributedC10d::get()->registerProcessGroupName(
+                    pg, name);
+                return pg;
+              }))
+          .def(
+              "alltoall_base",
+              [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self,
+                 at::Tensor output,
+                 at::Tensor input,
+                 std::vector<int64_t> outputSplitSizes,
+                 std::vector<int64_t> inputSplitSizes) {
+                return self->alltoall_base(
+                    output,
+                    input,
+                    outputSplitSizes,
+                    inputSplitSizes,
+                    ::c10d::AllToAllOptions());
+              })
+          .def(
+              "size",
+              [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
+                return (int64_t)self->getSize();
+              })
+          .def(
+              "rank",
+              [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
+                return (int64_t)self->getRank();
+              });
+}
+
+namespace {
+struct RealNCCLProcessGroupProvider : public NCCLProcessGroupProvider {
+  c10::intrusive_ptr<ProcessGroup> get(
+      c10::intrusive_ptr<PrefixStore> prefix_store,
+      int64_t rank,
+      int64_t world_size,
+      std::chrono::milliseconds timeout) const override {
+    auto options = ProcessGroupNCCL::Options::create();
+    options->is_high_priority_stream = false;
+    options->timeout = timeout;
+    return c10::make_intrusive<ProcessGroupNCCL>(
+        prefix_store, rank, world_size, options);
+  }
+};
+
+struct RegisterNCCLProcessGroupProvider {
+  RegisterNCCLProcessGroupProvider() {
+    static RealNCCLProcessGroupProvider provider;
+    registerNCCLProcessGroupProvider(&provider);
+  }
+};
+
+RegisterNCCLProcessGroupProvider reg;
+
+} // namespace
+#define DISPATCH_TO_CUDA(name, function) \
+    m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function)))
+TORCH_LIBRARY_IMPL(q, CUDA, m) {
+    DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", ::torch::distributed::c10d::quantization::_bfloat16_to_float_cuda);
+    DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", ::torch::distributed::c10d::quantization::_float_to_bfloat16_cuda);
+}
+} // namespace c10d
+
+#endif // USE_C10D_NCCL
diff --git a/torch/csrc/distributed/c10d/frontend_cuda.hpp b/torch/csrc/distributed/c10d/frontend_cuda.hpp
new file mode 100644 (file)
index 0000000..a790f2e
--- /dev/null
@@ -0,0 +1,12 @@
+#pragma once
+
+#ifdef USE_C10D_NCCL
+#include <c10/macros/Export.h>
+
+namespace c10d {
+
+TORCH_API void initCustomClassBindingsNccl();
+
+}
+
+#endif
index 6b52d3c05838451b071693bc925fff66568796ef..4bac0ca46edc43be35635e0097bfa05240aab2db 100644 (file)
@@ -17,7 +17,7 @@
 
 #ifdef USE_C10D_NCCL
 #include <c10d/ProcessGroupNCCL.hpp>
-#include <torch/csrc/distributed/c10d/quantization/quantization_gpu.h>
+#include <torch/csrc/distributed/c10d/frontend_cuda.hpp>
 #endif
 
 #ifdef USE_C10D_MPI
@@ -35,7 +35,6 @@
 
 #include <torch/csrc/Exceptions.h>
 #include <torch/csrc/distributed/c10d/python_comm_hook.h>
-#include <torch/csrc/distributed/c10d/quantization/quantization.h>
 #include <torch/csrc/jit/python/pybind_utils.h>
 #include <torch/csrc/utils/object_ptr.h>
 #include <torch/csrc/utils/pybind.h>
@@ -233,6 +232,9 @@ void _register_builtin_comm_hook(
 PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
   C10_LOG_API_USAGE_ONCE("c10d.python.import");
   ::c10d::initCustomClassBindings();
+#ifdef USE_C10D_NCCL
+  ::c10d::initCustomClassBindingsNccl();
+#endif
 
   auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed"));
   if (!c10d_module) {
@@ -1646,28 +1648,6 @@ static PyMethodDef methods[] = { // NOLINT
 PyMethodDef* python_functions() {
   return methods;
 }
-
-namespace quantization {
-TORCH_LIBRARY(q, m) {
-    m.def("_Bfloat16QuantizedToFloat(Tensor input) -> Tensor");
-    m.def("_FloatToBfloat16Quantized(Tensor input) -> Tensor");
-}
-    TORCH_LIBRARY_IMPL(q, CPU, m) {
-        m.impl("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cpu);
-        m.impl("_FloatToBfloat16Quantized", _float_to_bfloat16_cpu);
-    }
-
-#ifdef USE_C10D_NCCL
-    #define DISPATCH_TO_CUDA(name, function) \
-        m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function)))
-    TORCH_LIBRARY_IMPL(q, CUDA, m) {
-        DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda);
-        DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda);
-    }
-#endif
-
-} // namespace quantization
-
 } // namespace c10d
 } // namespace distributed
 } // namespace torch
index 7001209be9851dd0c70d53bfef87b4ab5b9cb5fe..5fbe63ede321c9d6aba2d33a6ad9cfc8dd3e2b7a 100644 (file)
@@ -16,6 +16,7 @@
 #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h>
 #include <torch/csrc/distributed/autograd/utils.h>
 #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
+#include <torch/csrc/distributed/rpc/py_rref.h>
 #include <torch/csrc/distributed/rpc/python_call.h>
 #include <torch/csrc/distributed/rpc/python_remote_call.h>
 #include <torch/csrc/distributed/rpc/python_resp.h>