Publishing R5 content (#72)
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_descriptor.h
1 // Copyright (C) 2018 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <memory>
8 #include <string>
9 #include <mkldnn.hpp>
10 #include <mkldnn/desc_iterator.hpp>
11
12 class MKLDNNDescriptor {
13 public:
14     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::batch_normalization_forward::desc> desc);
15     operator std::shared_ptr<mkldnn::batch_normalization_forward::desc>();
16
17     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_forward::desc> desc);
18     operator std::shared_ptr<mkldnn::convolution_forward::desc>();
19
20     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_relu_forward::desc> desc);
21     operator std::shared_ptr<mkldnn::convolution_relu_forward::desc>();
22
23     MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_backward_data::desc> desc,
24                      std::shared_ptr<mkldnn::convolution_forward::primitive_desc> prim);
25     operator std::shared_ptr<mkldnn::convolution_backward_data::desc>();
26     operator std::shared_ptr<mkldnn::convolution_forward::primitive_desc>();
27
28     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::inner_product_forward::desc> desc);
29     operator std::shared_ptr<mkldnn::inner_product_forward::desc>();
30
31     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::lrn_forward::desc> desc);
32     operator std::shared_ptr<mkldnn::lrn_forward::desc>();
33
34     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::pooling_forward::desc> desc);
35     operator std::shared_ptr<mkldnn::pooling_forward::desc>();
36
37     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::relu_forward::desc> desc);
38     operator std::shared_ptr<mkldnn::relu_forward::desc>();
39
40     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::roi_pooling_forward::desc> desc);
41     operator std::shared_ptr<mkldnn::roi_pooling_forward::desc>();
42
43     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::softmax_forward::desc> desc);
44     operator std::shared_ptr<mkldnn::softmax_forward::desc>();
45
46     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::depthwise_forward::desc> desc);
47     operator std::shared_ptr<mkldnn::depthwise_forward::desc>();
48
49     explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::rnn_forward::desc> desc);
50     operator std::shared_ptr<mkldnn::rnn_forward::desc>();
51
52     mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::engine &engine,
53             const mkldnn::primitive_attr &attr = mkldnn::primitive_attr()) const;
54
55     size_t outputNumbers() const;
56     size_t inputNumbers() const;
57
58     operator bool();
59
60 private:
61     class IDesc {
62     public:
63         virtual ~IDesc() {}
64         virtual mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
65                                                                                   const mkldnn::engine &engine) const = 0;
66     };
67
68     template <class T>
69     class DescFwdImpl: public IDesc {
70         std::shared_ptr<T> desc;
71     public:
72         explicit DescFwdImpl(std::shared_ptr<T> d) : desc(d) {}
73
74         mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
75                                                                           const mkldnn::engine &engine) const override {
76             return mkldnn::primitive_desc_iterator(*desc, attr, engine);
77         }
78
79         std::shared_ptr<T>& getPtr() {
80             return desc;
81         }
82     };
83
84
85     template <class T, class P>
86     class DescBwdImpl: public IDesc {
87         std::shared_ptr<T> desc;
88         std::shared_ptr<P> prim;
89
90     public:
91         DescBwdImpl(std::shared_ptr<T> d, std::shared_ptr<P> p) : desc(d), prim(p) {}
92
93         mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
94                                                                           const mkldnn::engine &engine) const override {
95             return mkldnn::primitive_desc_iterator(*desc, attr, engine, *prim);
96         }
97
98         std::shared_ptr<T>& getPtr() {
99             return desc;
100         }
101
102         std::shared_ptr<P>& getPrimPtr() {
103             return prim;
104         }
105     };
106
107     std::shared_ptr<IDesc> desc;
108 };