Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / test / type_prop / split.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "gtest/gtest.h"
18 #include "ngraph/ngraph.hpp"
19 #include "util/type_prop.hpp"
20
21 NGRAPH_SUPPRESS_DEPRECATED_START
22
23 using namespace std;
24 using namespace ngraph;
25
26 TEST(type_prop, split)
27 {
28     const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
29
30     try
31     {
32         const std::vector<size_t> splits = {1, 6}; // should sum up to 6
33         const auto axis = op::Constant::create(element::i64, Shape{}, {1});
34         const auto split = make_shared<op::Split>(data, axis, splits);
35         FAIL() << "Split node was created with incorrect data.";
36     }
37     catch (const NodeValidationFailure& error)
38     {
39         EXPECT_HAS_SUBSTRING(
40             error.what(), std::string("has to be equal to the sum of splits passed to the op: 7"));
41     }
42
43     try
44     {
45         const std::vector<size_t> splits = {4, 2};
46         const auto axis = op::Constant::create(element::i64, Shape{}, {-5});
47         const auto split = make_shared<op::Split>(data, axis, splits); // invalid axis
48         FAIL() << "Split node was created with incorrect data.";
49     }
50     catch (const ngraph_error& error)
51     {
52         EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis -5 out of the tensor rank"));
53     }
54
55     const auto axis = op::Constant::create(element::i64, Shape{}, {1});
56     const auto split = make_shared<op::Split>(data, axis, 2);
57     EXPECT_EQ(split->outputs().size(), 2);
58     EXPECT_EQ(split->get_output_shape(0), (Shape{2, 3}));
59     EXPECT_EQ(split->get_output_shape(1), (Shape{2, 3}));
60     EXPECT_EQ(split->get_output_element_type(0), element::i32);
61     EXPECT_EQ(split->get_output_element_type(1), element::i32);
62 }
63
64 TEST(type_prop, split_axis_must_be_scalar)
65 {
66     const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
67     const std::vector<size_t> splits = {1, 6};
68     const auto axis = op::Constant::create(element::i64, Shape{2}, {0, 1});
69
70     try
71     {
72         const auto split = make_shared<op::Split>(data, axis, splits);
73         FAIL() << "Incorrect axis of Split not detected.";
74     }
75     catch (const NodeValidationFailure& error)
76     {
77         EXPECT_HAS_SUBSTRING(error.what(), std::string("The 'axis' input node must be scalar"));
78     }
79     catch (...)
80     {
81         FAIL() << "Deduced type check failed for unexpected reason.";
82     }
83 }
84
85 TEST(type_prop, split_axis_must_be_constant)
86 {
87     const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
88     const std::vector<size_t> splits = {1, 6};
89     const auto axis = make_shared<op::Parameter>(element::i32, Shape{});
90
91     try
92     {
93         const auto split = make_shared<op::Split>(data, axis, splits);
94         FAIL() << "Not constant axis of Split not detected.";
95     }
96     catch (const NodeValidationFailure& error)
97     {
98         EXPECT_HAS_SUBSTRING(error.what(), std::string("The 'axis' input node must be constant"));
99     }
100     catch (...)
101     {
102         FAIL() << "Deduced type check failed for unexpected reason.";
103     }
104 }