* limitations under the License.
*******************************************************************************/
-#include "mkldnn_types.h"
-
#include "c_types_map.hpp"
-#include "utils.hpp"
#include "mkldnn_thread.hpp"
#include "type_helpers.hpp"
+#include "utils.hpp"
+
#include "jit_generator.hpp"
#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
namespace {
}
/* convolution forward */
-template <bool with_relu, data_type_t src_type, data_type_t dst_type>
-void _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t
- <with_relu, src_type, dst_type>::execute_forward()
+template <data_type_t src_type, data_type_t dst_type>
+void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t
+ <src_type, dst_type>::execute_forward() const
{
auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
auto weights =
reinterpret_cast<const wei_data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const char *>(this->input_memory(2));
auto dst = reinterpret_cast<dst_data_t *>(this->memory());
+
+ auto scratchpad = this->scratchpad();
+
+ if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) {
+ auto local_scales = scratchpad.template get<float>(
+ key_conv_adjusted_scales);
+ auto scales = pd()->attr()->output_scales_.scales_;
+ size_t count = pd()->attr()->output_scales_.count_;
+ float factor = 1.f / pd()->jcp_.wei_adj_scale;
+ if (count == 1) {
+ utils::array_set(local_scales, scales[0] * factor, 16);
+ } else {
+ for (size_t c = 0; c < count; c++)
+ local_scales[c] = scales[c] * factor;
+ }
+ }
+
parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) {
- execute_forward_thr(ithr, nthr, src, weights, bias, dst);
+ execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
});
}
-template <bool with_relu, data_type_t src_type, data_type_t dst_type>
-void _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<with_relu, src_type, dst_type>
+template <data_type_t src_type, data_type_t dst_type>
+void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type>
::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
- const wei_data_t *weights, const char *bias, dst_data_t *dst) {
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+ const wei_data_t *weights, const char *bias, dst_data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
- const size_t bia_dt_size = conf_.with_bias()
- ? types::data_type_size(conf_.cdesc()->bias_desc.data_type) : 0;
+ const size_t bia_dt_size = pd()->with_bias()
+ ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
const auto &jcp = kernel_->jcp;
+ auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
+ auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales);
const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
- const int stride_h = conf_.cdesc()->strides[0];
- const int stride_w = conf_.cdesc()->strides[1];
- const int pad_t = conf_.cdesc()->padding[0][0];
- const int pad_l = conf_.cdesc()->padding[0][1];
+ const int stride_h = pd()->desc()->strides[0];
+ const int stride_w = pd()->desc()->strides[1];
+ const int pad_t = pd()->desc()->padding[0][0];
+ const int pad_l = pd()->desc()->padding[0][1];
- const auto &oscales = conf_.attr()->output_scales_;
+ const auto &oscales = pd()->attr()->output_scales_;
int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block)
* jcp.oc_block * jcp.ic_block;
const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow);
p.output_data = &dst[dst_off];
- p.load_data = &weights[conf_.with_groups()
+ p.load_data = &weights[pd()->with_groups()
? weights_d.blk_off(g, ocb, icb)
: weights_d.blk_off(ocb, icb)];
p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size];
p.compensation = (jcp.signed_input)
? &compensation[_ocb * jcp.oc_block] : 0;
p.scales = (jcp.signed_input && jcp.ver != ver_vnni)
- ? &local_scales_[jcp.is_oc_scale * _ocb * jcp.oc_block]
+ ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block]
: &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block];
- if (conf_.rtus_.reduce_src_) {
- rp.ws = scratch_ + ithr * ws_per_thread_
+ if (pd()->rtus_.reduce_src_) {
+ rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
+ _icb * jcp.is * jcp.ic_block;
if (ocb == ocb_start) {
rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
} else
p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
+ p.oc_off = _ocb * jcp.oc_block * sizeof(float);
+
kernel_->jit_ker(&p);
};
}
}
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
- data_type::u8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
- data_type::u8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
- data_type::s8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
- data_type::s8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
- data_type::u8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
- data_type::u8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
- data_type::s8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
- data_type::s8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
- data_type::u8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
- data_type::u8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
- data_type::s8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
- data_type::s8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
- data_type::u8, data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
- data_type::u8, data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
- data_type::s8, data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
- data_type::s8, data_type::f32>;
+using namespace data_type;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, u8>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, u8>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s8>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s8>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s32>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s32>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, f32>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, f32>;
+
}
}
}