-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "mkldnn_dims.h"
#include "mkldnn_edge.h"
#include "mkldnn_memory.h"
+#include "ie_parallel.hpp"
#include <limits>
using namespace mkldnn;
}
initDescriptor(config);
}
+
+void MKLDNNConcatNode::execute(mkldnn::stream strm) {
+ if (isOptimized()) {
+ return;
+ }
+
+ const MKLDNNMemory& dst_memory = getChildEdgeAt(0)->getMemory();
+ const mkldnn::memory::data_type data_type = dst_memory.GetDataType();
+
+ const bool isInt8 = (data_type == mkldnn_s8 || data_type == mkldnn_u8);
+
+ if (isInt8) {
+ uint8_t* dst_ptr = reinterpret_cast<uint8_t*>(dst_memory.GetData());
+
+ const size_t num_src = getParentEdges().size();
+
+ std::vector<size_t> channels;
+ size_t channels_size = 0;
+ std::vector<const uint8_t*> src_ptrs;
+ std::vector<uint8_t*> dst_ptrs;
+
+ for (size_t i = 0; i < num_src; i++) {
+ const MKLDNNMemory& src_mem = getParentEdgeAt(i)->getMemory();
+ const size_t num_channels = src_mem.GetDims()[1];
+
+ channels.push_back(num_channels);
+ src_ptrs.push_back(reinterpret_cast<const uint8_t*>(src_mem.GetData()));
+ dst_ptrs.push_back(dst_ptr + channels_size);
+ channels_size += num_channels;
+ }
+
+ const size_t iter_count = getParentEdgeAt(0)->getMemory().GetSize() / channels[0];
+
+ parallel_for(iter_count, [&](int i) {
+ const size_t dst_off = i * channels_size;
+ for (int j = 0; j < num_src; j++) {
+ memcpy(dst_ptrs[j] + dst_off, src_ptrs[j] + i * channels[j], channels[j]);
+ }
+ });
+ } else {
+ MKLDNNNode::execute(strm);
+ }
+}