[DDP Comm Hook] Create a noop hook for performance debugging (#64344)
authorYi Wang <wayi@fb.com>
Thu, 2 Sep 2021 00:32:39 +0000 (17:32 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 00:36:22 +0000 (17:36 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64344

As title.

Additionally, avoid using numpy array in test_ddp_hooks.py.
ghstack-source-id: 137170449

Test Plan: buck test mode/dev-nosan caffe2/test/distributed/algorithms/ddp_comm_hooks:test_ddp_hooks -- test_ddp_comm_hook_noop_hook

Reviewed By: rohan-varma

Differential Revision: D30693220

fbshipit-source-id: e17f0d1c6198863cf20a53566f586a6bff602522

test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py
torch/distributed/algorithms/ddp_comm_hooks/__init__.py
torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py [new file with mode: 0644]

index 3d00712..67175b2 100644 (file)
@@ -2,7 +2,6 @@
 import os
 import sys
 
-import numpy as np
 import torch
 from torch import nn
 import torch.distributed as dist
@@ -105,7 +104,9 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
         # Run backward
         output.mean().backward()
 
-        return [p.grad.data.cpu().numpy() for p in model.parameters()]
+        # The only layer
+        param = next(model.parameters())
+        return param.grad
 
     @requires_nccl()
     @skip_if_lt_x_gpu(2)
@@ -122,7 +123,7 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
         # Register hook case, get the hook grads.
         hook_grads = self._get_grads(process_group, DDPCommHookType.ALLREDUCE)
 
-        np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=0)
+        torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=0)
 
     @requires_nccl()
     @skip_if_lt_x_gpu(2)
@@ -139,7 +140,7 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
         # Register hook case, get the hook grads.
         hook_grads = self._get_grads(process_group, DDPCommHookType.FP16_COMPRESS)
 
-        np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
+        torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
 
     @requires_nccl()
     @skip_if_lt_x_gpu(2)
@@ -156,7 +157,7 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
         # Register hook case, get the hook grads.
         hook_grads = self._get_grads(process_group, DDPCommHookType.QUANTIZE_PER_TENSOR)
 
-        np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
+        torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
 
     @requires_nccl()
     @skip_if_lt_x_gpu(2)
@@ -175,7 +176,28 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
             process_group, DDPCommHookType.QUANTIZE_PER_CHANNEL
         )
 
-        np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
+        torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
+
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(2)
+    def test_ddp_comm_hook_noop_hook(self):
+        """
+        This unit test verifies the ``noop`` hook registered case and a subsequent allreduce
+        gives same result with no hook registered case.
+        """
+        store = dist.FileStore(self.file_name, self.world_size)
+        process_group = dist.ProcessGroupNCCL(store, self.rank, self.world_size)
+
+        # No hook registered case, get the reference grads.
+        reference_grads = self._get_grads(process_group, None)
+        # Register hook case, get the hook grads.
+        hook_grads = self._get_grads(process_group, DDPCommHookType.NOOP)
+        # Apply a subsequent allreduce to average grads.
+        hook_grads.div_(self.world_size)
+        dist.all_reduce(hook_grads, group=process_group)
+
+        torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=0)
 
     @requires_nccl()
     @skip_if_lt_x_gpu(2)
index c3f3b06..ff22a81 100644 (file)
@@ -5,6 +5,7 @@ import torch.distributed as dist
 from torch.nn.parallel import DistributedDataParallel
 
 from . import (
+    debugging_hooks as debugging,
     default_hooks as default,
     powerSGD_hook as powerSGD,
     quantization_hooks as quantization,
@@ -78,6 +79,9 @@ class DDPCommHookType(Enum):
         comm_hook=powerSGD.batched_powerSGD_hook,
         matrix_approximation_rank=2,
     )
+    NOOP = partial(
+        _ddp_comm_hook_wrapper, comm_hook=debugging.noop_hook,
+    )
 
 
 def register_ddp_comm_hook(
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py
new file mode 100644 (file)
index 0000000..0c60762
--- /dev/null
@@ -0,0 +1,26 @@
+from typing import Any
+
+import torch
+import torch.distributed as dist
+
+
+def noop_hook(_: Any, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
+    """
+    This DDP communication hook returns the a future that wraps the input,
+    so it is a noop that does not incur any communication overheads.
+
+    This hook should **only** be used for headroom analysis of allreduce optimization,
+    instead of the normal gradient synchronization.
+    For example, if only less than 10% speedup of training time can be observed after this hook is registered,
+    it usually implies that allreduce is not a performance bottleneck for this case.
+    Such instrumentation can be particularly useful
+    if GPU traces cannot be easily retrieved or the trace analysis is complicated
+    some factors such as the overlap between allreduce and computation or the desynchronization across ranks.
+
+    Example::
+        >>> ddp_model.register_comm_hook(None, noop_hook)
+    """
+    fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
+    fut.set_result(bucket.buffer())
+
+    return fut