From 6e8008294b6ed502123feadca93a2968f76b94a8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 6 Apr 2018 11:16:17 -0700 Subject: [PATCH] TPU Cost Estimator has been modified to also account for the memory cost in the execution time. Until more sophisticated methods are added, we resort to the roofline model to calculate such cost. PiperOrigin-RevId: 191913626 --- tensorflow/core/grappler/costs/op_level_cost_estimator.cc | 15 ++++++++++----- tensorflow/core/grappler/costs/op_level_cost_estimator.h | 5 +++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 14e46ec..79735e6 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -459,11 +459,7 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost( Costs costs; costs.compute_time = compute_cost; costs.memory_time = memory_cost; - if (compute_memory_overlap_) { - costs.execution_time = std::max(compute_cost, memory_cost); - } else { - costs.execution_time = compute_cost + memory_cost; - } + CombineCostsAndUpdateExecutionTime(&costs); return costs; } @@ -1375,5 +1371,14 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad( return costs; } +void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime( + Costs* costs) const { + if (compute_memory_overlap_) { + costs->execution_time = std::max(costs->compute_time, costs->memory_time); + } else { + costs->execution_time = costs->compute_time + costs->memory_time; + } +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index fcbecbb..7080264 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -173,6 +173,11 @@ class OpLevelCostEstimator { const TensorShapeProto& original_image_shape, const OpInfo& op_info, bool* found_unknown_shapes); + // This method calculates the execution time depending on whether IO can + // overlap with computation. It assumes the memory and the compute times have + // already been calculated. + void CombineCostsAndUpdateExecutionTime(Costs* costs) const; + protected: std::map elementwise_ops_; typedef std::function CostImpl; -- 2.7.4