1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "gtest/gtest.h"
19 #include "ngraph/builder/autobroadcast.hpp"
20 #include "ngraph/ngraph.hpp"
22 NGRAPH_SUPPRESS_DEPRECATED_START
25 using namespace ngraph;
27 shared_ptr<op::Parameter> getParamFromShape(const Shape& shape)
29 return make_shared<op::Parameter>(element::f32, shape);
32 inline const Shape& getShapeFromParam(const shared_ptr<Node>& node)
34 return node->get_shape();
37 // input shapes are equal so AutoBroadcast does nothing
38 TEST(autobroadcast, no_broadcast_equal)
40 Shape s2345{2, 3, 4, 5};
41 auto lhs = getParamFromShape(s2345);
42 auto rhs = getParamFromShape(s2345);
44 auto shaped = builder::numpy_broadcast({lhs, rhs});
45 const shared_ptr<Node>& ab_lhs = shaped.first;
46 const shared_ptr<Node>& ab_rhs = shaped.second;
48 EXPECT_EQ(ab_lhs, lhs); // no change
49 EXPECT_EQ(getShapeFromParam(ab_lhs), s2345);
51 EXPECT_EQ(ab_rhs, rhs); // no change
52 EXPECT_EQ(getShapeFromParam(ab_rhs), s2345);
55 // input shapes are incompatable
56 TEST(autobroadcast, no_broadcast_incompatable)
58 Shape s2345{2, 3, 4, 5};
59 Shape s6789{6, 7, 8, 9};
60 auto lhs = getParamFromShape(s2345);
61 auto rhs = getParamFromShape(s6789);
63 EXPECT_THROW(builder::numpy_broadcast({lhs, rhs}),
64 builder::numpy_autobroadcast_incompatible_shapes);
67 // basic broadcast test
69 // lhs broadcast to 2,3
70 TEST(autobroadcast, normal_broadcast_2d)
74 auto lhs = getParamFromShape(s3);
75 auto rhs = getParamFromShape(s23);
77 auto shaped = builder::numpy_broadcast({lhs, rhs});
78 const shared_ptr<Node>& ab_lhs = shaped.first;
79 const shared_ptr<Node>& ab_rhs = shaped.second;
81 EXPECT_NE(ab_lhs, lhs);
82 EXPECT_EQ(getShapeFromParam(ab_lhs), s23);
84 EXPECT_EQ(ab_rhs, rhs); // no change
85 EXPECT_EQ(getShapeFromParam(ab_rhs), s23);
88 // basic broadcast test
90 // lhs broadcast to 2,3,4
91 TEST(autobroadcast, normal_broadcast_3d)
95 auto lhs = getParamFromShape(s34);
96 auto rhs = getParamFromShape(s234);
98 auto shaped = builder::numpy_broadcast({lhs, rhs});
99 const shared_ptr<Node>& ab_lhs = shaped.first;
100 const shared_ptr<Node>& ab_rhs = shaped.second;
102 EXPECT_NE(ab_lhs, lhs);
103 EXPECT_EQ(getShapeFromParam(ab_lhs), s234);
105 EXPECT_EQ(ab_rhs, rhs); // no change
106 EXPECT_EQ(getShapeFromParam(ab_rhs), s234);
109 // basic broadcast test
111 // lhs broadcast to 2,3,4,5
112 TEST(autobroadcast, normal_broadcast_4d)
115 Shape s2345{2, 3, 4, 5};
116 auto lhs = getParamFromShape(s345);
117 auto rhs = getParamFromShape(s2345);
119 auto shaped = builder::numpy_broadcast({lhs, rhs});
120 const shared_ptr<Node>& ab_lhs = shaped.first;
121 const shared_ptr<Node>& ab_rhs = shaped.second;
123 EXPECT_NE(ab_lhs, lhs);
124 EXPECT_EQ(getShapeFromParam(ab_lhs), s2345);
126 EXPECT_EQ(ab_rhs, rhs); // no change
127 EXPECT_EQ(getShapeFromParam(ab_rhs), s2345);
130 // basic reshape and broadcast test
131 // rhs reshape to 2,3,4 then
132 // rhs broadcast to 2,3,4,5
133 TEST(autobroadcast, reshape_1x_broadcast)
135 Shape s2345{2, 3, 4, 5};
136 Shape s2341{2, 3, 4, 1};
137 auto lhs = getParamFromShape(s2345);
138 auto rhs = getParamFromShape(s2341);
140 auto shaped = builder::numpy_broadcast({lhs, rhs});
141 const shared_ptr<Node>& ab_lhs = shaped.first;
142 const shared_ptr<Node>& ab_rhs = shaped.second;
144 EXPECT_EQ(ab_lhs, lhs); // no change
145 EXPECT_EQ(getShapeFromParam(ab_lhs), s2345);
147 EXPECT_NE(ab_rhs, rhs);
148 EXPECT_EQ(getShapeFromParam(ab_rhs), s2345);
151 // same as above, but additionally
152 // lhs reshape to 2,4,5 then
153 // lhs broadcast to 2,3,4,5
154 TEST(autobroadcast, reshape_2x_broadcast)
156 Shape s2145{2, 1, 4, 5};
157 Shape s2341{2, 3, 4, 1};
158 auto lhs = getParamFromShape(s2145);
159 auto rhs = getParamFromShape(s2341);
161 auto shaped = builder::numpy_broadcast({lhs, rhs});
162 const shared_ptr<Node>& ab_lhs = shaped.first;
163 const shared_ptr<Node>& ab_rhs = shaped.second;
165 Shape s2345{2, 3, 4, 5};
167 EXPECT_NE(ab_lhs, lhs);
168 EXPECT_EQ(getShapeFromParam(ab_lhs), s2345);
170 EXPECT_NE(ab_rhs, rhs);
171 EXPECT_EQ(getShapeFromParam(ab_rhs), s2345);
174 // matching singular dimension on axis 2
175 // should not require reshape of either lhs or rhs
176 // i.e. this should be the same as normal broadcast casse
177 // rhs broadcast to 2,3,1,5
178 TEST(autobroadcast, broadcast_with_dim1)
180 Shape s2315{2, 3, 1, 5};
182 auto lhs = getParamFromShape(s2315);
183 auto rhs = getParamFromShape(s315);
185 auto shaped = builder::numpy_broadcast({lhs, rhs});
186 const shared_ptr<Node>& ab_lhs = shaped.first;
187 const shared_ptr<Node>& ab_rhs = shaped.second;
189 EXPECT_EQ(ab_lhs, lhs); // no change
190 EXPECT_EQ(getShapeFromParam(ab_lhs), s2315);
192 EXPECT_NE(ab_rhs, rhs);
193 EXPECT_EQ(getShapeFromParam(ab_rhs), s2315);
197 // rhs reshape to 1,3,4,5 with no broadcast
198 TEST(autobroadcast, broadcast_with_leading_dim1)
200 Shape s1345{1, 3, 4, 5};
202 auto lhs = getParamFromShape(s1345);
203 auto rhs = getParamFromShape(s345);
205 auto shaped = builder::numpy_broadcast({lhs, rhs});
206 const shared_ptr<Node>& ab_lhs = shaped.first;
207 const shared_ptr<Node>& ab_rhs = shaped.second;
209 EXPECT_EQ(ab_lhs, lhs); // no change
210 EXPECT_EQ(getShapeFromParam(ab_lhs), s1345);
212 EXPECT_NE(ab_rhs, rhs);
213 EXPECT_EQ(getShapeFromParam(ab_rhs), s1345);
216 TEST(autobroadcast, make_node_2_args)
220 auto lhs = getParamFromShape(s21);
221 auto rhs = getParamFromShape(s23);
223 shared_ptr<Node> op = builder::make_with_numpy_broadcast<op::Add>(lhs, rhs);
224 EXPECT_NE(op, nullptr);
227 TEST(autobroadcast, make_node_3_args)
232 auto predicates = make_shared<op::Parameter>(element::boolean, s23);
233 auto lhs = getParamFromShape(s21);
234 auto rhs = getParamFromShape(s23);
236 shared_ptr<Node> op = builder::make_with_numpy_broadcast<op::Select>(predicates, lhs, rhs);
237 EXPECT_NE(op, nullptr);
240 TEST(autobroadcast, numpy_broadcast_for_matmul_op_2d)
242 const Shape lhs{3, 1, 4, 6};
243 const Shape rhs{6, 5};
244 const auto lhs_node = make_shared<op::Parameter>(element::f32, lhs);
245 const auto rhs_node = make_shared<op::Parameter>(element::f32, rhs);
247 const OutputVector result = builder::numpy_broadcast_for_matmul_operation(lhs_node, rhs_node);
249 EXPECT_EQ(result.at(0).get_shape(), (Shape{3, 1, 4, 6}));
250 EXPECT_EQ(result.at(1).get_shape(), (Shape{3, 1, 6, 5}));
253 TEST(autobroadcast, numpy_broadcast_for_matmul_op_3d)
255 const Shape lhs{3, 1, 4, 6};
256 const Shape rhs{2, 6, 5};
257 const auto lhs_node = make_shared<op::Parameter>(element::f32, lhs);
258 const auto rhs_node = make_shared<op::Parameter>(element::f32, rhs);
260 const OutputVector result = builder::numpy_broadcast_for_matmul_operation(lhs_node, rhs_node);
262 EXPECT_EQ(result.at(0).get_shape(), (Shape{3, 2, 4, 6}));
263 EXPECT_EQ(result.at(1).get_shape(), (Shape{3, 2, 6, 5}));
266 TEST(autobroadcast, numpy_broadcast_for_matmul_op_nop)
268 const Shape lhs{4, 6};
269 const Shape rhs{6, 5};
270 const auto lhs_node = make_shared<op::Parameter>(element::f32, lhs);
271 const auto rhs_node = make_shared<op::Parameter>(element::f32, rhs);
273 const OutputVector result = builder::numpy_broadcast_for_matmul_operation(lhs_node, rhs_node);
275 EXPECT_EQ(result.at(0).get_shape(), (Shape{4, 6}));
276 EXPECT_EQ(result.at(1).get_shape(), (Shape{6, 5}));
279 TEST(autobroadcast, opset1_legacy_broadcast_scalar)
281 const Shape lhs{2, 3, 4, 5};
283 size_t start_match_axis{3};
284 const auto lhs_node = make_shared<op::Parameter>(element::f32, lhs);
285 const auto rhs_node = make_shared<op::Parameter>(element::f32, rhs);
287 const Output<Node> result = builder::opset1::legacy_broadcast_for_binary_operation(
288 lhs_node, rhs_node, start_match_axis);
290 EXPECT_EQ(result.get_shape(), lhs);
293 TEST(autobroadcast, opset1_legacy_broadcast_1elem_tensor)
295 const Shape lhs{2, 3, 4, 5};
296 const Shape rhs{1, 1, 1};
297 size_t start_match_axis{1};
298 const auto lhs_node = make_shared<op::Parameter>(element::f32, lhs);
299 const auto rhs_node = make_shared<op::Parameter>(element::f32, rhs);
301 const Output<Node> result = builder::opset1::legacy_broadcast_for_binary_operation(
302 lhs_node, rhs_node, start_match_axis);
304 EXPECT_EQ(result.get_shape(), lhs);
307 TEST(autobroadcast, opset1_legacy_broadcast_1d)
309 const Shape lhs{2, 3, 4, 5};
311 size_t start_match_axis{3};
312 const auto lhs_node = make_shared<op::Parameter>(element::f32, lhs);
313 const auto rhs_node = make_shared<op::Parameter>(element::f32, rhs);
315 const Output<Node> result = builder::opset1::legacy_broadcast_for_binary_operation(
316 lhs_node, rhs_node, start_match_axis);
318 EXPECT_EQ(result.get_shape(), lhs);
321 TEST(autobroadcast, opset1_legacy_broadcast_2d)
323 const Shape lhs{2, 3, 4, 5};
324 const Shape rhs{4, 5};
325 size_t start_match_axis{2};
326 const auto lhs_node = make_shared<op::Parameter>(element::f32, lhs);
327 const auto rhs_node = make_shared<op::Parameter>(element::f32, rhs);
329 const Output<Node> result = builder::opset1::legacy_broadcast_for_binary_operation(
330 lhs_node, rhs_node, start_match_axis);
332 EXPECT_EQ(result.get_shape(), lhs);
335 TEST(autobroadcast, opset1_legacy_broadcast_2d_inside)
337 const Shape lhs{2, 3, 4, 5};
338 const Shape rhs{3, 4};
339 size_t start_match_axis{1};
340 const auto lhs_node = make_shared<op::Parameter>(element::f32, lhs);
341 const auto rhs_node = make_shared<op::Parameter>(element::f32, rhs);
343 const Output<Node> result = builder::opset1::legacy_broadcast_for_binary_operation(
344 lhs_node, rhs_node, start_match_axis);
346 EXPECT_EQ(result.get_shape(), lhs);
349 TEST(autobroadcast, opset1_legacy_broadcast_1d_left)
351 const Shape lhs{2, 3, 4, 5};
353 size_t start_match_axis{0};
354 const auto lhs_node = make_shared<op::Parameter>(element::f32, lhs);
355 const auto rhs_node = make_shared<op::Parameter>(element::f32, rhs);
357 const Output<Node> result = builder::opset1::legacy_broadcast_for_binary_operation(
358 lhs_node, rhs_node, start_match_axis);
360 EXPECT_EQ(result.get_shape(), lhs);
363 TEST(autobroadcast, opset1_legacy_broadcast_identical)
365 const Shape lhs{2, 3, 4, 5};
366 size_t start_match_axis{0};
367 const auto lhs_node = make_shared<op::Parameter>(element::f32, lhs);
368 const auto rhs_node = make_shared<op::Parameter>(element::f32, lhs);
370 const Output<Node> result = builder::opset1::legacy_broadcast_for_binary_operation(
371 lhs_node, rhs_node, start_match_axis);
373 EXPECT_EQ(result.get_shape(), lhs);
376 TEST(autobroadcast, axes_mapping_from_bcast_axes)
378 const Shape output_shape{2, 3, 4, 5};
379 const Shape input_shape{3, 5};
380 const AxisSet broadcast_axes{0, 2};
382 auto axes_mapping = builder::opset1::get_axes_mapping_output(output_shape, broadcast_axes);
383 EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
384 Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
385 EXPECT_EQ(axes_mapping_shape.size(), 2);
386 EXPECT_EQ(axes_mapping_shape, (Shape{1, 3}));
389 TEST(autobroadcast, axes_mapping_from_bcast_axes_scalar)
391 const Shape output_shape{2, 3, 4, 5};
392 const Shape input_shape{};
393 const AxisSet broadcast_axes{0, 1, 2, 3};
395 auto axes_mapping = builder::opset1::get_axes_mapping_output(output_shape, broadcast_axes);
396 EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
397 Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
398 EXPECT_EQ(axes_mapping_shape.size(), 0);
399 EXPECT_EQ(axes_mapping_shape, (Shape{}));
402 TEST(autobroadcast, axes_mapping_from_bcast_axes_identical)
404 const Shape output_shape{2, 3, 4, 5};
405 const Shape input_shape(output_shape);
406 const AxisSet broadcast_axes{};
408 auto axes_mapping = builder::opset1::get_axes_mapping_output(output_shape, broadcast_axes);
409 EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
410 Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
411 EXPECT_EQ(axes_mapping_shape.size(), output_shape.size());
412 EXPECT_EQ(axes_mapping_shape, (Shape{0, 1, 2, 3}));
415 TEST(autobroadcast, axes_mapping_start_match_axis)
417 const Shape output_shape{2, 3, 4, 5};
418 const Shape input_shape{3, 4};
419 const std::size_t start_match_axis{1};
422 builder::opset1::get_axes_mapping_output(output_shape, input_shape, start_match_axis);
423 EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
424 Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
425 EXPECT_EQ(axes_mapping_shape.size(), 2);
426 EXPECT_EQ(axes_mapping_shape, (Shape{1, 2}));
429 TEST(autobroadcast, axes_mapping_start_match_axis_scalar)
431 const Shape output_shape{2, 3, 4, 5};
432 const Shape input_shape{};
433 const std::size_t start_match_axis{4};
436 builder::opset1::get_axes_mapping_output(output_shape, input_shape, start_match_axis);
437 EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
438 Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
439 EXPECT_EQ(axes_mapping_shape.size(), 0);
440 EXPECT_EQ(axes_mapping_shape, (Shape{}));
443 TEST(autobroadcast, axes_mapping_start_match_axis_identical)
445 const Shape output_shape{2, 3, 4, 5};
446 const Shape input_shape{2, 3, 4, 5};
447 const std::size_t start_match_axis{0};
450 builder::opset1::get_axes_mapping_output(output_shape, input_shape, start_match_axis);
451 EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
452 Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
453 EXPECT_EQ(axes_mapping_shape.size(), output_shape.size());
454 EXPECT_EQ(axes_mapping_shape, (Shape{0, 1, 2, 3}));