1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
7 #include <inference_engine/shape_infer/built-in/ie_built_in_holder.hpp>
8 #include <shape_infer/mock_ishape_infer_impl.hpp>
9 #include <inference_engine/shape_infer/built-in/ie_equal_shape_infer.hpp>
11 using namespace InferenceEngine;
12 using namespace ShapeInfer;
14 class ShapeInferHolderTest : public ::testing::Test {
16 StatusCode sts = GENERAL_ERROR;
18 std::vector<InferenceEngine::SizeVector> outShapes;
19 std::map<std::string, std::string> params;
20 std::map<std::string, Blob::Ptr> blobs;
22 std::list<std::string> _expectedTypes = {
48 void TearDown() override {
51 void SetUp() override {
58 TEST_F(ShapeInferHolderTest, canCreateHolder) {
59 ASSERT_NO_THROW(BuiltInShapeInferHolder());
62 TEST_F(ShapeInferHolderTest, DISABLED_allRegistered) {
63 auto holder = std::make_shared<BuiltInShapeInferHolder>();
64 char** types = nullptr;
65 unsigned int size = 0;
66 ASSERT_NO_THROW(sts = holder->getShapeInferTypes(types, size, &resp));
67 std::list<std::string> actualTypes;
68 for (int i = 0; i < size; i++) {
69 actualTypes.emplace_back(types[i], strlen(types[i]));
72 _expectedTypes.sort();
75 std::vector<std::string> different_words;
76 std::set_difference(actualTypes.begin(), actualTypes.end(),
77 _expectedTypes.begin(), _expectedTypes.end(),
78 std::back_inserter(different_words));
79 // TODO: update expectedTypes!
80 ASSERT_EQ(19, different_words.size());
84 TEST_F(ShapeInferHolderTest, returnNullForNotKnown) {
85 IShapeInferImpl::Ptr impl;
87 sts = BuiltInShapeInferHolder().getShapeInferImpl(impl, "NOT_KNOWN_TYPE", &resp);
88 ASSERT_FALSE(impl) << resp.msg;
89 ASSERT_EQ(NOT_FOUND, sts);
92 class ShapeInferNotSupportedTest
93 : public ShapeInferHolderTest, public testing::WithParamInterface<std::string> {
96 TEST_P(ShapeInferNotSupportedTest, returnNotFoundOnNotSupported) {
97 std::string type = GetParam();
98 IShapeInferImpl::Ptr impl;
100 sts = BuiltInShapeInferHolder().getShapeInferImpl(impl, type.c_str(), &resp);
101 ASSERT_FALSE(impl) << resp.msg;
102 ASSERT_EQ(NOT_FOUND, sts) << resp.msg;
105 // TODO: list all not supported later
106 INSTANTIATE_TEST_CASE_P(
107 NotSupported, ShapeInferNotSupportedTest, ::testing::Values("NOT_SUPPORTED"));