Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / caffe / extractor_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18 from unittest.mock import patch
19
20 from mo.front.caffe.extractor import check_phase, register_caffe_python_extractor
21 from mo.front.extractor import CaffePythonFrontExtractorOp
22 from mo.graph.graph import Node
23 from mo.utils.unittest.extractors import FakeMultiParam
24 from mo.utils.unittest.graph import build_graph
25
26 nodes_attributes = {'node_1': {'type': 'Identity', 'kind': 'op'},
27                     'node_2': {'type': 'Identity', 'kind': 'op'}}
28
29
30 class TestExtractor(unittest.TestCase):
31     def test_check_phase_train_phase(self):
32         phase_param = {
33             'phase': 0
34         }
35
36         include_param = {
37             'include': [FakeMultiParam(phase_param)]
38         }
39
40         graph = build_graph(nodes_attributes,
41                             [('node_1', 'node_2')],
42                             {
43                                 'node_1': {'pb': FakeMultiParam(include_param)}
44                             })
45
46         node = Node(graph, 'node_1')
47         res = check_phase(node)
48         exp_res = {'phase': 0}
49         self.assertEqual(res, exp_res)
50
51     def test_check_phase_test_phase(self):
52         phase_param = {
53             'phase': 1
54         }
55
56         include_param = {
57             'include': [FakeMultiParam(phase_param)]
58         }
59
60         graph = build_graph(nodes_attributes,
61                             [('node_1', 'node_2')],
62                             {
63                                 'node_1': {'pb': FakeMultiParam(include_param)}
64                             })
65
66         node = Node(graph, 'node_1')
67         res = check_phase(node)
68         exp_res = {'phase': 1}
69         self.assertEqual(res, exp_res)
70
71     def test_check_phase_no_phase(self):
72         phase_param = {}
73
74         include_param = {
75             'include': [FakeMultiParam(phase_param)]
76         }
77
78         graph = build_graph(nodes_attributes,
79                             [('node_1', 'node_2')],
80                             {
81                                 'node_1': {'pb': FakeMultiParam(include_param)}
82                             })
83
84         node = Node(graph, 'node_1')
85         res = check_phase(node)
86         exp_res = {}
87         self.assertEqual(res, exp_res)
88
89     def test_check_phase_no_include(self):
90         include_param = {}
91
92         graph = build_graph(nodes_attributes,
93                             [('node_1', 'node_2')],
94                             {
95                                 'node_1': {'pb': FakeMultiParam(include_param)}
96                             })
97
98         node = Node(graph, 'node_1')
99         res = check_phase(node)
100         exp_res = {}
101         self.assertEqual(res, exp_res)
102
103     def test_check_phase_no_pb(self):
104         graph = build_graph(nodes_attributes,
105                             [('node_1', 'node_2')],
106                             {})
107
108         node = Node(graph, 'node_1')
109         res = check_phase(node)
110         exp_res = {}
111         self.assertEqual(res, exp_res)
112
113     @patch('mo.ops.activation.Activation')
114     def test_register_caffe_python_extractor_by_name(self, op_mock):
115         op_mock.op = 'TestLayer'
116         name = 'myTestLayer'
117         register_caffe_python_extractor(op_mock, name)
118         self.assertIn(name, CaffePythonFrontExtractorOp.registered_ops)
119
120     @patch('mo.ops.activation.Activation')
121     def test_register_caffe_python_extractor_by_op(self, op_mock):
122         op_mock.op = 'TestLayer'
123         register_caffe_python_extractor(op_mock)
124         self.assertIn(op_mock.op, CaffePythonFrontExtractorOp.registered_ops)