Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / tf / replacement.py
1 """
2  Copyright (c) 2017-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17
18 from mo.front.common.custom_replacement_registry import CustomReplacementRegistry
19 from mo.front.common.replacement import FrontReplacementSubgraph, FrontReplacementPattern
20 from mo.front.subgraph_matcher import SubgraphMatcher, SubgraphMatch
21 from mo.front.tf.custom_subgraph_call import merge_nodes
22 from mo.graph.graph import Graph
23 from mo.ops.op import Op
24 from mo.utils import class_registration
25 from mo.utils.graph import is_connected_component
26 from mo.utils.replacement_pattern import ReplacementPattern
27
28
29 class FrontReplacementFromConfigFileGeneral(FrontReplacementPattern):
30     """
31     Translates graph to transform with the configuration files with custom attributes
32     """
33     replacement_id = ""
34
35     def __init__(self):
36         super().__init__()
37
38     def transform_graph(self, graph, replacement_descriptions):
39         raise Exception('Function "transform_graph" must be overridden in the sub-class')
40
41     def find_and_replace_pattern(self, graph: Graph):
42         replacement_descriptions = CustomReplacementRegistry().get_custom_replacement_description(self.replacement_id)
43         if replacement_descriptions is None or len(replacement_descriptions) < 1:
44             log.info("Failed to find custom replacement description with id '{}'".format(self.replacement_id))
45             return
46         for desc in replacement_descriptions:
47             if 'custom_attributes' in desc._replacement_desc:
48                 self.transform_graph(graph, desc._replacement_desc['custom_attributes'])
49             else:
50                 log.info("Failed to find \'custom_attributes\' in replacement description with id '{}'".format(
51                     self.replacement_id))
52
53     registered_ops = {}
54     registered_cls = []
55
56     @classmethod
57     def class_type(cls):
58         return class_registration.ClassType.FRONT_REPLACER
59
60
61 ReplacementPattern.excluded_replacers.append(FrontReplacementFromConfigFileGeneral)
62
63
64 class FrontReplacementFromConfigFileSubGraph(FrontReplacementSubgraph):
65     """
66     Replace sub-graph defined in the configuration files with a sub-graph of operations.
67     """
68     replacement_id = ""
69
70     def __init__(self):
71         super().__init__()
72
73     def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
74         return match.matched_nodes_names()
75
76     def find_and_replace_pattern(self, graph: Graph):
77         replacement_descriptions = CustomReplacementRegistry().get_custom_replacement_description(self.replacement_id)
78         if replacement_descriptions is None:
79             log.info("Failed to find custom replacement description with id '{}'".format(self.replacement_id))
80             return
81         # there are a list of custom replacements descriptions that have the same replacement id
82         for replacement_description in replacement_descriptions:
83             sub_graph_matcher = SubgraphMatcher(replacement_description)
84             for match in sub_graph_matcher.matched_sub_graph_instances(graph):
85                 if not is_connected_component(graph, match.matched_nodes_names()):
86                     log.warning("The following nodes don't form connected sub-graph: {}".format(
87                         match.matched_nodes_names()))
88                     graph.dump_graph_for_graphviz(match.matched_nodes_names())
89                 self.replace_sub_graph(graph, match)
90
91     registered_ops = {}
92     registered_cls = []
93
94     @classmethod
95     def class_type(cls):
96         return class_registration.ClassType.FRONT_REPLACER
97
98
99 ReplacementPattern.excluded_replacers.append(FrontReplacementFromConfigFileSubGraph)
100
101
102 class FrontReplacementFromConfigFileOp(FrontReplacementFromConfigFileSubGraph):
103     """
104     Replace sub-graph defined in the configuration file with as single operation.
105     """
106     replacement_id = ""
107
108     def __init__(self):
109         super().__init__()
110
111     def input_edges_match(self,  # pylint: disable=method-hidden
112                           graph: Graph,
113                           match: SubgraphMatch,
114                           new_sub_graph: dict):
115         """
116         Function that generates matching of sub-graph input edges to a new sub-graph input edges. It works in case when
117         the sub-graph is replaced with a single custom-layer node.
118         :param graph: networkX graph to operate on.
119         :param match: object describing matched sub-graph.
120         :param new_sub_graph: dictionary of Nodes objects that forms new sub-graph.
121         :return: object describing edges matching.
122         """
123         input_edges_match = dict()
124         inputs_count = match.inputs_count()
125         for sub_graph_input_port in range(inputs_count):
126             # just create single edge for each input port of the sub-graph
127             input_node, input_port = match.input_nodes(sub_graph_input_port)[0]
128             input_edges_match[(input_node.id, input_port)] = (new_sub_graph['new_node'].id, sub_graph_input_port)
129         return input_edges_match
130
131     def output_edges_match(self,  # pylint: disable=method-hidden
132                            graph: Graph,
133                            match: SubgraphMatch,
134                            new_sub_graph: dict):
135         """
136         Function that generates matching of sub-graph output edges to a new sub-graph output edges. It works in case
137         when the sub-graph is replaced with a single custom-layer node.
138         :param graph: networkX graph to operate on.
139         :param match: object describing matched sub-graph.
140         :param new_sub_graph: dictionary of Nodes objects that forms new sub-graph.
141         :return: object describing edges matching.
142         """
143         output_edges_match = dict()
144         outputs_count = match.outputs_count()
145         # prepare output_edges_match based on custom replacement configuration file
146         for sub_graph_output_port in range(outputs_count):
147             output_node, output_port = match.output_node(sub_graph_output_port)
148             output_edges_match[(output_node.id, output_port)] = (new_sub_graph['new_node'].id, sub_graph_output_port)
149         return output_edges_match
150
151     def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
152         replacement_desc = match.custom_replacement_desc
153         op = Op.get_op_class_by_name(replacement_desc.op)(graph, match.custom_replacement_desc.custom_attributes)
154         op.default_backend_attrs = list(match.custom_replacement_desc.custom_attributes.keys())
155         if 'infer' not in op.attrs:
156             # update IE attrs
157             op.substitute_ie_attrs(op.attrs)
158             node = merge_nodes(graph, match.matched_nodes_names(), replacement_desc.get_inputs_description(),
159                                replacement_desc.get_outputs_description())
160             node.name = graph.unique_id(op.attrs['type'])
161             node_attrs = graph.node[node.id]
162             # copy attributes which are defined in the custom operation
163             for key in op.attrs.keys():
164                 if key not in ['name', 'op']:
165                     node_attrs[key] = op.attrs[key]
166             # functions below should return nothing because 'merge_nodes' already created input/output edges
167             self.input_edges_match = lambda gr, ma, new_sub_graph: dict()  # pylint: disable=method-hidden
168             self.output_edges_match = lambda gr, ma, new_sub_graph: dict()  # pylint: disable=method-hidden
169         else:
170             node = op.add_node(name=op.attrs['type'] + '_')
171             node.type = op.attrs['type']
172         return {'new_node': node}
173
174     registered_ops = {}
175     registered_cls = []
176
177     @classmethod
178     def class_type(cls):
179         return class_registration.ClassType.FRONT_REPLACER
180
181
182 ReplacementPattern.excluded_replacers.append(FrontReplacementFromConfigFileOp)