*******************************************************************************/
#include "src/common/mkldnn_thread.hpp"
+#include "src/common/math_utils.hpp"
#include "conv/conv_common.hpp"
};
auto maybe_post_ops = [&](float &conv_res, float dst) {
+ using namespace mkldnn::impl::math;
+
const auto &ops = p->attr.post_ops;
for (int idx = 0; idx < ops.len; ++idx) {
using pk = attr_t::post_ops_t::kind_t;
const auto &e = ops.entry[idx];
+
+ const auto &s = e.eltwise.scale;
+ const auto &a = e.eltwise.alpha;
+ const auto &b = e.eltwise.beta;
+
switch (e.kind) {
- case pk::SUM:
- conv_res += e.sum.scale * dst;
- break;
- case pk::RELU:
- conv_res = e.eltwise.scale * (conv_res < 0 ? 0 : conv_res);
- break;
+ case pk::SUM: conv_res += e.sum.scale * dst; break;
+ case pk::RELU: conv_res = s*relu_fwd(conv_res, a); break;
+ case pk::TANH: conv_res = s*tanh_fwd(conv_res); break;
+ case pk::ELU: conv_res = s*elu_fwd(conv_res, a); break;
+ case pk::SQUARE: conv_res = s*square_fwd(conv_res); break;
+ case pk::ABS: conv_res = s*abs_fwd(conv_res); break;
+ case pk::SQRT: conv_res = s*sqrt_fwd(conv_res); break;
+ case pk::LINEAR: conv_res = s*linear_fwd(conv_res, a, b); break;
+ case pk::BRELU: conv_res = s*bounded_relu_fwd(conv_res, a); break;
+ case pk::SRELU: conv_res = s*soft_relu_fwd(conv_res); break;
+ case pk::LOGISTIC: conv_res = s*logistic_fwd(conv_res); break;
default:
assert(!"unknown attr::post_ops::kind");
}
conv_res += ((float*)bia_m)[bia_off];
}
- if (p->merge == RELU && conv_res < 0)
- conv_res = 0;
-
maybe_scale(conv_res, g * p->oc / p->g + oc);
maybe_post_ops(conv_res, dst);
}
};
+ /* Used for Deconv FWD */
+ auto maybe_post_ops = [&](float &conv_res, float dst) {
+ using namespace mkldnn::impl::math;
+
+ const auto &ops = p->attr.post_ops;
+ for (int idx = 0; idx < ops.len; ++idx) {
+ using pk = attr_t::post_ops_t::kind_t;
+ const auto &e = ops.entry[idx];
+
+ const auto &s = e.eltwise.scale;
+ const auto &a = e.eltwise.alpha;
+ const auto &b = e.eltwise.beta;
+
+ switch (e.kind) {
+ case pk::SUM: conv_res += e.sum.scale * dst; break;
+ case pk::RELU: conv_res = s*relu_fwd(conv_res, a); break;
+ case pk::TANH: conv_res = s*tanh_fwd(conv_res); break;
+ case pk::ELU: conv_res = s*elu_fwd(conv_res, a); break;
+ case pk::SQUARE: conv_res = s*square_fwd(conv_res); break;
+ case pk::ABS: conv_res = s*abs_fwd(conv_res); break;
+ case pk::SQRT: conv_res = s*sqrt_fwd(conv_res); break;
+ case pk::LINEAR: conv_res = s*linear_fwd(conv_res, a, b); break;
+ case pk::BRELU: conv_res = s*bounded_relu_fwd(conv_res, a); break;
+ case pk::SRELU: conv_res = s*soft_relu_fwd(conv_res); break;
+ case pk::LOGISTIC: conv_res = s*logistic_fwd(conv_res); break;
+ default:
+ assert(!"unknown attr::post_ops::kind");
+ }
+ }
+ };
+
mkldnn::impl::parallel_nd(p->g, p->mb, p->ic / p->g, p->id, p->ih, p->iw,
[&](int g, int mb, int ic, int id, int ih, int iw) {
size_t src_off = src_off_f(p, mb, g, ic, id, ih, iw);
float &ds = ((float*)diff_src_m)[src_off];
- ds = 0;
+ float conv_res = 0;
if (fast)
- ker_fast(ds, g, mb, ic, id, ih, iw);
+ ker_fast(conv_res, g, mb, ic, id, ih, iw);
else
- ker(ds, g, mb, ic, id, ih, iw);
+ ker(conv_res, g, mb, ic, id, ih, iw);
if (p->dir & FLAG_BIA) {
const size_t bia_off = (size_t)g * p->ic / p->g + ic;
- ds += ((float*)bia_m)[bia_off];
+ conv_res += ((float*)bia_m)[bia_off];
}
- maybe_scale(ds, g * p->ic / p->g + ic);
+ maybe_scale(conv_res, g * p->ic / p->g + ic);
+ maybe_post_ops(conv_res, ds);
+
+ ds = conv_res;
}
);
}