1 /*******************************************************************************
2 * Copyright 2017 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
18 #include "mkldnn_thread.hpp"
21 #include "jit_generator.hpp"
22 #include "type_helpers.hpp"
24 #include "jit_uni_softmax.hpp"
30 using namespace Xbyak;
31 using namespace mkldnn::impl::status;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::utils;
35 template <cpu_isa_t isa>
36 jit_uni_softmax_fwd_t<isa>::jit_uni_softmax_fwd_t(const pd_t *apd,
37 const input_vector &inputs, const output_vector &outputs)
38 : cpu_primitive_t(apd, inputs, outputs)
40 kernel_ = new jit_uni_softmax_kernel_f32<isa>(pd()->jpp_);
43 template <cpu_isa_t isa>
44 jit_uni_softmax_fwd_t<isa>::~jit_uni_softmax_fwd_t() {
48 template <cpu_isa_t isa>
49 void jit_uni_softmax_fwd_t<isa>::execute_forward() const
51 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
52 auto dst = reinterpret_cast<data_t *>(this->memory(0));
54 const memory_desc_wrapper data_d(pd()->src_pd());
56 const auto &jpp = pd()->jpp_;
58 size_t outer_size = utils::array_product(pd()->src_pd()->desc()->dims, pd()->desc()->softmax_axis);
60 size_t dim = jpp.channels * jpp.inner_size;
62 if (jpp.inner_size > 1) {
63 auto ker = [&](const int ithr, const int nthr) {
64 size_t start{0}, end{0};
66 const size_t work_amount = outer_size;
67 balance211(work_amount, nthr, ithr, start, end);
70 nd_iterator_init(start, ou, outer_size);
72 for (size_t iwork = start; iwork < end; ++iwork) {
73 auto args = jit_softmax_call_s();
74 args.channels = jpp.channels;
75 args.work = jpp.inner_size;
76 size_t off = data_d.off_l(ou * dim);
82 nd_iterator_step(ou, outer_size);
88 auto ker = [&](const int ithr, const int nthr) {
89 size_t start{0}, end{0};
91 int ou_blocks = div_up(outer_size, jpp.outer_block);
93 const size_t work_amount = ou_blocks;
94 balance211(work_amount, nthr, ithr, start, end);
97 nd_iterator_init(start, oub, ou_blocks);
99 for (size_t iwork = start; iwork < end; ++iwork) {
100 size_t work = nstl::min(jpp.outer_block, outer_size - oub * jpp.outer_block);
102 auto args = jit_softmax_call_s();
103 args.channels = jpp.channels;
105 size_t off = data_d.off_l(oub * jpp.outer_block * dim);
106 args.src = src + off;
107 args.dst = dst + off;
111 nd_iterator_step(oub, ou_blocks);
119 template struct jit_uni_softmax_fwd_t<sse42>;
120 template struct jit_uni_softmax_fwd_t<avx2>;
121 template struct jit_uni_softmax_fwd_t<avx512_common>;