Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_descriptor.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <memory>
8 #include <string>
9 #include <mkldnn.hpp>
10 #include <mkldnn/desc_iterator.hpp>
11
12 class MKLDNNDescriptor {
13 public:
14     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::batch_normalization_forward::desc> desc);
15     operator std::shared_ptr<mkldnn::batch_normalization_forward::desc>();
16
17     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_forward::desc> desc);
18     operator std::shared_ptr<mkldnn::convolution_forward::desc>();
19
20     MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_backward_data::desc> desc,
21                      std::shared_ptr<mkldnn::convolution_forward::primitive_desc> prim);
22     operator std::shared_ptr<mkldnn::convolution_backward_data::desc>();
23     operator std::shared_ptr<mkldnn::convolution_forward::primitive_desc>();
24
25     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::inner_product_forward::desc> desc);
26     operator std::shared_ptr<mkldnn::inner_product_forward::desc>();
27
28     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::lrn_forward::desc> desc);
29     operator std::shared_ptr<mkldnn::lrn_forward::desc>();
30
31     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::pooling_forward::desc> desc);
32     operator std::shared_ptr<mkldnn::pooling_forward::desc>();
33
34     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::roi_pooling_forward::desc> desc);
35     operator std::shared_ptr<mkldnn::roi_pooling_forward::desc>();
36
37     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::softmax_forward::desc> desc);
38     operator std::shared_ptr<mkldnn::softmax_forward::desc>();
39
40     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::depthwise_forward::desc> desc);
41     operator std::shared_ptr<mkldnn::depthwise_forward::desc>();
42
43     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::rnn_forward::desc> desc);
44     operator std::shared_ptr<mkldnn::rnn_forward::desc>();
45
46     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::eltwise_forward::desc> desc);
47     operator std::shared_ptr<mkldnn::eltwise_forward::desc>();
48
49     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::binarization_forward::desc> desc);
50     operator std::shared_ptr<mkldnn::binarization_forward::desc>();
51
52     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::binary_convolution_forward::desc> desc);
53     operator std::shared_ptr<mkldnn::binary_convolution_forward::desc>();
54
55     mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::engine &engine,
56             const mkldnn::primitive_attr &attr = mkldnn::primitive_attr()) const;
57
58     size_t outputNumbers() const;
59     size_t inputNumbers() const;
60
61     operator bool();
62
63 private:
64     class IDesc {
65     public:
66         virtual ~IDesc() {}
67         virtual mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
68                                                                                   const mkldnn::engine &engine) const = 0;
69     };
70
71     template <class T>
72     class DescFwdImpl: public IDesc {
73         std::shared_ptr<T> desc;
74     public:
75         explicit DescFwdImpl(std::shared_ptr<T> d) : desc(d) {}
76
77         mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
78                                                                           const mkldnn::engine &engine) const override {
79             return mkldnn::primitive_desc_iterator(*desc, attr, engine);
80         }
81
82         std::shared_ptr<T>& getPtr() {
83             return desc;
84         }
85     };
86
87
88     template <class T, class P>
89     class DescBwdImpl: public IDesc {
90         std::shared_ptr<T> desc;
91         std::shared_ptr<P> prim;
92
93     public:
94         DescBwdImpl(std::shared_ptr<T> d, std::shared_ptr<P> p) : desc(d), prim(p) {}
95
96         mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
97                                                                           const mkldnn::engine &engine) const override {
98             return mkldnn::primitive_desc_iterator(*desc, attr, engine, *prim);
99         }
100
101         std::shared_ptr<T>& getPtr() {
102             return desc;
103         }
104
105         std::shared_ptr<P>& getPrimPtr() {
106             return prim;
107         }
108     };
109
110     std::shared_ptr<IDesc> desc;
111 };