return result;
}
+HloSharding HloSharding::TransformShardedTileShape(
+ const Shape& new_shape,
+ const std::function<int64(int64, int64)>& transform) const {
+ CHECK(!IsTuple());
+ if (IsTileMaximal()) {
+ return *this;
+ }
+ CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape()));
+ Shape new_tile_shape;
+ new_tile_shape.set_element_type(tile_shape().element_type());
+ for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) {
+ int64 dim;
+ if (tile_assignment().dim(i) == 1) {
+ dim = new_shape.dimensions(i);
+ } else if (transform) {
+ dim = transform(i, tile_shape().dimensions(i));
+ } else {
+ dim = tile_shape().dimensions(i);
+ }
+ new_tile_shape.add_dimensions(dim);
+ }
+ TF_CHECK_OK(
+ LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape));
+ return HloSharding::Tile(new_tile_shape, tile_assignment());
+}
+
} // namespace xla
// REQUIRES: !IsReplicated() && !IsTuple()
const Array<int64>& tile_assignment() const { return tile_assignment_; }
+ // Return a new sharding that can apply to the given new shape.
+ // If this sharding is tile-maximal, the returned sharding will be the same as
+ // this sharding. If this sharding is not tile-maximal, the returned
+ // sharding's tile size will differ:
+ // - Non-sharded dimensions will be adapted to be the same as `new_shape`;
+ // tile_dimension(i) = new_shape.dimensions(i);
+ // - Sharded dimensions will be kept the same unless `transform` is supplied
+ // in which case tile_dimension(i) = transform(i, tile_dimension(i));
+ // REQUIRES: !IsTuple().
+ HloSharding TransformShardedTileShape(
+ const Shape& new_shape,
+ const std::function<int64(int64, int64)>& transform = nullptr) const;
+
private:
HloSharding()
: replicated_(true),
}
}
+TEST_F(HloShardingTest, TransformShardedTileShapeTest) {
+ HloSharding sharding =
+ HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}),
+ Array4D<int64>({{{{0, 1}, {2, 3}}}}));
+ HloSharding result = sharding.TransformShardedTileShape(
+ ShapeUtil::MakeShape(F32, {13, 15, 17, 19}),
+ [](int dim, int value) { return dim * 111; });
+ HloSharding expected =
+ HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}),
+ Array4D<int64>({{{{0, 1}, {2, 3}}}}));
+ EXPECT_EQ(result, expected);
+}
+
} // namespace
} // namespace xla