From: hyeonseok lee Date: Thu, 23 Mar 2023 10:38:20 +0000 (+0900) Subject: [Application] match nntrainer and pytorch yolo model X-Git-Tag: accepted/tizen/8.0/unified/20231005.093407~22 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c2642986ef05a223f59522b22727c6d66b4079f3;p=platform%2Fcore%2Fml%2Fnntrainer.git [Application] match nntrainer and pytorch yolo model - Match option value like epsilon, momentum - This commit will match nntrainer yolo v2 output with pytorch yolo v2 Signed-off-by: hyeonseok lee --- diff --git a/Applications/YOLO/PyTorch/main.py b/Applications/YOLO/PyTorch/main.py index 1750c7b..c73f895 100644 --- a/Applications/YOLO/PyTorch/main.py +++ b/Applications/YOLO/PyTorch/main.py @@ -37,25 +37,27 @@ out_size = 13 num_classes = 92 num_anchors = 5 -epochs = 1000 -batch_size = 8 +epochs = 3 +batch_size = 4 -train_img_dir = './custom_dataset/images/*' -train_ann_dir = './custom_dataset/annotations/*' -valid_img_dir = './custom_dataset_val/images/*' -valid_ann_dir = './custom_dataset_val/annotations/*' +train_img_dir = 'TRAIN_DIR/images/*' +train_ann_dir = 'TRAIN_DIR/annotations/*' +valid_img_dir = 'VALIDATION_DIR/images/*' +valid_ann_dir = 'VALIDATION_DIR/annotations/*' # load data train_dataset = YOLODataset(train_img_dir, train_ann_dir) train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_db, shuffle=True, drop_last=True) valid_dataset = YOLODataset(valid_img_dir, valid_ann_dir) -valid_loader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate_db, shuffle=False, drop_last=False) +valid_loader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate_db, shuffle=False, drop_last=True) # set model, loss and optimizer model = YoloV2(num_classes=num_classes) criterion = YoloV2_LOSS(num_classes=num_classes) optimizer = optim.Adam(model.parameters(), lr=1e-3) -scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0) +# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0) + +save_bin(model, 'init_model') # train model best_loss = 1e+10 @@ -84,7 +86,7 @@ for epoch in range(epochs): # back prop loss.backward() optimizer.step() - scheduler.step() + # scheduler.step() epoch_train_loss += loss.item() for idx, (img, bbox, cls) in enumerate(valid_loader): diff --git a/Applications/YOLO/PyTorch/yolo_loss.py b/Applications/YOLO/PyTorch/yolo_loss.py index 4b73ee5..12f9557 100644 --- a/Applications/YOLO/PyTorch/yolo_loss.py +++ b/Applications/YOLO/PyTorch/yolo_loss.py @@ -160,20 +160,19 @@ class YoloV2_LOSS(nn.Module): torch.LongTensor(cls_gt[i]) ) - bbox_built.append(_bbox_built.numpy()) - bbox_mask.append(_bbox_mask.numpy()) - iou_built.append(_iou_built.numpy()) - iou_mask.append(_iou_mask.numpy()) - cls_built.append(_cls_built.numpy()) - cls_mask.append(_cls_mask.numpy()) + bbox_built.append(_bbox_built) + bbox_mask.append(_bbox_mask) + iou_built.append(_iou_built) + iou_mask.append(_iou_mask) + cls_built.append(_cls_built) + cls_mask.append(_cls_mask) - bbox_built, bbox_mask, iou_built, iou_mask, cls_built, cls_mask =\ - torch.FloatTensor(np.array(bbox_built)),\ - torch.FloatTensor(np.array(bbox_mask)),\ - torch.FloatTensor(np.array(iou_built)),\ - torch.FloatTensor(np.array(iou_mask)),\ - torch.FloatTensor(np.array(cls_built)),\ - torch.FloatTensor(np.array(cls_mask)) + bbox_built = torch.stack(bbox_built) + bbox_mask = torch.stack(bbox_mask) + iou_built = torch.stack(iou_built) + iou_mask = torch.stack(iou_mask) + cls_built = torch.stack(cls_built) + cls_mask = torch.stack(cls_mask) return bbox_built, iou_built, cls_built, bbox_mask, iou_mask, cls_mask @@ -221,7 +220,7 @@ class YoloV2_LOSS(nn.Module): _cls_mask[cell_idx, best_anchors, :] = 1 # set confidence score of gt - _iou_built = calculate_iou(_bbox_pred.reshape(-1, 4), _bbox_built.view(-1, 4)).detach() + _iou_built = calculate_iou(_bbox_pred.reshape(-1, 4), _bbox_built.view(-1, 4)).detach() _iou_built = _iou_built.view(hw, num_anchors, 1) _iou_mask[cell_idx, best_anchors, :] = 1 diff --git a/Applications/YOLO/jni/Android.mk b/Applications/YOLO/jni/Android.mk index 5fefab9..9f0dfb7 100644 --- a/Applications/YOLO/jni/Android.mk +++ b/Applications/YOLO/jni/Android.mk @@ -50,7 +50,7 @@ LOCAL_ARM_MODE := arm LOCAL_MODULE := nntrainer_yolo LOCAL_LDLIBS := -llog -landroid -fopenmp -LOCAL_SRC_FILES := main.cpp det_dataloader.cpp reorg_layer.cpp +LOCAL_SRC_FILES := main.cpp det_dataloader.cpp yolo_v2_loss.cpp reorg_layer.cpp LOCAL_SHARED_LIBRARIES := nntrainer ccapi-nntrainer LOCAL_C_INCLUDES += $(NNTRAINER_INCLUDES) diff --git a/Applications/YOLO/jni/main.cpp b/Applications/YOLO/jni/main.cpp index 92a76bf..7c470e8 100644 --- a/Applications/YOLO/jni/main.cpp +++ b/Applications/YOLO/jni/main.cpp @@ -19,11 +19,13 @@ #include #include +#include +#include #include #include #include -#include +#include "yolo_v2_loss.h" #include #include @@ -32,7 +34,19 @@ using LayerHandle = std::shared_ptr; using ModelHandle = std::unique_ptr; using UserDataType = std::unique_ptr; -const int num_classes = 92; +const unsigned int ANCHOR_NUMBER = 5; + +const unsigned int MAX_OBJECT_NUMBER = 4; +const unsigned int CLASS_NUMBER = 4; +const unsigned int GRID_HEIGHT_NUMBER = 13; +const unsigned int GRID_WIDTH_NUMBER = 13; +const unsigned int IMAGE_HEIGHT_SIZE = 416; +const unsigned int IMAGE_WIDTH_SIZE = 416; +const unsigned int BATCH_SIZE = 4; +const unsigned int EPOCHS = 3; +const char *TRAIN_DIR_PATH = "TRAIN_DIR_PATH"; +const char *VALIDATION_DIR_PATH = "VALIDATION_DIR_PATH"; +const std::string MODEL_INIT_BIN_PATH = "MODEL_INIT_BIN_PATH"; int trainData_cb(float **input, float **label, bool *last, void *user_data) { auto data = reinterpret_cast(user_data); @@ -135,9 +149,9 @@ std::vector yoloBlock(const std::string &block_name, LayerHandle a1 = createConv("a1", kernel_size, 1, "same", input_name); if (downsample) { - LayerHandle a2 = - createLayer("batch_normalization", - {with_name("a2"), withKey("activation", "leaky_relu")}); + LayerHandle a2 = createLayer("batch_normalization", + {with_name("a2"), withKey("momentum", "0.9"), + withKey("activation", "leaky_relu")}); LayerHandle a3 = createLayer( "pooling2d", {withKey("name", block_name), withKey("stride", {2, 2}), @@ -146,8 +160,9 @@ std::vector yoloBlock(const std::string &block_name, return {a1, a2, a3}; } else { LayerHandle a2 = - createLayer("batch_normalization", {withKey("name", block_name), - withKey("activation", "leaky_relu")}); + createLayer("batch_normalization", + {withKey("name", block_name), withKey("momentum", "0.9"), + withKey("activation", "leaky_relu")}); return {a1, a2}; } @@ -167,7 +182,10 @@ ModelHandle YOLO() { std::vector layers; layers.push_back(createLayer( - "input", {withKey("name", "input0"), withKey("input_shape", "3:416:416")})); + "input", + {withKey("name", "input0"), + withKey("input_shape", "3:" + std::to_string(IMAGE_HEIGHT_SIZE) + ":" + + std::to_string(IMAGE_WIDTH_SIZE))})); std::vector> blocks; @@ -212,7 +230,7 @@ ModelHandle YOLO() { blocks.push_back( {createLayer("conv2d", { withKey("name", "conv_out2"), - withKey("filters", (5 + num_classes) * 5), + withKey("filters", 5 * (5 + CLASS_NUMBER)), withKey("kernel_size", {1, 1}), withKey("stride", {1, 1}), withKey("padding", "same"), @@ -223,6 +241,30 @@ ModelHandle YOLO() { layers.insert(layers.end(), block.begin(), block.end()); } + layers.push_back(createLayer("permute", { + withKey("name", "permute"), + withKey("direction", {2, 3, 1}), + })); + + layers.push_back(createLayer( + "reshape", + { + withKey("name", "reshape"), + withKey("target_shape", + std::to_string(GRID_HEIGHT_NUMBER * GRID_WIDTH_NUMBER) + ":" + + std::to_string(ANCHOR_NUMBER) + ":" + + std::to_string(5 + CLASS_NUMBER)), + })); + + layers.push_back(createLayer( + "yolo_v2_loss", { + withKey("name", "yolo_v2_loss"), + withKey("max_object_number", MAX_OBJECT_NUMBER), + withKey("class_number", CLASS_NUMBER), + withKey("grid_height_number", GRID_HEIGHT_NUMBER), + withKey("grid_width_number", GRID_WIDTH_NUMBER), + })); + for (auto &layer : layers) { model->addLayer(layer); } @@ -237,69 +279,54 @@ int main(int argc, char *argv[]) { std::cout << "started computation at " << std::ctime(&start_time) << std::endl; - // set training config and print it - unsigned int data_size = 1; - unsigned int batch_size = 1; - unsigned int data_split = 1; - unsigned int epochs = 1; - std::cout << "batch_size: " << batch_size << " data_split: " << data_split - << " epoch: " << epochs << std::endl; - try { auto &app_context = nntrainer::AppContext::Global(); - app_context.registerFactory(nntrainer::createLayer); + app_context.registerFactory( + nntrainer::createLayer); } catch (std::invalid_argument &e) { std::cerr << "failed to register factory, reason: " << e.what() << std::endl; return 1; } - // create train and validation data - std::array user_datas; - try { - const char *train_dir = "./train_dir/"; - const char *valid_dir = "./valid_dir/"; - const int max_num_label = 5; - const int channel = 3; - const int width = 416; - const int height = 416; - user_datas = createDetDataGenerator(train_dir, valid_dir, max_num_label, - channel, width, height); - } catch (const std::exception &e) { - std::cerr << "uncaught error while creating data generator! details: " - << e.what() << std::endl; - return EXIT_FAILURE; - } - auto &[train_user_data, valid_user_data] = user_datas; + // set training config and print it + std::cout << "batch_size: " << BATCH_SIZE << " epochs: " << EPOCHS + << std::endl; try { - auto dataset_train = ml::train::createDataset( - ml::train::DatasetType::GENERATOR, trainData_cb, train_user_data.get()); - auto dataset_valid = ml::train::createDataset( - ml::train::DatasetType::GENERATOR, validData_cb, valid_user_data.get()); - // create YOLO v2 model ModelHandle model = YOLO(); - model->setProperty({withKey("batch_size", batch_size), - withKey("epochs", epochs), + model->setProperty({withKey("batch_size", BATCH_SIZE), + withKey("epochs", EPOCHS), withKey("save_path", "yolov2.bin")}); // create optimizer - auto optimizer = - ml::train::createOptimizer("adam", {"learning_rate=0.001"}); + auto optimizer = ml::train::createOptimizer( + "adam", {"learning_rate=0.001", "epsilon=1e-8", "torch_ref=true"}); model->setOptimizer(std::move(optimizer)); // compile and initialize model model->compile(); model->initialize(); + model->load(MODEL_INIT_BIN_PATH); + + // create train and validation data + std::array user_datas; + user_datas = createDetDataGenerator(TRAIN_DIR_PATH, VALIDATION_DIR_PATH, + MAX_OBJECT_NUMBER, 3, IMAGE_HEIGHT_SIZE, + IMAGE_WIDTH_SIZE); + auto &[train_user_data, valid_user_data] = user_datas; + + auto dataset_train = ml::train::createDataset( + ml::train::DatasetType::GENERATOR, trainData_cb, train_user_data.get()); + auto dataset_valid = ml::train::createDataset( + ml::train::DatasetType::GENERATOR, validData_cb, valid_user_data.get()); model->setDataset(ml::train::DatasetModeType::MODE_TRAIN, std::move(dataset_train)); model->setDataset(ml::train::DatasetModeType::MODE_VALID, std::move(dataset_valid)); - model->summarize(std::cout, ML_TRAIN_SUMMARY_MODEL); - model->train(); } catch (const std::exception &e) { std::cerr << "uncaught error while running! details: " << e.what() diff --git a/Applications/YOLO/jni/yolo_v2_loss.cpp b/Applications/YOLO/jni/yolo_v2_loss.cpp index f0a5b01..6eeb9ed 100644 --- a/Applications/YOLO/jni/yolo_v2_loss.cpp +++ b/Applications/YOLO/jni/yolo_v2_loss.cpp @@ -643,13 +643,14 @@ void YoloV2LossLayer::calcDerivative(nntrainer::RunLayerContext &context) { bbox_w_pred_grad.copyData(splitted_bbox_pred_grad[2]); bbox_h_pred_grad.copyData(splitted_bbox_pred_grad[3]); - std::vector bbox_pred_iou_grad = - calc_iou_grad(confidence_gt_grad, bbox1_width, bbox1_height, is_xy_min_max, - intersection_width, intersection_height, unions); - bbox_x_pred_grad.add_i(bbox_pred_iou_grad[0]); - bbox_y_pred_grad.add_i(bbox_pred_iou_grad[1]); - bbox_w_pred_grad.add_i(bbox_pred_iou_grad[2]); - bbox_h_pred_grad.add_i(bbox_pred_iou_grad[3]); + // std::vector bbox_pred_iou_grad = + // calc_iou_grad(confidence_gt_grad, bbox1_width, bbox1_height, + // is_xy_min_max, + // intersection_width, intersection_height, unions); + // bbox_x_pred_grad.add_i(bbox_pred_iou_grad[0]); + // bbox_y_pred_grad.add_i(bbox_pred_iou_grad[1]); + // bbox_w_pred_grad.add_i(bbox_pred_iou_grad[2]); + // bbox_h_pred_grad.add_i(bbox_pred_iou_grad[3]); /** * @brief calculate gradient for applying anchors to bounding box