Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_concat_node.cpp
index fd2893e..ec370ee 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
 // SPDX-License-Identifier: Apache-2.0
 //
 
@@ -16,6 +16,7 @@
 #include "mkldnn_dims.h"
 #include "mkldnn_edge.h"
 #include "mkldnn_memory.h"
+#include "ie_parallel.hpp"
 #include <limits>
 
 using namespace mkldnn;
@@ -509,3 +510,46 @@ void MKLDNNConcatNode::initOptimalPrimitiveDescriptor() {
     }
     initDescriptor(config);
 }
+
+void MKLDNNConcatNode::execute(mkldnn::stream strm) {
+    if (isOptimized()) {
+        return;
+    }
+
+    const MKLDNNMemory& dst_memory = getChildEdgeAt(0)->getMemory();
+    const mkldnn::memory::data_type data_type = dst_memory.GetDataType();
+
+    const bool isInt8 = (data_type == mkldnn_s8 || data_type == mkldnn_u8);
+
+    if (isInt8) {
+        uint8_t* dst_ptr = reinterpret_cast<uint8_t*>(dst_memory.GetData());
+
+        const size_t num_src = getParentEdges().size();
+
+        std::vector<size_t> channels;
+        size_t channels_size = 0;
+        std::vector<const uint8_t*> src_ptrs;
+        std::vector<uint8_t*> dst_ptrs;
+
+        for (size_t i = 0; i < num_src; i++) {
+            const MKLDNNMemory& src_mem = getParentEdgeAt(i)->getMemory();
+            const size_t num_channels = src_mem.GetDims()[1];
+
+            channels.push_back(num_channels);
+            src_ptrs.push_back(reinterpret_cast<const uint8_t*>(src_mem.GetData()));
+            dst_ptrs.push_back(dst_ptr + channels_size);
+            channels_size += num_channels;
+        }
+
+        const size_t iter_count = getParentEdgeAt(0)->getMemory().GetSize() / channels[0];
+
+        parallel_for(iter_count, [&](int i) {
+            const size_t dst_off = i * channels_size;
+            for (int j = 0; j < num_src; j++) {
+                memcpy(dst_ptrs[j] + dst_off, src_ptrs[j] + i * channels[j], channels[j]);
+            }
+        });
+    } else {
+        MKLDNNNode::execute(strm);
+    }
+}