1 // Copyright (C) 2018 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
10 #include <mkldnn/desc_iterator.hpp>
12 class MKLDNNDescriptor {
14 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::batch_normalization_forward::desc> desc);
15 operator std::shared_ptr<mkldnn::batch_normalization_forward::desc>();
17 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_forward::desc> desc);
18 operator std::shared_ptr<mkldnn::convolution_forward::desc>();
20 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_relu_forward::desc> desc);
21 operator std::shared_ptr<mkldnn::convolution_relu_forward::desc>();
23 MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_backward_data::desc> desc,
24 std::shared_ptr<mkldnn::convolution_forward::primitive_desc> prim);
25 operator std::shared_ptr<mkldnn::convolution_backward_data::desc>();
26 operator std::shared_ptr<mkldnn::convolution_forward::primitive_desc>();
28 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::inner_product_forward::desc> desc);
29 operator std::shared_ptr<mkldnn::inner_product_forward::desc>();
31 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::lrn_forward::desc> desc);
32 operator std::shared_ptr<mkldnn::lrn_forward::desc>();
34 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::pooling_forward::desc> desc);
35 operator std::shared_ptr<mkldnn::pooling_forward::desc>();
37 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::relu_forward::desc> desc);
38 operator std::shared_ptr<mkldnn::relu_forward::desc>();
40 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::roi_pooling_forward::desc> desc);
41 operator std::shared_ptr<mkldnn::roi_pooling_forward::desc>();
43 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::softmax_forward::desc> desc);
44 operator std::shared_ptr<mkldnn::softmax_forward::desc>();
46 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::depthwise_forward::desc> desc);
47 operator std::shared_ptr<mkldnn::depthwise_forward::desc>();
49 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::rnn_forward::desc> desc);
50 operator std::shared_ptr<mkldnn::rnn_forward::desc>();
52 mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::engine &engine,
53 const mkldnn::primitive_attr &attr = mkldnn::primitive_attr()) const;
55 size_t outputNumbers() const;
56 size_t inputNumbers() const;
64 virtual mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
65 const mkldnn::engine &engine) const = 0;
69 class DescFwdImpl: public IDesc {
70 std::shared_ptr<T> desc;
72 explicit DescFwdImpl(std::shared_ptr<T> d) : desc(d) {}
74 mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
75 const mkldnn::engine &engine) const override {
76 return mkldnn::primitive_desc_iterator(*desc, attr, engine);
79 std::shared_ptr<T>& getPtr() {
85 template <class T, class P>
86 class DescBwdImpl: public IDesc {
87 std::shared_ptr<T> desc;
88 std::shared_ptr<P> prim;
91 DescBwdImpl(std::shared_ptr<T> d, std::shared_ptr<P> p) : desc(d), prim(p) {}
93 mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
94 const mkldnn::engine &engine) const override {
95 return mkldnn::primitive_desc_iterator(*desc, attr, engine, *prim);
98 std::shared_ptr<T>& getPtr() {
102 std::shared_ptr<P>& getPrimPtr() {
107 std::shared_ptr<IDesc> desc;