Add gradient norm target arg to wass gradient penalty function. This trick is usd...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 13 Feb 2018 16:46:26 +0000 (08:46 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Feb 2018 16:50:17 +0000 (08:50 -0800)
PiperOrigin-RevId: 185535584

tensorflow/contrib/gan/python/losses/python/losses_impl.py
tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
tensorflow/contrib/gan/python/train.py

index 23a3b60..39588b7 100644 (file)
@@ -305,6 +305,7 @@ def wasserstein_gradient_penalty(
     discriminator_fn,
     discriminator_scope,
     epsilon=1e-10,
+    target=1.0,
     weights=1.0,
     scope=None,
     loss_collection=ops.GraphKeys.LOSSES,
@@ -324,6 +325,8 @@ def wasserstein_gradient_penalty(
     discriminator_scope: If not `None`, reuse discriminators from this scope.
     epsilon: A small positive number added for numerical stability when
       computing the gradient norm.
+    target: Optional Python number or `Tensor` indicating the target value of
+      gradient norm. Defaults to 1.0.
     weights: Optional `Tensor` whose rank is either 0, or the same rank as
       `real_data` and `generated_data`, and must be broadcastable to
       them (i.e., all dimensions must be either `1`, or the same as the
@@ -374,7 +377,7 @@ def wasserstein_gradient_penalty(
     # For numerical stability, add epsilon to the sum before taking the square
     # root. Note tf.norm does not add epsilon.
     slopes = math_ops.sqrt(gradient_squares + epsilon)
-    penalties = math_ops.square(slopes - 1.0)
+    penalties = math_ops.square(slopes / target - 1.0)
     penalty = losses.compute_weighted_loss(
         penalties, weights, scope=scope, loss_collection=loss_collection,
         reduction=reduction)
index 56ac455..dbaa624 100644 (file)
@@ -481,6 +481,29 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
                       })
       self.assertAlmostEqual(self._expected_loss, loss, 5)
 
+  def test_loss_with_gradient_norm_target(self):
+    """Test loss value with non default gradient norm target."""
+    generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
+    real_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
+
+    loss = tfgan_losses.wasserstein_gradient_penalty(
+        generated_data,
+        real_data,
+        self._kwargs['generator_inputs'],
+        self._kwargs['discriminator_fn'],
+        self._kwargs['discriminator_scope'],
+        target=2.0)
+
+    with self.test_session() as sess:
+      variables.global_variables_initializer().run()
+      loss = sess.run(
+          loss,
+          feed_dict={
+              generated_data: self._generated_data_np,
+              real_data: self._real_data_np,
+          })
+      self.assertAlmostEqual(1.0, loss, 5)
+
   def test_reuses_scope(self):
     """Test that gradient penalty reuses discriminator scope."""
     num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
index 5d0ac93..776eb11 100644 (file)
@@ -460,6 +460,7 @@ def gan_loss(
     # Auxiliary losses.
     gradient_penalty_weight=None,
     gradient_penalty_epsilon=1e-10,
+    gradient_penalty_target=1.0,
     mutual_information_penalty_weight=None,
     aux_cond_generator_weight=None,
     aux_cond_discriminator_weight=None,
@@ -481,6 +482,9 @@ def gan_loss(
       small positive value used by the gradient penalty function for numerical
       stability. Note some applications will need to increase this value to
       avoid NaNs.
+    gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python
+      number or `Tensor` indicating the target value of gradient norm. See the
+      CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0.
     mutual_information_penalty_weight: If not `None`, must be a non-negative
       Python number or Tensor indicating how much to weight the mutual
       information penalty. See https://arxiv.org/abs/1606.03657 for more
@@ -539,7 +543,10 @@ def gan_loss(
   # Add optional extra losses.
   if _use_aux_loss(gradient_penalty_weight):
     gp_loss = tfgan_losses.wasserstein_gradient_penalty(
-        model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries)
+        model,
+        epsilon=gradient_penalty_epsilon,
+        target=gradient_penalty_target,
+        add_summaries=add_summaries)
     dis_loss += gradient_penalty_weight * gp_loss
   if _use_aux_loss(mutual_information_penalty_weight):
     info_loss = tfgan_losses.mutual_information_penalty(