Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / kaldi / replace_splice_node_pattern.py
index 360a225..9c14e2c 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 """
 import numpy as np
 
-import networkx as nx
-
 from extensions.front.kaldi.replace_lstm_node_pattern import unique_id
 from mo.front.common.partial_infer.utils import int64_array
 from mo.front.common.replacement import FrontReplacementOp
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.ops.concat import Concat
 from mo.ops.crop import Crop
 from mo.ops.memory import Memory
@@ -49,7 +47,7 @@ class ReplaceSpliceNodePattern(FrontReplacementOp):
     op = "Splice"
     enabled = True
 
-    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
+    def replace_op(self, graph: Graph, node: Node):
         input_node = node.in_nodes()[0]
         memory_pair_id = unique_id('id')
         # Memory(in)
@@ -72,6 +70,7 @@ class ReplaceSpliceNodePattern(FrontReplacementOp):
         #         Concat
         # Input  /
         concat_node = Concat(graph, {'name': 'Splice_Concat',
+                                     'in_ports_count': 2,
                                      'axis': 1}).create_node([crop, input_node])
 
         # Concat -> Memory(out)