Publishing R3
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mean_image.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include "mean_image.h"
7
8 using namespace MKLDNNPlugin;
9 using namespace InferenceEngine;
10
11 MeanImage::MeanImage() : meanBuffer(nullptr) {
12 }
13
14 void MeanImage::Load(const MKLDNNDims& inputDims, InputInfo::Ptr inputInfo) {
15     PreProcessInfo &pp = inputInfo->getPreProcess();
16     size_t inChannels = pp.getNumberOfChannels();
17     if (inChannels == 0) {
18         meanBuffer = nullptr;
19         return;
20     }
21
22     if (inChannels != inputDims[1]) {
23         THROW_IE_EXCEPTION << "channels mismatch between mean and input";
24     }
25
26     ResponseDesc resp;
27
28     switch (pp.getMeanVariant()) {
29         case MEAN_VALUE: {
30             // mean image common value per channel (1x1xC)
31             meanValues.resize(inChannels);
32
33             for (unsigned channel = 0; channel < inChannels; channel++) {
34                 meanValues[channel] = pp[channel]->meanValue;
35             }
36         }
37         break;
38         case MEAN_IMAGE: {
39             // since MKLDNN expects all channels in the same buffer - we copy it here as it comes from different channels...
40             auto meanWidth = pp[0]->meanData->dims()[0];
41             auto meanHeight = pp[0]->meanData->dims()[1];
42
43
44             meanBuffer = make_shared_blob<float>(Precision::FP32, CHW, { meanWidth, meanHeight, inChannels });
45
46             meanBuffer->allocate();
47
48             for (unsigned channel = 0; channel < inChannels; channel++) {
49                 Blob::Ptr meanBlob = pp[channel]->meanData;
50                 if (!meanBlob || meanBlob->precision() != Precision::FP32)
51                     THROW_IE_EXCEPTION << "mean image not provided or not in Float 32";
52                 if (meanBlob->size() != meanHeight*meanWidth) {
53                     THROW_IE_EXCEPTION << "mean image size does not match expected network input, expecting " << meanWidth << " x " << meanHeight;
54                 }
55                 // todo: cast to TBlob and make sure it is floats
56                 memcpy(meanBuffer->data() + channel*meanBlob->size(), meanBlob->buffer(), meanBlob->byteSize());
57             }
58         }
59             break;
60
61         case NONE: {
62             // there is no mean image. So disable mean image step
63             meanBuffer = nullptr;
64         }
65             break;
66
67         default: {
68             THROW_IE_EXCEPTION << "Unsupported mean variant: " << pp.getMeanVariant();
69         }
70     }
71 }
72
73 void MeanImage::Subtract(const MKLDNNDims &inputDims, float *input) {
74     IE_ASSERT(input != nullptr);
75
76     if (inputDims.ndims() != 4) {
77         THROW_IE_EXCEPTION << "Expecting input as 4 dimension blob with format NxCxHxW.";
78     }
79
80     int MB = inputDims[0];
81     int srcSize = inputDims.size() / MB;
82
83     if (meanBuffer && meanBuffer->size()) {
84         const float * meanBufferValues = meanBuffer->readOnly();
85 #   pragma omp parallel for collapse(2) schedule(static)
86         for (int mb = 0; mb < MB; mb++) {
87             for (int i = 0; i < srcSize; i++) {
88                 input[srcSize * mb + i] -= meanBufferValues[i];
89             }
90         }
91     } else if (!meanValues.empty()) {
92         int C = inputDims[1];
93         srcSize /= inputDims[1];
94
95 #   pragma omp parallel for collapse(3) schedule(static)
96         for (int mb = 0; mb < MB; mb++) {
97             for (int c = 0; c < C; c++) {
98                 for (int i = 0; i < srcSize; i++) {
99                     input[srcSize * mb * C + c * srcSize + i] -= meanValues[c];
100                 }
101             }
102         }
103     }
104 }