Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / shape_infer / const_infer / ie_const_infer_impl.hpp
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <map>
8 #include <vector>
9 #include <memory>
10 #include <string>
11 #include "ie_layer_validators.hpp"
12
13 namespace InferenceEngine {
14 namespace ShapeInfer {
15
16 /**
17  * @experimental
18  * @class IConstInferImpl
19  * @brief This class provides interface for the layer's implementation to propagate const
20  */
21 class IConstInferImpl {
22 public:
23     using Ptr = std::shared_ptr<IConstInferImpl>;
24
25     virtual ~IConstInferImpl() = default;
26
27
28     /**
29      * @brief all shapes are valid, blobs are allocated
30      *
31      */
32     virtual void infer(const std::vector<Blob::CPtr>& inData,
33                        const std::map<std::string, std::string>& params,
34                        const std::map<std::string, Blob::Ptr>& blobs,
35                        std::vector<Blob::Ptr>& outData) = 0;
36 };
37
38 class ConstInferImpl : public IConstInferImpl {
39 public:
40     explicit ConstInferImpl(const std::string& type) : _type(type) {
41         _validator = details::LayerValidators::getInstance()->getValidator(_type);
42         if (!_validator)
43             THROW_IE_EXCEPTION << "Internal error: failed to find validator for layer with type: " << _type;
44     }
45
46     virtual void inferImpl(const std::vector<Blob::CPtr>& inData,
47                            const std::map<std::string, std::string>& params,
48                            const std::map<std::string, Blob::Ptr>& blobs,
49                            std::vector<Blob::Ptr>& outData) = 0;
50
51     void infer(const std::vector<Blob::CPtr>& inData,
52                const std::map<std::string, std::string>& params,
53                const std::map<std::string, Blob::Ptr>& blobs,
54                std::vector<Blob::Ptr>& outData) override;
55
56 protected:
57     std::string _type;
58     // to get parsed descendant CNNLayer from map<string,string>
59     details::LayerValidator::Ptr _validator;
60 };
61
62 }  // namespace ShapeInfer
63 }  // namespace InferenceEngine
64