with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'):
ShardMetadata(shard_offsets=[-1, 0], shard_lengths=[1, 1], placement="cuda:0")
- with self.assertRaisesRegex(ValueError, 'shard_lengths should be > 0'):
- ShardMetadata(shard_offsets=[0, 0], shard_lengths=[0, 1], placement="cuda:0")
+ with self.assertRaisesRegex(ValueError, 'shard_lengths should be >= 0'):
+ ShardMetadata(shard_offsets=[0, 0], shard_lengths=[-1, 1], placement="cuda:0")
with self.assertRaisesRegex(ValueError, 'Empty shard list provided'):
EnumerableShardingSpec([])
for i in range(len(self.shard_offsets)):
if self.shard_offsets[i] < 0:
raise ValueError('shard_offsets should be >=0')
- if self.shard_lengths[i] <= 0:
- raise ValueError('shard_lengths should be > 0')
+ if self.shard_lengths[i] < 0:
+ raise ValueError('shard_lengths should be >= 0')