Add new helpers to HLO sharding.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 19 Mar 2018 12:01:05 +0000 (05:01 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 19 Mar 2018 12:05:36 +0000 (05:05 -0700)
PiperOrigin-RevId: 189569053

tensorflow/compiler/xla/service/hlo_sharding.h

index e715dff..3827323 100644 (file)
@@ -173,7 +173,7 @@ class HloSharding {
 
   bool operator==(const HloSharding& other) const {
     return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
-           protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) &&
+           ShapeUtil::Compatible(tile_shape_, other.tile_shape_) &&
            tile_assignment_ == other.tile_assignment_ &&
            tuple_elements_ == other.tuple_elements_;
   }
@@ -207,6 +207,13 @@ class HloSharding {
   // REQUIRES: !IsReplicated() && !IsTuple()
   const Array<int64>& tile_assignment() const { return tile_assignment_; }
 
+  // Returns the flattened list of all the leaf shardings in a tuple shape, by
+  // pre-order walk (ShapeTree iterator order).
+  // REQUIRES: IsTuple().
+  const std::vector<HloSharding>& tuple_elements() const {
+    return tuple_elements_;
+  }
+
   // Return a new sharding that can apply to the given new shape.
   // If this sharding is tile-maximal, the returned sharding will be the same as
   // this sharding. If this sharding is not tile-maximal, the returned