2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "BatchMatMulLayer.h"
19 #include <cker/operation/BatchMatMul.h>
30 BatchMatMulLayer::BatchMatMulLayer()
31 : _lhs(nullptr), _rhs(nullptr), _output(nullptr), _adj_x(false), _adj_y(false),
32 _kernel(new nnfw::cker::BatchMatMul())
37 BatchMatMulLayer::~BatchMatMulLayer() = default;
39 void BatchMatMulLayer::batchMatMulFloat32()
41 nnfw::cker::BatchMatMul &batchmatmul_kernel = *_kernel;
42 nnfw::cker::Shape lhs_shape = getTensorShape(_lhs);
43 nnfw::cker::Shape rhs_shape = getTensorShape(_rhs);
44 nnfw::cker::Shape output_shape = getTensorShape(_output);
46 // TODO implement for constant input
48 batchmatmul_kernel.prepare(lhs_shape, rhs_shape, _adj_x, _adj_y);
49 batchmatmul_kernel(lhs_shape, reinterpret_cast<const float *>(_lhs->buffer()), rhs_shape,
50 reinterpret_cast<const float *>(_rhs->buffer()), _adj_x, _adj_y, output_shape,
51 reinterpret_cast<float *>(_output->buffer()));
54 void BatchMatMulLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs, bool adj_x,
55 bool adj_y, IPortableTensor *output)
57 assert(lhs != nullptr);
58 assert(rhs != nullptr);
59 assert(output != nullptr);
68 void BatchMatMulLayer::run()
70 if ((_lhs->data_type() == OperandType::FLOAT32) && (_rhs->data_type() == OperandType::FLOAT32))
76 throw std::runtime_error{"BatchMatMul: unsupported data type"};
80 #undef AVGPOOLING_PARAMETERS
84 } // namespace backend