Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / expand_dims.py
index 50ac4f0..dbdebd5 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -44,6 +44,12 @@ def tf_expand_dims_infer(node):
     if expand_axis is None:
         return
 
+    # expand_axis is a position where the new axis is placed
+    # so expand_dims works for negative axis in a different way
+    # not as insert operation
+    if expand_axis < 0:
+        expand_axis += len(input_node.shape) + 1
+
     output_node.shape = np.insert(input_node.shape, expand_axis, [1])
     # convert data type of the shape to int64 explicitly
     output_node.shape = output_node.shape.astype(np.int64)