[Application] implement yolo v2 forward
authorhyeonseok lee <hs89.lee@samsung.com>
Wed, 22 Mar 2023 04:44:40 +0000 (13:44 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 12 Sep 2023 15:21:01 +0000 (00:21 +0900)
 - Implement yolo v2 forward

Signed-off-by: hyeonseok lee <hs89.lee@samsung.com>
Applications/YOLO/jni/yolo_v2_loss.cpp
Applications/YOLO/jni/yolo_v2_loss.h
nntrainer/utils/util_func.cpp
nntrainer/utils/util_func.h

index 875ae60..5366bc7 100644 (file)
@@ -33,7 +33,13 @@ enum YoloV2LossParams {
   confidence_gt,
   class_gt,
   bbox_class_mask,
-  iou_mask
+  iou_mask,
+  bbox1_width,
+  bbox1_height,
+  is_xy_min_max,
+  intersection_width,
+  intersection_height,
+  unions,
 };
 
 namespace props {
@@ -41,29 +47,432 @@ MaxObjectNumber::MaxObjectNumber(const unsigned &value) { set(value); }
 ClassNumber::ClassNumber(const unsigned &value) { set(value); }
 GridHeightNumber::GridHeightNumber(const unsigned &value) { set(value); }
 GridWidthNumber::GridWidthNumber(const unsigned &value) { set(value); }
-ImageHeightSize::ImageHeightSize(const unsigned &value) { set(value); }
-ImageWidthSize::ImageWidthSize(const unsigned &value) { set(value); }
 } // namespace props
 
+/**
+ * @brief mse
+ *
+ * @param pred prediction
+ * @param ground_truth ground truth
+ * @return float loss
+ * @todo make loss behaves like acti_func
+ */
+float mse(nntrainer::Tensor &pred, nntrainer::Tensor &ground_truth) {
+  nntrainer::Tensor residual;
+  pred.subtract(ground_truth, residual);
+
+  float l2norm = residual.l2norm();
+  l2norm *= l2norm / residual.size();
+
+  return l2norm;
+}
+
+/**
+ * @brief backwarding of mse
+ *
+ * @param pred prediction
+ * @param ground_truth ground truth
+ * @param outgoing_derivative outgoing derivative
+ */
+void msePrime(nntrainer::Tensor &pred, nntrainer::Tensor &ground_truth,
+              nntrainer::Tensor &outgoing_derivative) {
+  /** NYI */
+}
+
+/**
+ * @brief calculate iou
+ *
+ * @param bbox1_x1 bbox1_x1
+ * @param bbox1_y1 bbox1_y1
+ * @param bbox1_w bbox1_w
+ * @param bbox1_h bbox1_h
+ * @param bbox2_x1 bbox2_x1
+ * @param bbox2_y1 bbox2_y1
+ * @param bbox2_w bbox2_w
+ * @param bbox2_h bbox2_h
+ * @param[out] bbox1_width bbox1 width
+ * @param[out] bbox1_height bbox1 height
+ * @param[out] is_xy_min_max For x1, y1 this value is 1 if x1 > x2, y1 > y2 and
+ * for x2, y2 this is value is 1 if x2 < x1, y2 < y1. else 0.
+ * @param[out] intersection_width intersection width
+ * @param[out] intersection_height intersection height
+ * @param[out] unions unions
+ * @return nntrainer::Tensor iou
+ */
+nntrainer::Tensor
+calc_iou(nntrainer::Tensor &bbox1_x1, nntrainer::Tensor &bbox1_y1,
+         nntrainer::Tensor &bbox1_w, nntrainer::Tensor &bbox1_h,
+         nntrainer::Tensor &bbox2_x1, nntrainer::Tensor &bbox2_y1,
+         nntrainer::Tensor &bbox2_w, nntrainer::Tensor &bbox2_h,
+         nntrainer::Tensor &bbox1_width, nntrainer::Tensor &bbox1_height,
+         nntrainer::Tensor &is_xy_min_max,
+         nntrainer::Tensor &intersection_width,
+         nntrainer::Tensor &intersection_height, nntrainer::Tensor &unions) {
+  nntrainer::Tensor bbox1_x2 = bbox1_x1.add(bbox1_w);
+  nntrainer::Tensor bbox1_y2 = bbox1_y1.add(bbox1_h);
+  nntrainer::Tensor bbox2_x2 = bbox2_x1.add(bbox2_w);
+  nntrainer::Tensor bbox2_y2 = bbox2_y1.add(bbox2_h);
+
+  bbox1_x2.subtract(bbox1_x1, bbox1_width);
+  bbox1_y2.subtract(bbox1_y1, bbox1_height);
+  nntrainer::Tensor bbox1 = bbox1_width.multiply(bbox1_height);
+
+  nntrainer::Tensor bbox2_width = bbox2_x2.subtract(bbox2_x1);
+  nntrainer::Tensor bbox2_height = bbox2_y2.subtract(bbox2_y1);
+  nntrainer::Tensor bbox2 = bbox2_width.multiply(bbox2_height);
+
+  auto min_func = [&](nntrainer::Tensor &bbox1_xy, nntrainer::Tensor &bbox2_xy,
+                      nntrainer::Tensor &intersection_xy) {
+    std::transform(bbox1_xy.getData(), bbox1_xy.getData() + bbox1_xy.size(),
+                   bbox2_xy.getData(), intersection_xy.getData(),
+                   [](float x1, float x2) { return std::min(x1, x2); });
+  };
+  auto max_func = [&](nntrainer::Tensor &bbox1_xy, nntrainer::Tensor &bbox2_xy,
+                      nntrainer::Tensor &intersection_xy) {
+    std::transform(bbox1_xy.getData(), bbox1_xy.getData() + bbox1_xy.size(),
+                   bbox2_xy.getData(), intersection_xy.getData(),
+                   [](float x1, float x2) { return std::max(x1, x2); });
+  };
+
+  nntrainer::Tensor intersection_x1(bbox1_x1.getDim());
+  nntrainer::Tensor intersection_x2(bbox1_x1.getDim());
+  nntrainer::Tensor intersection_y1(bbox1_y1.getDim());
+  nntrainer::Tensor intersection_y2(bbox1_y1.getDim());
+  max_func(bbox1_x1, bbox2_x1, intersection_x1);
+  min_func(bbox1_x2, bbox2_x2, intersection_x2);
+  max_func(bbox1_y1, bbox2_y1, intersection_y1);
+  min_func(bbox1_y2, bbox2_y2, intersection_y2);
+
+  auto is_min_max_func = [&](nntrainer::Tensor &xy,
+                             nntrainer::Tensor &intersection,
+                             nntrainer::Tensor &is_min_max) {
+    std::transform(xy.getData(), xy.getData() + xy.size(),
+                   intersection.getData(), is_min_max.getData(),
+                   [](float x, float m) {
+                     return nntrainer::absFloat(x - m) < 1e-4 ? 1.0 : 0.0;
+                   });
+  };
+
+  nntrainer::Tensor is_bbox1_x1_max(bbox1_x1.getDim());
+  nntrainer::Tensor is_bbox1_y1_max(bbox1_x1.getDim());
+  nntrainer::Tensor is_bbox1_x2_min(bbox1_x1.getDim());
+  nntrainer::Tensor is_bbox1_y2_min(bbox1_x1.getDim());
+  is_min_max_func(bbox1_x1, intersection_x1, is_bbox1_x1_max);
+  is_min_max_func(bbox1_y1, intersection_y1, is_bbox1_y1_max);
+  is_min_max_func(bbox1_x2, intersection_x2, is_bbox1_x2_min);
+  is_min_max_func(bbox1_y2, intersection_y2, is_bbox1_y2_min);
+
+  nntrainer::Tensor is_bbox_min_max = nntrainer::Tensor::cat(
+    {is_bbox1_x1_max, is_bbox1_y1_max, is_bbox1_x2_min, is_bbox1_y2_min}, 3);
+  is_xy_min_max.copyData(is_bbox_min_max);
+
+  intersection_x2.subtract(intersection_x1, intersection_width);
+  intersection_width.apply_i(nntrainer::ActiFunc::relu);
+  intersection_y2.subtract(intersection_y1, intersection_height);
+  intersection_height.apply_i(nntrainer::ActiFunc::relu);
+
+  nntrainer::Tensor intersection =
+    intersection_width.multiply(intersection_height);
+  bbox1.add(bbox2, unions);
+  unions.subtract_i(intersection);
+
+  return intersection.divide(unions);
+}
+
+/**
+ * @brief calculate iou graident
+ * @details Let say bbox_pred as x, intersection as f(x), union as g(x) and iou
+ * as y. Then y = f(x)/g(x). Also g(x) = bbox1 + bbox2 - f(x). Partial
+ * derivative of y with respect to x will be (f'(x)g(x) - f(x)g'(x))/(g(x)^2).
+ * Partial derivative of g(x) with respect to x will be bbox1'(x) - f'(x).
+ * @param confidence_gt_grad incoming derivative for iou
+ * @param bbox1_width bbox1_width
+ * @param bbox1_height bbox1_height
+ * @param is_xy_min_max For x1, y1 this value is 1 if x1 > x2, y1 > y2 and for
+ * x2, y2 this is value is 1 if x2 < x1, y2 < y1. else 0.
+ * @param intersection_width intersection width
+ * @param intersection_height intersection height
+ * @param unions unions
+ * @return std::vector<nntrainer::Tensor> iou_grad
+ */
+std::vector<nntrainer::Tensor> calc_iou_grad(
+  nntrainer::Tensor &confidence_gt_grad, nntrainer::Tensor &bbox1_width,
+  nntrainer::Tensor &bbox1_height, nntrainer::Tensor &is_xy_min_max,
+  nntrainer::Tensor &intersection_width, nntrainer::Tensor &intersection_height,
+  nntrainer::Tensor &unions) {
+  /** NYI */
+  return {nntrainer::Tensor()};
+}
+
 YoloV2LossLayer::YoloV2LossLayer() :
   anchors_w({1, 1, NUM_ANCHOR, 1}, anchors_w_buf),
   anchors_h({1, 1, NUM_ANCHOR, 1}, anchors_h_buf),
-  sigmoid(nntrainer::ActivationType::ACT_SIGMOID, false),
-  softmax(nntrainer::ActivationType::ACT_SOFTMAX, false),
+  sigmoid(nntrainer::ActivationType::ACT_SIGMOID, true),
+  softmax(nntrainer::ActivationType::ACT_SOFTMAX, true),
   yolo_v2_loss_props(props::MaxObjectNumber(), props::ClassNumber(),
-                     props::GridHeightNumber(), props::GridWidthNumber(),
-                     props::ImageHeightSize(), props::ImageWidthSize()) {
+                     props::GridHeightNumber(), props::GridWidthNumber()) {
   anchors_ratio = anchors_w.divide(anchors_h);
   wt_idx.fill(std::numeric_limits<unsigned>::max());
 }
 
 void YoloV2LossLayer::finalize(nntrainer::InitLayerContext &context) {
-  /** NYI */
+  nntrainer::TensorDim input_dim =
+    context.getInputDimensions()[SINGLE_INOUT_IDX];
+  const unsigned int batch_size = input_dim.batch();
+  const unsigned int class_number =
+    std::get<props::ClassNumber>(yolo_v2_loss_props).get();
+  const unsigned int grid_height_number =
+    std::get<props::GridHeightNumber>(yolo_v2_loss_props).get();
+  const unsigned int grid_width_number =
+    std::get<props::GridWidthNumber>(yolo_v2_loss_props).get();
+  const unsigned int max_object_number =
+    std::get<props::MaxObjectNumber>(yolo_v2_loss_props).get();
+  nntrainer::TensorDim label_dim(batch_size, 1, max_object_number, 5);
+  context.setOutputDimensions({label_dim});
+
+  nntrainer::TensorDim bbox_x_pred_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox_x_pred] = context.requestTensor(
+    bbox_x_pred_dim, "bbox_x_pred", nntrainer::Tensor::Initializer::NONE, true,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim bbox_y_pred_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox_y_pred] = context.requestTensor(
+    bbox_y_pred_dim, "bbox_y_pred", nntrainer::Tensor::Initializer::NONE, true,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim bbox_w_pred_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox_w_pred] = context.requestTensor(
+    bbox_w_pred_dim, "bbox_w_pred", nntrainer::Tensor::Initializer::NONE, true,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim bbox_h_pred_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox_h_pred] = context.requestTensor(
+    bbox_h_pred_dim, "bbox_h_pred", nntrainer::Tensor::Initializer::NONE, true,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim confidence_pred_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::confidence_pred] =
+    context.requestTensor(confidence_pred_dim, "confidence_pred",
+                          nntrainer::Tensor::Initializer::NONE, true,
+                          nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim class_pred_dim(batch_size,
+                                      grid_height_number * grid_width_number,
+                                      NUM_ANCHOR, class_number);
+  wt_idx[YoloV2LossParams::class_pred] = context.requestTensor(
+    class_pred_dim, "class_pred", nntrainer::Tensor::Initializer::NONE, true,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim bbox_w_pred_anchor_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox_w_pred_anchor] =
+    context.requestTensor(bbox_w_pred_anchor_dim, "bbox_w_pred_anchor",
+                          nntrainer::Tensor::Initializer::NONE, false,
+                          nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim bbox_h_pred_anchor_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox_h_pred_anchor] =
+    context.requestTensor(bbox_h_pred_anchor_dim, "bbox_h_pred_anchor",
+                          nntrainer::Tensor::Initializer::NONE, false,
+                          nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim bbox_x_gt_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox_x_gt] = context.requestTensor(
+    bbox_x_gt_dim, "bbox_x_gt", nntrainer::Tensor::Initializer::NONE, false,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim bbox_y_gt_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox_y_gt] = context.requestTensor(
+    bbox_y_gt_dim, "bbox_y_gt", nntrainer::Tensor::Initializer::NONE, false,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim bbox_w_gt_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox_w_gt] = context.requestTensor(
+    bbox_w_gt_dim, "bbox_w_gt", nntrainer::Tensor::Initializer::NONE, false,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim bbox_h_gt_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox_h_gt] = context.requestTensor(
+    bbox_h_gt_dim, "bbox_h_gt", nntrainer::Tensor::Initializer::NONE, false,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim confidence_gt_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::confidence_gt] = context.requestTensor(
+    confidence_gt_dim, "confidence_gt", nntrainer::Tensor::Initializer::NONE,
+    false, nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim class_gt_dim(batch_size,
+                                    grid_height_number * grid_width_number,
+                                    NUM_ANCHOR, class_number);
+  wt_idx[YoloV2LossParams::class_gt] = context.requestTensor(
+    class_gt_dim, "class_gt", nntrainer::Tensor::Initializer::NONE, false,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim bbox_class_mask_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox_class_mask] =
+    context.requestTensor(bbox_class_mask_dim, "bbox_class_mask",
+                          nntrainer::Tensor::Initializer::NONE, false,
+                          nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim iou_mask_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::iou_mask] = context.requestTensor(
+    iou_mask_dim, "iou_mask", nntrainer::Tensor::Initializer::NONE, false,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim bbox1_width_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox1_width] = context.requestTensor(
+    bbox1_width_dim, "bbox1_width", nntrainer::Tensor::Initializer::NONE, false,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim bbox1_height_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::bbox1_height] = context.requestTensor(
+    bbox1_height_dim, "bbox1_height", nntrainer::Tensor::Initializer::NONE,
+    false, nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim is_xy_min_max_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 4);
+  wt_idx[YoloV2LossParams::is_xy_min_max] = context.requestTensor(
+    is_xy_min_max_dim, "is_xy_min_max", nntrainer::Tensor::Initializer::NONE,
+    false, nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim intersection_width_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::intersection_width] =
+    context.requestTensor(intersection_width_dim, "intersection_width",
+                          nntrainer::Tensor::Initializer::NONE, false,
+                          nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim intersection_height_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::intersection_height] =
+    context.requestTensor(intersection_height_dim, "intersection_height",
+                          nntrainer::Tensor::Initializer::NONE, false,
+                          nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
+
+  nntrainer::TensorDim unions_dim(
+    batch_size, grid_height_number * grid_width_number, NUM_ANCHOR, 1);
+  wt_idx[YoloV2LossParams::unions] = context.requestTensor(
+    unions_dim, "unions", nntrainer::Tensor::Initializer::NONE, false,
+    nntrainer::TensorLifespan::FORWARD_DERIV_LIFESPAN);
 }
 
 void YoloV2LossLayer::forwarding(nntrainer::RunLayerContext &context,
                                  bool training) {
-  /** NYI */
+  const unsigned int max_object_number =
+    std::get<props::MaxObjectNumber>(yolo_v2_loss_props).get();
+
+  nntrainer::Tensor &input = context.getInput(SINGLE_INOUT_IDX);
+
+  std::vector<nntrainer::Tensor> splited_input =
+    input.split({1, 1, 1, 1, 1, max_object_number}, 3);
+  nntrainer::Tensor bbox_x_pred_ = splited_input[0];
+  nntrainer::Tensor bbox_y_pred_ = splited_input[1];
+  nntrainer::Tensor bbox_w_pred_ = splited_input[2];
+  nntrainer::Tensor bbox_h_pred_ = splited_input[3];
+  nntrainer::Tensor confidence_pred_ = splited_input[4];
+  nntrainer::Tensor class_pred_ = splited_input[5];
+
+  nntrainer::Tensor &bbox_x_pred =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_x_pred]);
+  nntrainer::Tensor &bbox_y_pred =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_y_pred]);
+  nntrainer::Tensor &bbox_w_pred =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred]);
+  nntrainer::Tensor &bbox_h_pred =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred]);
+
+  nntrainer::Tensor &confidence_pred =
+    context.getTensor(wt_idx[YoloV2LossParams::confidence_pred]);
+  nntrainer::Tensor &class_pred =
+    context.getTensor(wt_idx[YoloV2LossParams::class_pred]);
+
+  nntrainer::Tensor &bbox_w_pred_anchor =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred_anchor]);
+  nntrainer::Tensor &bbox_h_pred_anchor =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred_anchor]);
+
+  bbox_x_pred.copyData(bbox_x_pred_);
+  bbox_y_pred.copyData(bbox_y_pred_);
+  bbox_w_pred.copyData(bbox_w_pred_);
+  bbox_h_pred.copyData(bbox_h_pred_);
+
+  confidence_pred.copyData(confidence_pred_);
+  class_pred.copyData(class_pred_);
+
+  nntrainer::Tensor &bbox_x_gt =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_x_gt]);
+  nntrainer::Tensor &bbox_y_gt =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_y_gt]);
+  nntrainer::Tensor &bbox_w_gt =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_w_gt]);
+  nntrainer::Tensor &bbox_h_gt =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_h_gt]);
+
+  nntrainer::Tensor &confidence_gt =
+    context.getTensor(wt_idx[YoloV2LossParams::confidence_gt]);
+  nntrainer::Tensor &class_gt =
+    context.getTensor(wt_idx[YoloV2LossParams::class_gt]);
+
+  nntrainer::Tensor &bbox_class_mask =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_class_mask]);
+  nntrainer::Tensor &iou_mask =
+    context.getTensor(wt_idx[YoloV2LossParams::iou_mask]);
+
+  // init mask
+  bbox_class_mask.setValue(0);
+  iou_mask.setValue(0.5);
+
+  // activate pred
+  sigmoid.run_fn(bbox_x_pred, bbox_x_pred);
+  sigmoid.run_fn(bbox_y_pred, bbox_y_pred);
+  bbox_w_pred.apply_i(nntrainer::exp_util);
+  bbox_h_pred.apply_i(nntrainer::exp_util);
+  sigmoid.run_fn(confidence_pred, confidence_pred);
+  softmax.run_fn(class_pred, class_pred);
+
+  bbox_w_pred_anchor.copyData(bbox_w_pred);
+  bbox_h_pred_anchor.copyData(bbox_h_pred);
+
+  // apply anchors to bounding box
+  bbox_w_pred_anchor.multiply_i(anchors_w);
+  bbox_h_pred_anchor.multiply_i(anchors_h);
+  bbox_w_pred_anchor.apply_i(nntrainer::sqrtFloat);
+  bbox_h_pred_anchor.apply_i(nntrainer::sqrtFloat);
+
+  generate_ground_truth(context);
+
+  nntrainer::Tensor bbox_pred = nntrainer::Tensor::cat(
+    {bbox_x_pred, bbox_y_pred, bbox_w_pred_anchor, bbox_h_pred_anchor}, 3);
+  nntrainer::Tensor masked_bbox_pred = bbox_pred.multiply(bbox_class_mask);
+  nntrainer::Tensor masked_confidence_pred = confidence_pred.multiply(iou_mask);
+  nntrainer::Tensor masked_class_pred = class_pred.multiply(bbox_class_mask);
+
+  nntrainer::Tensor bbox_gt =
+    nntrainer::Tensor::cat({bbox_x_gt, bbox_y_gt, bbox_w_gt, bbox_h_gt}, 3);
+  nntrainer::Tensor masked_bbox_gt = bbox_gt.multiply(bbox_class_mask);
+  nntrainer::Tensor masked_confidence_gt = confidence_gt.multiply(iou_mask);
+  nntrainer::Tensor masked_class_gt = class_gt.multiply(bbox_class_mask);
+
+  float bbox_loss = mse(masked_bbox_pred, masked_bbox_gt);
+  float confidence_loss = mse(masked_confidence_pred, masked_confidence_gt);
+  float class_loss = mse(masked_class_pred, masked_class_gt);
+
+  float loss = 5 * bbox_loss + confidence_loss + class_loss;
 }
 
 void YoloV2LossLayer::calcDerivative(nntrainer::RunLayerContext &context) {
@@ -72,27 +481,156 @@ void YoloV2LossLayer::calcDerivative(nntrainer::RunLayerContext &context) {
 
 void YoloV2LossLayer::exportTo(nntrainer::Exporter &exporter,
                                const ml::train::ExportMethods &method) const {
-  /** NYI */
+  exporter.saveResult(yolo_v2_loss_props, method, this);
 }
 
 void YoloV2LossLayer::setProperty(const std::vector<std::string> &values) {
-  /** NYI */
+  auto remain_props = loadProperties(values, yolo_v2_loss_props);
+  NNTR_THROW_IF(!remain_props.empty(), std::invalid_argument)
+    << "[YoloV2LossLayer] Unknown Layer Properties count " +
+         std::to_string(values.size());
+}
+
+void YoloV2LossLayer::setBatch(nntrainer::RunLayerContext &context,
+                               unsigned int batch) {
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox_x_pred], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox_y_pred], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox_w_pred], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox_h_pred], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::confidence_pred], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::class_pred], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox_w_pred_anchor], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox_h_pred_anchor], batch);
+
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox_x_gt], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox_y_gt], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox_w_gt], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox_h_gt], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::confidence_gt], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::class_gt], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox_class_mask], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::iou_mask], batch);
+
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox1_width], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::bbox1_height], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::is_xy_min_max], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::intersection_width], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::intersection_height], batch);
+  context.updateTensor(wt_idx[YoloV2LossParams::unions], batch);
 }
 
 unsigned int YoloV2LossLayer::find_responsible_anchors(float bbox_ratio) {
-  /** NYI */
-  return 0;
+  nntrainer::Tensor similarity = anchors_ratio.subtract(bbox_ratio);
+  similarity.apply_i(nntrainer::absFloat);
+  auto data = similarity.getData();
+
+  auto min_iter = std::min_element(data, data + NUM_ANCHOR);
+  return std::distance(data, min_iter);
 }
 
 void YoloV2LossLayer::generate_ground_truth(
-  nntrainer::Tensor &bbox_x_pred, nntrainer::Tensor &bbox_y_pred,
-  nntrainer::Tensor &bbox_w_pred, nntrainer::Tensor &bbox_h_pred,
-  nntrainer::Tensor &labels, nntrainer::Tensor &bbox_x_gt,
-  nntrainer::Tensor &bbox_y_gt, nntrainer::Tensor &bbox_w_gt,
-  nntrainer::Tensor &bbox_h_gt, nntrainer::Tensor &confidence_gt,
-  nntrainer::Tensor &class_gt, nntrainer::Tensor &bbox_class_mask,
-  nntrainer::Tensor &iou_mask) {
-  /** NYI */
+  nntrainer::RunLayerContext &context) {
+  const unsigned int max_object_number =
+    std::get<props::MaxObjectNumber>(yolo_v2_loss_props).get();
+  const unsigned int grid_height_number =
+    std::get<props::GridHeightNumber>(yolo_v2_loss_props).get();
+  const unsigned int grid_width_number =
+    std::get<props::GridWidthNumber>(yolo_v2_loss_props).get();
+
+  nntrainer::Tensor &label = context.getLabel(SINGLE_INOUT_IDX);
+
+  nntrainer::Tensor &bbox_x_pred =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_x_pred]);
+  nntrainer::Tensor &bbox_y_pred =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_y_pred]);
+  nntrainer::Tensor &bbox_w_pred_anchor =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_w_pred_anchor]);
+  nntrainer::Tensor &bbox_h_pred_anchor =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_h_pred_anchor]);
+
+  nntrainer::Tensor &bbox_x_gt =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_x_gt]);
+  nntrainer::Tensor &bbox_y_gt =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_y_gt]);
+  nntrainer::Tensor &bbox_w_gt =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_w_gt]);
+  nntrainer::Tensor &bbox_h_gt =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_h_gt]);
+
+  nntrainer::Tensor &confidence_gt =
+    context.getTensor(wt_idx[YoloV2LossParams::confidence_gt]);
+  nntrainer::Tensor &class_gt =
+    context.getTensor(wt_idx[YoloV2LossParams::class_gt]);
+
+  nntrainer::Tensor &bbox_class_mask =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox_class_mask]);
+  nntrainer::Tensor &iou_mask =
+    context.getTensor(wt_idx[YoloV2LossParams::iou_mask]);
+
+  nntrainer::Tensor &bbox1_width =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox1_width]);
+  nntrainer::Tensor &bbox1_height =
+    context.getTensor(wt_idx[YoloV2LossParams::bbox1_height]);
+  nntrainer::Tensor &is_xy_min_max =
+    context.getTensor(wt_idx[YoloV2LossParams::is_xy_min_max]);
+  nntrainer::Tensor &intersection_width =
+    context.getTensor(wt_idx[YoloV2LossParams::intersection_width]);
+  nntrainer::Tensor &intersection_height =
+    context.getTensor(wt_idx[YoloV2LossParams::intersection_height]);
+  nntrainer::Tensor &unions =
+    context.getTensor(wt_idx[YoloV2LossParams::unions]);
+
+  const unsigned int batch_size = bbox_x_pred.getDim().batch();
+
+  std::vector<nntrainer::Tensor> splited_label =
+    label.split({1, 1, 1, 1, 1}, 3);
+  nntrainer::Tensor bbox_x_label = splited_label[0];
+  nntrainer::Tensor bbox_y_label = splited_label[1];
+  nntrainer::Tensor bbox_w_label = splited_label[2];
+  nntrainer::Tensor bbox_h_label = splited_label[3];
+  nntrainer::Tensor class_label = splited_label[4];
+
+  bbox_x_label.multiply_i(grid_width_number);
+  bbox_y_label.multiply_i(grid_height_number);
+
+  for (unsigned int batch = 0; batch < batch_size; ++batch) {
+    for (unsigned int object = 0; object < max_object_number; ++object) {
+      if (!bbox_w_label.getValue(batch, 0, object, 0) &&
+          !bbox_h_label.getValue(batch, 0, object, 0)) {
+        break;
+      }
+      unsigned int grid_x_index = bbox_x_label.getValue(batch, 0, object, 0);
+      unsigned int grid_y_index = bbox_y_label.getValue(batch, 0, object, 0);
+      unsigned int grid_index = grid_y_index * grid_width_number + grid_x_index;
+      unsigned int responsible_anchor =
+        find_responsible_anchors(bbox_w_label.getValue(batch, 0, object, 0) /
+                                 bbox_h_label.getValue(batch, 0, object, 0));
+
+      bbox_x_gt.setValue(batch, grid_index, responsible_anchor, 0,
+                         bbox_x_label.getValue(batch, 0, object, 0) -
+                           grid_x_index);
+      bbox_y_gt.setValue(batch, grid_index, responsible_anchor, 0,
+                         bbox_y_label.getValue(batch, 0, object, 0) -
+                           grid_y_index);
+      bbox_w_gt.setValue(
+        batch, grid_index, responsible_anchor, 0,
+        nntrainer::sqrtFloat(bbox_w_label.getValue(batch, 0, object, 0)));
+      bbox_h_gt.setValue(
+        batch, grid_index, responsible_anchor, 0,
+        nntrainer::sqrtFloat(bbox_h_label.getValue(batch, 0, object, 0)));
+
+      class_gt.setValue(batch, grid_index, responsible_anchor,
+                        class_label.getValue(batch, 0, object, 0), 1);
+      bbox_class_mask.setValue(batch, grid_index, responsible_anchor, 0, 1);
+      iou_mask.setValue(batch, grid_index, responsible_anchor, 0, 1);
+    }
+  }
+
+  nntrainer::Tensor iou = calc_iou(
+    bbox_x_pred, bbox_y_pred, bbox_w_pred_anchor, bbox_h_pred_anchor, bbox_x_gt,
+    bbox_y_gt, bbox_w_gt, bbox_h_gt, bbox1_width, bbox1_height, is_xy_min_max,
+    intersection_width, intersection_height, unions);
+  confidence_gt.copyData(iou);
 }
 
 #ifdef PLUGGABLE
