node_name, node_->input(pos), const_name, dtype,
input_node->attr().at("_output_shapes").list().shape(output_pos),
true);
- node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name);
+ node_map_->UpdateOutput(NodeName(node_->input(pos)), node_->name(),
+ node_name);
node_map_->AddOutput(node_name, node_->name());
*node_->mutable_input(pos) = node_name;
}
auto added_node =
AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true);
*node_->mutable_input(input_pos) = added_node->name();
- node_map_->UpdateOutput(added_node->input(0), node_->name(),
+ node_map_->UpdateOutput(NodeName(added_node->input(0)), node_->name(),
added_node->name());
node_map_->AddOutput(added_node->name(), node_->name());
}
AddNodeReshape(reshape_node_name, node_->input(vector_index),
shape_const_node_name, node_->attr().at("T").type());
node_map_->AddOutput(shape_const_node_name, reshape_node_name);
- node_map_->UpdateOutput(node_->input(vector_index), node_->name(),
- reshape_node_name);
+ node_map_->UpdateOutput(NodeName(node_->input(vector_index)),
+ node_->name(), reshape_node_name);
node_map_->AddOutput(reshape_node_name, node_->name());
*node_->mutable_input(vector_index) = reshape_node_name;
}
conv = _two_layer_model(x)
dim = array_ops.placeholder(dtype='int32')
split = array_ops.split(conv, 2, axis=dim)
- output = math_ops.reduce_sum(split[0])
+ scale = constant_op.constant(0.1, shape=[32])
+ offset = constant_op.constant(0.3, shape=[32])
+ bn0 = nn.fused_batch_norm(split[0], scale, offset)
+ bn1 = nn.fused_batch_norm(split[1], scale, offset)
+ add = bn0[0] + bn1[0]
+ output = array_ops.identity(add)
with session.Session() as sess:
output_val_ref = sess.run(output, feed_dict={dim: 3})
num_transposes += 1
nodes.append(node.name)
- # Four transposes were initially added in the Expand phase of
- # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
- self._assert_trans_nchw_to_nhwc('split-0-0', nodes)
+ self._assert_trans_nchw_to_nhwc('add_2-0-0', nodes)
self._assert_map_nhwc_to_nchw('split-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)