bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
- protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) &&
+ ShapeUtil::Compatible(tile_shape_, other.tile_shape_) &&
tile_assignment_ == other.tile_assignment_ &&
tuple_elements_ == other.tuple_elements_;
}
// REQUIRES: !IsReplicated() && !IsTuple()
const Array<int64>& tile_assignment() const { return tile_assignment_; }
+ // Returns the flattened list of all the leaf shardings in a tuple shape, by
+ // pre-order walk (ShapeTree iterator order).
+ // REQUIRES: IsTuple().
+ const std::vector<HloSharding>& tuple_elements() const {
+ return tuple_elements_;
+ }
+
// Return a new sharding that can apply to the given new shape.
// If this sharding is tile-maximal, the returned sharding will be the same as
// this sharding. If this sharding is not tile-maximal, the returned