return nullptr;
}
-static bool IsSwappable(GraphView::OutputPort output) {
+static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
const NodeDef& node = *output.node;
+ // There is no point in swapping out persistent tensors, since the tensor will
+ // continue to use memory.
if (IsPersistent(node)) {
return false;
}
if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
return false;
}
-
DataType dtype;
if (!OutputTypeForNode(node, *op_def, output.port_id, &dtype).ok()) {
return false;
}
+ // References can only refer to persistent memory: therefore the node isn't
+ // swappable.
+ if (IsRefType(dtype)) {
+ return false;
+ }
- return !IsRefType(dtype);
+ if (output.node->op() == "Identity" || output.node->op() == "Reshape") {
+ // If placed on the same device, these nodes are just forwarding references
+ // to their input. Therefore they are swappable iff their fanin is swappable
+ // or it resides on a different device.
+ GraphView::InputPort input;
+ input.node = output.node;
+ input.port_id = 0;
+ GraphView::OutputPort fanin = graph.GetRegularFanin(input);
+ if (fanin.node->device() == node.device()) {
+ return IsSwappable(graph, fanin);
+ }
+ }
+ return true;
}
static NodeDef* FindSwapOutTrigger(
// Don't bother with small tensors.
continue;
}
- // Don't try to swap out persistent data
GraphView::OutputPort port =
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
- if (!IsSwappable(port)) {
+ if (!IsSwappable(graph, port)) {
continue;
}
Costs::NanoSeconds execution_time(-1);
TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
+ Output v = ops::Variable(s.WithOpName("v").WithDevice("/gpu:0"),
{128, 128, 8}, DT_FLOAT);
- Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
- Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
- Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), {a});
+ Output a = ops::Identity(s.WithOpName("a").WithDevice("/gpu:0"), v);
+ Output b = ops::Square(s.WithOpName("b").WithDevice("/gpu:0"), v);
+ Output c = ops::Sqrt(s.WithOpName("c").WithDevice("/gpu:0"), a);
+ Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), b);
Output axis = ops::Const(s.WithOpName("axis"), 0);
Output e =
- ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
+ ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {a, b, c, d}, axis);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
for (int64 input_id : val.list().i()) {
inputs_to_swap.insert(input_id);
}
- EXPECT_EQ(std::set<int>({0, 1, 2}), inputs_to_swap);
+ EXPECT_EQ(std::set<int>({1, 2, 3}), inputs_to_swap);
}
}
}
TEST_F(MemoryOptimizerTest, UnswappableInputs) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
+ Output v = ops::Variable(s.WithOpName("v").WithDevice("/gpu:0"),
{128, 128, 8}, DT_FLOAT);
+ Output a = ops::Square(s.WithOpName("a").WithDevice("/gpu:0"), v);
Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
Output index = ops::Const(s.WithOpName("index"), {0});
Output indices = ops::Tile(s.WithOpName("indices"), index, {128});
Output d =
- ops::ScatterAdd(s.WithOpName("d").WithDevice("/gpu:0"), a, indices, c);
+ ops::ScatterAdd(s.WithOpName("d").WithDevice("/gpu:0"), v, indices, c);
Output axis = ops::Const(s.WithOpName("axis"), 0);
Output e =
ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);