1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
18 #include "c_types_map.hpp"
19 #include "jit_uni_x8s8s32x_1x1_convolution.hpp"
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::memory_format;
27 using namespace mkldnn::impl::utils;
29 template <cpu_isa_t isa, bool with_relu, data_type_t src_type, data_type_t dst_type>
30 void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<isa, with_relu, src_type, dst_type>::execute_forward() {
31 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
32 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
33 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
34 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
36 const memory_desc_wrapper src_d(conf_.src_pd());
37 const memory_desc_wrapper weights_d(conf_.weights_pd(0));
38 const memory_desc_wrapper dst_d(conf_.dst_pd());
39 const memory_desc_wrapper bias_d(conf_.weights_pd(1));
41 const auto &jcp = kernel_->jcp;
43 int ocb_work = utils::div_up(jcp.nb_oc, jcp.nb_oc_blocking);
44 int ohb_work = utils::div_up(jcp.oh, jcp.nb_oh_blocking);
45 const int work_amount = jcp.mb * jcp.ngroups * ocb_work * ohb_work;
47 const int stride_h = conf_.cdesc()->strides[0];
48 const int stride_w = conf_.cdesc()->strides[1];
49 const int pad_t = conf_.cdesc()->padding[0][0];
50 const int pad_l = conf_.cdesc()->padding[0][1];
52 const size_t bia_dt_size = conf_.with_bias()
53 ? types::data_type_size(conf_.cdesc()->bias_desc.data_type) : 0;
55 const auto &oscales = conf_.attr()->output_scales_;
57 auto ker = [&](const int ithr, const int nthr) {
58 jit_1x1_conv_call_s p = {};
59 p.acc_s32 = ws_ + ithr * ws_per_thread_;
61 const int oh_block = jcp.ow;
64 balance211(work_amount, nthr, ithr, start, end);
66 int n{0}, g{0}, ocb{0}, ohb{0};
67 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb,
68 ohb_work, ocb, ocb_work);
70 for (int iwork = start; iwork < end; ++iwork) {
71 int oc_ = ocb * jcp.nb_oc_blocking;
72 int oc_num = jcp.nb_oc_blocking;
74 int oh_ = ohb * jcp.nb_oh_blocking;
75 int oh_num = jcp.nb_oh_blocking;
77 int oh_step = nstl::min(oh_ + oh_num, jcp.oh) - oh_;
79 const int os = oh_ * oh_block;
80 const int oh = os / jcp.ow;
81 const int ow = os % jcp.ow;
83 const int ih = nstl::max(oh * stride_h - pad_t, 0);
84 const int iw = nstl::max(ow * stride_w - pad_l, 0);
86 p.os_dim = this_block_size(os, jcp.os, oh_step * oh_block);
87 p.oc_dim = nstl::min(oc_ + oc_num, jcp.nb_oc) - oc_;
89 const size_t dst_off = dst_d.blk_off(n, oc_*jcp.oc_block, oh, ow);
90 p.output_data = &dst[dst_off];
93 p.bias_data = &bias[bias_d.blk_off(oc_ * jcp.oc_block * bia_dt_size)];
95 p.scales = &oscales.scales_[jcp.is_oc_scale * oc_ * jcp.oc_block];
96 p.oc_data = &weights[conf_.with_groups() ? weights_d.blk_off(g, oc_, 0) : weights_d.blk_off(oc_, 0)];
97 p.is_data = src + src_d.blk_off(n, 0, ih, iw);
101 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb,
102 ohb_work, ocb, ocb_work);
109 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::u8, data_type::u8>::execute_forward();
110 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::u8, data_type::s8>::execute_forward();
111 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::u8, data_type::s32>::execute_forward();
112 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::u8, data_type::f32>::execute_forward();
113 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::u8, data_type::u8>::execute_forward();
114 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::u8, data_type::s8>::execute_forward();
115 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::u8, data_type::s32>::execute_forward();
116 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::u8, data_type::f32>::execute_forward();
118 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::s8, data_type::u8>::execute_forward();
119 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::s8, data_type::s8>::execute_forward();
120 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::s8, data_type::s32>::execute_forward();
121 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::s8, data_type::f32>::execute_forward();
122 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::s8, data_type::u8>::execute_forward();
123 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::s8, data_type::s8>::execute_forward();
124 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::s8, data_type::s32>::execute_forward();
125 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::s8, data_type::f32>::execute_forward();
127 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::u8, data_type::u8>::execute_forward();
128 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::u8, data_type::s8>::execute_forward();
129 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::u8, data_type::s32>::execute_forward();
130 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::u8, data_type::f32>::execute_forward();
131 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::u8, data_type::u8>::execute_forward();
132 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::u8, data_type::s8>::execute_forward();
133 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::u8, data_type::s32>::execute_forward();
134 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::u8, data_type::f32>::execute_forward();
136 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::s8, data_type::u8>::execute_forward();
137 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::s8, data_type::s8>::execute_forward();
138 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::s8, data_type::s32>::execute_forward();
139 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::s8, data_type::f32>::execute_forward();
140 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::s8, data_type::u8>::execute_forward();
141 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::s8, data_type::s8>::execute_forward();
142 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::s8, data_type::s32>::execute_forward();
143 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::s8, data_type::f32>::execute_forward();