Publishing 2019 R3 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / ReplaceMemoryOffsetWithSplice.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import numpy as np
17
18 from extensions.ops.splice import Splice
19 from mo.front.common.partial_infer.utils import int64_array
20 from mo.graph.graph import Graph, Node
21 from mo.middle.replacement import MiddleReplacementPattern
22 from mo.ops.concat import Concat
23 from mo.ops.crop import Crop
24 from mo.ops.memory import Memory
25 from mo.ops.result import Result
26 from mo.utils.error import Error
27
28
29 class ReplaceMemoryOffsetNodePattern(MiddleReplacementPattern):
30     """
31     Replace MemoryOffset with Splice
32     """
33     enabled = False
34
35     @staticmethod
36     def pattern():
37         return dict(
38             nodes=[('op', dict(op='MemoryOffset', has_default=False))],
39             edges=[])
40
41     @staticmethod
42     def replace_pattern(graph: Graph, match: dict):
43         node = match['op']
44         pair_node = Node(graph, node.pair_name)
45
46         if pair_node.has_default:
47             return
48
49         if node.in_port(0).get_source() is not None:
50             input_node_out_port = node.in_port(0).get_source()
51             op_output_id = node.out_port(0).get_destination().node.id
52             out_node_in_ports = pair_node.out_port(0).get_destinations()
53         else:
54             input_node_out_port = pair_node.in_port(0).get_source()
55             op_output_id = pair_node.out_port(0).get_destination().node.id
56             out_node_in_ports = node.out_port(0).get_destinations()
57
58         in_shape = input_node_out_port.data.get_shape().copy()
59
60         node_id = node.id
61         node_name = node.name
62         node_t = node.t
63
64         splice = Splice(graph, {'name': node_name,
65                                 'id': node_id,
66                                 'context': int64_array(range(node_t, 1)) if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
67         splice.in_port(0).connect(input_node_out_port)
68
69         # offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0
70         crop = Crop(graph, {'name': 'Splice_Crop',
71                             'axis': int64_array([1]),
72                             'offset': int64_array([max(0, in_shape[1] * node_t)]),
73                             'dim': int64_array([in_shape[1]])}).create_node()
74
75         splice.out_port(0).connect(crop.in_port(0))
76         splice.out_port(0).data.set_shape(int64_array([in_shape[0], (abs(node_t) + 1) * in_shape[1]]))
77
78         outs = input_node_out_port.get_destinations()
79         for in_port in outs:
80             out_ = in_port.node
81             if out_['op'] != 'MemoryOffset' and out_['op'] != 'Splice':
82                 crop_input = Crop(graph, {'name': 'Splice_Crop',
83                                           'axis': int64_array([1]),
84                                           'offset': int64_array([-min(0, in_shape[1] * node_t)]),
85                                           'dim': int64_array([in_shape[1]])}).create_node()
86                 splice.out_port(0).connect(crop_input.in_port(0))
87
88                 in_port.disconnect()
89                 crop_input.out_port(0).connect(in_port)
90                 crop_input.out_port(0).data.set_shape(in_shape)
91
92         for dest_port in out_node_in_ports:
93             dest_port.connect(crop.out_port(0))
94
95         graph.remove_node(op_output_id)
96         graph.remove_node(node.id)
97         graph.remove_node(pair_node.id)
98
99
100 class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
101     """
102     Replace MemoryOffset with Memory if IfDefined used with it to avoid cycles
103     """
104     enabled = False
105
106     @staticmethod
107     def pattern():
108         return dict(
109             nodes=[('op', dict(op='MemoryOffset', has_default=True))],
110             edges=[])
111
112     @staticmethod
113     def replace_pattern(graph: Graph, match: dict):
114         node = match['op']
115         pair_node = Node(graph, node.pair_name)
116
117         if node.t >= 0:
118             raise Error('Does not support IfDefined with t > 0')
119
120         if node.in_port(0).get_source() is not None:
121             input_port = node.in_port(0).get_source()
122             op_output_id = node.out_port(0).get_destination().node.id
123             out_port = pair_node.out_port(0)
124             node_name = node.name
125             pair_name = pair_node.name
126         else:
127             input_port = pair_node.in_port(0).get_source()
128             op_output_id = pair_node.out_port(0).get_destination().node.id
129             out_port = node.out_port(0)
130             node_name = pair_node.name
131             pair_name = node.name
132
133         in_shape = input_port.data.get_shape()
134         node_t = abs(node.t)
135
136         memory_out = Memory(graph, {'name': pair_name, 'id': node_name+pair_name,
137                                     'index': 1, 'size': 2,
138                                     'shape': np.array([in_shape[1]*node_t])}).create_node()
139         if node_t > 1:
140             crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]),
141                                        'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node()
142             memory_out.out_port(0).connect(crop_concat.in_port(0))
143             memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
144             concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
145             concat.add_sequence_of_ports('in', range(2))
146             crop_concat.out_port(0).connect(concat.in_port(0))
147             crop_concat.out_port(0).data.set_shape(np.array([in_shape[0], crop_concat.dim]))
148             concat.in_port(1).connect(input_port)
149             memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
150                                        'index': 0, 'size': 2,
151                                        'shape': memory_out.shape}).create_node()
152             concat.out_port(0).connect(memory_in.in_port(0))
153             concat.out_port(0).data.set_shape(np.array([in_shape[0], memory_in.shape[0]]))
154             out = Result(graph, {'name': 'Memory_output'}).create_node()
155             memory_in.out_port(0).connect(out.in_port(0))
156             memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
157
158             crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]),
159                                     'offset': np.array([0]), 'axis': np.array([1])}).create_node()
160             memory_out.out_port(0).connect(crop_out.in_port(0))
161             out_port.get_connection().set_source(crop_out.out_port(0))
162             crop_out.out_port(0).data.set_shape(np.array([in_shape[0], crop_out.dim]))
163         else:
164             memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
165                                        'index': 0, 'size': 2,
166                                        'shape': memory_out.shape}).create_node()
167             memory_in.in_port(0).connect(input_port)
168             out = Result(graph, {'name': 'Memory_output'}).create_node()
169             memory_in.out_port(0).connect(out.in_port(0))
170             memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
171             out_port.get_connection().set_source(memory_out.out_port(0))
172             memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
173
174         graph.remove_node(op_output_id)
175         graph.remove_node(node.id)
176         graph.remove_node(pair_node.id)