Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mean_image.h
index 24dc816..eba0762 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
 // SPDX-License-Identifier: Apache-2.0
 //
 
@@ -18,16 +18,20 @@ public:
 
 public:
     void Load(const MKLDNNDims& inputDims, InferenceEngine::InputInfo::Ptr inputInfo);
-    void Subtract(const MKLDNNDims &inputDims, float *input);
+    void Subtract(const MKLDNNDims &inputDims, float *input, InferenceEngine::Layout layout);
 
     template<typename T, typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
-    void Subtract(const MKLDNNDims &inputDims, T *input) {
+    void Subtract(const MKLDNNDims &inputDims, T *input, InferenceEngine::Layout layout) {
         IE_ASSERT(input != nullptr);
 
         if (inputDims.ndims() != 4) {
             THROW_IE_EXCEPTION << "Expecting input as 4 dimension blob with format NxCxHxW.";
         }
 
+        if (layout != InferenceEngine::NCHW && layout != InferenceEngine::NHWC) {
+            THROW_IE_EXCEPTION << "Expecting input layout NCHW or NHWC.";
+        }
+
         int MB = inputDims[0];
         int srcSize = inputDims.size() / MB;
 
@@ -45,13 +49,25 @@ public:
             int C = inputDims[1];
             srcSize /= inputDims[1];
 
-            InferenceEngine::parallel_for3d(MB, C, srcSize, [&](int mb, int c, int i) {
-                int buf = input[srcSize * mb * C + c * srcSize + i];
-                buf -= meanValues[c];
-                if (buf < std::numeric_limits<T>::min()) buf = std::numeric_limits<T>::min();
-                if (buf > std::numeric_limits<T>::max()) buf = std::numeric_limits<T>::max();
-                input[srcSize * mb * C + c * srcSize + i] = buf;
-            });
+            if (layout == InferenceEngine::NCHW) {
+                InferenceEngine::parallel_for3d(MB, C, srcSize, [&](int mb, int c, int i) {
+                    int buf = input[srcSize * mb * C + c * srcSize + i];
+                    buf -= meanValues[c];
+                    if (buf < std::numeric_limits<T>::min()) buf = std::numeric_limits<T>::min();
+                    if (buf > std::numeric_limits<T>::max()) buf = std::numeric_limits<T>::max();
+                    input[srcSize * mb * C + c * srcSize + i] = buf;
+                });
+            } else if (layout == InferenceEngine::NHWC) {
+                InferenceEngine::parallel_for2d(MB, srcSize, [&](int mb, int i) {
+                    for (int c = 0; c < C; c++) {
+                        int buf = input[mb * srcSize * C + i * C + c];
+                        buf -= meanValues[c];
+                        if (buf < std::numeric_limits<T>::min()) buf = std::numeric_limits<T>::min();
+                        if (buf > std::numeric_limits<T>::max()) buf = std::numeric_limits<T>::max();
+                        input[mb * srcSize * C + i * C + c] = buf;
+                    }
+                });
+            }
         }
     }