Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / ruy / RuySupport.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_RUY_RUY_SUPPORT_H__
19 #define __NNFW_CKER_RUY_RUY_SUPPORT_H__
20
21 #include <util/ConfigSource.h>
22 #include <ruy/context.h>
23 #include "cker/Types.h"
24
25 namespace nnfw
26 {
27 namespace cker
28 {
29 namespace ruy_support
30 {
31
32 template <typename Scalar, typename DataPointer>
33 void MakeRuyMatrix(const MatrixParams<Scalar> &params, DataPointer data_ptr,
34                    ruy::Matrix<Scalar> *dst)
35 {
36   dst->layout.rows = params.rows;
37   dst->layout.cols = params.cols;
38   if (params.order == Order::kColMajor)
39   {
40     dst->layout.order = ruy::Order::kColMajor;
41     dst->layout.stride = params.rows;
42   }
43   else
44   {
45     dst->layout.order = ruy::Order::kRowMajor;
46     dst->layout.stride = params.cols;
47   }
48   // Note that ruy::Matrix::data is a ConstCheckingPtr, not a plain pointer.
49   // It does care whether we assign to it a Scalar* or a const Scalar*.
50   dst->data = data_ptr;
51   dst->zero_point = params.zero_point;
52   dst->cacheable = params.cacheable;
53 }
54
55 template <typename GemmParamsType, typename RuySpecType>
56 void MakeRuySpec(const GemmParamsType &params, RuySpecType *ruy_spec)
57 {
58   // This validation has already been performed by the Gemm API entry point,
59   // but it doesn't hurt to test specifically this again here, where it's
60   // being used.
61   ValidateGemmParams(params);
62
63   ruy_spec->multiplier_fixedpoint = params.multiplier_fixedpoint;
64   ruy_spec->multiplier_exponent = params.multiplier_exponent;
65   ruy_spec->multiplier_fixedpoint_perchannel = params.multiplier_fixedpoint_perchannel;
66   ruy_spec->multiplier_exponent_perchannel = params.multiplier_exponent_perchannel;
67   ruy_spec->bias = params.bias;
68   ruy_spec->clamp_min = params.clamp_min;
69   ruy_spec->clamp_max = params.clamp_max;
70 }
71
72 } // namespace ruy_support
73 } // namespace cker
74 } // namespace nnfw
75
76 #endif // __NNFW_CKER_RUY_RUY_SUPPORT_H__