[Wait for #2177,#2213][Application] rebase for yolo v2
authorSeungbaek Hong <sb92.hong@samsung.com>
Wed, 31 May 2023 06:14:12 +0000 (15:14 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 12 Sep 2023 15:21:01 +0000 (00:21 +0900)
I've rebased #2177(loss for yolo) and #2213(custom layer for yolo).
(Because the author of PR #2177 is absent for now.

If someone needs to use yolo v2, then use this pr.

I'll update document for this later.

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Seungbaek Hong <sb92.hong@samsung.com>
Applications/YOLO/PyTorch/dataset.py
Applications/YOLO/PyTorch/main.py
Applications/YOLO/jni/main.cpp
Applications/YOLO/jni/meson.build
Applications/YOLO/jni/yolo_v2_loss.cpp
Applications/YOLO/jni/yolo_v2_loss.h

index 09d52e5e74314770131a301c80cde8a9fadb7f74..a02971ae87222cfcda178de6eced6c898adbb394 100644 (file)
@@ -36,7 +36,7 @@ class YOLODataset(Dataset):
             with open(ann_list[i], 'rt') as f:
                 for line in f.readlines():
                     line = [float(i) for i in line.split()]
-                    label_bbox.append(np.array(line[1:], dtype=np.float32))
+                    label_bbox.append(np.array(line[1:], dtype=np.float32) / 416)
                     label_cls.append(int(line[0]))
 
             self.input_images.append(img)
index c73f895a14efa9dd284ba91f0a022ca59216016e..b831e1ebb1e92a14ee4d2fd8789e816933ddfd5a 100644 (file)
@@ -34,16 +34,16 @@ from torchconverter import save_bin
 
 # set config
 out_size = 13
-num_classes = 92
+num_classes = 4
 num_anchors = 5
 
 epochs = 3
 batch_size = 4
 
-train_img_dir = 'TRAIN_DIR/images/*'
-train_ann_dir = 'TRAIN_DIR/annotations/*'
-valid_img_dir = 'VALIDATION_DIR/images/*'
-valid_ann_dir = 'VALIDATION_DIR/annotations/*'
+train_img_dir = '/home/user/TRAIN_DIR/images/*'
+train_ann_dir = '/home/user/TRAIN_DIR/annotations/*'
+valid_img_dir = '/home/user/VALID_DIR/images/*'
+valid_ann_dir = '/home/user/VALID_DIR/annotations/*'
 
 # load data
 train_dataset = YOLODataset(train_img_dir, train_ann_dir)
@@ -57,7 +57,9 @@ criterion = YoloV2_LOSS(num_classes=num_classes)
 optimizer = optim.Adam(model.parameters(), lr=1e-3)
 # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
 
+# save init model
 save_bin(model, 'init_model')
+torch.save(model.state_dict(), './init_model.pt')
 
 # train model
 best_loss = 1e+10
index 7c470e80f34adb7f1a0c756921b3a491c7e6f112..2d1af0c7cfb90032e5dff09acc7d1dfb85695611 100644 (file)
@@ -27,7 +27,6 @@
 
 #include "yolo_v2_loss.h"
 
-#include <app_context.h>
 #include <reorg_layer.h>
 
 using LayerHandle = std::shared_ptr<ml::train::Layer>;
@@ -44,9 +43,9 @@ const unsigned int IMAGE_HEIGHT_SIZE = 416;
 const unsigned int IMAGE_WIDTH_SIZE = 416;
 const unsigned int BATCH_SIZE = 4;
 const unsigned int EPOCHS = 3;
-const char *TRAIN_DIR_PATH = "TRAIN_DIR_PATH";
-const char *VALIDATION_DIR_PATH = "VALIDATION_DIR_PATH";
-const std::string MODEL_INIT_BIN_PATH = "MODEL_INIT_BIN_PATH";
+const char *TRAIN_DIR_PATH = "/TRAIN_DIR/";
+const char *VALIDATION_DIR_PATH = "/VALID_DIR/";
+const std::string MODEL_INIT_BIN_PATH = "/home/user/MODEL_INIT_BIN_PATH.bin";
 
 int trainData_cb(float **input, float **label, bool *last, void *user_data) {
   auto data = reinterpret_cast<nntrainer::util::DirDataLoader *>(user_data);
@@ -176,8 +175,7 @@ std::vector<LayerHandle> yoloBlock(const std::string &block_name,
 ModelHandle YOLO() {
   using ml::train::createLayer;
 
-  ModelHandle model = ml::train::createModel(ml::train::ModelType::NEURAL_NET,
-                                             {withKey("loss", "mse")});
+  ModelHandle model = ml::train::createModel(ml::train::ModelType::NEURAL_NET);
 
   std::vector<LayerHandle> layers;
 
@@ -216,7 +214,7 @@ ModelHandle YOLO() {
   blocks.push_back(yoloBlock("conv_a7", "conv_a6", 1024, 3, false));
 
   blocks.push_back(yoloBlock("conv_b", "conv13", 64, 1, false));
-  // todo: conv_b_pool layer will be replaced with re-organization custom layer
+
   blocks.push_back({createLayer("reorg", {withKey("name", "re_organization"),
                                           withKey("input_layers", "conv_b")})});
 
@@ -279,12 +277,21 @@ int main(int argc, char *argv[]) {
   std::cout << "started computation at " << std::ctime(&start_time)
             << std::endl;
 
+  auto &app_context = nntrainer::AppContext::Global();
+
+  try {
+    app_context.registerFactory(nntrainer::createLayer<custom::ReorgLayer>);
+  } catch (std::invalid_argument &e) {
+    std::cerr << "failed to register reorg layer, reason: " << e.what()
+              << std::endl;
+    return 1;
+  }
+
   try {
-    auto &app_context = nntrainer::AppContext::Global();
     app_context.registerFactory(
       nntrainer::createLayer<custom::YoloV2LossLayer>);
   } catch (std::invalid_argument &e) {
-    std::cerr << "failed to register factory, reason: " << e.what()
+    std::cerr << "failed to register loss layer, reason: " << e.what()
               << std::endl;
     return 1;
   }
@@ -308,7 +315,7 @@ int main(int argc, char *argv[]) {
     // compile and initialize model
     model->compile();
     model->initialize();
-    model->load(MODEL_INIT_BIN_PATH);
+    // model->load(MODEL_INIT_BIN_PATH);
 
     // create train and validation data
     std::array<UserDataType, 2> user_datas;
index 113bdc1f41419682a77a91192797004a488e59e1..cc0ac96dc9fd4e20937063fd3ec66796a8c94e84 100644 (file)
@@ -18,8 +18,8 @@ reorg_layer_dep = declare_dependency(
 yolo_sources = [
   'main.cpp',
   'det_dataloader.cpp',
-  layer_reorg_src
-  'yolo_v2_loss.cpp'
+  'yolo_v2_loss.cpp',
+  'reorg_layer.cpp',
 ]
 
 yolo_dependencies = [app_utils_dep,
index 6eeb9ed57194952acd6c57fbf3926f550f0c0d3d..88234aef76db200626dbf6e368e0a49971518a97 100644 (file)
@@ -12,6 +12,7 @@
  */
 
 #include "yolo_v2_loss.h"
+#include <iostream>
 
 namespace custom {
 
@@ -172,9 +173,30 @@ calc_iou(nntrainer::Tensor &bbox1_x1, nntrainer::Tensor &bbox1_y1,
   is_xy_min_max.copyData(is_bbox_min_max);
 
   intersection_x2.subtract(intersection_x1, intersection_width);
-  intersection_width.apply_i(nntrainer::ActiFunc::relu);
+
+  auto type_intersection_width = intersection_width.getDataType();
+  if (type_intersection_width == ml::train::TensorDim::DataType::FP32) {
+    intersection_width.apply_i<float>(nntrainer::ActiFunc::relu<float>);
+  } else if (type_intersection_width == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    intersection_width.apply_i<_FP16>(nntrainer::ActiFunc::relu<float>);
+#else
+    throw std::runtime_error("Not supported data type");
+#endif
+  }
+
   intersection_y2.subtract(intersection_y1, intersection_height);
-  intersection_height.apply_i(nntrainer::ActiFunc::relu);
+
+  auto type_intersection_height = intersection_height.getDataType();
+  if (type_intersection_height == ml::train::TensorDim::DataType::FP32) {
+    intersection_height.apply_i<float>(nntrainer::ActiFunc::relu<float>);
+  } else if (type_intersection_height == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    intersection_height.apply_i<_FP16>(nntrainer::ActiFunc::relu<_FP16>);
+#else
+    throw std::runtime_error("Not supported data type");
+#endif
+  }
 
   nntrainer::Tensor intersection =
     intersection_width.multiply(intersection_height);
@@ -209,10 +231,20 @@ std::vector<nntrainer::Tensor> calc_iou_grad(
     intersection_width.multiply(intersection_height);
 
   // 1. calculate intersection local gradient [f'(x)]
-  nntrainer::Tensor intersection_width_relu_prime =
-    intersection_width.apply(nntrainer::ActiFunc::reluPrime);
-  nntrainer::Tensor intersection_height_relu_prime =
-    intersection_height.apply(nntrainer::ActiFunc::reluPrime);
+  nntrainer::Tensor intersection_width_relu_prime;
+  nntrainer::Tensor intersection_height_relu_prime;
+  auto type_intersection_width = intersection_width.getDataType();
+  if (type_intersection_width == ml::train::TensorDim::DataType::FP32) {
+    intersection_width_relu_prime =
+      intersection_width.apply<float>(nntrainer::ActiFunc::reluPrime<float>);
+  } else if (type_intersection_width == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    intersection_height_relu_prime =
+      intersection_height.apply<_FP16>(nntrainer::ActiFunc::reluPrime<_FP16>);
+#else
+    throw std::runtime_error("Not supported data type");
+#endif
+  }
 
   nntrainer::Tensor intersection_x2_local_grad =
     intersection_width_relu_prime.multiply(intersection_height);
@@ -502,8 +534,29 @@ void YoloV2LossLayer::forwarding(nntrainer::RunLayerContext &context,
   // activate pred
   sigmoid.run_fn(bbox_x_pred, bbox_x_pred);
   sigmoid.run_fn(bbox_y_pred, bbox_y_pred);
-  bbox_w_pred.apply_i(nntrainer::exp_util);
-  bbox_h_pred.apply_i(nntrainer::exp_util);
+
+  auto type_bbox_w_pred = bbox_w_pred.getDataType();
+  if (type_bbox_w_pred == ml::train::TensorDim::DataType::FP32) {
+    bbox_w_pred.apply_i<float>(nntrainer::exp_util<float>);
+  } else if (type_bbox_w_pred == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    bbox_w_pred.apply_i<_FP16>(nntrainer::exp_util<_FP16>);
+#else
+    throw std::runtime_error("Not supported data type");
+#endif
+  }
+
+  auto type_bbox_h_pred = bbox_h_pred.getDataType();
+  if (type_bbox_h_pred == ml::train::TensorDim::DataType::FP32) {
+    bbox_h_pred.apply_i<float>(nntrainer::exp_util<float>);
+  } else if (type_bbox_h_pred == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    bbox_h_pred.apply_i<_FP16>(nntrainer::exp_util<_FP16>);
+#else
+    throw std::runtime_error("Not supported data type");
+#endif
+  }
+
   sigmoid.run_fn(confidence_pred, confidence_pred);
   softmax.run_fn(class_pred, class_pred);
 
@@ -512,9 +565,28 @@ void YoloV2LossLayer::forwarding(nntrainer::RunLayerContext &context,
 
   // apply anchors to bounding box
   bbox_w_pred_anchor.multiply_i(anchors_w);
+  auto type_bbox_w_pred_anchor = bbox_w_pred_anchor.getDataType();
+  if (type_bbox_w_pred_anchor == ml::train::TensorDim::DataType::FP32) {
+    bbox_w_pred_anchor.apply_i<float>(nntrainer::sqrtFloat);
+  } else if (type_bbox_w_pred_anchor == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    bbox_w_pred_anchor.apply_i<_FP16>(nntrainer::sqrtFloat);
+#else
+    throw std::runtime_error("Not supported data type");
+#endif
+  }
+
   bbox_h_pred_anchor.multiply_i(anchors_h);
-  bbox_w_pred_anchor.apply_i(nntrainer::sqrtFloat);
-  bbox_h_pred_anchor.apply_i(nntrainer::sqrtFloat);
+  auto type_bbox_h_pred_anchor = bbox_h_pred_anchor.getDataType();
+  if (type_bbox_h_pred_anchor == ml::train::TensorDim::DataType::FP32) {
+    bbox_h_pred_anchor.apply_i<float>(nntrainer::sqrtFloat);
+  } else if (type_bbox_h_pred_anchor == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    bbox_h_pred_anchor.apply_i<_FP16>(nntrainer::sqrtFloat);
+#else
+    throw std::runtime_error("Not supported data type");
+#endif
+  }
 
   generate_ground_truth(context);
 
@@ -535,6 +607,7 @@ void YoloV2LossLayer::forwarding(nntrainer::RunLayerContext &context,
   float class_loss = mse(masked_class_pred, masked_class_gt);
 
   float loss = 5 * bbox_loss + confidence_loss + class_loss;
+  std::cout << "\nCurrent iteration loss: " << loss << std::endl;
 }
 
 void YoloV2LossLayer::calcDerivative(nntrainer::RunLayerContext &context) {
@@ -733,7 +806,16 @@ void YoloV2LossLayer::setBatch(nntrainer::RunLayerContext &context,
 
 unsigned int YoloV2LossLayer::find_responsible_anchors(float bbox_ratio) {
   nntrainer::Tensor similarity = anchors_ratio.subtract(bbox_ratio);
-  similarity.apply_i(nntrainer::absFloat);
+  auto data_type = similarity.getDataType();
+  if (data_type == ml::train::TensorDim::DataType::FP32) {
+    similarity.apply_i<float>(nntrainer::absFloat);
+  } else if (data_type == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    similarity.apply_i<_FP16>(nntrainer::absFloat);
+#else
+    throw std::runtime_error("Not supported data type");
+#endif
+  }
   auto data = similarity.getData();
 
   auto min_iter = std::min_element(data, data + NUM_ANCHOR);
index 4dde9152d5a80ed417926c5ed28244b254cdb484..fd1f2fa2eff13296d512da81acc0cd5e40fc8f09 100644 (file)
@@ -154,7 +154,7 @@ private:
   std::tuple<props::MaxObjectNumber, props::ClassNumber,
              props::GridHeightNumber, props::GridWidthNumber>
     yolo_v2_loss_props;
-  std::array<unsigned int, 8> wt_idx; /**< indices of the weights */
+  std::array<unsigned int, 22> wt_idx; /**< indices of the weights */
 
   /**
    * @brief find responsible anchors per object