c10::intrusive_ptr<c10::ivalue::Future> FP16CompressCommHook::runHook(
GradBucket& bucket) {
- auto& tensor = bucket.getBufferRef();
- tensor.copy_(tensor.to(torch::kFloat16));
- std::vector<at::Tensor> tensors = {tensor};
+
+ auto compressed_tensor = bucket.getBufferRef().to(torch::kFloat16);
// Apply the division first to avoid overflow.
- tensors[0] /= state_->getSize();
+ compressed_tensor /= state_->getSize();
+ std::vector<at::Tensor> tensors = {compressed_tensor};
auto allreduce_fut = state_->allreduce(tensors)->getFuture();
- auto decompress = [](c10::ivalue::Future& allreduce_fut) {
+ auto decompressed_tensor = bucket.getBufferRef();
+ auto decompress = [decompressed_tensor](c10::ivalue::Future& allreduce_fut) {
auto result = allreduce_fut.value();
TORCH_INTERNAL_ASSERT(
result.isTensorList(),
"ProcessGroup::allreduce should return TensorList");
+
auto reduce_tensor = result.toTensorVector()[0];
- reduce_tensor.copy_(reduce_tensor.to(torch::kFloat));
- return c10::IValue(reduce_tensor);
+ decompressed_tensor.copy_(reduce_tensor);
+ return c10::IValue(decompressed_tensor);
};
return allreduce_fut->then(decompress, allreduce_fut->elementType());