TPU Cost Estimator has been modified to also account for the memory cost in the execu...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 18:16:17 +0000 (11:16 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 18:18:52 +0000 (11:18 -0700)
PiperOrigin-RevId: 191913626

tensorflow/core/grappler/costs/op_level_cost_estimator.cc
tensorflow/core/grappler/costs/op_level_cost_estimator.h

index 14e46ec..79735e6 100644 (file)
@@ -459,11 +459,7 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
   Costs costs;
   costs.compute_time = compute_cost;
   costs.memory_time = memory_cost;
-  if (compute_memory_overlap_) {
-    costs.execution_time = std::max(compute_cost, memory_cost);
-  } else {
-    costs.execution_time = compute_cost + memory_cost;
-  }
+  CombineCostsAndUpdateExecutionTime(&costs);
   return costs;
 }
 
@@ -1375,5 +1371,14 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
   return costs;
 }
 
+void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
+    Costs* costs) const {
+  if (compute_memory_overlap_) {
+    costs->execution_time = std::max(costs->compute_time, costs->memory_time);
+  } else {
+    costs->execution_time = costs->compute_time + costs->memory_time;
+  }
+}
+
 }  // end namespace grappler
 }  // end namespace tensorflow
index fcbecbb..7080264 100644 (file)
@@ -173,6 +173,11 @@ class OpLevelCostEstimator {
       const TensorShapeProto& original_image_shape, const OpInfo& op_info,
       bool* found_unknown_shapes);
 
+  // This method calculates the execution time depending on whether IO can
+  // overlap with computation. It assumes the memory and the compute times have
+  // already been calculated.
+  void CombineCostsAndUpdateExecutionTime(Costs* costs) const;
+
  protected:
   std::map<string, int> elementwise_ops_;
   typedef std::function<Costs(const OpContext& op_context)> CostImpl;