-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
}
}
-void MeanImage::Subtract(const MKLDNNDims &inputDims, float *input) {
+void MeanImage::Subtract(const MKLDNNDims &inputDims, float *input, InferenceEngine::Layout layout) {
IE_ASSERT(input != nullptr);
if (inputDims.ndims() != 4) {
THROW_IE_EXCEPTION << "Expecting input as 4 dimension blob with format NxCxHxW.";
}
+ if (layout != NCHW && layout != NHWC) {
+ THROW_IE_EXCEPTION << "Expecting input layout NCHW or NHWC.";
+ }
+
int MB = inputDims[0];
int srcSize = inputDims.size() / MB;
int C = inputDims[1];
srcSize /= inputDims[1];
- parallel_for3d(MB, C, srcSize, [&](int mb, int c, int i) {
- input[srcSize * mb * C + c * srcSize + i] -= meanValues[c];
- });
+ if (layout == NCHW) {
+ parallel_for3d(MB, C, srcSize, [&](int mb, int c, int i) {
+ input[mb * C * srcSize + c * srcSize + i] -= meanValues[c];
+ });
+ } else if (layout == NHWC) {
+ parallel_for2d(MB, srcSize, [&](int mb, int i) {
+ for (int c = 0; c < C; c++)
+ input[mb * srcSize * C + i * C + c] -= meanValues[c];
+ });
+ }
}
}