Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / decomposition.py
index 737074f..cf6739d 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -19,7 +19,7 @@ import logging as log
 import networkx as nx
 import numpy as np
 
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.middle.passes.eliminate import merge_data_nodes
 from mo.middle.pattern_match import apply_pattern
 from mo.ops.lin_op import Mul, Add
@@ -27,7 +27,7 @@ from mo.ops.op import Op
 from mo.ops.reshape import Reshape
 
 
-def convert_batch_norm(graph: nx.MultiDiGraph):
+def convert_batch_norm(graph: Graph):
     """
     This function finds FusedBatchNorm layer (or BatchNorm for MXNet) and replaces with Mul->Add->Mul->Add sequence.
     """
@@ -78,7 +78,7 @@ def convert_batch_norm(graph: nx.MultiDiGraph):
             _fused_batch_norm_decomposition(graph, tinput, toutput, const, beta, scale, shift, can_be_fused)
 
 
-def _fused_batch_norm_decomposition(graph: nx.MultiDiGraph, tinput: Node, toutput: Node, gamma: Node, beta: Node,
+def _fused_batch_norm_decomposition(graph: Graph, tinput: Node, toutput: Node, gamma: Node, beta: Node,
                                     mean: np.ndarray, variance: np.ndarray, can_be_fused=True):
     """
     This is common function for TF, Caffe and MXNet
@@ -113,64 +113,108 @@ def _fused_batch_norm_decomposition(graph: nx.MultiDiGraph, tinput: Node, toutpu
         data_nodes=toutput)
 
 
-def convert_scale_shift_to_mul_add(graph: nx.MultiDiGraph):
-    nodes = [Node(graph, node) for node in graph.nodes() if Node(graph, node).soft_get('op') == 'ScaleShift']
+def convert_scale_shift_to_mul_add(graph: Graph):
+    nodes = graph.get_op_nodes(op='ScaleShift')
     for node in nodes:
         if node.soft_get('can_be_fused') is False:
             continue
 
+        ports_count = len(node.in_ports())
+
+        input_port = node.in_port(0)
+        scale_port = node.in_port(1) if ports_count > 1 and not node.in_port(1).disconnected() else None
+        shift_port = node.in_port(2) if ports_count > 2 and not node.in_port(2).disconnected() else None
+        output_port = node.out_port(0)
+
         has_biases = True
         has_weights = True
+
         # We don't need zero biases
-        if len(node.in_nodes()) < 3 or all([x == 0 for x in node.in_node(2).value]):
+        if shift_port is None or (shift_port.data.get_value() is not None and all([x == 0 for x in shift_port.data.get_value()])):
             has_biases = False
-        input_node = node.in_node(0)
-        scale_node = node.in_node(1)
-        shift_node = node.in_node(2) if has_biases else None
-        output_node = node.out_node()
 
-        if scale_node.has_valid("value") and all([x == 1 for x in scale_node.value]):
+        # We don't need weights with ones
+        if scale_port is None or (scale_port.data.get_value() is not None and all([x == 1 for x in scale_port.data.get_value()])):
             has_weights = False
 
-        mul_node = Mul(graph, dict(name=node.name + "/Mul_"))
-        add_node = Add(graph, dict(name=node.name + "/Add_"))
-
-        # Disconnect ScaleShift node
-        graph.remove_edge(input_node.id, node.id)
-        graph.remove_edge(node.id, output_node.id)
+        mul_op = Mul(graph, dict(name=node.name + "/Mul_"))
+        add_op = Add(graph, dict(name=node.name + "/Add_"))
 
         # Expand dims for current layout
-        broadcast_dims_cnt = len(input_node.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0
-        if scale_node.has_valid("value"):
-            Op.expand_node_shape(scale_node, broadcast_dims_cnt)
-        else:
-            # insert reshape to make shapes similar
-            reshape_dims = np.zeros(len(input_node.shape), dtype=np.int64)
+        broadcast_dims_cnt = len(input_port.data.get_shape()) - 2 if graph.graph['layout'] == 'NCHW' else 0
+
+        # In case if we have constant weights/biases we have to broadcast them according to graph layout
+        # otherwise we insert Reshape with broadcast dim attribute.
+        def broadcast_value(port):
+            value = np.array(port.data.get_value())
+            for idx in range(broadcast_dims_cnt):
+                value = np.expand_dims(value, axis=-1)
+            port.data.set_value(value)
+
+        def broadcast_with_reshape(port):
+            input_shape = input_port.data.get_shape()
+            reshape_dims = np.zeros(len(input_shape), dtype=np.int64)
             for i in range(0, node.axis):
                 reshape_dims[i] = 1
-            for i in range(node.axis, node.axis + len(scale_node.shape)):
-                reshape_dims[i] = scale_node.shape[i-node.axis]
-            for i in range(node.axis + len(scale_node.shape), len(input_node.shape)):
+            data_shape = port.data.get_shape()
+            for i in range(node.axis, node.axis + len(data_shape)):
+                reshape_dims[i] = data_shape[i - node.axis]
+            for i in range(node.axis + len(data_shape), len(input_shape)):
                 reshape_dims[i] = 1
-            reshape = Reshape(graph, dict(name=scale_node.name+"/Broadcast_",
-                                          dim=reshape_dims))
-            scale_node = reshape.create_node_with_data(inputs=[scale_node])
+            reshape = Reshape(graph, dict(name=port.node.name + "/Broadcast_", dim=reshape_dims)).create_node()
+            port.get_connection().set_destination(reshape.in_port(0))
+            reshape.out_port(0).connect(port)
 
-        Op.expand_node_shape(shift_node, broadcast_dims_cnt)
+        if has_weights and scale_port.data.get_value() is not None:
+            broadcast_value(scale_port)
+        elif has_weights:
+            broadcast_with_reshape(scale_port)
 
-        # Connect input->mul->out->add->out
-        if has_biases:
-            add_node.create_node_with_data(
-                inputs=[mul_node.create_node_with_data(inputs=[input_node, scale_node]), shift_node],
-                data_nodes=output_node)
+        if has_biases and shift_port.data.get_value() is not None:
+            broadcast_value(shift_port)
+        elif has_biases:
+            broadcast_with_reshape(shift_port)
+
+        if has_biases and has_weights:
+            # Connect input->mul->out->add->out
+            add_node = add_op.create_node()
+            mul_node = mul_op.create_node()
+
+            # Connect Mul operation with inputs
+            input_port.get_connection().set_destination(mul_node.in_port(0))
+            scale_port.get_connection().set_destination(mul_node.in_port(1))
+
+            # Connect Add operation with inputs
+            mul_node.out_port(0).connect(add_node.in_port(0))
+            shift_port.get_connection().set_destination(add_node.in_port(1))
+
+            output_port.get_connection().set_source(add_node.out_port(0))
         elif has_weights:
-            mul_node.create_node_with_data(inputs=[input_node, scale_node], data_nodes=output_node)
+            # Connect input->mul->out
+            mul_node = mul_op.create_node()
+
+            # Connect Mul operation with inputs
+            input_port.get_connection().set_destination(mul_node.in_port(0))
+            scale_port.get_connection().set_destination(mul_node.in_port(1))
+
+            output_port.get_connection().set_source(mul_node.out_port(0))
+        elif has_biases:
+            # Connect input->add->out
+            add_node = add_op.create_node()
+
+            # Connect Add operation with inputs
+            input_port.get_connection().set_destination(add_node.in_port(0))
+            shift_port.get_connection().set_destination(add_node.in_port(1))
+
+            output_port.get_connection().set_source(add_node.out_port(0))
         else:
-            merge_data_nodes(graph, input_node, output_node)
-            graph.remove_node(output_node.id)
+            # Connect input->out
+            producer_port = input_port.get_source()
+            input_port.disconnect()
+            output_port.get_connection().set_source(producer_port)
 
 
-def _bn_to_mul_add_action(graph: nx.MultiDiGraph, match: dict):
+def _bn_to_mul_add_action(graph: Graph, match: dict):
     # Data nodes
     tinput = match['input']
     toutput = match['output']
@@ -209,7 +253,7 @@ def _bn_to_mul_add_action(graph: nx.MultiDiGraph, match: dict):
                                    data_nodes=toutput)
 
 
-def convert_bn_to_mul_add(graph: nx.MultiDiGraph):
+def convert_bn_to_mul_add(graph: Graph):
     apply_pattern(
         graph,
         nodes=[