1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
10 #include <mkldnn/desc_iterator.hpp>
12 class MKLDNNDescriptor {
14 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::batch_normalization_forward::desc> desc);
15 operator std::shared_ptr<mkldnn::batch_normalization_forward::desc>();
17 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_forward::desc> desc);
18 operator std::shared_ptr<mkldnn::convolution_forward::desc>();
20 MKLDNNDescriptor(std::shared_ptr<mkldnn::convolution_backward_data::desc> desc,
21 std::shared_ptr<mkldnn::convolution_forward::primitive_desc> prim);
22 operator std::shared_ptr<mkldnn::convolution_backward_data::desc>();
23 operator std::shared_ptr<mkldnn::convolution_forward::primitive_desc>();
25 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::inner_product_forward::desc> desc);
26 operator std::shared_ptr<mkldnn::inner_product_forward::desc>();
28 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::lrn_forward::desc> desc);
29 operator std::shared_ptr<mkldnn::lrn_forward::desc>();
31 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::pooling_forward::desc> desc);
32 operator std::shared_ptr<mkldnn::pooling_forward::desc>();
34 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::roi_pooling_forward::desc> desc);
35 operator std::shared_ptr<mkldnn::roi_pooling_forward::desc>();
37 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::softmax_forward::desc> desc);
38 operator std::shared_ptr<mkldnn::softmax_forward::desc>();
40 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::depthwise_forward::desc> desc);
41 operator std::shared_ptr<mkldnn::depthwise_forward::desc>();
43 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::rnn_forward::desc> desc);
44 operator std::shared_ptr<mkldnn::rnn_forward::desc>();
46 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::eltwise_forward::desc> desc);
47 operator std::shared_ptr<mkldnn::eltwise_forward::desc>();
49 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::binarization_forward::desc> desc);
50 operator std::shared_ptr<mkldnn::binarization_forward::desc>();
52 explicit MKLDNNDescriptor(std::shared_ptr<mkldnn::binary_convolution_forward::desc> desc);
53 operator std::shared_ptr<mkldnn::binary_convolution_forward::desc>();
55 mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::engine &engine,
56 const mkldnn::primitive_attr &attr = mkldnn::primitive_attr()) const;
58 size_t outputNumbers() const;
59 size_t inputNumbers() const;
67 virtual mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
68 const mkldnn::engine &engine) const = 0;
72 class DescFwdImpl: public IDesc {
73 std::shared_ptr<T> desc;
75 explicit DescFwdImpl(std::shared_ptr<T> d) : desc(d) {}
77 mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
78 const mkldnn::engine &engine) const override {
79 return mkldnn::primitive_desc_iterator(*desc, attr, engine);
82 std::shared_ptr<T>& getPtr() {
88 template <class T, class P>
89 class DescBwdImpl: public IDesc {
90 std::shared_ptr<T> desc;
91 std::shared_ptr<P> prim;
94 DescBwdImpl(std::shared_ptr<T> d, std::shared_ptr<P> p) : desc(d), prim(p) {}
96 mkldnn::primitive_desc_iterator createPrimitiveDescriptorIterator(const mkldnn::primitive_attr &attr,
97 const mkldnn::engine &engine) const override {
98 return mkldnn::primitive_desc_iterator(*desc, attr, engine, *prim);
101 std::shared_ptr<T>& getPtr() {
105 std::shared_ptr<P>& getPrimPtr() {
110 std::shared_ptr<IDesc> desc;