Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / GemmResolver.py
index 29a39b9..edef22a 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
-import networkx as nx
-
-from extensions.middle.NormalizeFullyConnected import NormalizeFullyConnected
 from mo.front.common.partial_infer.utils import mark_input_bins, assign_dims_to_weights, int64_array
+from mo.graph.graph import Graph
 from mo.middle.replacement import MiddleReplacementPattern
 from mo.ops.op import PermuteAttrs
 
 
 class GemmResolver(MiddleReplacementPattern):
     enabled = True
+    graph_condition = [lambda graph: graph.graph['fw'] != 'tf']
 
     def run_before(self):
+        from extensions.middle.NormalizeFullyConnected import NormalizeFullyConnected
         return [NormalizeFullyConnected]
 
+    def run_after(self):
+        from extensions.middle.pass_separator import MiddleStart
+        return [MiddleStart]
+
     def pattern(self):
         return dict(
             nodes=[
-                   ('input_0', dict(kind='data')),
-                   ('input_1', dict(kind='data')),
-                   ('fc', dict(op='MatMul')),
-                   ('fc_data', dict(kind='data'))],
+                ('input_0', dict(kind='data')),
+                ('input_1', dict(kind='data')),
+                ('fc', dict(op='MatMul')),
+                ('fc_data', dict(kind='data'))],
             edges=[
                 ('input_0', 'fc', {'in': 0}),
                 ('input_1', 'fc', {'in': 1}),
@@ -42,9 +46,10 @@ class GemmResolver(MiddleReplacementPattern):
             ]
         )
 
-    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(self, graph: Graph, match: dict):
         if not match['input_0'].has_valid('value') and not match['input_1'].has_valid('value') or \
-                not match['input_0'].has_valid('value') and match['input_1'].has_valid('value') and match['input_1'].shape.size > 2:
+                not match['input_0'].has_valid('value') and match['input_1'].has_valid('value') and match[
+            'input_1'].shape.size > 2:
             match['fc']['type'] = 'GEMM'
         elif not match['input_0'].has_valid('value') and match['input_1'].has_valid('value'):
             match['fc']['type'] = 'FullyConnected'
@@ -57,6 +62,3 @@ class GemmResolver(MiddleReplacementPattern):
             weights_shape = weights_node.shape
 
             node['out-size'] = weights_shape[1]
-
-
-