1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "bnorm/bnorm.hpp"
21 void compute_ref_fwd(const prb_t *p, const dnn_mem_t &src, dnn_mem_t &mean,
22 dnn_mem_t &var, const dnn_mem_t &ss, dnn_mem_t &dst) {
23 auto maybe_post_ops = [&](float &bn_res, float dst) {
24 const auto &ops = p->attr.post_ops;
25 for (int idx = 0; idx < ops.len; ++idx) {
26 using pk = attr_t::post_ops_t::kind_t;
27 const auto &e = ops.entry[idx];
30 bn_res += e.sum.scale * dst;
33 bn_res = e.eltwise.scale * (bn_res < 0 ? 0 : bn_res);
36 assert(!"unknown attr::post_ops::kind");
40 # pragma omp parallel for
41 for (int c = 0; c < p->ic; ++c) {
42 float smean = ((float *)mean)[c];
43 float svar = ((float *)var)[c];
44 float rcp_denom = (float)(1.0f / (sqrtf(svar + p->eps)));
46 float gamma = p->flags & USE_SCALESHIFT ? ((float *)ss)[c] : 1;
47 float beta = p->flags & USE_SCALESHIFT ? ((float *)ss)[p->ic + c] : 0;
49 for (int mb = 0; mb < p->mb; ++mb)
50 for (int d = 0; d < p->id; ++d)
51 for (int h = 0; h < p->ih; ++h)
52 for (int w = 0; w < p->iw; ++w) {
53 auto off = data_off(p, mb, c, d, h, w);
54 float res = gamma * (((float *)src)[off] - smean) * rcp_denom + beta;
55 float &D = ((float *)dst)[off];
56 if ((p->flags & FUSE_BN_RELU) && res < 0) res = 0;
57 maybe_post_ops(res, D);
63 void compute_ref_bwd(const prb_t *p, const dnn_mem_t &src,
64 const dnn_mem_t &mean, const dnn_mem_t &var, const dnn_mem_t &d_dst,
65 const dnn_mem_t &ss, const dnn_mem_t &rmask, dnn_mem_t &d_src,
67 const float NHW = p->mb * p->id * p->ih * p->iw;
69 # pragma omp parallel for
70 for (int c = 0; c < p->ic; ++c) {
71 float smean = ((float *)mean)[c];
72 float svar = ((float *)var)[c];
73 float rcp_denom = 1.f / sqrtf(svar + p->eps);
75 float gamma = p->flags & USE_SCALESHIFT ? ((float *)ss)[c] : 1;
80 for (int mb = 0; mb < p->mb; ++mb)
81 for (int d = 0; d < p->id; ++d)
82 for (int h = 0; h < p->ih; ++h)
83 for (int w = 0; w < p->iw; ++w) {
84 auto off = data_off(p, mb, c, d, h, w);
85 float dd = ((float *)d_dst)[off];
86 if ((p->flags & FUSE_BN_RELU) && ((float *)rmask)[off] == 0)
89 d_gamma += dd * (((float *)src)[off] - smean);
94 if ((p->flags & USE_SCALESHIFT) && (p->dir & FLAG_WEI)) {
95 ((float *)d_ss)[c] = d_gamma;
96 ((float *)d_ss)[p->ic + c] = d_beta;
99 for (int mb = 0; mb < p->mb; ++mb)
100 for (int d = 0; d < p->id; ++d)
101 for (int h = 0; h < p->ih; ++h)
102 for (int w = 0; w < p->iw; ++w) {
103 auto off = data_off(p, mb, c, d, h, w);
104 float dd = ((float *)d_dst)[off];
105 if ((p->flags & FUSE_BN_RELU) && ((float *)rmask)[off] == 0)
109 if (!(p->flags & GLOB_STATS)) {
110 const float x = ((float *)src)[off] - smean;
111 ds -= (d_beta + x * d_gamma * rcp_denom) / NHW;
114 ((float *)d_src)[off] = rcp_denom * ds * gamma;