1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <details/ie_exception.hpp>
6 #include "mkldnn_descriptor.h"
8 mkldnn::primitive_desc_iterator MKLDNNDescriptor::createPrimitiveDescriptorIterator(const mkldnn::engine &engine,
9 const mkldnn::primitive_attr &attr) const {
10 return desc->createPrimitiveDescriptorIterator(attr, engine);
13 MKLDNNDescriptor::operator bool() {
14 return desc.get() != nullptr;
17 size_t MKLDNNDescriptor::inputNumbers() const {
18 DescFwdImpl<mkldnn::roi_pooling_forward::desc> *roiPooling =
19 dynamic_cast<DescFwdImpl<mkldnn::roi_pooling_forward::desc> *>(desc.get());
20 if (roiPooling != nullptr) {
21 return roiPooling->getPtr()->c_api_inputs.size();
26 size_t MKLDNNDescriptor::outputNumbers() const {
30 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::batch_normalization_forward::desc> desc) {
31 this->desc.reset(new DescFwdImpl<mkldnn::batch_normalization_forward::desc>(desc));
34 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::batch_normalization_forward::desc>() {
35 DescFwdImpl<mkldnn::batch_normalization_forward::desc> *typeDesc =
36 dynamic_cast<DescFwdImpl<mkldnn::batch_normalization_forward::desc> *>(desc.get());
37 if (typeDesc == nullptr) {
38 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
40 return typeDesc->getPtr();
43 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_forward::desc> desc) {
44 this->desc.reset(new DescFwdImpl<mkldnn::convolution_forward::desc>(desc));
47 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::convolution_forward::desc>() {
48 DescFwdImpl<mkldnn::convolution_forward::desc> *typeDesc =
49 dynamic_cast<DescFwdImpl<mkldnn::convolution_forward::desc> *>(desc.get());
50 if (typeDesc == nullptr) {
51 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
53 return typeDesc->getPtr();
56 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_backward_data::desc> desc,
57 std::shared_ptr<mkldnn::convolution_forward::primitive_desc> prim) {
59 new DescBwdImpl<mkldnn::convolution_backward_data::desc,
60 mkldnn::convolution_forward::primitive_desc>(desc, prim));
63 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::convolution_backward_data::desc>() {
64 DescBwdImpl<mkldnn::convolution_backward_data::desc, mkldnn::convolution_forward::primitive_desc> *typeDesc =
65 dynamic_cast<DescBwdImpl<mkldnn::convolution_backward_data::desc,
66 mkldnn::convolution_forward::primitive_desc> *>(desc.get());
67 if (typeDesc == nullptr) {
68 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
70 return typeDesc->getPtr();
73 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::convolution_forward::primitive_desc>() {
74 DescBwdImpl<mkldnn::convolution_backward_data::desc, mkldnn::convolution_forward::primitive_desc> *typeDesc =
75 dynamic_cast<DescBwdImpl<mkldnn::convolution_backward_data::desc,
76 mkldnn::convolution_forward::primitive_desc> *>(desc.get());
77 if (typeDesc == nullptr) {
78 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
80 return typeDesc->getPrimPtr();
83 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::inner_product_forward::desc> desc) {
84 this->desc.reset(new DescFwdImpl<mkldnn::inner_product_forward::desc>(desc));
87 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::inner_product_forward::desc>() {
88 DescFwdImpl<mkldnn::inner_product_forward::desc> *typeDesc =
89 dynamic_cast<DescFwdImpl<mkldnn::inner_product_forward::desc> *>(desc.get());
90 if (typeDesc == nullptr) {
91 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
93 return typeDesc->getPtr();
96 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::lrn_forward::desc> desc) {
97 this->desc.reset(new DescFwdImpl<mkldnn::lrn_forward::desc>(desc));
100 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::lrn_forward::desc>() {
101 DescFwdImpl<mkldnn::lrn_forward::desc> *typeDesc =
102 dynamic_cast<DescFwdImpl<mkldnn::lrn_forward::desc> *>(desc.get());
103 if (typeDesc == nullptr) {
104 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
106 return typeDesc->getPtr();
109 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::pooling_forward::desc> desc) {
110 this->desc.reset(new DescFwdImpl<mkldnn::pooling_forward::desc>(desc));
113 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::pooling_forward::desc>() {
114 DescFwdImpl<mkldnn::pooling_forward::desc> *typeDesc =
115 dynamic_cast<DescFwdImpl<mkldnn::pooling_forward::desc> *>(desc.get());
116 if (typeDesc == nullptr) {
117 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
119 return typeDesc->getPtr();
122 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::roi_pooling_forward::desc> desc) {
123 this->desc.reset(new DescFwdImpl<mkldnn::roi_pooling_forward::desc>(desc));
126 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::roi_pooling_forward::desc>() {
127 DescFwdImpl<mkldnn::roi_pooling_forward::desc> *typeDesc =
128 dynamic_cast<DescFwdImpl<mkldnn::roi_pooling_forward::desc> *>(desc.get());
129 if (typeDesc == nullptr) {
130 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
132 return typeDesc->getPtr();
135 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::softmax_forward::desc> desc) {
136 this->desc.reset(new DescFwdImpl<mkldnn::softmax_forward::desc>(desc));
139 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::softmax_forward::desc>() {
140 DescFwdImpl<mkldnn::softmax_forward::desc> *typeDesc =
141 dynamic_cast<DescFwdImpl<mkldnn::softmax_forward::desc> *>(desc.get());
142 if (typeDesc == nullptr) {
143 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
145 return typeDesc->getPtr();
148 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::depthwise_forward::desc> desc) {
149 this->desc.reset(new DescFwdImpl<mkldnn::depthwise_forward::desc>(desc));
152 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::depthwise_forward::desc>() {
153 DescFwdImpl<mkldnn::depthwise_forward::desc> *typeDesc =
154 dynamic_cast<DescFwdImpl<mkldnn::depthwise_forward::desc> *>(desc.get());
155 if (typeDesc == nullptr) {
156 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
158 return typeDesc->getPtr();
161 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::rnn_forward::desc> desc) {
162 this->desc.reset(new DescFwdImpl<mkldnn::rnn_forward::desc>(desc));
165 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::rnn_forward::desc>() {
166 DescFwdImpl<mkldnn::rnn_forward::desc> *typeDesc =
167 dynamic_cast<DescFwdImpl<mkldnn::rnn_forward::desc> *>(desc.get());
168 if (typeDesc == nullptr) {
169 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
171 return typeDesc->getPtr();
174 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::eltwise_forward::desc> desc) {
175 this->desc.reset(new DescFwdImpl<mkldnn::eltwise_forward::desc>(desc));
178 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::eltwise_forward::desc>() {
179 DescFwdImpl<mkldnn::eltwise_forward::desc> *typeDesc =
180 dynamic_cast<DescFwdImpl<mkldnn::eltwise_forward::desc> *>(desc.get());
181 if (typeDesc == nullptr) {
182 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
184 return typeDesc->getPtr();
187 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::binarization_forward::desc> desc) {
188 this->desc.reset(new DescFwdImpl<mkldnn::binarization_forward::desc>(desc));
191 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::binarization_forward::desc>() {
192 auto *typeDesc = dynamic_cast<DescFwdImpl<mkldnn::binarization_forward::desc> *>(desc.get());
193 if (typeDesc == nullptr) {
194 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
196 return typeDesc->getPtr();
199 MKLDNNDescriptor::MKLDNNDescriptor(std::shared_ptr<mkldnn::binary_convolution_forward::desc> desc) {
200 this->desc.reset(new DescFwdImpl<mkldnn::binary_convolution_forward::desc>(desc));
203 MKLDNNDescriptor::operator std::shared_ptr<mkldnn::binary_convolution_forward::desc>() {
204 auto *typeDesc = dynamic_cast<DescFwdImpl<mkldnn::binary_convolution_forward::desc> *>(desc.get());
205 if (typeDesc == nullptr) {
206 THROW_IE_EXCEPTION << "Cannot cast descriptor!";
208 return typeDesc->getPtr();