index 11237ea..4dde915 100644 (file)
@@ -69,28 +69,6 @@ public:
   using prop_tag = nntrainer::uint_prop_tag;
 };
 
-/**
- * @brief image height size
- *
- */
-class ImageHeightSize final : public nntrainer::PositiveIntegerProperty {
-public:
-  ImageHeightSize(const unsigned &value = 1);
-  static constexpr const char *key = "image_height_size";
-  using prop_tag = nntrainer::uint_prop_tag;
-};
-
-/**
- * @brief image width size
- *
- */
-class ImageWidthSize final : public nntrainer::PositiveIntegerProperty {
-public:
-  ImageWidthSize(const unsigned &value = 1);
-  static constexpr const char *key = "image_width_size";
-  using prop_tag = nntrainer::uint_prop_tag;
-};
-
 } // namespace props
 
 /**
@@ -138,6 +116,12 @@ public:
   void setProperty(const std::vector<std::string> &values) override;
 
   /**
+   * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
+   */
+  void setBatch(nntrainer::RunLayerContext &context,
+                unsigned int batch) override;
+
+  /**
    * @copydoc bool supportBackwarding() const
    */
   bool supportBackwarding() const override { return true; };
@@ -168,8 +152,7 @@ private:
   nntrainer::ActiFunc softmax; /** softmax activation operation */
 
   std::tuple<props::MaxObjectNumber, props::ClassNumber,
