img_list = glob.glob(img_dir)
ann_list = glob.glob(ann_dir)
img_list.sort(), ann_list.sort()
-
+
self.length = len(img_list)
self.input_images = []
self.bbox_gt = []
self.cls_gt = []
for i in range(len(img_list)):
- img = np.array(Image.open(img_list[i])) / 255
+ img = np.array(Image.open(img_list[i]).resize((416, 416))) / 255
label_bbox = []
label_cls = []
with open(ann_list[i], 'rt') as f:
for line in f.readlines():
- line = [int(i) for i in line.split()]
- label_bbox.append(np.array(line[1:], dtype=np.float32) / 416)
- label_cls.append(line[0])
-
+ line = [float(i) for i in line.split()]
+ label_bbox.append(np.array(line[1:], dtype=np.float32))
+ label_cls.append(int(line[0]))
+
self.input_images.append(img)
self.bbox_gt.append(label_bbox)
self.cls_gt.append(label_cls)
-
+
self.input_images = np.array(self.input_images)
self.input_images = torch.FloatTensor(self.input_images).permute((0, 3, 1, 2))
-
+
def __len__(self):
return self.length
# set config
out_size = 13
-num_classes = 4
+num_classes = 92
num_anchors = 5
epochs = 1000
# split each prediction(bbox, iou, class prob)
bbox_pred_xy = torch.sigmoid(hypothesis[..., :2])
bbox_pred_wh = torch.exp(hypothesis[..., 2:4])
- bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)
- iou_pred = torch.sigmoid(hypothesis[..., 4:5])
+ bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3)
+ iou_pred = torch.sigmoid(hypothesis[..., 4:5])
score_pred = hypothesis[..., 5:].contiguous()
prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(score_pred.shape)
# calc loss
class YoloV2(nn.Module):
def __init__(self, num_classes, num_anchors=5):
- super(YoloV2, self).__init__()
+ super(YoloV2, self).__init__()
self.num_classes = num_classes
self.num_anchors = num_anchors
self.conv1 = nn.Sequential(nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32, eps=1e-3),
nn.LeakyReLU())
self.conv_b = nn.Sequential(nn.Conv2d(512, 64, 1, 1, 0), nn.BatchNorm2d(64, eps=1e-3),
- nn.LeakyReLU())
- self.avgpool_b = nn.AvgPool2d(2, 2)
+ nn.LeakyReLU())
self.maxpool_a = nn.MaxPool2d(2, 2)
self.conv_a1 = nn.Sequential(nn.Conv2d(512, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
self.conv_a7 = nn.Sequential(nn.Conv2d(1024, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
nn.LeakyReLU())
- self.conv_out1 = nn.Sequential(nn.Conv2d(1088, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
+ self.conv_out1 = nn.Sequential(nn.Conv2d(1280, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3),
nn.LeakyReLU())
self.conv_out2 = nn.Conv2d(1024, self.num_anchors * (5 + num_classes), 1, 1, 0)
output = self.conv12(output)
output = self.conv13(output)
- residual = output
-
output_a = self.maxpool_a(output)
output_a = self.conv_a1(output_a)
output_a = self.conv_a2(output_a)
output_a = self.conv_a6(output_a)
output_a = self.conv_a7(output_a)
- output_b = self.conv_b(residual)
- output_b = self.avgpool_b(output_b)
+ output_b = self.conv_b(output)
+ b, c, h, w = output_b.size()
+ output_b = output_b.view(b, int(c / 4), h, 2, w, 2).contiguous()
+ output_b = output_b.permute(0, 3, 5, 1, 2, 4).contiguous()
+ output_b = output_b.view(b, -1, int(h / 2), int(w / 2))
output = torch.cat((output_a, output_b), 1)
output = self.conv_out1(output)
$(NNTRAINER_ROOT)/api \
$(NNTRAINER_ROOT)/api/ccapi/include \
${ML_API_COMMON_INCLUDES}
-
LOCAL_MODULE := nntrainer
LOCAL_SRC_FILES := $(NNTRAINER_ROOT)/libs/$(TARGET_ARCH_ABI)/libnntrainer.so
LOCAL_MODULE := nntrainer_yolo
LOCAL_LDLIBS := -llog -landroid -fopenmp
-LOCAL_SRC_FILES := main.cpp det_dataloader.cpp
+LOCAL_SRC_FILES := main.cpp det_dataloader.cpp reorg_layer.cpp
LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer
-LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES) $(NNTRAINER_ROOT)/Applications/YOLO/jni
+LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES)
include $(BUILD_EXECUTABLE)
#include <det_dataloader.h>
+#include <app_context.h>
+#include <reorg_layer.h>
+
using LayerHandle = std::shared_ptr<ml::train::Layer>;
using ModelHandle = std::unique_ptr<ml::train::Model>;
using UserDataType = std::unique_ptr<nntrainer::util::DirDataLoader>;
-const int num_classes = 4;
+const int num_classes = 92;
int trainData_cb(float **input, float **label, bool *last, void *user_data) {
auto data = reinterpret_cast<nntrainer::util::DirDataLoader *>(user_data);
blocks.push_back(yoloBlock("conv_b", "conv13", 64, 1, false));
// todo: conv_b_pool layer will be replaced with re-organization custom layer
- blocks.push_back({createLayer(
- "pooling2d", {withKey("name", "conv_b_pool"), withKey("stride", {2, 2}),
- withKey("pooling", "average"), withKey("pool_size", {2, 2}),
- withKey("input_layers", "conv_b")})});
+ blocks.push_back({createLayer("reorg", {withKey("name", "re_organization"),
+ withKey("input_layers", "conv_b")})});
blocks.push_back(
{createLayer("concat", {withKey("name", "concat"),
- withKey("input_layers", "conv_a7, conv_b_pool"),
+ withKey("input_layers", "conv_a7, re_organization"),
withKey("axis", 1)})});
blocks.push_back(yoloBlock("conv_out1", "concat", 1024, 3, false));
std::cout << "batch_size: " << batch_size << " data_split: " << data_split
<< " epoch: " << epochs << std::endl;
+ try {
+ auto &app_context = nntrainer::AppContext::Global();
+ app_context.registerFactory(nntrainer::createLayer<custom::ReorgLayer>);
+ } catch (std::invalid_argument &e) {
+ std::cerr << "failed to register factory, reason: " << e.what()
+ << std::endl;
+ return 1;
+ }
+
// create train and validation data
std::array<UserDataType, 2> user_datas;
try {
+layer_reorg_src = files('reorg_layer.cpp')
+
+# build command for lib_reorg_layer.so
+reorg_layer = shared_library('reorg_layer',
+ layer_reorg_src,
+ dependencies: [nntrainer_dep, nntrainer_ccapi_dep],
+ include_directories: include_directories('./'),
+ install: true,
+ install_dir: application_install_dir,
+ cpp_args: '-DPLUGGABLE'
+)
+
+reorg_layer_dep = declare_dependency(
+ link_with: reorg_layer,
+ include_directories: include_directories('./')
+)
+
yolo_sources = [
'main.cpp',
- 'det_dataloader.cpp'
+ 'det_dataloader.cpp',
+ layer_reorg_src
]
yolo_dependencies = [app_utils_dep,
nntrainer_dep,
- nntrainer_ccapi_dep
+ nntrainer_ccapi_dep,
+ reorg_layer_dep
]
e = executable('nntrainer_yolov2',
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
+ *
+ * @file reorganization.cpp
+ * @date 06 April 2023
+ * @todo support in-place operation. we can get channel, height, width
+ * coordinate from index of buffer memory. then we can use reorganizePos and
+ * restorePos func
+ * @brief This file contains the mean absoulte error loss as a sample layer
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Seungbaek Hong <sb92.hong@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+
+#include <iostream>
+
+#include "reorg_layer.h"
+
+namespace custom {
+
+static constexpr size_t SINGLE_INOUT_IDX = 0;
+
+namespace ReorgOp {
+
+/**
+ * @brief re-organize tensor
+ * @return output coordinate of reorganized tensor
+ */
+int reorg(int b, int c, int h, int w, int batch, int channel, int height,
+ int width) {
+ int out_c = channel / 4;
+ int c2 = c % out_c;
+ int offset = c / out_c;
+ int w2 = w * 2 + offset % 2;
+ int h2 = h * 2 + offset / 2;
+ int out_index = w2 + width * 2 * (h2 + height * 2 * (c2 + out_c * b));
+ return out_index;
+}
+} // namespace ReorgOp
+
+void ReorgLayer::finalize(nntrainer::InitLayerContext &context) {
+ std::vector<nntrainer::TensorDim> dim = context.getInputDimensions();
+
+ for (unsigned int i = 0; i < dim.size(); ++i) {
+ if (dim[i].getDataLen() == 0) {
+ throw std::invalid_argument("Input dimension is not set");
+ } else {
+ dim[i].channel(dim[i].channel() * 4);
+ dim[i].height(dim[i].height() / 2);
+ dim[i].width(dim[i].width() / 2);
+ }
+ }
+
+ context.setOutputDimensions(dim);
+}
+
+void ReorgLayer::forwarding(nntrainer::RunLayerContext &context,
+ bool training) {
+ nntrainer::Tensor &in = context.getInput(SINGLE_INOUT_IDX);
+ nntrainer::Tensor &out = context.getOutput(SINGLE_INOUT_IDX);
+
+ for (int b = 0; b < (int)in.batch(); b++) {
+ for (int c = 0; c < (int)in.channel(); c++) {
+ for (int h = 0; h < (int)in.height(); h++) {
+ for (int w = 0; w < (int)in.width(); w++) {
+ int out_idx =
+ w + in.width() * (h + in.height() * (c + in.channel() * b));
+ int in_idx = ReorgOp::reorg(b, c, h, w, in.batch(), in.channel(),
+ in.height(), in.width());
+ out.getData()[out_idx] = in.getValue(in_idx);
+ }
+ }
+ }
+ }
+}
+
+void ReorgLayer::calcDerivative(nntrainer::RunLayerContext &context) {
+ const nntrainer::Tensor &derivative_ =
+ context.getIncomingDerivative(SINGLE_INOUT_IDX);
+
+ nntrainer::Tensor &dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
+
+ for (int b = 0; b < (int)derivative_.batch(); b++) {
+ for (int c = 0; c < (int)derivative_.channel(); c++) {
+ for (int h = 0; h < (int)derivative_.height(); h++) {
+ for (int w = 0; w < (int)derivative_.width(); w++) {
+ int in_idx =
+ w + derivative_.width() *
+ (h + derivative_.height() * (c + derivative_.channel() * b));
+ int out_idx = ReorgOp::reorg(
+ b, c, h, w, derivative_.batch(), derivative_.channel(),
+ derivative_.height(), derivative_.width());
+ dx.getData()[out_idx] = derivative_.getValue(in_idx);
+ }
+ }
+ }
+ }
+}
+
+#ifdef PLUGGABLE
+
+nntrainer::Layer *create_reorg_layer() {
+ auto layer = new ReorgLayer();
+ std::cout << "reorg created\n";
+ return layer;
+}
+
+void destroy_reorg_layer(nntrainer::Layer *layer) {
+ std::cout << "reorg deleted\n";
+ delete layer;
+}
+
+extern "C" {
+nntrainer::LayerPluggable ml_train_layer_pluggable{create_reorg_layer,
+ destroy_reorg_layer};
+}
+
+#endif
+
+} // namespace custom
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2023 Seungbaek Hong <sb92.hong@samsung.com>
+ *
+ * @file reorganization.h
+ * @date 4 April 2023
+ * @brief This file contains the mean absoulte error loss as a sample layer
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Seungbaek Hong <sb92.hong@samsung.com>
+ * @bug No known bugs except for NYI items
+ *
+ */
+
+#ifndef __REORGANIZATION_LAYER_H__
+#define __REORGANIZATION_LAYER_H__
+
+#include <layer_context.h>
+#include <layer_devel.h>
+#include <node_exporter.h>
+#include <utility>
+
+namespace custom {
+
+/**
+ * @brief A Re-orginazation layer for yolo v2.
+ *
+ */
+class ReorgLayer final : public nntrainer::Layer {
+public:
+ /**
+ * @brief Construct a new Reorg Layer object
+ *
+ */
+ ReorgLayer() : Layer() {}
+
+ /**
+ * @brief Destroy the Reorg Layer object
+ *
+ */
+ ~ReorgLayer() {}
+
+ /**
+ * @copydoc Layer::finalize(InitLayerContext &context)
+ */
+ void finalize(nntrainer::InitLayerContext &context) override;
+
+ /**
+ * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
+ */
+ void forwarding(nntrainer::RunLayerContext &context, bool training) override;
+
+ /**
+ * @copydoc Layer::calcDerivative(RunLayerContext &context)
+ */
+ void calcDerivative(nntrainer::RunLayerContext &context) override;
+
+ /**
+ * @copydoc bool supportBackwarding() const
+ */
+ bool supportBackwarding() const override { return true; };
+
+ /**
+ * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
+ */
+ void exportTo(nntrainer::Exporter &exporter,
+ const ml::train::ExportMethods &method) const override{};
+
+ /**
+ * @copydoc Layer::getType()
+ */
+ const std::string getType() const override { return ReorgLayer::type; };
+
+ /**
+ * @copydoc Layer::setProperty(const std::vector<std::string> &values)
+ */
+ void setProperty(const std::vector<std::string> &values) override{};
+
+ inline static const std::string type = "reorg";
+};
+
+} // namespace custom
+
+#endif /* __REORGANIZATION_LAYER_H__ */