From bf9d66586c388c0aa223644b1d224227443ae34b Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Wed, 1 Sep 2021 17:32:39 -0700 Subject: [PATCH] [DDP Comm Hook] Create a noop hook for performance debugging (#64344) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64344 As title. Additionally, avoid using numpy array in test_ddp_hooks.py. ghstack-source-id: 137170449 Test Plan: buck test mode/dev-nosan caffe2/test/distributed/algorithms/ddp_comm_hooks:test_ddp_hooks -- test_ddp_comm_hook_noop_hook Reviewed By: rohan-varma Differential Revision: D30693220 fbshipit-source-id: e17f0d1c6198863cf20a53566f586a6bff602522 --- .../algorithms/ddp_comm_hooks/test_ddp_hooks.py | 34 ++++++++++++++++++---- .../algorithms/ddp_comm_hooks/__init__.py | 4 +++ .../algorithms/ddp_comm_hooks/debugging_hooks.py | 26 +++++++++++++++++ 3 files changed, 58 insertions(+), 6 deletions(-) create mode 100644 torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py diff --git a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py index 3d00712..67175b2 100644 --- a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py +++ b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py @@ -2,7 +2,6 @@ import os import sys -import numpy as np import torch from torch import nn import torch.distributed as dist @@ -105,7 +104,9 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase): # Run backward output.mean().backward() - return [p.grad.data.cpu().numpy() for p in model.parameters()] + # The only layer + param = next(model.parameters()) + return param.grad @requires_nccl() @skip_if_lt_x_gpu(2) @@ -122,7 +123,7 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase): # Register hook case, get the hook grads. hook_grads = self._get_grads(process_group, DDPCommHookType.ALLREDUCE) - np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=0) + torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=0) @requires_nccl() @skip_if_lt_x_gpu(2) @@ -139,7 +140,7 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase): # Register hook case, get the hook grads. hook_grads = self._get_grads(process_group, DDPCommHookType.FP16_COMPRESS) - np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) @requires_nccl() @skip_if_lt_x_gpu(2) @@ -156,7 +157,7 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase): # Register hook case, get the hook grads. hook_grads = self._get_grads(process_group, DDPCommHookType.QUANTIZE_PER_TENSOR) - np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) @requires_nccl() @skip_if_lt_x_gpu(2) @@ -175,7 +176,28 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase): process_group, DDPCommHookType.QUANTIZE_PER_CHANNEL ) - np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_ddp_comm_hook_noop_hook(self): + """ + This unit test verifies the ``noop`` hook registered case and a subsequent allreduce + gives same result with no hook registered case. + """ + store = dist.FileStore(self.file_name, self.world_size) + process_group = dist.ProcessGroupNCCL(store, self.rank, self.world_size) + + # No hook registered case, get the reference grads. + reference_grads = self._get_grads(process_group, None) + # Register hook case, get the hook grads. + hook_grads = self._get_grads(process_group, DDPCommHookType.NOOP) + # Apply a subsequent allreduce to average grads. + hook_grads.div_(self.world_size) + dist.all_reduce(hook_grads, group=process_group) + + torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=0) @requires_nccl() @skip_if_lt_x_gpu(2) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py index c3f3b06..ff22a81 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -5,6 +5,7 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from . import ( + debugging_hooks as debugging, default_hooks as default, powerSGD_hook as powerSGD, quantization_hooks as quantization, @@ -78,6 +79,9 @@ class DDPCommHookType(Enum): comm_hook=powerSGD.batched_powerSGD_hook, matrix_approximation_rank=2, ) + NOOP = partial( + _ddp_comm_hook_wrapper, comm_hook=debugging.noop_hook, + ) def register_ddp_comm_hook( diff --git a/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py new file mode 100644 index 0000000..0c60762 --- /dev/null +++ b/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py @@ -0,0 +1,26 @@ +from typing import Any + +import torch +import torch.distributed as dist + + +def noop_hook(_: Any, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: + """ + This DDP communication hook returns the a future that wraps the input, + so it is a noop that does not incur any communication overheads. + + This hook should **only** be used for headroom analysis of allreduce optimization, + instead of the normal gradient synchronization. + For example, if only less than 10% speedup of training time can be observed after this hook is registered, + it usually implies that allreduce is not a performance bottleneck for this case. + Such instrumentation can be particularly useful + if GPU traces cannot be easily retrieved or the trace analysis is complicated + some factors such as the overlap between allreduce and computation or the desynchronization across ranks. + + Example:: + >>> ddp_model.register_comm_hook(None, noop_hook) + """ + fut: torch.futures.Future[torch.Tensor] = torch.futures.Future() + fut.set_result(bucket.buffer()) + + return fut -- 2.7.4