2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from google.protobuf import text_format
22 from mo.front.caffe.loader import caffe_pb_to_nx
23 from mo.front.caffe.proto import caffe_pb2
24 from mo.utils.error import Error
26 proto_str_one_input = 'name: "network" ' \
41 proto_str_old_styled_multi_input = 'name: "network" ' \
51 proto_str_input = 'name: "network" ' \
61 proto_str_multi_input = 'name: "network" ' \
77 proto_str_old_styled_input = 'name: "network" ' \
84 layer_proto_str = 'layer { ' \
86 'type: "Convolution" ' \
91 proto_same_name_layers = 'layer { ' \
93 'type: "Convolution" ' \
99 'type: "Convolution" ' \
104 class TestLoader(unittest.TestCase):
105 def test_caffe_pb_to_nx_one_input(self):
106 proto = caffe_pb2.NetParameter()
107 text_format.Merge(proto_str_one_input, proto)
108 graph, input_shapes = caffe_pb_to_nx(proto, None)
109 expected_input_shapes = {
110 'Input0': np.array([1, 3, 224, 224])
113 for i in expected_input_shapes:
114 np.testing.assert_array_equal(input_shapes[i], expected_input_shapes[i])
116 def test_caffe_pb_to_nx_old_styled_multi_input(self):
117 proto = caffe_pb2.NetParameter()
118 text_format.Merge(proto_str_old_styled_multi_input + layer_proto_str, proto)
119 self.assertRaises(Error, caffe_pb_to_nx, proto, None)
121 def test_caffe_pb_to_nx_old_styled_input(self):
122 proto = caffe_pb2.NetParameter()
123 text_format.Merge(proto_str_old_styled_input + layer_proto_str, proto)
124 graph, input_shapes = caffe_pb_to_nx(proto, None)
125 expected_input_shapes = {
126 'data': np.array([1, 3, 224, 224])
129 for i in expected_input_shapes:
130 np.testing.assert_array_equal(input_shapes[i], expected_input_shapes[i])
132 def test_caffe_pb_to_standart_input(self):
133 proto = caffe_pb2.NetParameter()
134 text_format.Merge(proto_str_input + layer_proto_str, proto)
135 graph, input_shapes = caffe_pb_to_nx(proto, None)
136 expected_input_shapes = {
137 'data': np.array([1, 3, 224, 224])
140 for i in expected_input_shapes:
141 np.testing.assert_array_equal(input_shapes[i], expected_input_shapes[i])
143 def test_caffe_pb_to_multi_input(self):
144 proto = caffe_pb2.NetParameter()
145 text_format.Merge(proto_str_multi_input + layer_proto_str, proto)
146 graph, input_shapes = caffe_pb_to_nx(proto, None)
147 expected_input_shapes = {
148 'data': np.array([1, 3, 224, 224]),
149 'data1': np.array([1, 3])
152 for i in expected_input_shapes:
153 np.testing.assert_array_equal(input_shapes[i], expected_input_shapes[i])
155 def test_caffe_same_name_layer(self):
156 proto = caffe_pb2.NetParameter()
157 text_format.Merge(proto_str_multi_input + proto_same_name_layers, proto)
158 graph, input_shapes = caffe_pb_to_nx(proto, None)
159 # 6 nodes because: 2 inputs + 2 convolutions + 2 output nodes
160 np.testing.assert_equal(len(graph.nodes()), 6)