try to enable uncertainty for lr loss (#17236)
authorXing Wang <xingwang@fb.com>
Thu, 11 Apr 2019 14:27:46 +0000 (07:27 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 11 Apr 2019 14:35:19 +0000 (07:35 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17236

Following the paper in https://papers.nips.cc/paper/7141-what-uncertainties-do-we-need-in-bayesian-deep-learning-for-computer-vision.pdf, approximate the classification case with the regression formulation. For the LRLoss, add penalty based on the variance and regularization on the variance with a tunable parameter lambda.

Reviewed By: chocjy

Differential Revision: D14077106

fbshipit-source-id: 4405d8995cebdc7275a0dd07857d32a8915d78ef

caffe2/python/layers/batch_lr_loss.py
caffe2/python/layers_test.py

index 2d9cc80..932aab1 100644 (file)
@@ -29,6 +29,7 @@ class BatchLRLoss(ModelLayer):
         homotopy_weighting=False,
         log_D_trick=False,
         unjoined_lr_loss=False,
+        uncertainty_penalty=1.0,
         **kwargs
     ):
         super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
@@ -60,6 +61,8 @@ class BatchLRLoss(ModelLayer):
         assert not (log_D_trick and unjoined_lr_loss)
         self.log_D_trick = log_D_trick
         self.unjoined_lr_loss = unjoined_lr_loss
+        assert uncertainty_penalty >= 0
+        self.uncertainty_penalty = uncertainty_penalty
 
         self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
 
@@ -170,6 +173,49 @@ class BatchLRLoss(ModelLayer):
             )
         else:
             loss = xent
+
+        if 'log_variance' in self.input_record.fields:
+            # mean (0.5 * exp(-s) * loss + 0.5 * penalty * s)
+            log_variance_blob = self.input_record.log_variance()
+
+            log_variance_blob = net.ExpandDims(
+                log_variance_blob, net.NextScopedBlob('expanded_log_variance'),
+                dims=[1]
+            )
+
+            neg_log_variance_blob = net.Negative(
+                [log_variance_blob],
+                net.NextScopedBlob('neg_log_variance')
+            )
+
+            # enforce less than 88 to avoid OverflowError
+            neg_log_variance_blob = net.Clip(
+                [neg_log_variance_blob],
+                net.NextScopedBlob('clipped_neg_log_variance'),
+                max=88.0
+            )
+
+            exp_neg_log_variance_blob = net.Exp(
+                [neg_log_variance_blob],
+                net.NextScopedBlob('exp_neg_log_variance')
+            )
+
+            exp_neg_log_variance_loss_blob = net.Mul(
+                [exp_neg_log_variance_blob, loss],
+                net.NextScopedBlob('exp_neg_log_variance_loss')
+            )
+
+            penalized_uncertainty = net.Scale(
+                log_variance_blob, net.NextScopedBlob("penalized_unceratinty"),
+                scale=float(self.uncertainty_penalty)
+            )
+
+            loss_2x = net.Add(
+                [exp_neg_log_variance_loss_blob, penalized_uncertainty],
+                net.NextScopedBlob('loss')
+            )
+            loss = net.Scale(loss_2x, net.NextScopedBlob("loss"), scale=0.5)
+
         if 'weight' in self.input_record.fields:
             weight_blob = self.input_record.weight()
             if self.input_record.weight.field_type().base != np.float32:
index b429b43..0310603 100644 (file)
@@ -776,6 +776,16 @@ class TestLayers(LayersTestCase):
         loss = self.model.BatchLRLoss(input_record)
         self.assertEqual(schema.Scalar((np.float32, tuple())), loss)
 
+    def testBatchLRLossWithUncertainty(self):
+        input_record = self.new_record(schema.Struct(
+            ('label', schema.Scalar((np.float64, (1,)))),
+            ('logit', schema.Scalar((np.float32, (2,)))),
+            ('weight', schema.Scalar((np.float64, (1,)))),
+            ('log_variance', schema.Scalar((np.float64, (1,)))),
+        ))
+        loss = self.model.BatchLRLoss(input_record)
+        self.assertEqual(schema.Scalar((np.float32, tuple())), loss)
+
     def testMarginRankLoss(self):
         input_record = self.new_record(schema.Struct(
             ('pos_prediction', schema.Scalar((np.float32, (1,)))),