Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / optimized / Gemm.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_OPTIMIZED_GEMM_H__
19 #define __NNFW_CKER_OPTIMIZED_GEMM_H__
20
21 #include "cker/eigen/eigen_gemm_eigen.h"
22 #include "cker/Shape.h"
23 #include "cker/Types.h"
24
25 #include <ruy/context.h>
26
27 namespace nnfw
28 {
29 namespace cker
30 {
31 namespace optimized
32 {
33
34 #if defined(CKER_X86_PLATFORM)
35
36 /* From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_x86.h */
37 template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar,
38           QuantizationFlavor quantization_flavor>
39 struct GemmImplX86
40 {
41   static void Run(const MatrixParams<LhsScalar> &, const LhsScalar *,
42                   const MatrixParams<RhsScalar> &, const RhsScalar *,
43                   const MatrixParams<DstScalar> &, DstScalar *,
44                   const GemmParams<AccumScalar, DstScalar, quantization_flavor> &)
45   {
46     static_assert(
47       std::is_floating_point<LhsScalar>::value && std::is_floating_point<RhsScalar>::value &&
48         std::is_floating_point<AccumScalar>::value && std::is_floating_point<DstScalar>::value &&
49         quantization_flavor != QuantizationFlavor::kFloatingPoint,
50       "GemmImplX86 does not supported types other than float yet.");
51   }
52 };
53
54 // For float, defer to eigen for now.
55 template <> struct GemmImplX86<float, float, float, float, QuantizationFlavor::kFloatingPoint>
56 {
57   static void Run(const MatrixParams<float> &lhs_params, const float *lhs_data,
58                   const MatrixParams<float> &rhs_params, const float *rhs_data,
59                   const MatrixParams<float> &dst_params, float *dst_data,
60                   const GemmParams<float, float, QuantizationFlavor::kFloatingPoint> &params)
61   {
62     detail::GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params,
63                                     dst_data, params);
64   }
65 };
66
67 /* From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm.h */
68 /* GEMM dispatch implementation for x86.
69  */
70 template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar,
71           QuantizationFlavor quantization_flavor>
72 struct GemmImpl : GemmImplX86<LhsScalar, RhsScalar, AccumScalar, DstScalar, quantization_flavor>
73 {
74 };
75
76 /* From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm.h */
77 template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar,
78           QuantizationFlavor quantization_flavor>
79 void Gemm(const MatrixParams<LhsScalar> &lhs_params, const LhsScalar *lhs_data,
80           const MatrixParams<RhsScalar> &rhs_params, const RhsScalar *rhs_data,
81           const MatrixParams<DstScalar> &dst_params, DstScalar *dst_data,
82           const GemmParams<AccumScalar, DstScalar, quantization_flavor> &params)
83 {
84   // Generic case: dispatch to any backend as a general GEMM.
85   GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar, quantization_flavor>::Run(
86     lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params);
87 }
88
89 // From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_params.h
90 inline CachePolicy DefaultCachePolicy(bool is_constant_data)
91 {
92   return is_constant_data ? CachePolicy::kCacheIfLargeSpeedup : CachePolicy::kNeverCache;
93 }
94 #endif // CKER_X86_PLATFORM
95
96 } // namespace optimized
97 } // namespace cker
98 } // namespace nnfw
99
100 #endif // __NNFW_CKER_OPTIMIZED_GEMM_H__