Add TransformShardedTileShape helper method to HloSharding
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 13 Mar 2018 18:34:23 +0000 (11:34 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Mar 2018 18:41:36 +0000 (11:41 -0700)
It transforms an existing sharding to be compatible with a new shape
with an optional transform method to adjust the tile size for the
sharded dimensions.

PiperOrigin-RevId: 188903257

tensorflow/compiler/xla/service/hlo_sharding.cc
tensorflow/compiler/xla/service/hlo_sharding.h
tensorflow/compiler/xla/service/hlo_sharding_test.cc

index afe79c9..aa9ff89 100644 (file)
@@ -348,4 +348,30 @@ OpSharding HloSharding::ToProto() const {
   return result;
 }
 
+HloSharding HloSharding::TransformShardedTileShape(
+    const Shape& new_shape,
+    const std::function<int64(int64, int64)>& transform) const {
+  CHECK(!IsTuple());
+  if (IsTileMaximal()) {
+    return *this;
+  }
+  CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape()));
+  Shape new_tile_shape;
+  new_tile_shape.set_element_type(tile_shape().element_type());
+  for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) {
+    int64 dim;
+    if (tile_assignment().dim(i) == 1) {
+      dim = new_shape.dimensions(i);
+    } else if (transform) {
+      dim = transform(i, tile_shape().dimensions(i));
+    } else {
+      dim = tile_shape().dimensions(i);
+    }
+    new_tile_shape.add_dimensions(dim);
+  }
+  TF_CHECK_OK(
+      LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape));
+  return HloSharding::Tile(new_tile_shape, tile_assignment());
+}
+
 }  // namespace xla
index 7263198..e715dff 100644 (file)
@@ -207,6 +207,19 @@ class HloSharding {
   // REQUIRES: !IsReplicated() && !IsTuple()
   const Array<int64>& tile_assignment() const { return tile_assignment_; }
 
+  // Return a new sharding that can apply to the given new shape.
+  // If this sharding is tile-maximal, the returned sharding will be the same as
+  // this sharding. If this sharding is not tile-maximal, the returned
+  // sharding's tile size will differ:
+  //   - Non-sharded dimensions will be adapted to be the same as `new_shape`;
+  //     tile_dimension(i) = new_shape.dimensions(i);
+  //   - Sharded dimensions will be kept the same unless `transform` is supplied
+  //     in which case tile_dimension(i) = transform(i, tile_dimension(i));
+  // REQUIRES: !IsTuple().
+  HloSharding TransformShardedTileShape(
+      const Shape& new_shape,
+      const std::function<int64(int64, int64)>& transform = nullptr) const;
+
  private:
   HloSharding()
       : replicated_(true),
index 0c7487b..07fc468 100644 (file)
@@ -269,5 +269,18 @@ TEST_F(HloShardingTest, Hash) {
   }
 }
 
+TEST_F(HloShardingTest, TransformShardedTileShapeTest) {
+  HloSharding sharding =
+      HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}),
+                        Array4D<int64>({{{{0, 1}, {2, 3}}}}));
+  HloSharding result = sharding.TransformShardedTileShape(
+      ShapeUtil::MakeShape(F32, {13, 15, 17, 19}),
+      [](int dim, int value) { return dim * 111; });
+  HloSharding expected =
+      HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}),
+                        Array4D<int64>({{{{0, 1}, {2, 3}}}}));
+  EXPECT_EQ(result, expected);
+}
+
 }  // namespace
 }  // namespace xla