#include "tensorflow/core/grappler/costs/graph_memory.h"
#include <list>
#include "tensorflow/core/framework/allocation_description.pb.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
NodeMap node_map(&item_.graph);
for (const auto& dev_stats : timeline.dev_stats()) {
+ const string& device_name = dev_stats.device();
+ const bool is_gpu = (device_name.find("GPU:") || device_name.find("gpu:"));
std::list<LiveTensor>& device_tensors =
live_tensors_per_device[dev_stats.device()];
for (const auto& node_stats : dev_stats.node_stats()) {
// graph (e.g _Send/_Recv nodes).
continue;
}
- for (const string& input : node->input()) {
+ std::unordered_set<int> swapped_inputs;
+ if (is_gpu) {
+ auto it = node->attr().find("_swap_to_host");
+ if (it != node->attr().end()) {
+ const AttrValue& val = it->second;
+ for (int port_id : val.list().i()) {
+ swapped_inputs.insert(port_id);
+ }
+ }
+ }
+ for (int i = 0; i < node->input_size(); ++i) {
+ if (swapped_inputs.find(i) != swapped_inputs.end()) {
+ // The memory of swapped inputs will be released as early as possible:
+ // therefore ignore this input when determining the deallocation time
+ // of the tensor.
+ continue;
+ }
+ const string& input = node->input(i);
int position;
string input_node = ParseNodeName(input, &position);
if (position < 0) {
EXPECT_EQ(gpu_expected, gpu_tensors);
}
+TEST_F(GraphMemoryTest, GpuSwapping) {
+ TrivialTestGraphInputYielder fake_input(4, 2, 1024 * 1024, false, {"/GPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+ item.feed.clear();
+
+ {
+ // Estimate the max memory usage for the graph.
+ GraphMemory memory(item);
+ Status s = memory.InferStatically(devices_);
+ TF_CHECK_OK(s);
+
+ const GraphMemory::MemoryUsage& gpu_mem =
+ memory.GetPeakMemoryUsage("/GPU:0");
+ EXPECT_EQ(20971520, gpu_mem.used_memory);
+ std::set<string> gpu_tensors;
+ for (const auto& t : gpu_mem.live_tensors) {
+ gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
+ }
+ std::set<string> gpu_expected;
+ gpu_expected.insert("Square:0");
+ gpu_expected.insert("Square_1:0");
+ gpu_expected.insert("AddN:0");
+ gpu_expected.insert("AddN_1:0");
+ gpu_expected.insert("AddN_2:0");
+ EXPECT_EQ(gpu_expected, gpu_tensors);
+ }
+
+ {
+ // Swap the first input to node AddN_1: its fanin (the square nodes) should
+ // not appear in the max cut anymore.
+ for (auto& node : *item.graph.mutable_node()) {
+ if (node.name() == "AddN_1") {
+ (*node.mutable_attr())["_swap_to_host"].mutable_list()->add_i(0);
+ }
+ }
+ GraphMemory memory(item);
+ Status s = memory.InferStatically(devices_);
+ TF_CHECK_OK(s);
+ const GraphMemory::MemoryUsage& new_gpu_mem =
+ memory.GetPeakMemoryUsage("/GPU:0");
+ EXPECT_EQ(20971520, new_gpu_mem.used_memory);
+ std::set<string> new_gpu_tensors;
+ for (const auto& t : new_gpu_mem.live_tensors) {
+ new_gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
+ }
+ std::set<string> new_gpu_expected;
+ new_gpu_expected.insert("AddN:0");
+ new_gpu_expected.insert("AddN_1:0");
+ new_gpu_expected.insert("AddN_2:0");
+ new_gpu_expected.insert("AddN_3:0");
+ new_gpu_expected.insert("AddN_4:0");
+ EXPECT_EQ(new_gpu_expected, new_gpu_tensors);
+ }
+}
+
TEST_F(GraphMemoryTest, CtrlDependencies) {
// Build a simple graph with a control dependency.
Scope s = Scope::NewRootScope();
GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
bool use_multiple_devices, bool insert_queue,
const std::vector<string>& device_names) {
- CHECK_GE(device_names.size(), width);
-
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
std::vector<Output> this_stage;
for (int j = 0; j < width; j++) {
if (last_stage.size() == 1) {
- Output unary_op =
- Square(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
- last_stage[0]);
+ Output unary_op = Square(
+ s.WithDevice(
+ device_names[use_multiple_devices ? j % device_names.size()
+ : 0]),
+ last_stage[0]);
this_stage.push_back(unary_op);
} else {
Output combine =
- AddN(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
+ AddN(s.WithDevice(
+ device_names[use_multiple_devices ? j % device_names.size()
+ : 0]),
last_stage);
this_stage.push_back(combine);
}