Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / bnorm / ref_bnorm.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "bnorm/bnorm.hpp"
18
19 namespace bnorm {
20
21 void compute_ref_fwd(const prb_t *p, const dnn_mem_t &src, dnn_mem_t &mean,
22         dnn_mem_t &var, const dnn_mem_t &ss, dnn_mem_t &dst) {
23     auto maybe_post_ops = [&](float &bn_res, float dst) {
24         const auto &ops = p->attr.post_ops;
25         for (int idx = 0; idx < ops.len; ++idx) {
26             using pk = attr_t::post_ops_t::kind_t;
27             const auto &e = ops.entry[idx];
28             switch (e.kind) {
29             case pk::SUM:
30                 bn_res += e.sum.scale * dst;
31                 break;
32             case pk::RELU:
33                 bn_res = e.eltwise.scale * (bn_res < 0 ? 0 : bn_res);
34                 break;
35             default:
36                 assert(!"unknown attr::post_ops::kind");
37             }
38         }
39     };
40 #   pragma omp parallel for
41     for (int c = 0; c < p->ic; ++c) {
42         float smean = ((float *)mean)[c];
43         float svar = ((float *)var)[c];
44         float rcp_denom = (float)(1.0f / (sqrtf(svar + p->eps)));
45
46         float gamma = p->flags & USE_SCALESHIFT ? ((float *)ss)[c] : 1;
47         float beta = p->flags & USE_SCALESHIFT ? ((float *)ss)[p->ic + c] : 0;
48
49         for (int mb = 0; mb < p->mb; ++mb)
50         for (int d = 0; d < p->id; ++d)
51         for (int h = 0; h < p->ih; ++h)
52         for (int w = 0; w < p->iw; ++w) {
53             auto off = data_off(p, mb, c, d, h, w);
54             float res = gamma * (((float *)src)[off] - smean) * rcp_denom + beta;
55             float &D = ((float *)dst)[off];
56             if ((p->flags & FUSE_BN_RELU) && res < 0) res = 0;
57             maybe_post_ops(res, D);
58             D = res;
59         }
60     }
61 }
62
63 void compute_ref_bwd(const prb_t *p, const dnn_mem_t &src,
64         const dnn_mem_t &mean, const dnn_mem_t &var, const dnn_mem_t &d_dst,
65         const dnn_mem_t &ss, const dnn_mem_t &rmask, dnn_mem_t &d_src,
66         dnn_mem_t &d_ss) {
67     const float NHW = p->mb * p->id * p->ih * p->iw;
68
69 #   pragma omp parallel for
70     for (int c = 0; c < p->ic; ++c) {
71         float smean = ((float *)mean)[c];
72         float svar = ((float *)var)[c];
73         float rcp_denom = 1.f / sqrtf(svar + p->eps);
74
75         float gamma = p->flags & USE_SCALESHIFT ? ((float *)ss)[c] : 1;
76
77         float d_gamma = 0;
78         float d_beta = 0;
79
80         for (int mb = 0; mb < p->mb; ++mb)
81         for (int d = 0; d < p->id; ++d)
82         for (int h = 0; h < p->ih; ++h)
83         for (int w = 0; w < p->iw; ++w) {
84             auto off = data_off(p, mb, c, d, h, w);
85             float dd = ((float *)d_dst)[off];
86             if ((p->flags & FUSE_BN_RELU) && ((float *)rmask)[off] == 0)
87                 dd = 0;
88
89             d_gamma += dd * (((float *)src)[off] - smean);
90             d_beta += dd;
91         }
92         d_gamma *= rcp_denom;
93
94         if ((p->flags & USE_SCALESHIFT) && (p->dir & FLAG_WEI)) {
95             ((float *)d_ss)[c] = d_gamma;
96             ((float *)d_ss)[p->ic + c] = d_beta;
97         }
98
99         for (int mb = 0; mb < p->mb; ++mb)
100         for (int d = 0; d < p->id; ++d)
101         for (int h = 0; h < p->ih; ++h)
102         for (int w = 0; w < p->iw; ++w) {
103             auto off = data_off(p, mb, c, d, h, w);
104             float dd = ((float *)d_dst)[off];
105             if ((p->flags & FUSE_BN_RELU) && ((float *)rmask)[off] == 0)
106                 dd = 0;
107             float ds = dd;
108
109             if (!(p->flags & GLOB_STATS)) {
110                 const float x = ((float *)src)[off] - smean;
111                 ds -= (d_beta + x * d_gamma * rcp_denom) / NHW;
112             }
113
114             ((float *)d_src)[off] = rcp_denom * ds * gamma;
115         }
116     }
117 }
118
119 }