Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TF_lstm_cell_to_generic.py
index b029b45..20faa4e 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
-import networkx as nx
 import numpy as np
 
 from extensions.middle.FusePermutesSequence import FusePermutesSequence
+from mo.graph.graph import Graph
 from mo.middle.replacement import MiddleReplacementPattern
 
 
@@ -31,7 +31,8 @@ class TensorFlowLSTMtoGeneric(MiddleReplacementPattern):
     enabled = True
 
     def run_after(self):
-        return []
+        from extensions.middle.pass_separator import MiddleStart
+        return [MiddleStart]
 
     def run_before(self):
         return [
@@ -44,7 +45,7 @@ class TensorFlowLSTMtoGeneric(MiddleReplacementPattern):
             edges=[]
         )
 
-    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(self, graph: Graph, match: dict):
         weights_node = match['lstm'].in_node(3)
         biases_node = match['lstm'].in_node(4)
         node = match['lstm']
@@ -61,9 +62,9 @@ class TensorFlowLSTMtoGeneric(MiddleReplacementPattern):
         hidden_size = node.in_node(1).shape[1]
         weights = weights_node.value
         biases = biases_node.value
-        assert weights.shape[0] == input_size + hidden_size, "weights.shape={} input_size={} hidden_size={}".format(
-            weights.shape, input_size, hidden_size)
-        assert weights.shape[1] == biases.shape[0] == 4 * hidden_size,\
+        assert weights.shape[0] == input_size + hidden_size, \
+            "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
+        assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \
             "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)
 
         weights = weights.reshape([