bucket_size_limits.push_back(first_bucket_bytes_cap_);
bucket_size_limits.push_back(bucket_bytes_cap_);
std::vector<size_t> per_bucket_size_limits;
+ auto ddp_set_last_bucket_as_small =
+ (parse_env("DDP_SET_LAST_BUCKET_CAP").compare("1") == 0);
+
+ if (ddp_set_last_bucket_as_small) {
+ // Reverse so that first_bucket_bytes_cap_ (smaller bucket) becomes the last
+ // bucket. We cannot simply pass in {bucket_bytes_cap_, first_bucket_bytes_cap}
+ // as the bucket order as we would immediately advance to the 2nd element
+ // after the first bucket, whereas we only want the last bucket to have
+ // a smaller size.
+ std::reverse(rebuilt_params_.begin(), rebuilt_params_.end());
+ std::reverse(rebuilt_param_indices_.begin(), rebuilt_param_indices_.end());
+ }
std::tie(rebuilt_bucket_indices, per_bucket_size_limits) =
compute_bucket_assignment_by_size(
rebuilt_params_,
expect_sparse_gradients_[0],
rebuilt_param_indices_);
+ if (ddp_set_last_bucket_as_small) {
+ // Reverse again because buckets were rebuilt in the opposite of gradient
+ // ready order.
+ std::reverse(rebuilt_bucket_indices.begin(), rebuilt_bucket_indices.end());
+ std::reverse(per_bucket_size_limits.begin(), per_bucket_size_limits.end());
+ }
+
if (ddp_debug_level_ != c10d::DistributedDebugLevel::OFF) {
TORCH_INTERNAL_ASSERT(
rebuilt_bucket_indices.size() == per_bucket_size_limits.size())
initialize_buckets(
std::move(rebuilt_bucket_indices), std::move(per_bucket_size_limits));
+
return true;
}
torch.cuda.set_device(self.rank)
default_bucket_cap_mb = 25 * (1024 ** 2)
first_bucket_bytes_mb = dist._DEFAULT_FIRST_BUCKET_BYTES
+ os.environ["DDP_SET_LAST_BUCKET_CAP"] = "1"
class MyModel(nn.Module):
def __init__(self):
device_ids=[self.rank]
)
inp = torch.randn(10, 2)
+ rebuilt_bucket_index = 2
for i in range(6):
out = ddp(inp).sum()
out.backward()
logging_data = ddp._get_ddp_logging_data()
- if i < 2:
- bucket_size_limits = [
- int(b) for b in logging_data["initial_bucket_size_limits"].split(", ")
- ]
- # first_bucket_bytes is actually the last because we reverse
- # parameter bucket order.
- self.assertEqual(bucket_size_limits[-1], first_bucket_bytes_mb)
- for j, bucket_size in enumerate(bucket_size_limits):
- if j != len(bucket_size_limits) - 1:
- self.assertEqual(bucket_size, default_bucket_cap_mb)
- else:
- bucket_size_limits = [
- int(b) for b in logging_data["rebuilt_bucket_size_limits"].split(", ")
- ]
- # TODO: rebuild buckets places first bucket at beginning, but
- # might be better to move it to end.
- self.assertEqual(
- bucket_size_limits[0], first_bucket_bytes_mb
- )
- for j, bucket_size in enumerate(bucket_size_limits):
- if j != 0:
- self.assertEqual(bucket_size, default_bucket_cap_mb)
+ bucket_size_limits = [
+ int(b) for b in logging_data[
+ "{}_bucket_size_limits".format(
+ "initial" if i < rebuilt_bucket_index else "rebuilt"
+ )
+ ].split(", ")
+ ]
+ # first_bucket_bytes is actually the last because we reverse
+ # parameter bucket order under DDP_SET_LAST_BUCKET_CAP flag.
+ self.assertEqual(bucket_size_limits[-1], first_bucket_bytes_mb)
+ for j, bucket_size in enumerate(bucket_size_limits):
+ if j != len(bucket_size_limits) - 1:
+ self.assertEqual(bucket_size, default_bucket_cap_mb)
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(