Adding collective quantization API (#62142)
authorMarjan Fariborz <marjanf@fb.com>
Mon, 9 Aug 2021 15:09:49 +0000 (08:09 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 9 Aug 2021 15:11:22 +0000 (08:11 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62142

Created wrapper that takes the collective op and a quantization type as an arguments. It quantize the input, performs the collective op, and and perform dequantization

Test Plan:
Tested through distributed_gloo_fork.
e.g., buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_all_to_all_quantized

Reviewed By: wanchaol

Differential Revision: D29682812

fbshipit-source-id: 79c39105ff11270008caa9f566361452fe82a92e

torch/distributed/algorithms/quantization.py [new file with mode: 0644]
torch/testing/_internal/distributed/distributed_test.py

diff --git a/torch/distributed/algorithms/quantization.py b/torch/distributed/algorithms/quantization.py
new file mode 100644 (file)
index 0000000..dead78a
--- /dev/null
@@ -0,0 +1,131 @@
+import functools
+import torch
+import torch.distributed as dist
+
+
+from enum import Enum
+
+
+TORCH_HALF_MIN = torch.finfo(torch.float16).min
+TORCH_HALF_MAX = torch.finfo(torch.float16).max
+
+class DQuantType(Enum):
+    FP16 = "fp16"
+
+    def __str__(self) -> str:
+        return self.value
+
+
+def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
+    return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
+
+def _quantize_tensor(tensor, qtype):
+    if not isinstance(tensor, torch.Tensor):
+        raise RuntimeError(
+            f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
+        )
+    if (qtype == DQuantType.FP16):
+        return _fp32_to_fp16_with_clamp(tensor)
+    else:
+        raise RuntimeError(
+            f'Quantization type {qtype} is not supported'
+        )
+
+def _quantize_tensor_list(tensor_list, qtype):
+    if not isinstance(tensor_list, list) or not all(
+        isinstance(p, torch.Tensor) for p in tensor_list
+    ):
+        raise RuntimeError(
+            f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
+        )
+    if (qtype == DQuantType.FP16):
+        quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
+        return quantized_tensor_list
+    else:
+        raise RuntimeError(
+            f'Quantization type {qtype} is not supported'
+        )
+
+def _dequantize_tensor(tensor, qtype, quant_loss=None):
+    if not isinstance(tensor, torch.Tensor):
+        raise RuntimeError(
+            f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
+        )
+    if (qtype == DQuantType.FP16):
+        if tensor.dtype != torch.float16:
+            raise RuntimeError(
+                f"tensor dtype is {tensor.dtype} while expected to be FP16."
+            )
+        elif tensor.dtype == torch.float16 and quant_loss is None:
+            return tensor.float()
+        else:
+            return tensor.float() / quant_loss
+    else:
+        raise RuntimeError(
+            f'Quantization type {qtype} is not supported'
+        )
+
+
+def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None):
+    if not isinstance(tensor_list, list) or not all(
+        isinstance(p, torch.Tensor) for p in tensor_list
+    ):
+        raise RuntimeError(
+            f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
+        )
+    elif (qtype == DQuantType.FP16):
+        dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
+        return dequantized_tensor_list
+    else:
+        raise RuntimeError(
+            f'Quantization type {qtype} is not supported'
+        )
+
+
+def auto_quantize(func, qtype, quant_loss=None):
+    """
+    This is a prototype API that automatically quantize the input tensors, choose the precision types, and
+    pass other necessary arguments and then dequantizes the output.
+
+    Currently it only supports:
+        . FP16 quantization method
+        . all_gather, all_to_all collective ops
+
+    Args:
+        func (callable): A function representing collective operations.
+        qtype (QuantType): Quantization method
+        quant_loss (float, optional): This can be used to improve accuracy in the dequantization.
+
+    Returns:
+        (callable): the same collective as func but enables automatic quantization/dequantization.
+    """
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        group = kwargs.get('group', None)
+        async_op = kwargs.get('async_op', False)
+        if (async_op is True):
+            raise RuntimeError(
+                'The async_op=True mode is not supported yet.'
+            )
+        if (func == dist.all_gather):
+            tensors = args[0]
+            input_tensors = _quantize_tensor(args[1], qtype)
+            out_tensors = _quantize_tensor_list(tensors, qtype)
+            dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op)
+            for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
+                tensors[i] = t
+
+        elif (func == dist.all_to_all):
+            tensors = args[0]
+            input_tensors = _quantize_tensor_list(args[1], qtype)
+            out_tensors = _quantize_tensor_list(tensors, qtype)
+            dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op)
+            for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
+                tensors[i] = t
+
+        else:
+            raise RuntimeError(
+                f"The collective op {func} is not supported yet"
+            )
+
+    return wrapper
index 21449a7..7ec9100 100644 (file)
@@ -19,6 +19,7 @@ import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_lo
 import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
 import torch.distributed.algorithms.model_averaging.averagers as averagers
 import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
