Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn / desc_iterator.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include "mkldnn.hpp"
8
9 #include <string>
10 #include <mkldnn_types.h>
11 #include <mkldnn.h>
12
13 namespace mkldnn {
14
15 struct primitive_desc_iterator : public handle<mkldnn_primitive_desc_iterator_t> {
16     template <typename T>
17     primitive_desc_iterator(const T &adesc, const mkldnn::primitive_attr &aattr, const engine &aengine) {
18         mkldnn_primitive_desc_iterator_t result;
19         error::wrap_c_api(mkldnn_primitive_desc_iterator_create_v2(
20                 &result, &adesc.data, aattr.get(), aengine.get(), nullptr),
21                           "could not create a primitive descriptor iterator");
22         reset(result);
23     }
24
25     template <typename T, typename TF>
26     primitive_desc_iterator(const T &adesc, const mkldnn::primitive_attr &aattr,
27             const engine &aengine, const TF &hint_fwd_primitive_desc) {
28         mkldnn_primitive_desc_iterator_t result;
29         error::wrap_c_api(mkldnn_primitive_desc_iterator_create_v2(&result,
30                         &adesc.data,
31                         aattr.get(),
32                         aengine.get(),
33                         hint_fwd_primitive_desc.get()),
34                 "could not create a primitive descriptor iterator");
35         reset(result);
36     }
37
38
39     memory::primitive_desc fetch() const {
40         memory::primitive_desc adesc;
41         mkldnn_primitive_desc_t cdesc;
42
43         cdesc = mkldnn_primitive_desc_iterator_fetch(get());
44
45         adesc.reset(cdesc);
46         return adesc;
47     }
48
49     bool next() {
50         mkldnn_status_t status = mkldnn_primitive_desc_iterator_next(get());
51         return status == mkldnn_status_t::mkldnn_success;
52     }
53
54     memory::primitive_desc src_primitive_desc(size_t index = 0) const {
55         memory::primitive_desc adesc;
56         memory::primitive_desc cdesc_elem;
57         mkldnn_primitive_desc_t cdesc;
58         cdesc_elem.reset(mkldnn_primitive_desc_iterator_fetch(get()));
59         const_mkldnn_primitive_desc_t const_cdesc =
60                 mkldnn_primitive_desc_query_pd(cdesc_elem.get(),
61                                                mkldnn::convert_to_c(src_pd), index);
62         error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
63                           "could not clone a src primititve descriptor");
64         adesc.reset(cdesc);
65         return adesc;
66     }
67
68     memory::primitive_desc dst_primitive_desc(size_t index = 0) const {
69         memory::primitive_desc adesc;
70         memory::primitive_desc cdesc_elem;
71         mkldnn_primitive_desc_t cdesc;
72         cdesc_elem.reset(mkldnn_primitive_desc_iterator_fetch(get()));
73         const_mkldnn_primitive_desc_t const_cdesc =
74                 mkldnn_primitive_desc_query_pd(cdesc_elem.get(),
75                                                mkldnn::convert_to_c(dst_pd), index);
76         error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
77                           "could not clone a dst primitive descriptor");
78         adesc.reset(cdesc);
79         return adesc;
80     }
81
82
83     memory::primitive_desc diff_src_primitive_desc(size_t index = 0) const {
84         memory::primitive_desc adesc;
85         memory::primitive_desc cdesc_elem;
86         mkldnn_primitive_desc_t cdesc;
87         cdesc_elem.reset(mkldnn_primitive_desc_iterator_fetch(get()));
88         const_mkldnn_primitive_desc_t const_cdesc =
89                 mkldnn_primitive_desc_query_pd(cdesc_elem.get(),
90                                                mkldnn::convert_to_c(diff_src_pd), index);
91         error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
92                           "could not clone a diff_src primititve descriptor");
93         adesc.reset(cdesc);
94         return adesc;
95     }
96
97     memory::primitive_desc weights_primitive_desc(size_t index = 0) const {
98         memory::primitive_desc adesc;
99         memory::primitive_desc cdesc_elem;
100         mkldnn_primitive_desc_t cdesc;
101         cdesc_elem.reset(mkldnn_primitive_desc_iterator_fetch(get()));
102         const_mkldnn_primitive_desc_t const_cdesc =
103                 mkldnn_primitive_desc_query_pd(cdesc_elem.get(),
104                                                mkldnn::convert_to_c(weights_pd), index);
105         error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
106                           "could not clone a weights primitive descriptor");
107         adesc.reset(cdesc);
108         return adesc;
109     }
110
111     memory::primitive_desc diff_dst_primitive_desc(size_t index = 0) const {
112         memory::primitive_desc adesc;
113         memory::primitive_desc cdesc_elem;
114         mkldnn_primitive_desc_t cdesc;
115         cdesc_elem.reset(mkldnn_primitive_desc_iterator_fetch(get()));
116         const_mkldnn_primitive_desc_t const_cdesc =
117                 mkldnn_primitive_desc_query_pd(cdesc_elem.get(),
118                                                mkldnn::convert_to_c(diff_dst_pd), index);
119         error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
120                           "could not clone a diff_dst primitive descriptor");
121         adesc.reset(cdesc);
122         return adesc;
123     }
124
125     std::string get_impl_info_str() const {
126         memory::primitive_desc cdesc_elem;
127         cdesc_elem.reset(mkldnn_primitive_desc_iterator_fetch(get()));
128         const char *info;
129         error::wrap_c_api(mkldnn_primitive_desc_query(cdesc_elem.get(),
130                         mkldnn::convert_to_c(impl_info_str), 0, &info),
131                 "could not query info string of primitive descriptor");
132         return std::string(info);
133     }
134
135     template <typename T>
136     void getPrimitiveDescriptor(T& pdesc) const {
137         mkldnn_primitive_desc_t cdesc;
138
139         memory::primitive_desc cdescpd;
140
141         cdescpd.reset(mkldnn_primitive_desc_iterator_fetch(get()));
142         error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, cdescpd.get()),
143                           "could not clone a src primititve descriptor");
144         pdesc.reset(cdesc);
145     }
146 };
147
148 }  // namespace mkldnn