Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_softmax.cpp
1 /*******************************************************************************
2 * Copyright 2017 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18 #include "mkldnn_thread.hpp"
19 #include "nstl.hpp"
20 #include "utils.hpp"
21 #include "jit_generator.hpp"
22 #include "type_helpers.hpp"
23
24 #include "jit_uni_softmax.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace Xbyak;
31 using namespace mkldnn::impl::status;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::utils;
34
35 template <cpu_isa_t isa>
36 jit_uni_softmax_fwd_t<isa>::jit_uni_softmax_fwd_t(const pd_t *apd,
37         const input_vector &inputs, const output_vector &outputs)
38         : cpu_primitive_t(apd, inputs, outputs)
39 {
40     kernel_ = new jit_uni_softmax_kernel_f32<isa>(pd()->jpp_);
41 }
42
43 template <cpu_isa_t isa>
44 jit_uni_softmax_fwd_t<isa>::~jit_uni_softmax_fwd_t() {
45     delete kernel_;
46 }
47
48 template <cpu_isa_t isa>
49 void jit_uni_softmax_fwd_t<isa>::execute_forward() const
50 {
51     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
52     auto dst = reinterpret_cast<data_t *>(this->memory(0));
53
54     const memory_desc_wrapper data_d(pd()->src_pd());
55
56     const auto &jpp = pd()->jpp_;
57
58     size_t outer_size = utils::array_product(pd()->src_pd()->desc()->dims, pd()->desc()->softmax_axis);
59
60     size_t dim = jpp.channels * jpp.inner_size;
61
62     if (jpp.inner_size > 1) {
63         auto ker = [&](const int ithr, const int nthr) {
64             size_t start{0}, end{0};
65
66             const size_t work_amount = outer_size;
67             balance211(work_amount, nthr, ithr, start, end);
68
69             size_t ou{0};
70             nd_iterator_init(start, ou, outer_size);
71
72             for (size_t iwork = start; iwork < end; ++iwork) {
73                 auto args = jit_softmax_call_s();
74                 args.channels = jpp.channels;
75                 args.work = jpp.inner_size;
76                 size_t off = data_d.off_l(ou * dim);
77                 args.src = src + off;
78                 args.dst = dst + off;
79
80                 (*kernel_)(&args);
81
82                 nd_iterator_step(ou, outer_size);
83             }
84         };
85
86         parallel(0, ker);
87     } else {
88         auto ker = [&](const int ithr, const int nthr) {
89             size_t start{0}, end{0};
90
91             int ou_blocks = div_up(outer_size, jpp.outer_block);
92
93             const size_t work_amount = ou_blocks;
94             balance211(work_amount, nthr, ithr, start, end);
95
96             size_t oub{0};
97             nd_iterator_init(start, oub, ou_blocks);
98
99             for (size_t iwork = start; iwork < end; ++iwork) {
100                 size_t work = nstl::min(jpp.outer_block, outer_size - oub * jpp.outer_block);
101
102                 auto args = jit_softmax_call_s();
103                 args.channels = jpp.channels;
104                 args.work = work;
105                 size_t off = data_d.off_l(oub * jpp.outer_block * dim);
106                 args.src = src + off;
107                 args.dst = dst + off;
108
109                 (*kernel_)(&args);
110
111                 nd_iterator_step(oub, ou_blocks);
112             }
113         };
114
115         parallel(0, ker);
116     }
117 }
118
119 template struct jit_uni_softmax_fwd_t<sse42>;
120 template struct jit_uni_softmax_fwd_t<avx2>;
121 template struct jit_uni_softmax_fwd_t<avx512_common>;
122
123 }
124 }
125 }