1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
7 #include <gtest/gtest.h>
9 #include <initializer_list>
12 #include <unordered_set>
13 #include <unordered_map>
15 #include <ie_util_internal.hpp>
16 #include <tests_common.hpp>
17 #include <graph_transformer.h>
18 #include "ie_utils.hpp"
19 #include "blob_factory.hpp"
21 #include "util_test.hpp"
22 #include <details/ie_cnn_network_tools.h>
24 namespace IE = InferenceEngine;
26 class ConstTransformatorTest : public IE::ConstTransformer {
28 explicit ConstTransformatorTest(IE::details::CNNNetworkImpl* network) : IE::ConstTransformer(network) {}
30 const std::map<std::string, bool>
31 getConstLayers(const std::vector<InferenceEngine::CNNLayerPtr>& sortedLayers) override {
32 return ConstTransformer::getConstLayers(sortedLayers);
35 const InferenceEngine::BlobMap getConstData(const std::map<std::string, bool>& constLayers,
36 const std::vector<InferenceEngine::CNNLayerPtr>& sortedLayers) override {
37 return ConstTransformer::getConstData(constLayers, sortedLayers);
40 std::vector<std::string>
41 foldConstSubgraphsInternal(const std::map<std::string, bool>& constLayers, const IE::BlobMap& constData,
42 const std::vector<IE::CNNLayerPtr>& sortedLayers) override {
43 return ConstTransformer::foldConstSubgraphsInternal(constLayers, constData, sortedLayers);
46 void trimShapeInputs(const std::vector<std::string>& constLayers) override {
47 ConstTransformer::trimShapeInputs(constLayers);
52 class RemoveLayerTests : public testing::Test {
54 void SetUp() override;
61 // I2-d2-L2-d5-L4-d6-L5-d9-L10
67 IE::details::CNNNetworkImplPtr getNetwork();
69 IE::CNNLayerPtr getLayer(const std::string& name);
71 IE::DataPtr getData(const std::string& name);
73 IE::BlobMap fillConstData(const std::vector<std::string>& constLayers);
75 IE::BlobMap initConstLayers(const std::vector<std::string>& constLayers);
77 NetBuilder netBuilder;
78 IE::details::CNNNetworkImplPtr net;
79 size_t originalLayersNum;
80 std::unique_ptr<ConstTransformatorTest> testTransformator;
83 class AdvancedShapeInferTests : public RemoveLayerTests {
85 void SetUp() override {};