From 9dcf033873007b48033b38b428af45abdef97ee7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 13 Mar 2018 11:34:23 -0700 Subject: [PATCH] Add TransformShardedTileShape helper method to HloSharding It transforms an existing sharding to be compatible with a new shape with an optional transform method to adjust the tile size for the sharded dimensions. PiperOrigin-RevId: 188903257 --- tensorflow/compiler/xla/service/hlo_sharding.cc | 26 ++++++++++++++++++++++ tensorflow/compiler/xla/service/hlo_sharding.h | 13 +++++++++++ .../compiler/xla/service/hlo_sharding_test.cc | 13 +++++++++++ 3 files changed, 52 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index afe79c9..aa9ff89 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -348,4 +348,30 @@ OpSharding HloSharding::ToProto() const { return result; } +HloSharding HloSharding::TransformShardedTileShape( + const Shape& new_shape, + const std::function& transform) const { + CHECK(!IsTuple()); + if (IsTileMaximal()) { + return *this; + } + CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape())); + Shape new_tile_shape; + new_tile_shape.set_element_type(tile_shape().element_type()); + for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) { + int64 dim; + if (tile_assignment().dim(i) == 1) { + dim = new_shape.dimensions(i); + } else if (transform) { + dim = transform(i, tile_shape().dimensions(i)); + } else { + dim = tile_shape().dimensions(i); + } + new_tile_shape.add_dimensions(dim); + } + TF_CHECK_OK( + LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape)); + return HloSharding::Tile(new_tile_shape, tile_assignment()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 7263198..e715dff 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -207,6 +207,19 @@ class HloSharding { // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } + // Return a new sharding that can apply to the given new shape. + // If this sharding is tile-maximal, the returned sharding will be the same as + // this sharding. If this sharding is not tile-maximal, the returned + // sharding's tile size will differ: + // - Non-sharded dimensions will be adapted to be the same as `new_shape`; + // tile_dimension(i) = new_shape.dimensions(i); + // - Sharded dimensions will be kept the same unless `transform` is supplied + // in which case tile_dimension(i) = transform(i, tile_dimension(i)); + // REQUIRES: !IsTuple(). + HloSharding TransformShardedTileShape( + const Shape& new_shape, + const std::function& transform = nullptr) const; + private: HloSharding() : replicated_(true), diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 0c7487b..07fc468 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -269,5 +269,18 @@ TEST_F(HloShardingTest, Hash) { } } +TEST_F(HloShardingTest, TransformShardedTileShapeTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), + Array4D({{{{0, 1}, {2, 3}}}})); + HloSharding result = sharding.TransformShardedTileShape( + ShapeUtil::MakeShape(F32, {13, 15, 17, 19}), + [](int dim, int value) { return dim * 111; }); + HloSharding expected = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}), + Array4D({{{{0, 1}, {2, 3}}}})); + EXPECT_EQ(result, expected); +} + } // namespace } // namespace xla -- 2.7.4