2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_RUY_RUY_SUPPORT_H__
19 #define __NNFW_CKER_RUY_RUY_SUPPORT_H__
21 #include <util/ConfigSource.h>
22 #include <ruy/context.h>
23 #include "cker/Types.h"
27 const int kDefaultNumThreadpoolThreads = 4;
40 RuyContext() : ruy_context_(new ruy::Context)
42 SetMaxNumThreads(onert::util::getConfigInt(onert::util::config::RUY_THREADS));
44 ruy_context_->cache_policy = ruy::kCacheLHSOnNarrowMul;
48 ruy::Context *ruy_context() const { return ruy_context_.get(); }
50 static inline RuyContext &GetRuyContext()
52 static thread_local RuyContext instance;
56 void SetMaxNumThreads(int max_num_threads)
58 const int target_num_threads =
59 max_num_threads > -1 ? max_num_threads : kDefaultNumThreadpoolThreads;
60 ruy_context_->max_num_threads = target_num_threads;
64 const std::unique_ptr<ruy::Context> ruy_context_;
67 inline ruy::Context *GetRuyContext()
69 auto &ctx = RuyContext::GetRuyContext();
70 return ctx.ruy_context();
73 template <typename Scalar, typename DataPointer>
74 void MakeRuyMatrix(const MatrixParams<Scalar> ¶ms, DataPointer data_ptr,
75 ruy::Matrix<Scalar> *dst)
77 dst->layout.rows = params.rows;
78 dst->layout.cols = params.cols;
79 if (params.order == Order::kColMajor)
81 dst->layout.order = ruy::Order::kColMajor;
82 dst->layout.stride = params.rows;
86 dst->layout.order = ruy::Order::kRowMajor;
87 dst->layout.stride = params.cols;
89 // Note that ruy::Matrix::data is a ConstCheckingPtr, not a plain pointer.
90 // It does care whether we assign to it a Scalar* or a const Scalar*.
92 dst->zero_point = params.zero_point;
93 dst->cacheable = params.cacheable;
96 template <typename GemmParamsType, typename RuySpecType>
97 void MakeRuySpec(const GemmParamsType ¶ms, RuySpecType *ruy_spec)
99 // This validation has already been performed by the Gemm API entry point,
100 // but it doesn't hurt to test specifically this again here, where it's
102 ValidateGemmParams(params);
104 ruy_spec->multiplier_fixedpoint = params.multiplier_fixedpoint;
105 ruy_spec->multiplier_exponent = params.multiplier_exponent;
106 ruy_spec->multiplier_fixedpoint_perchannel = params.multiplier_fixedpoint_perchannel;
107 ruy_spec->multiplier_exponent_perchannel = params.multiplier_exponent_perchannel;
108 ruy_spec->bias = params.bias;
109 ruy_spec->clamp_min = params.clamp_min;
110 ruy_spec->clamp_max = params.clamp_max;
113 } // namespace ruy_support
117 #endif // __NNFW_CKER_RUY_RUY_SUPPORT_H__