From: Michael Suo Date: Fri, 27 Aug 2021 03:54:54 +0000 (-0700) Subject: [torch/deploy] add torch.distributed to build (#63918) X-Git-Tag: accepted/tizen/unified/20221215.051042~664 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=63c90ec3bf6c9445a36199f65e0523a5e6532b0d;p=platform%2Fupstream%2Fpytorch.git [torch/deploy] add torch.distributed to build (#63918) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63918 Previously we were building with `USE_DISTRIBUTED` off, because c10d was built as a separately library for historical reasons. Since then, lw has merged the c10d build into libtorch, so this is fairly easy to turn on. Differential Revision: D30492442 **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.intern.facebook.com/intern/diff/D30492442/)! D30492442 D30492442 Test Plan: added a unit test Reviewed By: wconstab Pulled By: suo fbshipit-source-id: 843b8fcf349a72a7f6fcbd1fcc8961268690fb8c --- diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 3f6225358a..650830b314 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -340,6 +340,7 @@ libtorch_core_sources = sorted(core_sources_common + core_sources_full + core_tr # These files are the only ones that are supported on Windows. libtorch_distributed_base_sources = [ + "torch/csrc/distributed/c10d/frontend.cpp", "torch/csrc/distributed/c10d/comm.cpp", "torch/csrc/distributed/c10d/default_comm_hooks.cpp", "torch/csrc/distributed/c10d/FileStore.cpp", @@ -351,6 +352,7 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/ProcessGroupGloo.cpp", "torch/csrc/distributed/c10d/ProcessGroupMPI.cpp", "torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp", + "torch/csrc/distributed/c10d/quantization/quantization.cpp", "torch/csrc/distributed/c10d/reducer.cpp", "torch/csrc/distributed/c10d/sequence_num.cpp", "torch/csrc/distributed/c10d/Store.cpp", @@ -548,6 +550,7 @@ libtorch_cuda_distributed_base_sources = [ # These files are only supported on Linux (and others) but not on Windows. libtorch_cuda_distributed_extra_sources = [ + "torch/csrc/distributed/c10d/frontend_cuda.cpp", "torch/csrc/distributed/c10d/NCCLUtils.cpp", "torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", @@ -735,10 +738,8 @@ libtorch_python_core_sources = [ ] libtorch_python_distributed_core_sources = [ - "torch/csrc/distributed/c10d/frontend.cpp", "torch/csrc/distributed/c10d/init.cpp", "torch/csrc/distributed/c10d/python_comm_hook.cpp", - "torch/csrc/distributed/c10d/quantization/quantization.cpp", ] libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [ diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 761605fadc..7c08685561 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -214,11 +214,78 @@ add_custom_command( WORKING_DIRECTORY "${TORCH_ROOT}" ) +if(USE_DISTRIBUTED) + if(WIN32) + append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS) + else() + append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS) + endif() + # Disable certain warnings for GCC-9.X + if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0.0)) + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + endif() + # NCCL is a private dependency of libtorch, but libtorch_python includes + # some private headers of libtorch, which in turn include NCCL. As a hacky + # alternative to making NCCL a public dependency of libtorch, we make it + # a private dependency of libtorch_python as well. + if(USE_NCCL) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) + endif() + # Same for MPI. + if(USE_MPI) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${MPI_CXX_LIBRARIES}) + endif() + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) + +endif() + +if(USE_NCCL AND NOT WIN32) + list(APPEND TORCH_PYTHON_SRCS + ${TORCH_SRC_DIR}/csrc/cuda/python_nccl.cpp) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_NCCL) +endif() + # WARNING- any TORCH_PYTHON_COMPILE_DEFINITIONS above this line # affect both torch_python and DEPLOY interpreter. if(USE_DEPLOY) add_library(torch_python_obj OBJECT ${TORCH_PYTHON_SRCS}) + if(USE_DISTRIBUTED) + # Set c10d-related compile definitions. For a "normal" build of + # libtorch_python, these are set on libtorch as PUBLIC so they are + # automatically propagated when libtorch_python links against libtorch. But + # since in the deploy build we are intentionally *not* linking against + # libtorch, we need to set them manually here. + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_DISTRIBUTED) + if(USE_GLOO AND USE_C10D_GLOO) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_GLOO) + endif() + if(USE_NCCL AND USE_C10D_NCCL) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_NCCL) + # Put nccl headers on the include path. We are specifically only setting + # include dirs here instead of linking against __caffe2_nccl wholesale + # to ensure we aren't accidentally replicating the nccl lib. + target_include_directories(torch_python_obj PRIVATE $) + endif() + if(USE_MPI AND USE_C10D_MPI) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D_MPI) + endif() + + # Pass USE_RPC in order to reduce use of + # #if defined(USE_DISTRIBUTED) && !defined(_WIN32) + # need to be removed when RPC is supported + if(NOT WIN32) + target_compile_definitions(torch_cpu PUBLIC USE_RPC) + endif() + if(USE_TENSORPIPE) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_TENSORPIPE) + endif() + + # Set c10d-related include directories as well. + target_include_directories(torch_python_obj PRIVATE $) + endif() target_compile_definitions(torch_python_obj PRIVATE "-DTHP_BUILD_MAIN_LIB -DUSE_DEPLOY") target_compile_definitions(torch_python_obj PRIVATE ${TORCH_PYTHON_COMPILE_DEFINITIONS}) @@ -268,38 +335,6 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set_source_files_properties(${TORCH_SRC_DIR}/csrc/utils/throughput_benchmark.cpp PROPERTIES COMPILE_FLAGS -Wno-attributes) endif() -if(USE_DISTRIBUTED) - if(WIN32) - append_filelist("libtorch_python_distributed_core_sources" TORCH_PYTHON_SRCS) - else() - append_filelist("libtorch_python_distributed_sources" TORCH_PYTHON_SRCS) - endif() - # Disable certain warnings for GCC-9.X - if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0.0)) - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/rpc/testing/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - set_source_files_properties(${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") - endif() - # NCCL is a private dependency of libtorch, but libtorch_python includes - # some private headers of libtorch, which in turn include NCCL. As a hacky - # alternative to making NCCL a public dependency of libtorch, we make it - # a private dependency of libtorch_python as well. - if(USE_NCCL) - list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) - endif() - # Same for MPI. - if(USE_MPI) - list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${MPI_CXX_LIBRARIES}) - endif() - list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) -endif() - -if(USE_NCCL AND NOT WIN32) - list(APPEND TORCH_PYTHON_SRCS - ${TORCH_SRC_DIR}/csrc/cuda/python_nccl.cpp) - list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_NCCL) -endif() - add_library(torch_python SHARED ${TORCH_PYTHON_SRCS}) if(HAVE_SOVERSION) set_target_properties(torch_python PROPERTIES diff --git a/torch/csrc/deploy/example/generate_examples.py b/torch/csrc/deploy/example/generate_examples.py index 65f244373d..0f279d9221 100644 --- a/torch/csrc/deploy/example/generate_examples.py +++ b/torch/csrc/deploy/example/generate_examples.py @@ -79,3 +79,6 @@ if __name__ == "__main__": e.save_pickle("fn", "fn.pkl", load_library) generate_fx_example() + + with PackageExporter(p / "uses_distributed") as e: + e.save_source_string("uses_distributed", "import torch.distributed; assert torch.distributed.is_available()") diff --git a/torch/csrc/deploy/test_deploy.cpp b/torch/csrc/deploy/test_deploy.cpp index a004db1e0d..53456cacca 100644 --- a/torch/csrc/deploy/test_deploy.cpp +++ b/torch/csrc/deploy/test_deploy.cpp @@ -366,3 +366,15 @@ TEST(TorchpyTest, SharedLibraryLoad) { } } #endif + +TEST(TorchpyTest, UsesDistributed) { + const auto model_filename = path( + "USES_DISTRIBUTED", + "torch/csrc/deploy/example/generated/uses_distributed"); + torch::deploy::InterpreterManager m(1); + torch::deploy::Package p = m.load_package(model_filename); + { + auto I = p.acquire_session(); + I.self.attr("import_module")({"uses_distributed"}); + } +} diff --git a/torch/csrc/deploy/test_deploy_gpu.cpp b/torch/csrc/deploy/test_deploy_gpu.cpp index 8287d1683e..4e990adcd9 100644 --- a/torch/csrc/deploy/test_deploy_gpu.cpp +++ b/torch/csrc/deploy/test_deploy_gpu.cpp @@ -53,3 +53,15 @@ TEST(TorchDeployGPUTest, SimpleModel) { ASSERT_TRUE(ref_output.allclose(output, 1e-03, 1e-05)); } + +TEST(TorchDeployGPUTest, UsesDistributed) { + const auto model_filename = path( + "USES_DISTRIBUTED", + "torch/csrc/deploy/example/generated/uses_distributed"); + torch::deploy::InterpreterManager m(1); + torch::deploy::Package p = m.load_package(model_filename); + { + auto I = p.acquire_session(); + I.self.attr("import_module")({"uses_distributed"}); + } +} diff --git a/torch/csrc/distributed/c10d/frontend.cpp b/torch/csrc/distributed/c10d/frontend.cpp index b65cba7988..e5b59f2898 100644 --- a/torch/csrc/distributed/c10d/frontend.cpp +++ b/torch/csrc/distributed/c10d/frontend.cpp @@ -3,10 +3,11 @@ #include #include #include -#include #include #include #include +#include +#include #include #include @@ -17,10 +18,6 @@ #include #endif -#ifdef USE_C10D_NCCL -#include -#endif - #ifdef USE_C10D_MPI #include #endif @@ -29,6 +26,20 @@ namespace c10d { namespace { +// Constant initialization, so it is guaranteed to be initialized before +// static initialization calls which may invoke registerNCCLProcessGroupProvider +const NCCLProcessGroupProvider stubProvider; +constexpr const NCCLProcessGroupProvider* defaultStubProviderAddr = + &stubProvider; +inline const NCCLProcessGroupProvider*& getNCCLProcessGroupProviderAddress() { + static const NCCLProcessGroupProvider* stubs_ = defaultStubProviderAddr; + return stubs_; +} + +const NCCLProcessGroupProvider* GetNCCLProcessGroupProvider() { + return getNCCLProcessGroupProviderAddress(); +} + void maybePreprocessComplexTensor(at::Tensor& tensor) { if(!tensor.is_complex()) { return; @@ -63,6 +74,11 @@ void assertReduceOpSupportsComplexTensor(ReduceOp op) { } // namespace anonymous +void registerNCCLProcessGroupProvider(NCCLProcessGroupProvider* provider) { + getNCCLProcessGroupProviderAddress() = provider; +} + + std::string Backend::get(const std::string& backend_type) { return backend_type; } @@ -207,17 +223,7 @@ c10::intrusive_ptr DistributedC10d::newProcessGroupHelper( "Attempting to create GLOO-based process group while GLOO is either not enabled or built"); #endif // USE_C10D_GLOO } else if (backend == "nccl") { -#ifdef USE_C10D_NCCL - auto options = ProcessGroupNCCL::Options::create(); - - options->is_high_priority_stream = false; - options->timeout = timeout; - pg = c10::make_intrusive( - prefix_store, rank, world_size, options); -#else - AT_ERROR( - "Attempting to create NCCL-based process group while NCCL is either not enabled or built"); -#endif // USE_C10D_NCCL + pg = GetNCCLProcessGroupProvider()->get(prefix_store, rank, world_size, timeout); } else { // TODO: discuss to figure out how to extend this to third party backends? AT_ERROR("Unsupported backend type: ", backend); @@ -1008,7 +1014,7 @@ void initCustomClassBindings() { .def( "broadcast", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - std::vector data) { return self->broadcast(data); + std::vector data) { return self->broadcast(data); }) */ .def( @@ -1045,14 +1051,14 @@ void initCustomClassBindings() { .def( "allreduce", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - at::Tensor& tensor, - c10::intrusive_ptr<::c10d::ReduceOp> op) { + at::Tensor& tensor, + c10::intrusive_ptr<::c10d::ReduceOp> op) { ::c10d::AllreduceOptions opts; opts.reduceOp = *op; std::vector tensors = {tensor}; return self->allreduce(tensors, opts); - } - ) + } + ) */ // TODO: make AllreduceCoalescedOptions compatible with TorchBind to // provide the full API in python. @@ -1098,8 +1104,8 @@ void initCustomClassBindings() { .def( "allgather", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - std::vector output, - at::Tensor input) { + std::vector output, + at::Tensor input) { std::vector> outputs = { std::move(output)}; std::vector inputs = {std::move(input)}; @@ -1121,8 +1127,8 @@ void initCustomClassBindings() { .def( "gather", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - std::vector> output_tensors, - std::vector input_tensors) { + std::vector> output_tensors, + std::vector input_tensors) { ::c10d::GatherOptions opts; return self->gather(output_tensors, input_tensors, opts); }) @@ -1145,8 +1151,8 @@ void initCustomClassBindings() { .def( "scatter", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - std::vector outputTensors, - std::vector> inputTensors) { + std::vector outputTensors, + std::vector> inputTensors) { ::c10d::ScatterOptions opts; self->scatter(outputTensors, inputTensors, opts); }) @@ -1169,8 +1175,8 @@ void initCustomClassBindings() { // TODO: Enable this method when TorchBind supports ReduceScatterOptions. .def( "reduce_scatter", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, - std::vector outputTensors, - std::vector> inputTensors) { + std::vector outputTensors, + std::vector> inputTensors) { ::c10d::ReduceScatterOptions opts; return self->reduce_scatter(outputTensors, inputTensors, opts); }) @@ -1241,95 +1247,6 @@ void initCustomClassBindings() { return self->barrier(opts); }); -#ifdef USE_C10D_NCCL - // XXX: Ideally the Options of ProcessGroupNCCL should be - // bound using `def_readwrite` like in pybind11, but we - // didn't do that because: 1. no milisecond support yet - // 2. no def_readwrite or property support yet. - // TODO: make this binding the same as pybind11 - static const auto ProcessGroupNCCLOptionsTorchBind = - torch::class_<::c10d::ProcessGroupNCCL::Options>( - "dist_c10d", "ProcessGroupNCCLOptions") - .def(torch::init([](int64_t timeout, bool isHighPriorityStream) { - auto opTimeout = std::chrono::milliseconds(timeout); - auto opts = - ::c10d::ProcessGroupNCCL::Options::create(isHighPriorityStream); - opts->timeout = opTimeout; - return opts; - })); - - static const auto ProcessGroupNCCLTorchBind = - torch::class_<::c10d::ProcessGroupNCCL>("dist_c10d", "ProcessGroupNCCL") - .def_pickle( - [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { - auto base_process_group = - ::c10::static_intrusive_pointer_cast<::c10d::ProcessGroup>(self); - auto name = - ::c10d::DistributedC10d::get()->getNameOfProcessGroup(self); - return std::vector{name}; - }, - [](std::vector state) { - TORCH_CHECK( - state.size() == 1, - "Expecting exactly 1 state when restoring ProcessGroupNCCL, got: ", - state.size()); - const auto& process_group_name = state.front(); - auto base_process_group = - ::c10d::DistributedC10d::get()->getProcessGroupByName( - process_group_name); - TORCH_CHECK( - base_process_group.defined(), - "Needed process group not found, ", - "please create a process group with name: ", - process_group_name); - c10::intrusive_ptr<::c10d::ProcessGroupNCCL> - process_group_nccl = ::c10::dynamic_intrusive_pointer_cast< - ::c10d::ProcessGroupNCCL>(base_process_group); - TORCH_CHECK( - process_group_nccl.defined(), - "Process group ", - process_group_name, - " isn't configured for NCCL backend"); - return process_group_nccl; - }) - .def(torch::init( - [](const c10::intrusive_ptr<::c10d::Store>& store, - int64_t rank, - int64_t size, - c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options> options, - const std::string& name) { - auto pg = c10::make_intrusive<::c10d::ProcessGroupNCCL>( - store, rank, size, options); - ::c10d::DistributedC10d::get()->registerProcessGroupName( - pg, name); - return pg; - })) - .def( - "alltoall_base", - [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self, - at::Tensor output, - at::Tensor input, - std::vector outputSplitSizes, - std::vector inputSplitSizes) { - return self->alltoall_base( - output, - input, - outputSplitSizes, - inputSplitSizes, - ::c10d::AllToAllOptions()); - }) - .def( - "size", - [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { - return (int64_t)self->getSize(); - }) - .def( - "rank", - [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { - return (int64_t)self->getRank(); - }); -#endif - static const auto DistributedC10dFrontendTorchBind = torch::class_<::c10d::DistributedC10d>("dist_c10d", "frontend") .def(torch::init([]() { return ::c10d::DistributedC10d::get(); })) @@ -1344,4 +1261,12 @@ void initCustomClassBindings() { &::c10d::DistributedC10d::getNameOfProcessGroup); } +TORCH_LIBRARY(q, m) { + m.def("_Bfloat16QuantizedToFloat(Tensor input) -> Tensor"); + m.def("_FloatToBfloat16Quantized(Tensor input) -> Tensor"); +} +TORCH_LIBRARY_IMPL(q, CPU, m) { + m.impl("_Bfloat16QuantizedToFloat", ::torch::distributed::c10d::quantization::_bfloat16_to_float_cpu); + m.impl("_FloatToBfloat16Quantized", ::torch::distributed::c10d::quantization::_float_to_bfloat16_cpu); +} } // namespace c10d diff --git a/torch/csrc/distributed/c10d/frontend.hpp b/torch/csrc/distributed/c10d/frontend.hpp index c90cc077b2..b39d8b7a44 100644 --- a/torch/csrc/distributed/c10d/frontend.hpp +++ b/torch/csrc/distributed/c10d/frontend.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -259,7 +260,26 @@ class TORCH_PYTHON_API DistributedC10d : public torch::CustomClassHolder { int64_t group_count_; }; -// Must be called to initialize Torchbind bindings for c10d. -void initCustomClassBindings(); +// This class exists as a way to allow us to split NCCL-specific code into a +// different file. frontend_cuda.cpp will, if USE_C10D_NCCL is defined, +// override this NCCLProcessGroupProvider with one that will actually do +// something. +struct TORCH_API NCCLProcessGroupProvider { + virtual c10::intrusive_ptr get( + c10::intrusive_ptr /*prefix_store*/, + int64_t /*rank*/, + int64_t /*world_size*/, + std::chrono::milliseconds /*timeout*/) const { + AT_ERROR( + "Attempting to create NCCL-based process group while NCCL is either not enabled or built"); + } + + virtual ~NCCLProcessGroupProvider() = default; +}; + +TORCH_API void registerNCCLProcessGroupProvider( + NCCLProcessGroupProvider* provider); + +TORCH_API void initCustomClassBindings(); } // namespace c10d diff --git a/torch/csrc/distributed/c10d/frontend_cuda.cpp b/torch/csrc/distributed/c10d/frontend_cuda.cpp new file mode 100644 index 0000000000..1b42f13b3c --- /dev/null +++ b/torch/csrc/distributed/c10d/frontend_cuda.cpp @@ -0,0 +1,136 @@ +#include + +#ifdef USE_C10D_NCCL + +#include +#include +#include +#include +#include + +namespace c10d { + +void initCustomClassBindingsNccl() { + // XXX: Ideally the Options of ProcessGroupNCCL should be + // bound using `def_readwrite` like in pybind11, but we + // didn't do that because: 1. no milisecond support yet + // 2. no def_readwrite or property support yet. + // TODO: make this binding the same as pybind11 + static const auto ProcessGroupNCCLOptionsTorchBind = + torch::class_<::c10d::ProcessGroupNCCL::Options>( + "dist_c10d", "ProcessGroupNCCLOptions") + .def(torch::init([](int64_t timeout, bool isHighPriorityStream) { + auto opTimeout = std::chrono::milliseconds(timeout); + auto opts = + ::c10d::ProcessGroupNCCL::Options::create(isHighPriorityStream); + opts->timeout = opTimeout; + return opts; + })); + + static const auto ProcessGroupNCCLTorchBind = + torch::class_<::c10d::ProcessGroupNCCL>("dist_c10d", "ProcessGroupNCCL") + .def_pickle( + [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { + auto base_process_group = + ::c10::static_intrusive_pointer_cast<::c10d::ProcessGroup>( + self); + auto name = + ::c10d::DistributedC10d::get()->getNameOfProcessGroup(self); + return std::vector{name}; + }, + [](std::vector state) { + TORCH_CHECK( + state.size() == 1, + "Expecting exactly 1 state when restoring ProcessGroupNCCL, got: ", + state.size()); + const auto& process_group_name = state.front(); + auto base_process_group = + ::c10d::DistributedC10d::get()->getProcessGroupByName( + process_group_name); + TORCH_CHECK( + base_process_group.defined(), + "Needed process group not found, ", + "please create a process group with name: ", + process_group_name); + c10::intrusive_ptr<::c10d::ProcessGroupNCCL> + process_group_nccl = ::c10::dynamic_intrusive_pointer_cast< + ::c10d::ProcessGroupNCCL>(base_process_group); + TORCH_CHECK( + process_group_nccl.defined(), + "Process group ", + process_group_name, + " isn't configured for NCCL backend"); + return process_group_nccl; + }) + .def(torch::init( + [](const c10::intrusive_ptr<::c10d::Store>& store, + int64_t rank, + int64_t size, + c10::intrusive_ptr<::c10d::ProcessGroupNCCL::Options> options, + const std::string& name) { + auto pg = c10::make_intrusive<::c10d::ProcessGroupNCCL>( + store, rank, size, options); + ::c10d::DistributedC10d::get()->registerProcessGroupName( + pg, name); + return pg; + })) + .def( + "alltoall_base", + [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self, + at::Tensor output, + at::Tensor input, + std::vector outputSplitSizes, + std::vector inputSplitSizes) { + return self->alltoall_base( + output, + input, + outputSplitSizes, + inputSplitSizes, + ::c10d::AllToAllOptions()); + }) + .def( + "size", + [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { + return (int64_t)self->getSize(); + }) + .def( + "rank", + [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) { + return (int64_t)self->getRank(); + }); +} + +namespace { +struct RealNCCLProcessGroupProvider : public NCCLProcessGroupProvider { + c10::intrusive_ptr get( + c10::intrusive_ptr prefix_store, + int64_t rank, + int64_t world_size, + std::chrono::milliseconds timeout) const override { + auto options = ProcessGroupNCCL::Options::create(); + options->is_high_priority_stream = false; + options->timeout = timeout; + return c10::make_intrusive( + prefix_store, rank, world_size, options); + } +}; + +struct RegisterNCCLProcessGroupProvider { + RegisterNCCLProcessGroupProvider() { + static RealNCCLProcessGroupProvider provider; + registerNCCLProcessGroupProvider(&provider); + } +}; + +RegisterNCCLProcessGroupProvider reg; + +} // namespace +#define DISPATCH_TO_CUDA(name, function) \ + m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function))) +TORCH_LIBRARY_IMPL(q, CUDA, m) { + DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", ::torch::distributed::c10d::quantization::_bfloat16_to_float_cuda); + DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", ::torch::distributed::c10d::quantization::_float_to_bfloat16_cuda); +} +} // namespace c10d + +#endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/frontend_cuda.hpp b/torch/csrc/distributed/c10d/frontend_cuda.hpp new file mode 100644 index 0000000000..a790f2e847 --- /dev/null +++ b/torch/csrc/distributed/c10d/frontend_cuda.hpp @@ -0,0 +1,12 @@ +#pragma once + +#ifdef USE_C10D_NCCL +#include + +namespace c10d { + +TORCH_API void initCustomClassBindingsNccl(); + +} + +#endif diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6b52d3c058..4bac0ca46e 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -17,7 +17,7 @@ #ifdef USE_C10D_NCCL #include -#include +#include #endif #ifdef USE_C10D_MPI @@ -35,7 +35,6 @@ #include #include -#include #include #include #include @@ -233,6 +232,9 @@ void _register_builtin_comm_hook( PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { C10_LOG_API_USAGE_ONCE("c10d.python.import"); ::c10d::initCustomClassBindings(); +#ifdef USE_C10D_NCCL + ::c10d::initCustomClassBindingsNccl(); +#endif auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed")); if (!c10d_module) { @@ -1646,28 +1648,6 @@ static PyMethodDef methods[] = { // NOLINT PyMethodDef* python_functions() { return methods; } - -namespace quantization { -TORCH_LIBRARY(q, m) { - m.def("_Bfloat16QuantizedToFloat(Tensor input) -> Tensor"); - m.def("_FloatToBfloat16Quantized(Tensor input) -> Tensor"); -} - TORCH_LIBRARY_IMPL(q, CPU, m) { - m.impl("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cpu); - m.impl("_FloatToBfloat16Quantized", _float_to_bfloat16_cpu); - } - -#ifdef USE_C10D_NCCL - #define DISPATCH_TO_CUDA(name, function) \ - m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function))) - TORCH_LIBRARY_IMPL(q, CUDA, m) { - DISPATCH_TO_CUDA("_Bfloat16QuantizedToFloat", _bfloat16_to_float_cuda); - DISPATCH_TO_CUDA("_FloatToBfloat16Quantized", _float_to_bfloat16_cuda); - } -#endif - -} // namespace quantization - } // namespace c10d } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp index 7001209be9..5fbe63ede3 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.cpp +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include