}
}
-// Return a minimum shape if the shape is unknown. If known, return the original
-// shape.
-TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
- int rank, bool* found_unknown_shapes) {
- auto shape = original_shape;
- if (shape.unknown_rank() || shape.dim_size() < rank) {
- *found_unknown_shapes = true;
- TensorShapeProto::Dim dim;
- VLOG(2) << "Use minimum shape because the rank is unknown.";
- // The size of each dimension is at least 1, if unknown.
- dim.set_size(1);
- for (int i = 0; i < rank; i++) {
- *shape.add_dim() = dim;
- }
- } else {
- for (int i = 0; i < shape.dim_size(); i++) {
- if (shape.dim(i).size() < 0) {
- *found_unknown_shapes = true;
- VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
- // The size of each dimension is at least 1, if unknown.
- shape.mutable_dim(i)->set_size(1);
- }
- }
- }
- return shape;
-}
-
// Return the output element count of a binary element-wise op considering
// broadcasting.
int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
} // namespace
+// Return a minimum shape if the shape is unknown. If known, return the original
+// shape.
+TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
+ int rank, bool* found_unknown_shapes) {
+ auto shape = original_shape;
+ if (shape.unknown_rank() || shape.dim_size() < rank) {
+ *found_unknown_shapes = true;
+ TensorShapeProto::Dim dim;
+ VLOG(2) << "Use minimum shape because the rank is unknown.";
+ // The size of each dimension is at least 1, if unknown.
+ dim.set_size(1);
+ for (int i = 0; i < rank; i++) {
+ *shape.add_dim() = dim;
+ }
+ } else {
+ for (int i = 0; i < shape.dim_size(); i++) {
+ if (shape.dim(i).size() < 0) {
+ *found_unknown_shapes = true;
+ VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
+ // The size of each dimension is at least 1, if unknown.
+ shape.mutable_dim(i)->set_size(1);
+ }
+ }
+ }
+ return shape;
+}
+
OpLevelCostEstimator::OpLevelCostEstimator() {
// Syntactic sugar to build and return a lambda that takes an OpInfo and
// returns a cost.