Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_1x1_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18 #include "c_types_map.hpp"
19 #include "jit_uni_x8s8s32x_1x1_convolution.hpp"
20
21 namespace mkldnn {
22 namespace impl {
23 namespace cpu {
24
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::memory_format;
27 using namespace mkldnn::impl::utils;
28
29 template <cpu_isa_t isa, bool with_relu, data_type_t src_type, data_type_t dst_type>
30 void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<isa, with_relu, src_type, dst_type>::execute_forward() {
31     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
32     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
33     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
34     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
35
36     const memory_desc_wrapper src_d(conf_.src_pd());
37     const memory_desc_wrapper weights_d(conf_.weights_pd(0));
38     const memory_desc_wrapper dst_d(conf_.dst_pd());
39     const memory_desc_wrapper bias_d(conf_.weights_pd(1));
40
41     const auto &jcp = kernel_->jcp;
42
43     int ocb_work = utils::div_up(jcp.nb_oc, jcp.nb_oc_blocking);
44     int ohb_work = utils::div_up(jcp.oh, jcp.nb_oh_blocking);
45     const int work_amount = jcp.mb * jcp.ngroups * ocb_work * ohb_work;
46
47     const int stride_h = conf_.cdesc()->strides[0];
48     const int stride_w = conf_.cdesc()->strides[1];
49     const int pad_t = conf_.cdesc()->padding[0][0];
50     const int pad_l = conf_.cdesc()->padding[0][1];
51
52     const size_t bia_dt_size = conf_.with_bias()
53         ? types::data_type_size(conf_.cdesc()->bias_desc.data_type) : 0;
54
55     const auto &oscales = conf_.attr()->output_scales_;
56
57     auto ker = [&](const int ithr, const int nthr) {
58         jit_1x1_conv_call_s p = {};
59         p.acc_s32 = ws_ + ithr * ws_per_thread_;
60
61         const int oh_block = jcp.ow;
62
63         int start{0}, end{0};
64         balance211(work_amount, nthr, ithr, start, end);
65
66         int n{0}, g{0}, ocb{0}, ohb{0};
67         nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb,
68                          ohb_work, ocb, ocb_work);
69
70         for (int iwork = start; iwork < end; ++iwork) {
71             int oc_ = ocb * jcp.nb_oc_blocking;
72             int oc_num = jcp.nb_oc_blocking;
73
74             int oh_ = ohb * jcp.nb_oh_blocking;
75             int oh_num = jcp.nb_oh_blocking;
76
77             int oh_step = nstl::min(oh_ + oh_num, jcp.oh) - oh_;
78
79             const int os = oh_ * oh_block;
80             const int oh = os / jcp.ow;
81             const int ow = os % jcp.ow;
82
83             const int ih = nstl::max(oh * stride_h - pad_t, 0);
84             const int iw = nstl::max(ow * stride_w - pad_l, 0);
85
86             p.os_dim = this_block_size(os, jcp.os, oh_step * oh_block);
87             p.oc_dim = nstl::min(oc_ + oc_num, jcp.nb_oc) - oc_;
88
89             const size_t dst_off = dst_d.blk_off(n, oc_*jcp.oc_block, oh, ow);
90             p.output_data = &dst[dst_off];
91
92             if (bias)
93                 p.bias_data = &bias[bias_d.blk_off(oc_ * jcp.oc_block * bia_dt_size)];
94
95             p.scales = &oscales.scales_[jcp.is_oc_scale * oc_ * jcp.oc_block];
96             p.oc_data = &weights[conf_.with_groups() ? weights_d.blk_off(g, oc_, 0) : weights_d.blk_off(oc_, 0)];
97             p.is_data = src + src_d.blk_off(n, 0, ih, iw);
98
99             kernel_->jit_ker(&p);
100
101             nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb,
102                              ohb_work, ocb, ocb_work);
103         }
104     };
105
106     parallel(0, ker);
107 }
108
109 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::u8, data_type::u8>::execute_forward();
110 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::u8, data_type::s8>::execute_forward();
111 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::u8, data_type::s32>::execute_forward();
112 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::u8, data_type::f32>::execute_forward();
113 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::u8, data_type::u8>::execute_forward();
114 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::u8, data_type::s8>::execute_forward();
115 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::u8, data_type::s32>::execute_forward();
116 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::u8, data_type::f32>::execute_forward();
117
118 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::s8, data_type::u8>::execute_forward();
119 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::s8, data_type::s8>::execute_forward();
120 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::s8, data_type::s32>::execute_forward();
121 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, true, data_type::s8, data_type::f32>::execute_forward();
122 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::s8, data_type::u8>::execute_forward();
123 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::s8, data_type::s8>::execute_forward();
124 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::s8, data_type::s32>::execute_forward();
125 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2, false, data_type::s8, data_type::f32>::execute_forward();
126
127 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::u8, data_type::u8>::execute_forward();
128 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::u8, data_type::s8>::execute_forward();
129 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::u8, data_type::s32>::execute_forward();
130 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::u8, data_type::f32>::execute_forward();
131 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::u8, data_type::u8>::execute_forward();
132 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::u8, data_type::s8>::execute_forward();
133 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::u8, data_type::s32>::execute_forward();
134 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::u8, data_type::f32>::execute_forward();
135
136 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::s8, data_type::u8>::execute_forward();
137 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::s8, data_type::s8>::execute_forward();
138 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::s8, data_type::s32>::execute_forward();
139 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, true, data_type::s8, data_type::f32>::execute_forward();
140 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::s8, data_type::u8>::execute_forward();
141 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::s8, data_type::s8>::execute_forward();
142 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::s8, data_type::s32>::execute_forward();
143 template void _jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse42, false, data_type::s8, data_type::f32>::execute_forward();
144
145 }
146 }
147 }