1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
19 #include "ie_parallel.hpp"
23 void softmax_many_batches(const float *src_data, float *dst_data, int B, int C, int H, int W) {
24 InferenceEngine::parallel_for(B * H * W, [&](size_t i) {
25 const float *psrc = src_data + (i / (H * W)) * C * H * W - (i / (H * W)) * H * W;
26 float *pdst = dst_data + (i / (H * W)) * C * H * W - (i / (H * W)) * H * W;
29 for (int c = 0; c < C; c++) {
30 float val = psrc[c * H * W + i];
31 if (val > max) max = val;
35 for (int c = 0; c < C; c++) {
36 pdst[c * H * W + i] = exp(psrc[c * H * W + i] - max);
37 expSum += pdst[c * H * W + i];
40 for (int c = 0; c < C; c++) {
41 pdst[c * H * W + i] = pdst[c * H * W + i] / expSum;
47 void softmax_generic(const float *src_data, float *dst_data, int B, int C, int H, int W) {
48 for (int b = 0; b < B; b++) {
49 #if defined(HAVE_AVX2)
50 for (int i = 0; i <= H*W - 8; i += 8) {
51 __m256 vmax = _mm256_loadu_ps(src_data + b*C*H*W + i);
52 for (int c = 0; c < C; c++) {
53 __m256 vval = _mm256_loadu_ps(src_data + b*C*H*W + c*H*W + i);
54 __m256 vmask = _mm256_cmp_ps(vval, vmax, _CMP_GT_OS);
55 vmax = _mm256_blendv_ps(vmax, vval, vmask);
58 __m256 vexpSum = _mm256_setzero_ps();
59 for (int c = 0; c < C; c++) {
60 __m256 vval = _mm256_loadu_ps(src_data + b*C*H*W + c*H*W + i);
62 __m256 vres = _avx_fast_exp_ps(_mm256_sub_ps(vval, vmax));
64 __m256 vres = _avx_opt_exp_ps(_mm256_sub_ps(vval, vmax));
66 vexpSum = _mm256_add_ps(vexpSum, vres);
67 _mm256_storeu_ps(dst_data + b*C*H*W + c*H*W + i, vres);
70 for (int c = 0; c < C; c++) {
71 __m256 vval = _mm256_loadu_ps(dst_data + b*C*H*W + c*H*W + i);
72 _mm256_storeu_ps(dst_data + b*C*H*W + c*H*W + i, _mm256_div_ps(vval, vexpSum));
75 #elif defined(HAVE_SSE)
76 for (int i = 0; i <= H*W - 4; i += 4) {
77 __m128 vmax = _mm_loadu_ps(src_data + b*C*H*W + i);
78 for (int c = 0; c < C; c++) {
79 __m128 vval = _mm_loadu_ps(src_data + b*C*H*W + c*H*W + i);
80 __m128 vmask = _mm_cmpgt_ps(vval, vmax);
81 vmax = _mm_blendv_ps(vmax, vval, vmask);
84 __m128 vexpSum = _mm_setzero_ps();
85 for (int c = 0; c < C; c++) {
86 __m128 vval = _mm_loadu_ps(src_data + b*C*H*W + c*H*W + i);
88 __m128 vres = _sse_fast_exp_ps(_mm_sub_ps(vval, vmax));
90 __m128 vres = _sse_opt_exp_ps(_mm_sub_ps(vval, vmax));
92 vexpSum = _mm_add_ps(vexpSum, vres);
93 _mm_storeu_ps(dst_data + b*C*H*W + c*H*W + i, vres);
96 for (int c = 0; c < C; c++) {
97 __m128 vval = _mm_loadu_ps(dst_data + b*C*H*W + c*H*W + i);
98 _mm_storeu_ps(dst_data + b*C*H*W + c*H*W + i, _mm_div_ps(vval, vexpSum));
103 #if defined(HAVE_AVX2)
104 int start = (H*W / 8) * 8;
105 #elif defined(HAVE_SSE)
106 int start = (H*W / 4) * 4;
110 for (int i = start; i < H * W; i++) {
111 float max = src_data[b * C * H * W + i];
112 for (int c = 0; c < C; c++) {
113 float val = src_data[b * C * H * W + c * H * W + i];
114 if (val > max) max = val;
118 for (int c = 0; c < C; c++) {
119 dst_data[b * C * H * W + c * H * W + i] = exp(src_data[b * C * H * W + c * H * W + i] - max);
120 expSum += dst_data[b * C * H * W + c * H * W + i];
123 for (int c = 0; c < C; c++) {
124 dst_data[b * C * H * W + c * H * W + i] = dst_data[b * C * H * W + c * H * W + i] / expSum;