", input_shape=", ShapeUtil::HumanString(shape));
}
- // The tile shape must not be the same as the input shape without maximal_
- // also set. If this is the case, we're not actually sharded and the correct
- // constructor should have been used.
- if (ShapeUtil::Equal(shape, tile_shape_)) {
+ // The correct constructor have to be used to create tile maximal shardings.
+ if (tile_assignment_.num_elements() == 1) {
return tensorflow::errors::InvalidArgument(
- "Tile shape is the same as the input shape. If a replicated sharding "
- "was intended, use HloSharding::Replicated(). If a device placement "
- "was intended, use HloSharding::AssignDevice()");
+ "Tile assignment only contains a single device. If a replicated "
+ "sharding was intended, use HloSharding::Replicated(). If a device "
+ "placement was intended, use HloSharding::AssignDevice()");
}
- // The tile shape must not be greater than the input shape in any dimension.
- for (int64 i = 0, e = ShapeUtil::Rank(shape); i != e; ++i) {
- auto tile_dim = tile_shape_.dimensions(i);
- auto shape_dim = shape.dimensions(i);
- if (tile_dim > shape_dim) {
- return tensorflow::errors::InvalidArgument(
- StrCat("Tile is larger than input shape (dimension ", i, ", ",
- tile_dim, " > ", shape_dim));
- }
- }
-
- // The tile assignment tensor must be exactly dimensioned to ceil(shape[dim]
- // tile[dim]) for every dimension contained within tile.
+ // The tile assignment tensor must contain enough element to cover the full
+ // shape with tiles of the specified size.
for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) {
- int64 expected_dim =
- CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i));
- if (tile_assignment_.dimensions()[i] != expected_dim) {
+ int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i);
+ if (shape.dimensions(i) > total_tile_size) {
return tensorflow::errors::InvalidArgument(
- StrCat("Tile assignment tensor has incorrect shape. Dimension ", i,
- " expected ", expected_dim, " but got ",
- tile_assignment_.dimensions()[i]));
+ StrCat("Tile assignment tensor has too few element to cover the full "
+ "shape. Dimension ",
+ i, ", shape ", shape.dimensions(i), ", total size ",
+ total_tile_size));
}
}
}
{
- // Test should pass.
+ // Test should fail because of more devices used then `num_device`.
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
}
{
- // Test should fail due to the tile being larger than the input space.
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
- EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {2, 2}),
- /*num_devices=*/4));
- }
-
- {
- // Test should fail due to the tile not dividing the input space into 4
- // sections (even with padding).
+ // Test should fail because the total tiled size in dimension 0 is 4 but we
+ // have 6 elements along that dimensions.
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));