#include "mkldnn_test_common.hpp"
#include "gtest/gtest.h"
-
+#include "math_utils.hpp"
#include "mkldnn.hpp"
-namespace mkldnn {
-
-
-template <typename T, typename A> inline T relu_fwd(T s, A alpha) {
- return s > 0 ? s : static_cast<T>(s * alpha);
-}
-
-template <typename T> T tanh_fwd(T s) {
- const float e = ::expf(2*s); /* maybe replace with -2*s? */
- return static_cast<T>((e - 1.0) / (e + 1.0));
-}
-
-template <typename T, typename A> T elu_fwd(T s, A alpha) {
- return s > 0 ? s : static_cast<T>(alpha * (::expf(s) - 1));
-}
-
-template <typename T>
-T square_fwd(T s) {
- return s * s;
-}
-
-template <typename T>
-T abs_fwd(T s) {
- return s > 0 ? s : -s;;
-}
-
-template <typename T>
-T sqrt_fwd(T s) {
- return s > 0 ? ::sqrtf(s) : 0;
-}
-
-template <typename T, typename A>
-T linear_fwd(T s, A alpha, A beta) {
- return alpha * s + beta;
-}
-
-template <typename T, typename A>
-T bounded_relu_fwd(T s, A alpha) {
- s = s > 0 ? s : 0;
- return s > alpha ? alpha : s;
-}
-
-template <typename T>
-T soft_relu_fwd(T s) {
- return logf(1 + ::expf(s));
-}
+using namespace mkldnn::impl::math;
-template <typename T>
-T logistic_fwd(T s) {
- T v = ::expf(s);
- return v / (v + 1);
-}
-
-template <typename T, typename A>
-T clamp_fwd(T s, A alpha, A beta) {
- return s > alpha ? (T)(alpha) : s < beta ? (T)(beta) : s;
-}
+namespace mkldnn {
template <typename data_t_src, typename data_t_wei,
typename data_t_acc, typename data_t_dst>
const memory::desc weights_d = weights.get_primitive_desc().desc();
const memory::desc dst_d = dst.get_primitive_desc().desc();
+ size_t padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
+ size_t padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
+
+ size_t padded_ic_w = weights_d.data.format == mkldnn_OhIw8o4i ? weights_d.data.layout_desc.blocking.padding_dims[1] :
+ src_d.data.layout_desc.blocking.padding_dims[1];
+ size_t padded_oc_w = weights_d.data.format == mkldnn_OhIw8o4i ? weights_d.data.layout_desc.blocking.padding_dims[0] :
+ dst_d.data.layout_desc.blocking.padding_dims[1];
+
mkldnn::impl::parallel_nd(c.mb, c.ng, c.oc / c.ng, c.oh, c.ow,
[&](int n, int g, int oc, int oh, int ow) {
- int oidx = n * c.oc * c.oh * c.ow
- + g * c.oc / c.ng * c.oh * c.ow
- + oc * c.oh * c.ow + oh * c.ow + ow;
-
- int didx = map_index(dst_d, oidx);
- dst_data[didx] = bias_data ?
- bias_data[map_index(
- bias.get_primitive_desc().desc(),
- g * c.oc / c.ng + oc)] :
- data_t_dst{0};
- for (int ic = 0; ic < c.ic / c.ng; ic++) {
- for (int kh = 0; kh < c.kh; kh++) {
- for (int kw = 0; kw < c.kw; kw++) {
- int iw = ow * c.strw
- - c.padw + kw * (1 + c.dilw);
- int ih = oh * c.strh
- - c.padh + kh * (1 + c.dilh);
- if (iw < 0 || iw >= c.iw) continue;
- if (ih < 0 || ih >= c.ih) continue;
- int iidx = n * c.ic * c.ih * c.iw
- + g * c.ic / c.ng * c.ih * c.iw
- + ic * c.ih * c.iw + ih * c.iw + iw;
- int widx = g * c.oc / c.ng * c.ic
- / c.ng * c.kh * c.kw
- + oc * c.ic / c.ng * c.kh * c.kw
- + ic * c.kh * c.kw + kh * c.kw + kw;
-
- dst_data[didx]
- += src_data[map_index(src_d, iidx)]
- * weights_data[map_index(
- weights_d, widx)];
- }
- }
+ size_t oidx = n * padded_oc * c.oh * c.ow
+ + g * padded_oc / c.ng * c.oh * c.ow
+ + oc * c.oh * c.ow + oh * c.ow + ow;
+
+ size_t didx = map_index(dst_d, oidx);
+ dst_data[didx] = bias_data
+ ? bias_data[g * c.oc / c.ng + oc] : data_t_dst{0};
+
+ for (int ic = 0; ic < c.ic / c.ng; ic++)
+ for (int kh = 0; kh < c.kh; kh++)
+ for (int kw = 0; kw < c.kw; kw++)
+ {
+ int ih = oh * c.strh - c.padh + kh * (1 + c.dilh);
+ if (ih < 0 || ih >= c.ih) continue;
+ int iw = ow * c.strw - c.padw + kw * (1 + c.dilw);
+ if (iw < 0 || iw >= c.iw) continue;
+
+ size_t iidx = n * padded_ic * c.ih * c.iw
+ + g * padded_ic / c.ng * c.ih * c.iw
+ + ic * c.ih * c.iw + ih * c.iw + iw;
+ size_t widx = g * padded_oc_w / c.ng * padded_ic_w
+ / c.ng * c.kh * c.kw
+ + oc * padded_ic_w / c.ng * c.kh * c.kw
+ + ic * c.kh * c.kw + kh * c.kw + kw;
+
+ dst_data[didx] += src_data[map_index(src_d, iidx)]
+ * weights_data[map_index(weights_d, widx)];
}
+ auto &d = dst_data[didx];
switch (elt_alg) {
- case eltwise_relu:
- dst_data[didx] = relu_fwd(dst_data[didx], elt_alpha);
- break;
- case eltwise_tanh:
- dst_data[didx] = tanh_fwd(dst_data[didx]);
- break;
- case eltwise_elu:
- dst_data[didx] = elu_fwd(dst_data[didx], elt_alpha);
- break;
- case eltwise_square:
- dst_data[didx] = square_fwd(dst_data[didx]);
- break;
- case eltwise_abs:
- dst_data[didx] = abs_fwd(dst_data[didx]);
- break;
- case eltwise_sqrt:
- dst_data[didx] = sqrt_fwd(dst_data[didx]);
- break;
- case eltwise_linear:
- dst_data[didx] = linear_fwd(dst_data[didx], elt_alpha, elt_beta);
- break;
- case eltwise_bounded_relu:
- dst_data[didx] = bounded_relu_fwd(dst_data[didx], elt_alpha);
- break;
- case eltwise_soft_relu:
- dst_data[didx] = soft_relu_fwd(dst_data[didx]);
- break;
- case eltwise_logistic:
- dst_data[didx] = logistic_fwd(dst_data[didx]);
- break;
- default:
- assert(!"unknown alg_kind");
+ case eltwise_relu: d = relu_fwd(d, elt_alpha); break;
+ case eltwise_tanh: d = tanh_fwd(d); break;
+ case eltwise_elu: d = elu_fwd(d, elt_alpha); break;
+ case eltwise_square: d = square_fwd(d); break;
+ case eltwise_abs: d = abs_fwd(d); break;
+ case eltwise_sqrt: d = sqrt_fwd(d); break;
+ case eltwise_linear: d = linear_fwd(d, elt_alpha, elt_beta); break;
+ case eltwise_bounded_relu: d = bounded_relu_fwd(d, elt_alpha); break;
+ case eltwise_soft_relu: d = soft_relu_fwd(d); break;
+ case eltwise_logistic: d = logistic_fwd(d); break;
+ case eltwise_clamp: d = clamp_fwd(d, elt_alpha, elt_beta); break;
+ case eltwise_exp: d = exp_fwd(d); break;
+ default: assert(!"unknown alg_kind");
}
}
);
class convolution_eltwise_test
: public ::testing::TestWithParam<test_convolution_eltwise_params_t> {
protected:
- virtual void SetUp()
- {
+ virtual void SetUp() {
test_convolution_eltwise_params_t p
= ::testing::TestWithParam<
test_convolution_eltwise_params_t>::GetParam();
auto dst_ref = memory({c_dst_desc, eng});
fill_data<data_t_src>(c_src.get_primitive_desc().get_size()
- / sizeof(data_t_src), (data_t_src *)c_src.get_data_handle(), data_t_src(0), data_t_src(1));
+ / sizeof(data_t_src), (data_t_src *)c_src.get_data_handle(),
+ data_t_src(0), data_t_src(1));
+ check_zero_tail<data_t_src>(1, c_src);
fill_data<data_t_wei>(
c_weights.get_primitive_desc().get_size()
- / sizeof(data_t_wei),(data_t_wei *)c_weights.get_data_handle(), data_t_wei(0), data_t_wei(1));
+ / sizeof(data_t_wei),(data_t_wei *)c_weights.get_data_handle(),
+ data_t_wei(0), data_t_wei(1));
+ check_zero_tail<data_t_wei>(1, c_weights);
bool with_bias = p.formats.bias_format != memory::format::format_undef;
auto c_bias_desc = with_bias ?
(data_t_dst *)c_bias.get_data_handle(), 1., true);
}
- std::vector<int> padR = { cd.padh, cd.padw };
+ std::vector<ptrdiff_t> padR = { cd.padh, cd.padw };
for (int i = 0; i < 2; ++i) {
if ((cd.ih - ((cd.kh - 1) * (cd.dilh + 1) + 1) + cd.padh + padR[0])
/ cd.strh + 1 != cd.oh)
compute_ref_conv_eltwise_fwd<data_t_src, data_t_wei, data_t_wei,
data_t_dst>(cd, c_src, c_weights, c_bias, dst_ref, with_bias,
p.alg, eltwise_alpha, eltwise_beta);
- compare_data<data_t_dst>(dst_ref, c_dst);
+ check_zero_tail<data_t_dst>(1, dst_ref);
+
+ compare_data<data_t_dst>(dst_ref, c_dst, 1e-2);
+ check_zero_tail<data_t_dst>(0, c_dst);
}
};