Publishing R3
[platform/upstream/dldt.git] / inference-engine / include / mkldnn / mkldnn_generic_primitive.hpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 /**
7  * @brief a header file for MKL-DNN Generic Primitive API
8  * @file mkldnn_generic_primitive.hpp
9  */
10 #pragma once
11
12 #include "mkldnn_extension_types.hpp"
13 #include "details/ie_irelease.hpp"
14 #include <vector>
15
16 namespace InferenceEngine {
17 namespace MKLDNNPlugin {
18
19 /**
20  * @deprecated use new extensibility API
21  * @brief The MKLDNNGenericFormats stores weights, biases, inputs and outputs of the primitive
22  */
23 class MKLDNNGenericFormats {
24 public:
25     /**
26      * @brief A default constructor
27      * @param ins - vector of inputs
28      * @param outs - vector of outputs
29      * @param weights - weights, format_undef by default
30      * @param biases  - biases, format_undef by default
31      */
32     MKLDNNGenericFormats(const std::vector<MemoryFormat> &ins, const std::vector<MemoryFormat> &outs,
33                          const MemoryFormat weights = MemoryFormat::format_undef,
34                          const MemoryFormat biases = MemoryFormat::format_undef) : inputs(ins), outputs(outs) {
35         this->weights = weights;
36         this->biases = biases;
37     }
38
39     /**
40      * @brief Get input formats
41      * @return vector of input formats
42      */
43     const std::vector<MemoryFormat>& GetInputs() const noexcept {
44         return inputs;
45     }
46
47     /**
48      * @brief Get output formats
49      * @return vector of output formats
50      */
51     const std::vector<MemoryFormat>& GetOutputs() const noexcept {
52         return outputs;
53     }
54
55     /**
56      * @brief Get weights format
57      * @return weights format
58      */
59     const MemoryFormat& GetWeights() const noexcept {
60         return weights;
61     }
62
63     /**
64      * @brief Get biases format
65      * @return biases format
66      */
67     const MemoryFormat& GetBiases() const noexcept {
68         return biases;
69     }
70
71 private:
72     std::vector<MemoryFormat> inputs;
73     std::vector<MemoryFormat> outputs;
74     MemoryFormat weights;
75     MemoryFormat biases;
76 };
77
78 /**
79  * @deprecated use new extensibility API
80  * @brief The IMKLDNNGenericPrimitive is the main Generic Primitive interface
81  */
82 class IMKLDNNGenericPrimitive : public InferenceEngine::details::IRelease {
83 public:
84     void Release() noexcept override {
85         delete this;
86     }
87
88     /**
89      * @brief Sets inputs nd outputs
90      * @param inputs - vector of input primitives
91      * @param outputs - vector of output primitives
92      */
93     void SetMemory(const std::vector<MKLDNNPrimitiveMemory>& inputs,
94                            const std::vector<MKLDNNPrimitiveMemory>& outputs) noexcept {
95         this->inputs = inputs;
96         this->outputs = outputs;
97     }
98
99     /**
100      * @brief Gets supported formats
101      * @return vector of supported formats
102      */
103     virtual std::vector<MKLDNNGenericFormats> GetSupportedFormats() noexcept = 0;
104
105     /**
106      * @brief Entry point of actual execution of primitive.
107      * Error reporting mechanism missed, static check should be done in constructor
108      */
109     virtual void Execute() noexcept = 0;
110
111 protected:
112     /**
113      * @brief Vector of input primitives
114      */
115     std::vector<MKLDNNPrimitiveMemory> inputs;
116     /**
117      * @brief Vector of output primitives
118      */
119     std::vector<MKLDNNPrimitiveMemory> outputs;
120 };
121
122 }  // namespace MKLDNNPlugin
123 }  // namespace InferenceEngine