Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / FusedBatchNorm.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_FUSEDBATCHNORM_H__
19 #define __NNFW_CKER_FUSEDBATCHNORM_H__
20
21 #include "cker/Types.h"
22 #include "cker/Shape.h"
23 #include "cker/Utils.h"
24
25 #include "cker/operation/Helper/Tensor.h"
26 #include "cker/operation/Helper/MatmulBCast.h"
27
28 #include "Transpose.h"
29 #include "BatchMatMul.h"
30
31 #include <string>
32 #include <vector>
33 #include <map>
34 #include <numeric>
35 #include <algorithm>
36
37 namespace nnfw
38 {
39 namespace cker
40 {
41
42 class FusedBatchNorm
43 {
44 public:
45   FusedBatchNorm() : _prepared(false)
46   {
47     // DO NOTHING
48   }
49
50   void prepare() { _prepared = true; }
51
52   void operator()(const std::vector<Shape> &input_shapes,
53                   const std::vector<const float *> &input_data, const Shape &output_shape,
54                   float *output_data, FusedBatchNormParams param)
55   {
56     // TODO: support fused_batch_norm if is_traninig is false
57     assert(param.is_training == true);
58
59     // TODO: support case where dim[1] != 1 or dim[3] !=1.
60     // Here we only support input tensor of [B, 1, X, 1] shape
61     assert(input_shapes[0].Dims(1) == 1 && input_shapes[0].Dims(3) == 1);
62
63     if (!_prepared)
64
65     {
66       prepare();
67     }
68
69     Tensor transformed_input[5];
70     Tensor transformed_output;
71
72     const int num_inputs = input_shapes.size();
73     std::vector<InputTensor<float>> inputs(num_inputs);
74     for (int i = 0; i < num_inputs; i++)
75     {
76       inputs[i].shape.ReplaceWith(input_shapes[i].DimensionsCount(), input_shapes[i].DimsData());
77       inputs[i].buffer = input_data[i];
78       copyFrom<float>(inputs[i], inputs[i].shape, &transformed_input[i]);
79     }
80
81     InputTensor<float> output;
82     output.shape.ReplaceWith(output_shape.DimensionsCount(), output_shape.DimsData());
83     output.buffer = output_data;
84     copyFrom<float>(output, output.shape, &transformed_output);
85
86     // TODO: support transpose if data_format is NCHW
87     // Here, Eigen use RowMajor kernel(NHWC)
88
89     typename TTypes<float, 4>::Tensor x(transformed_input[0].shaped<float, 4>());
90     typename TTypes<float, 4>::Tensor y(transformed_output.shaped<float, 4>());
91     typename TTypes<float, 1>::Tensor scale(transformed_input[1].shaped<float, 1>());
92     typename TTypes<float, 1>::Tensor offset(transformed_input[2].shaped<float, 1>());
93
94     const int depth = x.dimension(3);
95     const int size = x.size();
96     const int rest_size = size / depth;
97     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
98
99     Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
100     Eigen::array<int, 1> reduce_dims({0});
101     Eigen::array<int, 2> bcast_spec({rest_size, 1});
102
103     auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<float>();
104     const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
105     float rest_size_inv = static_cast<float>(1.0f / static_cast<float>(rest_size));
106     // This adjustment is for Bessel's correction
107     float rest_size_adjust =
108       static_cast<float>(rest_size) / static_cast<float>(rest_size_minus_one);
109
110     Eigen::Tensor<float, 1, Eigen::RowMajor> batch_mean(depth);
111     Eigen::Tensor<float, 1, Eigen::RowMajor> batch_variance(depth);
112
113     const Eigen::ThreadPoolDevice &d = *eigen_support::GetThreadPoolDevice();
114
115     batch_mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
116     auto x_centered = x_rest_by_depth - batch_mean.reshape(one_by_depth).broadcast(bcast_spec);
117
118     batch_variance.device(d) = x_centered.square().sum(reduce_dims) * rest_size_inv;
119     auto scaling_factor = ((batch_variance + param.epsilon).rsqrt() * scale)
120                             .eval()
121                             .reshape(one_by_depth)
122                             .broadcast(bcast_spec);
123     auto x_scaled = x_centered * scaling_factor;
124     auto x_shifted =
125       (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec)).template cast<float>();
126
127     UNUSED_RELEASE(rest_size_adjust);
128
129     y.reshape(rest_by_depth).device(d) = x_shifted;
130
131     memcpy(output_data, y.data(), output_shape.FlatSize() * sizeof(float));
132   }
133
134   template <typename T>
135   void copyFrom(const InputTensor<T> &input, const Shape &shape, Tensor *output)
136   {
137     Tensor temp_tensor;
138     temp_tensor.shape.ReplaceWith(input.shape.DimensionsCount(), input.shape.DimsData());
139     temp_operand.emplace_back(std::make_unique<float[]>(input.shape.FlatSize()));
140     temp_tensor.buffer = temp_operand.back().get();
141     memcpy(temp_tensor.buffer, input.buffer, input.shape.FlatSize() * sizeof(float));
142
143     copyFrom(temp_tensor, shape, output);
144   }
145
146   void copyFrom(const Tensor &input, const Shape &shape, Tensor *output)
147   {
148     if (output->copyFrom(input, shape))
149       return;
150
151     throw std::runtime_error{"Einsum: Encountered error while reshaping a Tensor"};
152   }
153
154 private:
155   bool _prepared;
156   std::vector<std::unique_ptr<float[]>> temp_operand;
157 };
158
159 } // namespace cker
160 } // namespace nnfw
161
162 #endif // __NNFW_CKER_FUSEDBATCHNORM_H__