432b181bdb4ca74f7cd85d058f45519a8c9de7c4
[platform/core/ml/nnfw.git] / compute / cker / include / cker / ruy / RuySupport.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_RUY_RUY_SUPPORT_H__
19 #define __NNFW_CKER_RUY_RUY_SUPPORT_H__
20
21 #include <util/ConfigSource.h>
22 #include <ruy/context.h>
23 #include "cker/Types.h"
24
25 namespace
26 {
27 const int kDefaultNumThreadpoolThreads = 4;
28 }
29
30 namespace nnfw
31 {
32 namespace cker
33 {
34 namespace ruy_support
35 {
36
37 struct RuyContext
38 {
39 public:
40   RuyContext() : ruy_context_(new ruy::Context)
41   {
42     SetMaxNumThreads(onert::util::getConfigInt(onert::util::config::RUY_THREADS));
43 #ifdef USE_RUY_GEMV
44     ruy_context_->cache_policy = ruy::kCacheLHSOnNarrowMul;
45 #endif
46   };
47
48   ruy::Context *ruy_context() const { return ruy_context_.get(); }
49
50   static inline RuyContext &GetRuyContext()
51   {
52     static thread_local RuyContext instance;
53     return instance;
54   }
55
56   void SetMaxNumThreads(int max_num_threads)
57   {
58     const int target_num_threads =
59         max_num_threads > -1 ? max_num_threads : kDefaultNumThreadpoolThreads;
60     ruy_context_->max_num_threads = target_num_threads;
61   }
62
63 private:
64   const std::unique_ptr<ruy::Context> ruy_context_;
65 };
66
67 inline ruy::Context *GetRuyContext()
68 {
69   auto &ctx = RuyContext::GetRuyContext();
70   return ctx.ruy_context();
71 }
72
73 template <typename Scalar, typename DataPointer>
74 void MakeRuyMatrix(const MatrixParams<Scalar> &params, DataPointer data_ptr,
75                    ruy::Matrix<Scalar> *dst)
76 {
77   dst->layout.rows = params.rows;
78   dst->layout.cols = params.cols;
79   if (params.order == Order::kColMajor)
80   {
81     dst->layout.order = ruy::Order::kColMajor;
82     dst->layout.stride = params.rows;
83   }
84   else
85   {
86     dst->layout.order = ruy::Order::kRowMajor;
87     dst->layout.stride = params.cols;
88   }
89   // Note that ruy::Matrix::data is a ConstCheckingPtr, not a plain pointer.
90   // It does care whether we assign to it a Scalar* or a const Scalar*.
91   dst->data = data_ptr;
92   dst->zero_point = params.zero_point;
93   dst->cacheable = params.cacheable;
94 }
95
96 template <typename GemmParamsType, typename RuySpecType>
97 void MakeRuySpec(const GemmParamsType &params, RuySpecType *ruy_spec)
98 {
99   // This validation has already been performed by the Gemm API entry point,
100   // but it doesn't hurt to test specifically this again here, where it's
101   // being used.
102   ValidateGemmParams(params);
103
104   ruy_spec->multiplier_fixedpoint = params.multiplier_fixedpoint;
105   ruy_spec->multiplier_exponent = params.multiplier_exponent;
106   ruy_spec->multiplier_fixedpoint_perchannel = params.multiplier_fixedpoint_perchannel;
107   ruy_spec->multiplier_exponent_perchannel = params.multiplier_exponent_perchannel;
108   ruy_spec->bias = params.bias;
109   ruy_spec->clamp_min = params.clamp_min;
110   ruy_spec->clamp_max = params.clamp_max;
111 }
112
113 } // namespace ruy_support
114 } // namespace cker
115 } // namespace nnfw
116
117 #endif // __NNFW_CKER_RUY_RUY_SUPPORT_H__