1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "mean_image.h"
6 #include "ie_parallel.hpp"
9 using namespace MKLDNNPlugin;
10 using namespace InferenceEngine;
12 MeanImage::MeanImage() : meanBuffer(nullptr) {
15 void MeanImage::Load(const MKLDNNDims& inputDims, InputInfo::Ptr inputInfo) {
16 PreProcessInfo &pp = inputInfo->getPreProcess();
17 size_t inChannels = pp.getNumberOfChannels();
18 if (inChannels == 0) {
23 if (inChannels != inputDims[1]) {
24 THROW_IE_EXCEPTION << "channels mismatch between mean and input";
29 switch (pp.getMeanVariant()) {
31 // mean image common value per channel (1x1xC)
32 meanValues.resize(inChannels);
34 for (unsigned channel = 0; channel < inChannels; channel++) {
35 meanValues[channel] = pp[channel]->meanValue;
40 // since MKLDNN expects all channels in the same buffer - we copy it here as it comes from different channels...
41 auto meanWidth = pp[0]->meanData->dims()[0];
42 auto meanHeight = pp[0]->meanData->dims()[1];
45 meanBuffer = make_shared_blob<float>(Precision::FP32, CHW, { meanWidth, meanHeight, inChannels });
47 meanBuffer->allocate();
49 for (unsigned channel = 0; channel < inChannels; channel++) {
50 Blob::Ptr meanBlob = pp[channel]->meanData;
51 if (!meanBlob || meanBlob->precision() != Precision::FP32)
52 THROW_IE_EXCEPTION << "mean image not provided or not in Float 32";
53 if (meanBlob->size() != meanHeight*meanWidth) {
54 THROW_IE_EXCEPTION << "mean image size does not match expected network input, expecting " << meanWidth << " x " << meanHeight;
56 // todo: cast to TBlob and make sure it is floats
57 ie_memcpy(meanBuffer->data() + channel*meanBlob->size(), meanBuffer->byteSize() - channel*meanBlob->byteSize(),
58 meanBlob->buffer(), meanBlob->byteSize());
64 // there is no mean image. So disable mean image step
70 THROW_IE_EXCEPTION << "Unsupported mean variant: " << pp.getMeanVariant();
75 void MeanImage::Subtract(const MKLDNNDims &inputDims, float *input, InferenceEngine::Layout layout) {
76 IE_ASSERT(input != nullptr);
78 if (inputDims.ndims() != 4) {
79 THROW_IE_EXCEPTION << "Expecting input as 4 dimension blob with format NxCxHxW.";
82 if (layout != NCHW && layout != NHWC) {
83 THROW_IE_EXCEPTION << "Expecting input layout NCHW or NHWC.";
86 int MB = inputDims[0];
87 int srcSize = inputDims.size() / MB;
89 if (meanBuffer && meanBuffer->size()) {
90 const float * meanBufferValues = meanBuffer->readOnly();
92 parallel_for2d(MB, srcSize, [&](int mb, int i) {
93 input[srcSize * mb + i] -= meanBufferValues[i];
95 } else if (!meanValues.empty()) {
97 srcSize /= inputDims[1];
100 parallel_for3d(MB, C, srcSize, [&](int mb, int c, int i) {
101 input[mb * C * srcSize + c * srcSize + i] -= meanValues[c];
103 } else if (layout == NHWC) {
104 parallel_for2d(MB, srcSize, [&](int mb, int i) {
105 for (int c = 0; c < C; c++)
106 input[mb * srcSize * C + i * C + c] -= meanValues[c];