ProcessGroupMPI exists only if it is valid (#14809)
authorPieter Noordhuis <pietern@fb.com>
Thu, 11 Apr 2019 04:27:51 +0000 (21:27 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 11 Apr 2019 04:36:35 +0000 (21:36 -0700)
Summary:
Previously, MPI process groups were created for all processes, even if
they were not part of the created group. Their MPI_Comm member field
would be MPI_COMM_NULL and they would ignore any calls. Their rank and
size were identical to that of the global process group and they had a
special groupRank and groupSize field to capture the _real_ rank.

This also meant assymetry with other process group types, where creating
a new group would either return the process group OR
GroupMember.NON_GROUP_MEMBER. For the MPI process group, it would always
return a process group and an additional check was needed to verify
whether or not a process was indeed part of a process group or not.

This commit changes this such that every MPI process group is a valid
process group, and by extension that we no longer have to special case
MPI to determine whether or not a process is part of a group. Now, if
the value returned by `new_group` is GroupMember.NON_GROUP_MEMBER, the
process is not a member, otherwise it is.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14809

Differential Revision: D14887937

Pulled By: pietern

fbshipit-source-id: c5bf86d3b33e524cc5004ee68e30103178fa491d

torch/csrc/distributed/c10d/init.cpp
torch/distributed/distributed_c10d.py
torch/lib/c10d/ProcessGroupMPI.cpp
torch/lib/c10d/ProcessGroupMPI.hpp

index ff36947..a0c5ab2 100644 (file)
@@ -435,11 +435,17 @@ They are used in specifying strategies for reduction collectives, e.g.,
 #endif
 
 #ifdef USE_C10D_MPI
-  shared_ptr_class_<::c10d::ProcessGroupMPI>(
-      module, "ProcessGroupMPI", processGroup)
-      .def(py::init([](std::vector<int> ranks) {
+  auto processGroupMPI = shared_ptr_class_<::c10d::ProcessGroupMPI>(
+      module, "ProcessGroupMPI", processGroup);
+
+  // Define static create function instead of a constructor, because
+  // this function may return null. This happens if this process is not
+  // part of a sub group that is to be created.
+  processGroupMPI.def_static(
+      "create",
+      [](std::vector<int> ranks) {
         return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks);
-      }));
+      });
 #endif
 
   shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work")
index 6ff0e66..d08e1f9 100644 (file)
@@ -110,8 +110,7 @@ class GroupMember(object):
 
 # Cached process groups
 # For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
-# For MPI pg, it is a map from ProcessGroup to (Backend, Bool), where bool
-# represents if the ProcessGroup objects is part of the group
+# For MPI pg, it is a map from ProcessGroup to (Backend, None)
 _pg_map = {}
 # Process group's names, map from ProcessGroup to str
 _pg_names = {}
@@ -137,15 +136,9 @@ def _rank_not_in_group(group):
     Helper that checks if the current process's rank is not in a given group
 
     """
-    default_backend, _ = _pg_map[_get_default_group()]
-    if default_backend != Backend.MPI:
-        return group == GroupMember.NON_GROUP_MEMBER
-    else:
-        if group == GroupMember.WORLD:
-            return False
-        else:
-            _, in_group = _pg_map[group]
-            return not in_group
+    if group == GroupMember.WORLD:
+        return False
+    return group == GroupMember.NON_GROUP_MEMBER
 
 
 def _get_group_rank(group, rank):
@@ -345,8 +338,7 @@ def init_process_group(backend,
     on a system that supports MPI. The same applies to NCCL as well.
 
     """
-    global _pg_map
-    global _pg_names
+    global _pg_group_ranks
     global _backend
     global _default_pg
     global _default_pg_init_method
@@ -372,12 +364,14 @@ def init_process_group(backend,
     backend = Backend(backend)
 
     if backend == Backend.MPI:
-        if not is_mpi_available():
-            raise RuntimeError("Distributed package doesn't have MPI built in")
-
-        _default_pg = ProcessGroupMPI([])
-        _pg_map[_default_pg] = (Backend.MPI, True)
-        _pg_names[_default_pg] = group_name
+        _default_pg = _new_process_group_helper(
+            -1,
+            -1,
+            [],
+            Backend.MPI,
+            None,
+            group_name=group_name,
+            timeout=timeout)
     else:
         # backward compatible API
         url = init_method
@@ -392,22 +386,16 @@ def init_process_group(backend,
             store, rank, world_size = next(rendezvous(url))
             store.set_timeout(timeout)
 
-        if backend == Backend.GLOO:
-            _default_pg = ProcessGroupGloo(
-                store,
-                rank,
-                world_size,
-                timeout=timeout)
-            _pg_map[_default_pg] = (Backend.GLOO, store)
-            _pg_names[_default_pg] = group_name
-        elif backend == Backend.NCCL:
-            if not is_nccl_available():
-                raise RuntimeError("Distributed package doesn't have NCCL "
-                                   "built in")
-            _default_pg = ProcessGroupNCCL(store, rank, world_size)
-            _pg_map[_default_pg] = (Backend.NCCL, store)
-            _pg_names[_default_pg] = group_name
+        _default_pg = _new_process_group_helper(
+            world_size,
+            rank,
+            [],
+            backend,
+            store,
+            group_name=group_name,
+            timeout=timeout)
 
+    _pg_group_ranks[_default_pg] = {i: i for i in range(_default_pg.size())}
     _backend = _pg_map[_default_pg][0]
     _default_pg_init_method = init_method
 
@@ -415,14 +403,18 @@ def init_process_group(backend,
 def _new_process_group_helper(world_size,
                               rank,
                               group_ranks,
-                              in_group,
-                              group_name,
-                              timeout=_default_pg_timeout,
-                              backend=None):
+                              backend,
+                              store,
+                              group_name=None,
+                              timeout=_default_pg_timeout):
     """
-    Create a new distributed process group. And the new process group can be
-    used to perform collective operations.
+    Create a new distributed process group.
 
+    This function must be called by ALL processes in the global group, even if
+    the calling process is not part of the newly created group. In that case,
+    this function returns GroupMember.NON_GROUP_MEMBER.
+
+    This function is called with ``group_ranks == []`` for the default group.
     """
     global _pg_map
     global _group_count
@@ -440,25 +432,33 @@ def _new_process_group_helper(world_size,
         raise RuntimeError("Expected timeout argument to be of type"
                            "datetime.timedelta")
 
-    default_backend, default_store = _pg_map[_default_pg]
-    if backend is None:
-        backend = default_backend
-    else:
-        backend = Backend(backend)
+    # The list of group ranks is empty if we're creating the default group.
+    is_default_group = (len(group_ranks) == 0)
 
+    backend = Backend(backend)
     if backend == Backend.MPI:
         if not is_mpi_available():
             raise RuntimeError("Distributed package doesn't have MPI built in")
-        pg = ProcessGroupMPI(group_ranks)
-        _pg_map[pg] = (Backend.MPI, in_group)
+        pg = ProcessGroupMPI.create(group_ranks)
+        if not pg:
+            return GroupMember.NON_GROUP_MEMBER
+        _pg_map[pg] = (Backend.MPI, None)
         _pg_names[pg] = group_name
     else:
-        # Create the prefix store
-        store = PrefixStore(group_name, default_store)
+        # If this is a subgroup (which means group_ranks is specified),
+        # we check if the current process is a member of the new group.
+        if not is_default_group:
+            global_rank = _default_pg.rank()
+            if global_rank not in group_ranks:
+                return GroupMember.NON_GROUP_MEMBER
+
+        # Use the group name as prefix in the default store, such that
+        # a single store can be reused by multiple groups.
+        prefix_store = PrefixStore(group_name, store)
 
         if backend == Backend.GLOO:
             pg = ProcessGroupGloo(
-                store,
+                prefix_store,
                 rank,
                 world_size,
                 timeout=timeout)
@@ -468,11 +468,16 @@ def _new_process_group_helper(world_size,
             if not is_nccl_available():
                 raise RuntimeError("Distributed package doesn't have NCCL "
                                    "built in")
-            pg = ProcessGroupNCCL(store, rank, world_size, group_name)
+            pg = ProcessGroupNCCL(
+                prefix_store,
+                rank,
+                world_size,
+                group_name)
             _pg_map[pg] = (Backend.NCCL, store)
             _pg_names[pg] = group_name
         else:
             raise RuntimeError("Unsupported distributed backend by group")
+
     return pg
 
 
@@ -492,15 +497,14 @@ def destroy_process_group(group=group.WORLD):
     global _default_pg
     global _default_pg_init_method
 
-    default_backend, _ = _pg_map[_get_default_group()]
-    if (default_backend != Backend.MPI and
-            group == GroupMember.NON_GROUP_MEMBER):
+    if group == GroupMember.NON_GROUP_MEMBER:
         return
 
     if group == GroupMember.WORLD:
         pg = _default_pg
     else:
         pg = group
+
     if _pg_map.get(pg, None) is None:
         raise RuntimeError("Invalid process group specified")
 
@@ -1257,23 +1261,19 @@ def new_group(ranks=None, timeout=_default_pg_timeout, backend=None):
     _check_default_pg()
 
     global _pg_group_ranks
-    global _group_count
-    global _pg_names
-
-    group_name = str(_group_count)
-    _group_count += 1
-
-    if group_name in _pg_names.values():
-        raise RuntimeError("The specified group name has already been "
-                           "created, please use a different group name")
 
-    default_backend, _ = _pg_map[_default_pg]
+    default_backend, default_store = _pg_map[_default_pg]
     global_rank = _default_pg.rank()
     global_world_size = _default_pg.size()
 
+    # Default to the same backend as the global process group
+    # if the backend is not specified.
+    if not backend:
+        backend = default_backend
+
     # checks the input ranks
     if ranks is not None:
-        input_ranks = list(ranks)
+        ranks = sorted(ranks)
         group_world_size = len(ranks)
         if group_world_size > global_world_size:
             raise RuntimeError("the new group's world size should be less or "
@@ -1289,41 +1289,22 @@ def new_group(ranks=None, timeout=_default_pg_timeout, backend=None):
         else:
             group_rank = None
     else:
-        input_ranks = []
         ranks = list(range(global_world_size))
         group_world_size = global_world_size
         group_rank = global_rank
 
-    if default_backend == Backend.MPI:
-        in_group = global_rank in ranks
-        pg = _new_process_group_helper(group_world_size,
-                                       group_rank,
-                                       input_ranks,
-                                       in_group,
-                                       group_name,
-                                       timeout=timeout)
-    else:
-        # Release ranks not in the group
-        if global_rank not in ranks:
-            return GroupMember.NON_GROUP_MEMBER
-
-        if default_backend != Backend.MPI:
-            if backend is None:
-                backend = default_backend
-            pg = _new_process_group_helper(group_world_size,
-                                           group_rank,
-                                           input_ranks,
-                                           True,
-                                           group_name,
-                                           timeout=timeout,
-                                           backend=backend)
+    backend = Backend(backend)
+    pg = _new_process_group_helper(group_world_size,
+                                   group_rank,
+                                   ranks,
+                                   backend,
+                                   default_store,
+                                   timeout=timeout)
 
     # Create the global rank to group rank mapping
-    _pg_group_ranks[pg] = {}
-    if default_backend == Backend.MPI:
-        _pg_group_ranks[pg] = pg.group_ranks()
-    else:
-        for rank in range(global_world_size):
-            if rank in ranks:
-                _pg_group_ranks[pg][rank] = ranks.index(rank)
+    _pg_group_ranks[pg] = {
+        global_rank: group_rank
+        for group_rank, global_rank in enumerate(ranks)
+    }
+
     return pg
index 0c8b3d1..395bbb8 100644 (file)
@@ -162,7 +162,6 @@ void ProcessGroupMPI::AsyncWork::populateException() {
 }
 
 // Static global states
-int ProcessGroupMPI::numProcessGroups_ = 0;
 int ProcessGroupMPI::mpiThreadSupport_ = 0;
 std::mutex ProcessGroupMPI::pgGlobalMutex_;
 // We only want to initialize once
@@ -196,73 +195,52 @@ std::shared_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI(
   // Once initialization
   initMPIOnce();
 
-  std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
-
+  MPI_Comm groupComm = MPI_COMM_WORLD;
   int rank = -1;
   int size = -1;
 
-  // Update the world size and rank
-  MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &size));
-  MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank));
+  {
+    std::lock_guard<std::mutex> globalLock(pgGlobalMutex_);
+
+    // If no ranks are specified, assume we're creating the root group
+    if (!ranks.empty()) {
+      MPI_Group worldGroup;
+      MPI_Group ranksGroup;
+      MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
+      MPI_CHECK(
+          MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup));
+      MPI_CHECK(MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm));
+      MPI_CHECK(MPI_Group_free(&worldGroup));
+      MPI_CHECK(MPI_Group_free(&ranksGroup));
+    }
 
-  if (rank < 0 || size < 0) {
-    throw std::runtime_error("Failed to get the world_size / rank");
-  }
+    // Fetch rank and world size for this group (MPI_COMM_WORLD or new)
+    if (groupComm != MPI_COMM_NULL) {
+      MPI_CHECK(MPI_Comm_rank(groupComm, &rank));
+      MPI_CHECK(MPI_Comm_size(groupComm, &size));
 
-  // If no ranks are specified, assume we're creating the root group
-  if (ranks.empty()) {
-    globalLock.unlock();
-    return std::make_shared<ProcessGroupMPI>(rank, size, MPI_COMM_WORLD);
+      if (rank < 0 || size < 0) {
+        throw std::runtime_error("Failed to get the world_size / rank");
+      }
+    }
   }
 
-  MPI_Group worldGroup;
-  MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
-
-  MPI_Group ranksGroup;
-  MPI_CHECK(
-      MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup));
-
-  MPI_Comm groupComm;
-  MPI_CHECK(MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm));
-
-  MPI_CHECK(MPI_Group_free(&worldGroup));
-  MPI_CHECK(MPI_Group_free(&ranksGroup));
+  // If this process is not part of the group, we don't construct a
+  // process group instance. This is in line with the semantics of the
+  // other process group types.
+  if (groupComm == MPI_COMM_NULL) {
+    return std::shared_ptr<ProcessGroupMPI>();
+  }
 
