Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp
1
2 /*******************************************************************************
3 * Copyright 2018 Intel Corporation
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17
18 #ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP
19 #define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP
20
21 #include "c_types_map.hpp"
22 #include "cpu_deconvolution_pd.hpp"
23 #include "cpu_engine.hpp"
24 #include "cpu_reducer.hpp"
25 #include "mkldnn_thread.hpp"
26 #include "utils.hpp"
27 #include "cpu_convolution_pd.hpp"
28 #include "type_helpers.hpp"
29 #include "primitive_iterator.hpp"
30
31 #include "jit_uni_1x1_conv_utils.hpp"
32 #include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
33
34 namespace mkldnn {
35 namespace impl {
36 namespace cpu {
37
38 template <impl::data_type_t src_type, impl::data_type_t dst_type>
39 struct jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t
40         : public cpu_primitive_t {
41     struct pd_t : public cpu_deconvolution_fwd_pd_t {
42         pd_t(engine_t *engine, const deconvolution_desc_t *adesc,
43                 const primitive_attr_t *attr,
44                 const deconvolution_fwd_pd_t *hint_fwd_pd)
45             : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
46             , conv_pd_(nullptr) {}
47
48         pd_t(const pd_t &other)
49             : cpu_deconvolution_fwd_pd_t(other)
50             , conv_pd_(other.conv_pd_->clone())
51             , conv_supports_bias_(other.conv_supports_bias_) {}
52
53         ~pd_t() { delete conv_pd_; }
54
55         DECLARE_DECONVOLUTION_PD_T(
56                 jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<src_type,
57                         dst_type>);
58
59         status_t init_convolution() {
60
61             convolution_desc_t cd;
62             status_t status;
63
64             auto dd = this->desc();
65             status = conv_desc_init(&cd, prop_kind::forward_training,
66                     alg_kind::convolution_direct, &(dd->src_desc),
67                     &(dd->weights_desc), &(dd->bias_desc), &(dd->dst_desc),
68                     dd->strides, dd->dilates, dd->padding[0], dd->padding[1],
69                     dd->padding_kind);
70
71             if (status == status::success) {
72                 status = mkldnn_primitive_desc::create<
73                         typename mkldnn::impl::cpu::
74                                 jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type,
75                                         dst_type>::pd_t>(&conv_pd_,
76                         (op_desc_t *)&cd, &(this->attr_), this->engine_,
77                         nullptr);
78             }
79
80             if (status == status::success) {
81                 status = set_default_params();
82             }
83
84             return status;
85         };
86
87         virtual status_t init() override {
88             using namespace prop_kind;
89             status_t status;
90
91             assert(this->engine()->kind() == engine_kind::cpu);
92             bool ok = true && utils::one_of(this->desc()->prop_kind,
93                                       prop_kind::forward_training,
94                                       prop_kind::forward_inference)
95                     && this->desc()->alg_kind == alg_kind::deconvolution_direct
96                     && !this->has_zero_dim_memory()
97                     && this->desc()->src_desc.data_type == src_type
98                     && this->desc()->dst_desc.data_type == dst_type
99                     && this->desc()->weights_desc.data_type == data_type::s8
100                     && IMPLICATION(this->with_bias(),
101                                utils::one_of(this->desc()->bias_desc.data_type,
102                                            data_type::f32, data_type::s32,
103                                            data_type::s8, data_type::u8))
104                     && this->desc()->accum_data_type == data_type::s32;
105
106             if (ok)
107                 status = init_convolution();
108             else
109                 status = status::unimplemented;
110
111             return status;
112         }
113
114     protected:
115         virtual status_t set_default_params() {
116             using namespace memory_format;
117             auto conv_1x1_pd_ = static_cast<typename mkldnn::impl::cpu::
118                             jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type,
119                                     dst_type>::pd_t *>(conv_pd_);
120             CHECK(this->src_pd_.set_format(
121                     conv_1x1_pd_->src_pd()->desc()->format));
122             CHECK(this->dst_pd_.set_format(
123                     conv_1x1_pd_->dst_pd()->desc()->format));
124             CHECK(this->weights_pd_.set_format(
125                     conv_1x1_pd_->weights_pd()->desc()->format));
126             if (this->with_bias())
127                 CHECK(this->bias_pd_.set_format(
128                         conv_1x1_pd_->weights_pd(1)->desc()->format));
129             return status::success;
130         }
131
132         primitive_desc_t *conv_pd_;
133         bool conv_supports_bias_;
134     };
135
136     jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t(const pd_t *apd,
137             const input_vector &inputs, const output_vector &outputs)
138         : cpu_primitive_t(apd, inputs, outputs), conv_p_(nullptr) {}
139
140     ~jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t() {
141         delete this->conv_p_;
142     }
143
144     virtual void execute(event_t *e) const {
145         switch (pd()->desc()->prop_kind) {
146         case prop_kind::forward_training:
147         case prop_kind::forward_inference: (conv_p_)->execute(e); break;
148         default: assert(!"invalid prop_kind");
149         }
150         e->set_state(event_t::ready);
151     }
152
153 private:
154     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
155     primitive_t *conv_p_;
156 };
157
158 }
159 }
160 }
161
162 #endif /* CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP */