"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-import networkx as nx
import numpy as np
from mo.front.common.replacement import FrontReplacementOp
+from mo.graph.graph import Graph
from mo.ops.const import Const
from mo.ops.lin_op import Mul, Add
op = "ImageScaler"
enabled = True
- def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+ def replace_sub_graph(self, graph: Graph, match: dict):
# This replacer replace ImageScalar operation to Mul->Add sequence
# Also it check that weights and biases are good
op = match['op']
if all([x == 0 for x in np.nditer(op.bias)]):
has_bias = False
- # Get all outputs for op node
- out_nodes = [node for node in op.out_nodes().values()]
+ assert len(op.in_ports()) == 1
- assert len(op.in_nodes()) == 1
+ last_port = op.in_port(0).get_source()
- last_node = op.in_node()
# Create Mul & Add nodes
if has_weights:
- mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape))
- mul_op = Mul(graph, dict(name=op.id + '/mul_'))
- last_node = mul_op.create_node(inputs=[last_node, mul_weights.create_node()])
+ mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape)).create_node()
+ mul_op = Mul(graph, dict(name=op.id + '/mul_')).create_node()
+ op.in_port(0).get_connection().set_destination(mul_op.in_port(0))
+ mul_weights.out_port(0).connect(mul_op.in_port(1))
+ last_port = mul_op.out_port(0)
if has_bias:
- add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape))
- add_op = Add(graph, dict(name=op.id + '/add_'))
- last_node = add_op.create_node(inputs=[last_node, add_bias.create_node()])
-
- # Move edges from ImageScaler to last_node (Mul or Add)
- for out_node in out_nodes:
- edge_attrs = graph.get_edge_data(op.id, out_node.id)[0]
- graph.remove_edge(op.id, out_node.id)
- graph.add_edges_from([(last_node.id, out_node.id, edge_attrs)])
-
- # Disconnect ImageScalar node
- graph.remove_edge(op.in_node().id, op.id)
+ add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape)).create_node()
+ add_op = Add(graph, dict(name=op.id + '/add_')).create_node()
+ last_port.get_connection().set_destination(add_op.in_port(0))
+ add_bias.out_port(0).connect(add_op.in_port(1))
+ last_port = add_op.out_port(0)
+
+ op.in_port(0).disconnect()
+ op.out_port(0).get_connection().set_source(last_port)