Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_primitive.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <mkldnn_types.h>
6 #include "mkldnn_primitive.h"
7 #include "../../thirdparty/mkl-dnn/src/common/primitive_desc.hpp"
8 #include "../../thirdparty/mkl-dnn/src/common/memory_pd.hpp"
9 #include "../../thirdparty/mkl-dnn/src/cpu/cpu_concat.hpp"
10
11 using namespace MKLDNNPlugin;
12
13 MKLDNNPrimitive::MKLDNNPrimitive() {}
14
15 MKLDNNPrimitive::operator bool() {
16     return prim ? true : false;
17 }
18
19 mkldnn::primitive MKLDNNPrimitive::operator*() {
20     return *prim;
21 }
22
23 void MKLDNNPrimitive::reset(mkldnn::primitive* prim) {
24     this->prim.reset(prim);
25 }
26
27 MKLDNNPrimitive &MKLDNNPrimitive::operator=(const std::shared_ptr<mkldnn::primitive>& prim) {
28     this->prim = prim;
29     return *this;
30 }
31
32 void MKLDNNPrimitive::setBatchLimit(int batch, size_t inputNum, size_t outputNum) {
33     bool success = true;
34     auto * primDesc = prim->get_primitive_desc();
35     auto * concatPrimDesc = dynamic_cast<const mkldnn::impl::cpu::cpu_concat_pd_t *>(primDesc);
36     for (int i = 0; success && i < primDesc->n_inputs() && i < inputNum; i++) {
37         // Depthwise layers contains weights as input
38         if (primDesc->input_pd()->desc()->ndims != primDesc->input_pd(i)->desc()->ndims)
39             break;
40         auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->input_pd(i)->desc());
41         if (originInputBatches.size() <= i)
42             originInputBatches.push_back(memDesc->dims[0]);
43
44         if (batch > originInputBatches[i])
45             success = false;
46         memDesc->dims[0] = batch;
47         memDesc->layout_desc.blocking.padding_dims[0] = batch;
48         if (concatPrimDesc != nullptr) {
49             memDesc = const_cast<mkldnn_memory_desc_t *>(concatPrimDesc->src_image_pd(i)->desc());
50             memDesc->dims[0] = batch;
51             memDesc->layout_desc.blocking.padding_dims[0] = batch;
52         }
53     }
54     for (int i = 0; success && i < primDesc->n_outputs() && i < outputNum; i++) {
55         if (primDesc->output_pd()->desc()->ndims != primDesc->output_pd(i)->desc()->ndims)
56             break;
57         auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->output_pd(i)->desc());
58         if (i < inputNum && memDesc == primDesc->input_pd(i)->desc())
59             continue;
60         if (originOutputBatches.size() <= i)
61             originOutputBatches.push_back(memDesc->dims[0]);
62
63         if (batch > originOutputBatches[i])
64             success = false;
65         memDesc->dims[0] = batch;
66         memDesc->layout_desc.blocking.padding_dims[0] = batch;
67     }
68
69     if (success)
70         return;
71
72     for (int i = 0; i < primDesc->n_inputs() && i < originInputBatches.size(); i++) {
73         auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->input_pd(i)->desc());
74         memDesc->dims[0] = originInputBatches[i];
75         memDesc->layout_desc.blocking.padding_dims[0] = originInputBatches[i];
76     }
77     for (int i = 0; i < primDesc->n_outputs() && i < originOutputBatches.size(); i++) {
78         auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->output_pd(i)->desc());
79         memDesc->dims[0] = originOutputBatches[i];
80         memDesc->layout_desc.blocking.padding_dims[0] = originOutputBatches[i];
81     }
82
83     THROW_IE_EXCEPTION << "Dynamic batch cannot be changed!";
84 }