2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from extensions.ops.gather import Gather
19 from mo.back.replacement import BackReplacementPattern
20 from mo.graph.graph import Graph, Node
21 from mo.ops.const import Const
24 class LayoutChangeForConstantShapePaths(BackReplacementPattern):
26 graph_condition = [lambda graph: graph.graph['fw'] == 'tf',
27 lambda graph: graph.graph['cmd_params'].keep_shape_ops]
31 def if_has_value(graph: Graph, node_name: str):
32 return Node(graph, node_name).has_valid('value')
34 def search_of_constant_path_end(self, graph: Graph, node_name: str, visited: set):
35 from collections import deque
37 d.appendleft(node_name)
40 cur_node = d.popleft()
41 node = Node(graph, cur_node)
42 if node.has_valid('permute_attrs'):
43 node['permute_attrs'] = None
44 for _, out_node_name in graph.out_edges(cur_node):
45 if out_node_name not in visited:
46 if self.if_has_value(graph, out_node_name):
48 d.extend([op for _, op in graph.out_edges(out_node_name)])
53 def find_and_replace_pattern(self, graph: Graph):
54 # 1. Inserting Gather to N*C format on constant shape paths
55 # - Search for Shape ops
56 # - Inserting Gather after them in case of [4] or [5] output shape
58 shape_ops = graph.get_op_nodes(op='Shape')
59 constant_shape_paths = set()
62 for shape in shape_ops:
63 shape_of_shape_op_output = shape.out_node().shape
65 if np.array_equal(shape_of_shape_op_output, [4]):
66 index = np.array([0, 2, 3, 1])
67 elif np.array_equal(shape_of_shape_op_output, [5]):
68 index = np.array([0, 2, 3, 4, 1])
72 const = Const(graph, {'value': index}).create_node()
73 gather = Gather(graph, {}).create_node()
75 shape.out_port(0).get_connection().set_source(gather.out_port(0))
76 shape.out_port(0).connect(gather.in_port(0))
77 const.out_port(0).connect(gather.in_port(1))
79 constant_shape_paths.add(gather.id)
80 gather_inserted.append(gather.id)
82 # 2. Inserting Gather to NC* format
83 # - Search from Shape ops found in previous step for nodes without value that are n-th children of Shape op
84 # * MO can not propagate value, there is data path
85 # - Inserting Gather on ports which comes from operations in `constant_shape_paths` list
87 constant_shape_ends = []
89 for shape in shape_ops:
90 constant_shape_ends.extend(self.search_of_constant_path_end(graph, node_name=shape.id,
91 visited=constant_shape_paths))
93 for end in constant_shape_ends:
94 node = Node(graph, end)
95 in_ports = [in_port for in_port in node.in_ports().values()
96 if in_port.get_source().node.id in constant_shape_paths]
98 for in_port in in_ports:
99 shape = in_port.data.get_shape()
101 if np.array_equal(shape, [4]):
102 index = np.array([0, 3, 1, 2])
103 elif np.array_equal(shape, [5]):
104 index = np.array([0, 2, 3, 4, 1])
108 const = Const(graph, {'value': np.array(index)}).create_node()
109 gather = Gather(graph, {}).create_node()
111 in_port.get_connection().set_destination(gather.in_port(0))
112 const.out_port(0).connect(gather.in_port(1))
113 gather.out_port(0).connect(in_port)