Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mean_image.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mean_image.h"
6 #include "ie_parallel.hpp"
7 #include "ie_memcpy.h"
8
9 using namespace MKLDNNPlugin;
10 using namespace InferenceEngine;
11
12 MeanImage::MeanImage() : meanBuffer(nullptr) {
13 }
14
15 void MeanImage::Load(const MKLDNNDims& inputDims, InputInfo::Ptr inputInfo) {
16     PreProcessInfo &pp = inputInfo->getPreProcess();
17     size_t inChannels = pp.getNumberOfChannels();
18     if (inChannels == 0) {
19         meanBuffer = nullptr;
20         return;
21     }
22
23     if (inChannels != inputDims[1]) {
24         THROW_IE_EXCEPTION << "channels mismatch between mean and input";
25     }
26
27     ResponseDesc resp;
28
29     switch (pp.getMeanVariant()) {
30         case MEAN_VALUE: {
31             // mean image common value per channel (1x1xC)
32             meanValues.resize(inChannels);
33
34             for (unsigned channel = 0; channel < inChannels; channel++) {
35                 meanValues[channel] = pp[channel]->meanValue;
36             }
37         }
38         break;
39         case MEAN_IMAGE: {
40             // since MKLDNN expects all channels in the same buffer - we copy it here as it comes from different channels...
41             auto meanWidth = pp[0]->meanData->dims()[0];
42             auto meanHeight = pp[0]->meanData->dims()[1];
43
44
45             meanBuffer = make_shared_blob<float>(Precision::FP32, CHW, { meanWidth, meanHeight, inChannels });
46
47             meanBuffer->allocate();
48
49             for (unsigned channel = 0; channel < inChannels; channel++) {
50                 Blob::Ptr meanBlob = pp[channel]->meanData;
51                 if (!meanBlob || meanBlob->precision() != Precision::FP32)
52                     THROW_IE_EXCEPTION << "mean image not provided or not in Float 32";
53                 if (meanBlob->size() != meanHeight*meanWidth) {
54                     THROW_IE_EXCEPTION << "mean image size does not match expected network input, expecting " << meanWidth << " x " << meanHeight;
55                 }
56                 // todo: cast to TBlob and make sure it is floats
57                 ie_memcpy(meanBuffer->data() + channel*meanBlob->size(), meanBuffer->byteSize() - channel*meanBlob->byteSize(),
58                           meanBlob->buffer(), meanBlob->byteSize());
59             }
60         }
61             break;
62
63         case NONE: {
64             // there is no mean image. So disable mean image step
65             meanBuffer = nullptr;
66         }
67             break;
68
69         default: {
70             THROW_IE_EXCEPTION << "Unsupported mean variant: " << pp.getMeanVariant();
71         }
72     }
73 }
74
75 void MeanImage::Subtract(const MKLDNNDims &inputDims, float *input, InferenceEngine::Layout layout) {
76     IE_ASSERT(input != nullptr);
77
78     if (inputDims.ndims() != 4) {
79         THROW_IE_EXCEPTION << "Expecting input as 4 dimension blob with format NxCxHxW.";
80     }
81
82     if (layout != NCHW && layout != NHWC) {
83         THROW_IE_EXCEPTION << "Expecting input layout NCHW or NHWC.";
84     }
85
86     int MB = inputDims[0];
87     int srcSize = inputDims.size() / MB;
88
89     if (meanBuffer && meanBuffer->size()) {
90         const float * meanBufferValues = meanBuffer->readOnly();
91
92         parallel_for2d(MB, srcSize, [&](int mb, int i) {
93             input[srcSize * mb + i] -= meanBufferValues[i];
94         });
95     } else if (!meanValues.empty()) {
96         int C = inputDims[1];
97         srcSize /= inputDims[1];
98
99         if (layout == NCHW) {
100             parallel_for3d(MB, C, srcSize, [&](int mb, int c, int i) {
101                 input[mb * C * srcSize + c * srcSize + i] -= meanValues[c];
102             });
103         } else if (layout == NHWC) {
104             parallel_for2d(MB, srcSize, [&](int mb, int i) {
105                 for (int c = 0; c < C; c++)
106                     input[mb * srcSize * C + i * C + c] -= meanValues[c];
107             });
108         }
109     }
110 }