Allow using DNN to only train the embeddings and using the tree model for the final...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 21 May 2018 21:47:37 +0000 (14:47 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 21 May 2018 21:50:13 +0000 (14:50 -0700)
PiperOrigin-RevId: 197462585

tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py

index 9994c84..758754f 100644 (file)
@@ -45,6 +45,7 @@ from tensorflow.python.training import training_util
 
 _DNN_LEARNING_RATE = 0.001
 
+
 def _get_optimizer(optimizer):
   if callable(optimizer):
     return optimizer()
@@ -73,6 +74,7 @@ def _dnn_tree_combined_model_fn(features,
                                 dnn_input_layer_partitioner=None,
                                 dnn_input_layer_to_tree=True,
                                 dnn_steps_to_train=10000,
+                                predict_with_tree_only=False,
                                 tree_feature_columns=None,
                                 tree_center_bias=False,
                                 use_core_versions=False):
@@ -108,6 +110,8 @@ def _dnn_tree_combined_model_fn(features,
     as a feature to the tree.
     dnn_steps_to_train: Number of steps to train dnn for before switching
       to gbdt.
+    predict_with_tree_only: Whether to use only the tree model output as the
+      final prediction.
     tree_feature_columns: An iterable containing all the feature columns
       used by the model's boosted trees. If dnn_input_layer_to_tree is
       set to True, these features are in addition to dnn_feature_columns.
@@ -132,8 +136,7 @@ def _dnn_tree_combined_model_fn(features,
   dnn_parent_scope = "dnn"
   dnn_partitioner = dnn_input_layer_partitioner or (
       partitioned_variables.min_max_variable_partitioner(
-          max_partitions=config.num_ps_replicas,
-          min_slice_size=64 << 20))
+          max_partitions=config.num_ps_replicas, min_slice_size=64 << 20))
 
   with variable_scope.variable_scope(
       dnn_parent_scope,
@@ -171,8 +174,7 @@ def _dnn_tree_combined_model_fn(features,
       _add_hidden_layer_summary(net, hidden_layer_scope.name)
       previous_layer = net
     with variable_scope.variable_scope(
-        "logits",
-        values=(previous_layer,)) as logits_scope:
+        "logits", values=(previous_layer,)) as logits_scope:
       dnn_logits = layers.fully_connected(
           previous_layer,
           head.logits_dimension,
@@ -190,8 +192,7 @@ def _dnn_tree_combined_model_fn(features,
           optimizer=_get_optimizer(dnn_optimizer),
           name=dnn_parent_scope,
           variables=ops.get_collection(
-              ops.GraphKeys.TRAINABLE_VARIABLES,
-              scope=dnn_parent_scope),
+              ops.GraphKeys.TRAINABLE_VARIABLES, scope=dnn_parent_scope),
           # Empty summaries to prevent optimizers from logging training_loss.
           summaries=[])
 
@@ -230,7 +231,10 @@ def _dnn_tree_combined_model_fn(features,
         update_op = state_ops.assign_add(global_step, 1).op
         return update_op
 
-  tree_train_logits = dnn_logits + tree_logits
+  if predict_with_tree_only:
+    tree_train_logits = tree_logits
+  else:
+    tree_train_logits = dnn_logits + tree_logits
 
   def _no_train_op_fn(loss):
     """Returns a no-op."""
@@ -288,10 +292,10 @@ def _dnn_tree_combined_model_fn(features,
   finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
 
   model_fn_ops.training_hooks.extend([
-      trainer_hooks.SwitchTrainOp(
-          dnn_train_op, dnn_steps_to_train, tree_train_op),
-      trainer_hooks.StopAfterNTrees(
-          num_trees, attempted_trees, finalized_trees)])
+      trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train,
+                                  tree_train_op),
+      trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees)
+  ])
 
   return model_fn_ops
 