-             props::GridHeightNumber, props::GridWidthNumber,
-             props::ImageHeightSize, props::ImageWidthSize>
+             props::GridHeightNumber, props::GridWidthNumber>
     yolo_v2_loss_props;
   std::array<unsigned int, 8> wt_idx; /**< indices of the weights */
 
@@ -181,14 +164,7 @@ private:
   /**
    * @brief generate ground truth, mask from labels
    */
-  void generate_ground_truth(
-    nntrainer::Tensor &bbox_x_pred, nntrainer::Tensor &bbox_y_pred,
-    nntrainer::Tensor &bbox_w_pred, nntrainer::Tensor &bbox_h_pred,
-    nntrainer::Tensor &labels, nntrainer::Tensor &bbox_x_gt,
-    nntrainer::Tensor &bbox_y_gt, nntrainer::Tensor &bbox_w_gt,
-    nntrainer::Tensor &bbox_h_gt, nntrainer::Tensor &confidence_gt,
-    nntrainer::Tensor &class_gt, nntrainer::Tensor &bbox_class_mask,
-    nntrainer::Tensor &iou_mask);
+  void generate_ground_truth(nntrainer::RunLayerContext &context);
 };
 
 } // namespace custom
index 2198933..acc3308 100644 (file)
@@ -28,6 +28,7 @@
 #include <fstream>
 #include <random>
 
