Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / BatchMatMulLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "BatchMatMulLayer.h"
18
19 #include <cker/operation/BatchMatMul.h>
20
21 namespace onert
22 {
23 namespace backend
24 {
25 namespace cpu
26 {
27 namespace ops
28 {
29
30 BatchMatMulLayer::BatchMatMulLayer()
31   : _lhs(nullptr), _rhs(nullptr), _output(nullptr), _adj_x(false), _adj_y(false),
32     _kernel(new nnfw::cker::BatchMatMul())
33 {
34   // DO NOTHING
35 }
36
37 BatchMatMulLayer::~BatchMatMulLayer() = default;
38
39 void BatchMatMulLayer::batchMatMulFloat32()
40 {
41   nnfw::cker::BatchMatMul &batchmatmul_kernel = *_kernel;
42   nnfw::cker::Shape lhs_shape = getShape(_lhs);
43   nnfw::cker::Shape rhs_shape = getShape(_rhs);
44   nnfw::cker::Shape output_shape = getShape(_output);
45
46   // TODO implement for constant input
47
48   batchmatmul_kernel.prepare(lhs_shape, rhs_shape, _adj_x, _adj_y);
49   batchmatmul_kernel(lhs_shape, getBuffer<float>(_lhs), rhs_shape, getBuffer<float>(_rhs), _adj_x,
50                      _adj_y, output_shape, getBuffer<float>(_output));
51 }
52
53 void BatchMatMulLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs, bool adj_x,
54                                  bool adj_y, IPortableTensor *output)
55 {
56   assert(lhs != nullptr);
57   assert(rhs != nullptr);
58   assert(output != nullptr);
59
60   _lhs = lhs;
61   _rhs = rhs;
62   _adj_x = adj_x;
63   _adj_y = adj_y;
64   _output = output;
65 }
66
67 void BatchMatMulLayer::run()
68 {
69   if ((_lhs->data_type() == OperandType::FLOAT32) && (_rhs->data_type() == OperandType::FLOAT32))
70   {
71     batchMatMulFloat32();
72   }
73   else
74   {
75     throw std::runtime_error{"BatchMatMul: unsupported data type"};
76   }
77 }
78
79 #undef AVGPOOLING_PARAMETERS
80
81 } // namespace ops
82 } // namespace cpu
83 } // namespace backend
84 } // namespace onert