@@ -318,6 +322,7 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
                dnn_input_layer_partitioner=None,
                dnn_input_layer_to_tree=True,
                dnn_steps_to_train=10000,
+               predict_with_tree_only=False,
                tree_feature_columns=None,
                tree_center_bias=False,
                use_core_versions=False):
@@ -360,6 +365,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
       as a feature to the tree.
       dnn_steps_to_train: Number of steps to train dnn for before switching
         to gbdt.
+      predict_with_tree_only: Whether to use only the tree model output as the
+        final prediction.
       tree_feature_columns: An iterable containing all the feature columns
         used by the model's boosted trees. If dnn_input_layer_to_tree is
         set to True, these features are in addition to dnn_feature_columns.
@@ -377,16 +384,32 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
 
     def _model_fn(features, labels, mode, config):
       return _dnn_tree_combined_model_fn(
-          features, labels, mode, head, dnn_hidden_units, dnn_feature_columns,
-          tree_learner_config, num_trees, tree_examples_per_layer, config,
-          dnn_optimizer, dnn_activation_fn, dnn_dropout,
-          dnn_input_layer_partitioner, dnn_input_layer_to_tree,
-          dnn_steps_to_train, tree_feature_columns, tree_center_bias,
-          use_core_versions)
+          features=features,
+          labels=labels,
+          mode=mode,
+          head=head,
+          dnn_hidden_units=dnn_hidden_units,
+          dnn_feature_columns=dnn_feature_columns,
+          tree_learner_config=tree_learner_config,
+          num_trees=num_trees,
+          tree_examples_per_layer=tree_examples_per_layer,
+          config=config,
+          dnn_optimizer=dnn_optimizer,
+          dnn_activation_fn=dnn_activation_fn,
+          dnn_dropout=dnn_dropout,
+          dnn_input_layer_partitioner=dnn_input_layer_partitioner,
+          dnn_input_layer_to_tree=dnn_input_layer_to_tree,
+          dnn_steps_to_train=dnn_steps_to_train,
+          predict_with_tree_only=predict_with_tree_only,
+          tree_feature_columns=tree_feature_columns,
+          tree_center_bias=tree_center_bias,
+          use_core_versions=use_core_versions)
 
     super(DNNBoostedTreeCombinedClassifier, self).__init__(
-        model_fn=_model_fn, model_dir=model_dir,
-        config=config, feature_engineering_fn=feature_engineering_fn)
+        model_fn=_model_fn,
+        model_dir=model_dir,
+        config=config,
+        feature_engineering_fn=feature_engineering_fn)
 
 
 class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
@@ -410,6 +433,7 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
                dnn_input_layer_partitioner=None,
                dnn_input_layer_to_tree=True,
                dnn_steps_to_train=10000,
+               predict_with_tree_only=False,
                tree_feature_columns=None,
                tree_center_bias=False,
                use_core_versions=False):
@@ -452,6 +476,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
       as a feature to the tree.
       dnn_steps_to_train: Number of steps to train dnn for before switching
         to gbdt.
+      predict_with_tree_only: Whether to use only the tree model output as the
+        final prediction.
       tree_feature_columns: An iterable containing all the feature columns
         used by the model's boosted trees. If dnn_input_layer_to_tree is
         set to True, these features are in addition to dnn_feature_columns.
