1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
7 #include "inference_engine.hpp"
8 #include "mkldnn_dims.h"
9 #include "ie_parallel.hpp"
13 namespace MKLDNNPlugin {
20 void Load(const MKLDNNDims& inputDims, InferenceEngine::InputInfo::Ptr inputInfo);
21 void Subtract(const MKLDNNDims &inputDims, float *input, InferenceEngine::Layout layout);
23 template<typename T, typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
24 void Subtract(const MKLDNNDims &inputDims, T *input, InferenceEngine::Layout layout) {
25 IE_ASSERT(input != nullptr);
27 if (inputDims.ndims() != 4) {
28 THROW_IE_EXCEPTION << "Expecting input as 4 dimension blob with format NxCxHxW.";
31 if (layout != InferenceEngine::NCHW && layout != InferenceEngine::NHWC) {
32 THROW_IE_EXCEPTION << "Expecting input layout NCHW or NHWC.";
35 int MB = inputDims[0];
36 int srcSize = inputDims.size() / MB;
38 if (meanBuffer && meanBuffer->size()) {
39 const float * meanBufferValues = meanBuffer->readOnly();
41 InferenceEngine::parallel_for2d(MB, srcSize, [&](int mb, int i) {
42 int buf = input[srcSize * mb + i];
43 buf -= meanBufferValues[i];
44 if (buf < std::numeric_limits<T>::min()) buf = std::numeric_limits<T>::min();
45 if (buf > std::numeric_limits<T>::max()) buf = std::numeric_limits<T>::max();
46 input[srcSize * mb + i] = buf;
48 } else if (!meanValues.empty()) {
50 srcSize /= inputDims[1];
52 if (layout == InferenceEngine::NCHW) {
53 InferenceEngine::parallel_for3d(MB, C, srcSize, [&](int mb, int c, int i) {
54 int buf = input[srcSize * mb * C + c * srcSize + i];
56 if (buf < std::numeric_limits<T>::min()) buf = std::numeric_limits<T>::min();
57 if (buf > std::numeric_limits<T>::max()) buf = std::numeric_limits<T>::max();
58 input[srcSize * mb * C + c * srcSize + i] = buf;
60 } else if (layout == InferenceEngine::NHWC) {
61 InferenceEngine::parallel_for2d(MB, srcSize, [&](int mb, int i) {
62 for (int c = 0; c < C; c++) {
63 int buf = input[mb * srcSize * C + i * C + c];
65 if (buf < std::numeric_limits<T>::min()) buf = std::numeric_limits<T>::min();
66 if (buf > std::numeric_limits<T>::max()) buf = std::numeric_limits<T>::max();
67 input[mb * srcSize * C + i * C + c] = buf;
75 std::vector<float> meanValues;
77 InferenceEngine::TBlob<float>::Ptr meanBuffer;
80 } // namespace MKLDNNPlugin