Costs costs;
costs.compute_time = compute_cost;
costs.memory_time = memory_cost;
- if (compute_memory_overlap_) {
- costs.execution_time = std::max(compute_cost, memory_cost);
- } else {
- costs.execution_time = compute_cost + memory_cost;
- }
+ CombineCostsAndUpdateExecutionTime(&costs);
return costs;
}
return costs;
}
+void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
+ Costs* costs) const {
+ if (compute_memory_overlap_) {
+ costs->execution_time = std::max(costs->compute_time, costs->memory_time);
+ } else {
+ costs->execution_time = costs->compute_time + costs->memory_time;
+ }
+}
+
} // end namespace grappler
} // end namespace tensorflow
const TensorShapeProto& original_image_shape, const OpInfo& op_info,
bool* found_unknown_shapes);
+ // This method calculates the execution time depending on whether IO can
+ // overlap with computation. It assumes the memory and the compute times have
+ // already been calculated.
+ void CombineCostsAndUpdateExecutionTime(Costs* costs) const;
+
protected:
std::map<string, int> elementwise_ops_;
typedef std::function<Costs(const OpContext& op_context)> CostImpl;