1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <mkldnn_types.h>
6 #include "mkldnn_primitive.h"
7 #include "../../thirdparty/mkl-dnn/src/common/primitive_desc.hpp"
8 #include "../../thirdparty/mkl-dnn/src/common/memory_pd.hpp"
9 #include "../../thirdparty/mkl-dnn/src/cpu/cpu_concat.hpp"
11 using namespace MKLDNNPlugin;
13 MKLDNNPrimitive::MKLDNNPrimitive() {}
15 MKLDNNPrimitive::operator bool() {
16 return prim ? true : false;
19 mkldnn::primitive MKLDNNPrimitive::operator*() {
23 void MKLDNNPrimitive::reset(mkldnn::primitive* prim) {
24 this->prim.reset(prim);
27 MKLDNNPrimitive &MKLDNNPrimitive::operator=(const std::shared_ptr<mkldnn::primitive>& prim) {
32 void MKLDNNPrimitive::setBatchLimit(int batch, size_t inputNum, size_t outputNum) {
34 auto * primDesc = prim->get_primitive_desc();
35 auto * concatPrimDesc = dynamic_cast<const mkldnn::impl::cpu::cpu_concat_pd_t *>(primDesc);
36 for (int i = 0; success && i < primDesc->n_inputs() && i < inputNum; i++) {
37 // Depthwise layers contains weights as input
38 if (primDesc->input_pd()->desc()->ndims != primDesc->input_pd(i)->desc()->ndims)
40 auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->input_pd(i)->desc());
41 if (originInputBatches.size() <= i)
42 originInputBatches.push_back(memDesc->dims[0]);
44 if (batch > originInputBatches[i])
46 memDesc->dims[0] = batch;
47 memDesc->layout_desc.blocking.padding_dims[0] = batch;
48 if (concatPrimDesc != nullptr) {
49 memDesc = const_cast<mkldnn_memory_desc_t *>(concatPrimDesc->src_image_pd(i)->desc());
50 memDesc->dims[0] = batch;
51 memDesc->layout_desc.blocking.padding_dims[0] = batch;
54 for (int i = 0; success && i < primDesc->n_outputs() && i < outputNum; i++) {
55 if (primDesc->output_pd()->desc()->ndims != primDesc->output_pd(i)->desc()->ndims)
57 auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->output_pd(i)->desc());
58 if (i < inputNum && memDesc == primDesc->input_pd(i)->desc())
60 if (originOutputBatches.size() <= i)
61 originOutputBatches.push_back(memDesc->dims[0]);
63 if (batch > originOutputBatches[i])
65 memDesc->dims[0] = batch;
66 memDesc->layout_desc.blocking.padding_dims[0] = batch;
72 for (int i = 0; i < primDesc->n_inputs() && i < originInputBatches.size(); i++) {
73 auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->input_pd(i)->desc());
74 memDesc->dims[0] = originInputBatches[i];
75 memDesc->layout_desc.blocking.padding_dims[0] = originInputBatches[i];
77 for (int i = 0; i < primDesc->n_outputs() && i < originOutputBatches.size(); i++) {
78 auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->output_pd(i)->desc());
79 memDesc->dims[0] = originOutputBatches[i];
80 memDesc->layout_desc.blocking.padding_dims[0] = originOutputBatches[i];
83 THROW_IE_EXCEPTION << "Dynamic batch cannot be changed!";