1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include <mkldnn_types.h>
7 #include "mkldnn_primitive.h"
8 #include "../../thirdparty/mkl-dnn/src/common/primitive_desc.hpp"
9 #include "../../thirdparty/mkl-dnn/src/common/memory_pd.hpp"
10 #include "../../thirdparty/mkl-dnn/src/cpu/cpu_concat.hpp"
12 using namespace MKLDNNPlugin;
14 MKLDNNPrimitive::MKLDNNPrimitive() {}
15 MKLDNNPrimitive::MKLDNNPrimitive(const std::shared_ptr<mkldnn::primitive>& prim): prim(prim) {}
17 MKLDNNPrimitive::operator std::shared_ptr<mkldnn::primitive>() {
21 MKLDNNPrimitive::operator bool() {
22 return prim ? true : false;
25 mkldnn::primitive MKLDNNPrimitive::operator*() {
29 void MKLDNNPrimitive::reset(mkldnn::primitive* prim) {
30 this->prim.reset(prim);
33 MKLDNNPrimitive &MKLDNNPrimitive::operator=(const std::shared_ptr<mkldnn::primitive>& prim) {
38 void MKLDNNPrimitive::setBatchLimit(int batch, size_t inputNum, size_t outputNum) {
40 auto * primDesc = prim->get_primitive_desc();
41 auto * concatPrimDesc = dynamic_cast<const mkldnn::impl::cpu::cpu_concat_pd_t *>(primDesc);
42 for (int i = 0; success && i < primDesc->n_inputs() && i < inputNum; i++) {
43 // Depthwise layers contains weights as input
44 if (primDesc->input_pd()->desc()->ndims != primDesc->input_pd(i)->desc()->ndims)
46 auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->input_pd(i)->desc());
47 if (originInputBatches.size() <= i)
48 originInputBatches.push_back(memDesc->dims[0]);
50 if (batch > originInputBatches[i])
52 memDesc->dims[0] = batch;
53 memDesc->layout_desc.blocking.padding_dims[0] = batch;
54 if (concatPrimDesc != nullptr) {
55 memDesc = const_cast<mkldnn_memory_desc_t *>(concatPrimDesc->src_image_pd(i)->desc());
56 memDesc->dims[0] = batch;
57 memDesc->layout_desc.blocking.padding_dims[0] = batch;
60 for (int i = 0; success && i < primDesc->n_outputs() && i < outputNum; i++) {
61 if (primDesc->output_pd()->desc()->ndims != primDesc->output_pd(i)->desc()->ndims)
63 auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->output_pd(i)->desc());
64 if (i < inputNum && memDesc == primDesc->input_pd(i)->desc())
66 if (originOutputBatches.size() <= i)
67 originOutputBatches.push_back(memDesc->dims[0]);
69 if (batch > originOutputBatches[i])
71 memDesc->dims[0] = batch;
72 memDesc->layout_desc.blocking.padding_dims[0] = batch;
78 for (int i = 0; i < primDesc->n_inputs() && i < originInputBatches.size(); i++) {
79 auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->input_pd(i)->desc());
80 memDesc->dims[0] = originInputBatches[i];
81 memDesc->layout_desc.blocking.padding_dims[0] = originInputBatches[i];
83 for (int i = 0; i < primDesc->n_outputs() && i < originOutputBatches.size(); i++) {
84 auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->output_pd(i)->desc());
85 memDesc->dims[0] = originOutputBatches[i];
86 memDesc->layout_desc.blocking.padding_dims[0] = originOutputBatches[i];
89 THROW_IE_EXCEPTION << "Dynamic batch cannot be changed!";