2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "BatchMatMulLayer.h"
19 #include <cker/operation/BatchMatMul.h>
30 BatchMatMulLayer::BatchMatMulLayer()
31 : _lhs(nullptr), _rhs(nullptr), _output(nullptr), _adj_x(false), _adj_y(false),
32 _kernel(new nnfw::cker::BatchMatMul())
37 BatchMatMulLayer::~BatchMatMulLayer() = default;
39 void BatchMatMulLayer::batchMatMulFloat32()
41 nnfw::cker::BatchMatMul &batchmatmul_kernel = *_kernel;
42 nnfw::cker::Shape lhs_shape = getShape(_lhs);
43 nnfw::cker::Shape rhs_shape = getShape(_rhs);
44 nnfw::cker::Shape output_shape = getShape(_output);
46 // TODO implement for constant input
48 batchmatmul_kernel.prepare(lhs_shape, rhs_shape, _adj_x, _adj_y);
49 batchmatmul_kernel(lhs_shape, getBuffer<float>(_lhs), rhs_shape, getBuffer<float>(_rhs), _adj_x,
50 _adj_y, output_shape, getBuffer<float>(_output));
53 void BatchMatMulLayer::configure(const IPortableTensor *lhs, const IPortableTensor *rhs, bool adj_x,
54 bool adj_y, IPortableTensor *output)
56 assert(lhs != nullptr);
57 assert(rhs != nullptr);
58 assert(output != nullptr);
67 void BatchMatMulLayer::run()
69 if ((_lhs->data_type() == OperandType::FLOAT32) && (_rhs->data_type() == OperandType::FLOAT32))
75 throw std::runtime_error{"BatchMatMul: unsupported data type"};
79 #undef AVGPOOLING_PARAMETERS
83 } // namespace backend