1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_DECONVOLUTION_FWD_PD_HPP
18 #define CPU_DECONVOLUTION_FWD_PD_HPP
22 #include "c_types_map.hpp"
23 #include "deconvolution_pd.hpp"
24 #include "convolution_pd.hpp"
25 #include "cpu_engine.hpp"
26 #include "cpu_memory.hpp"
27 #include "cpu_primitive.hpp"
28 #include "type_helpers.hpp"
31 #define DECLARE_DECONVOLUTION_PD_t(...) \
32 virtual pd_t *clone() const override { return new pd_t(*this); } \
33 virtual status_t create_primitive(primitive_t **primitive, \
34 const primitive_at_t *inputs, const primitive_t **outputs) \
36 double ms = get_msec(); \
37 using namespace prop_kind; \
38 primitive_t::input_vector ins(inputs, inputs + this->n_inputs()); \
39 primitive_t::output_vector outs(outputs, outputs + this->n_outputs()); \
40 auto ret = safe_ptr_assign<primitive_t>( \
41 *primitive, new (__VA_ARGS__)(this, ins, outs)); \
42 primitive_t *conv_primitive; \
43 if (this->desc()->prop_kind == backward_weights) { \
44 primitive_at_t conv_inputs[2]; \
45 conv_inputs[0] = inputs[1]; \
46 conv_inputs[1] = inputs[0]; \
47 conv_pd_->create_primitive( \
48 (&conv_primitive), conv_inputs, outputs); \
50 conv_pd_->create_primitive((&conv_primitive), inputs, outputs); \
51 ((__VA_ARGS__ *)(*primitive))->conv_p_ = conv_primitive; \
52 ms = get_msec() - ms; \
53 if (mkldnn_verbose()->level >= 2) { \
54 printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
59 virtual const char *name() const override { return conv_pd_->name(); }
61 #define DECLARE_DECONVOLUTION_PD_T(...) DECLARE_DECONVOLUTION_PD_t(__VA_ARGS__)
67 struct cpu_deconvolution_fwd_pd_t: public deconvolution_fwd_pd_t {
68 using cpu_memory_pd_t = cpu_memory_t::pd_t;
70 cpu_deconvolution_fwd_pd_t(engine_t *engine,
71 const deconvolution_desc_t *adesc,
72 const primitive_attr_t *attr,
73 const deconvolution_fwd_pd_t *hint_fwd_pd)
74 : deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
75 , src_pd_(this->engine_, &this->desc_.src_desc)
76 , dst_pd_(this->engine_, &this->desc_.dst_desc)
77 , weights_pd_(this->engine_, &this->desc_.weights_desc)
78 , bias_pd_(this->engine_, &this->desc_.bias_desc) {}
79 virtual ~cpu_deconvolution_fwd_pd_t() {}
81 virtual const cpu_memory_pd_t *src_pd(int index = 0) const override
82 { return index == 0 ? &src_pd_ : nullptr; }
83 virtual const cpu_memory_pd_t *dst_pd(int index = 0) const override
84 { return index == 0 ? &dst_pd_ : nullptr; }
85 virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override {
86 if (index == 0) return &weights_pd_;
87 if (index == 1 && this->with_bias()) return &bias_pd_;
92 cpu_memory_pd_t src_pd_, dst_pd_;
93 cpu_memory_pd_t weights_pd_, bias_pd_;
96 struct cpu_deconvolution_bwd_data_pd_t: public deconvolution_bwd_data_pd_t {
97 using cpu_memory_pd_t = cpu_memory_t::pd_t;
99 cpu_deconvolution_bwd_data_pd_t(engine_t *engine,
100 const deconvolution_desc_t *adesc,
101 const primitive_attr_t *attr,
102 const deconvolution_fwd_pd_t *hint_fwd_pd)
103 : deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
104 , diff_src_pd_(this->engine_, &this->desc_.diff_src_desc)
105 , diff_dst_pd_(this->engine_, &this->desc_.diff_dst_desc)
106 , weights_pd_(this->engine_, &this->desc_.weights_desc) {}
107 virtual ~cpu_deconvolution_bwd_data_pd_t() {}
109 virtual const cpu_memory_pd_t *diff_src_pd(int index = 0) const override
110 { return index == 0 ? &diff_src_pd_ : nullptr; }
111 virtual const cpu_memory_pd_t *diff_dst_pd(int index = 0) const override
112 { return index == 0 ? &diff_dst_pd_ : nullptr; }
113 virtual const cpu_memory_pd_t *weights_pd(int index = 0) const override
114 { return index == 0 ? &weights_pd_ : nullptr; }
117 cpu_memory_pd_t diff_src_pd_, diff_dst_pd_;
118 cpu_memory_pd_t weights_pd_;
121 struct cpu_deconvolution_bwd_weights_pd_t: public deconvolution_bwd_weights_pd_t
123 using cpu_memory_pd_t = cpu_memory_t::pd_t;
125 cpu_deconvolution_bwd_weights_pd_t(engine_t *engine,
126 const deconvolution_desc_t *adesc,
127 const primitive_attr_t *attr,
128 const deconvolution_fwd_pd_t *hint_fwd_pd)
129 : deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
130 , src_pd_(this->engine_, &this->desc_.src_desc)
131 , diff_dst_pd_(this->engine_, &this->desc_.diff_dst_desc)
132 , diff_weights_pd_(this->engine_, &this->desc_.diff_weights_desc)
133 , diff_bias_pd_(this->engine_, &this->desc_.diff_bias_desc) {}
134 virtual ~cpu_deconvolution_bwd_weights_pd_t() {}
136 virtual const cpu_memory_pd_t *src_pd(int index = 0) const override
137 { return index == 0 ? &src_pd_ : nullptr; }
138 virtual const cpu_memory_pd_t *diff_dst_pd(int index = 0) const override
139 { return index == 0 ? &diff_dst_pd_ : nullptr; }
140 virtual const cpu_memory_pd_t *diff_weights_pd(int index = 0) const override
142 if (index == 0) return &diff_weights_pd_;
143 if (index == 1 && this->with_bias()) return &diff_bias_pd_;
148 cpu_memory_pd_t src_pd_;
149 cpu_memory_pd_t diff_dst_pd_;
150 cpu_memory_pd_t diff_weights_pd_, diff_bias_pd_;
159 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s