// may indicate that there is some sort of collective desynchronization.
uint64_t getSequenceNumberForGroup() override;
+ int getNumThreads() {
+ return options_->threads;
+ }
+
protected:
std::unique_ptr<::gloo::rendezvous::Store> store_;
const c10::intrusive_ptr<Options> options_;
#include <fmt/format.h>
#include <string>
+#ifdef USE_C10D_GLOO
+#include <c10d/ProcessGroupGloo.hpp>
+#endif
+
namespace c10d {
// When training runs at these iterations, log the runtime
parse_env("GLOO_SOCKET_IFNAME");
ddp_logging_data_->strs_map["gloo_device_transport"] =
parse_env("GLOO_DEVICE_TRANSPORT");
+
+ #ifdef USE_C10D_GLOO
+ auto gloo_pg =
+ static_cast<c10d::ProcessGroupGloo*>(reducer_->process_group_.get());
+ auto n_threads = gloo_pg->getNumThreads();
+ ddp_logging_data_->ints_map["gloo_num_threads"] = n_threads;
+ #endif
}
}
ddp_logging_data.get("gloo_device_transport"),
parse_env("GLOO_DEVICE_TRANSPORT"),
)
+ default_gloo_threads = 2
+ self.assertEqual(
+ ddp_logging_data.get("gloo_num_threads"),
+ default_gloo_threads,
+ )
+
self.assertEqual(ddp_logging_data.get("nccl_socket_ifname"), None)
self.assertEqual(ddp_logging_data.get("nccl_blocking_wait"), None)
self.assertEqual(ddp_logging_data.get("nccl_async_error_handling"), None)