4689784d945484fb950aa5866450dc385773498b
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / ReplaceSpliceNodePattern_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import unittest
17
18 from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
19 from mo.graph.graph import Node
20 from mo.utils.unittest.graph import build_graph, compare_graphs
21
22
23 class ReplaceSpliceNodePatternTests(unittest.TestCase):
24     @classmethod
25     def setUpClass(cls):
26         cls.nodes_attributes = {
27             'placeholder': {'kind': 'op', 'op': None},
28             'in_node': {'kind': 'data', 'shape': [1, 13]},
29             'splice': {'kind': 'op', 'op': 'Splice', 'context': range(-5, 6), 'const_dim': 0},
30             'splice_data': {'kind': 'data', 'shape': [1, 143]},
31             'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
32         }
33
34     def test_splice(self):
35         graph = build_graph(self.nodes_attributes,
36                             [('placeholder', 'in_node'),
37                              ('in_node', 'splice'),
38                              ('splice', 'splice_data'),
39                              ('splice_data', 'out_placeholder')])
40         ReplaceSpliceNodePattern().find_and_replace_pattern(graph)
41
42         ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
43                                  'in_node': {'kind': 'data', 'shape': [1, 13]},
44                                  'memory_in': {'kind': 'op', 'op': 'Memory'},
45                                  'memory_in_data': {'kind': 'data'},
46                                  'crop_mem':  {'kind': 'op', 'op': 'Crop', 'offset': 13, 'dim': 130},
47                                  'crop_mem_data': {'kind': 'data'},
48                                  'concat': {'kind': 'op', 'op': 'Concat'},
49                                  'concat_data': {'kind': 'data', 'shape': [1, 143]},
50                                  'memory_out': {'kind': 'op', 'op': 'Memory'},
51                                  'memory_out_data': {'kind': 'data'},
52                                  'result': {'kind': 'op', 'op': 'Result'},
53                                  'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
54                                  },
55                                 [
56                                     ('in_placeholder', 'in_node'),
57                                     ('memory_in', 'memory_in_data'),
58                                     ('memory_in_data', 'crop_mem'),
59                                     ('crop_mem', 'crop_mem_data'),
60                                     ('crop_mem_data', 'concat', {'in': 0}),
61                                     ('in_node', 'concat', {'in': 1}),
62                                     ('concat', 'concat_data'),
63                                     ('concat_data', 'memory_out'),
64                                     ('memory_out', 'memory_out_data'),
65                                     ('memory_out_data', 'result'),
66                                     ('concat_data', 'out_placeholder'),
67                                 ]
68                                 )
69
70         (flag, resp) = compare_graphs(graph, ref_graph, 'out_placeholder')
71         self.assertTrue(flag, resp)
72
73     def test_splice_with_constdim(self):
74         graph = build_graph(self.nodes_attributes,
75                             [('placeholder', 'in_node'),
76                              ('in_node', 'splice'),
77                              ('splice', 'splice_data'),
78                              ('splice_data', 'out_placeholder')])
79         Node(graph, 'splice')['const_dim'] = 10
80         Node(graph, 'splice_data')['shape'] = [1, 43]
81         ReplaceSpliceNodePattern().find_and_replace_pattern(graph)
82
83         ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
84                                  'in_node': {'kind': 'data', 'shape': [1, 13]},
85                                  'split': {'kind': 'op', 'op': 'Split'},
86                                  'split_data_0': {'kind': 'data'},
87                                  'split_data_1': {'kind': 'data'},
88                                  'memory_in': {'kind': 'op', 'op': 'Memory'},
89                                  'memory_in_data': {'kind': 'data'},
90                                  'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 3, 'dim': 30},
91                                  'crop_mem_data': {'kind': 'data'},
92                                  'concat': {'kind': 'op', 'op': 'Concat'},
93                                  'concat_data': {'kind': 'data'},
94                                  'memory_out': {'kind': 'op', 'op': 'Memory'},
95                                  'memory_out_data': {'kind': 'data'},
96                                  'result': {'kind': 'op', 'op': 'Result'},
97                                  'memory_in_constdims': {'kind': 'op', 'op': 'Memory'},
98                                  'memory_in_constdims_data': {'kind': 'data'},
99                                  'crop_mem_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 10, 'dim': 100},
100                                  'crop_mem_constdims_data': {'kind': 'data'},
101                                  'concat_constdims': {'kind': 'op', 'op': 'Concat'},
102                                  'concat_constdims_data': {'kind': 'data'},
103                                  'memory_out_constdims': {'kind': 'op', 'op': 'Memory'},
104                                  'memory_out_constdims_data': {'kind': 'data'},
105                                  'result_constdims': {'kind': 'op', 'op': 'Result'},
106                                  'crop_first_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 10},
107                                  'crop_first_constdims_data': {'kind': 'data'},
108                                  'concat_all': {'kind': 'op', 'op': 'Concat'},
109                                  'concat_all_data': {'kind': 'data', 'shape': [1, 43]},
110                                  'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
111                                  },
112                                 [
113                                     ('in_placeholder', 'in_node'),
114                                     ('in_node', 'split'),
115                                     ('split', 'split_data_0', {'out': 0}),
116                                     ('split', 'split_data_1', {'out': 1}),
117                                     ('memory_in', 'memory_in_data'),
118                                     ('memory_in_data', 'crop_mem'),
119                                     ('crop_mem', 'crop_mem_data'),
120                                     ('crop_mem_data', 'concat', {'in': 0}),
121                                     ('split_data_0', 'concat', {'in': 1}),
122                                     ('concat', 'concat_data'),
123                                     ('concat_data', 'memory_out'),
124                                     ('memory_out', 'memory_out_data'),
125                                     ('memory_out_data', 'result'),
126                                     ('memory_in_constdims', 'memory_in_constdims_data'),
127                                     ('memory_in_constdims_data', 'crop_mem_constdims'),
128                                     ('crop_mem_constdims', 'crop_mem_constdims_data'),
129                                     ('crop_mem_constdims_data', 'concat_constdims', {'in': 0}),
130                                     ('split_data_1', 'concat_constdims', {'in': 1}),
131                                     ('concat_constdims', 'concat_constdims_data'),
132                                     ('concat_constdims_data', 'memory_out_constdims'),
133                                     ('memory_out_constdims', 'memory_out_constdims_data'),
134                                     ('memory_out_constdims_data', 'result_constdims'),
135                                     ('concat_constdims_data', 'crop_first_constdims'),
136                                     ('crop_first_constdims', 'crop_first_constdims_data'),
137                                     ('crop_first_constdims_data', 'concat_all', {'in': 1}),
138                                     ('concat_data', 'concat_all', {'in': 0}),
139                                     ('concat_all', 'concat_all_data'),
140                                     ('concat_all_data', 'out_placeholder'),
141                                 ]
142                                 )
143
144         (flag, resp) = compare_graphs(graph, ref_graph, 'out_placeholder')
145         self.assertTrue(flag, resp)