[Application] Update yolo example of torch for tracking gradients
authorSeungbaek Hong <sb92.hong@samsung.com>
Wed, 29 Mar 2023 09:41:22 +0000 (18:41 +0900)
committerjijoong.moon <jijoong.moon@samsung.com>
Tue, 4 Apr 2023 01:39:05 +0000 (10:39 +0900)
For tracking the gradients of Loss class,
I removed in-place operation in the Loss class.

And, I added hook_variable function for tracking specific tensors.

We can register specific variable with name using hook_variable
function, then we can check the gradient values using
print_hook_variable function after backwarding.

Signed-off-by: Seungbaek Hong <sb92.hong@samsung.com>
Applications/YOLO/PyTorch/yolo_loss.py

index c3754cc..4b73ee5 100644 (file)
@@ -14,7 +14,7 @@ import numpy as np
 
 ##
 # @brief calculate iou between two boxes list
-def calculate_iou(bbox1, bbox2):    
+def calculate_iou(bbox1, bbox2):
     """
     @param bbox1 shape(numb_of_bbox, 4), it contains x, y, w, h
     @param bbox2 shape(numb_of_bbox, 4), it contains x, y, w, h
@@ -66,6 +66,7 @@ class YoloV2_LOSS(nn.Module):
         self.num_classes = num_classes
         self.img_shape = img_shape
         self.outsize = outsize
+        self.hook = dict()
         
         self.anchors = torch.FloatTensor(
             [(1.3221, 1.73145),
@@ -76,7 +77,30 @@ class YoloV2_LOSS(nn.Module):
         )
                 
         self.mse = nn.MSELoss()
-        self.bbox_loss, self.iou_loss, self.cls_loss = None, None, None       
+        self.bbox_loss, self.iou_loss, self.cls_loss = None, None, None
+    
+    ##
+    # @brief function to track gradients of non-leaf varibles.    
+    def hook_variable(self, name, var):
+        """ Do not use this function when training. It is for debugging. """
+        self.hook[name] = var
+        self.hook[name].requires_grad_().retain_grad()
+
+    ##
+    # @brief function to print gradients of non-leaf varibles.
+    def print_hook_variables(self):
+        """ Do not use this function when training. It is for debugging. """
+        for k, var in self.hook.items():
+            print("gradients of variable {}:".format(k))
+            batch, channel, height, width = var.grad.shape
+            for b in range(batch):
+                for c in range(channel):
+                    for h in range(height):
+                        for w in range(width):
+                            if torch.abs(var.grad[b, c, h, w]).item() >= 1e-3:
+                                print("(b: {}, c: {}, h: {}, w: {}) = {}"\
+                                      .format(b, c, h, w, var.grad[b, c, h, w]))
+            print("=" * 20)
         
     def forward(self, bbox_pred, iou_pred, prob_pred, bbox_gt, cls_gt):        
         """
@@ -87,8 +111,9 @@ class YoloV2_LOSS(nn.Module):
         @param cls_gt shape(batch_size, num_bbox, 1)
         @return loss shape(1,)
         """
+        self.hook_variable("bbox_pred", bbox_pred)
         bbox_pred = self.apply_anchors_to_bbox(bbox_pred)
-        
+
         bbox_built, iou_built, cls_built, bbox_mask, iou_mask, cls_mask =\
             self._build_target(bbox_pred, bbox_gt, cls_gt)
         
@@ -101,16 +126,17 @@ class YoloV2_LOSS(nn.Module):
         
         return self.bbox_loss * 5 + self.iou_loss + self.cls_loss
         
-    def apply_anchors_to_bbox(self, bbox_pred):        
+    def apply_anchors_to_bbox(self, bbox_pred):
         """
         @param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
         @return bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)    
         """
         anchor_w = self.anchors[:, 0].contiguous().view(-1, 1)
         anchor_h = self.anchors[:, 1].contiguous().view(-1, 1)
-        bbox_pred[:, :, :, 2:3] = torch.sqrt(bbox_pred[:, :, :, 2:3] * anchor_w)
-        bbox_pred[:, :, :, 3:4] = torch.sqrt(bbox_pred[:, :, :, 3:4] * anchor_h)        
-        return bbox_pred
+        bbox_pred_tmp = bbox_pred.clone()
+        bbox_pred_tmp[:, :, :, 2:3] = torch.sqrt(bbox_pred[:, :, :, 2:3] * anchor_w)
+        bbox_pred_tmp[:, :, :, 3:4] = torch.sqrt(bbox_pred[:, :, :, 3:4] * anchor_h)
+        return bbox_pred_tmp
     
     def _build_target(self, bbox_pred, bbox_gt, cls_gt):
         """