[Application] add re-orginization layer to Yolo v2
authorSeungbaek Hong <sb92.hong@samsung.com>
Tue, 30 May 2023 07:04:34 +0000 (16:04 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 12 Sep 2023 02:35:35 +0000 (11:35 +0900)
Added Re-organization layer to yolo v2 examples
of nntrainer and pytorch.

Signed-off-by: Seungbaek Hong <sb92.hong@samsung.com>
Applications/YOLO/PyTorch/dataset.py
Applications/YOLO/PyTorch/main.py
Applications/YOLO/PyTorch/yolo.py
Applications/YOLO/jni/Android.mk
Applications/YOLO/jni/main.cpp
Applications/YOLO/jni/meson.build
Applications/YOLO/jni/reorg_layer.cpp [new file with mode: 0644]
Applications/YOLO/jni/reorg_layer.h [new file with mode: 0644]

index 738c3e8..09d52e5 100644 (file)
@@ -23,29 +23,29 @@ class YOLODataset(Dataset):
         img_list = glob.glob(img_dir)
         ann_list = glob.glob(ann_dir)
         img_list.sort(), ann_list.sort()
-    
+
         self.length = len(img_list)
         self.input_images = []
         self.bbox_gt = []
         self.cls_gt = []
 
         for i in range(len(img_list)):
-            img = np.array(Image.open(img_list[i])) / 255
+            img = np.array(Image.open(img_list[i]).resize((416, 416))) / 255
             label_bbox = []
             label_cls = []
             with open(ann_list[i], 'rt') as f:
                 for line in f.readlines():
-                    line = [int(i) for i in line.split()]
-                    label_bbox.append(np.array(line[1:], dtype=np.float32) / 416)
-                    label_cls.append(line[0])
-                    
+                    line = [float(i) for i in line.split()]
+                    label_bbox.append(np.array(line[1:], dtype=np.float32))
+                    label_cls.append(int(line[0]))
+
             self.input_images.append(img)
             self.bbox_gt.append(label_bbox)
             self.cls_gt.append(label_cls)
-        
+
         self.input_images = np.array(self.input_images)
         self.input_images = torch.FloatTensor(self.input_images).permute((0, 3, 1, 2))
-        
+
     def __len__(self):
         return self.length
     
index 0c1b1be..1750c7b 100644 (file)
@@ -34,7 +34,7 @@ from torchconverter import save_bin
 
 # set config
 out_size = 13
-num_classes = 4
+num_classes = 92
 num_anchors = 5
 
 epochs = 1000
@@ -96,8 +96,8 @@ for epoch in range(epochs):
             # split each prediction(bbox, iou, class prob)
             bbox_pred_xy = torch.sigmoid(hypothesis[..., :2])
             bbox_pred_wh = torch.exp(hypothesis[..., 2:4])
-            bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)        
-            iou_pred = torch.sigmoid(hypothesis[..., 4:5])        
+            bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)
+            iou_pred = torch.sigmoid(hypothesis[..., 4:5])
             score_pred = hypothesis[..., 5:].contiguous()
             prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(score_pred.shape)
             # calc loss
index e31e772..53763f1 100644 (file)
@@ -15,7 +15,7 @@ import torch.nn as nn
 class YoloV2(nn.Module): 
     def __init__(self, num_classes, num_anchors=5):
         
-        super(YoloV2, self).__init__()              
+        super(YoloV2, self).__init__()
         self.num_classes = num_classes
         self.num_anchors = num_anchors
         self.conv1 = nn.Sequential(nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32, eps=1e-3),
