Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / FusedBatchNormLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "FusedBatchNormLayer.h"
18
19 #include <cker/operation/FusedBatchNorm.h>
20
21 namespace onert
22 {
23 namespace backend
24 {
25 namespace cpu
26 {
27 namespace ops
28 {
29
30 FusedBatchNormLayer::FusedBatchNormLayer()
31     : _inputs(), _output(nullptr), _epsilon(0), _is_training(true),
32       _fusedbatchnorm_kernel(new nnfw::cker::FusedBatchNorm())
33 {
34   // DO NOTHING
35 }
36
37 FusedBatchNormLayer::~FusedBatchNormLayer() = default;
38
39 void FusedBatchNormLayer::fusedbatchnormFloat32()
40 {
41   uint32_t num_inputs = _inputs.size();
42   nnfw::cker::FusedBatchNorm &kernel = *_fusedbatchnorm_kernel;
43
44   kernel.prepare();
45
46   std::vector<nnfw::cker::Shape> inputShapes;
47   std::vector<const float *> inputFloatPtrs;
48
49   for (uint32_t i = 0; i < num_inputs; i++)
50   {
51     inputShapes.emplace_back(getTensorShape(_inputs[i]));
52     inputFloatPtrs.emplace_back(reinterpret_cast<const float *>(_inputs[i]->buffer()));
53   }
54
55   nnfw::cker::FusedBatchNormParams param;
56
57   param.epsilon = _epsilon;
58   param.is_training = _is_training;
59   param.data_format = _data_format;
60
61   kernel(inputShapes, inputFloatPtrs, getTensorShape(_output),
62          reinterpret_cast<float *>(_output->buffer()), param);
63 }
64
65 void FusedBatchNormLayer::run()
66 {
67   if (_output->data_type() == OperandType::FLOAT32)
68   {
69     fusedbatchnormFloat32();
70   }
71   else
72   {
73     throw std::runtime_error{"FusedBatchNorm: unsupported data type"};
74   }
75 }
76
77 void FusedBatchNormLayer::configure(const std::vector<const IPortableTensor *> &inputs,
78                                     float epsilon, bool is_training, std::string data_format,
79                                     IPortableTensor *output)
80 {
81   assert(inputs.size() > 0);
82   assert(output != nullptr);
83
84   _inputs = inputs;
85   _output = output;
86   _epsilon = epsilon;
87   _is_training = is_training;
88   _data_format = data_format;
89 }
90
91 } // namespace ops
92 } // namespace cpu
93 } // namespace backend
94 } // namespace onert