Prepare variance to be exported for serving with the servo library.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 30 Jan 2018 18:18:36 +0000 (10:18 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 30 Jan 2018 20:32:21 +0000 (12:32 -0800)
PiperOrigin-RevId: 183851026

tensorflow/contrib/tensor_forest/client/random_forest.py

index a998ac1..4abcc20 100644 (file)
@@ -18,7 +18,7 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib import layers
-
+from tensorflow.contrib.learn.python.learn.estimators import constants
 from tensorflow.contrib.learn.python.learn.estimators import estimator
 from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
 from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
@@ -43,8 +43,8 @@ from tensorflow.python.training import training_util
 KEYS_NAME = 'keys'
 LOSS_NAME = 'rf_training_loss'
 TREE_PATHS_PREDICTION_KEY = 'tree_paths'
-VARIANCE_PREDICTION_KEY = 'regression_variance'
-
+VARIANCE_PREDICTION_KEY = 'prediction_variance'
+ALL_SERVING_KEY = 'tensorforest_all'
 EPSILON = 0.000001
 
 
@@ -134,7 +134,8 @@ def get_model_fn(params,
                  trainer_id=0,
                  report_feature_importances=False,
                  local_eval=False,
-                 head_scope=None):
+                 head_scope=None,
+                 include_all_in_serving=False):
   """Return a model function given a way to construct a graph builder."""
   if model_head is None:
     model_head = get_default_head(params, weights_name)
@@ -238,7 +239,13 @@ def get_model_fn(params,
       model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
 
     model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
-
+    if include_all_in_serving:
+      # In order to serve the variance we need to add the prediction dict
+      # to output_alternatives dict.
+      if not model_ops.output_alternatives:
+        model_ops.output_alternatives = {}
+      model_ops.output_alternatives[ALL_SERVING_KEY] = (
+          constants.ProblemType.UNSPECIFIED, model_ops.predictions)
     return model_ops
 
   return _model_fn
@@ -293,7 +300,8 @@ class TensorForestEstimator(estimator.Estimator):
                report_feature_importances=False,
                local_eval=False,
                version=None,
-               head=None):
+               head=None,
+               include_all_in_serving=False):
     """Initializes a TensorForestEstimator instance.
 
     Args:
@@ -339,6 +347,23 @@ class TensorForestEstimator(estimator.Estimator):
       version: Unused.
       head: A heads_lib.Head object that calculates losses and such. If None,
         one will be automatically created based on params.
+      include_all_in_serving: if True, allow preparation of the complete
+        prediction dict including the variance to be exported for serving with
+        the Servo lib; and it also requires calling export_savedmodel with
+        default_output_alternative_key=ALL_SERVING_KEY, i.e.
+        estimator.export_savedmodel(export_dir_base=your_export_dir,
+          serving_input_fn=your_export_input_fn,
+          default_output_alternative_key=ALL_SERVING_KEY)
+        if False, resort to default behavior, i.e. export scores and
+          probabilities but no variances. In this case
+          default_output_alternative_key should be None while calling
+          export_savedmodel().
+        Note, that due to backward compatibility we cannot always set
+        include_all_in_serving to True because in this case calling
+        export_saved_model() without
+        default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the
+        saved_model_export_utils.get_output_alternatives() would raise
+        ValueError.
 
     Returns:
       A `TensorForestEstimator` instance.
@@ -357,7 +382,9 @@ class TensorForestEstimator(estimator.Estimator):
             num_trainers=num_trainers,
             trainer_id=trainer_id,
             report_feature_importances=report_feature_importances,
-            local_eval=local_eval),
+            local_eval=local_eval,
+            include_all_in_serving=include_all_in_serving,
+        ),
         model_dir=model_dir,
         config=config,
         feature_engineering_fn=feature_engineering_fn)