Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / extension / common / softmax.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #define USE_FAST_EXP 0
8
9 #if USE_FAST_EXP
10 #include "fast_exp.h"
11 #else
12
13 #include "opt_exp.h"
14
15 #endif
16
17 #include <cmath>
18 #include "defs.h"
19 #include "ie_parallel.hpp"
20
21
22 static inline
23 void softmax_many_batches(const float *src_data, float *dst_data, int B, int C, int H, int W) {
24     InferenceEngine::parallel_for(B * H * W, [&](size_t i) {
25         const float *psrc = src_data + (i / (H * W)) * C * H * W - (i / (H * W)) * H * W;
26         float *pdst = dst_data + (i / (H * W)) * C * H * W - (i / (H * W)) * H * W;
27
28         float max = psrc[i];
29         for (int c = 0; c < C; c++) {
30             float val = psrc[c * H * W + i];
31             if (val > max) max = val;
32         }
33
34         float expSum = 0;
35         for (int c = 0; c < C; c++) {
36             pdst[c * H * W + i] = exp(psrc[c * H * W + i] - max);
37             expSum += pdst[c * H * W + i];
38         }
39
40         for (int c = 0; c < C; c++) {
41             pdst[c * H * W + i] = pdst[c * H * W + i] / expSum;
42         }
43     });
44 }
45
46 static inline
47 void softmax_generic(const float *src_data, float *dst_data, int B, int C, int H, int W) {
48     for (int b = 0; b < B; b++) {
49 #if defined(HAVE_AVX2)
50         for (int i = 0; i <= H*W - 8; i += 8) {
51             __m256 vmax = _mm256_loadu_ps(src_data + b*C*H*W + i);
52             for (int c = 0; c < C; c++) {
53                 __m256 vval = _mm256_loadu_ps(src_data + b*C*H*W + c*H*W + i);
54                 __m256 vmask = _mm256_cmp_ps(vval, vmax, _CMP_GT_OS);
55                 vmax = _mm256_blendv_ps(vmax, vval, vmask);
56             }
57
58             __m256 vexpSum = _mm256_setzero_ps();
59             for (int c = 0; c < C; c++) {
60                 __m256 vval = _mm256_loadu_ps(src_data + b*C*H*W + c*H*W + i);
61 #if USE_FAST_EXP
62                 __m256 vres = _avx_fast_exp_ps(_mm256_sub_ps(vval, vmax));
63 #else
64                 __m256 vres = _avx_opt_exp_ps(_mm256_sub_ps(vval, vmax));
65 #endif
66                 vexpSum = _mm256_add_ps(vexpSum, vres);
67                 _mm256_storeu_ps(dst_data + b*C*H*W + c*H*W + i, vres);
68             }
69
70             for (int c = 0; c < C; c++) {
71                 __m256 vval = _mm256_loadu_ps(dst_data + b*C*H*W + c*H*W + i);
72                 _mm256_storeu_ps(dst_data + b*C*H*W + c*H*W + i, _mm256_div_ps(vval, vexpSum));
73             }
74         }
75 #elif defined(HAVE_SSE)
76         for (int i = 0; i <= H*W - 4; i += 4) {
77             __m128 vmax = _mm_loadu_ps(src_data + b*C*H*W + i);
78             for (int c = 0; c < C; c++) {
79                 __m128 vval = _mm_loadu_ps(src_data + b*C*H*W + c*H*W + i);
80                 __m128 vmask = _mm_cmpgt_ps(vval, vmax);
81                 vmax = _mm_blendv_ps(vmax, vval, vmask);
82             }
83
84             __m128 vexpSum = _mm_setzero_ps();
85             for (int c = 0; c < C; c++) {
86                 __m128 vval = _mm_loadu_ps(src_data + b*C*H*W + c*H*W + i);
87 #if USE_FAST_EXP
88                 __m128 vres = _sse_fast_exp_ps(_mm_sub_ps(vval, vmax));
89 #else
90                 __m128 vres = _sse_opt_exp_ps(_mm_sub_ps(vval, vmax));
91 #endif
92                 vexpSum = _mm_add_ps(vexpSum, vres);
93                 _mm_storeu_ps(dst_data + b*C*H*W + c*H*W + i, vres);
94             }
95
96             for (int c = 0; c < C; c++) {
97                 __m128 vval = _mm_loadu_ps(dst_data + b*C*H*W + c*H*W + i);
98                 _mm_storeu_ps(dst_data + b*C*H*W + c*H*W + i, _mm_div_ps(vval, vexpSum));
99             }
100         }
101 #endif
102
103 #if defined(HAVE_AVX2)
104         int start = (H*W / 8) * 8;
105 #elif defined(HAVE_SSE)
106         int start = (H*W / 4) * 4;
107 #else
108         int start = 0;
109 #endif
110         for (int i = start; i < H * W; i++) {
111             float max = src_data[b * C * H * W + i];
112             for (int c = 0; c < C; c++) {
113                 float val = src_data[b * C * H * W + c * H * W + i];
114                 if (val > max) max = val;
115             }
116
117             float expSum = 0;
118             for (int c = 0; c < C; c++) {
119                 dst_data[b * C * H * W + c * H * W + i] = exp(src_data[b * C * H * W + c * H * W + i] - max);
120                 expSum += dst_data[b * C * H * W + c * H * W + i];
121             }
122
123             for (int c = 0; c < C; c++) {
124                 dst_data[b * C * H * W + c * H * W + i] = dst_data[b * C * H * W + c * H * W + i] / expSum;
125             }
126         }
127     }
128 }