1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <xml_net_builder.hpp>
7 #include <inference_engine/cnn_network_impl.hpp>
8 #include <inference_engine/ie_format_parser.h>
9 #include <inference_engine/ie_layer_validators.hpp>
10 #include <xml_helper.hpp>
11 #include <../shape_infer/built_in_shape_infer_general_test.hpp>
13 #include <../include/ie_data.h>
15 #include "layer_builder.h"
17 using namespace InferenceEngine;
18 using namespace InferenceEngine::details;
20 TEST_P(CNNLayerValidationTests, checkValidParams) {
22 assertThat(type)->setParams(valid_params);
23 auto layer = getLayer();
24 LayerValidator::Ptr validator = LayerValidators::getInstance()->getValidator(type);
26 ASSERT_NO_THROW(validator->parseParams(layer.get()));
27 ASSERT_NO_THROW(validator->checkParams(layer.get()));
30 TEST_P(CNNLayerValidationTests, checkInvalidParams) {
33 int numberOfParams = getNumOfParams();
34 LayerValidator::Ptr validator = LayerValidators::getInstance()->getValidator(type);
35 auto layer_ = getLayer();
36 for (int i = 0; i < numberOfParams; ++i) {
37 layer->setParams(!valid_params);
38 ASSERT_THROW(validator->parseParams(layer_.get()), InferenceEngineException);
39 ASSERT_THROW(validator->checkParams(layer_.get()), InferenceEngineException);
43 TEST_P(CNNLayerValidationTests, checkInvalidInputShapes) {
44 LayerValidator::Ptr validator = LayerValidators::getInstance()->getValidator(type);
45 std::vector<DataPtr> spData;
46 assertThat(type)->setShapes(spData, !valid_input);
48 auto layer_ = getLayer();
50 InferenceEngine::details::getInOutShapes(layer_.get(), shapes);
51 ASSERT_THROW(validator->checkShapes(layer_.get(), shapes.inDims), InferenceEngineException);
54 TEST_P(CNNLayerValidationTests, checkValidShapes) {
56 std::vector<DataPtr> spData;
57 assertThat(type)->setShapes(spData, valid_input);
58 auto layer = getLayer();
59 LayerValidator::Ptr validator = LayerValidators::getInstance()->getValidator(type);
61 InferenceEngine::details::getInOutShapes(layer.get(), shapes);
62 ASSERT_NO_THROW(validator->checkShapes(layer.get(), shapes.inDims));
65 INSTANTIATE_TEST_CASE_P(
66 InstantiationName, CNNLayerValidationTests,