Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / LayoutChangeForConstantShapePaths.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import numpy as np
17
18 from extensions.ops.gather import Gather
19 from mo.back.replacement import BackReplacementPattern
20 from mo.graph.graph import Graph, Node
21 from mo.ops.const import Const
22
23
24 class LayoutChangeForConstantShapePaths(BackReplacementPattern):
25     enabled = False
26     graph_condition = [lambda graph: graph.graph['fw'] == 'tf',
27                        lambda graph: graph.graph['cmd_params'].keep_shape_ops]
28     force_clean_up = True
29
30     @staticmethod
31     def if_has_value(graph: Graph, node_name: str):
32         return Node(graph, node_name).has_valid('value')
33
34     def search_of_constant_path_end(self, graph: Graph, node_name: str, visited: set):
35         from collections import deque
36         d = deque()
37         d.appendleft(node_name)
38         ends = set()
39         while len(d) != 0:
40             cur_node = d.popleft()
41             node = Node(graph, cur_node)
42             if node.has_valid('permute_attrs'):
43                 node['permute_attrs'] = None
44             for _, out_node_name in graph.out_edges(cur_node):
45                 if out_node_name not in visited:
46                     if self.if_has_value(graph, out_node_name):
47                         visited.add(cur_node)
48                         d.extend([op for _, op in graph.out_edges(out_node_name)])
49                     else:
50                         ends.add(cur_node)
51         return ends
52
53     def find_and_replace_pattern(self, graph: Graph):
54         # 1. Inserting Gather to N*C format on constant shape paths
55         #   - Search for Shape ops
56         #   - Inserting Gather after them in case of [4] or [5] output shape
57
58         shape_ops = graph.get_op_nodes(op='Shape')
59         constant_shape_paths = set()
60         gather_inserted = []
61
62         for shape in shape_ops:
63             shape_of_shape_op_output = shape.out_node().shape
64
65             if np.array_equal(shape_of_shape_op_output, [4]):
66                 index = np.array([0, 2, 3, 1])
67             elif np.array_equal(shape_of_shape_op_output, [5]):
68                 index = np.array([0, 2, 3, 4, 1])
69             else:
70                 continue
71
72             const = Const(graph, {'value': index}).create_node()
73             gather = Gather(graph, {}).create_node()
74
75             shape.out_port(0).get_connection().set_source(gather.out_port(0))
76             shape.out_port(0).connect(gather.in_port(0))
77             const.out_port(0).connect(gather.in_port(1))
78
79             constant_shape_paths.add(gather.id)
80             gather_inserted.append(gather.id)
81
82         # 2. Inserting Gather to NC* format
83         #   - Search from Shape ops found in previous step for nodes without value that are n-th children of Shape op
84         #       * MO can not propagate value, there is data path
85         #   - Inserting Gather on ports which comes from operations in `constant_shape_paths` list
86
87         constant_shape_ends = []
88
89         for shape in shape_ops:
90             constant_shape_ends.extend(self.search_of_constant_path_end(graph, node_name=shape.id,
91                                                                         visited=constant_shape_paths))
92
93         for end in constant_shape_ends:
94             node = Node(graph, end)
95             in_ports = [in_port for in_port in node.in_ports().values()
96                         if in_port.get_source().node.id in constant_shape_paths]
97
98             for in_port in in_ports:
99                 shape = in_port.data.get_shape()
100
101                 if np.array_equal(shape, [4]):
102                     index = np.array([0, 3, 1, 2])
103                 elif np.array_equal(shape, [5]):
104                     index = np.array([0, 2, 3, 4, 1])
105                 else:
106                     continue
107
108                 const = Const(graph, {'value': np.array(index)}).create_node()
109                 gather = Gather(graph, {}).create_node()
110
111                 in_port.get_connection().set_destination(gather.in_port(0))
112                 const.out_port(0).connect(gather.in_port(1))
113                 gather.out_port(0).connect(in_port)