1 // Copyright (C) 2018 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <extension/ext_list.hpp>
7 #include <xml_net_builder.hpp>
8 #include <inference_engine/cnn_network_impl.hpp>
9 #include <inference_engine/shape_infer/ie_reshaper.hpp>
10 #include <cpp/ie_cnn_net_reader.h>
11 #include <test_model_path.hpp>
12 #include <inference_engine/debug.h>
13 #include <ie_extension.h>
14 #include <tests_common.hpp>
15 #include "built_in_shape_infer_general_test.hpp"
17 using namespace InferenceEngine;
18 using namespace InferenceEngine::details;
19 using namespace ShapeInfer;
21 class CPUExtShapeInferTests : public BuiltInShapeInferImplTest {
23 InferenceEngine::ShapeInferExtension shapeInferExt;
24 CPUExtShapeInferTests () : shapeInferExt(TestsCommon::make_so_name("cpu_extension")) {}
26 void SetUp() override {
27 BuiltInShapeInferImplTest::SetUp();
28 holder = std::shared_ptr<IShapeInferExtension>(&shapeInferExt, [](IShapeInferExtension*){});
32 TEST_P(CPUExtShapeInferTests, impl) {
33 auto impl = getShapeInferImpl(type);
34 ASSERT_NE(nullptr, impl);
35 ASSERT_NO_THROW(sts = impl->inferShapes(newInOutShapes.inDims, layerParams.data, blobs, outShapes, &resp));
38 ASSERT_EQ(int(OK), sts) << resp.msg;
39 ASSERT_EQ(newInOutShapes.outDims, outShapes);
41 ASSERT_EQ(GENERAL_ERROR, sts) << resp.msg;
45 TEST_P(CPUExtShapeInferTests, reshaper) {
46 auto cnnNetworkImplPtr = buildSingleLayerNetwork<3>(type, inOutShapes, &layerParams.data, layerDataName);
47 auto reshaper = std::make_shared<Reshaper>(*cnnNetworkImplPtr);
48 auto inputShapes = setInputShapes(*cnnNetworkImplPtr.get(), newInOutShapes.inDims);
49 reshaper->AddExtension(holder);
52 reshaper->run(inputShapes);
53 checkNetworkInOut(*cnnNetworkImplPtr, newInOutShapes);
55 ASSERT_THROW(reshaper->run(inputShapes), InferenceEngine::details::InferenceEngineException);
59 INSTANTIATE_TEST_CASE_P(
60 CPUExtGeneralImpls, CPUExtShapeInferTests,
62 ::testing::make_tuple(LayerType("SpatialTransformer"),
63 InOutShapes({{{1, 6, 5, 5}, {1, 3}},
65 NewInOutShapes({{{2, 6, 5, 6}, {1, 3}},
67 MapParams(MapStrStr()),
68 LayerDataName("data"),