Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / ops / reshape.py
index f616c8d..8cc24f1 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 """
 import math
 
-import networkx as nx
 import numpy as np
 
 from mo.front.common.partial_infer.elemental import single_output_infer
 from mo.front.common.partial_infer.reshape import tf_reshape_shape_infer
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.ops.op import Op
 from mo.utils.error import Error
 
@@ -29,19 +28,18 @@ class Reshape(Op):
     op = 'Reshape'
     enabled = True
 
-    def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
+    def __init__(self, graph: Graph, attrs: dict):
         super().__init__(graph, {
             'kind': 'op',
             'type': __class__.op,
             'op': __class__.op,
+            'in_ports_count': 2,
+            'out_ports_count': 1,
             'infer': lambda node: single_output_infer(node, tf_reshape_shape_infer,
                                                       lambda node: np.reshape(node.in_node().value,
                                                                               node.out_node().shape))
         }, attrs)
 
-    def supported_attrs(self):
-        return [('dim', lambda node: ','.join(map(str, node['dim'])))]
-
     @staticmethod
     def kaldi_infer(node: Node):
         in_node = node.in_node().in_node()  # prev_layer_node -> data -> this_node
@@ -50,7 +48,7 @@ class Reshape(Op):
         # Convolution/Pooling layers. Therefore there are 4 cases with different
         # partial inference.
         batch = input_shape[0]
-        if in_node.op == 'Convolution' or in_node.op == 'Pooling':
+        if in_node.op in ['Convolution', 'Pooling', 'Permute']:
             output_spatial = np.array([batch, np.prod(input_shape[1:])], dtype=np.int64)
             return Reshape.set_shape_and_dim(node, output_spatial)
         # Supports ONLY NCHW and NH layouts