1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include "mean_image.h"
8 using namespace MKLDNNPlugin;
9 using namespace InferenceEngine;
11 MeanImage::MeanImage() : meanBuffer(nullptr) {
14 void MeanImage::Load(const MKLDNNDims& inputDims, InputInfo::Ptr inputInfo) {
15 PreProcessInfo &pp = inputInfo->getPreProcess();
16 size_t inChannels = pp.getNumberOfChannels();
17 if (inChannels == 0) {
22 if (inChannels != inputDims[1]) {
23 THROW_IE_EXCEPTION << "channels mismatch between mean and input";
28 switch (pp.getMeanVariant()) {
30 // mean image common value per channel (1x1xC)
31 meanValues.resize(inChannels);
33 for (unsigned channel = 0; channel < inChannels; channel++) {
34 meanValues[channel] = pp[channel]->meanValue;
39 // since MKLDNN expects all channels in the same buffer - we copy it here as it comes from different channels...
40 auto meanWidth = pp[0]->meanData->dims()[0];
41 auto meanHeight = pp[0]->meanData->dims()[1];
44 meanBuffer = make_shared_blob<float>(Precision::FP32, CHW, { meanWidth, meanHeight, inChannels });
46 meanBuffer->allocate();
48 for (unsigned channel = 0; channel < inChannels; channel++) {
49 Blob::Ptr meanBlob = pp[channel]->meanData;
50 if (!meanBlob || meanBlob->precision() != Precision::FP32)
51 THROW_IE_EXCEPTION << "mean image not provided or not in Float 32";
52 if (meanBlob->size() != meanHeight*meanWidth) {
53 THROW_IE_EXCEPTION << "mean image size does not match expected network input, expecting " << meanWidth << " x " << meanHeight;
55 // todo: cast to TBlob and make sure it is floats
56 memcpy(meanBuffer->data() + channel*meanBlob->size(), meanBlob->buffer(), meanBlob->byteSize());
62 // there is no mean image. So disable mean image step
68 THROW_IE_EXCEPTION << "Unsupported mean variant: " << pp.getMeanVariant();
73 void MeanImage::Subtract(const MKLDNNDims &inputDims, float *input) {
74 IE_ASSERT(input != nullptr);
76 if (inputDims.ndims() != 4) {
77 THROW_IE_EXCEPTION << "Expecting input as 4 dimension blob with format NxCxHxW.";
80 int MB = inputDims[0];
81 int srcSize = inputDims.size() / MB;
83 if (meanBuffer && meanBuffer->size()) {
84 const float * meanBufferValues = meanBuffer->readOnly();
85 # pragma omp parallel for collapse(2) schedule(static)
86 for (int mb = 0; mb < MB; mb++) {
87 for (int i = 0; i < srcSize; i++) {
88 input[srcSize * mb + i] -= meanBufferValues[i];
91 } else if (!meanValues.empty()) {
93 srcSize /= inputDims[1];
95 # pragma omp parallel for collapse(3) schedule(static)
96 for (int mb = 0; mb < MB; mb++) {
97 for (int c = 0; c < C; c++) {
98 for (int i = 0; i < srcSize; i++) {
99 input[srcSize * mb * C + c * srcSize + i] -= meanValues[c];