[tfgan] Allow to add custom eval metrics to GANEstimator (#19133)
authorLukas Geiger <lgeiger@users.noreply.github.com>
Tue, 8 May 2018 23:31:39 +0000 (01:31 +0200)
committerRasmus Munk Larsen <rmlarsen@google.com>
Tue, 8 May 2018 23:31:39 +0000 (16:31 -0700)
tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
tensorflow/contrib/gan/python/estimator/python/head_impl.py
tensorflow/contrib/gan/python/estimator/python/head_test.py

index e3fc6bf..4092b32 100644 (file)
@@ -112,6 +112,7 @@ class GANEstimator(estimator.Estimator):
                generator_optimizer=None,
                discriminator_optimizer=None,
                get_hooks_fn=None,
+               get_eval_metric_ops_fn=None,
                add_summaries=None,
                use_loss_summaries=True,
                config=None):
@@ -146,6 +147,9 @@ class GANEstimator(estimator.Estimator):
         list of hooks. These hooks are run on the generator and discriminator
         train ops, and can be used to implement the GAN training scheme.
         Defaults to `train.get_sequential_train_hooks()`.
+      get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+        dict of metric results keyed by name. The output of this function is
+        passed into `tf.estimator.EstimatorSpec` during evaluation.
       add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
       use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
         If `None`, uses defaults.
@@ -160,7 +164,8 @@ class GANEstimator(estimator.Estimator):
               else discriminator_optimizer)
       gan_head = head_lib.gan_head(
           generator_loss_fn, discriminator_loss_fn, gopt, dopt,
-          use_loss_summaries, get_hooks_fn=get_hooks_fn)
+          use_loss_summaries, get_hooks_fn=get_hooks_fn,
+          get_eval_metric_ops_fn=get_eval_metric_ops_fn)
       return _gan_model_fn(
           features, labels, mode, generator_fn, discriminator_fn, gan_head,
           add_summaries)
index 6bbd173..9554825 100644 (file)
@@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import metrics as metrics_lib
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.platform import test
 from tensorflow.python.summary.writer import writer_cache
@@ -194,6 +195,12 @@ class GANEstimatorIntegrationTest(test.TestCase):
       lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
       return training.GradientDescentOptimizer(lr)
 
+    def get_metrics(gan_model):
+      return {
+          'mse_custom_metric': metrics_lib.mean_squared_error(
+              gan_model.real_data, gan_model.generated_data)
+      }
+
     gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
     dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
     est = estimator.GANEstimator(
@@ -203,6 +210,7 @@ class GANEstimatorIntegrationTest(test.TestCase):
         discriminator_loss_fn=losses.wasserstein_discriminator_loss,
         generator_optimizer=gopt,
         discriminator_optimizer=dopt,
+        get_eval_metric_ops_fn=get_metrics,
         model_dir=self._model_dir)
 
     # TRAIN
@@ -215,6 +223,7 @@ class GANEstimatorIntegrationTest(test.TestCase):
     self.assertIn('loss', six.iterkeys(scores))
     self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'],
                      scores['loss'])
+    self.assertIn('mse_custom_metric', six.iterkeys(scores))
 
     # PREDICT
     predictions = np.array([x for x in est.predict(predict_input_fn)])
index d174cb3..ff903a7 100644 (file)
@@ -39,7 +39,7 @@ def _summary_key(head_name, val):
 def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
              discriminator_optimizer, use_loss_summaries=True,
              get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
-             name=None):
+             get_eval_metric_ops_fn=None, name=None):
   """Creates a `GANHead`.
 
   Args:
@@ -51,9 +51,12 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
     discriminator_optimizer: Same as `generator_optimizer`, but for the
       discriminator updates.
     use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
-        If `None`, uses defaults.
-    get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
-        of hooks.
+      If `None`, uses defaults.
+    get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
+      list of hooks.
+    get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+      dict of metric results keyed by name. The output of this function is
+      passed into `tf.estimator.EstimatorSpec` during evaluation.
     name: name of the head. If provided, summary and metrics keys will be
       suffixed by `"/" + name`.
 
@@ -66,6 +69,7 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
                  discriminator_optimizer=discriminator_optimizer,
                  use_loss_summaries=use_loss_summaries,
                  get_hooks_fn=get_hooks_fn,
+                 get_eval_metric_ops_fn=get_eval_metric_ops_fn,
                  name=name)
 
 
@@ -76,6 +80,7 @@ class GANHead(head._Head):  # pylint: disable=protected-access
                generator_optimizer, discriminator_optimizer,
                use_loss_summaries=True,
                get_hooks_fn=None,
+               get_eval_metric_ops_fn=None,
                name=None):
     """`Head` for GAN training.
 
@@ -89,8 +94,11 @@ class GANHead(head._Head):  # pylint: disable=protected-access
         discriminator updates.
       use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
         If `None`, uses defaults.
-      get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
-        of hooks. Defaults to `train.get_sequential_train_hooks()`
+      get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
+        list of hooks. Defaults to `train.get_sequential_train_hooks()`
+      get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+        dict of metric results keyed by name. The output of this function is
+        passed into `tf.estimator.EstimatorSpec` during evaluation.
       name: name of the head. If provided, summary and metrics keys will be
         suffixed by `"/" + name`.
     """
@@ -108,6 +116,7 @@ class GANHead(head._Head):  # pylint: disable=protected-access
     self._generator_optimizer = generator_optimizer
     self._discriminator_optimizer = discriminator_optimizer
     self._get_hooks_fn = get_hooks_fn
+    self._get_eval_metric_ops_fn = get_eval_metric_ops_fn
     self._name = name
 
   @property
@@ -187,6 +196,12 @@ class GANHead(head._Head):  # pylint: disable=protected-access
               _summary_key(self._name, 'discriminator_loss'):
                   metrics_lib.mean(gan_loss.discriminator_loss)
           }
+          if self._get_eval_metric_ops_fn is not None:
+            custom_eval_metric_ops = self._get_eval_metric_ops_fn(gan_model)
+            if not isinstance(custom_eval_metric_ops, dict):
+              raise TypeError('get_eval_metric_ops_fn must return a dict, '
+                              'received: {}'.format(custom_eval_metric_ops))
+            eval_metric_ops.update(custom_eval_metric_ops)
         return model_fn_lib.EstimatorSpec(
             mode=model_fn_lib.ModeKeys.EVAL,
             predictions=gan_model.generated_data,
index 8168f00..6587f1f 100644 (file)
@@ -62,9 +62,14 @@ class GANHeadTest(test.TestCase):
         generator_loss_fn=dummy_loss,
         discriminator_loss_fn=dummy_loss,
         generator_optimizer=training.GradientDescentOptimizer(1.0),
-        discriminator_optimizer=training.GradientDescentOptimizer(1.0))
+        discriminator_optimizer=training.GradientDescentOptimizer(1.0),
+        get_eval_metric_ops_fn=self.get_metrics)
     self.assertTrue(isinstance(self.gan_head, head.GANHead))
 
+  def get_metrics(self, gan_model):
+    self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel))
+    return {}
+
   def _test_modes_helper(self, mode):
     self.gan_head.create_estimator_spec(
         features=None,