Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / expand_dims_test.py
index 69dbc44..119c3c2 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -146,6 +146,40 @@ class TestExpandDimsInfer(unittest.TestCase):
         for i in range(0, len(exp_shape)):
             self.assertEqual(exp_shape[i], res_shape[i])
 
+    def test_expand_dims_infer_one_input_3(self):
+        graph = build_graph(nodes_attributes,
+                            [('input_1', 'expand_dims'),
+                             ('expand_dims', 'out')],
+                            {'input_1': {'shape': np.array([3, 256, 256])},
+                             'expand_dims': {'expand_axis': -1}
+                             })
+
+        expand_dims_node = Node(graph, 'expand_dims')
+
+        tf_expand_dims_infer(expand_dims_node)
+        exp_shape = np.array([3, 256, 256, 1])
+        res_shape = expand_dims_node.out_node().shape
+        self.assertEqual(len(exp_shape), len(res_shape))
+        for i in range(0, len(exp_shape)):
+            self.assertEqual(exp_shape[i], res_shape[i])
+
+    def test_expand_dims_infer_one_input_4(self):
+        graph = build_graph(nodes_attributes,
+                            [('input_1', 'expand_dims'),
+                             ('expand_dims', 'out')],
+                            {'input_1': {'shape': np.array([3, 256, 256])},
+                             'expand_dims': {'expand_axis': -2}
+                             })
+
+        expand_dims_node = Node(graph, 'expand_dims')
+
+        tf_expand_dims_infer(expand_dims_node)
+        exp_shape = np.array([3, 256, 1, 256])
+        res_shape = expand_dims_node.out_node().shape
+        self.assertEqual(len(exp_shape), len(res_shape))
+        for i in range(0, len(exp_shape)):
+            self.assertEqual(exp_shape[i], res_shape[i])
+
     def test_expand_dims_infer_one_input_negative(self):
         graph = build_graph(nodes_attributes,
                             [('input_1', 'expand_dims'),