From 6a76ee04de5f10b76cc8f97cc254da43905d170b Mon Sep 17 00:00:00 2001 From: Marjan Fariborz Date: Fri, 27 Aug 2021 12:45:01 -0700 Subject: [PATCH] Adding alltoall_single collective to collective quantization API (#63154) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63154 The collective quantization API now supports alltoall, alltoall_single, and allscatter. The test is also included. ghstack-source-id: 136856877 Test Plan: buck test mode/dev-nosan //caffe2/test/distributed/algorithms/quantization:DistQuantizationTests_nccl -- test_all_to_all_single_bfp16 Reviewed By: wanchaol Differential Revision: D30255251 fbshipit-source-id: 856f4fa12de104689a03a0c8dc9e3ecfd41cad29 --- .../algorithms/quantization/test_quantization.py | 61 ++++++++++++++++++++++ .../algorithms/quantization/quantization.py | 14 +++-- 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/test/distributed/algorithms/quantization/test_quantization.py b/test/distributed/algorithms/quantization/test_quantization.py index 505f805..e60539f 100644 --- a/test/distributed/algorithms/quantization/test_quantization.py +++ b/test/distributed/algorithms/quantization/test_quantization.py @@ -148,6 +148,46 @@ if BACKEND == "gloo" or BACKEND == "nccl": dtype=torch.float32, qtype=DQuantType.BFP16) + @requires_nccl() + @sandcastle_skip_if(BACKEND != "nccl", "Only nccl backend supports all_to_all_single_fp16") + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_all_to_all_single_fp16(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl') + device = torch.device(f"cuda:{self.rank}") + group = list(range(0, self.world_size)) + group_id = dist.new_group(range(self.world_size)) + rank_to_GPU = self._init_multigpu_helper() + self._test_all_to_all_single( + group, + group_id, + self.rank, + cuda=True, + rank_to_GPU=rank_to_GPU, + dtype=torch.float32, + qtype=DQuantType.FP16 + ) + + @requires_nccl() + @sandcastle_skip_if(BACKEND != "nccl", "Only nccl backend supports all_to_all_single_bfp16") + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + def test_all_to_all_single_bfp16(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl') + device = torch.device(f"cuda:{self.rank}") + group = list(range(0, self.world_size)) + group_id = dist.new_group(range(self.world_size)) + rank_to_GPU = self._init_multigpu_helper() + self._test_all_to_all_single( + group, + group_id, + self.rank, + cuda=True, + rank_to_GPU=rank_to_GPU, + dtype=torch.float32, + qtype=DQuantType.BFP16 + ) + def _test_all_gather( self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float, qtype=None): for dest in group: @@ -203,5 +243,26 @@ if BACKEND == "gloo" or BACKEND == "nccl": for t1, t2 in zip(out_tensors, expected_tensors): self.assertEqual(t1, t2) + def _test_all_to_all_single( + self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float, qtype=DQuantType.FP16 + ): + if group_id is not None: + size = len(group) + in_splits = [i + 1 for i in group] + out_splits = [rank + 1 for _ in group] + in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank + out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype) + expected_tensor = torch.cat( + [torch.ones([rank + 1, size], dtype=dtype) * i for i in group] + ) + if cuda: + rank_to_GPU = rank_to_GPU[rank][0] + in_tensor = in_tensor.cuda(rank_to_GPU) + expected_tensor = expected_tensor.cuda(rank_to_GPU) + out_tensor = out_tensor.cuda(rank_to_GPU) + quantize_alltoall_single = quant.auto_quantize(dist.all_to_all_single, qtype, quant_loss=None) + quantize_alltoall_single(out_tensor, in_tensor, out_splits=out_splits, in_splits=in_splits, group=group_id) + self.assertEqual(out_tensor, expected_tensor) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/algorithms/quantization/quantization.py b/torch/distributed/algorithms/quantization/quantization.py index d58c58c..a5e9b46 100644 --- a/torch/distributed/algorithms/quantization/quantization.py +++ b/torch/distributed/algorithms/quantization/quantization.py @@ -90,18 +90,14 @@ def auto_quantize(func, qtype, quant_loss=None): """ This is a prototype API that automatically quantize the input tensors, choose the precision types, and pass other necessary arguments and then dequantizes the output. - Currently it only supports: . FP16 and BFP16 quantization method supported for gloo and nccl backends . all_gather, all_to_all collective ops - Note: BFP16 only supports 2D tensors. - Args: func (callable): A function representing collective operations. qtype (QuantType): Quantization method quant_loss (float, optional): This can be used to improve accuracy in the dequantization. - Returns: (callable): the same collective as func but enables automatic quantization/dequantization. """ @@ -129,6 +125,16 @@ def auto_quantize(func, qtype, quant_loss=None): for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)): tensors[i] = t + elif (func == dist.all_to_all_single): + tensors = args[0] + out_splits = kwargs.get('out_splits', None) + in_splits = kwargs.get('in_splits', None) + # Quantizing the input/output tensor + input_tensors = _quantize_tensor(args[1], qtype) + out_tensors = _quantize_tensor(tensors, qtype) + dist.all_to_all_single(out_tensors, input_tensors, out_splits, in_splits, group=group) + for i, t in enumerate(_dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)): + tensors[i] = t else: raise RuntimeError( f"The collective op {func} is not supported yet" -- 2.7.4