+import torch.distributed.algorithms.quantization as quant
 import torch.nn as nn
 import torch.nn.functional as F
 from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
@@ -28,6 +29,7 @@ from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
 from torch.distributed.algorithms.ddp_comm_hooks import (
     quantization as quantization_hooks,
 )
+from torch.distributed.algorithms.quantization import DQuantType
 from torch.distributed.distributed_c10d import (
     get_world_size,
     _get_default_group,
@@ -2736,11 +2738,15 @@ class DistributedTest:
 
         # ALL GATHER
         def _test_all_gather_helper(
-            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float, qtype=None
         ):
             for dest in group:
                 tensor = _build_tensor(dest + 1, rank, dtype=dtype)
                 tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group]
+                if qtype is not None:
+                    allgather = quant.auto_quantize(dist.all_gather, qtype, quant_loss=None)
+                else:
+                    allgather = dist.all_gather
                 if cuda:
                     tensor = tensor.cuda(rank_to_GPU[rank][0])
                     tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
@@ -2751,10 +2757,11 @@ class DistributedTest:
                 self.call_dist_op(
                     ":all_gather",
                     False,
-                    dist.all_gather,
+                    allgather,
                     tensors,
                     tensor,
                     group_id,
+                    False,
                     tensor_shapes=tensor_shapes,
                 )
 
@@ -2805,6 +2812,12 @@ class DistributedTest:
             group, group_id, rank = self._init_full_group_test()
             self._test_all_gather_helper(group, group_id, rank)
 
+        @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
+        @sandcastle_skip_if(BACKEND == "mpi", "all_gather_quantized does not support MPI")
+        def test_all_gather_quantized(self):
+            group, group_id, rank = self._init_global_test()
+            self._test_all_gather_helper(group, group_id, rank, dtype=torch.float32, qtype=DQuantType.FP16)
+
         def _run_all_gather_coalesced_and_verify(
             self, output_tensor_lists, input_tensors, expected_tensors, group_id
         ):
@@ -3007,6 +3020,7 @@ class DistributedTest:
             cuda=False,
             rank_to_GPU=None,
             dtype=torch.float,
+            qtype=None
         ):
             if group_id is not None:
                 size = len(group)
@@ -3027,7 +3041,11 @@ class DistributedTest:
                         t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors
                     ]
                     out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
-                dist.all_to_all(out_tensors, in_tensors, group=group_id)
+                if(qtype is not None):
+                    quantize_alltoall = quant.auto_quantize(dist.all_to_all, qtype, quant_loss=None)
+                    quantize_alltoall(out_tensors, in_tensors, group=group_id)
+                else:
+                    dist.all_to_all(out_tensors, in_tensors, group=group_id)
                 for t1, t2 in zip(out_tensors, expected_tensors):
                     self.assertEqual(t1, t2)
             self._barrier()
@@ -3110,6 +3128,20 @@ class DistributedTest:
             group, group_id, rank = self._init_global_test()
             self._test_all_to_all_helper(group, group_id, rank)
 
+        @sandcastle_skip_if(BACKEND != "nccl", "Only NCCL supports all_to_all")
+        @skip_if_rocm
+        def test_all_to_all_quantized(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = self._init_multigpu_helper()
+            self._test_all_to_all_helper(
+                group,
+                group_id,
+                rank,
+                cuda=True,
+                rank_to_GPU=rank_to_GPU,
+                dtype=torch.float32,
+                qtype=DQuantType.FP16)
+
         @sandcastle_skip_if(BACKEND != "nccl", "Only NCCL supports CUDA all_to_all")
         @skip_if_rocm
         def test_all_to_all_cuda(self):