@@ -46,8 +46,7 @@ class YoloV2(nn.Module):
                                            nn.LeakyReLU())
 
         self.conv_b = nn.Sequential(nn.Conv2d(512, 64, 1, 1, 0), nn.BatchNorm2d(64, eps=1e-3),
-                                    nn.LeakyReLU())
-        self.avgpool_b = nn.AvgPool2d(2, 2)
+                                    nn.LeakyReLU())        
 
         self.maxpool_a = nn.MaxPool2d(2, 2)
         self.conv_a1 = nn.Sequential(nn.Conv2d(512, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
@@ -65,7 +64,7 @@ class YoloV2(nn.Module):
         self.conv_a7 = nn.Sequential(nn.Conv2d(1024, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
                                             nn.LeakyReLU())
 
-        self.conv_out1 = nn.Sequential(nn.Conv2d(1088, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
+        self.conv_out1 = nn.Sequential(nn.Conv2d(1280, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
                                           nn.LeakyReLU())
 
         self.conv_out2 = nn.Conv2d(1024, self.num_anchors * (5 + num_classes), 1, 1, 0)
@@ -85,8 +84,6 @@ class YoloV2(nn.Module):
         output = self.conv12(output)
         output = self.conv13(output)
 
-        residual = output
-
         output_a = self.maxpool_a(output)
         output_a = self.conv_a1(output_a)
         output_a = self.conv_a2(output_a)
@@ -96,8 +93,11 @@ class YoloV2(nn.Module):
         output_a = self.conv_a6(output_a)
         output_a = self.conv_a7(output_a)
 
-        output_b = self.conv_b(residual)
-        output_b = self.avgpool_b(output_b)
+        output_b = self.conv_b(output)
+        b, c, h, w = output_b.size()
+        output_b = output_b.view(b, int(c / 4), h, 2, w, 2).contiguous()
+        output_b = output_b.permute(0, 3, 5, 1, 2, 4).contiguous()
+        output_b = output_b.view(b, -1, int(h / 2), int(w / 2))
 
         output = torch.cat((output_a, output_b), 1)
         output = self.conv_out1(output)
index 8e057ba..5fefab9 100644 (file)
@@ -24,7 +24,6 @@ NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer \
        $(NNTRAINER_ROOT)/api \
        $(NNTRAINER_ROOT)/api/ccapi/include \
        ${ML_API_COMMON_INCLUDES}
-       
 
 LOCAL_MODULE := nntrainer
 LOCAL_SRC_FILES := $(NNTRAINER_ROOT)/libs/$(TARGET_ARCH_ABI)/libnntrainer.so
@@ -51,9 +50,9 @@ LOCAL_ARM_MODE := arm
 LOCAL_MODULE := nntrainer_yolo
 LOCAL_LDLIBS := -llog -landroid -fopenmp
 
-LOCAL_SRC_FILES := main.cpp det_dataloader.cpp
+LOCAL_SRC_FILES := main.cpp det_dataloader.cpp reorg_layer.cpp
 LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer
 
-LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES) $(NNTRAINER_ROOT)/Applications/YOLO/jni
+LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES)
 
 include $(BUILD_EXECUTABLE)
index 30c5824..92a76bf 100644 (file)
 
 #include <det_dataloader.h>
 
+#include <app_context.h>
+#include <reorg_layer.h>
+
 using LayerHandle = std::shared_ptr<ml::train::Layer>;
 using ModelHandle = std::unique_ptr<ml::train::Model>;
 using UserDataType = std::unique_ptr<nntrainer::util::DirDataLoader>;
 
-const int num_classes = 4;
+const int num_classes = 92;
 
 int trainData_cb(float **input, float **label, bool *last, void *user_data) {
   auto data = reinterpret_cast<nntrainer::util::DirDataLoader *>(user_data);
@@ -196,14 +199,12 @@ ModelHandle YOLO() {
 
   blocks.push_back(yoloBlock("conv_b", "conv13", 64, 1, false));
   // todo: conv_b_pool layer will be replaced with re-organization custom layer
-  blocks.push_back({createLayer(
-    "pooling2d", {withKey("name", "conv_b_pool"), withKey("stride", {2, 2}),
-                  withKey("pooling", "average"), withKey("pool_size", {2, 2}),
-                  withKey("input_layers", "conv_b")})});
+  blocks.push_back({createLayer("reorg", {withKey("name", "re_organization"),
+                                          withKey("input_layers", "conv_b")})});
 
   blocks.push_back(
     {createLayer("concat", {withKey("name", "concat"),
-                            withKey("input_layers", "conv_a7, conv_b_pool"),
+                            withKey("input_layers", "conv_a7, re_organization"),
                             withKey("axis", 1)})});
 
   blocks.push_back(yoloBlock("conv_out1", "concat", 1024, 3, false));
@@ -244,6 +245,15 @@ int main(int argc, char *argv[]) {
   std::cout << "batch_size: " << batch_size << " data_split: " << data_split
             << " epoch: " << epochs << std::endl;
 
+  try {
+    auto &app_context = nntrainer::AppContext::Global();
+    app_context.registerFactory(nntrainer::createLayer<custom::ReorgLayer>);
+  } catch (std::invalid_argument &e) {
+    std::cerr << "failed to register factory, reason: " << e.what()
+              << std::endl;
+    return 1;
+  }
+
   // create train and validation data
   std::array<UserDataType, 2> user_datas;
   try {
index f7a45db..ee6c383 100644 (file)
@@ -1,11 +1,30 @@
+layer_reorg_src = files('reorg_layer.cpp')
+
+# build command for lib_reorg_layer.so
+reorg_layer = shared_library('reorg_layer',
+  layer_reorg_src,
+  dependencies: [nntrainer_dep, nntrainer_ccapi_dep],
+  include_directories: include_directories('./'),
+  install: true,
+  install_dir: application_install_dir,
+  cpp_args: '-DPLUGGABLE'
+)
+
+reorg_layer_dep = declare_dependency(
+  link_with: reorg_layer,
+  include_directories: include_directories('./')
+)
+
 yolo_sources = [
   'main.cpp',
-  'det_dataloader.cpp'
+  'det_dataloader.cpp',
+  layer_reorg_src
 ]
 
 yolo_dependencies = [app_utils_dep,
   nntrainer_dep,
-  nntrainer_ccapi_dep
+  nntrainer_ccapi_dep,
+  reorg_layer_dep
 ]
 
 e = executable('nntrainer_yolov2',
diff --git a/Applications/YOLO/jni/reorg_layer.cpp b/Applications/YOLO/jni/reorg_layer.cpp
new file mode 100644 (file)
index 0000000..9ccac8d
--- /dev/null
@@ -0,0 +1,121 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
+ *
+ * @file   reorganization.cpp
+ * @date   06 April 2023
+ * @todo support in-place operation. we can get channel, height, width
+ * coordinate from index of buffer memory. then we can use reorganizePos and
+ * restorePos func
+ * @brief  This file contains the mean absoulte error loss as a sample layer
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Seungbaek Hong <sb92.hong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ */
+
+#include <iostream>
+
+#include "reorg_layer.h"
+
+namespace custom {
+
+static constexpr size_t SINGLE_INOUT_IDX = 0;
+
+namespace ReorgOp {
+
+/**
+ * @brief re-organize tensor
+ * @return output coordinate of reorganized tensor
+ */
+int reorg(int b, int c, int h, int w, int batch, int channel, int height,
+          int width) {
+  int out_c = channel / 4;
+  int c2 = c % out_c;
+  int offset = c / out_c;
+  int w2 = w * 2 + offset % 2;
+  int h2 = h * 2 + offset / 2;
+  int out_index = w2 + width * 2 * (h2 + height * 2 * (c2 + out_c * b));
+  return out_index;
+}
+} // namespace ReorgOp
+
+void ReorgLayer::finalize(nntrainer::InitLayerContext &context) {
+  std::vector<nntrainer::TensorDim> dim = context.getInputDimensions();
+
+  for (unsigned int i = 0; i < dim.size(); ++i) {
+    if (dim[i].getDataLen() == 0) {
+      throw std::invalid_argument("Input dimension is not set");
+    } else {
+      dim[i].channel(dim[i].channel() * 4);
+      dim[i].height(dim[i].height() / 2);
+      dim[i].width(dim[i].width() / 2);
+    }
+  }
+
+  context.setOutputDimensions(dim);
+}
+
+void ReorgLayer::forwarding(nntrainer::RunLayerContext &context,
+                            bool training) {
+  nntrainer::Tensor &in = context.getInput(SINGLE_INOUT_IDX);
+  nntrainer::Tensor &out = context.getOutput(SINGLE_INOUT_IDX);
+
+  for (int b = 0; b < (int)in.batch(); b++) {
+    for (int c = 0; c < (int)in.channel(); c++) {
+      for (int h = 0; h < (int)in.height(); h++) {
+        for (int w = 0; w < (int)in.width(); w++) {
+          int out_idx =
+            w + in.width() * (h + in.height() * (c + in.channel() * b));
+          int in_idx = ReorgOp::reorg(b, c, h, w, in.batch(), in.channel(),
+                                      in.height(), in.width());
+          out.getData()[out_idx] = in.getValue(in_idx);
+        }
+      }
+    }
+  }
+}
+
+void ReorgLayer::calcDerivative(nntrainer::RunLayerContext &context) {
+  const nntrainer::Tensor &derivative_ =
+    context.getIncomingDerivative(SINGLE_INOUT_IDX);
+
+  nntrainer::Tensor &dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
+
+  for (int b = 0; b < (int)derivative_.batch(); b++) {
+    for (int c = 0; c < (int)derivative_.channel(); c++) {
+      for (int h = 0; h < (int)derivative_.height(); h++) {
+        for (int w = 0; w < (int)derivative_.width(); w++) {
+          int in_idx =
+            w + derivative_.width() *
+                  (h + derivative_.height() * (c + derivative_.channel() * b));
+          int out_idx = ReorgOp::reorg(
+            b, c, h, w, derivative_.batch(), derivative_.channel(),
+            derivative_.height(), derivative_.width());
+          dx.getData()[out_idx] = derivative_.getValue(in_idx);
+        }
+      }
+    }
+  }
+}
+
+#ifdef PLUGGABLE
+
+nntrainer::Layer *create_reorg_layer() {
+  auto layer = new ReorgLayer();
+  std::cout << "reorg created\n";
+  return layer;
+}
+
+void destroy_reorg_layer(nntrainer::Layer *layer) {
+  std::cout << "reorg deleted\n";
+  delete layer;
+}
+
+extern "C" {
+nntrainer::LayerPluggable ml_train_layer_pluggable{create_reorg_layer,
+                                                   destroy_reorg_layer};
+}
+
+#endif
+
+} // namespace custom
diff --git a/Applications/YOLO/jni/reorg_layer.h b/Applications/YOLO/jni/reorg_layer.h
new file mode 100644 (file)
index 0000000..5d9f78b
--- /dev/null
@@ -0,0 +1,83 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
+ *
+ * @file   reorganization.h
+ * @date   4 April 2023
+ * @brief  This file contains the mean absoulte error loss as a sample layer
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Seungbaek Hong <sb92.hong@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+
+#ifndef __REORGANIZATION_LAYER_H__
+#define __REORGANIZATION_LAYER_H__
+
+#include <layer_context.h>
+#include <layer_devel.h>
+#include <node_exporter.h>
+#include <utility>
+
+namespace custom {
+
+/**
+ * @brief A Re-orginazation layer for yolo v2.
+ *
+ */
+class ReorgLayer final : public nntrainer::Layer {
+public:
+  /**
+   * @brief Construct a new Reorg Layer object
+   *
+   */
+  ReorgLayer() : Layer() {}
+
+  /**
+   * @brief Destroy the Reorg Layer object
+   *
+   */
+  ~ReorgLayer() {}
+
+  /**
+   * @copydoc Layer::finalize(InitLayerContext &context)
+   */
+  void finalize(nntrainer::InitLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
+   */
+  void forwarding(nntrainer::RunLayerContext &context, bool training) override;
+
+  /**
+   * @copydoc Layer::calcDerivative(RunLayerContext &context)
+   */
+  void calcDerivative(nntrainer::RunLayerContext &context) override;
+
+  /**
+   * @copydoc bool supportBackwarding() const
+   */
+  bool supportBackwarding() const override { return true; };
+
+  /**
+   * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
+   */
+  void exportTo(nntrainer::Exporter &exporter,
+                const ml::train::ExportMethods &method) const override{};
+
+  /**
+   * @copydoc Layer::getType()
+   */
+  const std::string getType() const override { return ReorgLayer::type; };
+
+  /**
+   * @copydoc Layer::setProperty(const std::vector<std::string> &values)
+   */
+  void setProperty(const std::vector<std::string> &values) override{};
+
+  inline static const std::string type = "reorg";
+};
+
+} // namespace custom
+
+#endif /* __REORGANIZATION_LAYER_H__ */