Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / BatchMatMulLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "BatchMatMulLayer.h"
18
19 #include <cker/operation/BatchMatMul.h>
20
21 namespace onert
22 {
23 namespace backend
24 {
25 namespace cpu
26 {
27 namespace ops
28 {
29
30 BatchMatMulLayer::BatchMatMulLayer()
31     : _lhs(nullptr), _rhs(nullptr), _output(nullptr), _adj_x(false), _adj_y(false),
32       _kernel(new nnfw::cker::BatchMatMul())
33 {
34   // DO NOTHING
35 }
36
37 BatchMatMulLayer::~BatchMatMulLayer() = default;
38
39 void BatchMatMulLayer::batchMatMulFloat32()
40 {
41   nnfw::cker::BatchMatMul &batchmatmul_kernel = *_kernel;
42   nnfw::cker::Shape lhs_shape = getTensorShape(_lhs);
43   nnfw::cker::Shape rhs_shape = getTensorShape(_rhs);
44   nnfw::cker::Shape output_shape = getTensorShape(_output);
45
46   // TODO implement for constant input
47
48   batchmatmul_kernel.prepare(lhs_shape, rhs_shape, _adj_x, _adj_y);
49   batchmatmul_kernel(lhs_shape, reinterpret_cast<const float *>(_lhs->buffer()), rhs_shape,
50                      reinterpret_cast<const float *>(_rhs->buffer()), _adj_x, _adj_y, output_shape,
51                      reinterpret_cast<float *>(_output->buffer()));
52 }
53
54 void BatchMatMulLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs, bool adj_x,
55                                  bool adj_y, IPortableTensor *output)
56 {
57   assert(lhs != nullptr);
58   assert(rhs != nullptr);
59   assert(output != nullptr);
60
61   _lhs = lhs;
62   _rhs = rhs;
63   _adj_x = adj_x;
64   _adj_y = adj_y;
65   _output = output;
66 }
67
68 void BatchMatMulLayer::run()
69 {
70   if ((_lhs->data_type() == OperandType::FLOAT32) && (_rhs->data_type() == OperandType::FLOAT32))
71   {
72     batchMatMulFloat32();
73   }
74   else
75   {
76     throw std::runtime_error{"BatchMatMul: unsupported data type"};
77   }
78 }
79
80 #undef AVGPOOLING_PARAMETERS
81
82 } // namespace ops
83 } // namespace cpu
84 } // namespace backend
85 } // namespace onert