discriminator_fn,
discriminator_scope,
epsilon=1e-10,
+ target=1.0,
weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
discriminator_scope: If not `None`, reuse discriminators from this scope.
epsilon: A small positive number added for numerical stability when
computing the gradient norm.
+ target: Optional Python number or `Tensor` indicating the target value of
+ gradient norm. Defaults to 1.0.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`real_data` and `generated_data`, and must be broadcastable to
them (i.e., all dimensions must be either `1`, or the same as the
# For numerical stability, add epsilon to the sum before taking the square
# root. Note tf.norm does not add epsilon.
slopes = math_ops.sqrt(gradient_squares + epsilon)
- penalties = math_ops.square(slopes - 1.0)
+ penalties = math_ops.square(slopes / target - 1.0)
penalty = losses.compute_weighted_loss(
penalties, weights, scope=scope, loss_collection=loss_collection,
reduction=reduction)
})
self.assertAlmostEqual(self._expected_loss, loss, 5)
+ def test_loss_with_gradient_norm_target(self):
+ """Test loss value with non default gradient norm target."""
+ generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
+ real_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
+
+ loss = tfgan_losses.wasserstein_gradient_penalty(
+ generated_data,
+ real_data,
+ self._kwargs['generator_inputs'],
+ self._kwargs['discriminator_fn'],
+ self._kwargs['discriminator_scope'],
+ target=2.0)
+
+ with self.test_session() as sess:
+ variables.global_variables_initializer().run()
+ loss = sess.run(
+ loss,
+ feed_dict={
+ generated_data: self._generated_data_np,
+ real_data: self._real_data_np,
+ })
+ self.assertAlmostEqual(1.0, loss, 5)
+
def test_reuses_scope(self):
"""Test that gradient penalty reuses discriminator scope."""
num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
# Auxiliary losses.
gradient_penalty_weight=None,
gradient_penalty_epsilon=1e-10,
+ gradient_penalty_target=1.0,
mutual_information_penalty_weight=None,
aux_cond_generator_weight=None,
aux_cond_discriminator_weight=None,
small positive value used by the gradient penalty function for numerical
stability. Note some applications will need to increase this value to
avoid NaNs.
+ gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python
+ number or `Tensor` indicating the target value of gradient norm. See the
+ CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0.
mutual_information_penalty_weight: If not `None`, must be a non-negative
Python number or Tensor indicating how much to weight the mutual
information penalty. See https://arxiv.org/abs/1606.03657 for more
# Add optional extra losses.
if _use_aux_loss(gradient_penalty_weight):
gp_loss = tfgan_losses.wasserstein_gradient_penalty(
- model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries)
+ model,
+ epsilon=gradient_penalty_epsilon,
+ target=gradient_penalty_target,
+ add_summaries=add_summaries)
dis_loss += gradient_penalty_weight * gp_loss
if _use_aux_loss(mutual_information_penalty_weight):
info_loss = tfgan_losses.mutual_information_penalty(