Softens the requirements in the HLO sharding validation
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 24 Apr 2018 14:38:49 +0000 (07:38 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 24 Apr 2018 14:41:40 +0000 (07:41 -0700)
The goal is to support tiled shardings where the last N tile have no data.

PiperOrigin-RevId: 194085302

tensorflow/compiler/xla/service/hlo_sharding.cc
tensorflow/compiler/xla/service/hlo_sharding_test.cc

index 1b42349..994de44 100644 (file)
@@ -256,37 +256,24 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
         ", input_shape=", ShapeUtil::HumanString(shape));
   }
 
-  // The tile shape must not be the same as the input shape without maximal_
-  // also set. If this is the case, we're not actually sharded and the correct
-  // constructor should have been used.
-  if (ShapeUtil::Equal(shape, tile_shape_)) {
+  // The correct constructor have to be used to create tile maximal shardings.
+  if (tile_assignment_.num_elements() == 1) {
     return tensorflow::errors::InvalidArgument(
-        "Tile shape is the same as the input shape. If a replicated sharding "
-        "was intended, use HloSharding::Replicated(). If a device placement "
-        "was intended, use HloSharding::AssignDevice()");
+        "Tile assignment only contains a single device. If a replicated "
+        "sharding was intended, use HloSharding::Replicated(). If a device "
+        "placement was intended, use HloSharding::AssignDevice()");
   }
 
-  // The tile shape must not be greater than the input shape in any dimension.
-  for (int64 i = 0, e = ShapeUtil::Rank(shape); i != e; ++i) {
-    auto tile_dim = tile_shape_.dimensions(i);
-    auto shape_dim = shape.dimensions(i);
-    if (tile_dim > shape_dim) {
-      return tensorflow::errors::InvalidArgument(
-          StrCat("Tile is larger than input shape (dimension ", i, ", ",
-                 tile_dim, " > ", shape_dim));
-    }
-  }
-
-  // The tile assignment tensor must be exactly dimensioned to ceil(shape[dim]
-  // tile[dim]) for every dimension contained within tile.
+  // The tile assignment tensor must contain enough element to cover the full
+  // shape with tiles of the specified size.
   for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) {
-    int64 expected_dim =
-        CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i));
-    if (tile_assignment_.dimensions()[i] != expected_dim) {
+    int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i);
+    if (shape.dimensions(i) > total_tile_size) {
       return tensorflow::errors::InvalidArgument(
-          StrCat("Tile assignment tensor has incorrect shape. Dimension ", i,
-                 " expected ", expected_dim, " but got ",
-                 tile_assignment_.dimensions()[i]));
+          StrCat("Tile assignment tensor has too few element to cover the full "
+                 "shape. Dimension ",
+                 i, ", shape ", shape.dimensions(i), ", total size ",
+                 total_tile_size));
     }
   }
 
index 69ea423..3bf0d25 100644 (file)
@@ -88,7 +88,7 @@ TEST_F(HloShardingTest, Tile) {
   }
 
   {
-    // Test should pass.
+    // Test should fail because of more devices used then `num_device`.
     Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
     HloSharding sharding =
         HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
@@ -97,17 +97,8 @@ TEST_F(HloShardingTest, Tile) {
   }
 
   {
-    // Test should fail due to the tile being larger than the input space.
-    Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
-    HloSharding sharding =
-        HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
-    EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {2, 2}),
-                                       /*num_devices=*/4));
-  }
-
-  {
-    // Test should fail due to the tile not dividing the input space into 4
-    // sections (even with padding).
+    // Test should fail because the total tiled size in dimension 0 is 4 but we
+    // have 6 elements along that dimensions.
     Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
     HloSharding sharding =
         HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));