1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
10 #include <mkldnn_types.h>
15 struct primitive_desc_iterator : public handle<mkldnn_primitive_desc_iterator_t> {
17 primitive_desc_iterator(const T &adesc, const mkldnn::primitive_attr &aattr, const engine &aengine) {
18 mkldnn_primitive_desc_iterator_t result;
19 error::wrap_c_api(mkldnn_primitive_desc_iterator_create_v2(
20 &result, &adesc.data, aattr.get(), aengine.get(), nullptr),
21 "could not create a primitive descriptor iterator");
25 template <typename T, typename TF>
26 primitive_desc_iterator(const T &adesc, const mkldnn::primitive_attr &aattr,
27 const engine &aengine, const TF &hint_fwd_primitive_desc) {
28 mkldnn_primitive_desc_iterator_t result;
29 error::wrap_c_api(mkldnn_primitive_desc_iterator_create_v2(&result,
33 hint_fwd_primitive_desc.get()),
34 "could not create a primitive descriptor iterator");
39 memory::primitive_desc fetch() const {
40 memory::primitive_desc adesc;
41 mkldnn_primitive_desc_t cdesc;
43 cdesc = mkldnn_primitive_desc_iterator_fetch(get());
50 mkldnn_status_t status = mkldnn_primitive_desc_iterator_next(get());
51 return status == mkldnn_status_t::mkldnn_success;
54 memory::primitive_desc src_primitive_desc(size_t index = 0) const {
55 memory::primitive_desc adesc;
56 memory::primitive_desc cdesc_elem;
57 mkldnn_primitive_desc_t cdesc;
58 cdesc_elem.reset(mkldnn_primitive_desc_iterator_fetch(get()));
59 const_mkldnn_primitive_desc_t const_cdesc =
60 mkldnn_primitive_desc_query_pd(cdesc_elem.get(),
61 mkldnn::convert_to_c(src_pd), index);
62 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
63 "could not clone a src primititve descriptor");
68 memory::primitive_desc dst_primitive_desc(size_t index = 0) const {
69 memory::primitive_desc adesc;
70 memory::primitive_desc cdesc_elem;
71 mkldnn_primitive_desc_t cdesc;
72 cdesc_elem.reset(mkldnn_primitive_desc_iterator_fetch(get()));
73 const_mkldnn_primitive_desc_t const_cdesc =
74 mkldnn_primitive_desc_query_pd(cdesc_elem.get(),
75 mkldnn::convert_to_c(dst_pd), index);
76 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
77 "could not clone a dst primitive descriptor");
83 memory::primitive_desc diff_src_primitive_desc(size_t index = 0) const {
84 memory::primitive_desc adesc;
85 memory::primitive_desc cdesc_elem;
86 mkldnn_primitive_desc_t cdesc;
87 cdesc_elem.reset(mkldnn_primitive_desc_iterator_fetch(get()));
88 const_mkldnn_primitive_desc_t const_cdesc =
89 mkldnn_primitive_desc_query_pd(cdesc_elem.get(),
90 mkldnn::convert_to_c(diff_src_pd), index);
91 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
92 "could not clone a diff_src primititve descriptor");
97 memory::primitive_desc weights_primitive_desc(size_t index = 0) const {
98 memory::primitive_desc adesc;
99 memory::primitive_desc cdesc_elem;
100 mkldnn_primitive_desc_t cdesc;
101 cdesc_elem.reset(mkldnn_primitive_desc_iterator_fetch(get()));
102 const_mkldnn_primitive_desc_t const_cdesc =
103 mkldnn_primitive_desc_query_pd(cdesc_elem.get(),
104 mkldnn::convert_to_c(weights_pd), index);
105 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
106 "could not clone a weights primitive descriptor");
111 memory::primitive_desc diff_dst_primitive_desc(size_t index = 0) const {
112 memory::primitive_desc adesc;
113 memory::primitive_desc cdesc_elem;
114 mkldnn_primitive_desc_t cdesc;
115 cdesc_elem.reset(mkldnn_primitive_desc_iterator_fetch(get()));
116 const_mkldnn_primitive_desc_t const_cdesc =
117 mkldnn_primitive_desc_query_pd(cdesc_elem.get(),
118 mkldnn::convert_to_c(diff_dst_pd), index);
119 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
120 "could not clone a diff_dst primitive descriptor");
125 std::string get_impl_info_str() const {
126 memory::primitive_desc cdesc_elem;
127 cdesc_elem.reset(mkldnn_primitive_desc_iterator_fetch(get()));
129 error::wrap_c_api(mkldnn_primitive_desc_query(cdesc_elem.get(),
130 mkldnn::convert_to_c(impl_info_str), 0, &info),
131 "could not query info string of primitive descriptor");
132 return std::string(info);
135 template <typename T>
136 void getPrimitiveDescriptor(T& pdesc) const {
137 mkldnn_primitive_desc_t cdesc;
139 memory::primitive_desc cdescpd;
141 cdescpd.reset(mkldnn_primitive_desc_iterator_fetch(get()));
142 error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, cdescpd.get()),
143 "could not clone a src primititve descriptor");
148 } // namespace mkldnn