# Cached process groups
# For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
-# For MPI pg, it is a map from ProcessGroup to (Backend, Bool), where bool
-# represents if the ProcessGroup objects is part of the group
+# For MPI pg, it is a map from ProcessGroup to (Backend, None)
_pg_map = {}
# Process group's names, map from ProcessGroup to str
_pg_names = {}
Helper that checks if the current process's rank is not in a given group
"""
- default_backend, _ = _pg_map[_get_default_group()]
- if default_backend != Backend.MPI:
- return group == GroupMember.NON_GROUP_MEMBER
- else:
- if group == GroupMember.WORLD:
- return False
- else:
- _, in_group = _pg_map[group]
- return not in_group
+ if group == GroupMember.WORLD:
+ return False
+ return group == GroupMember.NON_GROUP_MEMBER
def _get_group_rank(group, rank):
on a system that supports MPI. The same applies to NCCL as well.
"""
- global _pg_map
- global _pg_names
+ global _pg_group_ranks
global _backend
global _default_pg
global _default_pg_init_method
backend = Backend(backend)
if backend == Backend.MPI:
- if not is_mpi_available():
- raise RuntimeError("Distributed package doesn't have MPI built in")
-
- _default_pg = ProcessGroupMPI([])
- _pg_map[_default_pg] = (Backend.MPI, True)
- _pg_names[_default_pg] = group_name
+ _default_pg = _new_process_group_helper(
+ -1,
+ -1,
+ [],
+ Backend.MPI,
+ None,
+ group_name=group_name,
+ timeout=timeout)
else:
# backward compatible API
url = init_method
store, rank, world_size = next(rendezvous(url))
store.set_timeout(timeout)
- if backend == Backend.GLOO:
- _default_pg = ProcessGroupGloo(
- store,
- rank,
- world_size,
- timeout=timeout)
- _pg_map[_default_pg] = (Backend.GLOO, store)
- _pg_names[_default_pg] = group_name
- elif backend == Backend.NCCL:
- if not is_nccl_available():
- raise RuntimeError("Distributed package doesn't have NCCL "
- "built in")
- _default_pg = ProcessGroupNCCL(store, rank, world_size)
- _pg_map[_default_pg] = (Backend.NCCL, store)
- _pg_names[_default_pg] = group_name
+ _default_pg = _new_process_group_helper(
+ world_size,
+ rank,
+ [],
+ backend,
+ store,
+ group_name=group_name,
+ timeout=timeout)
+ _pg_group_ranks[_default_pg] = {i: i for i in range(_default_pg.size())}
_backend = _pg_map[_default_pg][0]
_default_pg_init_method = init_method
def _new_process_group_helper(world_size,
rank,
group_ranks,
- in_group,
- group_name,
- timeout=_default_pg_timeout,
- backend=None):
+ backend,
+ store,
+ group_name=None,
+ timeout=_default_pg_timeout):
"""
- Create a new distributed process group. And the new process group can be
- used to perform collective operations.
+ Create a new distributed process group.
+ This function must be called by ALL processes in the global group, even if
+ the calling process is not part of the newly created group. In that case,
+ this function returns GroupMember.NON_GROUP_MEMBER.
+
+ This function is called with ``group_ranks == []`` for the default group.
"""
global _pg_map
global _group_count
raise RuntimeError("Expected timeout argument to be of type"
"datetime.timedelta")
- default_backend, default_store = _pg_map[_default_pg]
- if backend is None:
- backend = default_backend
- else:
- backend = Backend(backend)
+ # The list of group ranks is empty if we're creating the default group.
+ is_default_group = (len(group_ranks) == 0)
+ backend = Backend(backend)
if backend == Backend.MPI:
if not is_mpi_available():
raise RuntimeError("Distributed package doesn't have MPI built in")
- pg = ProcessGroupMPI(group_ranks)
- _pg_map[pg] = (Backend.MPI, in_group)
+ pg = ProcessGroupMPI.create(group_ranks)
+ if not pg:
+ return GroupMember.NON_GROUP_MEMBER
+ _pg_map[pg] = (Backend.MPI, None)
_pg_names[pg] = group_name
else:
- # Create the prefix store
- store = PrefixStore(group_name, default_store)
+ # If this is a subgroup (which means group_ranks is specified),
+ # we check if the current process is a member of the new group.
+ if not is_default_group:
+ global_rank = _default_pg.rank()
+ if global_rank not in group_ranks:
+ return GroupMember.NON_GROUP_MEMBER
+
+ # Use the group name as prefix in the default store, such that
+ # a single store can be reused by multiple groups.
+ prefix_store = PrefixStore(group_name, store)
if backend == Backend.GLOO:
pg = ProcessGroupGloo(
- store,
+ prefix_store,
rank,
world_size,
timeout=timeout)
if not is_nccl_available():
raise RuntimeError("Distributed package doesn't have NCCL "
"built in")
- pg = ProcessGroupNCCL(store, rank, world_size, group_name)
+ pg = ProcessGroupNCCL(
+ prefix_store,
+ rank,
+ world_size,
+ group_name)
_pg_map[pg] = (Backend.NCCL, store)
_pg_names[pg] = group_name
else:
raise RuntimeError("Unsupported distributed backend by group")
+
return pg
global _default_pg
global _default_pg_init_method
- default_backend, _ = _pg_map[_get_default_group()]
- if (default_backend != Backend.MPI and
- group == GroupMember.NON_GROUP_MEMBER):
+ if group == GroupMember.NON_GROUP_MEMBER:
return
if group == GroupMember.WORLD:
pg = _default_pg
else:
pg = group
+
if _pg_map.get(pg, None) is None:
raise RuntimeError("Invalid process group specified")
_check_default_pg()
global _pg_group_ranks
- global _group_count
- global _pg_names
-
- group_name = str(_group_count)
- _group_count += 1
-
- if group_name in _pg_names.values():
- raise RuntimeError("The specified group name has already been "
- "created, please use a different group name")
- default_backend, _ = _pg_map[_default_pg]
+ default_backend, default_store = _pg_map[_default_pg]
global_rank = _default_pg.rank()
global_world_size = _default_pg.size()
+ # Default to the same backend as the global process group
+ # if the backend is not specified.
+ if not backend:
+ backend = default_backend
+
# checks the input ranks
if ranks is not None:
- input_ranks = list(ranks)
+ ranks = sorted(ranks)
group_world_size = len(ranks)
if group_world_size > global_world_size:
raise RuntimeError("the new group's world size should be less or "
else:
group_rank = None
else:
- input_ranks = []
ranks = list(range(global_world_size))
group_world_size = global_world_size
group_rank = global_rank
- if default_backend == Backend.MPI:
- in_group = global_rank in ranks
- pg = _new_process_group_helper(group_world_size,
- group_rank,
- input_ranks,
- in_group,
- group_name,
- timeout=timeout)
- else:
- # Release ranks not in the group
- if global_rank not in ranks:
- return GroupMember.NON_GROUP_MEMBER
-
- if default_backend != Backend.MPI:
- if backend is None:
- backend = default_backend
- pg = _new_process_group_helper(group_world_size,
- group_rank,
- input_ranks,
- True,
- group_name,
- timeout=timeout,
- backend=backend)
+ backend = Backend(backend)
+ pg = _new_process_group_helper(group_world_size,
+ group_rank,
+ ranks,
+ backend,
+ default_store,
+ timeout=timeout)
# Create the global rank to group rank mapping
- _pg_group_ranks[pg] = {}
- if default_backend == Backend.MPI:
- _pg_group_ranks[pg] = pg.group_ranks()
- else:
- for rank in range(global_world_size):
- if rank in ranks:
- _pg_group_ranks[pg][rank] = ranks.index(rank)
+ _pg_group_ranks[pg] = {
+ global_rank: group_rank
+ for group_rank, global_rank in enumerate(ranks)
+ }
+
return pg
}
// Static global states
-int ProcessGroupMPI::numProcessGroups_ = 0;
int ProcessGroupMPI::mpiThreadSupport_ = 0;
std::mutex ProcessGroupMPI::pgGlobalMutex_;
// We only want to initialize once
// Once initialization
initMPIOnce();
- std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
-
+ MPI_Comm groupComm = MPI_COMM_WORLD;
int rank = -1;
int size = -1;
- // Update the world size and rank
- MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &size));
- MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank));
+ {
+ std::lock_guard<std::mutex> globalLock(pgGlobalMutex_);
+
+ // If no ranks are specified, assume we're creating the root group
+ if (!ranks.empty()) {
+ MPI_Group worldGroup;
+ MPI_Group ranksGroup;
+ MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
+ MPI_CHECK(
+ MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup));
+ MPI_CHECK(MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm));
+ MPI_CHECK(MPI_Group_free(&worldGroup));
+ MPI_CHECK(MPI_Group_free(&ranksGroup));
+ }
- if (rank < 0 || size < 0) {
- throw std::runtime_error("Failed to get the world_size / rank");
- }
+ // Fetch rank and world size for this group (MPI_COMM_WORLD or new)
+ if (groupComm != MPI_COMM_NULL) {
+ MPI_CHECK(MPI_Comm_rank(groupComm, &rank));
+ MPI_CHECK(MPI_Comm_size(groupComm, &size));
- // If no ranks are specified, assume we're creating the root group
- if (ranks.empty()) {
- globalLock.unlock();
- return std::make_shared<ProcessGroupMPI>(rank, size, MPI_COMM_WORLD);
+ if (rank < 0 || size < 0) {
+ throw std::runtime_error("Failed to get the world_size / rank");
+ }
+ }
}
- MPI_Group worldGroup;
- MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
-
- MPI_Group ranksGroup;
- MPI_CHECK(
- MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup));
-
- MPI_Comm groupComm;
- MPI_CHECK(MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm));
-
- MPI_CHECK(MPI_Group_free(&worldGroup));
- MPI_CHECK(MPI_Group_free(&ranksGroup));
+ // If this process is not part of the group, we don't construct a
+ // process group instance. This is in line with the semantics of the
+ // other process group types.
+ if (groupComm == MPI_COMM_NULL) {
+ return std::shared_ptr<ProcessGroupMPI>();
+ }
- globalLock.unlock();
return std::make_shared<ProcessGroupMPI>(rank, size, groupComm);
}
ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm)
- : ProcessGroup(rank, size),
- stop_(false),
- pgComm_(pgComm),
- groupRank_(-1),
- groupSize_(-1) {
- std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
-
- if (pgComm_ != MPI_COMM_NULL) {
- MPI_CHECK(MPI_Comm_rank(pgComm_, &groupRank_));
- MPI_CHECK(MPI_Comm_size(pgComm_, &groupSize_));
- std::vector<int> rankToGroupRank{rank_, groupRank_};
- std::vector<int> allRankToGroupRank;
- allRankToGroupRank.resize(2 * groupSize_);
- MPI_CHECK(MPI_Allgather(
- rankToGroupRank.data(),
- 2,
- MPI_INT,
- allRankToGroupRank.data(),
- 2,
- MPI_INT,
- pgComm_));
- for (size_t i = 0; i < allRankToGroupRank.size(); i += 2) {
- groupRankMap_[allRankToGroupRank[i]] = allRankToGroupRank[i + 1];
- }
+ : ProcessGroup(rank, size), stop_(false), pgComm_(pgComm) {
+ if (pgComm_ == MPI_COMM_NULL) {
+ throw std::runtime_error("pgComm_ must not be MPI_COMM_NULL");
}
- // increase the total PG count
- ++numProcessGroups_;
- globalLock.unlock();
-
// Start the worker thread accepting MPI calls
workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this);
}
// Join the single worker thread
workerThread_.join();
-
- // Decrease the number of PG created
- std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
- --numProcessGroups_;
}
void ProcessGroupMPI::abort() {
std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) {
- if (pgComm_ == MPI_COMM_NULL) {
- return nullptr;
- }
checkSingleTensor(tensors);
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
- if (pgComm_ == MPI_COMM_NULL) {
- return nullptr;
- }
checkSingleTensor(tensors);
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts) {
- if (pgComm_ == MPI_COMM_NULL) {
- return nullptr;
- }
checkSingleTensor(tensors);
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[opts, this](std::unique_ptr<WorkEntry>& entry) {
auto data = (entry->src)[0];
auto dataPtr = (entry->src)[0].data_ptr();
- void* sendbuf = (groupRank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr;
- void* recvbuf = (groupRank_ == opts.rootRank) ? dataPtr : nullptr;
+ void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr;
+ void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr;
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
MPI_CHECK(MPI_Reduce(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) {
- if (pgComm_ == MPI_COMM_NULL) {
- return nullptr;
- }
checkSingleTensor(inputTensors);
if (outputTensors.size() != 1) {
throw std::runtime_error(
"MPI process group only supports a single "
"tensor op");
}
- if (static_cast<size_t>(groupSize_) != outputTensors[0].size()) {
+ if (static_cast<size_t>(size_) != outputTensors[0].size()) {
throw std::runtime_error(
"All gather: number of output tensors should equal "
"to the world size");
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts) {
- if (pgComm_ == MPI_COMM_NULL) {
- return nullptr;
- }
checkSingleTensor(inputTensors);
- if (groupRank_ != opts.rootRank) {
+ if (rank_ != opts.rootRank) {
if (outputTensors.size() > 0) {
throw std::runtime_error(
"Gather: number of output tensors should be 0 "
if (outputTensors.size() != 1) {
throw std::runtime_error("Gather: multi-GPU collective is not supported");
}
- if (static_cast<size_t>(groupSize_) != outputTensors[0].size()) {
+ if (static_cast<size_t>(size_) != outputTensors[0].size()) {
throw std::runtime_error(
"Gather: number of output tensors should equal "
"to the world size");
void* recvbuf = nullptr;
at::Tensor flatOutputTensor;
- if (groupRank_ == opts.rootRank) {
+ if (rank_ == opts.rootRank) {
flatOutputTensor = newLikeFlat(entry->dst);
recvbuf = flatOutputTensor.data_ptr();
}
opts.rootRank,
pgComm_));
- if (groupRank_ == opts.rootRank) {
+ if (rank_ == opts.rootRank) {
std::vector<at::Tensor>& outputDataVec = entry->dst;
// copy the flattened output tensors to the outputs
for (size_t i = 0; i < outputDataVec.size(); ++i) {
}
};
- if (groupRank_ == opts.rootRank) {
+ if (rank_ == opts.rootRank) {
auto entry = std::unique_ptr<WorkEntry>(
new WorkEntry(&inputTensors, &outputTensors[0], std::move(runFunc)));
return enqueue(std::move(entry));
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts) {
- if (pgComm_ == MPI_COMM_NULL) {
- return nullptr;
- }
checkSingleTensor(outputTensors);
- if (groupRank_ != opts.rootRank) {
+ if (rank_ != opts.rootRank) {
if (inputTensors.size() > 0) {
throw std::runtime_error(
"Scatter: number of input tensors should be 0 "
}
} else {
if (inputTensors.size() != 1) {
- throw std::runtime_error("Scatter: multi-GPU collective is not supported");
+ throw std::runtime_error(
+ "Scatter: multi-GPU collective is not supported");
}
- if (static_cast<size_t>(groupSize_) != inputTensors[0].size()) {
+ if (static_cast<size_t>(size_) != inputTensors[0].size()) {
throw std::runtime_error(
"Scatter: number of input tensors should equal "
"to the world size");
void* sendbuf = nullptr;
at::Tensor flatInputTensor;
- if (groupRank_ == opts.rootRank) {
+ if (rank_ == opts.rootRank) {
std::vector<at::Tensor>& inputDataVec = entry->src;
flatInputTensor = newLikeFlat(inputDataVec);
sendbuf = flatInputTensor.data_ptr();
pgComm_));
};
- if (groupRank_ == opts.rootRank) {
+ if (rank_ == opts.rootRank) {
auto entry = std::unique_ptr<WorkEntry>(
new WorkEntry(&inputTensors[0], &outputTensors, std::move(runFunc)));
return enqueue(std::move(entry));
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) {
- if (pgComm_ == MPI_COMM_NULL) {
- return nullptr;
- }
-
checkSingleTensor(tensors);
auto& tensor = tensors[0];
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) {
- if (pgComm_ == MPI_COMM_NULL) {
- return nullptr;
- }
-
checkSingleTensor(tensors);
auto& tensor = tensors[0];
std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) {
- if (pgComm_ == MPI_COMM_NULL) {
- return nullptr;
- }
-
checkSingleTensor(tensors);
auto& tensor = tensors[0];
std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::barrier(
const BarrierOptions& opts) {
- if (pgComm_ == MPI_COMM_NULL) {
- return nullptr;
- }
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[this](std::unique_ptr<WorkEntry>& entry) {
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
}
std::unordered_map<int, int> ProcessGroupMPI::getGroupRank() {
- return groupRankMap_;
+ throw std::runtime_error("ProcessGroupMPI does not support getGroupRank");
}
} // namespace c10d