2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_OPTIMIZED_GEMM_H__
19 #define __NNFW_CKER_OPTIMIZED_GEMM_H__
21 #include "cker/eigen/eigen_gemm_eigen.h"
22 #include "cker/Shape.h"
23 #include "cker/Types.h"
25 #include <ruy/context.h>
34 #if defined(CKER_X86_PLATFORM)
36 /* From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_x86.h */
37 template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar,
38 QuantizationFlavor quantization_flavor>
41 static void Run(const MatrixParams<LhsScalar> &, const LhsScalar *,
42 const MatrixParams<RhsScalar> &, const RhsScalar *,
43 const MatrixParams<DstScalar> &, DstScalar *,
44 const GemmParams<AccumScalar, DstScalar, quantization_flavor> &)
47 std::is_floating_point<LhsScalar>::value && std::is_floating_point<RhsScalar>::value &&
48 std::is_floating_point<AccumScalar>::value && std::is_floating_point<DstScalar>::value &&
49 quantization_flavor != QuantizationFlavor::kFloatingPoint,
50 "GemmImplX86 does not supported types other than float yet.");
54 // For float, defer to eigen for now.
55 template <> struct GemmImplX86<float, float, float, float, QuantizationFlavor::kFloatingPoint>
57 static void Run(const MatrixParams<float> &lhs_params, const float *lhs_data,
58 const MatrixParams<float> &rhs_params, const float *rhs_data,
59 const MatrixParams<float> &dst_params, float *dst_data,
60 const GemmParams<float, float, QuantizationFlavor::kFloatingPoint> ¶ms)
62 detail::GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params,
67 /* From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm.h */
68 /* GEMM dispatch implementation for x86.
70 template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar,
71 QuantizationFlavor quantization_flavor>
72 struct GemmImpl : GemmImplX86<LhsScalar, RhsScalar, AccumScalar, DstScalar, quantization_flavor>
76 /* From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm.h */
77 template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar,
78 QuantizationFlavor quantization_flavor>
79 void Gemm(const MatrixParams<LhsScalar> &lhs_params, const LhsScalar *lhs_data,
80 const MatrixParams<RhsScalar> &rhs_params, const RhsScalar *rhs_data,
81 const MatrixParams<DstScalar> &dst_params, DstScalar *dst_data,
82 const GemmParams<AccumScalar, DstScalar, quantization_flavor> ¶ms)
84 // Generic case: dispatch to any backend as a general GEMM.
85 GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar, quantization_flavor>::Run(
86 lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params);
89 // From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_params.h
90 inline CachePolicy DefaultCachePolicy(bool is_constant_data)
92 return is_constant_data ? CachePolicy::kCacheIfLargeSpeedup : CachePolicy::kNeverCache;
94 #endif // CKER_X86_PLATFORM
96 } // namespace optimized
100 #endif // __NNFW_CKER_OPTIMIZED_GEMM_H__