--- /dev/null
+import functools
+import torch
+import torch.distributed as dist
+
+
+from enum import Enum
+
+
+TORCH_HALF_MIN = torch.finfo(torch.float16).min
+TORCH_HALF_MAX = torch.finfo(torch.float16).max
+
+class DQuantType(Enum):
+ FP16 = "fp16"
+
+ def __str__(self) -> str:
+ return self.value
+
+
+def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
+ return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
+
+def _quantize_tensor(tensor, qtype):
+ if not isinstance(tensor, torch.Tensor):
+ raise RuntimeError(
+ f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
+ )
+ if (qtype == DQuantType.FP16):
+ return _fp32_to_fp16_with_clamp(tensor)
+ else:
+ raise RuntimeError(
+ f'Quantization type {qtype} is not supported'
+ )
+
+def _quantize_tensor_list(tensor_list, qtype):
+ if not isinstance(tensor_list, list) or not all(
+ isinstance(p, torch.Tensor) for p in tensor_list
+ ):
+ raise RuntimeError(
+ f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
+ )
+ if (qtype == DQuantType.FP16):
+ quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
+ return quantized_tensor_list
+ else:
+ raise RuntimeError(
+ f'Quantization type {qtype} is not supported'
+ )
+
+def _dequantize_tensor(tensor, qtype, quant_loss=None):
+ if not isinstance(tensor, torch.Tensor):
+ raise RuntimeError(
+ f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
+ )
+ if (qtype == DQuantType.FP16):
+ if tensor.dtype != torch.float16:
+ raise RuntimeError(
+ f"tensor dtype is {tensor.dtype} while expected to be FP16."
+ )
+ elif tensor.dtype == torch.float16 and quant_loss is None:
+ return tensor.float()
+ else:
+ return tensor.float() / quant_loss
+ else:
+ raise RuntimeError(
+ f'Quantization type {qtype} is not supported'
+ )
+
+
+def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None):
+ if not isinstance(tensor_list, list) or not all(
+ isinstance(p, torch.Tensor) for p in tensor_list
+ ):
+ raise RuntimeError(
+ f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
+ )
+ elif (qtype == DQuantType.FP16):
+ dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
+ return dequantized_tensor_list
+ else:
+ raise RuntimeError(
+ f'Quantization type {qtype} is not supported'
+ )
+
+
+def auto_quantize(func, qtype, quant_loss=None):
+ """
+ This is a prototype API that automatically quantize the input tensors, choose the precision types, and
+ pass other necessary arguments and then dequantizes the output.
+
+ Currently it only supports:
+ . FP16 quantization method
+ . all_gather, all_to_all collective ops
+
+ Args:
+ func (callable): A function representing collective operations.
+ qtype (QuantType): Quantization method
+ quant_loss (float, optional): This can be used to improve accuracy in the dequantization.
+
+ Returns:
+ (callable): the same collective as func but enables automatic quantization/dequantization.
+ """
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ group = kwargs.get('group', None)
+ async_op = kwargs.get('async_op', False)
+ if (async_op is True):
+ raise RuntimeError(
+ 'The async_op=True mode is not supported yet.'
+ )
+ if (func == dist.all_gather):
+ tensors = args[0]
+ input_tensors = _quantize_tensor(args[1], qtype)
+ out_tensors = _quantize_tensor_list(tensors, qtype)
+ dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op)
+ for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
+ tensors[i] = t
+
+ elif (func == dist.all_to_all):
+ tensors = args[0]
+ input_tensors = _quantize_tensor_list(args[1], qtype)
+ out_tensors = _quantize_tensor_list(tensors, qtype)
+ dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op)
+ for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
+ tensors[i] = t
+
+ else:
+ raise RuntimeError(
+ f"The collective op {func} is not supported yet"
+ )
+
+ return wrapper
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
import torch.distributed.algorithms.model_averaging.averagers as averagers
import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
+import torch.distributed.algorithms.quantization as quant
import torch.nn as nn
import torch.nn.functional as F
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
from torch.distributed.algorithms.ddp_comm_hooks import (
quantization as quantization_hooks,
)
+from torch.distributed.algorithms.quantization import DQuantType
from torch.distributed.distributed_c10d import (
get_world_size,
_get_default_group,
# ALL GATHER
def _test_all_gather_helper(
- self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
+ self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float, qtype=None
):
for dest in group:
tensor = _build_tensor(dest + 1, rank, dtype=dtype)
tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group]
+ if qtype is not None:
+ allgather = quant.auto_quantize(dist.all_gather, qtype, quant_loss=None)
+ else:
+ allgather = dist.all_gather
if cuda:
tensor = tensor.cuda(rank_to_GPU[rank][0])
tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
self.call_dist_op(
":all_gather",
False,
- dist.all_gather,
+ allgather,
tensors,
tensor,
group_id,
+ False,
tensor_shapes=tensor_shapes,
)
group, group_id, rank = self._init_full_group_test()
self._test_all_gather_helper(group, group_id, rank)
+ @sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
+ @sandcastle_skip_if(BACKEND == "mpi", "all_gather_quantized does not support MPI")
+ def test_all_gather_quantized(self):
+ group, group_id, rank = self._init_global_test()
+ self._test_all_gather_helper(group, group_id, rank, dtype=torch.float32, qtype=DQuantType.FP16)
+
def _run_all_gather_coalesced_and_verify(
self, output_tensor_lists, input_tensors, expected_tensors, group_id
):
cuda=False,
rank_to_GPU=None,
dtype=torch.float,
+ qtype=None
):
if group_id is not None:
size = len(group)
t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors
]
out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
- dist.all_to_all(out_tensors, in_tensors, group=group_id)
+ if(qtype is not None):
+ quantize_alltoall = quant.auto_quantize(dist.all_to_all, qtype, quant_loss=None)
+ quantize_alltoall(out_tensors, in_tensors, group=group_id)
+ else:
+ dist.all_to_all(out_tensors, in_tensors, group=group_id)
for t1, t2 in zip(out_tensors, expected_tensors):
self.assertEqual(t1, t2)
self._barrier()
group, group_id, rank = self._init_global_test()
self._test_all_to_all_helper(group, group_id, rank)
+ @sandcastle_skip_if(BACKEND != "nccl", "Only NCCL supports all_to_all")
+ @skip_if_rocm
+ def test_all_to_all_quantized(self):
+ group, group_id, rank = self._init_global_test()
+ rank_to_GPU = self._init_multigpu_helper()
+ self._test_all_to_all_helper(
+ group,
+ group_id,
+ rank,
+ cuda=True,
+ rank_to_GPU=rank_to_GPU,
+ dtype=torch.float32,
+ qtype=DQuantType.FP16)
+
@sandcastle_skip_if(BACKEND != "nccl", "Only NCCL supports CUDA all_to_all")
@skip_if_rocm
def test_all_to_all_cuda(self):