Adds inverse_link_fn argument in regression_head. This is used in generalized regression.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 21 Feb 2018 20:00:04 +0000 (12:00 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Feb 2018 20:08:56 +0000 (12:08 -0800)
PiperOrigin-RevId: 186497656

tensorflow/contrib/estimator/python/estimator/head.py
tensorflow/python/estimator/canned/head.py
tensorflow/python/estimator/canned/head_test.py

index 238cf28..a45f693 100644 (file)
@@ -177,6 +177,7 @@ def regression_head(weight_column=None,
                     label_dimension=1,
                     loss_reduction=losses.Reduction.SUM,
                     loss_fn=None,
+                    inverse_link_fn=None,
                     name=None):
   """Creates a `_Head` for regression using the `mean_squared_error` loss.
 
@@ -195,10 +196,16 @@ def regression_head(weight_column=None,
   `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or
   `[D0, D1, ... DN, label_dimension]`.
 
-  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+  Supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
   `(labels, logits, features)` as arguments and returns unreduced loss with
   shape `[D0, D1, ... DN, label_dimension]`.
 
+  Also supports custom `inverse_link_fn`, also known as 'mean function'.
+  `inverse_link_fn` takes `logits` as argument and returns predicted values.
+  This function is the inverse of the link function defined in
+  https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function
+  Namely, for poisson regression, set `inverse_link_fn=tf.exp`.
+
   Args:
     weight_column: A string or a `_NumericColumn` created by
       `tf.feature_column.numeric_column` defining feature column representing
@@ -209,7 +216,9 @@ def regression_head(weight_column=None,
       `[batch_size, label_dimension]`).
     loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
       reduce training loss over batch. Defaults to `SUM`.
-    loss_fn: Optional loss function.
+    loss_fn: Optional loss function. Defaults to `mean_squared_error`.
+    inverse_link_fn: Optional inverse link function, also known as 'mean
+      function'. Defaults to identity.
     name: name of the head. If provided, summary and metrics keys will be
       suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
 
@@ -224,6 +233,7 @@ def regression_head(weight_column=None,
       label_dimension=label_dimension,
       loss_reduction=loss_reduction,
       loss_fn=loss_fn,
+      inverse_link_fn=inverse_link_fn,
       name=name)
 
 
index cb9e3fc..8d742a2 100644 (file)
@@ -1156,6 +1156,7 @@ def _regression_head_with_mean_squared_error_loss(
     label_dimension=1,
     loss_reduction=losses.Reduction.SUM,
     loss_fn=None,
+    inverse_link_fn=None,
     name=None):
   """Creates a `_Head` for regression using the `mean_squared_error` loss.
 
@@ -1174,10 +1175,16 @@ def _regression_head_with_mean_squared_error_loss(
   `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or
   `[D0, D1, ... DN, label_dimension]`.
 
-  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+  Supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
   `(labels, logits, features)` as arguments and returns unreduced loss with
   shape `[D0, D1, ... DN, label_dimension]`.
 
+  Also supports custom `inverse_link_fn`, also known as 'mean function'.
+  `inverse_link_fn` takes `logits` as argument and returns predicted values.
+  This function is the inverse of the link function defined in
+  https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function
+  Namely, for poisson regression, set `inverse_link_fn=tf.exp`.
+
   Args:
     weight_column: A string or a `_NumericColumn` created by
       `tf.feature_column.numeric_column` defining feature column representing
@@ -1188,7 +1195,9 @@ def _regression_head_with_mean_squared_error_loss(
       `[batch_size, label_dimension]`).
     loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
       reduce training loss over batch. Defaults to `SUM`.
-    loss_fn: Optional loss function.
+    loss_fn: Optional loss function. Defaults to `mean_squared_error`.
+    inverse_link_fn: Optional inverse link function, also known as 'mean
+      function'. Defaults to identity.
     name: name of the head. If provided, summary and metrics keys will be
       suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
 
@@ -1208,6 +1217,7 @@ def _regression_head_with_mean_squared_error_loss(
       label_dimension=label_dimension,
       loss_reduction=loss_reduction,
       loss_fn=loss_fn,
+      inverse_link_fn=inverse_link_fn,
       name=name)
 
 
@@ -1220,6 +1230,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
       weight_column=None,
       loss_reduction=losses.Reduction.SUM,
       loss_fn=None,
+      inverse_link_fn=None,
       name=None):
     """`Head` for regression."""
     if label_dimension < 1:
@@ -1228,6 +1239,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
     self._weight_column = weight_column
     self._loss_reduction = loss_reduction
     self._loss_fn = loss_fn
+    self._inverse_link_fn = inverse_link_fn
     self._name = name
 
   @property
@@ -1294,9 +1306,19 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
     # Predict.
     with ops.name_scope(self._name, 'head'):
       logits = _check_logits_final_dim(logits, self._logits_dimension)
-      predictions = {prediction_keys.PredictionKeys.PREDICTIONS: logits}
+      if self._inverse_link_fn:
+        predicted_value = self._inverse_link_fn(logits)
+        predictions = {
+            prediction_keys.PredictionKeys.PREDICTIONS: predicted_value,
+            prediction_keys.PredictionKeys.LOGITS: logits,
+        }
+      else:
+        predicted_value = logits
+        predictions = {
+            prediction_keys.PredictionKeys.PREDICTIONS: predicted_value}
       if mode == model_fn.ModeKeys.PREDICT:
-        regression_output = export_output.RegressionOutput(value=logits)
+        regression_output = export_output.RegressionOutput(
+            value=predicted_value)
         return model_fn.EstimatorSpec(
             mode=model_fn.ModeKeys.PREDICT,
             predictions=predictions,
index c09f882..a300f31 100644 (file)
@@ -2703,10 +2703,9 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
     self.assertIsNone(spec.loss)
     self.assertEqual({}, spec.eval_metric_ops)
     self.assertIsNone(spec.train_op)
+    default_serving_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
     self.assertItemsEqual(
-        (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
-         'predict',
-         'regression'),
+        (default_serving_key, 'predict', 'regression'),
         spec.export_outputs.keys())
     _assert_no_hooks(self, spec)
 
@@ -2714,6 +2713,54 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
     with self.test_session():
       _initialize_variables(self, spec.scaffold)
       self.assertAllClose(logits, spec.predictions[prediction_key].eval())
+      self.assertAllClose(
+          logits, spec.export_outputs[default_serving_key].value.eval())
+      self.assertAllClose(
+          logits, spec.export_outputs['regression'].value.eval())
+      self.assertAllClose(
+          logits, spec.export_outputs['predict'].outputs['predictions'].eval())
+
+  def test_predict_with_inverse_link_fn(self):
+    def _inverse_link_fn(logits):
+      return logits - 10.
+    head = head_lib._regression_head_with_mean_squared_error_loss(
+        inverse_link_fn=_inverse_link_fn)
+
+    # Create estimator spec.
+    logits = np.array(((45,), (41,),), dtype=np.int32)
+    expected_predictions = np.array(((35,), (31,),), dtype=np.int32)
+    spec = head.create_estimator_spec(
+        features={'x': np.array(((42.,),), dtype=np.int32)},
+        mode=model_fn.ModeKeys.PREDICT,
+        logits=logits)
+
+    # Assert spec contains expected tensors.
+    keys = prediction_keys.PredictionKeys
+    self.assertItemsEqual(
+        (keys.PREDICTIONS, keys.LOGITS), spec.predictions.keys())
+    self.assertEqual(dtypes.float32, spec.predictions[keys.PREDICTIONS].dtype)
+    self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype)
+    default_serving_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+    self.assertItemsEqual(
+        (default_serving_key, 'predict', 'regression'),
+        spec.export_outputs.keys())
+
+    # Assert predictions.
+    with self.test_session():
+      _initialize_variables(self, spec.scaffold)
+      self.assertAllClose(
+          expected_predictions, spec.predictions[keys.PREDICTIONS].eval())
+      self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval())
+      self.assertAllClose(
+          expected_predictions,
+          spec.export_outputs[default_serving_key].value.eval())
+      self.assertAllClose(
+          expected_predictions, spec.export_outputs['regression'].value.eval())
+      self.assertAllClose(
+          expected_predictions,
+          spec.export_outputs['predict'].outputs['predictions'].eval())
+      self.assertAllClose(
+          logits, spec.export_outputs['predict'].outputs['logits'].eval())
 
   def test_eval_create_loss(self):
     head = head_lib._regression_head_with_mean_squared_error_loss()