updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_x8s8s32x_inner_product.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "math_utils.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "simple_q10n.hpp"
20 #include "gemm_x8s8s32x_inner_product.hpp"
21
22 namespace mkldnn {
23 namespace impl {
24 namespace cpu {
25
26 using namespace math;
27 using namespace memory_format;
28 using namespace memory_tracking::names;
29
30 template <data_type_t src_type, data_type_t dst_type>
31 void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type
32         >::execute_forward() const {
33     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
34     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
35     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
36     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
37
38     const int MB = pd()->MB();
39     const int OC = pd()->OC();
40
41     bool wei_tr = utils::one_of(pd()->weights_pd()->desc()->format,
42             oi, oiw, owi, oihw, ohwi, oidhw, odhwi);
43
44     const int M = OC;
45     const int N = MB;
46     const int K = pd()->IC_total_padded();
47     const int8_t off_a = 0, off_b = 0;
48     const int32_t off_c = 0;
49
50     const float *scales = pd()->attr()->output_scales_.scales_;
51
52     acc_data_t *acc = pd()->dst_is_acc_
53         ? (acc_data_t *)dst
54         : scratchpad().template get<acc_data_t>(key_iprod_int_dat_in_acc_dt);
55
56     const float onef = 1.0, zerof = 0.0;
57
58     if (src_type == data_type::u8) {
59         mkldnn_gemm_s8u8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef,
60                 weights, wei_tr ? &K : &M, &off_a, (uint8_t *)src, &K, &off_b, &zerof,
61                 acc, &M, &off_c);
62     } else if (src_type == data_type::s8) {
63         mkldnn_gemm_s8s8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef,
64                 weights, wei_tr ? &K : &M, &off_a, (int8_t *)src, &K, &off_b, &zerof,
65                 acc, &M, &off_c);
66     } else {
67         assert(!"incorrect src type");
68     }
69
70     if (!pd()->attr()->has_default_values() || !pd()->dst_is_acc_
71             || pd()->with_bias()) {
72         const bool force_sequential = MB * OC < 2000;
73         parallel(force_sequential ? 1 : 0, (size_t)OC * MB, [&](int ithr, int nthr) {
74             size_t start = 0, end = 0;
75             balance211((size_t)OC * MB, nthr, ithr, start, end);
76             (*pp_kernel_)(dst, acc, bias, scales, start, end);
77         });
78     }
79 }
80
81 using namespace data_type;
82
83 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, f32>;
84 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s32>;
85 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s8>;
86 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, u8>;
87 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, f32>;
88 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s32>;
89 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s8>;
90 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, u8>;
91 }
92 }
93 }