1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "math_utils.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "simple_q10n.hpp"
20 #include "gemm_x8s8s32x_inner_product.hpp"
27 using namespace memory_format;
28 using namespace memory_tracking::names;
30 template <data_type_t src_type, data_type_t dst_type>
31 void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type
32 >::execute_forward() const {
33 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
34 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
35 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
36 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
38 const int MB = pd()->MB();
39 const int OC = pd()->OC();
41 bool wei_tr = utils::one_of(pd()->weights_pd()->desc()->format,
42 oi, oiw, owi, oihw, ohwi, oidhw, odhwi);
46 const int K = pd()->IC_total_padded();
47 const int8_t off_a = 0, off_b = 0;
48 const int32_t off_c = 0;
50 const float *scales = pd()->attr()->output_scales_.scales_;
52 acc_data_t *acc = pd()->dst_is_acc_
54 : scratchpad().template get<acc_data_t>(key_iprod_int_dat_in_acc_dt);
56 const float onef = 1.0, zerof = 0.0;
58 if (src_type == data_type::u8) {
59 mkldnn_gemm_s8u8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef,
60 weights, wei_tr ? &K : &M, &off_a, (uint8_t *)src, &K, &off_b, &zerof,
62 } else if (src_type == data_type::s8) {
63 mkldnn_gemm_s8s8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef,
64 weights, wei_tr ? &K : &M, &off_a, (int8_t *)src, &K, &off_b, &zerof,
67 assert(!"incorrect src type");
70 if (!pd()->attr()->has_default_values() || !pd()->dst_is_acc_
71 || pd()->with_bias()) {
72 const bool force_sequential = MB * OC < 2000;
73 parallel(force_sequential ? 1 : 0, (size_t)OC * MB, [&](int ithr, int nthr) {
74 size_t start = 0, end = 0;
75 balance211((size_t)OC * MB, nthr, ithr, start, end);
76 (*pp_kernel_)(dst, acc, bias, scales, start, end);
81 using namespace data_type;
83 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, f32>;
84 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s32>;
85 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s8>;
86 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, u8>;
87 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, f32>;
88 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s32>;
89 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s8>;
90 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, u8>;