[DDP][Grad compression] Fix fp16 cpp hook (#63375)
authorRohan Varma <rvarm1@fb.com>
Wed, 18 Aug 2021 18:38:11 +0000 (11:38 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 18 Aug 2021 18:49:35 +0000 (11:49 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63375

I think tensor.copy_(tensor.to(torch::kFloat16)); will keep it as
float32.

Tested by add the following line:

```
LOG(INFO) << "Type is: " << compressed_tensor.scalar_type();
```

before:

```
I0816 17:03:09.823688 364141 default_comm_hooks.cpp:21] Type is: Float
```
after:

```
I0816 17:01:16.779052 353924 default_comm_hooks.cpp:21] Type is: Half
```
ghstack-source-id: 136056092

Test Plan: ci

Reviewed By: SciPioneer

Differential Revision: D30356256

fbshipit-source-id: 8208a705acd7628541cd43c8bf61d007dfdd2435

torch/csrc/distributed/c10d/default_comm_hooks.cpp

index 9d13099c424c655ddbce8908db9d8ab9aebd8f1c..91700baa2e4a584754e3eb15eeb51c774a8d30c6 100644 (file)
@@ -16,21 +16,23 @@ c10::intrusive_ptr<c10::ivalue::Future> AllReduceCommHook::runHook(
 
 c10::intrusive_ptr<c10::ivalue::Future> FP16CompressCommHook::runHook(
     GradBucket& bucket) {
-  auto& tensor = bucket.getBufferRef();
-  tensor.copy_(tensor.to(torch::kFloat16));
-  std::vector<at::Tensor> tensors = {tensor};
+
+  auto compressed_tensor = bucket.getBufferRef().to(torch::kFloat16);
   // Apply the division first to avoid overflow.
-  tensors[0] /= state_->getSize();
+  compressed_tensor /= state_->getSize();
+  std::vector<at::Tensor> tensors = {compressed_tensor};
 
   auto allreduce_fut = state_->allreduce(tensors)->getFuture();
-  auto decompress = [](c10::ivalue::Future& allreduce_fut) {
+  auto decompressed_tensor = bucket.getBufferRef();
+  auto decompress = [decompressed_tensor](c10::ivalue::Future& allreduce_fut) {
     auto result = allreduce_fut.value();
     TORCH_INTERNAL_ASSERT(
         result.isTensorList(),
         "ProcessGroup::allreduce should return TensorList");
+
     auto reduce_tensor = result.toTensorVector()[0];
-    reduce_tensor.copy_(reduce_tensor.to(torch::kFloat));
-    return c10::IValue(reduce_tensor);
+    decompressed_tensor.copy_(reduce_tensor);
+    return c10::IValue(decompressed_tensor);
   };
 
   return allreduce_fut->then(decompress, allreduce_fut->elementType());