@@ -474,16 +500,32 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
 
     def _model_fn(features, labels, mode, config):
       return _dnn_tree_combined_model_fn(
-          features, labels, mode, head, dnn_hidden_units, dnn_feature_columns,
-          tree_learner_config, num_trees, tree_examples_per_layer, config,
-          dnn_optimizer, dnn_activation_fn, dnn_dropout,
-          dnn_input_layer_partitioner, dnn_input_layer_to_tree,
-          dnn_steps_to_train, tree_feature_columns, tree_center_bias,
-          use_core_versions)
+          features=features,
+          labels=labels,
+          mode=mode,
+          head=head,
+          dnn_hidden_units=dnn_hidden_units,
+          dnn_feature_columns=dnn_feature_columns,
+          tree_learner_config=tree_learner_config,
+          num_trees=num_trees,
+          tree_examples_per_layer=tree_examples_per_layer,
+          config=config,
+          dnn_optimizer=dnn_optimizer,
+          dnn_activation_fn=dnn_activation_fn,
+          dnn_dropout=dnn_dropout,
+          dnn_input_layer_partitioner=dnn_input_layer_partitioner,
+          dnn_input_layer_to_tree=dnn_input_layer_to_tree,
+          dnn_steps_to_train=dnn_steps_to_train,
+          predict_with_tree_only=predict_with_tree_only,
+          tree_feature_columns=tree_feature_columns,
+          tree_center_bias=tree_center_bias,
+          use_core_versions=use_core_versions)
 
     super(DNNBoostedTreeCombinedRegressor, self).__init__(
-        model_fn=_model_fn, model_dir=model_dir,
-        config=config, feature_engineering_fn=feature_engineering_fn)
+        model_fn=_model_fn,
+        model_dir=model_dir,
+        config=config,
+        feature_engineering_fn=feature_engineering_fn)
 
 
 class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
@@ -508,6 +550,7 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
                dnn_input_layer_partitioner=None,
                dnn_input_layer_to_tree=True,
                dnn_steps_to_train=10000,
+               predict_with_tree_only=False,
                tree_feature_columns=None,
                tree_center_bias=False,
                use_core_versions=False):
@@ -545,6 +588,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
       as a feature to the tree.
       dnn_steps_to_train: Number of steps to train dnn for before switching
         to gbdt.
+      predict_with_tree_only: Whether to use only the tree model output as the
+        final prediction.
       tree_feature_columns: An iterable containing all the feature columns
         used by the model's boosted trees. If dnn_input_layer_to_tree is
         set to True, these features are in addition to dnn_feature_columns.
@@ -553,15 +598,32 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
       use_core_versions: Whether feature columns and loss are from the core (as
         opposed to contrib) version of tensorflow.
     """
+
     def _model_fn(features, labels, mode, config):
       return _dnn_tree_combined_model_fn(
-          features, labels, mode, head, dnn_hidden_units, dnn_feature_columns,
-          tree_learner_config, num_trees, tree_examples_per_layer, config,
-          dnn_optimizer, dnn_activation_fn, dnn_dropout,
-          dnn_input_layer_partitioner, dnn_input_layer_to_tree,
-          dnn_steps_to_train, tree_feature_columns, tree_center_bias,
-          use_core_versions)
+          features=features,
+          labels=labels,
+          mode=mode,
+          head=head,
+          dnn_hidden_units=dnn_hidden_units,
+          dnn_feature_columns=dnn_feature_columns,
+          tree_learner_config=tree_learner_config,
+          num_trees=num_trees,
+          tree_examples_per_layer=tree_examples_per_layer,
+          config=config,
+          dnn_optimizer=dnn_optimizer,
+          dnn_activation_fn=dnn_activation_fn,
+          dnn_dropout=dnn_dropout,
+          dnn_input_layer_partitioner=dnn_input_layer_partitioner,
+          dnn_input_layer_to_tree=dnn_input_layer_to_tree,
+          dnn_steps_to_train=dnn_steps_to_train,
+          predict_with_tree_only=predict_with_tree_only,
+          tree_feature_columns=tree_feature_columns,
+          tree_center_bias=tree_center_bias,
+          use_core_versions=use_core_versions)
 
     super(DNNBoostedTreeCombinedEstimator, self).__init__(
-        model_fn=_model_fn, model_dir=model_dir,
-        config=config, feature_engineering_fn=feature_engineering_fn)
+        model_fn=_model_fn,
+        model_dir=model_dir,
+        config=config,
+        feature_engineering_fn=feature_engineering_fn)