2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "FusedBatchNormLayer.h"
19 #include <cker/operation/FusedBatchNorm.h>
30 FusedBatchNormLayer::FusedBatchNormLayer()
31 : _inputs(), _output(nullptr), _epsilon(0), _is_training(true),
32 _fusedbatchnorm_kernel(new nnfw::cker::FusedBatchNorm())
37 FusedBatchNormLayer::~FusedBatchNormLayer() = default;
39 void FusedBatchNormLayer::fusedbatchnormFloat32()
41 uint32_t num_inputs = _inputs.size();
42 nnfw::cker::FusedBatchNorm &kernel = *_fusedbatchnorm_kernel;
46 std::vector<nnfw::cker::Shape> inputShapes;
47 std::vector<const float *> inputFloatPtrs;
49 for (uint32_t i = 0; i < num_inputs; i++)
51 inputShapes.emplace_back(getTensorShape(_inputs[i]));
52 inputFloatPtrs.emplace_back(reinterpret_cast<const float *>(_inputs[i]->buffer()));
55 nnfw::cker::FusedBatchNormParams param;
57 param.epsilon = _epsilon;
58 param.is_training = _is_training;
59 param.data_format = _data_format;
61 kernel(inputShapes, inputFloatPtrs, getTensorShape(_output),
62 reinterpret_cast<float *>(_output->buffer()), param);
65 void FusedBatchNormLayer::run()
67 if (_output->data_type() == OperandType::FLOAT32)
69 fusedbatchnormFloat32();
73 throw std::runtime_error{"FusedBatchNorm: unsupported data type"};
77 void FusedBatchNormLayer::configure(const std::vector<const IPortableTensor *> &inputs,
78 float epsilon, bool is_training, std::string data_format,
79 IPortableTensor *output)
81 assert(inputs.size() > 0);
82 assert(output != nullptr);
87 _is_training = is_training;
88 _data_format = data_format;
93 } // namespace backend