-  globalLock.unlock();
   return std::make_shared<ProcessGroupMPI>(rank, size, groupComm);
 }
 
 ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm)
-    : ProcessGroup(rank, size),
-      stop_(false),
-      pgComm_(pgComm),
-      groupRank_(-1),
-      groupSize_(-1) {
-  std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
-
-  if (pgComm_ != MPI_COMM_NULL) {
-    MPI_CHECK(MPI_Comm_rank(pgComm_, &groupRank_));
-    MPI_CHECK(MPI_Comm_size(pgComm_, &groupSize_));
-    std::vector<int> rankToGroupRank{rank_, groupRank_};
-    std::vector<int> allRankToGroupRank;
-    allRankToGroupRank.resize(2 * groupSize_);
-    MPI_CHECK(MPI_Allgather(
-        rankToGroupRank.data(),
-        2,
-        MPI_INT,
-        allRankToGroupRank.data(),
-        2,
-        MPI_INT,
-        pgComm_));
-    for (size_t i = 0; i < allRankToGroupRank.size(); i += 2) {
-      groupRankMap_[allRankToGroupRank[i]] = allRankToGroupRank[i + 1];
-    }
+    : ProcessGroup(rank, size), stop_(false), pgComm_(pgComm) {
+  if (pgComm_ == MPI_COMM_NULL) {
+    throw std::runtime_error("pgComm_ must not be MPI_COMM_NULL");
   }
 
-  // increase the total PG count
-  ++numProcessGroups_;
-  globalLock.unlock();
-
   // Start the worker thread accepting MPI calls
   workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this);
 }
@@ -284,10 +262,6 @@ void ProcessGroupMPI::destroy() {
 
   // Join the single worker thread
   workerThread_.join();
-
-  // Decrease the number of PG created
-  std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
-  --numProcessGroups_;
 }
 
 void ProcessGroupMPI::abort() {
@@ -338,9 +312,6 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::enqueue(
 std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::broadcast(
     std::vector<at::Tensor>& tensors,
     const BroadcastOptions& opts) {
-  if (pgComm_ == MPI_COMM_NULL) {
-    return nullptr;
-  }
   checkSingleTensor(tensors);
   std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
       [opts, this](std::unique_ptr<WorkEntry>& entry) {
@@ -361,9 +332,6 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::broadcast(
 std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allreduce(
     std::vector<at::Tensor>& tensors,
     const AllreduceOptions& opts) {
-  if (pgComm_ == MPI_COMM_NULL) {
-    return nullptr;
-  }
   checkSingleTensor(tensors);
 
   std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
@@ -386,17 +354,14 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allreduce(
 std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::reduce(
     std::vector<at::Tensor>& tensors,
     const ReduceOptions& opts) {
-  if (pgComm_ == MPI_COMM_NULL) {
-    return nullptr;
-  }
   checkSingleTensor(tensors);
 
   std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
       [opts, this](std::unique_ptr<WorkEntry>& entry) {
         auto data = (entry->src)[0];
         auto dataPtr = (entry->src)[0].data_ptr();
-        void* sendbuf = (groupRank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr;
-        void* recvbuf = (groupRank_ == opts.rootRank) ? dataPtr : nullptr;
+        void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr;
+        void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr;
 
         std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
         MPI_CHECK(MPI_Reduce(
@@ -417,16 +382,13 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allgather(
     std::vector<std::vector<at::Tensor>>& outputTensors,
     std::vector<at::Tensor>& inputTensors,
     const AllgatherOptions& opts) {
-  if (pgComm_ == MPI_COMM_NULL) {
-    return nullptr;
-  }
   checkSingleTensor(inputTensors);
   if (outputTensors.size() != 1) {
     throw std::runtime_error(
         "MPI process group only supports a single "
         "tensor op");
   }
-  if (static_cast<size_t>(groupSize_) != outputTensors[0].size()) {
+  if (static_cast<size_t>(size_) != outputTensors[0].size()) {
     throw std::runtime_error(
         "All gather: number of output tensors should equal "
         "to the world size");
@@ -463,12 +425,9 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::gather(
     std::vector<std::vector<at::Tensor>>& outputTensors,
     std::vector<at::Tensor>& inputTensors,
     const GatherOptions& opts) {
-  if (pgComm_ == MPI_COMM_NULL) {
-    return nullptr;
-  }
   checkSingleTensor(inputTensors);
 
-  if (groupRank_ != opts.rootRank) {
+  if (rank_ != opts.rootRank) {
     if (outputTensors.size() > 0) {
       throw std::runtime_error(
           "Gather: number of output tensors should be 0 "
@@ -478,7 +437,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::gather(
     if (outputTensors.size() != 1) {
       throw std::runtime_error("Gather: multi-GPU collective is not supported");
     }
-    if (static_cast<size_t>(groupSize_) != outputTensors[0].size()) {
+    if (static_cast<size_t>(size_) != outputTensors[0].size()) {
       throw std::runtime_error(
           "Gather: number of output tensors should equal "
           "to the world size");
@@ -492,7 +451,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::gather(
         void* recvbuf = nullptr;
         at::Tensor flatOutputTensor;
 
-        if (groupRank_ == opts.rootRank) {
+        if (rank_ == opts.rootRank) {
           flatOutputTensor = newLikeFlat(entry->dst);
           recvbuf = flatOutputTensor.data_ptr();
         }
@@ -508,7 +467,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::gather(
             opts.rootRank,
             pgComm_));
 
-        if (groupRank_ == opts.rootRank) {
+        if (rank_ == opts.rootRank) {
           std::vector<at::Tensor>& outputDataVec = entry->dst;
           // copy the flattened output tensors to the outputs
           for (size_t i = 0; i < outputDataVec.size(); ++i) {
@@ -517,7 +476,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::gather(
         }
       };
 
-  if (groupRank_ == opts.rootRank) {
+  if (rank_ == opts.rootRank) {
     auto entry = std::unique_ptr<WorkEntry>(
         new WorkEntry(&inputTensors, &outputTensors[0], std::move(runFunc)));
     return enqueue(std::move(entry));
@@ -532,12 +491,9 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::scatter(
     std::vector<at::Tensor>& outputTensors,
     std::vector<std::vector<at::Tensor>>& inputTensors,
     const ScatterOptions& opts) {
-  if (pgComm_ == MPI_COMM_NULL) {
-    return nullptr;
-  }
   checkSingleTensor(outputTensors);
 
-  if (groupRank_ != opts.rootRank) {
+  if (rank_ != opts.rootRank) {
     if (inputTensors.size() > 0) {
       throw std::runtime_error(
           "Scatter: number of input tensors should be 0 "
@@ -545,9 +501,10 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::scatter(
     }
   } else {
     if (inputTensors.size() != 1) {
-      throw std::runtime_error("Scatter: multi-GPU collective is not supported");
+      throw std::runtime_error(
+          "Scatter: multi-GPU collective is not supported");
     }
-    if (static_cast<size_t>(groupSize_) != inputTensors[0].size()) {
+    if (static_cast<size_t>(size_) != inputTensors[0].size()) {
       throw std::runtime_error(
           "Scatter: number of input tensors should equal "
           "to the world size");
@@ -561,7 +518,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::scatter(
         void* sendbuf = nullptr;
         at::Tensor flatInputTensor;
 
-        if (groupRank_ == opts.rootRank) {
+        if (rank_ == opts.rootRank) {
           std::vector<at::Tensor>& inputDataVec = entry->src;
           flatInputTensor = newLikeFlat(inputDataVec);
           sendbuf = flatInputTensor.data_ptr();
@@ -584,7 +541,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::scatter(
             pgComm_));
       };
 
-  if (groupRank_ == opts.rootRank) {
+  if (rank_ == opts.rootRank) {
     auto entry = std::unique_ptr<WorkEntry>(
         new WorkEntry(&inputTensors[0], &outputTensors, std::move(runFunc)));
     return enqueue(std::move(entry));
@@ -599,10 +556,6 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::send(
     std::vector<at::Tensor>& tensors,
     int dstRank,
     int tag) {
-  if (pgComm_ == MPI_COMM_NULL) {
-    return nullptr;
-  }
-
   checkSingleTensor(tensors);
 
   auto& tensor = tensors[0];
@@ -627,10 +580,6 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::recv(
     std::vector<at::Tensor>& tensors,
     int srcRank,
     int tag) {
-  if (pgComm_ == MPI_COMM_NULL) {
-    return nullptr;
-  }
-
   checkSingleTensor(tensors);
 
   auto& tensor = tensors[0];
@@ -654,10 +603,6 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::recv(
 std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::recvAnysource(
     std::vector<at::Tensor>& tensors,
     int tag) {
-  if (pgComm_ == MPI_COMM_NULL) {
-    return nullptr;
-  }
-
   checkSingleTensor(tensors);
 
   auto& tensor = tensors[0];
@@ -680,9 +625,6 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::recvAnysource(
 
 std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::barrier(
     const BarrierOptions& opts) {
-  if (pgComm_ == MPI_COMM_NULL) {
-    return nullptr;
-  }
   std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
       [this](std::unique_ptr<WorkEntry>& entry) {
         std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
@@ -694,7 +636,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::barrier(
 }
 
 std::unordered_map<int, int> ProcessGroupMPI::getGroupRank() {
-  return groupRankMap_;
+  throw std::runtime_error("ProcessGroupMPI does not support getGroupRank");
 }
 
 } // namespace c10d
index 5d7ec5c..b3511dc 100644 (file)
@@ -181,13 +181,9 @@ class ProcessGroupMPI : public ProcessGroup {
   static std::once_flag onceFlagInitMPI;
 
   static std::mutex pgGlobalMutex_;
-  static int numProcessGroups_;
   static int mpiThreadSupport_;
 
   MPI_Comm pgComm_;
-  int groupRank_;
-  int groupSize_;
-  std::unordered_map<int, int> groupRankMap_;
 };
 
 } // namespace c10d