[Application] match nntrainer and pytorch yolo model
[platform/core/ml/nntrainer.git] / Applications / YOLO / jni / yolo_v2_loss.cpp
1 // SPDX-License-Identifier: Apache-2.0
2 /**
3  * Copyright (C) 2023 Hyeonseok Lee <hs89.lee@samsung.com>
4  *
5  * @file   yolo_v2_loss.cpp
6  * @date   07 March 2023
7  * @brief  This file contains the yolo v2 loss layer
8  * @see    https://github.com/nnstreamer/nntrainer
9  * @author Hyeonseok Lee <hs89.lee@samsung.com>
10  * @bug    No known bugs except for NYI items
11  *
12  */
13
14 #include "yolo_v2_loss.h"
15
16 namespace custom {
17
18 static constexpr size_t SINGLE_INOUT_IDX = 0;
19
20 enum YoloV2LossParams {
21   bbox_x_pred,
22   bbox_y_pred,
23   bbox_w_pred,
24   bbox_h_pred,
25   confidence_pred,
26   class_pred,
27   bbox_w_pred_anchor,
28   bbox_h_pred_anchor,
29   bbox_x_gt,
30   bbox_y_gt,
31   bbox_w_gt,
32   bbox_h_gt,
33   confidence_gt,
34   class_gt,
35   bbox_class_mask,
36   iou_mask,
37   bbox1_width,
38   bbox1_height,
39   is_xy_min_max,
40   intersection_width,
41   intersection_height,
42   unions,
43 };
44
45 namespace props {
46 MaxObjectNumber::MaxObjectNumber(const unsigned &value) { set(value); }
47 ClassNumber::ClassNumber(const unsigned &value) { set(value); }
48 GridHeightNumber::GridHeightNumber(const unsigned &value) { set(value); }
49 GridWidthNumber::GridWidthNumber(const unsigned &value) { set(value); }
50 } // namespace props
51
52 /**
53  * @brief mse
54  *
55  * @param pred prediction
56  * @param ground_truth ground truth
57  * @return float loss
58  * @todo make loss behaves like acti_func
59  */
60 float mse(nntrainer::Tensor &pred, nntrainer::Tensor &ground_truth) {
61   nntrainer::Tensor residual;
62   pred.subtract(ground_truth, residual);
63
64   float l2norm = residual.l2norm();
65   l2norm *= l2norm / residual.size();
66
67   return l2norm;
68 }
69
70 /**
71  * @brief backwarding of mse
72  *
73  * @param pred prediction
74  * @param ground_truth ground truth
75  * @param outgoing_derivative outgoing derivative
76  */
77 void msePrime(nntrainer::Tensor &pred, nntrainer::Tensor &ground_truth,
78               nntrainer::Tensor &outgoing_derivative) {
79   pred.subtract(ground_truth, outgoing_derivative);
80   float divider = ((float)pred.size()) / 2;
81   if (outgoing_derivative.divide_i(divider) != ML_ERROR_NONE) {
82     throw std::runtime_error(
83       "[YoloV2LossLayer::calcDerivative] Error when calculating loss");
84   }
85 }
86
87 /**
88  * @brief calculate iou
89  *
90  * @param bbox1_x1 bbox1_x1
91  * @param bbox1_y1 bbox1_y1
92  * @param bbox1_w bbox1_w
93  * @param bbox1_h bbox1_h
94  * @param bbox2_x1 bbox2_x1
95  * @param bbox2_y1 bbox2_y1
96  * @param bbox2_w bbox2_w
97  * @param bbox2_h bbox2_h
98  * @param[out] bbox1_width bbox1 width
99  * @param[out] bbox1_height bbox1 height
100  * @param[out] is_xy_min_max For x1, y1 this value is 1 if x1 > x2, y1 > y2 and
101  * for x2, y2 this is value is 1 if x2 < x1, y2 < y1. else 0.
102  * @param[out] intersection_width intersection width
103  * @param[out] intersection_height intersection height
104  * @param[out] unions unions
105  * @return nntrainer::Tensor iou
106  */
107 nntrainer::Tensor
108 calc_iou(nntrainer::Tensor &bbox1_x1, nntrainer::Tensor &bbox1_y1,
109          nntrainer::Tensor &bbox1_w, nntrainer::Tensor &bbox1_h,
110          nntrainer::Tensor &bbox2_x1, nntrainer::Tensor &bbox2_y1,
111          nntrainer::Tensor &bbox2_w, nntrainer::Tensor &bbox2_h,
112          nntrainer::Tensor &bbox1_width, nntrainer::Tensor &bbox1_height,
113          nntrainer::Tensor &is_xy_min_max,
114          nntrainer::Tensor &intersection_width,
115          nntrainer::Tensor &intersection_height, nntrainer::Tensor &unions) {
116   nntrainer::Tensor bbox1_x2 = bbox1_x1.add(bbox1_w);
117   nntrainer::Tensor bbox1_y2 = bbox1_y1.add(bbox1_h);
118   nntrainer::Tensor bbox2_x2 = bbox2_x1.add(bbox2_w);
119   nntrainer::Tensor bbox2_y2 = bbox2_y1.add(bbox2_h);
120
121   bbox1_x2.subtract(bbox1_x1, bbox1_width);
122   bbox1_y2.subtract(bbox1_y1, bbox1_height);
123   nntrainer::Tensor bbox1 = bbox1_width.multiply(bbox1_height);
124
125   nntrainer::Tensor bbox2_width = bbox2_x2.subtract(bbox2_x1);
126   nntrainer::Tensor bbox2_height = bbox2_y2.subtract(bbox2_y1);
127   nntrainer::Tensor bbox2 = bbox2_width.multiply(bbox2_height);
128
129   auto min_func = [&](nntrainer::Tensor &bbox1_xy, nntrainer::Tensor &bbox2_xy,
130                       nntrainer::Tensor &intersection_xy) {
131     std::transform(bbox1_xy.getData(), bbox1_xy.getData() + bbox1_xy.size(),
132                    bbox2_xy.getData(), intersection_xy.getData(),
133                    [](float x1, float x2) { return std::min(x1, x2); });
134   };
135   auto max_func = [&](nntrainer::Tensor &bbox1_xy, nntrainer::Tensor &bbox2_xy,
136                       nntrainer::Tensor &intersection_xy) {
137     std::transform(bbox1_xy.getData(), bbox1_xy.getData() + bbox1_xy.size(),
138                    bbox2_xy.getData(), intersection_xy.getData(),
139                    [](float x1, float x2) { return std::max(x1, x2); });
140   };
141
142   nntrainer::Tensor intersection_x1(bbox1_x1.getDim());
143   nntrainer::Tensor intersection_x2(bbox1_x1.getDim());
144   nntrainer::Tensor intersection_y1(bbox1_y1.getDim());
145   nntrainer::Tensor intersection_y2(bbox1_y1.getDim());
146   max_func(bbox1_x1, bbox2_x1, intersection_x1);
147   min_func(bbox1_x2, bbox2_x2, intersection_x2);
148   max_func(bbox1_y1, bbox2_y1, intersection_y1);
149   min_func(bbox1_y2, bbox2_y2, intersection_y2);
150
151   auto is_min_max_func = [&](nntrainer::Tensor &xy,
152                              nntrainer::Tensor &intersection,
153                              nntrainer::Tensor &is_min_max) {
154     std::transform(xy.getData(), xy.getData() + xy.size(),
155                    intersection.getData(), is_min_max.getData(),
156                    [](float x, float m) {
157                      return nntrainer::absFloat(x - m) < 1e-4 ? 1.0 : 0.0;
158                    });
159   };
160
161   nntrainer::Tensor is_bbox1_x1_max(bbox1_x1.getDim());
162   nntrainer::Tensor is_bbox1_y1_max(bbox1_x1.getDim());
163   nntrainer::Tensor is_bbox1_x2_min(bbox1_x1.getDim());
164   nntrainer::Tensor is_bbox1_y2_min(bbox1_x1.getDim());
165   is_min_max_func(bbox1_x1, intersection_x1, is_bbox1_x1_max);
166   is_min_max_func(bbox1_y1, intersection_y1, is_bbox1_y1_max);
167   is_min_max_func(bbox1_x2, intersection_x2, is_bbox1_x2_min);
168   is_min_max_func(bbox1_y2, intersection_y2, is_bbox1_y2_min);
169
170   nntrainer::Tensor is_bbox_min_max = nntrainer::Tensor::cat(
171     {is_bbox1_x1_max, is_bbox1_y1_max, is_bbox1_x2_min, is_bbox1_y2_min}, 3);
172   is_xy_min_max.copyData(is_bbox_min_max);
173
174   intersection_x2.subtract(intersection_x1, intersection_width);
175   intersection_width.apply_i(nntrainer::ActiFunc::relu);
176   intersection_y2.subtract(intersection_y1, intersection_height);
177   intersection_height.apply_i(nntrainer::ActiFunc::relu);
178
179   nntrainer::Tensor intersection =
180     intersection_width.multiply(intersection_height);
181   bbox1.add(bbox2, unions);
182   unions.subtract_i(intersection);
183
184   return intersection.divide(unions);
185 }
186
187 /**
188  * @brief calculate iou graident
189  * @details Let say bbox_pred as x, intersection as f(x), union as g(x) and iou
190  * as y. Then y = f(x)/g(x). Also g(x) = bbox1 + bbox2 - f(x). Partial
191  * derivative of y with respect to x will be (f'(x)g(x) - f(x)g'(x))/(g(x)^2).
192  * Partial derivative of g(x) with respect to x will be bbox1'(x) - f'(x).
193  * @param confidence_gt_grad incoming derivative for iou
194  * @param bbox1_width bbox1_width
195  * @param bbox1_height bbox1_height
196  * @param is_xy_min_max For x1, y1 this value is 1 if x1 > x2, y1 > y2 and for
197  * x2, y2 this is value is 1 if x2 < x1, y2 < y1. else 0.
198  * @param intersection_width intersection width
199  * @param intersection_height intersection height
200  * @param unions unions
201  * @return std::vector<nntrainer::Tensor> iou_grad
202  */
203 std::vector<nntrainer::Tensor> calc_iou_grad(
204   nntrainer::Tensor &confidence_gt_grad, nntrainer::Tensor &bbox1_width,
205   nntrainer::Tensor &bbox1_height, nntrainer::Tensor &is_xy_min_max,
206   nntrainer::Tensor &intersection_width, nntrainer::Tensor &intersection_height,
207   nntrainer::Tensor &unions) {
208   nntrainer::Tensor intersection =
209     intersection_width.multiply(intersection_height);
210
211   // 1. calculate intersection local gradient [f'(x)]
212   nntrainer::Tensor intersection_width_relu_prime =
213     intersection_width.apply(nntrainer::ActiFunc::reluPrime);
214   nntrainer::Tensor intersection_height_relu_prime =
215     intersection_height.apply(nntrainer::ActiFunc::reluPrime);
216
217   nntrainer::Tensor intersection_x2_local_grad =
218     intersection_width_relu_prime.multiply(intersection_height);
219   nntrainer::Tensor intersection_y2_local_grad =
220     intersection_height_relu_prime.multiply(intersection_width);
221   nntrainer::Tensor intersection_x1_local_grad =
222     intersection_x2_local_grad.multiply(-1.0);
223   nntrainer::Tensor intersection_y1_local_grad =
224     intersection_y2_local_grad.multiply(-1.0);
225
226   nntrainer::Tensor intersection_local_grad = nntrainer::Tensor::cat(
227     {intersection_x1_local_grad, intersection_y1_local_grad,
228      intersection_x2_local_grad, intersection_y2_local_grad},
229     3);
230   intersection_local_grad.multiply_i(is_xy_min_max);
231
232   // 2. calculate union local gradient [g'(x)]
233   nntrainer::Tensor bbox1_x1_grad = bbox1_height.multiply(-1.0);
234   nntrainer::Tensor bbox1_y1_grad = bbox1_width.multiply(-1.0);
235   nntrainer::Tensor bbox1_x2_grad = bbox1_height;
236   nntrainer::Tensor bbox1_y2_grad = bbox1_width;
237   nntrainer::Tensor bbox1_grad = nntrainer::Tensor::cat(
238     {bbox1_x1_grad, bbox1_y1_grad, bbox1_x2_grad, bbox1_y2_grad}, 3);
239
240   nntrainer::Tensor unions_local_grad =
241     bbox1_grad.subtract(intersection_local_grad);
242
243   // 3. calculate iou local gradient [(f'(x)g(x) - f(x)g'(x))/(g(x)^2)]
244   nntrainer::Tensor lhs = intersection_local_grad.multiply(unions);
245   nntrainer::Tensor rhs = unions_local_grad.multiply(intersection);
246   nntrainer::Tensor iou_grad = lhs.subtract(rhs);
247   iou_grad.divide_i(unions);
248   iou_grad.divide_i(unions);
249
250   // 3. multiply with incoming derivative
251   iou_grad.multiply_i(confidence_gt_grad);
252
253   auto splitted_iou_grad = iou_grad.split({1, 1, 1, 1}, 3);
254   std::vector<nntrainer::Tensor> ret = {
255     splitted_iou_grad[0].add(splitted_iou_grad[2]),
256     splitted_iou_grad[1].add(splitted_iou_grad[3]), splitted_iou_grad[2],
257     splitted_iou_grad[3]};
258   return ret;
259 }
260
261 YoloV2LossLayer::YoloV2LossLayer() :
262   anchors_w({1, 1, NUM_ANCHOR, 1}, anchors_w_buf),
263   anchors_h({1, 1, NUM_ANCHOR, 1}, anchors_h_buf),
264   sigmoid(nntrainer::ActivationType::ACT_SIGMOID, true),
265   softmax(nntrainer::ActivationType::ACT_SOFTMAX, true),
266   yolo_v2_loss_props(props::MaxObjectNumber(), props::ClassNumber(),
267                      props::GridHeightNumber(), props::GridWidthNumber()) {
268   anchors_ratio = anchors_w.divide(anchors_h);
269   wt_idx.fill(std::numeric_limits<unsigned>::max());
270 }
271
272 void YoloV2LossLayer::finalize(nntrainer::InitLayerContext &context) {
273   nntrainer::TensorDim input_dim =
274     context.getInputDimensions()[SINGLE_INOUT_IDX];
275   const unsigned int batch_size = input_dim.batch();
276   const unsigned int class_number =
277     std::get<props::ClassNumber>(yolo_v2_loss_props).get();
278   const unsigned int grid_height_number =
279     std::get<props::GridHeightNumber>(yolo_v2_loss_props).get();
280   const unsigned int grid_width_number =
281     std::get<props::GridWidthNumber>(yolo_v2_loss_props).get();
282   const unsigned int max_object_number =
283     std::get<props::MaxObjectNumber>(yolo_v2_loss_props).get();
284   nntrainer::TensorDim label_dim(batch_size, 1, max_object_number, 5);
285   context.setOutputDimensions({label_dim});
286
287   nntrainer::TensorDim bbox_x_pred_dim(
288     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
289   wt_idx[YoloV2LossParams::bbox_x_pred] = context.requestTensor(
290     bbox_x_pred_dim, "bbox_x_pred", nntrainer::Tensor::Initializer::NONE, true,
291     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
292
293   nntrainer::TensorDim bbox_y_pred_dim(
294     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
295   wt_idx[YoloV2LossParams::bbox_y_pred] = context.requestTensor(
296     bbox_y_pred_dim, "bbox_y_pred", nntrainer::Tensor::Initializer::NONE, true,
297     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
298
299   nntrainer::TensorDim bbox_w_pred_dim(
300     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
301   wt_idx[YoloV2LossParams::bbox_w_pred] = context.requestTensor(
302     bbox_w_pred_dim, "bbox_w_pred", nntrainer::Tensor::Initializer::NONE, true,
303     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
304
305   nntrainer::TensorDim bbox_h_pred_dim(
306     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
307   wt_idx[YoloV2LossParams::bbox_h_pred] = context.requestTensor(
308     bbox_h_pred_dim, "bbox_h_pred", nntrainer::Tensor::Initializer::NONE, true,
309     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
310
311   nntrainer::TensorDim confidence_pred_dim(
312     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
313   wt_idx[YoloV2LossParams::confidence_pred] =
314     context.requestTensor(confidence_pred_dim, "confidence_pred",
315                           nntrainer::Tensor::Initializer::NONE, true,
316                           nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
317
318   nntrainer::TensorDim class_pred_dim(batch_size,
319                                       grid_height_number * grid_width_number,
320                                       NUM_ANCHOR, class_number);
321   wt_idx[YoloV2LossParams::class_pred] = context.requestTensor(
322     class_pred_dim, "class_pred", nntrainer::Tensor::Initializer::NONE, true,
323     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
324
325   nntrainer::TensorDim bbox_w_pred_anchor_dim(
326     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
327   wt_idx[YoloV2LossParams::bbox_w_pred_anchor] =
328     context.requestTensor(bbox_w_pred_anchor_dim, "bbox_w_pred_anchor",
329                           nntrainer::Tensor::Initializer::NONE, false,
330                           nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
331
332   nntrainer::TensorDim bbox_h_pred_anchor_dim(
333     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
334   wt_idx[YoloV2LossParams::bbox_h_pred_anchor] =
335     context.requestTensor(bbox_h_pred_anchor_dim, "bbox_h_pred_anchor",
336                           nntrainer::Tensor::Initializer::NONE, false,
337                           nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
338
339   nntrainer::TensorDim bbox_x_gt_dim(
340     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
341   wt_idx[YoloV2LossParams::bbox_x_gt] = context.requestTensor(
342     bbox_x_gt_dim, "bbox_x_gt", nntrainer::Tensor::Initializer::NONE, false,
343     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
344
345   nntrainer::TensorDim bbox_y_gt_dim(
346     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
347   wt_idx[YoloV2LossParams::bbox_y_gt] = context.requestTensor(
348     bbox_y_gt_dim, "bbox_y_gt", nntrainer::Tensor::Initializer::NONE, false,
349     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
350
351   nntrainer::TensorDim bbox_w_gt_dim(
352     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
353   wt_idx[YoloV2LossParams::bbox_w_gt] = context.requestTensor(
354     bbox_w_gt_dim, "bbox_w_gt", nntrainer::Tensor::Initializer::NONE, false,
355     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
356
357   nntrainer::TensorDim bbox_h_gt_dim(
358     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
359   wt_idx[YoloV2LossParams::bbox_h_gt] = context.requestTensor(
360     bbox_h_gt_dim, "bbox_h_gt", nntrainer::Tensor::Initializer::NONE, false,
361     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
362
363   nntrainer::TensorDim confidence_gt_dim(
364     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
365   wt_idx[YoloV2LossParams::confidence_gt] = context.requestTensor(
366     confidence_gt_dim, "confidence_gt", nntrainer::Tensor::Initializer::NONE,
367     false, nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
368
369   nntrainer::TensorDim class_gt_dim(batch_size,
370                                     grid_height_number * grid_width_number,
371                                     NUM_ANCHOR, class_number);
372   wt_idx[YoloV2LossParams::class_gt] = context.requestTensor(
373     class_gt_dim, "class_gt", nntrainer::Tensor::Initializer::NONE, false,
374     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
375
376   nntrainer::TensorDim bbox_class_mask_dim(
377     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
378   wt_idx[YoloV2LossParams::bbox_class_mask] =
379     context.requestTensor(bbox_class_mask_dim, "bbox_class_mask",
380                           nntrainer::Tensor::Initializer::NONE, false,
381                           nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
382
383   nntrainer::TensorDim iou_mask_dim(
384     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
385   wt_idx[YoloV2LossParams::iou_mask] = context.requestTensor(
386     iou_mask_dim, "iou_mask", nntrainer::Tensor::Initializer::NONE, false,
387     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
388
389   nntrainer::TensorDim bbox1_width_dim(
390     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
391   wt_idx[YoloV2LossParams::bbox1_width] = context.requestTensor(
392     bbox1_width_dim, "bbox1_width", nntrainer::Tensor::Initializer::NONE, false,
393     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
394
395   nntrainer::TensorDim bbox1_height_dim(
396     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
397   wt_idx[YoloV2LossParams::bbox1_height] = context.requestTensor(
398     bbox1_height_dim, "bbox1_height", nntrainer::Tensor::Initializer::NONE,
399     false, nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
400
401   nntrainer::TensorDim is_xy_min_max_dim(
402     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 4);
403   wt_idx[YoloV2LossParams::is_xy_min_max] = context.requestTensor(
404     is_xy_min_max_dim, "is_xy_min_max", nntrainer::Tensor::Initializer::NONE,
405     false, nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
406
407   nntrainer::TensorDim intersection_width_dim(
408     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
409   wt_idx[YoloV2LossParams::intersection_width] =
410     context.requestTensor(intersection_width_dim, "intersection_width",
411                           nntrainer::Tensor::Initializer::NONE, false,
412                           nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
413
414   nntrainer::TensorDim intersection_height_dim(
415     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
416   wt_idx[YoloV2LossParams::intersection_height] =
417     context.requestTensor(intersection_height_dim, "intersection_height",
418                           nntrainer::Tensor::Initializer::NONE, false,
419                           nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
420
421   nntrainer::TensorDim unions_dim(
422     batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
423   wt_idx[YoloV2LossParams::unions] = context.requestTensor(
424     unions_dim, "unions", nntrainer::Tensor::Initializer::NONE, false,
425     nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
426 }
427
428 void YoloV2LossLayer::forwarding(nntrainer::RunLayerContext &context,
429                                  bool training) {
430   const unsigned int max_object_number =
431     std::get<props::MaxObjectNumber>(yolo_v2_loss_props).get();
432
433   nntrainer::Tensor &input = context.getInput(SINGLE_INOUT_IDX);
434
435   std::vector<nntrainer::Tensor> splited_input =
436     input.split({1, 1, 1, 1, 1, max_object_number}, 3);
437   nntrainer::Tensor bbox_x_pred_ = splited_input[0];
438   nntrainer::Tensor bbox_y_pred_ = splited_input[1];
439   nntrainer::Tensor bbox_w_pred_ = splited_input[2];
440   nntrainer::Tensor bbox_h_pred_ = splited_input[3];
441   nntrainer::Tensor confidence_pred_ = splited_input[4];
442   nntrainer::Tensor class_pred_ = splited_input[5];
443
444   nntrainer::Tensor &bbox_x_pred =
445     context.getTensor(wt_idx[YoloV2LossParams::bbox_x_pred]);
446   nntrainer::Tensor &bbox_y_pred =
447     context.getTensor(wt_idx[YoloV2LossParams::bbox_y_pred]);
448   nntrainer::Tensor &bbox_w_pred =
449     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred]);
450   nntrainer::Tensor &bbox_h_pred =
451     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred]);
452
453   nntrainer::Tensor &confidence_pred =
454     context.getTensor(wt_idx[YoloV2LossParams::confidence_pred]);
455   nntrainer::Tensor &class_pred =
456     context.getTensor(wt_idx[YoloV2LossParams::class_pred]);
457
458   nntrainer::Tensor &bbox_w_pred_anchor =
459     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred_anchor]);
460   nntrainer::Tensor &bbox_h_pred_anchor =
461     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred_anchor]);
462
463   bbox_x_pred.copyData(bbox_x_pred_);
464   bbox_y_pred.copyData(bbox_y_pred_);
465   bbox_w_pred.copyData(bbox_w_pred_);
466   bbox_h_pred.copyData(bbox_h_pred_);
467
468   confidence_pred.copyData(confidence_pred_);
469   class_pred.copyData(class_pred_);
470
471   nntrainer::Tensor &bbox_x_gt =
472     context.getTensor(wt_idx[YoloV2LossParams::bbox_x_gt]);
473   nntrainer::Tensor &bbox_y_gt =
474     context.getTensor(wt_idx[YoloV2LossParams::bbox_y_gt]);
475   nntrainer::Tensor &bbox_w_gt =
476     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_gt]);
477   nntrainer::Tensor &bbox_h_gt =
478     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_gt]);
479
480   nntrainer::Tensor &confidence_gt =
481     context.getTensor(wt_idx[YoloV2LossParams::confidence_gt]);
482   nntrainer::Tensor &class_gt =
483     context.getTensor(wt_idx[YoloV2LossParams::class_gt]);
484
485   nntrainer::Tensor &bbox_class_mask =
486     context.getTensor(wt_idx[YoloV2LossParams::bbox_class_mask]);
487   nntrainer::Tensor &iou_mask =
488     context.getTensor(wt_idx[YoloV2LossParams::iou_mask]);
489
490   bbox_x_gt.setValue(0);
491   bbox_y_gt.setValue(0);
492   bbox_w_gt.setValue(0);
493   bbox_h_gt.setValue(0);
494
495   confidence_gt.setValue(0);
496   class_gt.setValue(0);
497
498   // init mask
499   bbox_class_mask.setValue(0);
500   iou_mask.setValue(0.5);
501
502   // activate pred
503   sigmoid.run_fn(bbox_x_pred, bbox_x_pred);
504   sigmoid.run_fn(bbox_y_pred, bbox_y_pred);
505   bbox_w_pred.apply_i(nntrainer::exp_util);
506   bbox_h_pred.apply_i(nntrainer::exp_util);
507   sigmoid.run_fn(confidence_pred, confidence_pred);
508   softmax.run_fn(class_pred, class_pred);
509
510   bbox_w_pred_anchor.copyData(bbox_w_pred);
511   bbox_h_pred_anchor.copyData(bbox_h_pred);
512
513   // apply anchors to bounding box
514   bbox_w_pred_anchor.multiply_i(anchors_w);
515   bbox_h_pred_anchor.multiply_i(anchors_h);
516   bbox_w_pred_anchor.apply_i(nntrainer::sqrtFloat);
517   bbox_h_pred_anchor.apply_i(nntrainer::sqrtFloat);
518
519   generate_ground_truth(context);
520
521   nntrainer::Tensor bbox_pred = nntrainer::Tensor::cat(
522     {bbox_x_pred, bbox_y_pred, bbox_w_pred_anchor, bbox_h_pred_anchor}, 3);
523   nntrainer::Tensor masked_bbox_pred = bbox_pred.multiply(bbox_class_mask);
524   nntrainer::Tensor masked_confidence_pred = confidence_pred.multiply(iou_mask);
525   nntrainer::Tensor masked_class_pred = class_pred.multiply(bbox_class_mask);
526
527   nntrainer::Tensor bbox_gt =
528     nntrainer::Tensor::cat({bbox_x_gt, bbox_y_gt, bbox_w_gt, bbox_h_gt}, 3);
529   nntrainer::Tensor masked_bbox_gt = bbox_gt.multiply(bbox_class_mask);
530   nntrainer::Tensor masked_confidence_gt = confidence_gt.multiply(iou_mask);
531   nntrainer::Tensor masked_class_gt = class_gt.multiply(bbox_class_mask);
532
533   float bbox_loss = mse(masked_bbox_pred, masked_bbox_gt);
534   float confidence_loss = mse(masked_confidence_pred, masked_confidence_gt);
535   float class_loss = mse(masked_class_pred, masked_class_gt);
536
537   float loss = 5 * bbox_loss + confidence_loss + class_loss;
538 }
539
540 void YoloV2LossLayer::calcDerivative(nntrainer::RunLayerContext &context) {
541   nntrainer::Tensor &bbox_x_pred =
542     context.getTensor(wt_idx[YoloV2LossParams::bbox_x_pred]);
543   nntrainer::Tensor &bbox_x_pred_grad =
544     context.getTensorGrad(wt_idx[YoloV2LossParams::bbox_x_pred]);
545   nntrainer::Tensor &bbox_y_pred =
546     context.getTensor(wt_idx[YoloV2LossParams::bbox_y_pred]);
547   nntrainer::Tensor &bbox_y_pred_grad =
548     context.getTensorGrad(wt_idx[YoloV2LossParams::bbox_y_pred]);
549   nntrainer::Tensor &bbox_w_pred =
550     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred]);
551   nntrainer::Tensor &bbox_w_pred_grad =
552     context.getTensorGrad(wt_idx[YoloV2LossParams::bbox_w_pred]);
553   nntrainer::Tensor &bbox_h_pred =
554     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred]);
555   nntrainer::Tensor &bbox_h_pred_grad =
556     context.getTensorGrad(wt_idx[YoloV2LossParams::bbox_h_pred]);
557
558   nntrainer::Tensor &confidence_pred =
559     context.getTensor(wt_idx[YoloV2LossParams::confidence_pred]);
560   nntrainer::Tensor &confidence_pred_grad =
561     context.getTensorGrad(wt_idx[YoloV2LossParams::confidence_pred]);
562   nntrainer::Tensor &class_pred =
563     context.getTensor(wt_idx[YoloV2LossParams::class_pred]);
564   nntrainer::Tensor &class_pred_grad =
565     context.getTensorGrad(wt_idx[YoloV2LossParams::class_pred]);
566
567   nntrainer::Tensor &bbox_w_pred_anchor =
568     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred_anchor]);
569   nntrainer::Tensor &bbox_h_pred_anchor =
570     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred_anchor]);
571
572   nntrainer::Tensor &bbox_x_gt =
573     context.getTensor(wt_idx[YoloV2LossParams::bbox_x_gt]);
574   nntrainer::Tensor &bbox_y_gt =
575     context.getTensor(wt_idx[YoloV2LossParams::bbox_y_gt]);
576   nntrainer::Tensor &bbox_w_gt =
577     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_gt]);
578   nntrainer::Tensor &bbox_h_gt =
579     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_gt]);
580
581   nntrainer::Tensor &confidence_gt =
582     context.getTensor(wt_idx[YoloV2LossParams::confidence_gt]);
583   nntrainer::Tensor &class_gt =
584     context.getTensor(wt_idx[YoloV2LossParams::class_gt]);
585
586   nntrainer::Tensor &bbox_class_mask =
587     context.getTensor(wt_idx[YoloV2LossParams::bbox_class_mask]);
588   nntrainer::Tensor &iou_mask =
589     context.getTensor(wt_idx[YoloV2LossParams::iou_mask]);
590
591   nntrainer::Tensor &bbox1_width =
592     context.getTensor(wt_idx[YoloV2LossParams::bbox1_width]);
593   nntrainer::Tensor &bbox1_height =
594     context.getTensor(wt_idx[YoloV2LossParams::bbox1_height]);
595   nntrainer::Tensor &is_xy_min_max =
596     context.getTensor(wt_idx[YoloV2LossParams::is_xy_min_max]);
597   nntrainer::Tensor &intersection_width =
598     context.getTensor(wt_idx[YoloV2LossParams::intersection_width]);
599   nntrainer::Tensor &intersection_height =
600     context.getTensor(wt_idx[YoloV2LossParams::intersection_height]);
601   nntrainer::Tensor &unions =
602     context.getTensor(wt_idx[YoloV2LossParams::unions]);
603
604   nntrainer::Tensor bbox_pred = nntrainer::Tensor::cat(
605     {bbox_x_pred, bbox_y_pred, bbox_w_pred_anchor, bbox_h_pred_anchor}, 3);
606   nntrainer::Tensor masked_bbox_pred = bbox_pred.multiply(bbox_class_mask);
607   nntrainer::Tensor masked_confidence_pred = confidence_pred.multiply(iou_mask);
608   nntrainer::Tensor masked_class_pred = class_pred.multiply(bbox_class_mask);
609
610   nntrainer::Tensor bbox_gt =
611     nntrainer::Tensor::cat({bbox_x_gt, bbox_y_gt, bbox_w_gt, bbox_h_gt}, 3);
612   nntrainer::Tensor masked_bbox_gt = bbox_gt.multiply(bbox_class_mask);
613   nntrainer::Tensor masked_confidence_gt = confidence_gt.multiply(iou_mask);
614   nntrainer::Tensor masked_class_gt = class_gt.multiply(bbox_class_mask);
615
616   nntrainer::Tensor masked_bbox_pred_grad;
617   nntrainer::Tensor masked_confidence_pred_grad;
618   nntrainer::Tensor masked_confidence_gt_grad;
619   nntrainer::Tensor masked_class_pred_grad;
620
621   nntrainer::Tensor confidence_gt_grad;
622
623   msePrime(masked_bbox_pred, masked_bbox_gt, masked_bbox_pred_grad);
624   msePrime(masked_confidence_pred, masked_confidence_gt,
625            masked_confidence_pred_grad);
626   msePrime(masked_confidence_gt, masked_confidence_pred,
627            masked_confidence_gt_grad);
628   msePrime(masked_class_pred, masked_class_gt, masked_class_pred_grad);
629
630   masked_bbox_pred_grad.multiply_i(5);
631
632   nntrainer::Tensor bbox_pred_grad;
633
634   masked_bbox_pred_grad.multiply(bbox_class_mask, bbox_pred_grad);
635   masked_confidence_pred_grad.multiply(iou_mask, confidence_pred_grad);
636   masked_confidence_gt_grad.multiply(iou_mask, confidence_gt_grad);
637   masked_class_pred_grad.multiply(bbox_class_mask, class_pred_grad);
638
639   std::vector<nntrainer::Tensor> splitted_bbox_pred_grad =
640     bbox_pred_grad.split({1, 1, 1, 1}, 3);
641   bbox_x_pred_grad.copyData(splitted_bbox_pred_grad[0]);
642   bbox_y_pred_grad.copyData(splitted_bbox_pred_grad[1]);
643   bbox_w_pred_grad.copyData(splitted_bbox_pred_grad[2]);
644   bbox_h_pred_grad.copyData(splitted_bbox_pred_grad[3]);
645
646   // std::vector<nntrainer::Tensor> bbox_pred_iou_grad =
647   //   calc_iou_grad(confidence_gt_grad, bbox1_width, bbox1_height,
648   //   is_xy_min_max,
649   //                 intersection_width, intersection_height, unions);
650   // bbox_x_pred_grad.add_i(bbox_pred_iou_grad[0]);
651   // bbox_y_pred_grad.add_i(bbox_pred_iou_grad[1]);
652   // bbox_w_pred_grad.add_i(bbox_pred_iou_grad[2]);
653   // bbox_h_pred_grad.add_i(bbox_pred_iou_grad[3]);
654
655   /**
656    * @brief calculate gradient for applying anchors to bounding box
657    * @details Let say bbox_pred as x, anchor as c indicated that anchor is
658    * constant for bbox_pred and bbox_pred_anchor as y. Then we can denote y =
659    * sqrt(cx). Partial derivative of y with respect to x will be
660    * sqrt(c)/(2*sqrt(x)) which is equivalent to sqrt(cx)/(2x) and we can replace
661    * sqrt(cx) with y.
662    * @note divide by bbox_pred(x) will not be executed because bbox_pred_grad
663    * will be multiply by bbox_pred(x) soon after.
664    */
665   bbox_w_pred_grad.multiply_i(bbox_w_pred_anchor);
666   bbox_h_pred_grad.multiply_i(bbox_h_pred_anchor);
667   /** intended comment */
668   // bbox_w_pred_grad.divide_i(bbox_w_pred);
669   // bbox_h_pred_grad.divide_i(bbox_h_pred);
670   bbox_w_pred_grad.divide_i(2);
671   bbox_h_pred_grad.divide_i(2);
672
673   sigmoid.run_prime_fn(bbox_x_pred, bbox_x_pred, bbox_x_pred_grad,
674                        bbox_x_pred_grad);
675   sigmoid.run_prime_fn(bbox_y_pred, bbox_y_pred, bbox_y_pred_grad,
676                        bbox_y_pred_grad);
677   /** intended comment */
678   // bbox_w_pred_grad.multiply_i(bbox_w_pred);
679   // bbox_h_pred_grad.multiply_i(bbox_h_pred);
680   sigmoid.run_prime_fn(confidence_pred, confidence_pred, confidence_pred_grad,
681                        confidence_pred_grad);
682   softmax.run_prime_fn(class_pred, class_pred, class_pred_grad,
683                        class_pred_grad);
684
685   nntrainer::Tensor outgoing_derivative_ = nntrainer::Tensor::cat(
686     {bbox_x_pred_grad, bbox_y_pred_grad, bbox_w_pred_grad, bbox_h_pred_grad,
687      confidence_pred_grad, class_pred_grad},
688     3);
689   nntrainer::Tensor &outgoing_derivative =
690     context.getOutgoingDerivative(SINGLE_INOUT_IDX);
691   outgoing_derivative.copyData(outgoing_derivative_);
692 }
693
694 void YoloV2LossLayer::exportTo(nntrainer::Exporter &exporter,
695                                const ml::train::ExportMethods &method) const {
696   exporter.saveResult(yolo_v2_loss_props, method, this);
697 }
698
699 void YoloV2LossLayer::setProperty(const std::vector<std::string> &values) {
700   auto remain_props = loadProperties(values, yolo_v2_loss_props);
701   NNTR_THROW_IF(!remain_props.empty(), std::invalid_argument)
702     << "[YoloV2LossLayer] Unknown Layer Properties count " +
703          std::to_string(values.size());
704 }
705
706 void YoloV2LossLayer::setBatch(nntrainer::RunLayerContext &context,
707                                unsigned int batch) {
708   context.updateTensor(wt_idx[YoloV2LossParams::bbox_x_pred], batch);
709   context.updateTensor(wt_idx[YoloV2LossParams::bbox_y_pred], batch);
710   context.updateTensor(wt_idx[YoloV2LossParams::bbox_w_pred], batch);
711   context.updateTensor(wt_idx[YoloV2LossParams::bbox_h_pred], batch);
712   context.updateTensor(wt_idx[YoloV2LossParams::confidence_pred], batch);
713   context.updateTensor(wt_idx[YoloV2LossParams::class_pred], batch);
714   context.updateTensor(wt_idx[YoloV2LossParams::bbox_w_pred_anchor], batch);
715   context.updateTensor(wt_idx[YoloV2LossParams::bbox_h_pred_anchor], batch);
716
717   context.updateTensor(wt_idx[YoloV2LossParams::bbox_x_gt], batch);
718   context.updateTensor(wt_idx[YoloV2LossParams::bbox_y_gt], batch);
719   context.updateTensor(wt_idx[YoloV2LossParams::bbox_w_gt], batch);
720   context.updateTensor(wt_idx[YoloV2LossParams::bbox_h_gt], batch);
721   context.updateTensor(wt_idx[YoloV2LossParams::confidence_gt], batch);
722   context.updateTensor(wt_idx[YoloV2LossParams::class_gt], batch);
723   context.updateTensor(wt_idx[YoloV2LossParams::bbox_class_mask], batch);
724   context.updateTensor(wt_idx[YoloV2LossParams::iou_mask], batch);
725
726   context.updateTensor(wt_idx[YoloV2LossParams::bbox1_width], batch);
727   context.updateTensor(wt_idx[YoloV2LossParams::bbox1_height], batch);
728   context.updateTensor(wt_idx[YoloV2LossParams::is_xy_min_max], batch);
729   context.updateTensor(wt_idx[YoloV2LossParams::intersection_width], batch);
730   context.updateTensor(wt_idx[YoloV2LossParams::intersection_height], batch);
731   context.updateTensor(wt_idx[YoloV2LossParams::unions], batch);
732 }
733
734 unsigned int YoloV2LossLayer::find_responsible_anchors(float bbox_ratio) {
735   nntrainer::Tensor similarity = anchors_ratio.subtract(bbox_ratio);
736   similarity.apply_i(nntrainer::absFloat);
737   auto data = similarity.getData();
738
739   auto min_iter = std::min_element(data, data + NUM_ANCHOR);
740   return std::distance(data, min_iter);
741 }
742
743 void YoloV2LossLayer::generate_ground_truth(
744   nntrainer::RunLayerContext &context) {
745   const unsigned int max_object_number =
746     std::get<props::MaxObjectNumber>(yolo_v2_loss_props).get();
747   const unsigned int grid_height_number =
748     std::get<props::GridHeightNumber>(yolo_v2_loss_props).get();
749   const unsigned int grid_width_number =
750     std::get<props::GridWidthNumber>(yolo_v2_loss_props).get();
751
752   nntrainer::Tensor &label = context.getLabel(SINGLE_INOUT_IDX);
753
754   nntrainer::Tensor &bbox_x_pred =
755     context.getTensor(wt_idx[YoloV2LossParams::bbox_x_pred]);
756   nntrainer::Tensor &bbox_y_pred =
757     context.getTensor(wt_idx[YoloV2LossParams::bbox_y_pred]);
758   nntrainer::Tensor &bbox_w_pred_anchor =
759     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred_anchor]);
760   nntrainer::Tensor &bbox_h_pred_anchor =
761     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred_anchor]);
762
763   nntrainer::Tensor &bbox_x_gt =
764     context.getTensor(wt_idx[YoloV2LossParams::bbox_x_gt]);
765   nntrainer::Tensor &bbox_y_gt =
766     context.getTensor(wt_idx[YoloV2LossParams::bbox_y_gt]);
767   nntrainer::Tensor &bbox_w_gt =
768     context.getTensor(wt_idx[YoloV2LossParams::bbox_w_gt]);
769   nntrainer::Tensor &bbox_h_gt =
770     context.getTensor(wt_idx[YoloV2LossParams::bbox_h_gt]);
771
772   nntrainer::Tensor &confidence_gt =
773     context.getTensor(wt_idx[YoloV2LossParams::confidence_gt]);
774   nntrainer::Tensor &class_gt =
775     context.getTensor(wt_idx[YoloV2LossParams::class_gt]);
776
777   nntrainer::Tensor &bbox_class_mask =
778     context.getTensor(wt_idx[YoloV2LossParams::bbox_class_mask]);
779   nntrainer::Tensor &iou_mask =
780     context.getTensor(wt_idx[YoloV2LossParams::iou_mask]);
781
782   nntrainer::Tensor &bbox1_width =
783     context.getTensor(wt_idx[YoloV2LossParams::bbox1_width]);
784   nntrainer::Tensor &bbox1_height =
785     context.getTensor(wt_idx[YoloV2LossParams::bbox1_height]);
786   nntrainer::Tensor &is_xy_min_max =
787     context.getTensor(wt_idx[YoloV2LossParams::is_xy_min_max]);
788   nntrainer::Tensor &intersection_width =
789     context.getTensor(wt_idx[YoloV2LossParams::intersection_width]);
790   nntrainer::Tensor &intersection_height =
791     context.getTensor(wt_idx[YoloV2LossParams::intersection_height]);
792   nntrainer::Tensor &unions =
793     context.getTensor(wt_idx[YoloV2LossParams::unions]);
794
795   const unsigned int batch_size = bbox_x_pred.getDim().batch();
796
797   std::vector<nntrainer::Tensor> splited_label =
798     label.split({1, 1, 1, 1, 1}, 3);
799   nntrainer::Tensor bbox_x_label = splited_label[0];
800   nntrainer::Tensor bbox_y_label = splited_label[1];
801   nntrainer::Tensor bbox_w_label = splited_label[2];
802   nntrainer::Tensor bbox_h_label = splited_label[3];
803   nntrainer::Tensor class_label = splited_label[4];
804
805   bbox_x_label.multiply_i(grid_width_number);
806   bbox_y_label.multiply_i(grid_height_number);
807
808   for (unsigned int batch = 0; batch < batch_size; ++batch) {
809     for (unsigned int object = 0; object < max_object_number; ++object) {
810       if (!bbox_w_label.getValue(batch, 0, object, 0) &&
811           !bbox_h_label.getValue(batch, 0, object, 0)) {
812         break;
813       }
814       unsigned int grid_x_index = bbox_x_label.getValue(batch, 0, object, 0);
815       unsigned int grid_y_index = bbox_y_label.getValue(batch, 0, object, 0);
816       unsigned int grid_index = grid_y_index * grid_width_number + grid_x_index;
817       unsigned int responsible_anchor =
818         find_responsible_anchors(bbox_w_label.getValue(batch, 0, object, 0) /
819                                  bbox_h_label.getValue(batch, 0, object, 0));
820
821       bbox_x_gt.setValue(batch, grid_index, responsible_anchor, 0,
822                          bbox_x_label.getValue(batch, 0, object, 0) -
823                            grid_x_index);
824       bbox_y_gt.setValue(batch, grid_index, responsible_anchor, 0,
825                          bbox_y_label.getValue(batch, 0, object, 0) -
826                            grid_y_index);
827       bbox_w_gt.setValue(
828         batch, grid_index, responsible_anchor, 0,
829         nntrainer::sqrtFloat(bbox_w_label.getValue(batch, 0, object, 0)));
830       bbox_h_gt.setValue(
831         batch, grid_index, responsible_anchor, 0,
832         nntrainer::sqrtFloat(bbox_h_label.getValue(batch, 0, object, 0)));
833
834       class_gt.setValue(batch, grid_index, responsible_anchor,
835                         class_label.getValue(batch, 0, object, 0), 1);
836       bbox_class_mask.setValue(batch, grid_index, responsible_anchor, 0, 1);
837       iou_mask.setValue(batch, grid_index, responsible_anchor, 0, 1);
838     }
839   }
840
841   nntrainer::Tensor iou = calc_iou(
842     bbox_x_pred, bbox_y_pred, bbox_w_pred_anchor, bbox_h_pred_anchor, bbox_x_gt,
843     bbox_y_gt, bbox_w_gt, bbox_h_gt, bbox1_width, bbox1_height, is_xy_min_max,
844     intersection_width, intersection_height, unions);
845   confidence_gt.copyData(iou);
846 }
847
848 #ifdef PLUGGABLE
849
850 nntrainer::Layer *create_yolo_v2_loss_layer() {
851   auto layer = new YoloV2LossLayer();
852   return layer;
853 }
854
855 void destory_yolo_v2_loss_layer(nntrainer::Layer *layer) { delete layer; }
856
857 /**
858  * @note ml_train_layer_pluggable defines the entry point for nntrainer to
859  * register a plugin layer
860  */
861 extern "C" {
862 nntrainer::LayerPluggable ml_train_layer_pluggable{create_yolo_v2_loss_layer,
863                                                    destory_yolo_v2_loss_layer};
864 }
865
866 #endif
867 } // namespace custom