Added logging for the Reducer's non-member functions. (#65023)
authorAvery Wang <averywang@fb.com>
Thu, 16 Sep 2021 23:37:52 +0000 (16:37 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 16 Sep 2021 23:39:39 +0000 (16:39 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65023

Added an optional logging parameter for non-member functions `compute_bucket_assignment_by_size` and `verify_replica0_across_processes`. If a logger is provided then `TORCH_CHECK` assertions are replaced with a wrapper that logs the error to the DDP reducer's logger before calling `TORCH_CHECK`. If a logger is not provided `TORCH_CHECK` is still called.

Modified python-side calls to `_compute_bucket_assignment_by_size` and `_verify_model_across_ranks` to include a logger whenever possible. A notable exception is when these non-member functions are called in DDP's constructor - we cannot pass in a logger as they may have not been initialized yet.

We also added 4 new tests: `test_compute_bucket_assignment_by_size_sparse_error_{with, without}_logger` which tests the `_compute_bucket_assignment_by_size` function to ensure that sparse tensors are rejected and the errors are logged.  `test_verify_model_across_rank_{with, without}_logger` calls `_verify_model_across_ranks` to ensure that ill-formed models (different ranks have different number of parameters compared to rank 0) are rejected and the errors are logged. The test `test_ddp_model_diff_across_ranks` remains unchanged - while it does construct a ill-formed DDP instance which triggers the error in `_verify_model_across_ranks`, we cannot check the logger because this error appears in the constructor.

Lastly, did some cleanup of the `test_ddp_model_diff_across_ranks` function to make the logic of choosing which context manager and error message to use more clean.

Test Plan:
**Build commands**
`buck build mode/dev-nosan //caffe2/test/distributed:distributed_nccl_spawn --keep-going`

`buck build mode/dev-nosan //caffe2/test/distributed:distributed_gloo_spawn --keep-going`

**Test commands**
Test for `_compute_bucket_assignment_by_size` (Python)/ `compute_bucket_assignment_by_size` (C++)
`BACKEND={nccl, gloo} WORLD_SIZE=2 ../buck-out/dev/gen/caffe2/test/distributed/distributed_{nccl, gloo}_spawn#binary.par -r test_compute_bucket_assignment_by_size_sparse_error_{with, without}_logger`

Test for `_verify_model_across_ranks` (Python)/`verify_replicas0_across_process` (C++)
`BACKEND={nccl, gloo} WORLD_SIZE=2 ../buck-out/dev/gen/caffe2/test/distributed/distributed_{nccl, gloo}_spawn#binary.par -r test_verify_model_across_ranks_{with, without}_logger`

Test that constructs an ill-formed DDP instance. Only did cleanup of this function.
`BACKEND={nccl, gloo} WORLD_SIZE=2 ../buck-out/dev/gen/caffe2/test/distributed/distributed_{nccl, gloo}_spawn#binary.par -r test_ddp_model_diff_across_ranks`

Reviewed By: rohan-varma

Differential Revision: D30924790

fbshipit-source-id: dae6fa82485a204a6a4b022f2d073417d07ebb2f

torch/csrc/distributed/c10d/init.cpp
torch/csrc/distributed/c10d/reducer.cpp
torch/csrc/distributed/c10d/reducer.hpp
torch/testing/_internal/distributed/distributed_test.py

index 4bac0ca..77e2dae 100644 (file)
@@ -1544,18 +1544,40 @@ Example::
 
   module.def(
       "_compute_bucket_assignment_by_size",
-      &::c10d::compute_bucket_assignment_by_size,
+      [](const std::vector<at::Tensor>& tensors,
+            const std::vector<size_t>& bucket_size_limits,
+            const std::vector<bool>& expect_sparse_gradient,
+            const std::vector<int64_t>& tensor_indices,
+            const c10::optional<std::shared_ptr<::c10d::Logger>>& logger) {
+             if (logger.has_value()) {
+                std::weak_ptr<::c10d::Logger> logger_weakref = logger.value();
+                return ::c10d::compute_bucket_assignment_by_size(tensors, bucket_size_limits, expect_sparse_gradient, tensor_indices, {logger_weakref});
+             } else {
+                return ::c10d::compute_bucket_assignment_by_size(tensors, bucket_size_limits, expect_sparse_gradient, tensor_indices, {});
+             }
+      },
       py::arg("tensors"),
       py::arg("bucket_size"),
       py::arg("expect_sparse_gradient") = std::vector<bool>(),
       py::arg("tensor_indices") = std::vector<int64_t>(),
+      py::arg("logger") = c10::optional<std::shared_ptr<::c10d::Logger>>{},
       py::call_guard<py::gil_scoped_release>());
 
   module.def(
       "_verify_model_across_ranks",
-      &::c10d::verify_replica0_across_processes,
+      [](const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
+         const std::vector<std::vector<at::Tensor>>& model_replicas,
+         const c10::optional<std::shared_ptr<::c10d::Logger>>& logger) {
+             if (logger.has_value()) {
+                std::weak_ptr<::c10d::Logger> logger_weakref = logger.value();
+                verify_replica0_across_processes(process_group, model_replicas, {logger_weakref});
+             } else {
+                verify_replica0_across_processes(process_group, model_replicas, {});
+             }
+      },
       py::arg("process_group"),
       py::arg("replicas"),
+      py::arg("logger") = c10::optional<std::shared_ptr<::c10d::Logger>>{},
       py::call_guard<py::gil_scoped_release>());
 
   module.def(
index 91db615..3e726f4 100644 (file)
@@ -1702,7 +1702,8 @@ bool Reducer::rebuild_buckets() {
           rebuilt_params_,
           bucket_size_limits,
           expect_sparse_gradients_[0],
-          rebuilt_param_indices_);
+          rebuilt_param_indices_,
+          logger_);
 
   if (ddp_set_last_bucket_as_small) {
     // Reverse again because buckets were rebuilt in the opposite of gradient
@@ -1958,7 +1959,8 @@ compute_bucket_assignment_by_size(
     const std::vector<at::Tensor>& tensors,
     const std::vector<size_t>& bucket_size_limits,
     const std::vector<bool>& expect_sparse_gradient,
-    const std::vector<int64_t>& tensor_indices) {
+    const std::vector<int64_t>& tensor_indices,
+    const c10::optional<std::weak_ptr<c10d::Logger>>& logger) {
   // Either expect_sparse_gradient is not specified or it has as many elements
   // as the vector with tensors.
   TORCH_INTERNAL_ASSERT(
@@ -1988,9 +1990,12 @@ compute_bucket_assignment_by_size(
 
   for (const auto i : c10::irange(tensors.size())) {
     const auto& tensor = tensors[i];
-    // TODO: This is not a reducer method so it does not have access to logger,
-    // pass in logger directly here.
-    TORCH_CHECK(!tensor.is_sparse(), "No support for sparse tensors.");
+    auto msg = std::string("No support for sparse tensors.");
+    if (logger.has_value()) {
+      REDUCER_CHECK(!tensor.is_sparse(), logger.value(), msg);
+    } else {
+      TORCH_CHECK(!tensor.is_sparse(), msg);
+    }
 
     // when tensor_indices is empty, the index of tensors[i] assigned to
     // bucket is i, otherwise the tensor index is tensor_indices[i].
@@ -2078,8 +2083,9 @@ compute_bucket_assignment_by_size(
 // Verifies corresponding params in replica 0 have the same sizes/strides
 // across processes.
 void verify_replica0_across_processes(
-    c10::intrusive_ptr<c10d::ProcessGroup> process_group,
-    std::vector<std::vector<at::Tensor>> model_replicas) {
+    const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
+    const std::vector<std::vector<at::Tensor>>& model_replicas,
+    const c10::optional<std::weak_ptr<c10d::Logger>>& logger) {
   size_t i = 0;
   for (const auto& t : model_replicas[0]) {
     i += 2 * t.dim();
@@ -2116,26 +2122,27 @@ void verify_replica0_across_processes(
     // I'd like to include which process we are in the message,
     // but ProcessGroup::getRank is not public!
     for (const auto& sz : t.sizes()) {
-      // TODO: pass in logger and use REDUCER_CHECK.
-      TORCH_CHECK(
-          sz == control_accessor[i++],
-          "replicas[0][",
-          p,
-          "] in this process"
-          " with sizes ",
-          t.sizes(),
-          " appears not to match sizes of the same param in process 0.");
+      auto msg = c10::str("replicas[0][", p, "] in this process",
+                        " with sizes ",
+                        t.sizes(),
+                        " appears not to match sizes of the same param in process 0.");
+      if (logger.has_value()) {
+        REDUCER_CHECK(sz == control_accessor[i++], logger.value(), msg)
+      } else {
+        TORCH_CHECK(sz == control_accessor[i++], msg)
+      }
+
     }
     for (const auto& str : t.strides()) {
-      // TODO: pass in logger and use REDUCER_CHECK.
-      TORCH_CHECK(
-          str == control_accessor[i++],
-          "replicas[0][",
-          p,
-          "] in this process"
-          " with strides ",
-          t.strides(),
-          " appears not to match strides of the same param in process 0.");
+      auto msg = c10::str("replicas[0][", p, "] in this process",
+                        " with sizes ",
+                        t.sizes(),
+                        " appears not to match strides of the same param in process 0.");
+      if (logger.has_value()) {
+        REDUCER_CHECK(str == control_accessor[i++], logger.value(), msg)
+      } else {
+        TORCH_CHECK(str == control_accessor[i++], msg)
+      }
     }
   }
 }
index 9f46e63..8dc42b6 100644 (file)
@@ -602,11 +602,13 @@ compute_bucket_assignment_by_size(
     const std::vector<at::Tensor>& tensors,
     const std::vector<size_t>& bucket_size,
     const std::vector<bool>& expect_sparse_gradient = {},
-    const std::vector<int64_t>& tensor_indices = {});
+    const std::vector<int64_t>& tensor_indices = {},
+    const c10::optional<std::weak_ptr<c10d::Logger>>& logger = {});
 
 // Verify models across all processes are the same as model on rank 0 with
 // respect to no. of params and matching dtype/size/layout.
 TORCH_API void verify_replica0_across_processes(
-    c10::intrusive_ptr<c10d::ProcessGroup> process_group,
-    std::vector<std::vector<at::Tensor>> model_replicas);
+    const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
+    const std::vector<std::vector<at::Tensor>>& model_replicas,
+    const c10::optional<std::weak_ptr<c10d::Logger>>& logger);
 } // namespace c10d
index 613e23e..999b2dc 100644 (file)
@@ -6857,11 +6857,20 @@ class DistributedTest:
             ):
                 dist.scatter_object_list([], scatter_list, src=src_rank)
 
-        @require_backend({"gloo", "nccl"})
-        @require_backends_available({"gloo", "nccl"})
-        @skip_if_lt_x_gpu(2)
-        @skip_if_rocm
-        def test_ddp_model_diff_across_ranks(self):
+        def _generate_sparse_tensors_for_bucket_assignment_test(self):
+            tensors = [
+                torch.empty([50], dtype=torch.float),
+                torch.empty([25], dtype=torch.double),
+                torch.empty([50], dtype=torch.float),
+                torch.empty([25], dtype=torch.double),
+                torch.empty([50], dtype=torch.float),
+                torch.empty([25], dtype=torch.double),
+            ]
+
+            tensors_sparse = [t.to_sparse() for t in tensors]
+            return tensors_sparse
+
+        def _test_compute_bucket_assignment_by_size(self, use_logger):
             group_gloo = dist.new_group(
                 timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
             )
@@ -6872,9 +6881,49 @@ class DistributedTest:
                 backend=dist.get_backend(), timeout=timedelta(seconds=5)
             )
             torch.cuda.set_device(self.rank)
-            # Creates network with different sized embedding table on different
-            # ranks. This should throw an error during DDP init.
-            net = EmbeddingNet(self.rank)
+
+            # Create a valid model. The constructor initializes the logger that we use later.
+            # We never actually use the rest of the model - we only need its logger.
+            net = EmbeddingNet(0)
+            net = torch.nn.parallel.DistributedDataParallel(
+                net.to(self.rank),
+                device_ids=[self.rank],
+                process_group=group_to_use,
+            )
+
+            # if we don't pass a logger then we can only check that an exception was thrown.
+            expected_err = "No support for sparse tensors."
+            with self.assertRaisesRegex(RuntimeError, expected_err):
+                tensors_sparse = self._generate_sparse_tensors_for_bucket_assignment_test()
+                if use_logger:
+                    result = dist._compute_bucket_assignment_by_size(
+                        tensors_sparse,
+                        [400],
+                        logger=net.logger)
+                else:
+                    result = dist._compute_bucket_assignment_by_size(tensors_sparse, [400])
+            if use_logger:
+                verify_ddp_error_logged(net, expected_err)
+
+            # Perform gloo-based barrier to ensure one rank doesn't exit test
+            # early which causes failure with Barrier.sync.
+            dist.barrier(group_gloo)
+
+        @require_backend({"gloo", "nccl"})
+        @require_backends_available({"gloo", "nccl"})
+        @skip_if_lt_x_gpu(2)
+        @skip_if_rocm
+        def test_compute_bucket_assignment_by_size_sparse_error_without_logger(self):
+            self._test_compute_bucket_assignment_by_size(use_logger=False)
+
+        @require_backend({"gloo", "nccl"})
+        @require_backends_available({"gloo", "nccl"})
+        @skip_if_lt_x_gpu(2)
+        @skip_if_rocm
+        def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self):
+            self._test_compute_bucket_assignment_by_size(use_logger=True)
+
+        def _determine_expected_error_verify_model_across_rank(self, group_to_use):
             # When running with NCCL backend, we don't expect an error on rank 0,
             # rather, it will be taken down by NCCL_ASYNC_ERROR_HANDLING. When
             # running with Gloo or with debug mode wrapper, we expect the error
@@ -6882,21 +6931,98 @@ class DistributedTest:
             is_detail_dbg_mode = (
                 dist._get_debug_mode() == dist._DistributedDebugLevel.DETAIL
             )
-            rank_0_ctx = (
-                self.assertRaisesRegex(
-                    RuntimeError, "Caught collective operation timeout"
-                )
-                if dist.get_backend(group_to_use) == dist.Backend.NCCL
-                and not is_detail_dbg_mode
-                # Gloo can raise various exception messages, so just assert
-                # Runtime error here.
-                else self.assertRaises(RuntimeError)
+            if self.rank == 0:
+                if dist.get_backend(group_to_use) == dist.Backend.NCCL and not is_detail_dbg_mode:
+                    expected_err = "Caught collective operation timeout"
+                    ctx = self.assertRaisesRegex(RuntimeError, expected_err)
+                else:
+                    expected_err = None
+                    ctx = self.assertRaises(RuntimeError)
+            else:
+                expected_err = "appears not to match"
+                ctx = self.assertRaisesRegex(RuntimeError, expected_err)
+            return ctx, expected_err
+
+        def _test_verify_model_across_rank(self, use_logger):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
             )
-            ctx = (
-                rank_0_ctx
-                if self.rank == 0
-                else self.assertRaisesRegex(RuntimeError, "appears not to match")
+            # Set NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
+            # determinism.
+            os.environ["NCCL_BLOCKING_WAIT"] = "1"
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=5)
             )
+            torch.cuda.set_device(self.rank)
+            ctx, expected_err = self._determine_expected_error_verify_model_across_rank(group_to_use)
+
+            # Create a valid model. The constructor initializes the logger that we use later.
+            net = EmbeddingNet(0)
+            net = torch.nn.parallel.DistributedDataParallel(
+                net.to(self.rank),
+                device_ids=[self.rank],
+                process_group=group_to_use,
+            )
+
+            # Modify the model so that the number of parameters are different for each rank.
+            # This will cause a RuntimeError to be thrown below in dist._verify_model_across_ranks,
+            # so we can check if the correct error is thrown and is logged.
+            # We can't do this in the constructor above otherwise the logger will
+            # not be properly initialized.
+            net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1)
+
+            # if we pass a logger we can verify that it was logged
+            with ctx:
+                if use_logger:
+                    dist._verify_model_across_ranks(net.process_group, [list(net.parameters())], net.logger)
+                else:
+                    dist._verify_model_across_ranks(net.process_group, [list(net.parameters())])
+                # Should only be run by rank 0, and blocking_wait catches and
+                # reports exception.
+                dist.barrier(group_to_use)
+
+            # We don't check when self.rank != 0 because the logger doesn't log
+            # the error "Caught collective operation" as that is not thrown in the reducer.
+            if use_logger and self.rank != 0:
+                verify_ddp_error_logged(net, expected_err)
+
+            # Perform gloo-based barrier to ensure one rank doesn't exit test
+            # early which causes failure with Barrier.sync.
+            dist.barrier(group_gloo)
+
+        @require_backend({"gloo", "nccl"})
+        @require_backends_available({"gloo", "nccl"})
+        @skip_if_lt_x_gpu(2)
+        @skip_if_rocm
+        def test_verify_model_across_rank_with_logger(self):
+            self._test_verify_model_across_rank(use_logger=True)
+
+        @require_backend({"gloo", "nccl"})
+        @require_backends_available({"gloo", "nccl"})
+        @skip_if_lt_x_gpu(2)
+        @skip_if_rocm
+        def test_verify_model_across_rank_without_logger(self):
+            self._test_verify_model_across_rank(use_logger=False)
+
+        @require_backend({"gloo", "nccl"})
+        @require_backends_available({"gloo", "nccl"})
+        @skip_if_lt_x_gpu(2)
+        @skip_if_rocm
+        def test_ddp_model_diff_across_ranks(self):
+            group_gloo = dist.new_group(
+                timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
+            )
+            # Set NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
+            # determinism.
+            os.environ["NCCL_BLOCKING_WAIT"] = "1"
+            group_to_use = dist.new_group(
+                backend=dist.get_backend(), timeout=timedelta(seconds=5)
+            )
+            torch.cuda.set_device(self.rank)
+            ctx, expected_err = self._determine_expected_error_verify_model_across_rank(group_to_use)
+            # Creates network with different sized embedding table on different
+            # ranks. This should throw an error during DDP init.
+            net = EmbeddingNet(self.rank)
             with ctx:
                 net = torch.nn.parallel.DistributedDataParallel(
                     net.to(self.rank),
@@ -6906,6 +7032,7 @@ class DistributedTest:
                 # Should only be run by rank 0, and blocking_wait catches and
                 # reports exception.
                 dist.barrier(group_to_use)
+            # can't use verify_ddp_error_logged here because net was never properly constructed
 
             # Perform gloo-based barrier to ensure one rank doesn't exit test
             # early which causes failure with Barrier.sync.