Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / mxnet / zeros_ext.py
index 00923d2..5fec929 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
  limitations under the License.
 """
 
+import ast
 import numpy as np
 
 from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
@@ -29,13 +30,16 @@ class ZerosFrontExtractor(FrontExtractorOp):
     def extract(node):
         attrs = get_mxnet_layer_attrs(node.symbol_dict)
         shape = list(attrs.tuple('shape', int, None))
+        zero_shapes = []
         for i, s in enumerate(shape):
             if s == 0:
                 shape[i] = 1
+                zero_shapes.append(i)
 
         update_attrs = {
             'shape': np.ndarray(shape),
             'value': np.zeros(shape),
+            'zero_shapes': zero_shapes
         }
 
         # update the attributes of the node