From 0a5149019f71735554bb89d1f4e9ad198ceb9df3 Mon Sep 17 00:00:00 2001 From: Avery Wang Date: Thu, 16 Sep 2021 16:37:52 -0700 Subject: [PATCH] Added logging for the Reducer's non-member functions. (#65023) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65023 Added an optional logging parameter for non-member functions `compute_bucket_assignment_by_size` and `verify_replica0_across_processes`. If a logger is provided then `TORCH_CHECK` assertions are replaced with a wrapper that logs the error to the DDP reducer's logger before calling `TORCH_CHECK`. If a logger is not provided `TORCH_CHECK` is still called. Modified python-side calls to `_compute_bucket_assignment_by_size` and `_verify_model_across_ranks` to include a logger whenever possible. A notable exception is when these non-member functions are called in DDP's constructor - we cannot pass in a logger as they may have not been initialized yet. We also added 4 new tests: `test_compute_bucket_assignment_by_size_sparse_error_{with, without}_logger` which tests the `_compute_bucket_assignment_by_size` function to ensure that sparse tensors are rejected and the errors are logged. `test_verify_model_across_rank_{with, without}_logger` calls `_verify_model_across_ranks` to ensure that ill-formed models (different ranks have different number of parameters compared to rank 0) are rejected and the errors are logged. The test `test_ddp_model_diff_across_ranks` remains unchanged - while it does construct a ill-formed DDP instance which triggers the error in `_verify_model_across_ranks`, we cannot check the logger because this error appears in the constructor. Lastly, did some cleanup of the `test_ddp_model_diff_across_ranks` function to make the logic of choosing which context manager and error message to use more clean. Test Plan: **Build commands** `buck build mode/dev-nosan //caffe2/test/distributed:distributed_nccl_spawn --keep-going` `buck build mode/dev-nosan //caffe2/test/distributed:distributed_gloo_spawn --keep-going` **Test commands** Test for `_compute_bucket_assignment_by_size` (Python)/ `compute_bucket_assignment_by_size` (C++) `BACKEND={nccl, gloo} WORLD_SIZE=2 ../buck-out/dev/gen/caffe2/test/distributed/distributed_{nccl, gloo}_spawn#binary.par -r test_compute_bucket_assignment_by_size_sparse_error_{with, without}_logger` Test for `_verify_model_across_ranks` (Python)/`verify_replicas0_across_process` (C++) `BACKEND={nccl, gloo} WORLD_SIZE=2 ../buck-out/dev/gen/caffe2/test/distributed/distributed_{nccl, gloo}_spawn#binary.par -r test_verify_model_across_ranks_{with, without}_logger` Test that constructs an ill-formed DDP instance. Only did cleanup of this function. `BACKEND={nccl, gloo} WORLD_SIZE=2 ../buck-out/dev/gen/caffe2/test/distributed/distributed_{nccl, gloo}_spawn#binary.par -r test_ddp_model_diff_across_ranks` Reviewed By: rohan-varma Differential Revision: D30924790 fbshipit-source-id: dae6fa82485a204a6a4b022f2d073417d07ebb2f --- torch/csrc/distributed/c10d/init.cpp | 26 +++- torch/csrc/distributed/c10d/reducer.cpp | 57 ++++--- torch/csrc/distributed/c10d/reducer.hpp | 8 +- .../_internal/distributed/distributed_test.py | 169 ++++++++++++++++++--- 4 files changed, 209 insertions(+), 51 deletions(-) diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 4bac0ca..77e2dae 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1544,18 +1544,40 @@ Example:: module.def( "_compute_bucket_assignment_by_size", - &::c10d::compute_bucket_assignment_by_size, + [](const std::vector& tensors, + const std::vector& bucket_size_limits, + const std::vector& expect_sparse_gradient, + const std::vector& tensor_indices, + const c10::optional>& logger) { + if (logger.has_value()) { + std::weak_ptr<::c10d::Logger> logger_weakref = logger.value(); + return ::c10d::compute_bucket_assignment_by_size(tensors, bucket_size_limits, expect_sparse_gradient, tensor_indices, {logger_weakref}); + } else { + return ::c10d::compute_bucket_assignment_by_size(tensors, bucket_size_limits, expect_sparse_gradient, tensor_indices, {}); + } + }, py::arg("tensors"), py::arg("bucket_size"), py::arg("expect_sparse_gradient") = std::vector(), py::arg("tensor_indices") = std::vector(), + py::arg("logger") = c10::optional>{}, py::call_guard()); module.def( "_verify_model_across_ranks", - &::c10d::verify_replica0_across_processes, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group, + const std::vector>& model_replicas, + const c10::optional>& logger) { + if (logger.has_value()) { + std::weak_ptr<::c10d::Logger> logger_weakref = logger.value(); + verify_replica0_across_processes(process_group, model_replicas, {logger_weakref}); + } else { + verify_replica0_across_processes(process_group, model_replicas, {}); + } + }, py::arg("process_group"), py::arg("replicas"), + py::arg("logger") = c10::optional>{}, py::call_guard()); module.def( diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 91db615..3e726f4 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -1702,7 +1702,8 @@ bool Reducer::rebuild_buckets() { rebuilt_params_, bucket_size_limits, expect_sparse_gradients_[0], - rebuilt_param_indices_); + rebuilt_param_indices_, + logger_); if (ddp_set_last_bucket_as_small) { // Reverse again because buckets were rebuilt in the opposite of gradient @@ -1958,7 +1959,8 @@ compute_bucket_assignment_by_size( const std::vector& tensors, const std::vector& bucket_size_limits, const std::vector& expect_sparse_gradient, - const std::vector& tensor_indices) { + const std::vector& tensor_indices, + const c10::optional>& logger) { // Either expect_sparse_gradient is not specified or it has as many elements // as the vector with tensors. TORCH_INTERNAL_ASSERT( @@ -1988,9 +1990,12 @@ compute_bucket_assignment_by_size( for (const auto i : c10::irange(tensors.size())) { const auto& tensor = tensors[i]; - // TODO: This is not a reducer method so it does not have access to logger, - // pass in logger directly here. - TORCH_CHECK(!tensor.is_sparse(), "No support for sparse tensors."); + auto msg = std::string("No support for sparse tensors."); + if (logger.has_value()) { + REDUCER_CHECK(!tensor.is_sparse(), logger.value(), msg); + } else { + TORCH_CHECK(!tensor.is_sparse(), msg); + } // when tensor_indices is empty, the index of tensors[i] assigned to // bucket is i, otherwise the tensor index is tensor_indices[i]. @@ -2078,8 +2083,9 @@ compute_bucket_assignment_by_size( // Verifies corresponding params in replica 0 have the same sizes/strides // across processes. void verify_replica0_across_processes( - c10::intrusive_ptr process_group, - std::vector> model_replicas) { + const c10::intrusive_ptr& process_group, + const std::vector>& model_replicas, + const c10::optional>& logger) { size_t i = 0; for (const auto& t : model_replicas[0]) { i += 2 * t.dim(); @@ -2116,26 +2122,27 @@ void verify_replica0_across_processes( // I'd like to include which process we are in the message, // but ProcessGroup::getRank is not public! for (const auto& sz : t.sizes()) { - // TODO: pass in logger and use REDUCER_CHECK. - TORCH_CHECK( - sz == control_accessor[i++], - "replicas[0][", - p, - "] in this process" - " with sizes ", - t.sizes(), - " appears not to match sizes of the same param in process 0."); + auto msg = c10::str("replicas[0][", p, "] in this process", + " with sizes ", + t.sizes(), + " appears not to match sizes of the same param in process 0."); + if (logger.has_value()) { + REDUCER_CHECK(sz == control_accessor[i++], logger.value(), msg) + } else { + TORCH_CHECK(sz == control_accessor[i++], msg) + } + } for (const auto& str : t.strides()) { - // TODO: pass in logger and use REDUCER_CHECK. - TORCH_CHECK( - str == control_accessor[i++], - "replicas[0][", - p, - "] in this process" - " with strides ", - t.strides(), - " appears not to match strides of the same param in process 0."); + auto msg = c10::str("replicas[0][", p, "] in this process", + " with sizes ", + t.sizes(), + " appears not to match strides of the same param in process 0."); + if (logger.has_value()) { + REDUCER_CHECK(str == control_accessor[i++], logger.value(), msg) + } else { + TORCH_CHECK(str == control_accessor[i++], msg) + } } } } diff --git a/torch/csrc/distributed/c10d/reducer.hpp b/torch/csrc/distributed/c10d/reducer.hpp index 9f46e63..8dc42b6 100644 --- a/torch/csrc/distributed/c10d/reducer.hpp +++ b/torch/csrc/distributed/c10d/reducer.hpp @@ -602,11 +602,13 @@ compute_bucket_assignment_by_size( const std::vector& tensors, const std::vector& bucket_size, const std::vector& expect_sparse_gradient = {}, - const std::vector& tensor_indices = {}); + const std::vector& tensor_indices = {}, + const c10::optional>& logger = {}); // Verify models across all processes are the same as model on rank 0 with // respect to no. of params and matching dtype/size/layout. TORCH_API void verify_replica0_across_processes( - c10::intrusive_ptr process_group, - std::vector> model_replicas); + const c10::intrusive_ptr& process_group, + const std::vector>& model_replicas, + const c10::optional>& logger); } // namespace c10d diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 613e23e..999b2dc 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -6857,11 +6857,20 @@ class DistributedTest: ): dist.scatter_object_list([], scatter_list, src=src_rank) - @require_backend({"gloo", "nccl"}) - @require_backends_available({"gloo", "nccl"}) - @skip_if_lt_x_gpu(2) - @skip_if_rocm - def test_ddp_model_diff_across_ranks(self): + def _generate_sparse_tensors_for_bucket_assignment_test(self): + tensors = [ + torch.empty([50], dtype=torch.float), + torch.empty([25], dtype=torch.double), + torch.empty([50], dtype=torch.float), + torch.empty([25], dtype=torch.double), + torch.empty([50], dtype=torch.float), + torch.empty([25], dtype=torch.double), + ] + + tensors_sparse = [t.to_sparse() for t in tensors] + return tensors_sparse + + def _test_compute_bucket_assignment_by_size(self, use_logger): group_gloo = dist.new_group( timeout=timedelta(seconds=60), backend=dist.Backend.GLOO ) @@ -6872,9 +6881,49 @@ class DistributedTest: backend=dist.get_backend(), timeout=timedelta(seconds=5) ) torch.cuda.set_device(self.rank) - # Creates network with different sized embedding table on different - # ranks. This should throw an error during DDP init. - net = EmbeddingNet(self.rank) + + # Create a valid model. The constructor initializes the logger that we use later. + # We never actually use the rest of the model - we only need its logger. + net = EmbeddingNet(0) + net = torch.nn.parallel.DistributedDataParallel( + net.to(self.rank), + device_ids=[self.rank], + process_group=group_to_use, + ) + + # if we don't pass a logger then we can only check that an exception was thrown. + expected_err = "No support for sparse tensors." + with self.assertRaisesRegex(RuntimeError, expected_err): + tensors_sparse = self._generate_sparse_tensors_for_bucket_assignment_test() + if use_logger: + result = dist._compute_bucket_assignment_by_size( + tensors_sparse, + [400], + logger=net.logger) + else: + result = dist._compute_bucket_assignment_by_size(tensors_sparse, [400]) + if use_logger: + verify_ddp_error_logged(net, expected_err) + + # Perform gloo-based barrier to ensure one rank doesn't exit test + # early which causes failure with Barrier.sync. + dist.barrier(group_gloo) + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_compute_bucket_assignment_by_size_sparse_error_without_logger(self): + self._test_compute_bucket_assignment_by_size(use_logger=False) + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self): + self._test_compute_bucket_assignment_by_size(use_logger=True) + + def _determine_expected_error_verify_model_across_rank(self, group_to_use): # When running with NCCL backend, we don't expect an error on rank 0, # rather, it will be taken down by NCCL_ASYNC_ERROR_HANDLING. When # running with Gloo or with debug mode wrapper, we expect the error @@ -6882,21 +6931,98 @@ class DistributedTest: is_detail_dbg_mode = ( dist._get_debug_mode() == dist._DistributedDebugLevel.DETAIL ) - rank_0_ctx = ( - self.assertRaisesRegex( - RuntimeError, "Caught collective operation timeout" - ) - if dist.get_backend(group_to_use) == dist.Backend.NCCL - and not is_detail_dbg_mode - # Gloo can raise various exception messages, so just assert - # Runtime error here. - else self.assertRaises(RuntimeError) + if self.rank == 0: + if dist.get_backend(group_to_use) == dist.Backend.NCCL and not is_detail_dbg_mode: + expected_err = "Caught collective operation timeout" + ctx = self.assertRaisesRegex(RuntimeError, expected_err) + else: + expected_err = None + ctx = self.assertRaises(RuntimeError) + else: + expected_err = "appears not to match" + ctx = self.assertRaisesRegex(RuntimeError, expected_err) + return ctx, expected_err + + def _test_verify_model_across_rank(self, use_logger): + group_gloo = dist.new_group( + timeout=timedelta(seconds=60), backend=dist.Backend.GLOO ) - ctx = ( - rank_0_ctx - if self.rank == 0 - else self.assertRaisesRegex(RuntimeError, "appears not to match") + # Set NCCL_BLOCKING_WAIT and use a new NCCL group to improve test + # determinism. + os.environ["NCCL_BLOCKING_WAIT"] = "1" + group_to_use = dist.new_group( + backend=dist.get_backend(), timeout=timedelta(seconds=5) ) + torch.cuda.set_device(self.rank) + ctx, expected_err = self._determine_expected_error_verify_model_across_rank(group_to_use) + + # Create a valid model. The constructor initializes the logger that we use later. + net = EmbeddingNet(0) + net = torch.nn.parallel.DistributedDataParallel( + net.to(self.rank), + device_ids=[self.rank], + process_group=group_to_use, + ) + + # Modify the model so that the number of parameters are different for each rank. + # This will cause a RuntimeError to be thrown below in dist._verify_model_across_ranks, + # so we can check if the correct error is thrown and is logged. + # We can't do this in the constructor above otherwise the logger will + # not be properly initialized. + net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1) + + # if we pass a logger we can verify that it was logged + with ctx: + if use_logger: + dist._verify_model_across_ranks(net.process_group, [list(net.parameters())], net.logger) + else: + dist._verify_model_across_ranks(net.process_group, [list(net.parameters())]) + # Should only be run by rank 0, and blocking_wait catches and + # reports exception. + dist.barrier(group_to_use) + + # We don't check when self.rank != 0 because the logger doesn't log + # the error "Caught collective operation" as that is not thrown in the reducer. + if use_logger and self.rank != 0: + verify_ddp_error_logged(net, expected_err) + + # Perform gloo-based barrier to ensure one rank doesn't exit test + # early which causes failure with Barrier.sync. + dist.barrier(group_gloo) + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_verify_model_across_rank_with_logger(self): + self._test_verify_model_across_rank(use_logger=True) + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_verify_model_across_rank_without_logger(self): + self._test_verify_model_across_rank(use_logger=False) + + @require_backend({"gloo", "nccl"}) + @require_backends_available({"gloo", "nccl"}) + @skip_if_lt_x_gpu(2) + @skip_if_rocm + def test_ddp_model_diff_across_ranks(self): + group_gloo = dist.new_group( + timeout=timedelta(seconds=60), backend=dist.Backend.GLOO + ) + # Set NCCL_BLOCKING_WAIT and use a new NCCL group to improve test + # determinism. + os.environ["NCCL_BLOCKING_WAIT"] = "1" + group_to_use = dist.new_group( + backend=dist.get_backend(), timeout=timedelta(seconds=5) + ) + torch.cuda.set_device(self.rank) + ctx, expected_err = self._determine_expected_error_verify_model_across_rank(group_to_use) + # Creates network with different sized embedding table on different + # ranks. This should throw an error during DDP init. + net = EmbeddingNet(self.rank) with ctx: net = torch.nn.parallel.DistributedDataParallel( net.to(self.rank), @@ -6906,6 +7032,7 @@ class DistributedTest: # Should only be run by rank 0, and blocking_wait catches and # reports exception. dist.barrier(group_to_use) + # can't use verify_ddp_error_logged here because net was never properly constructed # Perform gloo-based barrier to ensure one rank doesn't exit test # early which causes failure with Barrier.sync. -- 2.7.4