homotopy_weighting=False,
log_D_trick=False,
unjoined_lr_loss=False,
+ uncertainty_penalty=1.0,
**kwargs
):
super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
assert not (log_D_trick and unjoined_lr_loss)
self.log_D_trick = log_D_trick
self.unjoined_lr_loss = unjoined_lr_loss
+ assert uncertainty_penalty >= 0
+ self.uncertainty_penalty = uncertainty_penalty
self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
)
else:
loss = xent
+
+ if 'log_variance' in self.input_record.fields:
+ # mean (0.5 * exp(-s) * loss + 0.5 * penalty * s)
+ log_variance_blob = self.input_record.log_variance()
+
+ log_variance_blob = net.ExpandDims(
+ log_variance_blob, net.NextScopedBlob('expanded_log_variance'),
+ dims=[1]
+ )
+
+ neg_log_variance_blob = net.Negative(
+ [log_variance_blob],
+ net.NextScopedBlob('neg_log_variance')
+ )
+
+ # enforce less than 88 to avoid OverflowError
+ neg_log_variance_blob = net.Clip(
+ [neg_log_variance_blob],
+ net.NextScopedBlob('clipped_neg_log_variance'),
+ max=88.0
+ )
+
+ exp_neg_log_variance_blob = net.Exp(
+ [neg_log_variance_blob],
+ net.NextScopedBlob('exp_neg_log_variance')
+ )
+
+ exp_neg_log_variance_loss_blob = net.Mul(
+ [exp_neg_log_variance_blob, loss],
+ net.NextScopedBlob('exp_neg_log_variance_loss')
+ )
+
+ penalized_uncertainty = net.Scale(
+ log_variance_blob, net.NextScopedBlob("penalized_unceratinty"),
+ scale=float(self.uncertainty_penalty)
+ )
+
+ loss_2x = net.Add(
+ [exp_neg_log_variance_loss_blob, penalized_uncertainty],
+ net.NextScopedBlob('loss')
+ )
+ loss = net.Scale(loss_2x, net.NextScopedBlob("loss"), scale=0.5)
+
if 'weight' in self.input_record.fields:
weight_blob = self.input_record.weight()
if self.input_record.weight.field_type().base != np.float32:
loss = self.model.BatchLRLoss(input_record)
self.assertEqual(schema.Scalar((np.float32, tuple())), loss)
+ def testBatchLRLossWithUncertainty(self):
+ input_record = self.new_record(schema.Struct(
+ ('label', schema.Scalar((np.float64, (1,)))),
+ ('logit', schema.Scalar((np.float32, (2,)))),
+ ('weight', schema.Scalar((np.float64, (1,)))),
+ ('log_variance', schema.Scalar((np.float64, (1,)))),
+ ))
+ loss = self.model.BatchLRLoss(input_record)
+ self.assertEqual(schema.Scalar((np.float32, tuple())), loss)
+
def testMarginRankLoss(self):
input_record = self.new_record(schema.Struct(
('pos_prediction', schema.Scalar((np.float32, (1,)))),