Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / jit_avx512_common_gemm_f32.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_AVX512_COMMON_GEMM_F32_HPP
18 #define JIT_AVX512_COMMON_GEMM_F32_HPP
19
20 #include "c_types_map.hpp"
21 #include "../jit_generator.hpp"
22
23 namespace mkldnn {
24 namespace impl {
25 namespace cpu {
26
27 class jit_avx512_common_gemm_f32 {
28 public:
29     void sgemm(const char *transa, const char *transb, const int *M,
30             const int *N, const int *K, const float *alpha, const float *A,
31             const int *lda, const float *B, const int *ldb, const float *beta,
32             float *C, const int *ldc, const float *bias = NULL);
33
34     jit_avx512_common_gemm_f32(
35             char transa, char transb, float beta, bool hasBias = false);
36     ~jit_avx512_common_gemm_f32();
37
38 private:
39     typedef void (*ker)(long long int, long long int, long long int, float *,
40             float *, long long int, float *, long long int, float *, float *,
41             long long int, float *, float *);
42     void sgemm_nocopy_driver(const char *transa, const char *transb, int m,
43             int n, int k, const float *alpha, const float *a, int lda,
44             const float *b, int ldb, const float *beta, float *c, int ldc,
45             const float *bias, float *ws);
46
47     char transa_, transb_;
48     float beta_;
49     bool hasBias_;
50     struct xbyak_gemm;
51     xbyak_gemm *ker_bn_, *ker_b1_, *ker_b0_;
52     int nthrs_;
53 };
54 }
55 }
56 }
57
58 #endif