2 Copyright (c) 2017-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from mo.front.common.custom_replacement_registry import CustomReplacementRegistry
19 from mo.front.common.replacement import FrontReplacementSubgraph, FrontReplacementPattern
20 from mo.front.subgraph_matcher import SubgraphMatcher, SubgraphMatch
21 from mo.front.tf.custom_subgraph_call import merge_nodes
22 from mo.graph.graph import Graph
23 from mo.ops.op import Op
24 from mo.utils import class_registration
25 from mo.utils.graph import is_connected_component
26 from mo.utils.replacement_pattern import ReplacementPattern
29 class FrontReplacementFromConfigFileGeneral(FrontReplacementPattern):
31 Translates graph to transform with the configuration files with custom attributes
38 def transform_graph(self, graph, replacement_descriptions):
39 raise Exception('Function "transform_graph" must be overridden in the sub-class')
41 def find_and_replace_pattern(self, graph: Graph):
42 replacement_descriptions = CustomReplacementRegistry().get_custom_replacement_description(self.replacement_id)
43 if replacement_descriptions is None or len(replacement_descriptions) < 1:
44 log.info("Failed to find custom replacement description with id '{}'".format(self.replacement_id))
46 for desc in replacement_descriptions:
47 if 'custom_attributes' in desc._replacement_desc:
48 self.transform_graph(graph, desc._replacement_desc['custom_attributes'])
50 log.info("Failed to find \'custom_attributes\' in replacement description with id '{}'".format(
58 return class_registration.ClassType.FRONT_REPLACER
61 ReplacementPattern.excluded_replacers.append(FrontReplacementFromConfigFileGeneral)
64 class FrontReplacementFromConfigFileSubGraph(FrontReplacementSubgraph):
66 Replace sub-graph defined in the configuration files with a sub-graph of operations.
73 def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
74 return match.matched_nodes_names()
76 def find_and_replace_pattern(self, graph: Graph):
77 replacement_descriptions = CustomReplacementRegistry().get_custom_replacement_description(self.replacement_id)
78 if replacement_descriptions is None:
79 log.info("Failed to find custom replacement description with id '{}'".format(self.replacement_id))
81 # there are a list of custom replacements descriptions that have the same replacement id
82 for replacement_description in replacement_descriptions:
83 sub_graph_matcher = SubgraphMatcher(replacement_description)
84 for match in sub_graph_matcher.matched_sub_graph_instances(graph):
85 if not is_connected_component(graph, match.matched_nodes_names()):
86 log.warning("The following nodes don't form connected sub-graph: {}".format(
87 match.matched_nodes_names()))
88 graph.dump_graph_for_graphviz(match.matched_nodes_names())
89 self.replace_sub_graph(graph, match)
96 return class_registration.ClassType.FRONT_REPLACER
99 ReplacementPattern.excluded_replacers.append(FrontReplacementFromConfigFileSubGraph)
102 class FrontReplacementFromConfigFileOp(FrontReplacementFromConfigFileSubGraph):
104 Replace sub-graph defined in the configuration file with as single operation.
111 def input_edges_match(self, # pylint: disable=method-hidden
113 match: SubgraphMatch,
114 new_sub_graph: dict):
116 Function that generates matching of sub-graph input edges to a new sub-graph input edges. It works in case when
117 the sub-graph is replaced with a single custom-layer node.
118 :param graph: networkX graph to operate on.
119 :param match: object describing matched sub-graph.
120 :param new_sub_graph: dictionary of Nodes objects that forms new sub-graph.
121 :return: object describing edges matching.
123 input_edges_match = dict()
124 inputs_count = match.inputs_count()
125 for sub_graph_input_port in range(inputs_count):
126 # just create single edge for each input port of the sub-graph
127 input_node, input_port = match.input_nodes(sub_graph_input_port)[0]
128 input_edges_match[(input_node.id, input_port)] = (new_sub_graph['new_node'].id, sub_graph_input_port)
129 return input_edges_match
131 def output_edges_match(self, # pylint: disable=method-hidden
133 match: SubgraphMatch,
134 new_sub_graph: dict):
136 Function that generates matching of sub-graph output edges to a new sub-graph output edges. It works in case
137 when the sub-graph is replaced with a single custom-layer node.
138 :param graph: networkX graph to operate on.
139 :param match: object describing matched sub-graph.
140 :param new_sub_graph: dictionary of Nodes objects that forms new sub-graph.
141 :return: object describing edges matching.
143 output_edges_match = dict()
144 outputs_count = match.outputs_count()
145 # prepare output_edges_match based on custom replacement configuration file
146 for sub_graph_output_port in range(outputs_count):
147 output_node, output_port = match.output_node(sub_graph_output_port)
148 output_edges_match[(output_node.id, output_port)] = (new_sub_graph['new_node'].id, sub_graph_output_port)
149 return output_edges_match
151 def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
152 replacement_desc = match.custom_replacement_desc
153 op = Op.get_op_class_by_name(replacement_desc.op)(graph, match.custom_replacement_desc.custom_attributes)
154 op.default_backend_attrs = list(match.custom_replacement_desc.custom_attributes.keys())
155 if 'infer' not in op.attrs:
157 op.substitute_ie_attrs(op.attrs)
158 node = merge_nodes(graph, match.matched_nodes_names(), replacement_desc.get_inputs_description(),
159 replacement_desc.get_outputs_description())
160 node.name = graph.unique_id(op.attrs['type'])
161 node_attrs = graph.node[node.id]
162 # copy attributes which are defined in the custom operation
163 for key in op.attrs.keys():
164 if key not in ['name', 'op']:
165 node_attrs[key] = op.attrs[key]
166 # functions below should return nothing because 'merge_nodes' already created input/output edges
167 self.input_edges_match = lambda gr, ma, new_sub_graph: dict() # pylint: disable=method-hidden
168 self.output_edges_match = lambda gr, ma, new_sub_graph: dict() # pylint: disable=method-hidden
170 node = op.add_node(name=op.attrs['type'] + '_')
171 node.type = op.attrs['type']
172 return {'new_node': node}
179 return class_registration.ClassType.FRONT_REPLACER
182 ReplacementPattern.excluded_replacers.append(FrontReplacementFromConfigFileOp)