+#include <acti_func.h>
 #include <nntrainer_log.h>
 #include <util_func.h>
 
@@ -39,6 +40,8 @@ float sqrtFloat(float x) { return sqrt(x); };
 
 double sqrtDouble(double x) { return sqrt(x); };
 
+float absFloat(float x) { return abs(x); };
+
 float logFloat(float x) { return log(x + 1.0e-20); }
 
 float exp_util(float x) { return exp(x); }
@@ -60,12 +63,6 @@ Tensor rotate_180(Tensor in) {
   return output;
 }
 
-Tensor calculateIOU(Tensor &b1_x1, Tensor &b1_y1, Tensor &b1_w, Tensor &b1_h,
-                    Tensor &b2_x1, Tensor &b2_y1, Tensor &b2_w, Tensor &b2_h) {
-  /** NYI */
-  return Tensor();
-}
-
 bool isFileExist(std::string file_name) {
   std::ifstream infile(file_name);
   return infile.good();
index 8fff164..ed536d7 100644 (file)
@@ -92,6 +92,12 @@ float sqrtFloat(float x);
 double sqrtDouble(double x);
 
 /**
+ * @brief     abs function for float type
+ * @param[in] x float
+ */
+float absFloat(float x);
+
+/**
  * @brief     log function for float type
  * @param[in] x float
  */
@@ -105,9 +111,6 @@ template <typename T = float> T logFloat(T x) {
  */
 template <typename T = float> T exp_util(T x) { return static_cast<T>(exp(x)); }
 
-Tensor calculateIOU(Tensor &b1_x1, Tensor &b1_y1, Tensor &b1_x2, Tensor &b1_y2,
-                    Tensor &b2_x1, Tensor &b2_y1, Tensor &b2_x2, Tensor &b2_y2);
-
 /**
  * @brief     Check Existance of File
  * @param[in] file path of the file to be checked