Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / mxnet / max_ext.py
index 3db428c..4af1468 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@
  limitations under the License.
 """
 
+import numpy as np
+
 from mo.front.extractor import FrontExtractorOp
 from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
 from mo.ops.reduce import Reduce
@@ -27,7 +29,7 @@ class MaxFrontExtractor(FrontExtractorOp):
     def extract(node):
         attrs = get_mxnet_layer_attrs(node.symbol_dict)
         data = {
-            'axis': [attrs.int('axis', 0)],
+            'axis': np.array([attrs.int('axis', 0)], dtype=np.int64),
             'reduce_type': 'max',
             'keep_dims': False
         }