Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / include / ie_preprocess.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 /**
6  * @brief This header file provides structures to store info about pre-processing of network inputs (scale, mean image, ...)
7  * @file ie_preprocess.hpp
8  */
9 #pragma once
10
11 #include "ie_blob.h"
12 #include <vector>
13 #include <memory>
14
15 namespace InferenceEngine {
16
17 /**
18  * @brief This structure stores info about pre-processing of network inputs (scale, mean image, ...)
19  */
20 struct PreProcessChannel {
21     /** @brief Scale parameter for a channel */
22     float stdScale = 1;
23
24     /** @brief Mean value for a channel */
25     float meanValue = 0;
26
27     /** @brief Mean data for a channel */
28     Blob::Ptr meanData;
29
30     /** @brief Smart pointer to an instance */
31     using Ptr = std::shared_ptr<PreProcessChannel>;
32 };
33
34 /**
35  * @brief Defines available types of mean
36  */
37 enum MeanVariant {
38     MEAN_IMAGE, /**< mean value is specified for each input pixel */
39     MEAN_VALUE, /**< mean value is specified for each input channel */
40     NONE,       /**< no mean value specified */
41 };
42
43 /**
44  * @enum ResizeAlgorithm
45  * @brief Represents the list of supported resize algorithms.
46  */
47 enum ResizeAlgorithm {
48     NO_RESIZE = 0,
49     RESIZE_BILINEAR,
50     RESIZE_AREA
51 };
52
53 /**
54  * @brief This class stores pre-process information for the input
55  */
56 class PreProcessInfo {
57     // Channel data
58     std::vector<PreProcessChannel::Ptr> _channelsInfo;
59     MeanVariant _variant = NONE;
60
61     // Resize Algorithm to be applied for input before inference if needed.
62     ResizeAlgorithm _resizeAlg = NO_RESIZE;
63
64 public:
65     /**
66      * @brief Overloaded [] operator to safely get the channel by an index. 
67      * Throws an exception if channels are empty.
68      * @param index Index of the channel to get
69      * @return The pre-process channel instance
70      */
71     PreProcessChannel::Ptr &operator[](size_t index) {
72         if (_channelsInfo.empty()) {
73             THROW_IE_EXCEPTION << "accessing pre-process when nothing was set.";
74         }
75         if (index >= _channelsInfo.size()) {
76             THROW_IE_EXCEPTION << "pre process index " << index << " is out of bounds.";
77         }
78         return _channelsInfo[index];
79     }
80
81     /**
82      * @brief operator [] to safely get the channel preprocessing information by index.
83      * Throws exception if channels are empty or index is out of border
84      *
85      * @param index Index of the channel to get
86      * @return The const preprocess channel instance
87      */
88     const PreProcessChannel::Ptr &operator[](size_t index) const {
89         if (_channelsInfo.empty()) {
90             THROW_IE_EXCEPTION << "accessing pre-process when nothing was set.";
91         }
92         if (index >= _channelsInfo.size()) {
93             THROW_IE_EXCEPTION << "pre process index " << index << " is out of bounds.";
94         }
95         return _channelsInfo[index];
96     }
97
98     /**
99      * @brief Returns a number of channels to preprocess
100      * @return The number of channels
101      */
102     size_t getNumberOfChannels() const {
103         return _channelsInfo.size();
104     }
105
106     /**
107      * @brief Initializes with given number of channels
108      * @param numberOfChannels Number of channels to initialize
109      */
110     void init(const size_t numberOfChannels) {
111         _channelsInfo.resize(numberOfChannels);
112         for (auto &channelInfo : _channelsInfo) {
113             channelInfo = std::make_shared<PreProcessChannel>();
114         }
115     }
116
117     /**
118      * @brief Sets mean image values if operation is applicable.
119      * Also sets the mean type to MEAN_IMAGE for all channels
120      * @param meanImage Blob with a mean image
121      */
122     void setMeanImage(const Blob::Ptr &meanImage) {
123         if (meanImage.get() == nullptr) {
124             THROW_IE_EXCEPTION << "Failed to set invalid mean image: nullptr";
125         } else if (meanImage.get()->dims().size() != 3) {
126             THROW_IE_EXCEPTION << "Failed to set invalid mean image: number of dimensions != 3";
127         } else if (meanImage.get()->dims()[2] != getNumberOfChannels()) {
128             THROW_IE_EXCEPTION << "Failed to set invalid mean image: number of channels != "
129                                << getNumberOfChannels();
130         } else if (meanImage.get()->layout() != Layout::CHW) {
131             THROW_IE_EXCEPTION << "Mean image layout should be CHW";
132         }
133         _variant = MEAN_IMAGE;
134     }
135
136     /**
137      * @brief Sets mean image values if operation is applicable.
138      * Also sets the mean type to MEAN_IMAGE for a particular channel
139      * @param meanImage Blob with a mean image
140      * @param channel Index of a particular channel
141      */
142     void setMeanImageForChannel(const Blob::Ptr &meanImage, const size_t channel) {
143         if (meanImage.get() == nullptr) {
144             THROW_IE_EXCEPTION << "Failed to set invalid mean image for channel: nullptr";
145         } else if (meanImage.get()->dims().size() != 2) {
146             THROW_IE_EXCEPTION << "Failed to set invalid mean image for channel: number of dimensions != 2";
147         } else if (channel >= _channelsInfo.size()) {
148             THROW_IE_EXCEPTION << "Channel " << channel << " exceed number of PreProcess channels: "
149                                << _channelsInfo.size();
150         }
151         _variant = MEAN_IMAGE;
152         _channelsInfo[channel]->meanData = meanImage;
153     }
154
155     /**
156      * @brief Sets a type of mean operation
157      * @param variant Type of mean operation to set
158      */
159     void setVariant(const MeanVariant &variant) {
160         _variant = variant;
161     }
162
163     /**
164      * @brief Gets a type of mean operation
165      * @return The type of mean operation
166      */
167     MeanVariant getMeanVariant() const {
168         return _variant;
169     }
170
171     /**
172      * @brief Sets resize algorithm to be used during pre-processing.
173      * @param alg Resize algorithm.
174      */
175     void setResizeAlgorithm(const ResizeAlgorithm &alg) {
176         _resizeAlg = alg;
177     }
178
179     /**
180      * @brief Gets preconfigured resize algorithm.
181      * @return Resize algorithm.
182      */
183     ResizeAlgorithm getResizeAlgorithm() const {
184         return _resizeAlg;
185     }
186 };
187 }  // namespace InferenceEngine