Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_descriptor.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <details/ie_exception.hpp>
6 #include "mkldnn_descriptor.h"
7
8 mkldnn::primitive_desc_iterator MKLDNNDescriptor::createPrimitiveDescriptorIterator(const mkldnn::engine &engine,
9                                                                                     const mkldnn::primitive_attr &attr) const {
10     return desc->createPrimitiveDescriptorIterator(attr, engine);
11 }
12
13 MKLDNNDescriptor::operator bool() {
14     return desc.get() != nullptr;
15 }
16
17 size_t MKLDNNDescriptor::inputNumbers() const {
18     DescFwdImpl<mkldnn::roi_pooling_forward::desc> *roiPooling =
19             dynamic_cast<DescFwdImpl<mkldnn::roi_pooling_forward::desc> *>(desc.get());
20     if (roiPooling != nullptr) {
21         return roiPooling->getPtr()->c_api_inputs.size();
22     }
23     return 1;
24 }
25
26 size_t MKLDNNDescriptor::outputNumbers() const {
27     return 1;
28 }
29
30 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::batch_normalization_forward::desc> desc) {
31     this->desc.reset(new DescFwdImpl<mkldnn::batch_normalization_forward::desc>(desc));
32 }
33
34 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::batch_normalization_forward::desc>() {
35     DescFwdImpl<mkldnn::batch_normalization_forward::desc> *typeDesc =
36             dynamic_cast<DescFwdImpl<mkldnn::batch_normalization_forward::desc> *>(desc.get());
37     if (typeDesc == nullptr) {
38         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
39     }
40     return typeDesc->getPtr();
41 }
42
43 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_forward::desc> desc) {
44     this->desc.reset(new DescFwdImpl<mkldnn::convolution_forward::desc>(desc));
45 }
46
47 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::convolution_forward::desc>() {
48     DescFwdImpl<mkldnn::convolution_forward::desc> *typeDesc =
49             dynamic_cast<DescFwdImpl<mkldnn::convolution_forward::desc> *>(desc.get());
50     if (typeDesc == nullptr) {
51         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
52     }
53     return typeDesc->getPtr();
54 }
55
56 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_backward_data::desc> desc,
57                                    std::shared_ptr<mkldnn::convolution_forward::primitive_desc> prim) {
58     this->desc.reset(
59             new DescBwdImpl<mkldnn::convolution_backward_data::desc,
60                     mkldnn::convolution_forward::primitive_desc>(desc, prim));
61 }
62
63 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::convolution_backward_data::desc>() {
64     DescBwdImpl<mkldnn::convolution_backward_data::desc, mkldnn::convolution_forward::primitive_desc> *typeDesc =
65             dynamic_cast<DescBwdImpl<mkldnn::convolution_backward_data::desc,
66                     mkldnn::convolution_forward::primitive_desc> *>(desc.get());
67     if (typeDesc == nullptr) {
68         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
69     }
70     return typeDesc->getPtr();
71 }
72
73 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::convolution_forward::primitive_desc>() {
74     DescBwdImpl<mkldnn::convolution_backward_data::desc, mkldnn::convolution_forward::primitive_desc> *typeDesc =
75             dynamic_cast<DescBwdImpl<mkldnn::convolution_backward_data::desc,
76                     mkldnn::convolution_forward::primitive_desc> *>(desc.get());
77     if (typeDesc == nullptr) {
78         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
79     }
80     return typeDesc->getPrimPtr();
81 }
82
83 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::inner_product_forward::desc> desc) {
84     this->desc.reset(new DescFwdImpl<mkldnn::inner_product_forward::desc>(desc));
85 }
86
87 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::inner_product_forward::desc>() {
88     DescFwdImpl<mkldnn::inner_product_forward::desc> *typeDesc =
89             dynamic_cast<DescFwdImpl<mkldnn::inner_product_forward::desc> *>(desc.get());
90     if (typeDesc == nullptr) {
91         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
92     }
93     return typeDesc->getPtr();
94 }
95
96 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::lrn_forward::desc> desc) {
97     this->desc.reset(new DescFwdImpl<mkldnn::lrn_forward::desc>(desc));
98 }
99
100 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::lrn_forward::desc>() {
101     DescFwdImpl<mkldnn::lrn_forward::desc> *typeDesc =
102             dynamic_cast<DescFwdImpl<mkldnn::lrn_forward::desc> *>(desc.get());
103     if (typeDesc == nullptr) {
104         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
105     }
106     return typeDesc->getPtr();
107 }
108
109 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::pooling_forward::desc> desc) {
110     this->desc.reset(new DescFwdImpl<mkldnn::pooling_forward::desc>(desc));
111 }
112
113 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::pooling_forward::desc>() {
114     DescFwdImpl<mkldnn::pooling_forward::desc> *typeDesc =
115             dynamic_cast<DescFwdImpl<mkldnn::pooling_forward::desc> *>(desc.get());
116     if (typeDesc == nullptr) {
117         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
118     }
119     return typeDesc->getPtr();
120 }
121
122 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::roi_pooling_forward::desc> desc) {
123     this->desc.reset(new DescFwdImpl<mkldnn::roi_pooling_forward::desc>(desc));
124 }
125
126 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::roi_pooling_forward::desc>() {
127     DescFwdImpl<mkldnn::roi_pooling_forward::desc> *typeDesc =
128             dynamic_cast<DescFwdImpl<mkldnn::roi_pooling_forward::desc> *>(desc.get());
129     if (typeDesc == nullptr) {
130         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
131     }
132     return typeDesc->getPtr();
133 }
134
135 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::softmax_forward::desc> desc) {
136     this->desc.reset(new DescFwdImpl<mkldnn::softmax_forward::desc>(desc));
137 }
138
139 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::softmax_forward::desc>() {
140     DescFwdImpl<mkldnn::softmax_forward::desc> *typeDesc =
141             dynamic_cast<DescFwdImpl<mkldnn::softmax_forward::desc> *>(desc.get());
142     if (typeDesc == nullptr) {
143         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
144     }
145     return typeDesc->getPtr();
146 }
147
148 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::depthwise_forward::desc> desc) {
149     this->desc.reset(new DescFwdImpl<mkldnn::depthwise_forward::desc>(desc));
150 }
151
152 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::depthwise_forward::desc>() {
153     DescFwdImpl<mkldnn::depthwise_forward::desc> *typeDesc =
154             dynamic_cast<DescFwdImpl<mkldnn::depthwise_forward::desc> *>(desc.get());
155     if (typeDesc == nullptr) {
156         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
157     }
158     return typeDesc->getPtr();
159 }
160
161 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::rnn_forward::desc> desc) {
162     this->desc.reset(new DescFwdImpl<mkldnn::rnn_forward::desc>(desc));
163 }
164
165 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::rnn_forward::desc>() {
166     DescFwdImpl<mkldnn::rnn_forward::desc> *typeDesc =
167             dynamic_cast<DescFwdImpl<mkldnn::rnn_forward::desc> *>(desc.get());
168     if (typeDesc == nullptr) {
169         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
170     }
171     return typeDesc->getPtr();
172 }
173
174 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::eltwise_forward::desc> desc) {
175     this->desc.reset(new DescFwdImpl<mkldnn::eltwise_forward::desc>(desc));
176 }
177
178 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::eltwise_forward::desc>() {
179     DescFwdImpl<mkldnn::eltwise_forward::desc> *typeDesc =
180             dynamic_cast<DescFwdImpl<mkldnn::eltwise_forward::desc> *>(desc.get());
181         if (typeDesc == nullptr) {
182         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
183     }
184     return typeDesc->getPtr();
185 }
186
187 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::binarization_forward::desc> desc) {
188     this->desc.reset(new DescFwdImpl<mkldnn::binarization_forward::desc>(desc));
189 }
190
191 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::binarization_forward::desc>() {
192     auto *typeDesc = dynamic_cast<DescFwdImpl<mkldnn::binarization_forward::desc> *>(desc.get());
193     if (typeDesc == nullptr) {
194         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
195     }
196     return typeDesc->getPtr();
197 }
198
199 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::binary_convolution_forward::desc> desc) {
200     this->desc.reset(new DescFwdImpl<mkldnn::binary_convolution_forward::desc>(desc));
201 }
202
203 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::binary_convolution_forward::desc>() {
204     auto *typeDesc = dynamic_cast<DescFwdImpl<mkldnn::binary_convolution_forward::desc> *>(desc.get());
205     if (typeDesc == nullptr) {
206         THROW_IE_EXCEPTION << "Cannot cast descriptor!";
207     }
208     return typeDesc->getPtr();
209 }