From f656b7f3e07fc3a6a51cb6083d27abebcc6212bb Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Tue, 27 Mar 2018 19:24:40 -0700 Subject: [PATCH] Fixed the interaction between virtual cluster and measuring cost estimator. PiperOrigin-RevId: 190712404 --- .../grappler/costs/measuring_cost_estimator.cc | 23 +++++++++++++++++++--- tensorflow/core/grappler/costs/utils.cc | 4 ++-- tensorflow/python/grappler/cluster_test.py | 16 +++++++++------ 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc index ea43206..833205a 100644 --- a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/robust_stats.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -52,6 +53,8 @@ Status MeasuringCostEstimator::Initialize(const GrapplerItem& item) { Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph, CostGraphDef* cost_graph, Costs* costs) const { + const bool running_simulation = (cluster_->type() == "virtual"); + std::vector times(measurement_steps_); BlockingCounter barrier(measurement_steps_); @@ -80,9 +83,23 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph, } const Costs::MicroSeconds finish = Env::Default()->NowMicros(); - const double time = (finish - start).count() * 1e3; - times[step] = time; - + if (running_simulation) { + // When running simulation, return the estimated runtime, not the time it + // takes to run the simulation. + double time = 0.0; + for (const DeviceStepStats& stepstats : + metadata.step_stats().dev_stats()) { + for (const NodeExecStats& node_stats : stepstats.node_stats()) { + const double completion_time = + node_stats.all_end_rel_micros() + node_stats.all_start_micros(); + time = std::max(time, completion_time * 1e3); + } + } + times[step] = time; + } else { + const double time = (finish - start).count() * 1e3; + times[step] = time; + } if (cost_graph && (step + 1 == measurement_steps_)) { metadata.mutable_cost_graph()->Swap(cost_graph); } diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 076945d..f318e39 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -212,8 +212,8 @@ DeviceProperties GetDeviceInfo(const string& device_str) { CudaGpuId cuda_gpu_id; Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); if (!s.ok()) { - LOG(ERROR) << s; - return unknown; + // We are probably running simulation without linking cuda libraries. + cuda_gpu_id = CudaGpuId(parsed.id); } return GetLocalGPUInfo(cuda_gpu_id); } else if (parsed.type == "CPU") { diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py index a3c4c2b..26c6f22 100644 --- a/tensorflow/python/grappler/cluster_test.py +++ b/tensorflow/python/grappler/cluster_test.py @@ -87,9 +87,10 @@ class ClusterTest(test.TestCase): def testVirtualCluster(self): with ops.Graph().as_default() as g: - a = random_ops.random_uniform(shape=()) - b = random_ops.random_uniform(shape=()) - c = a + b + with ops.device('/device:GPU:0'): + a = random_ops.random_uniform(shape=[1024, 1024]) + b = random_ops.random_uniform(shape=[1024, 1024]) + c = a + b train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(c) mg = meta_graph.create_meta_graph_def(graph=g) @@ -102,10 +103,13 @@ class ClusterTest(test.TestCase): 'architecture': '7' }) named_device = device_properties_pb2.NamedDevice( - properties=device_properties, name='/GPU:0') - grappler_cluster = cluster.Cluster(devices=[named_device]) + properties=device_properties, name='/device:GPU:0') + grappler_cluster = cluster.Cluster( + disable_detailed_stats=False, + disable_timeline=False, + devices=[named_device]) op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item) - self.assertGreater(run_time, 0) + self.assertEqual(run_time, 0.000545) self.assertEqual(len(op_perfs), 15) estimated_perf = grappler_cluster.EstimatePerformance(named_device) -- 2.7.4