Expose MaybeGetMinimumShape for use in cost estimators other than OpLevelCostEstimator.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 11 May 2018 22:07:48 +0000 (15:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 22:19:32 +0000 (15:19 -0700)
PiperOrigin-RevId: 196315239

tensorflow/core/grappler/costs/op_level_cost_estimator.cc
tensorflow/core/grappler/costs/op_level_cost_estimator.h

index fbdd311..b8e3375 100644 (file)
@@ -129,33 +129,6 @@ int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride,
   }
 }
 
-// Return a minimum shape if the shape is unknown. If known, return the original
-// shape.
-TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
-                                      int rank, bool* found_unknown_shapes) {
-  auto shape = original_shape;
-  if (shape.unknown_rank() || shape.dim_size() < rank) {
-    *found_unknown_shapes = true;
-    TensorShapeProto::Dim dim;
-    VLOG(2) << "Use minimum shape because the rank is unknown.";
-    // The size of each dimension is at least 1, if unknown.
-    dim.set_size(1);
-    for (int i = 0; i < rank; i++) {
-      *shape.add_dim() = dim;
-    }
-  } else {
-    for (int i = 0; i < shape.dim_size(); i++) {
-      if (shape.dim(i).size() < 0) {
-        *found_unknown_shapes = true;
-        VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
-        // The size of each dimension is at least 1, if unknown.
-        shape.mutable_dim(i)->set_size(1);
-      }
-    }
-  }
-  return shape;
-}
-
 // Return the output element count of a binary element-wise op considering
 // broadcasting.
 int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
@@ -187,6 +160,33 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
 
 }  // namespace
 
+// Return a minimum shape if the shape is unknown. If known, return the original
+// shape.
+TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
+                                      int rank, bool* found_unknown_shapes) {
+  auto shape = original_shape;
+  if (shape.unknown_rank() || shape.dim_size() < rank) {
+    *found_unknown_shapes = true;
+    TensorShapeProto::Dim dim;
+    VLOG(2) << "Use minimum shape because the rank is unknown.";
+    // The size of each dimension is at least 1, if unknown.
+    dim.set_size(1);
+    for (int i = 0; i < rank; i++) {
+      *shape.add_dim() = dim;
+    }
+  } else {
+    for (int i = 0; i < shape.dim_size(); i++) {
+      if (shape.dim(i).size() < 0) {
+        *found_unknown_shapes = true;
+        VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
+        // The size of each dimension is at least 1, if unknown.
+        shape.mutable_dim(i)->set_size(1);
+      }
+    }
+  }
+  return shape;
+}
+
 OpLevelCostEstimator::OpLevelCostEstimator() {
   // Syntactic sugar to build and return a lambda that takes an OpInfo and
   // returns a cost.
index 35649f7..d384f57 100644 (file)
@@ -30,6 +30,8 @@ namespace grappler {
 
 bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
                                         TensorShapeProto* tensor_shape_proto);
+TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
+                                      int rank, bool* found_unknown_shapes);
 
 class OpLevelCostEstimator {
  public: