From 1fc821602d69e5812b854a61f09f163ce549641b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 13 Feb 2018 08:46:26 -0800 Subject: [PATCH] Add gradient norm target arg to wass gradient penalty function. This trick is usd in the progressive GAN paper https://arxiv.org/abs/1710.10196 PiperOrigin-RevId: 185535584 --- .../gan/python/losses/python/losses_impl.py | 5 ++++- .../gan/python/losses/python/losses_impl_test.py | 23 ++++++++++++++++++++++ tensorflow/contrib/gan/python/train.py | 9 ++++++++- 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 23a3b60..39588b7 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -305,6 +305,7 @@ def wasserstein_gradient_penalty( discriminator_fn, discriminator_scope, epsilon=1e-10, + target=1.0, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -324,6 +325,8 @@ def wasserstein_gradient_penalty( discriminator_scope: If not `None`, reuse discriminators from this scope. epsilon: A small positive number added for numerical stability when computing the gradient norm. + target: Optional Python number or `Tensor` indicating the target value of + gradient norm. Defaults to 1.0. weights: Optional `Tensor` whose rank is either 0, or the same rank as `real_data` and `generated_data`, and must be broadcastable to them (i.e., all dimensions must be either `1`, or the same as the @@ -374,7 +377,7 @@ def wasserstein_gradient_penalty( # For numerical stability, add epsilon to the sum before taking the square # root. Note tf.norm does not add epsilon. slopes = math_ops.sqrt(gradient_squares + epsilon) - penalties = math_ops.square(slopes - 1.0) + penalties = math_ops.square(slopes / target - 1.0) penalty = losses.compute_weighted_loss( penalties, weights, scope=scope, loss_collection=loss_collection, reduction=reduction) diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 56ac455..dbaa624 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -481,6 +481,29 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): }) self.assertAlmostEqual(self._expected_loss, loss, 5) + def test_loss_with_gradient_norm_target(self): + """Test loss value with non default gradient norm target.""" + generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + + loss = tfgan_losses.wasserstein_gradient_penalty( + generated_data, + real_data, + self._kwargs['generator_inputs'], + self._kwargs['discriminator_fn'], + self._kwargs['discriminator_scope'], + target=2.0) + + with self.test_session() as sess: + variables.global_variables_initializer().run() + loss = sess.run( + loss, + feed_dict={ + generated_data: self._generated_data_np, + real_data: self._real_data_np, + }) + self.assertAlmostEqual(1.0, loss, 5) + def test_reuses_scope(self): """Test that gradient penalty reuses discriminator scope.""" num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 5d0ac93..776eb11 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -460,6 +460,7 @@ def gan_loss( # Auxiliary losses. gradient_penalty_weight=None, gradient_penalty_epsilon=1e-10, + gradient_penalty_target=1.0, mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, @@ -481,6 +482,9 @@ def gan_loss( small positive value used by the gradient penalty function for numerical stability. Note some applications will need to increase this value to avoid NaNs. + gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python + number or `Tensor` indicating the target value of gradient norm. See the + CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0. mutual_information_penalty_weight: If not `None`, must be a non-negative Python number or Tensor indicating how much to weight the mutual information penalty. See https://arxiv.org/abs/1606.03657 for more @@ -539,7 +543,10 @@ def gan_loss( # Add optional extra losses. if _use_aux_loss(gradient_penalty_weight): gp_loss = tfgan_losses.wasserstein_gradient_penalty( - model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries) + model, + epsilon=gradient_penalty_epsilon, + target=gradient_penalty_target, + add_summaries=add_summaries) dis_loss += gradient_penalty_weight * gp_loss if _use_aux_loss(mutual_information_penalty_weight): info_loss = tfgan_losses.mutual_information_penalty( -- 2.7.4