[DDP] Log num threads (#64072)
authorRohan Varma <rvarm1@fb.com>
Thu, 2 Sep 2021 01:12:02 +0000 (18:12 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 01:36:15 +0000 (18:36 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64072

Log gloo threads to DDP logging.
ghstack-source-id: 137119480

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D30596083

fbshipit-source-id: 2b4f6e762cb5d850be6056bcc5922029a1af3c91

torch/csrc/distributed/c10d/ProcessGroupGloo.hpp
torch/csrc/distributed/c10d/logger.cpp
torch/testing/_internal/distributed/distributed_test.py

index 036ce91..5c0c76a 100644 (file)
@@ -318,6 +318,10 @@ class TORCH_API ProcessGroupGloo : public ProcessGroup {
   // may indicate that there is some sort of collective desynchronization.
   uint64_t getSequenceNumberForGroup() override;
 
+  int getNumThreads() {
+    return options_->threads;
+  }
+
  protected:
   std::unique_ptr<::gloo::rendezvous::Store> store_;
   const c10::intrusive_ptr<Options> options_;
index 0bb960a..b1efd0b 100644 (file)
@@ -4,6 +4,10 @@
 #include <fmt/format.h>
 #include <string>
 
+#ifdef USE_C10D_GLOO
+#include <c10d/ProcessGroupGloo.hpp>
+#endif
+
 namespace c10d {
 
 // When training runs at these iterations, log the runtime
@@ -68,6 +72,13 @@ void Logger::set_env_variables() {
         parse_env("GLOO_SOCKET_IFNAME");
     ddp_logging_data_->strs_map["gloo_device_transport"] =
         parse_env("GLOO_DEVICE_TRANSPORT");
+
+    #ifdef USE_C10D_GLOO
+    auto gloo_pg =
+        static_cast<c10d::ProcessGroupGloo*>(reducer_->process_group_.get());
+    auto n_threads = gloo_pg->getNumThreads();
+    ddp_logging_data_->ints_map["gloo_num_threads"] = n_threads;
+    #endif
   }
 }
 
index f17842e..613e23e 100644 (file)
@@ -5074,6 +5074,12 @@ class DistributedTest:
                     ddp_logging_data.get("gloo_device_transport"),
                     parse_env("GLOO_DEVICE_TRANSPORT"),
                 )
+                default_gloo_threads = 2
+                self.assertEqual(
+                    ddp_logging_data.get("gloo_num_threads"),
+                    default_gloo_threads,
+                )
+
             self.assertEqual(ddp_logging_data.get("nccl_socket_ifname"), None)
             self.assertEqual(ddp_logging_data.get("nccl_blocking_wait"), None)
             self.assertEqual(ddp_logging_data.get("nccl_async_error_handling"), None)