Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mean_image.cpp
index f1ac17e..dcf11ef 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
 // SPDX-License-Identifier: Apache-2.0
 //
 
@@ -72,13 +72,17 @@ void MeanImage::Load(const MKLDNNDims& inputDims, InputInfo::Ptr inputInfo) {
     }
 }
 
-void MeanImage::Subtract(const MKLDNNDims &inputDims, float *input) {
+void MeanImage::Subtract(const MKLDNNDims &inputDims, float *input, InferenceEngine::Layout layout) {
     IE_ASSERT(input != nullptr);
 
     if (inputDims.ndims() != 4) {
         THROW_IE_EXCEPTION << "Expecting input as 4 dimension blob with format NxCxHxW.";
     }
 
+    if (layout != NCHW && layout != NHWC) {
+        THROW_IE_EXCEPTION << "Expecting input layout NCHW or NHWC.";
+    }
+
     int MB = inputDims[0];
     int srcSize = inputDims.size() / MB;
 
@@ -92,8 +96,15 @@ void MeanImage::Subtract(const MKLDNNDims &inputDims, float *input) {
         int C = inputDims[1];
         srcSize /= inputDims[1];
 
-        parallel_for3d(MB, C, srcSize, [&](int mb, int c, int i) {
-            input[srcSize * mb * C + c * srcSize + i] -= meanValues[c];
-        });
+        if (layout == NCHW) {
+            parallel_for3d(MB, C, srcSize, [&](int mb, int c, int i) {
+                input[mb * C * srcSize + c * srcSize + i] -= meanValues[c];
+            });
+        } else if (layout == NHWC) {
+            parallel_for2d(MB, srcSize, [&](int mb, int i) {
+                for (int c = 0; c < C; c++)
+                    input[mb * srcSize * C + i * C + c] -= meanValues[c];
+            });
+        }
     }
 }