2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "nncc/core/ADT/tensor/Shape.h"
19 #include <gtest/gtest.h>
21 TEST(ADT_TENSOR_SHAPE, ctor)
23 nncc::core::ADT::tensor::Shape shape;
25 ASSERT_EQ(0, shape.rank());
28 TEST(ADT_TENSOR_SHAPE, ctor_initializer_list)
30 nncc::core::ADT::tensor::Shape shape{1, 3, 5, 7};
32 ASSERT_EQ(4, shape.rank());
34 ASSERT_EQ(1, shape.dim(0));
35 ASSERT_EQ(3, shape.dim(1));
36 ASSERT_EQ(5, shape.dim(2));
37 ASSERT_EQ(7, shape.dim(3));
40 TEST(ADT_TENSOR_SHAPE, resize)
42 nncc::core::ADT::tensor::Shape shape;
46 ASSERT_EQ(4, shape.rank());
49 TEST(ADT_TENSOR_SHAPE, dim)
51 nncc::core::ADT::tensor::Shape shape;
55 uint32_t dims[4] = {3, 5, 2, 7};
57 for (uint32_t axis = 0; axis < 4; ++axis)
59 shape.dim(axis) = dims[axis];
60 ASSERT_EQ(dims[axis], shape.dim(axis));
64 TEST(ADT_TENSOR_SHAPE, copy)
66 const nncc::core::ADT::tensor::Shape original{3, 5, 2, 7};
67 const nncc::core::ADT::tensor::Shape copied{original};
69 ASSERT_EQ(copied.rank(), original.rank());
71 for (uint32_t axis = 0; axis < 4; ++axis)
73 ASSERT_EQ(copied.dim(axis), original.dim(axis));
77 TEST(ADT_TENSOR_SHAPE, num_elements_rank_0)
79 using nncc::core::ADT::tensor::Shape;
80 using nncc::core::ADT::tensor::num_elements;
84 ASSERT_EQ(1, num_elements(rank_0_shape));
87 TEST(ADT_TENSOR_SHAPE, num_elements_zero)
89 using nncc::core::ADT::tensor::Shape;
90 using nncc::core::ADT::tensor::num_elements;
92 ASSERT_EQ(0, num_elements(Shape{0, 0, 0, 0}));
95 TEST(ADT_TENSOR_SHAPE, num_elements_nonzero)
97 using nncc::core::ADT::tensor::Shape;
98 using nncc::core::ADT::tensor::num_elements;
100 ASSERT_EQ(6, num_elements(Shape{2, 3}));
103 TEST(ADT_TENSOR_SHAPE, num_elements_nulldim)
105 using nncc::core::ADT::tensor::Shape;
106 using nncc::core::ADT::tensor::num_elements;
108 ASSERT_EQ(0, num_elements(Shape{2, 0, 3}));
111 TEST(ADT_TENSOR_SHAPE, squeeze_neg)
113 using nncc::core::ADT::tensor::Shape;
114 using nncc::core::ADT::tensor::squeeze;
116 auto squeezed = squeeze(Shape{3, 5, 2});
118 ASSERT_EQ(3, squeezed.rank());
119 ASSERT_EQ(3, squeezed.dim(0));
120 ASSERT_EQ(5, squeezed.dim(1));
121 ASSERT_EQ(2, squeezed.dim(2));
124 TEST(ADT_TENSOR_SHAPE, squeeze_neg_0)
126 using nncc::core::ADT::tensor::Shape;
127 using nncc::core::ADT::tensor::squeeze;
129 auto squeezed = squeeze(Shape{3, 0, 2});
131 ASSERT_EQ(3, squeezed.rank());
132 ASSERT_EQ(3, squeezed.dim(0));
133 ASSERT_EQ(0, squeezed.dim(1));
134 ASSERT_EQ(2, squeezed.dim(2));
137 TEST(ADT_TENSOR_SHAPE, squeeze_pos)
139 using nncc::core::ADT::tensor::Shape;
140 using nncc::core::ADT::tensor::squeeze;
142 auto squeezed = squeeze(Shape{3, 1, 2});
144 ASSERT_EQ(2, squeezed.rank());
145 ASSERT_EQ(3, squeezed.dim(0));
146 ASSERT_EQ(2, squeezed.dim(1));
149 TEST(ADT_TENSOR_SHAPE, squeeze_nested)
151 using nncc::core::ADT::tensor::Shape;
152 using nncc::core::ADT::tensor::squeeze;
154 Shape shape{3, 1, 2};
156 shape.squeeze().squeeze();
158 ASSERT_EQ(2, shape.rank());
159 ASSERT_EQ(3, shape.dim(0));
160 ASSERT_EQ(2, shape.dim(1));
163 TEST(ADT_TENSOR_SHAPE, eq_negative_on_unmatched_rank)
165 const nncc::core::ADT::tensor::Shape left{1, 1, 1};
166 const nncc::core::ADT::tensor::Shape right{1, 1, 1, 1};
168 ASSERT_FALSE(left == right);
171 TEST(ADT_TENSOR_SHAPE, eq_negative_on_unmatched_dim)
173 const nncc::core::ADT::tensor::Shape left{2, 3};
174 const nncc::core::ADT::tensor::Shape right{2, 4};
176 ASSERT_FALSE(left == right);
179 TEST(ADT_TENSOR_SHAPE, eq_positive)
181 const nncc::core::ADT::tensor::Shape left{2, 3};
182 const nncc::core::ADT::tensor::Shape right{2, 3};
184 ASSERT_TRUE(left == right);