Imported Upstream version 1.22.1
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / linux / PALBatchMatMul.h
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_BATCHMATMUL_H
19 #define LUCI_INTERPRETER_PAL_BATCHMATMUL_H
20
21 #include <tensorflow/lite/kernels/internal/reference/batch_matmul.h>
22
23 namespace luci_interpreter_pal
24 {
25 inline void BatchMatMul(const tflite::RuntimeShape &lhs_shape, const float *lhs_data,
26                         const tflite::RuntimeShape &rhs_shape, const float *rhs_data,
27                         const tflite::RuntimeShape &output_shape, float *output_data)
28 {
29   tflite::reference_ops::BatchMatMul(lhs_shape, lhs_data, rhs_shape, rhs_data, output_shape,
30                                      output_data);
31 }
32
33 static inline void SetupScratchpadTensor(luci_interpreter::Tensor *lhs_scratchpad,
34                                          luci_interpreter::Tensor *rhs_scratchpad,
35                                          const tflite::RuntimeShape &lhs_shape,
36                                          const tflite::RuntimeShape &rhs_shape)
37 {
38   // Scratchpad for transposed LHS
39   {
40     auto lhs_rank = lhs_shape.DimensionsCount();
41     luci_interpreter::Shape scratchpad_size(lhs_rank);
42     for (int i = 0; i < lhs_rank - 2; ++i)
43     {
44       scratchpad_size.dim(i) = lhs_shape.Dims(i);
45     }
46     scratchpad_size.dim(lhs_rank - 2) = lhs_shape.Dims(lhs_rank - 1);
47     scratchpad_size.dim(lhs_rank - 1) = lhs_shape.Dims(lhs_rank - 2);
48
49     lhs_scratchpad->resize(scratchpad_size);
50   }
51   // Scratchpad for transposed RHS
52   {
53     auto rhs_rank = rhs_shape.DimensionsCount();
54     luci_interpreter::Shape scratchpad_size(rhs_rank);
55     for (int i = 0; i < rhs_rank - 2; ++i)
56     {
57       scratchpad_size.dim(i) = rhs_shape.Dims(i);
58     }
59     scratchpad_size.dim(rhs_rank - 2) = rhs_shape.Dims(rhs_rank - 1);
60     scratchpad_size.dim(rhs_rank - 1) = rhs_shape.Dims(rhs_rank - 2);
61
62     rhs_scratchpad->resize(scratchpad_size);
63   }
64 }
65
66 } // namespace luci_interpreter_pal
67
68 #endif // LUCI_INTERPRETER_PAL_BATCHMATMUL_H