"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import networkx as nx
import numpy as np
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
from mo.middle.passes.eliminate import merge_data_nodes
from mo.middle.pattern_match import apply_pattern
from mo.ops.lin_op import Mul, Add
from mo.ops.reshape import Reshape
-def convert_batch_norm(graph: nx.MultiDiGraph):
+def convert_batch_norm(graph: Graph):
"""
This function finds FusedBatchNorm layer (or BatchNorm for MXNet) and replaces with Mul->Add->Mul->Add sequence.
"""
_fused_batch_norm_decomposition(graph, tinput, toutput, const, beta, scale, shift, can_be_fused)
-def _fused_batch_norm_decomposition(graph: nx.MultiDiGraph, tinput: Node, toutput: Node, gamma: Node, beta: Node,
+def _fused_batch_norm_decomposition(graph: Graph, tinput: Node, toutput: Node, gamma: Node, beta: Node,
mean: np.ndarray, variance: np.ndarray, can_be_fused=True):
"""
This is common function for TF, Caffe and MXNet
data_nodes=toutput)
-def convert_scale_shift_to_mul_add(graph: nx.MultiDiGraph):
- nodes = [Node(graph, node) for node in graph.nodes() if Node(graph, node).soft_get('op') == 'ScaleShift']
+def convert_scale_shift_to_mul_add(graph: Graph):
+ nodes = graph.get_op_nodes(op='ScaleShift')
for node in nodes:
if node.soft_get('can_be_fused') is False:
continue
+ ports_count = len(node.in_ports())
+
+ input_port = node.in_port(0)
+ scale_port = node.in_port(1) if ports_count > 1 and not node.in_port(1).disconnected() else None
+ shift_port = node.in_port(2) if ports_count > 2 and not node.in_port(2).disconnected() else None
+ output_port = node.out_port(0)
+
has_biases = True
has_weights = True
+
# We don't need zero biases
- if len(node.in_nodes()) < 3 or all([x == 0 for x in node.in_node(2).value]):
+ if shift_port is None or (shift_port.data.get_value() is not None and all([x == 0 for x in shift_port.data.get_value()])):
has_biases = False
- input_node = node.in_node(0)
- scale_node = node.in_node(1)
- shift_node = node.in_node(2) if has_biases else None
- output_node = node.out_node()
- if scale_node.has_valid("value") and all([x == 1 for x in scale_node.value]):
+ # We don't need weights with ones
+ if scale_port is None or (scale_port.data.get_value() is not None and all([x == 1 for x in scale_port.data.get_value()])):
has_weights = False
- mul_node = Mul(graph, dict(name=node.name + "/Mul_"))
- add_node = Add(graph, dict(name=node.name + "/Add_"))
-
- # Disconnect ScaleShift node
- graph.remove_edge(input_node.id, node.id)
- graph.remove_edge(node.id, output_node.id)
+ mul_op = Mul(graph, dict(name=node.name + "/Mul_"))
+ add_op = Add(graph, dict(name=node.name + "/Add_"))
# Expand dims for current layout
- broadcast_dims_cnt = len(input_node.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0
- if scale_node.has_valid("value"):
- Op.expand_node_shape(scale_node, broadcast_dims_cnt)
- else:
- # insert reshape to make shapes similar
- reshape_dims = np.zeros(len(input_node.shape), dtype=np.int64)
+ broadcast_dims_cnt = len(input_port.data.get_shape()) - 2 if graph.graph['layout'] == 'NCHW' else 0
+
+ # In case if we have constant weights/biases we have to broadcast them according to graph layout
+ # otherwise we insert Reshape with broadcast dim attribute.
+ def broadcast_value(port):
+ value = np.array(port.data.get_value())
+ for idx in range(broadcast_dims_cnt):
+ value = np.expand_dims(value, axis=-1)
+ port.data.set_value(value)
+
+ def broadcast_with_reshape(port):
+ input_shape = input_port.data.get_shape()
+ reshape_dims = np.zeros(len(input_shape), dtype=np.int64)
for i in range(0, node.axis):
reshape_dims[i] = 1
- for i in range(node.axis, node.axis + len(scale_node.shape)):
- reshape_dims[i] = scale_node.shape[i-node.axis]
- for i in range(node.axis + len(scale_node.shape), len(input_node.shape)):
+ data_shape = port.data.get_shape()
+ for i in range(node.axis, node.axis + len(data_shape)):
+ reshape_dims[i] = data_shape[i - node.axis]
+ for i in range(node.axis + len(data_shape), len(input_shape)):
reshape_dims[i] = 1
- reshape = Reshape(graph, dict(name=scale_node.name+"/Broadcast_",
- dim=reshape_dims))
- scale_node = reshape.create_node_with_data(inputs=[scale_node])
+ reshape = Reshape(graph, dict(name=port.node.name + "/Broadcast_", dim=reshape_dims)).create_node()
+ port.get_connection().set_destination(reshape.in_port(0))
+ reshape.out_port(0).connect(port)
- Op.expand_node_shape(shift_node, broadcast_dims_cnt)
+ if has_weights and scale_port.data.get_value() is not None:
+ broadcast_value(scale_port)
+ elif has_weights:
+ broadcast_with_reshape(scale_port)
- # Connect input->mul->out->add->out
- if has_biases:
- add_node.create_node_with_data(
- inputs=[mul_node.create_node_with_data(inputs=[input_node, scale_node]), shift_node],
- data_nodes=output_node)
+ if has_biases and shift_port.data.get_value() is not None:
+ broadcast_value(shift_port)
+ elif has_biases:
+ broadcast_with_reshape(shift_port)
+
+ if has_biases and has_weights:
+ # Connect input->mul->out->add->out
+ add_node = add_op.create_node()
+ mul_node = mul_op.create_node()
+
+ # Connect Mul operation with inputs
+ input_port.get_connection().set_destination(mul_node.in_port(0))
+ scale_port.get_connection().set_destination(mul_node.in_port(1))
+
+ # Connect Add operation with inputs
+ mul_node.out_port(0).connect(add_node.in_port(0))
+ shift_port.get_connection().set_destination(add_node.in_port(1))
+
+ output_port.get_connection().set_source(add_node.out_port(0))
elif has_weights:
- mul_node.create_node_with_data(inputs=[input_node, scale_node], data_nodes=output_node)
+ # Connect input->mul->out
+ mul_node = mul_op.create_node()
+
+ # Connect Mul operation with inputs
+ input_port.get_connection().set_destination(mul_node.in_port(0))
+ scale_port.get_connection().set_destination(mul_node.in_port(1))
+
+ output_port.get_connection().set_source(mul_node.out_port(0))
+ elif has_biases:
+ # Connect input->add->out
+ add_node = add_op.create_node()
+
+ # Connect Add operation with inputs
+ input_port.get_connection().set_destination(add_node.in_port(0))
+ shift_port.get_connection().set_destination(add_node.in_port(1))
+
+ output_port.get_connection().set_source(add_node.out_port(0))
else:
- merge_data_nodes(graph, input_node, output_node)
- graph.remove_node(output_node.id)
+ # Connect input->out
+ producer_port = input_port.get_source()
+ input_port.disconnect()
+ output_port.get_connection().set_source(producer_port)
-def _bn_to_mul_add_action(graph: nx.MultiDiGraph, match: dict):
+def _bn_to_mul_add_action(graph: Graph, match: dict):
# Data nodes
tinput = match['input']
toutput = match['output']
data_nodes=toutput)
-def convert_bn_to_mul_add(graph: nx.MultiDiGraph):
+def convert_bn_to_mul_add(graph: Graph):
apply_pattern(
graph,
nodes=[