Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / ruy / RuySupport.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_RUY_RUY_SUPPORT_H__
19 #define __NNFW_CKER_RUY_RUY_SUPPORT_H__
20
21 #include <util/ConfigSource.h>
22 #include <ruy/matrix.h>
23 #include <ruy/ruy.h>
24 #include <cassert>
25 #include "cker/Types.h"
26
27 namespace nnfw
28 {
29 namespace cker
30 {
31 namespace ruy_support
32 {
33
34 inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy)
35 {
36   switch (cache_policy)
37   {
38     case CachePolicy::kNeverCache:
39       return ruy::CachePolicy::kNeverCache;
40     case CachePolicy::kCacheIfLargeSpeedup:
41       return ruy::CachePolicy::kCacheIfLargeSpeedup;
42     case CachePolicy::kAlwaysCache:
43       return ruy::CachePolicy::kAlwaysCache;
44     default:
45       assert(false);
46       return ruy::CachePolicy::kNeverCache;
47   }
48 }
49
50 template <typename Scalar, typename DataPointer>
51 void MakeRuyMatrix(const MatrixParams<Scalar> &params, DataPointer data_ptr,
52                    ruy::Matrix<Scalar> *dst, bool use_caching = false)
53 {
54   ruy::Order ruy_order =
55     params.order == Order::kColMajor ? ruy::Order::kColMajor : ruy::Order::kRowMajor;
56   ruy::MakeSimpleLayout(params.rows, params.cols, ruy_order, dst->mutable_layout());
57   // Note that ruy::Matrix::data is a ConstCheckingPtr, not a plain pointer.
58   // It does care whether we assign to it a Scalar* or a const Scalar*.
59   dst->set_data(data_ptr);
60   dst->set_zero_point(params.zero_point);
61   if (use_caching)
62   {
63     dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
64   }
65 }
66
67 template <typename GemmParamsType, typename RuySpecType>
68 void MakeRuyMulParams(const GemmParamsType &params, RuySpecType *ruy_mul_params)
69 {
70   // This validation has already been performed by the Gemm API entry point,
71   // but it doesn't hurt to test specifically this again here, where it's
72   // being used.
73   ValidateGemmParams(params);
74
75   ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint);
76   ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent);
77   ruy_mul_params->set_multiplier_fixedpoint_perchannel(params.multiplier_fixedpoint_perchannel);
78   ruy_mul_params->set_multiplier_exponent_perchannel(params.multiplier_exponent_perchannel);
79   ruy_mul_params->set_bias(params.bias);
80   ruy_mul_params->set_clamp_min(params.clamp_min);
81   ruy_mul_params->set_clamp_max(params.clamp_max);
82 }
83
84 } // namespace ruy_support
85 } // namespace cker
86 } // namespace nnfw
87
88 #endif // __NNFW_CKER_RUY_RUY_SUPPORT_H__