Publishing 2020.1 content
[platform/upstream/dldt.git] / inference-engine / include / ie_preprocess.hpp
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 /**
6  * @brief This header file provides structures to store info about pre-processing of network inputs (scale, mean image,
7  * ...)
8  *
9  * @file ie_preprocess.hpp
10  */
11 #pragma once
12
13 #include <memory>
14 #include <vector>
15
16 #include "ie_blob.h"
17
18 namespace InferenceEngine {
19
20 /**
21  * @brief This structure stores info about pre-processing of network inputs (scale, mean image, ...)
22  */
23 struct PreProcessChannel {
24     /** @brief Scale parameter for a channel */
25     float stdScale = 1;
26
27     /** @brief Mean value for a channel */
28     float meanValue = 0;
29
30     /** @brief Mean data for a channel */
31     Blob::Ptr meanData;
32
33     /** @brief Smart pointer to an instance */
34     using Ptr = std::shared_ptr<PreProcessChannel>;
35 };
36
37 /**
38  * @brief Defines available types of mean
39  */
40 enum MeanVariant {
41     MEAN_IMAGE, /**< mean value is specified for each input pixel */
42     MEAN_VALUE, /**< mean value is specified for each input channel */
43     NONE,       /**< no mean value specified */
44 };
45
46 /**
47  * @enum ResizeAlgorithm
48  * @brief Represents the list of supported resize algorithms.
49  */
50 enum ResizeAlgorithm { NO_RESIZE = 0, RESIZE_BILINEAR, RESIZE_AREA };
51
52 /**
53  * @brief This class stores pre-process information for the input
54  */
55 class PreProcessInfo {
56     // Channel data
57     std::vector<PreProcessChannel::Ptr> _channelsInfo;
58     MeanVariant _variant = NONE;
59
60     // Resize Algorithm to be applied for input before inference if needed.
61     ResizeAlgorithm _resizeAlg = NO_RESIZE;
62
63     // Color format to be used in on-demand color conversions applied to input before inference
64     ColorFormat _colorFormat = ColorFormat::RAW;
65
66 public:
67     /**
68      * @brief Overloaded [] operator to safely get the channel by an index
69      *
70      * Throws an exception if channels are empty
71      *
72      * @param index Index of the channel to get
73      * @return The pre-process channel instance
74      */
75     PreProcessChannel::Ptr& operator[](size_t index) {
76         if (_channelsInfo.empty()) {
77             THROW_IE_EXCEPTION << "accessing pre-process when nothing was set.";
78         }
79         if (index >= _channelsInfo.size()) {
80             THROW_IE_EXCEPTION << "pre process index " << index << " is out of bounds.";
81         }
82         return _channelsInfo[index];
83     }
84
85     /**
86      * @brief operator [] to safely get the channel preprocessing information by index.
87      *
88      * Throws exception if channels are empty or index is out of border
89      *
90      * @param index Index of the channel to get
91      * @return The const preprocess channel instance
92      */
93     const PreProcessChannel::Ptr& operator[](size_t index) const {
94         if (_channelsInfo.empty()) {
95             THROW_IE_EXCEPTION << "accessing pre-process when nothing was set.";
96         }
97         if (index >= _channelsInfo.size()) {
98             THROW_IE_EXCEPTION << "pre process index " << index << " is out of bounds.";
99         }
100         return _channelsInfo[index];
101     }
102
103     /**
104      * @brief Returns a number of channels to preprocess
105      *
106      * @return The number of channels
107      */
108     size_t getNumberOfChannels() const {
109         return _channelsInfo.size();
110     }
111
112     /**
113      * @brief Initializes with given number of channels
114      *
115      * @param numberOfChannels Number of channels to initialize
116      */
117     void init(const size_t numberOfChannels) {
118         _channelsInfo.resize(numberOfChannels);
119         for (auto& channelInfo : _channelsInfo) {
120             channelInfo = std::make_shared<PreProcessChannel>();
121         }
122     }
123
124     /**
125      * @brief Sets mean image values if operation is applicable.
126      *
127      * Also sets the mean type to MEAN_IMAGE for all channels
128      *
129      * @param meanImage Blob with a mean image
130      */
131     void setMeanImage(const Blob::Ptr& meanImage) {
132         if (meanImage.get() == nullptr) {
133             THROW_IE_EXCEPTION << "Failed to set invalid mean image: nullptr";
134         } else if (meanImage.get()->getTensorDesc().getLayout() != Layout::CHW) {
135             THROW_IE_EXCEPTION << "Mean image layout should be CHW";
136         } else if (meanImage.get()->getTensorDesc().getDims().size() != 3) {
137             THROW_IE_EXCEPTION << "Failed to set invalid mean image: number of dimensions != 3";
138         } else if (meanImage.get()->getTensorDesc().getDims()[0] != getNumberOfChannels()) {
139             THROW_IE_EXCEPTION << "Failed to set invalid mean image: number of channels != " << getNumberOfChannels();
140         }
141         _variant = MEAN_IMAGE;
142     }
143
144     /**
145      * @brief Sets mean image values if operation is applicable.
146      *
147      * Also sets the mean type to MEAN_IMAGE for a particular channel
148      *
149      * @param meanImage Blob with a mean image
150      * @param channel Index of a particular channel
151      */
152     void setMeanImageForChannel(const Blob::Ptr& meanImage, const size_t channel) {
153         if (meanImage.get() == nullptr) {
154             THROW_IE_EXCEPTION << "Failed to set invalid mean image for channel: nullptr";
155         } else if (meanImage.get()->getTensorDesc().getDims().size() != 2) {
156             THROW_IE_EXCEPTION << "Failed to set invalid mean image for channel: number of dimensions != 2";
157         } else if (channel >= _channelsInfo.size()) {
158             THROW_IE_EXCEPTION << "Channel " << channel
159                                << " exceed number of PreProcess channels: " << _channelsInfo.size();
160         }
161         _variant = MEAN_IMAGE;
162         _channelsInfo[channel]->meanData = meanImage;
163     }
164
165     /**
166      * @brief Sets a type of mean operation
167      *
168      * @param variant Type of mean operation to set
169      */
170     void setVariant(const MeanVariant& variant) {
171         _variant = variant;
172     }
173
174     /**
175      * @brief Gets a type of mean operation
176      *
177      * @return The type of mean operation
178      */
179     MeanVariant getMeanVariant() const {
180         return _variant;
181     }
182
183     /**
184      * @brief Sets resize algorithm to be used during pre-processing
185      *
186      * @param alg Resize algorithm
187      */
188     void setResizeAlgorithm(const ResizeAlgorithm& alg) {
189         _resizeAlg = alg;
190     }
191
192     /**
193      * @brief Gets preconfigured resize algorithm
194      *
195      * @return Resize algorithm
196      */
197     ResizeAlgorithm getResizeAlgorithm() const {
198         return _resizeAlg;
199     }
200
201     /**
202      * @brief Changes the color format of the input data provided by the user
203      *
204      * This function should be called before loading the network to the plugin
205      * Setting color format different from ColorFormat::RAW enables automatic color conversion
206      * (as a part of built-in preprocessing routine)
207      *
208      * @param fmt A new color format associated with the input
209      */
210     void setColorFormat(ColorFormat fmt) {
211         _colorFormat = fmt;
212     }
213
214     /**
215      * @brief Gets a color format associated with the input
216      *
217      * @details By default, the color format is ColorFormat::RAW meaning
218      *          there is no particular color format assigned to the input
219      * @return Color format.
220      */
221     ColorFormat getColorFormat() const {
222         return _colorFormat;
223     }
224 };
225 }  // namespace InferenceEngine