From: Pieter Noordhuis Date: Thu, 11 Apr 2019 04:27:51 +0000 (-0700) Subject: ProcessGroupMPI exists only if it is valid (#14809) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~278 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ce166d949d8e5ff5be2beb7feaff226646a11518;p=platform%2Fupstream%2Fpytorch.git ProcessGroupMPI exists only if it is valid (#14809) Summary: Previously, MPI process groups were created for all processes, even if they were not part of the created group. Their MPI_Comm member field would be MPI_COMM_NULL and they would ignore any calls. Their rank and size were identical to that of the global process group and they had a special groupRank and groupSize field to capture the _real_ rank. This also meant assymetry with other process group types, where creating a new group would either return the process group OR GroupMember.NON_GROUP_MEMBER. For the MPI process group, it would always return a process group and an additional check was needed to verify whether or not a process was indeed part of a process group or not. This commit changes this such that every MPI process group is a valid process group, and by extension that we no longer have to special case MPI to determine whether or not a process is part of a group. Now, if the value returned by `new_group` is GroupMember.NON_GROUP_MEMBER, the process is not a member, otherwise it is. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14809 Differential Revision: D14887937 Pulled By: pietern fbshipit-source-id: c5bf86d3b33e524cc5004ee68e30103178fa491d --- diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index ff36947..a0c5ab2 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -435,11 +435,17 @@ They are used in specifying strategies for reduction collectives, e.g., #endif #ifdef USE_C10D_MPI - shared_ptr_class_<::c10d::ProcessGroupMPI>( - module, "ProcessGroupMPI", processGroup) - .def(py::init([](std::vector ranks) { + auto processGroupMPI = shared_ptr_class_<::c10d::ProcessGroupMPI>( + module, "ProcessGroupMPI", processGroup); + + // Define static create function instead of a constructor, because + // this function may return null. This happens if this process is not + // part of a sub group that is to be created. + processGroupMPI.def_static( + "create", + [](std::vector ranks) { return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks); - })); + }); #endif shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 6ff0e66..d08e1f9 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -110,8 +110,7 @@ class GroupMember(object): # Cached process groups # For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store) -# For MPI pg, it is a map from ProcessGroup to (Backend, Bool), where bool -# represents if the ProcessGroup objects is part of the group +# For MPI pg, it is a map from ProcessGroup to (Backend, None) _pg_map = {} # Process group's names, map from ProcessGroup to str _pg_names = {} @@ -137,15 +136,9 @@ def _rank_not_in_group(group): Helper that checks if the current process's rank is not in a given group """ - default_backend, _ = _pg_map[_get_default_group()] - if default_backend != Backend.MPI: - return group == GroupMember.NON_GROUP_MEMBER - else: - if group == GroupMember.WORLD: - return False - else: - _, in_group = _pg_map[group] - return not in_group + if group == GroupMember.WORLD: + return False + return group == GroupMember.NON_GROUP_MEMBER def _get_group_rank(group, rank): @@ -345,8 +338,7 @@ def init_process_group(backend, on a system that supports MPI. The same applies to NCCL as well. """ - global _pg_map - global _pg_names + global _pg_group_ranks global _backend global _default_pg global _default_pg_init_method @@ -372,12 +364,14 @@ def init_process_group(backend, backend = Backend(backend) if backend == Backend.MPI: - if not is_mpi_available(): - raise RuntimeError("Distributed package doesn't have MPI built in") - - _default_pg = ProcessGroupMPI([]) - _pg_map[_default_pg] = (Backend.MPI, True) - _pg_names[_default_pg] = group_name + _default_pg = _new_process_group_helper( + -1, + -1, + [], + Backend.MPI, + None, + group_name=group_name, + timeout=timeout) else: # backward compatible API url = init_method @@ -392,22 +386,16 @@ def init_process_group(backend, store, rank, world_size = next(rendezvous(url)) store.set_timeout(timeout) - if backend == Backend.GLOO: - _default_pg = ProcessGroupGloo( - store, - rank, - world_size, - timeout=timeout) - _pg_map[_default_pg] = (Backend.GLOO, store) - _pg_names[_default_pg] = group_name - elif backend == Backend.NCCL: - if not is_nccl_available(): - raise RuntimeError("Distributed package doesn't have NCCL " - "built in") - _default_pg = ProcessGroupNCCL(store, rank, world_size) - _pg_map[_default_pg] = (Backend.NCCL, store) - _pg_names[_default_pg] = group_name + _default_pg = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + timeout=timeout) + _pg_group_ranks[_default_pg] = {i: i for i in range(_default_pg.size())} _backend = _pg_map[_default_pg][0] _default_pg_init_method = init_method @@ -415,14 +403,18 @@ def init_process_group(backend, def _new_process_group_helper(world_size, rank, group_ranks, - in_group, - group_name, - timeout=_default_pg_timeout, - backend=None): + backend, + store, + group_name=None, + timeout=_default_pg_timeout): """ - Create a new distributed process group. And the new process group can be - used to perform collective operations. + Create a new distributed process group. + This function must be called by ALL processes in the global group, even if + the calling process is not part of the newly created group. In that case, + this function returns GroupMember.NON_GROUP_MEMBER. + + This function is called with ``group_ranks == []`` for the default group. """ global _pg_map global _group_count @@ -440,25 +432,33 @@ def _new_process_group_helper(world_size, raise RuntimeError("Expected timeout argument to be of type" "datetime.timedelta") - default_backend, default_store = _pg_map[_default_pg] - if backend is None: - backend = default_backend - else: - backend = Backend(backend) + # The list of group ranks is empty if we're creating the default group. + is_default_group = (len(group_ranks) == 0) + backend = Backend(backend) if backend == Backend.MPI: if not is_mpi_available(): raise RuntimeError("Distributed package doesn't have MPI built in") - pg = ProcessGroupMPI(group_ranks) - _pg_map[pg] = (Backend.MPI, in_group) + pg = ProcessGroupMPI.create(group_ranks) + if not pg: + return GroupMember.NON_GROUP_MEMBER + _pg_map[pg] = (Backend.MPI, None) _pg_names[pg] = group_name else: - # Create the prefix store - store = PrefixStore(group_name, default_store) + # If this is a subgroup (which means group_ranks is specified), + # we check if the current process is a member of the new group. + if not is_default_group: + global_rank = _default_pg.rank() + if global_rank not in group_ranks: + return GroupMember.NON_GROUP_MEMBER + + # Use the group name as prefix in the default store, such that + # a single store can be reused by multiple groups. + prefix_store = PrefixStore(group_name, store) if backend == Backend.GLOO: pg = ProcessGroupGloo( - store, + prefix_store, rank, world_size, timeout=timeout) @@ -468,11 +468,16 @@ def _new_process_group_helper(world_size, if not is_nccl_available(): raise RuntimeError("Distributed package doesn't have NCCL " "built in") - pg = ProcessGroupNCCL(store, rank, world_size, group_name) + pg = ProcessGroupNCCL( + prefix_store, + rank, + world_size, + group_name) _pg_map[pg] = (Backend.NCCL, store) _pg_names[pg] = group_name else: raise RuntimeError("Unsupported distributed backend by group") + return pg @@ -492,15 +497,14 @@ def destroy_process_group(group=group.WORLD): global _default_pg global _default_pg_init_method - default_backend, _ = _pg_map[_get_default_group()] - if (default_backend != Backend.MPI and - group == GroupMember.NON_GROUP_MEMBER): + if group == GroupMember.NON_GROUP_MEMBER: return if group == GroupMember.WORLD: pg = _default_pg else: pg = group + if _pg_map.get(pg, None) is None: raise RuntimeError("Invalid process group specified") @@ -1257,23 +1261,19 @@ def new_group(ranks=None, timeout=_default_pg_timeout, backend=None): _check_default_pg() global _pg_group_ranks - global _group_count - global _pg_names - - group_name = str(_group_count) - _group_count += 1 - - if group_name in _pg_names.values(): - raise RuntimeError("The specified group name has already been " - "created, please use a different group name") - default_backend, _ = _pg_map[_default_pg] + default_backend, default_store = _pg_map[_default_pg] global_rank = _default_pg.rank() global_world_size = _default_pg.size() + # Default to the same backend as the global process group + # if the backend is not specified. + if not backend: + backend = default_backend + # checks the input ranks if ranks is not None: - input_ranks = list(ranks) + ranks = sorted(ranks) group_world_size = len(ranks) if group_world_size > global_world_size: raise RuntimeError("the new group's world size should be less or " @@ -1289,41 +1289,22 @@ def new_group(ranks=None, timeout=_default_pg_timeout, backend=None): else: group_rank = None else: - input_ranks = [] ranks = list(range(global_world_size)) group_world_size = global_world_size group_rank = global_rank - if default_backend == Backend.MPI: - in_group = global_rank in ranks - pg = _new_process_group_helper(group_world_size, - group_rank, - input_ranks, - in_group, - group_name, - timeout=timeout) - else: - # Release ranks not in the group - if global_rank not in ranks: - return GroupMember.NON_GROUP_MEMBER - - if default_backend != Backend.MPI: - if backend is None: - backend = default_backend - pg = _new_process_group_helper(group_world_size, - group_rank, - input_ranks, - True, - group_name, - timeout=timeout, - backend=backend) + backend = Backend(backend) + pg = _new_process_group_helper(group_world_size, + group_rank, + ranks, + backend, + default_store, + timeout=timeout) # Create the global rank to group rank mapping - _pg_group_ranks[pg] = {} - if default_backend == Backend.MPI: - _pg_group_ranks[pg] = pg.group_ranks() - else: - for rank in range(global_world_size): - if rank in ranks: - _pg_group_ranks[pg][rank] = ranks.index(rank) + _pg_group_ranks[pg] = { + global_rank: group_rank + for group_rank, global_rank in enumerate(ranks) + } + return pg diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp index 0c8b3d1..395bbb8 100644 --- a/torch/lib/c10d/ProcessGroupMPI.cpp +++ b/torch/lib/c10d/ProcessGroupMPI.cpp @@ -162,7 +162,6 @@ void ProcessGroupMPI::AsyncWork::populateException() { } // Static global states -int ProcessGroupMPI::numProcessGroups_ = 0; int ProcessGroupMPI::mpiThreadSupport_ = 0; std::mutex ProcessGroupMPI::pgGlobalMutex_; // We only want to initialize once @@ -196,73 +195,52 @@ std::shared_ptr ProcessGroupMPI::createProcessGroupMPI( // Once initialization initMPIOnce(); - std::unique_lock globalLock(pgGlobalMutex_); - + MPI_Comm groupComm = MPI_COMM_WORLD; int rank = -1; int size = -1; - // Update the world size and rank - MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &size)); - MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); + { + std::lock_guard globalLock(pgGlobalMutex_); + + // If no ranks are specified, assume we're creating the root group + if (!ranks.empty()) { + MPI_Group worldGroup; + MPI_Group ranksGroup; + MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); + MPI_CHECK( + MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup)); + MPI_CHECK(MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)); + MPI_CHECK(MPI_Group_free(&worldGroup)); + MPI_CHECK(MPI_Group_free(&ranksGroup)); + } - if (rank < 0 || size < 0) { - throw std::runtime_error("Failed to get the world_size / rank"); - } + // Fetch rank and world size for this group (MPI_COMM_WORLD or new) + if (groupComm != MPI_COMM_NULL) { + MPI_CHECK(MPI_Comm_rank(groupComm, &rank)); + MPI_CHECK(MPI_Comm_size(groupComm, &size)); - // If no ranks are specified, assume we're creating the root group - if (ranks.empty()) { - globalLock.unlock(); - return std::make_shared(rank, size, MPI_COMM_WORLD); + if (rank < 0 || size < 0) { + throw std::runtime_error("Failed to get the world_size / rank"); + } + } } - MPI_Group worldGroup; - MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); - - MPI_Group ranksGroup; - MPI_CHECK( - MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup)); - - MPI_Comm groupComm; - MPI_CHECK(MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)); - - MPI_CHECK(MPI_Group_free(&worldGroup)); - MPI_CHECK(MPI_Group_free(&ranksGroup)); + // If this process is not part of the group, we don't construct a + // process group instance. This is in line with the semantics of the + // other process group types. + if (groupComm == MPI_COMM_NULL) { + return std::shared_ptr(); + } - globalLock.unlock(); return std::make_shared(rank, size, groupComm); } ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm) - : ProcessGroup(rank, size), - stop_(false), - pgComm_(pgComm), - groupRank_(-1), - groupSize_(-1) { - std::unique_lock globalLock(pgGlobalMutex_); - - if (pgComm_ != MPI_COMM_NULL) { - MPI_CHECK(MPI_Comm_rank(pgComm_, &groupRank_)); - MPI_CHECK(MPI_Comm_size(pgComm_, &groupSize_)); - std::vector rankToGroupRank{rank_, groupRank_}; - std::vector allRankToGroupRank; - allRankToGroupRank.resize(2 * groupSize_); - MPI_CHECK(MPI_Allgather( - rankToGroupRank.data(), - 2, - MPI_INT, - allRankToGroupRank.data(), - 2, - MPI_INT, - pgComm_)); - for (size_t i = 0; i < allRankToGroupRank.size(); i += 2) { - groupRankMap_[allRankToGroupRank[i]] = allRankToGroupRank[i + 1]; - } + : ProcessGroup(rank, size), stop_(false), pgComm_(pgComm) { + if (pgComm_ == MPI_COMM_NULL) { + throw std::runtime_error("pgComm_ must not be MPI_COMM_NULL"); } - // increase the total PG count - ++numProcessGroups_; - globalLock.unlock(); - // Start the worker thread accepting MPI calls workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this); } @@ -284,10 +262,6 @@ void ProcessGroupMPI::destroy() { // Join the single worker thread workerThread_.join(); - - // Decrease the number of PG created - std::unique_lock globalLock(pgGlobalMutex_); - --numProcessGroups_; } void ProcessGroupMPI::abort() { @@ -338,9 +312,6 @@ std::shared_ptr ProcessGroupMPI::enqueue( std::shared_ptr ProcessGroupMPI::broadcast( std::vector& tensors, const BroadcastOptions& opts) { - if (pgComm_ == MPI_COMM_NULL) { - return nullptr; - } checkSingleTensor(tensors); std::function&)> runFunc = [opts, this](std::unique_ptr& entry) { @@ -361,9 +332,6 @@ std::shared_ptr ProcessGroupMPI::broadcast( std::shared_ptr ProcessGroupMPI::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - if (pgComm_ == MPI_COMM_NULL) { - return nullptr; - } checkSingleTensor(tensors); std::function&)> runFunc = @@ -386,17 +354,14 @@ std::shared_ptr ProcessGroupMPI::allreduce( std::shared_ptr ProcessGroupMPI::reduce( std::vector& tensors, const ReduceOptions& opts) { - if (pgComm_ == MPI_COMM_NULL) { - return nullptr; - } checkSingleTensor(tensors); std::function&)> runFunc = [opts, this](std::unique_ptr& entry) { auto data = (entry->src)[0]; auto dataPtr = (entry->src)[0].data_ptr(); - void* sendbuf = (groupRank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr; - void* recvbuf = (groupRank_ == opts.rootRank) ? dataPtr : nullptr; + void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr; + void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr; std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Reduce( @@ -417,16 +382,13 @@ std::shared_ptr ProcessGroupMPI::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { - if (pgComm_ == MPI_COMM_NULL) { - return nullptr; - } checkSingleTensor(inputTensors); if (outputTensors.size() != 1) { throw std::runtime_error( "MPI process group only supports a single " "tensor op"); } - if (static_cast(groupSize_) != outputTensors[0].size()) { + if (static_cast(size_) != outputTensors[0].size()) { throw std::runtime_error( "All gather: number of output tensors should equal " "to the world size"); @@ -463,12 +425,9 @@ std::shared_ptr ProcessGroupMPI::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { - if (pgComm_ == MPI_COMM_NULL) { - return nullptr; - } checkSingleTensor(inputTensors); - if (groupRank_ != opts.rootRank) { + if (rank_ != opts.rootRank) { if (outputTensors.size() > 0) { throw std::runtime_error( "Gather: number of output tensors should be 0 " @@ -478,7 +437,7 @@ std::shared_ptr ProcessGroupMPI::gather( if (outputTensors.size() != 1) { throw std::runtime_error("Gather: multi-GPU collective is not supported"); } - if (static_cast(groupSize_) != outputTensors[0].size()) { + if (static_cast(size_) != outputTensors[0].size()) { throw std::runtime_error( "Gather: number of output tensors should equal " "to the world size"); @@ -492,7 +451,7 @@ std::shared_ptr ProcessGroupMPI::gather( void* recvbuf = nullptr; at::Tensor flatOutputTensor; - if (groupRank_ == opts.rootRank) { + if (rank_ == opts.rootRank) { flatOutputTensor = newLikeFlat(entry->dst); recvbuf = flatOutputTensor.data_ptr(); } @@ -508,7 +467,7 @@ std::shared_ptr ProcessGroupMPI::gather( opts.rootRank, pgComm_)); - if (groupRank_ == opts.rootRank) { + if (rank_ == opts.rootRank) { std::vector& outputDataVec = entry->dst; // copy the flattened output tensors to the outputs for (size_t i = 0; i < outputDataVec.size(); ++i) { @@ -517,7 +476,7 @@ std::shared_ptr ProcessGroupMPI::gather( } }; - if (groupRank_ == opts.rootRank) { + if (rank_ == opts.rootRank) { auto entry = std::unique_ptr( new WorkEntry(&inputTensors, &outputTensors[0], std::move(runFunc))); return enqueue(std::move(entry)); @@ -532,12 +491,9 @@ std::shared_ptr ProcessGroupMPI::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { - if (pgComm_ == MPI_COMM_NULL) { - return nullptr; - } checkSingleTensor(outputTensors); - if (groupRank_ != opts.rootRank) { + if (rank_ != opts.rootRank) { if (inputTensors.size() > 0) { throw std::runtime_error( "Scatter: number of input tensors should be 0 " @@ -545,9 +501,10 @@ std::shared_ptr ProcessGroupMPI::scatter( } } else { if (inputTensors.size() != 1) { - throw std::runtime_error("Scatter: multi-GPU collective is not supported"); + throw std::runtime_error( + "Scatter: multi-GPU collective is not supported"); } - if (static_cast(groupSize_) != inputTensors[0].size()) { + if (static_cast(size_) != inputTensors[0].size()) { throw std::runtime_error( "Scatter: number of input tensors should equal " "to the world size"); @@ -561,7 +518,7 @@ std::shared_ptr ProcessGroupMPI::scatter( void* sendbuf = nullptr; at::Tensor flatInputTensor; - if (groupRank_ == opts.rootRank) { + if (rank_ == opts.rootRank) { std::vector& inputDataVec = entry->src; flatInputTensor = newLikeFlat(inputDataVec); sendbuf = flatInputTensor.data_ptr(); @@ -584,7 +541,7 @@ std::shared_ptr ProcessGroupMPI::scatter( pgComm_)); }; - if (groupRank_ == opts.rootRank) { + if (rank_ == opts.rootRank) { auto entry = std::unique_ptr( new WorkEntry(&inputTensors[0], &outputTensors, std::move(runFunc))); return enqueue(std::move(entry)); @@ -599,10 +556,6 @@ std::shared_ptr ProcessGroupMPI::send( std::vector& tensors, int dstRank, int tag) { - if (pgComm_ == MPI_COMM_NULL) { - return nullptr; - } - checkSingleTensor(tensors); auto& tensor = tensors[0]; @@ -627,10 +580,6 @@ std::shared_ptr ProcessGroupMPI::recv( std::vector& tensors, int srcRank, int tag) { - if (pgComm_ == MPI_COMM_NULL) { - return nullptr; - } - checkSingleTensor(tensors); auto& tensor = tensors[0]; @@ -654,10 +603,6 @@ std::shared_ptr ProcessGroupMPI::recv( std::shared_ptr ProcessGroupMPI::recvAnysource( std::vector& tensors, int tag) { - if (pgComm_ == MPI_COMM_NULL) { - return nullptr; - } - checkSingleTensor(tensors); auto& tensor = tensors[0]; @@ -680,9 +625,6 @@ std::shared_ptr ProcessGroupMPI::recvAnysource( std::shared_ptr ProcessGroupMPI::barrier( const BarrierOptions& opts) { - if (pgComm_ == MPI_COMM_NULL) { - return nullptr; - } std::function&)> runFunc = [this](std::unique_ptr& entry) { std::unique_lock globalLock(pgGlobalMutex_); @@ -694,7 +636,7 @@ std::shared_ptr ProcessGroupMPI::barrier( } std::unordered_map ProcessGroupMPI::getGroupRank() { - return groupRankMap_; + throw std::runtime_error("ProcessGroupMPI does not support getGroupRank"); } } // namespace c10d diff --git a/torch/lib/c10d/ProcessGroupMPI.hpp b/torch/lib/c10d/ProcessGroupMPI.hpp index 5d7ec5c..b3511dc 100644 --- a/torch/lib/c10d/ProcessGroupMPI.hpp +++ b/torch/lib/c10d/ProcessGroupMPI.hpp @@ -181,13 +181,9 @@ class ProcessGroupMPI : public ProcessGroup { static std::once_flag onceFlagInitMPI; static std::mutex pgGlobalMutex_; - static int numProcessGroups_; static int mpiThreadSupport_; MPI_Comm pgComm_; - int groupRank_; - int groupSize_; - std::unordered_map groupRankMap_; }; } // namespace c10d