##
# @brief calculate iou between two boxes list
-def calculate_iou(bbox1, bbox2):
+def calculate_iou(bbox1, bbox2):
"""
@param bbox1 shape(numb_of_bbox, 4), it contains x, y, w, h
@param bbox2 shape(numb_of_bbox, 4), it contains x, y, w, h
self.num_classes = num_classes
self.img_shape = img_shape
self.outsize = outsize
+ self.hook = dict()
self.anchors = torch.FloatTensor(
[(1.3221, 1.73145),
)
self.mse = nn.MSELoss()
- self.bbox_loss, self.iou_loss, self.cls_loss = None, None, None
+ self.bbox_loss, self.iou_loss, self.cls_loss = None, None, None
+
+ ##
+ # @brief function to track gradients of non-leaf varibles.
+ def hook_variable(self, name, var):
+ """ Do not use this function when training. It is for debugging. """
+ self.hook[name] = var
+ self.hook[name].requires_grad_().retain_grad()
+
+ ##
+ # @brief function to print gradients of non-leaf varibles.
+ def print_hook_variables(self):
+ """ Do not use this function when training. It is for debugging. """
+ for k, var in self.hook.items():
+ print("gradients of variable {}:".format(k))
+ batch, channel, height, width = var.grad.shape
+ for b in range(batch):
+ for c in range(channel):
+ for h in range(height):
+ for w in range(width):
+ if torch.abs(var.grad[b, c, h, w]).item() >= 1e-3:
+ print("(b: {}, c: {}, h: {}, w: {}) = {}"\
+ .format(b, c, h, w, var.grad[b, c, h, w]))
+ print("=" * 20)
def forward(self, bbox_pred, iou_pred, prob_pred, bbox_gt, cls_gt):
"""
@param cls_gt shape(batch_size, num_bbox, 1)
@return loss shape(1,)
"""
+ self.hook_variable("bbox_pred", bbox_pred)
bbox_pred = self.apply_anchors_to_bbox(bbox_pred)
-
+
bbox_built, iou_built, cls_built, bbox_mask, iou_mask, cls_mask =\
self._build_target(bbox_pred, bbox_gt, cls_gt)
return self.bbox_loss * 5 + self.iou_loss + self.cls_loss
- def apply_anchors_to_bbox(self, bbox_pred):
+ def apply_anchors_to_bbox(self, bbox_pred):
"""
@param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
@return bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
"""
anchor_w = self.anchors[:, 0].contiguous().view(-1, 1)
anchor_h = self.anchors[:, 1].contiguous().view(-1, 1)
- bbox_pred[:, :, :, 2:3] = torch.sqrt(bbox_pred[:, :, :, 2:3] * anchor_w)
- bbox_pred[:, :, :, 3:4] = torch.sqrt(bbox_pred[:, :, :, 3:4] * anchor_h)
- return bbox_pred
+ bbox_pred_tmp = bbox_pred.clone()
+ bbox_pred_tmp[:, :, :, 2:3] = torch.sqrt(bbox_pred[:, :, :, 2:3] * anchor_w)
+ bbox_pred_tmp[:, :, :, 3:4] = torch.sqrt(bbox_pred[:, :, :, 3:4] * anchor_h)
+ return bbox_pred_tmp
def _build_target(self, bbox_pred, bbox_gt, cls_gt):
"""