Publishing R3
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_primitive.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <mkldnn_types.h>
7 #include "mkldnn_primitive.h"
8 #include "../../thirdparty/mkl-dnn/src/common/primitive_desc.hpp"
9 #include "../../thirdparty/mkl-dnn/src/common/memory_pd.hpp"
10 #include "../../thirdparty/mkl-dnn/src/cpu/cpu_concat.hpp"
11
12 using namespace MKLDNNPlugin;
13
14 MKLDNNPrimitive::MKLDNNPrimitive() {}
15 MKLDNNPrimitive::MKLDNNPrimitive(const std::shared_ptr<mkldnn::primitive>& prim): prim(prim) {}
16
17 MKLDNNPrimitive::operator std::shared_ptr<mkldnn::primitive>() {
18     return prim;
19 }
20
21 MKLDNNPrimitive::operator bool() {
22     return prim ? true : false;
23 }
24
25 mkldnn::primitive MKLDNNPrimitive::operator*() {
26     return *prim;
27 }
28
29 void MKLDNNPrimitive::reset(mkldnn::primitive* prim) {
30     this->prim.reset(prim);
31 }
32
33 MKLDNNPrimitive &MKLDNNPrimitive::operator=(const std::shared_ptr<mkldnn::primitive>& prim) {
34     this->prim = prim;
35     return *this;
36 }
37
38 void MKLDNNPrimitive::setBatchLimit(int batch, size_t inputNum, size_t outputNum) {
39     bool success = true;
40     auto * primDesc = prim->get_primitive_desc();
41     auto * concatPrimDesc = dynamic_cast<const mkldnn::impl::cpu::cpu_concat_pd_t *>(primDesc);
42     for (int i = 0; success && i < primDesc->n_inputs() && i < inputNum; i++) {
43         // Depthwise layers contains weights as input
44         if (primDesc->input_pd()->desc()->ndims != primDesc->input_pd(i)->desc()->ndims)
45             break;
46         auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->input_pd(i)->desc());
47         if (originInputBatches.size() <= i)
48             originInputBatches.push_back(memDesc->dims[0]);
49
50         if (batch > originInputBatches[i])
51             success = false;
52         memDesc->dims[0] = batch;
53         memDesc->layout_desc.blocking.padding_dims[0] = batch;
54         if (concatPrimDesc != nullptr) {
55             memDesc = const_cast<mkldnn_memory_desc_t *>(concatPrimDesc->src_image_pd(i)->desc());
56             memDesc->dims[0] = batch;
57             memDesc->layout_desc.blocking.padding_dims[0] = batch;
58         }
59     }
60     for (int i = 0; success && i < primDesc->n_outputs() && i < outputNum; i++) {
61         if (primDesc->output_pd()->desc()->ndims != primDesc->output_pd(i)->desc()->ndims)
62             break;
63         auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->output_pd(i)->desc());
64         if (i < inputNum && memDesc == primDesc->input_pd(i)->desc())
65             continue;
66         if (originOutputBatches.size() <= i)
67             originOutputBatches.push_back(memDesc->dims[0]);
68
69         if (batch > originOutputBatches[i])
70             success = false;
71         memDesc->dims[0] = batch;
72         memDesc->layout_desc.blocking.padding_dims[0] = batch;
73     }
74
75     if (success)
76         return;
77
78     for (int i = 0; i < primDesc->n_inputs() && i < originInputBatches.size(); i++) {
79         auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->input_pd(i)->desc());
80         memDesc->dims[0] = originInputBatches[i];
81         memDesc->layout_desc.blocking.padding_dims[0] = originInputBatches[i];
82     }
83     for (int i = 0; i < primDesc->n_outputs() && i < originOutputBatches.size(); i++) {
84         auto * memDesc = const_cast<mkldnn_memory_desc_t *>(primDesc->output_pd(i)->desc());
85         memDesc->dims[0] = originOutputBatches[i];
86         memDesc->layout_desc.blocking.padding_dims[0] = originOutputBatches[i];
87     }
88
89     THROW_IE_EXCEPTION << "Dynamic batch cannot be changed!";
90 }