2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from unittest.mock import patch
20 from mo.front.caffe.extractor import check_phase, register_caffe_python_extractor
21 from mo.front.extractor import CaffePythonFrontExtractorOp
22 from mo.graph.graph import Node
23 from mo.utils.unittest.extractors import FakeMultiParam
24 from mo.utils.unittest.graph import build_graph
26 nodes_attributes = {'node_1': {'type': 'Identity', 'kind': 'op'},
27 'node_2': {'type': 'Identity', 'kind': 'op'}}
30 class TestExtractor(unittest.TestCase):
31 def test_check_phase_train_phase(self):
37 'include': [FakeMultiParam(phase_param)]
40 graph = build_graph(nodes_attributes,
41 [('node_1', 'node_2')],
43 'node_1': {'pb': FakeMultiParam(include_param)}
46 node = Node(graph, 'node_1')
47 res = check_phase(node)
48 exp_res = {'phase': 0}
49 self.assertEqual(res, exp_res)
51 def test_check_phase_test_phase(self):
57 'include': [FakeMultiParam(phase_param)]
60 graph = build_graph(nodes_attributes,
61 [('node_1', 'node_2')],
63 'node_1': {'pb': FakeMultiParam(include_param)}
66 node = Node(graph, 'node_1')
67 res = check_phase(node)
68 exp_res = {'phase': 1}
69 self.assertEqual(res, exp_res)
71 def test_check_phase_no_phase(self):
75 'include': [FakeMultiParam(phase_param)]
78 graph = build_graph(nodes_attributes,
79 [('node_1', 'node_2')],
81 'node_1': {'pb': FakeMultiParam(include_param)}
84 node = Node(graph, 'node_1')
85 res = check_phase(node)
87 self.assertEqual(res, exp_res)
89 def test_check_phase_no_include(self):
92 graph = build_graph(nodes_attributes,
93 [('node_1', 'node_2')],
95 'node_1': {'pb': FakeMultiParam(include_param)}
98 node = Node(graph, 'node_1')
99 res = check_phase(node)
101 self.assertEqual(res, exp_res)
103 def test_check_phase_no_pb(self):
104 graph = build_graph(nodes_attributes,
105 [('node_1', 'node_2')],
108 node = Node(graph, 'node_1')
109 res = check_phase(node)
111 self.assertEqual(res, exp_res)
113 @patch('mo.ops.activation.Activation')
114 def test_register_caffe_python_extractor_by_name(self, op_mock):
115 op_mock.op = 'TestLayer'
117 register_caffe_python_extractor(op_mock, name)
118 self.assertIn(name, CaffePythonFrontExtractorOp.registered_ops)
120 @patch('mo.ops.activation.Activation')
121 def test_register_caffe_python_extractor_by_op(self, op_mock):
122 op_mock.op = 'TestLayer'
123 register_caffe_python_extractor(op_mock)
124 self.assertIn(op_mock.op, CaffePythonFrontExtractorOp.registered_ops)