From 673641c2d6a27fa97ee05453d671853731a4c602 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 22 Dec 2017 12:37:14 -0800 Subject: [PATCH] Updated the virtual cluster to return the proper error code if the simulated peak memory usage exceeds the available memory. PiperOrigin-RevId: 179952918 --- tensorflow/core/grappler/clusters/BUILD | 2 ++ .../core/grappler/clusters/virtual_cluster.cc | 27 +++++++++++++++++++ .../grappler/clusters/virtual_cluster_test.cc | 24 ++++++++++++++++- .../core/grappler/costs/virtual_scheduler.cc | 11 ++++++++ .../core/grappler/costs/virtual_scheduler.h | 3 +++ 5 files changed, 66 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index e9ddb86a10..7f8527302e 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -78,6 +78,8 @@ tf_cc_test( srcs = ["virtual_cluster_test.cc"], deps = [ ":virtual_cluster", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc index e1f5925f7e..b97e3d1db1 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -96,6 +96,33 @@ Status VirtualCluster::Run(const GraphDef& graph, if (metadata) { scheduler.Summary(metadata); } + + const std::unordered_map& device = GetDevices(); + std::unordered_map peak_mem_usage = + scheduler.GetPeakMemoryUsage(); + for (const auto& mem_usage : peak_mem_usage) { + const string& device_name = mem_usage.first; + auto it = device.find(device_name); + if (it == device.end()) { + // It's probably the fake send/recv device. Eventually we'll need to + // remove this fake device to ensure proper memory accounting for + // multi-device settings. + continue; + } + const DeviceProperties& dev = it->second; + if (dev.memory_size() <= 0) { + // Available device memory unknown + continue; + } + int64 peak_mem = mem_usage.second; + if (peak_mem >= dev.memory_size()) { + return errors::ResourceExhausted( + "Graph requires ", peak_mem, " bytes of memory on device ", + device_name, " to run ", " but device only has ", dev.memory_size(), + " available."); + } + } + return Status::OK(); } diff --git a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc index ec21f5f426..fd925a6ce7 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc @@ -14,11 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -37,13 +40,17 @@ class VirtualClusterTest : public ::testing::Test { cpu_device.set_l1_cache_size(32 * 1024); cpu_device.set_l2_cache_size(256 * 1024); cpu_device.set_l3_cache_size(4 * 1024 * 1024); + cpu_device.set_memory_size(1024 * 1024); std::unordered_map devices; devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; cluster_.reset(new VirtualCluster(devices)); TF_CHECK_OK(cluster_->Provision()); } - void TearDown() override { cluster_.reset(); } + void TearDown() override { + TF_CHECK_OK(cluster_->Shutdown()); + cluster_.reset(); + } protected: std::unique_ptr cluster_; @@ -91,6 +98,21 @@ TEST_F(VirtualClusterTest, CostModel) { } } +TEST_F(VirtualClusterTest, OutOfMemory) { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + // Create a large variable that can't fit in memory. + auto zero = ops::Variable(root.WithOpName("zero"), {1024, 1024}, DT_FLOAT); + auto identity = ops::Identity(root.WithOpName("i"), zero); + auto identity2 = ops::Identity(root.WithOpName("i2"), identity); + GrapplerItem item; + TF_CHECK_OK(root.ToGraphDef(&item.graph)); + item.fetch.push_back("i2"); + + TF_CHECK_OK(cluster_->Initialize(item)); + Status s = cluster_->Run(item.graph, item.feed, item.fetch, nullptr); + EXPECT_EQ(error::RESOURCE_EXHAUSTED, s.code()); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index fce50e33d6..fb3bdedcc6 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -979,5 +979,16 @@ Costs VirtualScheduler::Summary(RunMetadata* metadata) { return Summary(); } +const std::unordered_map VirtualScheduler::GetPeakMemoryUsage() + const { + std::unordered_map result; + for (const auto& device : device_) { + const string& name = device.first; + const DeviceState& state = device.second; + result[name] = state.max_memory_usage; + } + return result; +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 74088780cb..df8ae5861a 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -261,6 +261,9 @@ class VirtualScheduler { // If metadata is nullptr, then just calls and return Summary(). Costs Summary(RunMetadata* metadata); + // Return per device peak memory usage. + const std::unordered_map GetPeakMemoryUsage() const; + protected: const std::unordered_map* GetDeviceStates() const { return &device_; -- 2.34.1