1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
9 #include "common_test_utils/common_utils.hpp"
10 #include "functional_test_utils/precision_utils.hpp"
11 #include "functional_test_utils/skip_tests_config.hpp"
12 #include "subgraph_tests/split_trivial_permute_concat.hpp"
13 #include "ngraph_functions/utils/ngraph_helpers.hpp"
15 namespace LayerTestsDefinitions {
16 std::string SplitTrivialPermuteConcatTest::getTestCaseName(const testing::TestParamInfo<SplitTrivialPermuteConcatTuple>& obj) {
17 InferenceEngine::Precision netPrecision;
18 std::string targetName;
19 std::vector<size_t> inputShape;
22 std::tie(netPrecision, targetName, inputShape, splitAxis, concatAxis, std::ignore) = obj.param;
23 std::ostringstream results;
25 results << "netPRC=" << netPrecision.name() << "_";
27 for (size_t size : inputShape)
28 results << size << "_";
29 results << "SA=" << splitAxis << "_";
30 results << "CA=" << concatAxis << "_";
31 results << "targetDevice=" << targetName;
35 void SplitTrivialPermuteConcatTest::SetUp() {
36 InferenceEngine::Precision netPrecision;
37 std::vector<size_t> inputShape;
40 std::map<std::string, std::string> config;
41 std::tie(netPrecision, targetDevice, inputShape, splitAxis, concatAxis, config) = this->GetParam();
42 configuration.insert(config.begin(), config.end());
43 auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
44 auto input = ngraph::builder::makeParams(ngPrc, { inputShape });
45 auto split = ngraph::builder::makeSplit(input[0], ngPrc, 2, splitAxis);
47 auto permute_in_params = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64,
49 ngraph::Shape{ {0, 3, 2, 1} });
50 auto permute_0 = std::make_shared<ngraph::opset1::Transpose>(split->output(0), permute_in_params);
51 auto permute_1 = std::make_shared<ngraph::opset1::Transpose>(split->output(1), permute_in_params);
53 auto concat = std::make_shared<ngraph::opset1::Concat>(ngraph::OutputVector{ split->output(0), split->output(1) }, concatAxis);
54 auto act = ngraph::builder::makeActivation(concat, ngPrc, ngraph::helpers::ActivationTypes::Relu);
55 function = std::make_shared<ngraph::Function>(act, input, "split_trivial_permute_concat");
58 TEST_P(SplitTrivialPermuteConcatTest, CompareWithRefs) {
61 } // namespace LayerTestsDefinitions