np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
+ @requires_nccl()
+ @skip_if_lt_x_gpu(2)
+ def test_is_last_hook(self):
+
+ store = dist.FileStore(self.file_name, self.world_size)
+ process_group = dist.ProcessGroupNCCL(store, self.rank, self.world_size)
+
+ def hook(flags, bucket):
+ flags.append(bucket.is_last())
+ fut = torch.futures.Future()
+ fut.set_result(bucket.buffer())
+ return fut
+
+ flags = []
+ device_id = gpus_for_rank(self.world_size)[self.rank][0]
+ model = nn.Sequential(
+ nn.Linear(2, 4000, bias=False),
+ *[nn.Linear(4000, 4000, bias=False) for _ in range(10)]
+ )
+ gpu_model = DistributedDataParallel(
+ model.to(device_id),
+ device_ids=[device_id],
+ process_group=process_group,
+ )
+ gpu_model.register_comm_hook(state=flags, hook=hook)
+ input = torch.randn(10, 2)
+ gpu_model(input).sum().backward()
+ self.assertTrue(flags[-1])
+ self.assertFalse(any(flags[:-1]))
+
if __name__ == "__main__":
assert (
public:
explicit GradBucket(
size_t index,
+ size_t bucket_count,
const at::Tensor& tensor,
const std::vector<size_t>& offsets,
const std::vector<size_t>& lengths,
const std::vector<c10::IntArrayRef>& sizes_vec,
const std::vector<at::Tensor>& parameters)
: index_(index),
+ bucket_count_(bucket_count),
buffer_(tensor),
offsets_(offsets),
lengths_(lengths),
// Returns whther this bucket is the last bucket to allreduce in an iteration.
bool isLast() const {
- return index_ == 0;
+ return index_ == bucket_count_ - 1;
}
private:
size_t index_;
+ size_t bucket_count_;
at::Tensor buffer_;
// Per-variable info in buffer_.
auto variables_for_bucket = get_variables_for_bucket(i, bucket);
gradBuckets.emplace_back(
i,
+ buckets_.size(),
return_zero_tensors ? at::zeros_like(bucket.replicas[0].contents)
: bucket.replicas[0].contents,
bucket.replicas[0].offsets,
auto variables_for_bucket = get_variables_for_bucket(next_bucket_, bucket);
GradBucket grad_bucket(
next_bucket_,
+ buckets_.size(),
tensors[0],
// Since we only support single-process single-device
// mode, there is always only one replica in the bucket.