Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / shape_infer / built_in_holder_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <list>
7 #include <inference_engine/shape_infer/built-in/ie_built_in_holder.hpp>
8 #include <shape_infer/mock_ishape_infer_impl.hpp>
9 #include <inference_engine/shape_infer/built-in/ie_equal_shape_infer.hpp>
10
11 using namespace InferenceEngine;
12 using namespace ShapeInfer;
13
14 class ShapeInferHolderTest : public ::testing::Test {
15 protected:
16     StatusCode sts = GENERAL_ERROR;
17     ResponseDesc resp;
18     std::vector<InferenceEngine::SizeVector> outShapes;
19     std::map<std::string, std::string> params;
20     std::map<std::string, Blob::Ptr> blobs;
21
22     std::list<std::string> _expectedTypes = {
23             "Power",
24             "Convolution",
25             "Deconvolution",
26             "Pooling",
27             "LRN",
28             "Norm",
29             "SoftMax",
30             "ReLU",
31             "Clamp",
32             "Split",
33             "Slice",
34             "Concat",
35             "Eltwise",
36             "ScaleShift",
37             "PReLU",
38             "Crop",
39             "Reshape",
40             "Tile",
41             "BatchNormalization",
42             "Input",
43             "Memory",
44             "Const",
45             "Gemm"
46     };
47
48     void TearDown() override {
49     }
50
51     void SetUp() override {
52     }
53
54 public:
55
56 };
57
58 TEST_F(ShapeInferHolderTest, canCreateHolder) {
59     ASSERT_NO_THROW(BuiltInShapeInferHolder());
60 }
61
62 TEST_F(ShapeInferHolderTest, DISABLED_allRegistered) {
63     auto holder = std::make_shared<BuiltInShapeInferHolder>();
64     char** types = nullptr;
65     unsigned int size = 0;
66     ASSERT_NO_THROW(sts = holder->getShapeInferTypes(types, size, &resp));
67     std::list<std::string> actualTypes;
68     for (int i = 0; i < size; i++) {
69         actualTypes.emplace_back(types[i], strlen(types[i]));
70     }
71
72     _expectedTypes.sort();
73     actualTypes.sort();
74
75     std::vector<std::string> different_words;
76     std::set_difference(actualTypes.begin(), actualTypes.end(),
77                         _expectedTypes.begin(), _expectedTypes.end(),
78                         std::back_inserter(different_words));
79     // TODO: update expectedTypes!
80     ASSERT_EQ(19, different_words.size());
81 }
82
83
84 TEST_F(ShapeInferHolderTest, returnNullForNotKnown) {
85     IShapeInferImpl::Ptr impl;
86
87     sts = BuiltInShapeInferHolder().getShapeInferImpl(impl, "NOT_KNOWN_TYPE", &resp);
88     ASSERT_FALSE(impl) << resp.msg;
89     ASSERT_EQ(NOT_FOUND, sts);
90 }
91
92 class ShapeInferNotSupportedTest
93         : public ShapeInferHolderTest, public testing::WithParamInterface<std::string> {
94 };
95
96 TEST_P(ShapeInferNotSupportedTest, returnNotFoundOnNotSupported) {
97     std::string type = GetParam();
98     IShapeInferImpl::Ptr impl;
99
100     sts = BuiltInShapeInferHolder().getShapeInferImpl(impl, type.c_str(), &resp);
101     ASSERT_FALSE(impl) << resp.msg;
102     ASSERT_EQ(NOT_FOUND, sts) << resp.msg;
103 }
104
105 // TODO: list all not supported later
106 INSTANTIATE_TEST_CASE_P(
107         NotSupported, ShapeInferNotSupportedTest, ::testing::Values("NOT_SUPPORTED"));