generator_optimizer=None,
discriminator_optimizer=None,
get_hooks_fn=None,
+ get_eval_metric_ops_fn=None,
add_summaries=None,
use_loss_summaries=True,
config=None):
list of hooks. These hooks are run on the generator and discriminator
train ops, and can be used to implement the GAN training scheme.
Defaults to `train.get_sequential_train_hooks()`.
+ get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+ dict of metric results keyed by name. The output of this function is
+ passed into `tf.estimator.EstimatorSpec` during evaluation.
add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
else discriminator_optimizer)
gan_head = head_lib.gan_head(
generator_loss_fn, discriminator_loss_fn, gopt, dopt,
- use_loss_summaries, get_hooks_fn=get_hooks_fn)
+ use_loss_summaries, get_hooks_fn=get_hooks_fn,
+ get_eval_metric_ops_fn=get_eval_metric_ops_fn)
return _gan_model_fn(
features, labels, mode, generator_fn, discriminator_fn, gan_head,
add_summaries)
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
return training.GradientDescentOptimizer(lr)
+ def get_metrics(gan_model):
+ return {
+ 'mse_custom_metric': metrics_lib.mean_squared_error(
+ gan_model.real_data, gan_model.generated_data)
+ }
+
gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
est = estimator.GANEstimator(
discriminator_loss_fn=losses.wasserstein_discriminator_loss,
generator_optimizer=gopt,
discriminator_optimizer=dopt,
+ get_eval_metric_ops_fn=get_metrics,
model_dir=self._model_dir)
# TRAIN
self.assertIn('loss', six.iterkeys(scores))
self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'],
scores['loss'])
+ self.assertIn('mse_custom_metric', six.iterkeys(scores))
# PREDICT
predictions = np.array([x for x in est.predict(predict_input_fn)])
def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
discriminator_optimizer, use_loss_summaries=True,
get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
- name=None):
+ get_eval_metric_ops_fn=None, name=None):
"""Creates a `GANHead`.
Args:
discriminator_optimizer: Same as `generator_optimizer`, but for the
discriminator updates.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
- If `None`, uses defaults.
- get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
- of hooks.
+ If `None`, uses defaults.
+ get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
+ list of hooks.
+ get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+ dict of metric results keyed by name. The output of this function is
+ passed into `tf.estimator.EstimatorSpec` during evaluation.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`.
discriminator_optimizer=discriminator_optimizer,
use_loss_summaries=use_loss_summaries,
get_hooks_fn=get_hooks_fn,
+ get_eval_metric_ops_fn=get_eval_metric_ops_fn,
name=name)
generator_optimizer, discriminator_optimizer,
use_loss_summaries=True,
get_hooks_fn=None,
+ get_eval_metric_ops_fn=None,
name=None):
"""`Head` for GAN training.
discriminator updates.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
- get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
- of hooks. Defaults to `train.get_sequential_train_hooks()`
+ get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
+ list of hooks. Defaults to `train.get_sequential_train_hooks()`
+ get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+ dict of metric results keyed by name. The output of this function is
+ passed into `tf.estimator.EstimatorSpec` during evaluation.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`.
"""
self._generator_optimizer = generator_optimizer
self._discriminator_optimizer = discriminator_optimizer
self._get_hooks_fn = get_hooks_fn
+ self._get_eval_metric_ops_fn = get_eval_metric_ops_fn
self._name = name
@property
_summary_key(self._name, 'discriminator_loss'):
metrics_lib.mean(gan_loss.discriminator_loss)
}
+ if self._get_eval_metric_ops_fn is not None:
+ custom_eval_metric_ops = self._get_eval_metric_ops_fn(gan_model)
+ if not isinstance(custom_eval_metric_ops, dict):
+ raise TypeError('get_eval_metric_ops_fn must return a dict, '
+ 'received: {}'.format(custom_eval_metric_ops))
+ eval_metric_ops.update(custom_eval_metric_ops)
return model_fn_lib.EstimatorSpec(
mode=model_fn_lib.ModeKeys.EVAL,
predictions=gan_model.generated_data,
generator_loss_fn=dummy_loss,
discriminator_loss_fn=dummy_loss,
generator_optimizer=training.GradientDescentOptimizer(1.0),
- discriminator_optimizer=training.GradientDescentOptimizer(1.0))
+ discriminator_optimizer=training.GradientDescentOptimizer(1.0),
+ get_eval_metric_ops_fn=self.get_metrics)
self.assertTrue(isinstance(self.gan_head, head.GANHead))
+ def get_metrics(self, gan_model):
+ self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel))
+ return {}
+
def _test_modes_helper(self, mode):
self.gan_head.create_estimator_spec(
features=None,