int output_pos = 0;
string input_node_name = ParseNodeName(node->input(0), &output_pos);
const NodeDef* input = node_map_->GetNode(input_node_name);
- if (input->op() == "Reshape") {
+ if (input->op() == "Reshape" && !HasControlInputs(*input)) {
reshape->set_input(0, input->input(0));
node_map_->UpdateInput(reshape->name(), input->name(), input->input(0));
nodes_to_simplify->PushBack(reshape);
return reshape->name();
}
- // If the reshape is a no-op, forward its input to its consumers. This is
- // considered aggressive, because users may state that the placeholder
- // outputs tensors of shape [M, N] while feeding it with tensors of shape
- // [M*N] (or worse). The reshape nodes are then necessary to update the
- // tensor metadata to the required shape.
- if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_)) {
+ // If the reshape is a no-op, forward its input to its consumers, unless it
+ // anchors a control dependency since we want to make sure that control
+ // dependency is triggered.
+ if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_) &&
+ !HasControlInputs(*reshape)) {
return reshape->input(0);
}
}
return num_outputs;
}
+bool HasControlInputs(const NodeDef& node) {
+ int num_inputs = node.input_size();
+ if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) {
+ return true;
+ }
+ return false;
+}
+
int NumNonControlInputs(const NodeDef& node) {
int num_inputs = node.input_size();
for (const string& input : node.input()) {
// some of the outputs may be unconnected.
int NumOutputs(const NodeDef& node, GraphDef* graph);
+// Returns true iff the node has at least one control input.
+bool HasControlInputs(const NodeDef& node);
+
// Number of connected non-control inputs.
int NumNonControlInputs(const NodeDef& node);