[Application] Setup meson file for libyolov2_loss_layer.so
[platform/core/ml/nntrainer.git] / Applications / YOLO / jni / yolo_v2_loss.cpp
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2023 Hyeonseok Lee <hs89.lee@samsung.com>
4  *
5  * @file   yolo_v2_loss.cpp
6  * @date   07 March 2023
7  * @brief  This file contains the yolo v2 loss layer
8  * @see    https://github.com/nnstreamer/nntrainer
9  * @author Hyeonseok Lee <hs89.lee@samsung.com>
10  * @bug    No known bugs except for NYI items
11  *
12  */
13
14 #include "yolo_v2_loss.h"
15 #include <nntrainer_log.h>
16
17 namespace custom {
18
19 static constexpr size_t SINGLE_INOUT_IDX = 0;
20
21 enum YoloV2LossParams {
22   bbox_x_pred,
23   bbox_y_pred,
24   bbox_w_pred,
25   bbox_h_pred,
26   confidence_pred,
27   class_pred,
28   bbox_w_pred_anchor,
29   bbox_h_pred_anchor,
30   bbox_x_gt,
31   bbox_y_gt,
32   bbox_w_gt,
33   bbox_h_gt,
34   confidence_gt,
35   class_gt,
36   bbox_class_mask,
37   iou_mask,
38   bbox1_width,
39   bbox1_height,
40   is_xy_min_max,
41   intersection_width,
42   intersection_height,
43   unions,
44 };
45
46 namespace props {
47 MaxObjectNumber::MaxObjectNumber(const unsigned &value) { set(value); }
48 ClassNumber::ClassNumber(const unsigned &value) { set(value); }
49 GridHeightNumber::GridHeightNumber(const unsigned &value) { set(value); }
50 GridWidthNumber::GridWidthNumber(const unsigned &value) { set(value); }
51 } // namespace props
52
53 /**
54  * @brief mse
55  *
56  * @param pred prediction
57  * @param ground_truth ground truth
58  * @return float loss
59  * @todo make loss behaves like acti_func
60  */
61 float mse(nntrainer::Tensor &pred, nntrainer::Tensor &ground_truth) {
62   nntrainer::Tensor residual;
63   pred.subtract(ground_truth, residual);
64
65   float l2norm = residual.l2norm();
66   l2norm *= l2norm / residual.size();
67
68   return l2norm;
69 }
70
71 /**
72  * @brief backwarding of mse
73  *
74  * @param pred prediction
75  * @param ground_truth ground truth
76  * @param outgoing_derivative outgoing derivative
77  */
78 void msePrime(nntrainer::Tensor &pred, nntrainer::Tensor &ground_truth,
79               nntrainer::Tensor &outgoing_derivative) {
80   pred.subtract(ground_truth, outgoing_derivative);
81   float divider = ((float)pred.size()) / 2;
82   if (outgoing_derivative.divide_i(divider) != ML_ERROR_NONE) {
83     throw std::runtime_error(
84       "[YoloV2LossLayer::calcDerivative] Error when calculating loss");
85   }
86 }
87
88 /**
89  * @brief calculate iou
90  *
91  * @param bbox1_x1 bbox1_x1
92  * @param bbox1_y1 bbox1_y1
93  * @param bbox1_w bbox1_w
94  * @param bbox1_h bbox1_h
95  * @param bbox2_x1 bbox2_x1
96  * @param bbox2_y1 bbox2_y1
97  * @param bbox2_w bbox2_w
98  * @param bbox2_h bbox2_h
99  * @param[out] bbox1_width bbox1 width
100  * @param[out] bbox1_height bbox1 height
101  * @param[out] is_xy_min_max For x1, y1 this value is 1 if x1 > x2, y1 > y2 and
102  * for x2, y2 this is value is 1 if x2 < x1, y2 < y1. else 0.
103  * @param[out] intersection_width intersection width
104  * @param[out] intersection_height intersection height
105  * @param[out] unions unions
106  * @return nntrainer::Tensor iou
107  */
108 nntrainer::Tensor
109 calc_iou(nntrainer::Tensor &bbox1_x1, nntrainer::Tensor &bbox1_y1,
110          nntrainer::Tensor &bbox1_w, nntrainer::Tensor &bbox1_h,
111          nntrainer::Tensor &bbox2_x1, nntrainer::Tensor &bbox2_y1,
112          nntrainer::Tensor &bbox2_w, nntrainer::Tensor &bbox2_h,
113          nntrainer::Tensor &bbox1_width, nntrainer::Tensor &bbox1_height,
114          nntrainer::Tensor &is_xy_min_max,
115          nntrainer::Tensor &intersection_width,
116          nntrainer::Tensor &intersection_height, nntrainer::Tensor &unions) {
117   nntrainer::Tensor bbox1_x2 = bbox1_x1.add(bbox1_w);
118   nntrainer::Tensor bbox1_y2 = bbox1_y1.add(bbox1_h);
119   nntrainer::Tensor bbox2_x2 = bbox2_x1.add(bbox2_w);
120   nntrainer::Tensor bbox2_y2 = bbox2_y1.add(bbox2_h);
121
122   bbox1_x2.subtract(bbox1_x1, bbox1_width);
123   bbox1_y2.subtract(bbox1_y1, bbox1_height);
124   nntrainer::Tensor bbox1 = bbox1_width.multiply(bbox1_height);
125
126   nntrainer::Tensor bbox2_width = bbox2_x2.subtract(bbox2_x1);
127   nntrainer::Tensor bbox2_height = bbox2_y2.subtract(bbox2_y1);
128   nntrainer::Tensor bbox2 = bbox2_width.multiply(bbox2_height);
129
130   auto min_func = [&](nntrainer::Tensor &bbox1_xy, nntrainer::Tensor &bbox2_xy,
131                       nntrainer::Tensor &intersection_xy) {
132     std::transform(bbox1_xy.getData(), bbox1_xy.getData() + bbox1_xy.size(),
133                    bbox2_xy.getData(), intersection_xy.getData(),
134                    [](float x1, float x2) { return std::min(x1, x2); });
135   };
136   auto max_func = [&](nntrainer::Tensor &bbox1_xy, nntrainer::Tensor &bbox2_xy,
137                       nntrainer::Tensor &intersection_xy) {
138     std::transform(bbox1_xy.getData(), bbox1_xy.getData() + bbox1_xy.size(),
139                    bbox2_xy.getData(), intersection_xy.getData(),
140                    [](float x1, float x2) { return std::max(x1, x2); });
141   };
142
143   nntrainer::Tensor intersection_x1(bbox1_x1.getDim());
144   nntrainer::Tensor intersection_x2(bbox1_x1.getDim());
145   nntrainer::Tensor intersection_y1(bbox1_y1.getDim());
146   nntrainer::Tensor intersection_y2(bbox1_y1.getDim());
147   max_func(bbox1_x1, bbox2_x1, intersection_x1);
148   min_func(bbox1_x2, bbox2_x2, intersection_x2);
149   max_func(bbox1_y1, bbox2_y1, intersection_y1);
150   min_func(bbox1_y2, bbox2_y2, intersection_y2);
151
152   auto is_min_max_func = [&](nntrainer::Tensor &xy,
153                              nntrainer::Tensor &intersection,
154                              nntrainer::Tensor &is_min_max) {
155     std::transform(xy.getData(), xy.getData() + xy.size(),
156                    intersection.getData(), is_min_max.getData(),
157                    [](float x, float m) {
158                      return nntrainer::absFloat(x - m) < 1e-4 ? 1.0 : 0.0;
159                    });
160   };
161
162   nntrainer::Tensor is_bbox1_x1_max(bbox1_x1.getDim());
163   nntrainer::Tensor is_bbox1_y1_max(bbox1_x1.getDim());
164   nntrainer::Tensor is_bbox1_x2_min(bbox1_x1.getDim());
165   nntrainer::Tensor is_bbox1_y2_min(bbox1_x1.getDim());
166   is_min_max_func(bbox1_x1, intersection_x1, is_bbox1_x1_max);
167   is_min_max_func(bbox1_y1, intersection_y1, is_bbox1_y1_max);
168   is_min_max_func(bbox1_x2, intersection_x2, is_bbox1_x2_min);
169   is_min_max_func(bbox1_y2, intersection_y2, is_bbox1_y2_min);
170
171   nntrainer::Tensor is_bbox_min_max = nntrainer::Tensor::cat(
172     {is_bbox1_x1_max, is_bbox1_y1_max, is_bbox1_x2_min, is_bbox1_y2_min}, 3);
173   is_xy_min_max.copyData(is_bbox_min_max);
174
175   intersection_x2.subtract(intersection_x1, intersection_width);
176
177   auto type_intersection_width = intersection_width.getDataType();
178   if (type_intersection_width == ml::train::TensorDim::DataType::FP32) {
179     intersection_width.apply_i<float>(nntrainer::ActiFunc::relu<float>);
180   } else if (type_intersection_width == ml::train::TensorDim::DataType::FP16) {
181 #ifdef ENABLE_FP16
182     intersection_width.apply_i<_FP16>(nntrainer::ActiFunc::relu<float>);
183 #else
184     throw std::runtime_error("Not supported data type");
185 #endif
186   }
187
188   intersection_y2.subtract(intersection_y1, intersection_height);
189
190   auto type_intersection_height = intersection_height.getDataType();
191   if (type_intersection_height == ml::train::TensorDim::DataType::FP32) {
192     intersection_height.apply_i<float>(nntrainer::ActiFunc::relu<float>);
193   } else if (type_intersection_height == ml::train::TensorDim::DataType::FP16) {
194 #ifdef ENABLE_FP16
195     intersection_height.apply_i<_FP16>(nntrainer::ActiFunc::relu<_FP16>);
196 #else
197     throw std::runtime_error("Not supported data type");
198 #endif
199   }
200
201   nntrainer::Tensor intersection =
202     intersection_width.multiply(intersection_height);
203   bbox1.add(bbox2, unions);
204   unions.subtract_i(intersection);
205
206   return intersection.divide(unions);
207 }
208
209 /**
210  * @brief calculate iou graident
211  * @details Let say bbox_pred as x, intersection as f(x), union as g(x) and iou
212  * as y. Then y = f(x)/g(x). Also g(x) = bbox1 + bbox2 - f(x). Partial
213  * derivative of y with respect to x will be (f'(x)g(x) - f(x)g'(x))/(g(x)^2).
214  * Partial derivative of g(x) with respect to x will be bbox1'(x) - f'(x).
215  * @param confidence_gt_grad incoming derivative for iou
216  * @param bbox1_width bbox1_width
217  * @param bbox1_height bbox1_height
218  * @param is_xy_min_max For x1, y1 this value is 1 if x1 > x2, y1 > y2 and for
219  * x2, y2 this is value is 1 if x2 < x1, y2 < y1. else 0.
220  * @param intersection_width intersection width
221  * @param intersection_height intersection height
222  * @param unions unions
223  * @return std::vector<nntrainer::Tensor> iou_grad
224  */
225 std::vector<nntrainer::Tensor> calc_iou_grad(
226   nntrainer::Tensor &confidence_gt_grad, nntrainer::Tensor &bbox1_width,
227   nntrainer::Tensor &bbox1_height, nntrainer::Tensor &is_xy_min_max,
228   nntrainer::Tensor &intersection_width, nntrainer::Tensor &intersection_height,
229   nntrainer::Tensor &unions) {
230   nntrainer::Tensor intersection =
231     intersection_width.multiply(intersection_height);
232
233   // 1. calculate intersection local gradient [f'(x)]
234   nntrainer::Tensor intersection_width_relu_prime;
235   nntrainer::Tensor intersection_height_relu_prime;
236   auto type_intersection_width = intersection_width.getDataType();
237   if (type_intersection_width == ml::train::TensorDim::DataType::FP32) {
238     intersection_width_relu_prime =
239       intersection_width.apply<float>(nntrainer::ActiFunc::reluPrime<float>);
240   } else if (type_intersection_width == ml::train::TensorDim::DataType::FP16) {
241 #ifdef ENABLE_FP16
242     intersection_height_relu_prime =
243       intersection_height.apply<_FP16>(nntrainer::ActiFunc::reluPrime<_FP16>);
244 #else
245     throw std::runtime_error("Not supported data type");
246 #endif
247   }
248
249   nntrainer::Tensor intersection_x2_local_grad =
250     intersection_width_relu_prime.multiply(intersection_height);
251   nntrainer::Tensor intersection_y2_local_grad =
252     intersection_height_relu_prime.multiply(intersection_width);
253   nntrainer::Tensor intersection_x1_local_grad =
254     intersection_x2_local_grad.multiply(-1.0);
255   nntrainer::Tensor intersection_y1_local_grad =
256     intersection_y2_local_grad.multiply(-1.0);
257
258   nntrainer::Tensor intersection_local_grad = nntrainer::Tensor::cat(
259     {intersection_x1_local_grad, intersection_y1_local_grad,
260      intersection_x2_local_grad, intersection_y2_local_grad},
261     3);
262   intersection_local_grad.multiply_i(is_xy_min_max);
263
264   // 2. calculate union local gradient [g'(x)]
265   nntrainer::Tensor bbox1_x1_grad = bbox1_height.multiply(-1.0);
266   nntrainer::Tensor bbox1_y1_grad = bbox1_width.multiply(-1.0);
267   nntrainer::Tensor bbox1_x2_grad = bbox1_height;
268   nntrainer::Tensor bbox1_y2_grad = bbox1_width;
269   nntrainer::Tensor bbox1_grad = nntrainer::Tensor::cat(
270     {bbox1_x1_grad, bbox1_y1_grad, bbox1_x2_grad, bbox1_y2_grad}, 3);
271
272   nntrainer::Tensor unions_local_grad =
273     bbox1_grad.subtract(intersection_local_grad);
274
275   // 3. calculate iou local gradient [(f'(x)g(x) - f(x)g'(x))/(g(x)^2)]
276   nntrainer::Tensor lhs = intersection_local_grad.multiply(unions);
277   nntrainer::Tensor rhs = unions_local_grad.multiply(intersection);
278   nntrainer::Tensor iou_grad = lhs.subtract(rhs);
279   iou_grad.divide_i(unions);
280   iou_grad.divide_i(unions);
281
282   // 3. multiply with incoming derivative
283   iou_grad.multiply_i(confidence_gt_grad);
284
285   auto splitted_iou_grad = iou_grad.split({1, 1, 1, 1}, 3);
286   std::vector<nntrainer::Tensor> ret = {
287     splitted_iou_grad[0].add(splitted_iou_grad[2]),
288     splitted_iou_grad[1].add(splitted_iou_grad[3]), splitted_iou_grad[2],
289     splitted_iou_grad[3]};
290   return ret;
291 }
292
293 YoloV2LossLayer::YoloV2LossLayer() :
294   anchors_w({1, 1, NUM_ANCHOR, 1}, anchors_w_buf),
295   anchors_h({1, 1, NUM_ANCHOR, 1}, anchors_h_buf),
296   sigmoid(nntrainer::ActivationType::ACT_SIGMOID, true),
297   softmax(nntrainer::ActivationType::ACT_SOFTMAX, true),
298   yolo_v2_loss_props(props::MaxObjectNumber(), props::ClassNumber(),
299                      props::GridHeightNumber(), props::GridWidthNumber()) {
300   anchors_ratio = anchors_w.divide(anchors_h);
301   wt_idx.fill(std::numeric_limits<unsigned>::max());
302 }
303
304 void YoloV2LossLayer::finalize(nntrainer::InitLayerContext &context) {
305   nntrainer::TensorDim input_dim =
306     context.getInputDimensions()[SINGLE_INOUT_IDX];
307   const unsigned int batch_size = input_dim.batch();
308   const unsigned int class_number =
309     std::get<props::ClassNumber>(yolo_v2_loss_props).get();
310   const unsigned int grid_height_number =
311     std::get<props::GridHeightNumber>(yolo_v2_loss_props).get();
312   const unsigned int grid_width_number =
313     std::get<props::GridWidthNumber>(yolo_v2_loss_props).get();
314   const unsigned int max_object_number =
315     std::get<props::MaxObjectNumber>(yolo_v2_loss_props).get();
316   nntrainer::TensorDim label_dim(batch_size, 1, max_object_number, 5);
317   context.setOutputDimensions({label_dim});
318
319   nntrainer::TensorDim bbox_x_pred_dim(
320     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
321   wt_idx[YoloV2LossParams::bbox_x_pred] = context.requestTensor(
322     bbox_x_pred_dim, "bbox_x_pred", nntrainer::Tensor::Initializer::NONE, true,
323     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
324
325   nntrainer::TensorDim bbox_y_pred_dim(
326     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
327   wt_idx[YoloV2LossParams::bbox_y_pred] = context.requestTensor(
328     bbox_y_pred_dim, "bbox_y_pred", nntrainer::Tensor::Initializer::NONE, true,
329     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
330
331   nntrainer::TensorDim bbox_w_pred_dim(
332     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
333   wt_idx[YoloV2LossParams::bbox_w_pred] = context.requestTensor(
334     bbox_w_pred_dim, "bbox_w_pred", nntrainer::Tensor::Initializer::NONE, true,
335     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
336
337   nntrainer::TensorDim bbox_h_pred_dim(
338     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
339   wt_idx[YoloV2LossParams::bbox_h_pred] = context.requestTensor(
340     bbox_h_pred_dim, "bbox_h_pred", nntrainer::Tensor::Initializer::NONE, true,
341     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
342
343   nntrainer::TensorDim confidence_pred_dim(
344     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
345   wt_idx[YoloV2LossParams::confidence_pred] =
346     context.requestTensor(confidence_pred_dim, "confidence_pred",
347                           nntrainer::Tensor::Initializer::NONE, true,
348                           nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
349
350   nntrainer::TensorDim class_pred_dim(batch_size,
351                                       grid_height_number * grid_width_number,
352                                       NUM_ANCHOR, class_number);
353   wt_idx[YoloV2LossParams::class_pred] = context.requestTensor(
354     class_pred_dim, "class_pred", nntrainer::Tensor::Initializer::NONE, true,
355     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
356
357   nntrainer::TensorDim bbox_w_pred_anchor_dim(
358     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
359   wt_idx[YoloV2LossParams::bbox_w_pred_anchor] =
360     context.requestTensor(bbox_w_pred_anchor_dim, "bbox_w_pred_anchor",
361                           nntrainer::Tensor::Initializer::NONE, false,
362                           nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
363
364   nntrainer::TensorDim bbox_h_pred_anchor_dim(
365     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
366   wt_idx[YoloV2LossParams::bbox_h_pred_anchor] =
367     context.requestTensor(bbox_h_pred_anchor_dim, "bbox_h_pred_anchor",
368                           nntrainer::Tensor::Initializer::NONE, false,
369                           nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
370
371   nntrainer::TensorDim bbox_x_gt_dim(
372     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
373   wt_idx[YoloV2LossParams::bbox_x_gt] = context.requestTensor(
374     bbox_x_gt_dim, "bbox_x_gt", nntrainer::Tensor::Initializer::NONE, false,
375     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
376
377   nntrainer::TensorDim bbox_y_gt_dim(
378     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
379   wt_idx[YoloV2LossParams::bbox_y_gt] = context.requestTensor(
380     bbox_y_gt_dim, "bbox_y_gt", nntrainer::Tensor::Initializer::NONE, false,
381     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
382
383   nntrainer::TensorDim bbox_w_gt_dim(
384     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
385   wt_idx[YoloV2LossParams::bbox_w_gt] = context.requestTensor(
386     bbox_w_gt_dim, "bbox_w_gt", nntrainer::Tensor::Initializer::NONE, false,
387     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
388
389   nntrainer::TensorDim bbox_h_gt_dim(
390     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
391   wt_idx[YoloV2LossParams::bbox_h_gt] = context.requestTensor(
392     bbox_h_gt_dim, "bbox_h_gt", nntrainer::Tensor::Initializer::NONE, false,
393     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
394
395   nntrainer::TensorDim confidence_gt_dim(
396     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
397   wt_idx[YoloV2LossParams::confidence_gt] = context.requestTensor(
398     confidence_gt_dim, "confidence_gt", nntrainer::Tensor::Initializer::NONE,
399     false, nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
400
401   nntrainer::TensorDim class_gt_dim(batch_size,
402                                     grid_height_number * grid_width_number,
403                                     NUM_ANCHOR, class_number);
404   wt_idx[YoloV2LossParams::class_gt] = context.requestTensor(
405     class_gt_dim, "class_gt", nntrainer::Tensor::Initializer::NONE, false,
406     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
407
408   nntrainer::TensorDim bbox_class_mask_dim(
409     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
410   wt_idx[YoloV2LossParams::bbox_class_mask] =
411     context.requestTensor(bbox_class_mask_dim, "bbox_class_mask",
412                           nntrainer::Tensor::Initializer::NONE, false,
413                           nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
414
415   nntrainer::TensorDim iou_mask_dim(
416     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
417   wt_idx[YoloV2LossParams::iou_mask] = context.requestTensor(
418     iou_mask_dim, "iou_mask", nntrainer::Tensor::Initializer::NONE, false,
419     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
420
421   nntrainer::TensorDim bbox1_width_dim(
422     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
423   wt_idx[YoloV2LossParams::bbox1_width] = context.requestTensor(
424     bbox1_width_dim, "bbox1_width", nntrainer::Tensor::Initializer::NONE, false,
425     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
426
427   nntrainer::TensorDim bbox1_height_dim(
428     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
429   wt_idx[YoloV2LossParams::bbox1_height] = context.requestTensor(
430     bbox1_height_dim, "bbox1_height", nntrainer::Tensor::Initializer::NONE,
431     false, nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
432
433   nntrainer::TensorDim is_xy_min_max_dim(
434     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 4);
435   wt_idx[YoloV2LossParams::is_xy_min_max] = context.requestTensor(
436     is_xy_min_max_dim, "is_xy_min_max", nntrainer::Tensor::Initializer::NONE,
437     false, nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
438
439   nntrainer::TensorDim intersection_width_dim(
440     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
441   wt_idx[YoloV2LossParams::intersection_width] =
442     context.requestTensor(intersection_width_dim, "intersection_width",
443                           nntrainer::Tensor::Initializer::NONE, false,
444                           nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
445
446   nntrainer::TensorDim intersection_height_dim(
447     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
448   wt_idx[YoloV2LossParams::intersection_height] =
449     context.requestTensor(intersection_height_dim, "intersection_height",
450                           nntrainer::Tensor::Initializer::NONE, false,
451                           nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
452
453   nntrainer::TensorDim unions_dim(
454     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
455   wt_idx[YoloV2LossParams::unions] = context.requestTensor(
456     unions_dim, "unions", nntrainer::Tensor::Initializer::NONE, false,
457     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
458 }
459
460 void YoloV2LossLayer::forwarding(nntrainer::RunLayerContext &context,
461                                  bool training) {
462   const unsigned int max_object_number =
463     std::get<props::MaxObjectNumber>(yolo_v2_loss_props).get();
464
465   nntrainer::Tensor &input = context.getInput(SINGLE_INOUT_IDX);
466
467   std::vector<nntrainer::Tensor> splited_input =
468     input.split({1, 1, 1, 1, 1, max_object_number}, 3);
469   nntrainer::Tensor bbox_x_pred_ = splited_input[0];
470   nntrainer::Tensor bbox_y_pred_ = splited_input[1];
471   nntrainer::Tensor bbox_w_pred_ = splited_input[2];
472   nntrainer::Tensor bbox_h_pred_ = splited_input[3];
473   nntrainer::Tensor confidence_pred_ = splited_input[4];
474   nntrainer::Tensor class_pred_ = splited_input[5];
475
476   nntrainer::Tensor &bbox_x_pred =
477     context.getTensor(wt_idx[YoloV2LossParams::bbox_x_pred]);
478   nntrainer::Tensor &bbox_y_pred =
479     context.getTensor(wt_idx[YoloV2LossParams::bbox_y_pred]);
480   nntrainer::Tensor &bbox_w_pred =
481     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred]);
482   nntrainer::Tensor &bbox_h_pred =
483     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred]);
484
485   nntrainer::Tensor &confidence_pred =
486     context.getTensor(wt_idx[YoloV2LossParams::confidence_pred]);
487   nntrainer::Tensor &class_pred =
488     context.getTensor(wt_idx[YoloV2LossParams::class_pred]);
489
490   nntrainer::Tensor &bbox_w_pred_anchor =
491     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred_anchor]);
492   nntrainer::Tensor &bbox_h_pred_anchor =
493     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred_anchor]);
494
495   bbox_x_pred.copyData(bbox_x_pred_);
496   bbox_y_pred.copyData(bbox_y_pred_);
497   bbox_w_pred.copyData(bbox_w_pred_);
498   bbox_h_pred.copyData(bbox_h_pred_);
499
500   confidence_pred.copyData(confidence_pred_);
501   class_pred.copyData(class_pred_);
502
503   nntrainer::Tensor &bbox_x_gt =
504     context.getTensor(wt_idx[YoloV2LossParams::bbox_x_gt]);
505   nntrainer::Tensor &bbox_y_gt =
506     context.getTensor(wt_idx[YoloV2LossParams::bbox_y_gt]);
507   nntrainer::Tensor &bbox_w_gt =
508     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_gt]);
509   nntrainer::Tensor &bbox_h_gt =
510     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_gt]);
511
512   nntrainer::Tensor &confidence_gt =
513     context.getTensor(wt_idx[YoloV2LossParams::confidence_gt]);
514   nntrainer::Tensor &class_gt =
515     context.getTensor(wt_idx[YoloV2LossParams::class_gt]);
516
517   nntrainer::Tensor &bbox_class_mask =
518     context.getTensor(wt_idx[YoloV2LossParams::bbox_class_mask]);
519   nntrainer::Tensor &iou_mask =
520     context.getTensor(wt_idx[YoloV2LossParams::iou_mask]);
521
522   bbox_x_gt.setValue(0);
523   bbox_y_gt.setValue(0);
524   bbox_w_gt.setValue(0);
525   bbox_h_gt.setValue(0);
526
527   confidence_gt.setValue(0);
528   class_gt.setValue(0);
529
530   // init mask
531   bbox_class_mask.setValue(0);
532   iou_mask.setValue(0.5);
533
534   // activate pred
535   sigmoid.run_fn(bbox_x_pred, bbox_x_pred);
536   sigmoid.run_fn(bbox_y_pred, bbox_y_pred);
537
538   auto type_bbox_w_pred = bbox_w_pred.getDataType();
539   if (type_bbox_w_pred == ml::train::TensorDim::DataType::FP32) {
540     bbox_w_pred.apply_i<float>(nntrainer::exp_util<float>);
541   } else if (type_bbox_w_pred == ml::train::TensorDim::DataType::FP16) {
542 #ifdef ENABLE_FP16
543     bbox_w_pred.apply_i<_FP16>(nntrainer::exp_util<_FP16>);
544 #else
545     throw std::runtime_error("Not supported data type");
546 #endif
547   }
548
549   auto type_bbox_h_pred = bbox_h_pred.getDataType();
550   if (type_bbox_h_pred == ml::train::TensorDim::DataType::FP32) {
551     bbox_h_pred.apply_i<float>(nntrainer::exp_util<float>);
552   } else if (type_bbox_h_pred == ml::train::TensorDim::DataType::FP16) {
553 #ifdef ENABLE_FP16
554     bbox_h_pred.apply_i<_FP16>(nntrainer::exp_util<_FP16>);
555 #else
556     throw std::runtime_error("Not supported data type");
557 #endif
558   }
559
560   sigmoid.run_fn(confidence_pred, confidence_pred);
561   softmax.run_fn(class_pred, class_pred);
562
563   bbox_w_pred_anchor.copyData(bbox_w_pred);
564   bbox_h_pred_anchor.copyData(bbox_h_pred);
565
566   // apply anchors to bounding box
567   bbox_w_pred_anchor.multiply_i(anchors_w);
568   auto type_bbox_w_pred_anchor = bbox_w_pred_anchor.getDataType();
569   if (type_bbox_w_pred_anchor == ml::train::TensorDim::DataType::FP32) {
570     bbox_w_pred_anchor.apply_i<float>(nntrainer::sqrtFloat);
571   } else if (type_bbox_w_pred_anchor == ml::train::TensorDim::DataType::FP16) {
572 #ifdef ENABLE_FP16
573     bbox_w_pred_anchor.apply_i<_FP16>(nntrainer::sqrtFloat);
574 #else
575     throw std::runtime_error("Not supported data type");
576 #endif
577   }
578
579   bbox_h_pred_anchor.multiply_i(anchors_h);
580   auto type_bbox_h_pred_anchor = bbox_h_pred_anchor.getDataType();
581   if (type_bbox_h_pred_anchor == ml::train::TensorDim::DataType::FP32) {
582     bbox_h_pred_anchor.apply_i<float>(nntrainer::sqrtFloat);
583   } else if (type_bbox_h_pred_anchor == ml::train::TensorDim::DataType::FP16) {
584 #ifdef ENABLE_FP16
585     bbox_h_pred_anchor.apply_i<_FP16>(nntrainer::sqrtFloat);
586 #else
587     throw std::runtime_error("Not supported data type");
588 #endif
589   }
590
591   generate_ground_truth(context);
592
593   nntrainer::Tensor bbox_pred = nntrainer::Tensor::cat(
594     {bbox_x_pred, bbox_y_pred, bbox_w_pred_anchor, bbox_h_pred_anchor}, 3);
595   nntrainer::Tensor masked_bbox_pred = bbox_pred.multiply(bbox_class_mask);
596   nntrainer::Tensor masked_confidence_pred = confidence_pred.multiply(iou_mask);
597   nntrainer::Tensor masked_class_pred = class_pred.multiply(bbox_class_mask);
598
599   nntrainer::Tensor bbox_gt =
600     nntrainer::Tensor::cat({bbox_x_gt, bbox_y_gt, bbox_w_gt, bbox_h_gt}, 3);
601   nntrainer::Tensor masked_bbox_gt = bbox_gt.multiply(bbox_class_mask);
602   nntrainer::Tensor masked_confidence_gt = confidence_gt.multiply(iou_mask);
603   nntrainer::Tensor masked_class_gt = class_gt.multiply(bbox_class_mask);
604
605   float bbox_loss = mse(masked_bbox_pred, masked_bbox_gt);
606   float confidence_loss = mse(masked_confidence_pred, masked_confidence_gt);
607   float class_loss = mse(masked_class_pred, masked_class_gt);
608
609   float loss = 5 * bbox_loss + confidence_loss + class_loss;
610   ml_logd("Current iteration loss: %f", loss);
611 }
612
613 void YoloV2LossLayer::calcDerivative(nntrainer::RunLayerContext &context) {
614   nntrainer::Tensor &bbox_x_pred =
615     context.getTensor(wt_idx[YoloV2LossParams::bbox_x_pred]);
616   nntrainer::Tensor &bbox_x_pred_grad =
617     context.getTensorGrad(wt_idx[YoloV2LossParams::bbox_x_pred]);
618   nntrainer::Tensor &bbox_y_pred =
619     context.getTensor(wt_idx[YoloV2LossParams::bbox_y_pred]);
620   nntrainer::Tensor &bbox_y_pred_grad =
621     context.getTensorGrad(wt_idx[YoloV2LossParams::bbox_y_pred]);
622   nntrainer::Tensor &bbox_w_pred =
623     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred]);
624   nntrainer::Tensor &bbox_w_pred_grad =
625     context.getTensorGrad(wt_idx[YoloV2LossParams::bbox_w_pred]);
626   nntrainer::Tensor &bbox_h_pred =
627     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred]);
628   nntrainer::Tensor &bbox_h_pred_grad =
629     context.getTensorGrad(wt_idx[YoloV2LossParams::bbox_h_pred]);
630
631   nntrainer::Tensor &confidence_pred =
632     context.getTensor(wt_idx[YoloV2LossParams::confidence_pred]);
633   nntrainer::Tensor &confidence_pred_grad =
634     context.getTensorGrad(wt_idx[YoloV2LossParams::confidence_pred]);
635   nntrainer::Tensor &class_pred =
636     context.getTensor(wt_idx[YoloV2LossParams::class_pred]);
637   nntrainer::Tensor &class_pred_grad =
638     context.getTensorGrad(wt_idx[YoloV2LossParams::class_pred]);
639
640   nntrainer::Tensor &bbox_w_pred_anchor =
641     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred_anchor]);
642   nntrainer::Tensor &bbox_h_pred_anchor =
643     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred_anchor]);
644
645   nntrainer::Tensor &bbox_x_gt =
646     context.getTensor(wt_idx[YoloV2LossParams::bbox_x_gt]);
647   nntrainer::Tensor &bbox_y_gt =
648     context.getTensor(wt_idx[YoloV2LossParams::bbox_y_gt]);
649   nntrainer::Tensor &bbox_w_gt =
650     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_gt]);
651   nntrainer::Tensor &bbox_h_gt =
652     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_gt]);
653
654   nntrainer::Tensor &confidence_gt =
655     context.getTensor(wt_idx[YoloV2LossParams::confidence_gt]);
656   nntrainer::Tensor &class_gt =
657     context.getTensor(wt_idx[YoloV2LossParams::class_gt]);
658
659   nntrainer::Tensor &bbox_class_mask =
660     context.getTensor(wt_idx[YoloV2LossParams::bbox_class_mask]);
661   nntrainer::Tensor &iou_mask =
662     context.getTensor(wt_idx[YoloV2LossParams::iou_mask]);
663
664   nntrainer::Tensor &bbox1_width =
665     context.getTensor(wt_idx[YoloV2LossParams::bbox1_width]);
666   nntrainer::Tensor &bbox1_height =
667     context.getTensor(wt_idx[YoloV2LossParams::bbox1_height]);
668   nntrainer::Tensor &is_xy_min_max =
669     context.getTensor(wt_idx[YoloV2LossParams::is_xy_min_max]);
670   nntrainer::Tensor &intersection_width =
671     context.getTensor(wt_idx[YoloV2LossParams::intersection_width]);
672   nntrainer::Tensor &intersection_height =
673     context.getTensor(wt_idx[YoloV2LossParams::intersection_height]);
674   nntrainer::Tensor &unions =
675     context.getTensor(wt_idx[YoloV2LossParams::unions]);
676
677   nntrainer::Tensor bbox_pred = nntrainer::Tensor::cat(
678     {bbox_x_pred, bbox_y_pred, bbox_w_pred_anchor, bbox_h_pred_anchor}, 3);
679   nntrainer::Tensor masked_bbox_pred = bbox_pred.multiply(bbox_class_mask);
680   nntrainer::Tensor masked_confidence_pred = confidence_pred.multiply(iou_mask);
681   nntrainer::Tensor masked_class_pred = class_pred.multiply(bbox_class_mask);
682
683   nntrainer::Tensor bbox_gt =
684     nntrainer::Tensor::cat({bbox_x_gt, bbox_y_gt, bbox_w_gt, bbox_h_gt}, 3);
685   nntrainer::Tensor masked_bbox_gt = bbox_gt.multiply(bbox_class_mask);
686   nntrainer::Tensor masked_confidence_gt = confidence_gt.multiply(iou_mask);
687   nntrainer::Tensor masked_class_gt = class_gt.multiply(bbox_class_mask);
688
689   nntrainer::Tensor masked_bbox_pred_grad;
690   nntrainer::Tensor masked_confidence_pred_grad;
691   nntrainer::Tensor masked_confidence_gt_grad;
692   nntrainer::Tensor masked_class_pred_grad;
693
694   nntrainer::Tensor confidence_gt_grad;
695
696   msePrime(masked_bbox_pred, masked_bbox_gt, masked_bbox_pred_grad);
697   msePrime(masked_confidence_pred, masked_confidence_gt,
698            masked_confidence_pred_grad);
699   msePrime(masked_confidence_gt, masked_confidence_pred,
700            masked_confidence_gt_grad);
701   msePrime(masked_class_pred, masked_class_gt, masked_class_pred_grad);
702
703   masked_bbox_pred_grad.multiply_i(5);
704
705   nntrainer::Tensor bbox_pred_grad;
706
707   masked_bbox_pred_grad.multiply(bbox_class_mask, bbox_pred_grad);
708   masked_confidence_pred_grad.multiply(iou_mask, confidence_pred_grad);
709   masked_confidence_gt_grad.multiply(iou_mask, confidence_gt_grad);
710   masked_class_pred_grad.multiply(bbox_class_mask, class_pred_grad);
711
712   std::vector<nntrainer::Tensor> splitted_bbox_pred_grad =
713     bbox_pred_grad.split({1, 1, 1, 1}, 3);
714   bbox_x_pred_grad.copyData(splitted_bbox_pred_grad[0]);
715   bbox_y_pred_grad.copyData(splitted_bbox_pred_grad[1]);
716   bbox_w_pred_grad.copyData(splitted_bbox_pred_grad[2]);
717   bbox_h_pred_grad.copyData(splitted_bbox_pred_grad[3]);
718
719   // std::vector<nntrainer::Tensor> bbox_pred_iou_grad =
720   //   calc_iou_grad(confidence_gt_grad, bbox1_width, bbox1_height,
721   //   is_xy_min_max,
722   //                 intersection_width, intersection_height, unions);
723   // bbox_x_pred_grad.add_i(bbox_pred_iou_grad[0]);
724   // bbox_y_pred_grad.add_i(bbox_pred_iou_grad[1]);
725   // bbox_w_pred_grad.add_i(bbox_pred_iou_grad[2]);
726   // bbox_h_pred_grad.add_i(bbox_pred_iou_grad[3]);
727
728   /**
729    * @brief calculate gradient for applying anchors to bounding box
730    * @details Let say bbox_pred as x, anchor as c indicated that anchor is
731    * constant for bbox_pred and bbox_pred_anchor as y. Then we can denote y =
732    * sqrt(cx). Partial derivative of y with respect to x will be
733    * sqrt(c)/(2*sqrt(x)) which is equivalent to sqrt(cx)/(2x) and we can replace
734    * sqrt(cx) with y.
735    * @note divide by bbox_pred(x) will not be executed because bbox_pred_grad
736    * will be multiply by bbox_pred(x) soon after.
737    */
738   bbox_w_pred_grad.multiply_i(bbox_w_pred_anchor);
739   bbox_h_pred_grad.multiply_i(bbox_h_pred_anchor);
740   /** intended comment */
741   // bbox_w_pred_grad.divide_i(bbox_w_pred);
742   // bbox_h_pred_grad.divide_i(bbox_h_pred);
743   bbox_w_pred_grad.divide_i(2);
744   bbox_h_pred_grad.divide_i(2);
745
746   sigmoid.run_prime_fn(bbox_x_pred, bbox_x_pred, bbox_x_pred_grad,
747                        bbox_x_pred_grad);
748   sigmoid.run_prime_fn(bbox_y_pred, bbox_y_pred, bbox_y_pred_grad,
749                        bbox_y_pred_grad);
750   /** intended comment */
751   // bbox_w_pred_grad.multiply_i(bbox_w_pred);
752   // bbox_h_pred_grad.multiply_i(bbox_h_pred);
753   sigmoid.run_prime_fn(confidence_pred, confidence_pred, confidence_pred_grad,
754                        confidence_pred_grad);
755   softmax.run_prime_fn(class_pred, class_pred, class_pred_grad,
756                        class_pred_grad);
757
758   nntrainer::Tensor outgoing_derivative_ = nntrainer::Tensor::cat(
759     {bbox_x_pred_grad, bbox_y_pred_grad, bbox_w_pred_grad, bbox_h_pred_grad,
760      confidence_pred_grad, class_pred_grad},
761     3);
762   nntrainer::Tensor &outgoing_derivative =
763     context.getOutgoingDerivative(SINGLE_INOUT_IDX);
764   outgoing_derivative.copyData(outgoing_derivative_);
765 }
766
767 void YoloV2LossLayer::exportTo(nntrainer::Exporter &exporter,
768                                const ml::train::ExportMethods &method) const {
769   exporter.saveResult(yolo_v2_loss_props, method, this);
770 }
771
772 void YoloV2LossLayer::setProperty(const std::vector<std::string> &values) {
773   auto remain_props = loadProperties(values, yolo_v2_loss_props);
774   NNTR_THROW_IF(!remain_props.empty(), std::invalid_argument)
775     << "[YoloV2LossLayer] Unknown Layer Properties count " +
776          std::to_string(values.size());
777 }
778
779 void YoloV2LossLayer::setBatch(nntrainer::RunLayerContext &context,
780                                unsigned int batch) {
781   context.updateTensor(wt_idx[YoloV2LossParams::bbox_x_pred], batch);
782   context.updateTensor(wt_idx[YoloV2LossParams::bbox_y_pred], batch);
783   context.updateTensor(wt_idx[YoloV2LossParams::bbox_w_pred], batch);
784   context.updateTensor(wt_idx[YoloV2LossParams::bbox_h_pred], batch);
785   context.updateTensor(wt_idx[YoloV2LossParams::confidence_pred], batch);
786   context.updateTensor(wt_idx[YoloV2LossParams::class_pred], batch);
787   context.updateTensor(wt_idx[YoloV2LossParams::bbox_w_pred_anchor], batch);
788   context.updateTensor(wt_idx[YoloV2LossParams::bbox_h_pred_anchor], batch);
789
790   context.updateTensor(wt_idx[YoloV2LossParams::bbox_x_gt], batch);
791   context.updateTensor(wt_idx[YoloV2LossParams::bbox_y_gt], batch);
792   context.updateTensor(wt_idx[YoloV2LossParams::bbox_w_gt], batch);
793   context.updateTensor(wt_idx[YoloV2LossParams::bbox_h_gt], batch);
794   context.updateTensor(wt_idx[YoloV2LossParams::confidence_gt], batch);
795   context.updateTensor(wt_idx[YoloV2LossParams::class_gt], batch);
796   context.updateTensor(wt_idx[YoloV2LossParams::bbox_class_mask], batch);
797   context.updateTensor(wt_idx[YoloV2LossParams::iou_mask], batch);
798
799   context.updateTensor(wt_idx[YoloV2LossParams::bbox1_width], batch);
800   context.updateTensor(wt_idx[YoloV2LossParams::bbox1_height], batch);
801   context.updateTensor(wt_idx[YoloV2LossParams::is_xy_min_max], batch);
802   context.updateTensor(wt_idx[YoloV2LossParams::intersection_width], batch);
803   context.updateTensor(wt_idx[YoloV2LossParams::intersection_height], batch);
804   context.updateTensor(wt_idx[YoloV2LossParams::unions], batch);
805 }
806
807 unsigned int YoloV2LossLayer::find_responsible_anchors(float bbox_ratio) {
808   nntrainer::Tensor similarity = anchors_ratio.subtract(bbox_ratio);
809   auto data_type = similarity.getDataType();
810   if (data_type == ml::train::TensorDim::DataType::FP32) {
811     similarity.apply_i<float>(nntrainer::absFloat);
812   } else if (data_type == ml::train::TensorDim::DataType::FP16) {
813 #ifdef ENABLE_FP16
814     similarity.apply_i<_FP16>(nntrainer::absFloat);
815 #else
816     throw std::runtime_error("Not supported data type");
817 #endif
818   }
819   auto data = similarity.getData();
820
821   auto min_iter = std::min_element(data, data + NUM_ANCHOR);
822   return std::distance(data, min_iter);
823 }
824
825 void YoloV2LossLayer::generate_ground_truth(
826   nntrainer::RunLayerContext &context) {
827   const unsigned int max_object_number =
828     std::get<props::MaxObjectNumber>(yolo_v2_loss_props).get();
829   const unsigned int grid_height_number =
830     std::get<props::GridHeightNumber>(yolo_v2_loss_props).get();
831   const unsigned int grid_width_number =
832     std::get<props::GridWidthNumber>(yolo_v2_loss_props).get();
833
834   nntrainer::Tensor &label = context.getLabel(SINGLE_INOUT_IDX);
835
836   nntrainer::Tensor &bbox_x_pred =
837     context.getTensor(wt_idx[YoloV2LossParams::bbox_x_pred]);
838   nntrainer::Tensor &bbox_y_pred =
839     context.getTensor(wt_idx[YoloV2LossParams::bbox_y_pred]);
840   nntrainer::Tensor &bbox_w_pred_anchor =
841     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred_anchor]);
842   nntrainer::Tensor &bbox_h_pred_anchor =
843     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred_anchor]);
844
845   nntrainer::Tensor &bbox_x_gt =
846     context.getTensor(wt_idx[YoloV2LossParams::bbox_x_gt]);
847   nntrainer::Tensor &bbox_y_gt =
848     context.getTensor(wt_idx[YoloV2LossParams::bbox_y_gt]);
849   nntrainer::Tensor &bbox_w_gt =
850     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_gt]);
851   nntrainer::Tensor &bbox_h_gt =
852     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_gt]);
853
854   nntrainer::Tensor &confidence_gt =
855     context.getTensor(wt_idx[YoloV2LossParams::confidence_gt]);
856   nntrainer::Tensor &class_gt =
857     context.getTensor(wt_idx[YoloV2LossParams::class_gt]);
858
859   nntrainer::Tensor &bbox_class_mask =
860     context.getTensor(wt_idx[YoloV2LossParams::bbox_class_mask]);
861   nntrainer::Tensor &iou_mask =
862     context.getTensor(wt_idx[YoloV2LossParams::iou_mask]);
863
864   nntrainer::Tensor &bbox1_width =
865     context.getTensor(wt_idx[YoloV2LossParams::bbox1_width]);
866   nntrainer::Tensor &bbox1_height =
867     context.getTensor(wt_idx[YoloV2LossParams::bbox1_height]);
868   nntrainer::Tensor &is_xy_min_max =
869     context.getTensor(wt_idx[YoloV2LossParams::is_xy_min_max]);
870   nntrainer::Tensor &intersection_width =
871     context.getTensor(wt_idx[YoloV2LossParams::intersection_width]);
872   nntrainer::Tensor &intersection_height =
873     context.getTensor(wt_idx[YoloV2LossParams::intersection_height]);
874   nntrainer::Tensor &unions =
875     context.getTensor(wt_idx[YoloV2LossParams::unions]);
876
877   const unsigned int batch_size = bbox_x_pred.getDim().batch();
878
879   std::vector<nntrainer::Tensor> splited_label =
880     label.split({1, 1, 1, 1, 1}, 3);
881   nntrainer::Tensor bbox_x_label = splited_label[0];
882   nntrainer::Tensor bbox_y_label = splited_label[1];
883   nntrainer::Tensor bbox_w_label = splited_label[2];
884   nntrainer::Tensor bbox_h_label = splited_label[3];
885   nntrainer::Tensor class_label = splited_label[4];
886
887   bbox_x_label.multiply_i(grid_width_number);
888   bbox_y_label.multiply_i(grid_height_number);
889
890   for (unsigned int batch = 0; batch < batch_size; ++batch) {
891     for (unsigned int object = 0; object < max_object_number; ++object) {
892       if (!bbox_w_label.getValue(batch, 0, object, 0) &&
893           !bbox_h_label.getValue(batch, 0, object, 0)) {
894         break;
895       }
896       unsigned int grid_x_index = bbox_x_label.getValue(batch, 0, object, 0);
897       unsigned int grid_y_index = bbox_y_label.getValue(batch, 0, object, 0);
898       unsigned int grid_index = grid_y_index * grid_width_number + grid_x_index;
899       unsigned int responsible_anchor =
900         find_responsible_anchors(bbox_w_label.getValue(batch, 0, object, 0) /
901                                  bbox_h_label.getValue(batch, 0, object, 0));
902
903       bbox_x_gt.setValue(batch, grid_index, responsible_anchor, 0,
904                          bbox_x_label.getValue(batch, 0, object, 0) -
905                            grid_x_index);
906       bbox_y_gt.setValue(batch, grid_index, responsible_anchor, 0,
907                          bbox_y_label.getValue(batch, 0, object, 0) -
908                            grid_y_index);
909       bbox_w_gt.setValue(
910         batch, grid_index, responsible_anchor, 0,
911         nntrainer::sqrtFloat(bbox_w_label.getValue(batch, 0, object, 0)));
912       bbox_h_gt.setValue(
913         batch, grid_index, responsible_anchor, 0,
914         nntrainer::sqrtFloat(bbox_h_label.getValue(batch, 0, object, 0)));
915
916       class_gt.setValue(batch, grid_index, responsible_anchor,
917                         class_label.getValue(batch, 0, object, 0), 1);
918       bbox_class_mask.setValue(batch, grid_index, responsible_anchor, 0, 1);
919       iou_mask.setValue(batch, grid_index, responsible_anchor, 0, 1);
920     }
921   }
922
923   nntrainer::Tensor iou = calc_iou(
924     bbox_x_pred, bbox_y_pred, bbox_w_pred_anchor, bbox_h_pred_anchor, bbox_x_gt,
925     bbox_y_gt, bbox_w_gt, bbox_h_gt, bbox1_width, bbox1_height, is_xy_min_max,
926     intersection_width, intersection_height, unions);
927   confidence_gt.copyData(iou);
928 }
929
930 #ifdef PLUGGABLE
931
932 nntrainer::Layer *create_yolo_v2_loss_layer() {
933   auto layer = new YoloV2LossLayer();
934   return layer;
935 }
936
937 void destory_yolo_v2_loss_layer(nntrainer::Layer *layer) { delete layer; }
938
939 /**
940  * @note ml_train_layer_pluggable defines the entry point for nntrainer to
941  * register a plugin layer
942  */
943 extern "C" {
944 nntrainer::LayerPluggable ml_train_layer_pluggable{create_yolo_v2_loss_layer,
945                                                    destory_yolo_v2_loss_layer};
946 }
947
948 #endif
949 } // namespace custom