2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from extensions.ops.splice import Splice
19 from mo.front.common.partial_infer.utils import int64_array
20 from mo.graph.graph import Graph, Node
21 from mo.middle.replacement import MiddleReplacementPattern
22 from mo.ops.concat import Concat
23 from mo.ops.crop import Crop
24 from mo.ops.memory import Memory
25 from mo.ops.result import Result
26 from mo.utils.error import Error
29 class ReplaceMemoryOffsetNodePattern(MiddleReplacementPattern):
31 Replace MemoryOffset with Splice
38 nodes=[('op', dict(op='MemoryOffset', has_default=False))],
42 def replace_pattern(graph: Graph, match: dict):
44 pair_node = Node(graph, node.pair_name)
46 if pair_node.has_default:
49 if node.in_port(0).get_source() is not None:
50 input_node_out_port = node.in_port(0).get_source()
51 op_output_id = node.out_port(0).get_destination().node.id
52 out_node_in_ports = pair_node.out_port(0).get_destinations()
54 input_node_out_port = pair_node.in_port(0).get_source()
55 op_output_id = pair_node.out_port(0).get_destination().node.id
56 out_node_in_ports = node.out_port(0).get_destinations()
58 in_shape = input_node_out_port.data.get_shape().copy()
64 splice = Splice(graph, {'name': node_name,
66 'context': int64_array(range(node_t, 1)) if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
67 splice.in_port(0).connect(input_node_out_port)
69 # offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0
70 crop = Crop(graph, {'name': 'Splice_Crop',
71 'axis': int64_array([1]),
72 'offset': int64_array([max(0, in_shape[1] * node_t)]),
73 'dim': int64_array([in_shape[1]])}).create_node()
75 splice.out_port(0).connect(crop.in_port(0))
76 splice.out_port(0).data.set_shape(int64_array([in_shape[0], (abs(node_t) + 1) * in_shape[1]]))
78 outs = input_node_out_port.get_destinations()
81 if out_['op'] != 'MemoryOffset' and out_['op'] != 'Splice':
82 crop_input = Crop(graph, {'name': 'Splice_Crop',
83 'axis': int64_array([1]),
84 'offset': int64_array([-min(0, in_shape[1] * node_t)]),
85 'dim': int64_array([in_shape[1]])}).create_node()
86 splice.out_port(0).connect(crop_input.in_port(0))
89 crop_input.out_port(0).connect(in_port)
90 crop_input.out_port(0).data.set_shape(in_shape)
92 for dest_port in out_node_in_ports:
93 dest_port.connect(crop.out_port(0))
95 graph.remove_node(op_output_id)
96 graph.remove_node(node.id)
97 graph.remove_node(pair_node.id)
100 class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
102 Replace MemoryOffset with Memory if IfDefined used with it to avoid cycles
109 nodes=[('op', dict(op='MemoryOffset', has_default=True))],
113 def replace_pattern(graph: Graph, match: dict):
115 pair_node = Node(graph, node.pair_name)
118 raise Error('Does not support IfDefined with t > 0')
120 if node.in_port(0).get_source() is not None:
121 input_port = node.in_port(0).get_source()
122 op_output_id = node.out_port(0).get_destination().node.id
123 out_port = pair_node.out_port(0)
124 node_name = node.name
125 pair_name = pair_node.name
127 input_port = pair_node.in_port(0).get_source()
128 op_output_id = pair_node.out_port(0).get_destination().node.id
129 out_port = node.out_port(0)
130 node_name = pair_node.name
131 pair_name = node.name
133 in_shape = input_port.data.get_shape()
136 memory_out = Memory(graph, {'name': pair_name, 'id': node_name+pair_name,
137 'index': 1, 'size': 2,
138 'shape': np.array([in_shape[1]*node_t])}).create_node()
140 crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]),
141 'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node()
142 memory_out.out_port(0).connect(crop_concat.in_port(0))
143 memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
144 concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
145 concat.add_sequence_of_ports('in', range(2))
146 crop_concat.out_port(0).connect(concat.in_port(0))
147 crop_concat.out_port(0).data.set_shape(np.array([in_shape[0], crop_concat.dim]))
148 concat.in_port(1).connect(input_port)
149 memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
150 'index': 0, 'size': 2,
151 'shape': memory_out.shape}).create_node()
152 concat.out_port(0).connect(memory_in.in_port(0))
153 concat.out_port(0).data.set_shape(np.array([in_shape[0], memory_in.shape[0]]))
154 out = Result(graph, {'name': 'Memory_output'}).create_node()
155 memory_in.out_port(0).connect(out.in_port(0))
156 memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
158 crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]),
159 'offset': np.array([0]), 'axis': np.array([1])}).create_node()
160 memory_out.out_port(0).connect(crop_out.in_port(0))
161 out_port.get_connection().set_source(crop_out.out_port(0))
162 crop_out.out_port(0).data.set_shape(np.array([in_shape[0], crop_out.dim]))
164 memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
165 'index': 0, 'size': 2,
166 'shape': memory_out.shape}).create_node()
167 memory_in.in_port(0).connect(input_port)
168 out = Result(graph, {'name': 'Memory_output'}).create_node()
169 memory_in.out_port(0).connect(out.in_port(0))
170 memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
171 out_port.get_connection().set_source(memory_out.out_port(0))
172 memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
174 graph.remove_node(op_output_id)
175 graph.remove_node(node.id)
176 graph.remove_node(pair_node.id)