Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / tests / unit / shape_infer / cpu_ext_shape_infer_general_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <extension/ext_list.hpp>
7 #include <xml_net_builder.hpp>
8 #include <inference_engine/cnn_network_impl.hpp>
9 #include <inference_engine/shape_infer/ie_reshaper.hpp>
10 #include <cpp/ie_cnn_net_reader.h>
11 #include <test_model_path.hpp>
12 #include <inference_engine/debug.h>
13 #include <ie_extension.h>
14 #include <tests_common.hpp>
15 #include "built_in_shape_infer_general_test.hpp"
16
17 using namespace InferenceEngine;
18 using namespace InferenceEngine::details;
19 using namespace ShapeInfer;
20
21 class CPUExtShapeInferTests : public BuiltInShapeInferImplTest {
22 protected:
23     InferenceEngine::ShapeInferExtension shapeInferExt;
24     CPUExtShapeInferTests () : shapeInferExt(TestsCommon::make_so_name("cpu_extension")) {}
25
26     void SetUp() override {
27         BuiltInShapeInferImplTest::SetUp();
28         holder = std::shared_ptr<IShapeInferExtension>(&shapeInferExt, [](IShapeInferExtension*){});
29     }
30 };
31
32 TEST_P(CPUExtShapeInferTests, impl) {
33     auto impl = getShapeInferImpl(type);
34     ASSERT_NE(nullptr, impl);
35     ASSERT_NO_THROW(sts = impl->inferShapes(newInOutShapes.inDims, layerParams.data, blobs, outShapes, &resp));
36
37     if (canInfer) {
38         ASSERT_EQ(int(OK), sts) << resp.msg;
39         ASSERT_EQ(newInOutShapes.outDims, outShapes);
40     } else {
41         ASSERT_EQ(GENERAL_ERROR, sts) << resp.msg;
42     }
43 }
44
45 TEST_P(CPUExtShapeInferTests, reshaper) {
46     auto cnnNetworkImplPtr = buildSingleLayerNetwork<3>(type, inOutShapes, &layerParams.data, layerDataName);
47     auto reshaper = std::make_shared<Reshaper>(*cnnNetworkImplPtr);
48     auto inputShapes = setInputShapes(*cnnNetworkImplPtr.get(), newInOutShapes.inDims);
49     reshaper->AddExtension(holder);
50
51     if (canInfer) {
52         reshaper->run(inputShapes);
53         checkNetworkInOut(*cnnNetworkImplPtr, newInOutShapes);
54     } else {
55         ASSERT_THROW(reshaper->run(inputShapes), InferenceEngine::details::InferenceEngineException);
56     }
57 }
58
59 INSTANTIATE_TEST_CASE_P(
60         CPUExtGeneralImpls, CPUExtShapeInferTests,
61         ::testing::Values(
62                 ::testing::make_tuple(LayerType("SpatialTransformer"),
63                                       InOutShapes({{{1, 6, 5, 5}, {1, 3}},
64                                                    {{1, 6, 5, 5}}}),
65                                       NewInOutShapes({{{2, 6, 5, 6}, {1, 3}},
66                                                       {{2, 6, 5, 6}}}),
67                                       MapParams(MapStrStr()),
68                                       LayerDataName("data"),
69                                       CanInfer(true))
70         )
71 );