Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / ops / gather.py
index 255fd1f..210633d 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2017-2018 Intel Corporation
+ Copyright (c) 2017-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
+import logging as log
+
 import networkx as nx
 import numpy as np
 
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.ops.op import Op
 
 
 class Gather(Op):
     op = 'Gather'
 
-    def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
+    def __init__(self, graph: Graph, attrs: dict):
         mandatory_props = {
             'type': __class__.op,
             'op': __class__.op,
             'axis': 0,
-            'infer': __class__.infer
+            'in_ports_count': 3,
+            'out_ports_count': 1,
+            'infer': __class__.infer,
         }
         super().__init__(graph, mandatory_props, attrs)
 
@@ -62,6 +66,6 @@ class Gather(Op):
 
         shape = np.concatenate((data.shape[:axis], indices.shape))
         if axis < len(data.shape) - 1:
-            shape = np.concatenate((shape, data.shape[axis+1:]))
+            shape = np.concatenate((shape, data.shape[axis + 1:]))
 
         node.out_node(0).shape = np.array(shape, dtype=np.int64)