"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
for i in range(0, len(exp_shape)):
self.assertEqual(exp_shape[i], res_shape[i])
+ def test_expand_dims_infer_one_input_3(self):
+ graph = build_graph(nodes_attributes,
+ [('input_1', 'expand_dims'),
+ ('expand_dims', 'out')],
+ {'input_1': {'shape': np.array([3, 256, 256])},
+ 'expand_dims': {'expand_axis': -1}
+ })
+
+ expand_dims_node = Node(graph, 'expand_dims')
+
+ tf_expand_dims_infer(expand_dims_node)
+ exp_shape = np.array([3, 256, 256, 1])
+ res_shape = expand_dims_node.out_node().shape
+ self.assertEqual(len(exp_shape), len(res_shape))
+ for i in range(0, len(exp_shape)):
+ self.assertEqual(exp_shape[i], res_shape[i])
+
+ def test_expand_dims_infer_one_input_4(self):
+ graph = build_graph(nodes_attributes,
+ [('input_1', 'expand_dims'),
+ ('expand_dims', 'out')],
+ {'input_1': {'shape': np.array([3, 256, 256])},
+ 'expand_dims': {'expand_axis': -2}
+ })
+
+ expand_dims_node = Node(graph, 'expand_dims')
+
+ tf_expand_dims_infer(expand_dims_node)
+ exp_shape = np.array([3, 256, 1, 256])
+ res_shape = expand_dims_node.out_node().shape
+ self.assertEqual(len(exp_shape), len(res_shape))
+ for i in range(0, len(exp_shape)):
+ self.assertEqual(exp_shape[i], res_shape[i])
+
def test_expand_dims_infer_one_input_negative(self):
graph = build_graph(nodes_attributes,
[('input_1', 'expand_dims'),