Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_softmax.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_REF_SOFTMAX_HPP
18 #define CPU_REF_SOFTMAX_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "memory_tracking.hpp"
24 #include "type_helpers.hpp"
25 #include "utils.hpp"
26
27 #include "cpu_softmax_pd.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 template <impl::data_type_t data_type>
34 struct ref_softmax_fwd_t: public cpu_primitive_t {
35     struct pd_t: public cpu_softmax_fwd_pd_t {
36         pd_t(engine_t *engine, const softmax_desc_t *adesc,
37                 const primitive_attr_t *attr,
38                 const softmax_fwd_pd_t *hint_fwd_pd)
39             : cpu_softmax_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
40
41         DECLARE_COMMON_PD_T("ref:any", ref_softmax_fwd_t);
42
43         virtual status_t init() override {
44             using namespace prop_kind;
45             assert(engine()->kind() == engine_kind::cpu);
46             bool ok = true
47                 && utils::one_of(desc()->prop_kind, forward_inference,
48                         forward_training)
49                 && data_pd_.desc()->data_type == data_type
50                 && attr()->has_default_values();
51             if (!ok) return status::unimplemented;
52
53             init_scratchpad();
54
55             return status::success;
56         }
57
58     private:
59         void init_scratchpad() {
60             const int inner_size = utils::array_product(
61                     desc()->data_desc.dims + desc()->softmax_axis + 1,
62                     desc()->data_desc.ndims - desc()->softmax_axis - 1);
63
64             if (inner_size > 1) {
65                 auto scratchpad = scratchpad_registry().registrar();
66                 scratchpad.book(memory_tracking::names::key_softmax_reduction,
67                         sizeof(data_t) * 2 * inner_size);
68             }
69         }
70     };
71
72     ref_softmax_fwd_t(const pd_t *apd, const input_vector &inputs,
73             const output_vector &outputs)
74         : cpu_primitive_t(apd, inputs, outputs)
75     {
76         auto ndims = pd()->desc()->data_desc.ndims;
77         auto dims = pd()->desc()->data_desc.dims;
78         auto axis = pd()->desc()->softmax_axis;
79
80         outer_size_ = utils::array_product(dims, axis);
81         channels_ = dims[axis];
82         inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1);
83
84         const memory_desc_wrapper data_d(pd()->src_pd());
85         use_dense_ = inner_size_ == 1 && data_d.is_dense()
86             && data_d.blocking_desc().block_dims[axis] == 1
87             && data_d.blocking_desc().strides[0][axis] == 1;
88     }
89     ~ref_softmax_fwd_t() {}
90
91     typedef typename prec_traits<data_type>::type data_t;
92
93     virtual void execute(event_t *e) const {
94         if (use_dense_) execute_forward_dense();
95         else execute_forward_generic();
96         e->set_state(event_t::ready);
97     }
98
99 private:
100     void execute_forward_dense() const;
101     void execute_forward_generic() const;
102
103     void _max(int n, const data_t *x, data_t *max_data) const;
104     void _sub(int n, data_t alpha, const data_t *x, data_t *y) const;
105     void _exp(int n, const data_t *a, data_t *r) const;
106     void _exp_parallel(int n, const data_t *a, data_t *r) const;
107     void _sum(int n, const data_t *x, data_t *sum_data) const;
108     void _scal(int n, data_t alpha, data_t *x) const;
109     void _scal_parallel(int n, data_t alpha, data_t *x) const;
110
111     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
112
113     bool use_dense_;
114     int outer_size_, channels_, inner_size_;
115 };
116
117 template <impl::data_type_t data_type>
118 struct ref_softmax_bwd_t: public cpu_primitive_t {
119     struct pd_t: public cpu_softmax_bwd_pd_t {
120         pd_t(engine_t *engine, const softmax_desc_t *adesc,
121                 const primitive_attr_t *attr,
122                 const softmax_fwd_pd_t *hint_fwd_pd)
123             : cpu_softmax_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
124
125         DECLARE_COMMON_PD_T("ref:any", ref_softmax_bwd_t);
126
127         virtual status_t init() override {
128             using namespace prop_kind;
129             assert(engine()->kind() == engine_kind::cpu);
130             bool ok = true
131                 && utils::one_of(desc()->prop_kind, backward_data)
132                 && diff_src_pd_.desc()->data_type == data_type
133                 && diff_dst_pd_.desc()->data_type == data_type
134                 && attr()->has_default_values();
135             if (!ok) return status::unimplemented;
136
137             return status::success;
138         }
139     };
140
141     ref_softmax_bwd_t(const pd_t *apd, const input_vector &inputs,
142             const output_vector &outputs)
143         : cpu_primitive_t(apd, inputs, outputs) {
144         auto dims = pd()->desc()->diff_desc.dims;
145         auto axis = pd()->desc()->softmax_axis;
146         auto ndims = pd()->desc()->diff_desc.ndims;
147
148         outer_size_ = utils::array_product(dims, axis);
149         channels_ = dims[axis];
150         inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1);
151
152         // Diff desc as well as data desc whould be checked
153         const memory_desc_wrapper data_d(pd()->dst_pd());
154         const memory_desc_wrapper diff_d(pd()->diff_dst_pd());
155         use_dense_ = true
156             && inner_size_ == 1
157             && diff_d == data_d
158             && diff_d.is_dense()
159             && diff_d.blocking_desc().block_dims[axis] == 1
160             && diff_d.blocking_desc().strides[0][axis] == 1;
161     }
162     ~ref_softmax_bwd_t() {}
163
164     typedef typename prec_traits<data_type>::type data_t;
165
166     virtual void execute(event_t *e) const {
167         if (use_dense_) execute_backward_dense();
168         else execute_backward_generic();
169         e->set_state(event_t::ready);
170     }
171
172 private:
173     void execute_backward_dense() const;
174     void execute_backward_generic() const;
175     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
176
177     bool use_dense_;
178     int outer_size_, channels_, inner_size_;
179 };
180
181
182 }
183 }
184 }
185
186 #endif
187
188 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s