#include <c10d/default_comm_hooks.hpp>
+#include <c10/core/ScalarType.h>
+#include <c10/util/Exception.h>
#include <c10d/ProcessGroup.hpp>
#include <c10d/comm.hpp>
"ProcessGroup::allreduce should return TensorList");
auto reduce_tensor = result.toTensorVector()[0];
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+ reduce_tensor.scalar_type() == at::ScalarType::Half,
+ "Expected reduced tensor to be fp16 in FP16CompressHook, but got type ",
+ reduce_tensor.scalar_type()
+ );
decompressed_tensor.copy_(reduce_tensor);
return c10::IValue(decompressed_tensor);
};