From 10fe5291e97c2e7dd6cb428e9b27640bc94ec04c Mon Sep 17 00:00:00 2001 From: Seungbaek Hong Date: Tue, 30 May 2023 16:04:34 +0900 Subject: [PATCH] [Application] add re-orginization layer to Yolo v2 Added Re-organization layer to yolo v2 examples of nntrainer and pytorch. Signed-off-by: Seungbaek Hong --- Applications/YOLO/PyTorch/dataset.py | 16 ++--- Applications/YOLO/PyTorch/main.py | 6 +- Applications/YOLO/PyTorch/yolo.py | 16 ++--- Applications/YOLO/jni/Android.mk | 5 +- Applications/YOLO/jni/main.cpp | 22 +++++-- Applications/YOLO/jni/meson.build | 23 ++++++- Applications/YOLO/jni/reorg_layer.cpp | 121 ++++++++++++++++++++++++++++++++++ Applications/YOLO/jni/reorg_layer.h | 83 +++++++++++++++++++++++ 8 files changed, 262 insertions(+), 30 deletions(-) create mode 100644 Applications/YOLO/jni/reorg_layer.cpp create mode 100644 Applications/YOLO/jni/reorg_layer.h diff --git a/Applications/YOLO/PyTorch/dataset.py b/Applications/YOLO/PyTorch/dataset.py index 738c3e8..09d52e5 100644 --- a/Applications/YOLO/PyTorch/dataset.py +++ b/Applications/YOLO/PyTorch/dataset.py @@ -23,29 +23,29 @@ class YOLODataset(Dataset): img_list = glob.glob(img_dir) ann_list = glob.glob(ann_dir) img_list.sort(), ann_list.sort() - + self.length = len(img_list) self.input_images = [] self.bbox_gt = [] self.cls_gt = [] for i in range(len(img_list)): - img = np.array(Image.open(img_list[i])) / 255 + img = np.array(Image.open(img_list[i]).resize((416, 416))) / 255 label_bbox = [] label_cls = [] with open(ann_list[i], 'rt') as f: for line in f.readlines(): - line = [int(i) for i in line.split()] - label_bbox.append(np.array(line[1:], dtype=np.float32) / 416) - label_cls.append(line[0]) - + line = [float(i) for i in line.split()] + label_bbox.append(np.array(line[1:], dtype=np.float32)) + label_cls.append(int(line[0])) + self.input_images.append(img) self.bbox_gt.append(label_bbox) self.cls_gt.append(label_cls) - + self.input_images = np.array(self.input_images) self.input_images = torch.FloatTensor(self.input_images).permute((0, 3, 1, 2)) - + def __len__(self): return self.length diff --git a/Applications/YOLO/PyTorch/main.py b/Applications/YOLO/PyTorch/main.py index 0c1b1be..1750c7b 100644 --- a/Applications/YOLO/PyTorch/main.py +++ b/Applications/YOLO/PyTorch/main.py @@ -34,7 +34,7 @@ from torchconverter import save_bin # set config out_size = 13 -num_classes = 4 +num_classes = 92 num_anchors = 5 epochs = 1000 @@ -96,8 +96,8 @@ for epoch in range(epochs): # split each prediction(bbox, iou, class prob) bbox_pred_xy = torch.sigmoid(hypothesis[..., :2]) bbox_pred_wh = torch.exp(hypothesis[..., 2:4]) - bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3) - iou_pred = torch.sigmoid(hypothesis[..., 4:5]) + bbox_pred = torch.cat((bbox_pred_xy, bbox_pred_wh), 3) + iou_pred = torch.sigmoid(hypothesis[..., 4:5]) score_pred = hypothesis[..., 5:].contiguous() prob_pred = torch.softmax(score_pred.view(-1, num_classes), dim=1).view(score_pred.shape) # calc loss diff --git a/Applications/YOLO/PyTorch/yolo.py b/Applications/YOLO/PyTorch/yolo.py index e31e772..53763f1 100644 --- a/Applications/YOLO/PyTorch/yolo.py +++ b/Applications/YOLO/PyTorch/yolo.py @@ -15,7 +15,7 @@ import torch.nn as nn class YoloV2(nn.Module): def __init__(self, num_classes, num_anchors=5): - super(YoloV2, self).__init__() + super(YoloV2, self).__init__() self.num_classes = num_classes self.num_anchors = num_anchors self.conv1 = nn.Sequential(nn.Conv2d(3, 32, 3, 1, 1), nn.BatchNorm2d(32, eps=1e-3), @@ -46,8 +46,7 @@ class YoloV2(nn.Module): nn.LeakyReLU()) self.conv_b = nn.Sequential(nn.Conv2d(512, 64, 1, 1, 0), nn.BatchNorm2d(64, eps=1e-3), - nn.LeakyReLU()) - self.avgpool_b = nn.AvgPool2d(2, 2) + nn.LeakyReLU()) self.maxpool_a = nn.MaxPool2d(2, 2) self.conv_a1 = nn.Sequential(nn.Conv2d(512, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3), @@ -65,7 +64,7 @@ class YoloV2(nn.Module): self.conv_a7 = nn.Sequential(nn.Conv2d(1024, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3), nn.LeakyReLU()) - self.conv_out1 = nn.Sequential(nn.Conv2d(1088, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3), + self.conv_out1 = nn.Sequential(nn.Conv2d(1280, 1024, 3, 1, 1), nn.BatchNorm2d(1024, eps=1e-3), nn.LeakyReLU()) self.conv_out2 = nn.Conv2d(1024, self.num_anchors * (5 + num_classes), 1, 1, 0) @@ -85,8 +84,6 @@ class YoloV2(nn.Module): output = self.conv12(output) output = self.conv13(output) - residual = output - output_a = self.maxpool_a(output) output_a = self.conv_a1(output_a) output_a = self.conv_a2(output_a) @@ -96,8 +93,11 @@ class YoloV2(nn.Module): output_a = self.conv_a6(output_a) output_a = self.conv_a7(output_a) - output_b = self.conv_b(residual) - output_b = self.avgpool_b(output_b) + output_b = self.conv_b(output) + b, c, h, w = output_b.size() + output_b = output_b.view(b, int(c / 4), h, 2, w, 2).contiguous() + output_b = output_b.permute(0, 3, 5, 1, 2, 4).contiguous() + output_b = output_b.view(b, -1, int(h / 2), int(w / 2)) output = torch.cat((output_a, output_b), 1) output = self.conv_out1(output) diff --git a/Applications/YOLO/jni/Android.mk b/Applications/YOLO/jni/Android.mk index 8e057ba..5fefab9 100644 --- a/Applications/YOLO/jni/Android.mk +++ b/Applications/YOLO/jni/Android.mk @@ -24,7 +24,6 @@ NNTRAINER_INCLUDES := $(NNTRAINER_ROOT)/nntrainer \ $(NNTRAINER_ROOT)/api \ $(NNTRAINER_ROOT)/api/ccapi/include \ ${ML_API_COMMON_INCLUDES} - LOCAL_MODULE := nntrainer LOCAL_SRC_FILES := $(NNTRAINER_ROOT)/libs/$(TARGET_ARCH_ABI)/libnntrainer.so @@ -51,9 +50,9 @@ LOCAL_ARM_MODE := arm LOCAL_MODULE := nntrainer_yolo LOCAL_LDLIBS := -llog -landroid -fopenmp -LOCAL_SRC_FILES := main.cpp det_dataloader.cpp +LOCAL_SRC_FILES := main.cpp det_dataloader.cpp reorg_layer.cpp LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer -LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES) $(NNTRAINER_ROOT)/Applications/YOLO/jni +LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES) include $(BUILD_EXECUTABLE) diff --git a/Applications/YOLO/jni/main.cpp b/Applications/YOLO/jni/main.cpp index 30c5824..92a76bf 100644 --- a/Applications/YOLO/jni/main.cpp +++ b/Applications/YOLO/jni/main.cpp @@ -25,11 +25,14 @@ #include +#include +#include + using LayerHandle = std::shared_ptr; using ModelHandle = std::unique_ptr; using UserDataType = std::unique_ptr; -const int num_classes = 4; +const int num_classes = 92; int trainData_cb(float **input, float **label, bool *last, void *user_data) { auto data = reinterpret_cast(user_data); @@ -196,14 +199,12 @@ ModelHandle YOLO() { blocks.push_back(yoloBlock("conv_b", "conv13", 64, 1, false)); // todo: conv_b_pool layer will be replaced with re-organization custom layer - blocks.push_back({createLayer( - "pooling2d", {withKey("name", "conv_b_pool"), withKey("stride", {2, 2}), - withKey("pooling", "average"), withKey("pool_size", {2, 2}), - withKey("input_layers", "conv_b")})}); + blocks.push_back({createLayer("reorg", {withKey("name", "re_organization"), + withKey("input_layers", "conv_b")})}); blocks.push_back( {createLayer("concat", {withKey("name", "concat"), - withKey("input_layers", "conv_a7, conv_b_pool"), + withKey("input_layers", "conv_a7, re_organization"), withKey("axis", 1)})}); blocks.push_back(yoloBlock("conv_out1", "concat", 1024, 3, false)); @@ -244,6 +245,15 @@ int main(int argc, char *argv[]) { std::cout << "batch_size: " << batch_size << " data_split: " << data_split << " epoch: " << epochs << std::endl; + try { + auto &app_context = nntrainer::AppContext::Global(); + app_context.registerFactory(nntrainer::createLayer); + } catch (std::invalid_argument &e) { + std::cerr << "failed to register factory, reason: " << e.what() + << std::endl; + return 1; + } + // create train and validation data std::array user_datas; try { diff --git a/Applications/YOLO/jni/meson.build b/Applications/YOLO/jni/meson.build index f7a45db..ee6c383 100644 --- a/Applications/YOLO/jni/meson.build +++ b/Applications/YOLO/jni/meson.build @@ -1,11 +1,30 @@ +layer_reorg_src = files('reorg_layer.cpp') + +# build command for lib_reorg_layer.so +reorg_layer = shared_library('reorg_layer', + layer_reorg_src, + dependencies: [nntrainer_dep, nntrainer_ccapi_dep], + include_directories: include_directories('./'), + install: true, + install_dir: application_install_dir, + cpp_args: '-DPLUGGABLE' +) + +reorg_layer_dep = declare_dependency( + link_with: reorg_layer, + include_directories: include_directories('./') +) + yolo_sources = [ 'main.cpp', - 'det_dataloader.cpp' + 'det_dataloader.cpp', + layer_reorg_src ] yolo_dependencies = [app_utils_dep, nntrainer_dep, - nntrainer_ccapi_dep + nntrainer_ccapi_dep, + reorg_layer_dep ] e = executable('nntrainer_yolov2', diff --git a/Applications/YOLO/jni/reorg_layer.cpp b/Applications/YOLO/jni/reorg_layer.cpp new file mode 100644 index 0000000..9ccac8d --- /dev/null +++ b/Applications/YOLO/jni/reorg_layer.cpp @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2023 Seungbaek Hong + * + * @file reorganization.cpp + * @date 06 April 2023 + * @todo support in-place operation. we can get channel, height, width + * coordinate from index of buffer memory. then we can use reorganizePos and + * restorePos func + * @brief This file contains the mean absoulte error loss as a sample layer + * @see https://github.com/nnstreamer/nntrainer + * @author Seungbaek Hong + * @bug No known bugs except for NYI items + */ + +#include + +#include "reorg_layer.h" + +namespace custom { + +static constexpr size_t SINGLE_INOUT_IDX = 0; + +namespace ReorgOp { + +/** + * @brief re-organize tensor + * @return output coordinate of reorganized tensor + */ +int reorg(int b, int c, int h, int w, int batch, int channel, int height, + int width) { + int out_c = channel / 4; + int c2 = c % out_c; + int offset = c / out_c; + int w2 = w * 2 + offset % 2; + int h2 = h * 2 + offset / 2; + int out_index = w2 + width * 2 * (h2 + height * 2 * (c2 + out_c * b)); + return out_index; +} +} // namespace ReorgOp + +void ReorgLayer::finalize(nntrainer::InitLayerContext &context) { + std::vector dim = context.getInputDimensions(); + + for (unsigned int i = 0; i < dim.size(); ++i) { + if (dim[i].getDataLen() == 0) { + throw std::invalid_argument("Input dimension is not set"); + } else { + dim[i].channel(dim[i].channel() * 4); + dim[i].height(dim[i].height() / 2); + dim[i].width(dim[i].width() / 2); + } + } + + context.setOutputDimensions(dim); +} + +void ReorgLayer::forwarding(nntrainer::RunLayerContext &context, + bool training) { + nntrainer::Tensor &in = context.getInput(SINGLE_INOUT_IDX); + nntrainer::Tensor &out = context.getOutput(SINGLE_INOUT_IDX); + + for (int b = 0; b < (int)in.batch(); b++) { + for (int c = 0; c < (int)in.channel(); c++) { + for (int h = 0; h < (int)in.height(); h++) { + for (int w = 0; w < (int)in.width(); w++) { + int out_idx = + w + in.width() * (h + in.height() * (c + in.channel() * b)); + int in_idx = ReorgOp::reorg(b, c, h, w, in.batch(), in.channel(), + in.height(), in.width()); + out.getData()[out_idx] = in.getValue(in_idx); + } + } + } + } +} + +void ReorgLayer::calcDerivative(nntrainer::RunLayerContext &context) { + const nntrainer::Tensor &derivative_ = + context.getIncomingDerivative(SINGLE_INOUT_IDX); + + nntrainer::Tensor &dx = context.getOutgoingDerivative(SINGLE_INOUT_IDX); + + for (int b = 0; b < (int)derivative_.batch(); b++) { + for (int c = 0; c < (int)derivative_.channel(); c++) { + for (int h = 0; h < (int)derivative_.height(); h++) { + for (int w = 0; w < (int)derivative_.width(); w++) { + int in_idx = + w + derivative_.width() * + (h + derivative_.height() * (c + derivative_.channel() * b)); + int out_idx = ReorgOp::reorg( + b, c, h, w, derivative_.batch(), derivative_.channel(), + derivative_.height(), derivative_.width()); + dx.getData()[out_idx] = derivative_.getValue(in_idx); + } + } + } + } +} + +#ifdef PLUGGABLE + +nntrainer::Layer *create_reorg_layer() { + auto layer = new ReorgLayer(); + std::cout << "reorg created\n"; + return layer; +} + +void destroy_reorg_layer(nntrainer::Layer *layer) { + std::cout << "reorg deleted\n"; + delete layer; +} + +extern "C" { +nntrainer::LayerPluggable ml_train_layer_pluggable{create_reorg_layer, + destroy_reorg_layer}; +} + +#endif + +} // namespace custom diff --git a/Applications/YOLO/jni/reorg_layer.h b/Applications/YOLO/jni/reorg_layer.h new file mode 100644 index 0000000..5d9f78b --- /dev/null +++ b/Applications/YOLO/jni/reorg_layer.h @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2023 Seungbaek Hong + * + * @file reorganization.h + * @date 4 April 2023 + * @brief This file contains the mean absoulte error loss as a sample layer + * @see https://github.com/nnstreamer/nntrainer + * @author Seungbaek Hong + * @bug No known bugs except for NYI items + * + */ + +#ifndef __REORGANIZATION_LAYER_H__ +#define __REORGANIZATION_LAYER_H__ + +#include +#include +#include +#include + +namespace custom { + +/** + * @brief A Re-orginazation layer for yolo v2. + * + */ +class ReorgLayer final : public nntrainer::Layer { +public: + /** + * @brief Construct a new Reorg Layer object + * + */ + ReorgLayer() : Layer() {} + + /** + * @brief Destroy the Reorg Layer object + * + */ + ~ReorgLayer() {} + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(nntrainer::InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(nntrainer::RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(nntrainer::RunLayerContext &context) override; + + /** + * @copydoc bool supportBackwarding() const + */ + bool supportBackwarding() const override { return true; }; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) + */ + void exportTo(nntrainer::Exporter &exporter, + const ml::train::ExportMethods &method) const override{}; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { return ReorgLayer::type; }; + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) override{}; + + inline static const std::string type = "reorg"; +}; + +} // namespace custom + +#endif /* __REORGANIZATION_LAYER_H__ */ -- 2.7.4