boosted_trees: accept integer labels properly now the same as float labels; added...
authorYounghee Kwon <youngheek@google.com>
Wed, 16 May 2018 16:44:48 +0000 (09:44 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 16 May 2018 16:47:26 +0000 (09:47 -0700)
PiperOrigin-RevId: 196841265

tensorflow/python/estimator/canned/boosted_trees.py
tensorflow/python/estimator/canned/boosted_trees_test.py

index 6d7a329..6e4a19f 100644 (file)
@@ -57,7 +57,7 @@ def _get_transformed_features(features, sorted_feature_columns):
 
   Args:
     features: a dicionary of name to Tensor.
-    feature_columns: a list/set of tf.feature_column.
+    sorted_feature_columns: a list/set of tf.feature_column, sorted by name.
 
   Returns:
     result_features: a list of the transformed features, sorted by the name.
@@ -256,7 +256,7 @@ class _CacheTrainingStatesUsingHashTable(object):
     elif dtypes.as_dtype(dtypes.string).is_compatible_with(example_ids.dtype):
       empty_key = ''
     else:
-      raise ValueError('Unsupported example_id_feature dtype %s.',
+      raise ValueError('Unsupported example_id_feature dtype %s.' %
                        example_ids.dtype)
     # Cache holds latest <tree_id, node_id, logits> for each example.
     # tree_id and node_id are both int32 but logits is a float32.
@@ -675,6 +675,7 @@ def _create_classification_head_and_closed_form(n_classes, weight_column,
       predictions = math_ops.reciprocal(math_ops.exp(-logits) + 1.0)
       normalizer = math_ops.reciprocal(
           math_ops.cast(array_ops.size(predictions), dtypes.float32))
+      labels = math_ops.cast(labels, dtypes.float32)
       gradients = (predictions - labels) * normalizer
       hessians = predictions * (1.0 - predictions) * normalizer
       return gradients, hessians
index 95bb9b5..13595d4 100644 (file)
@@ -160,6 +160,49 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
     self.assertAllClose([[0], [1], [1], [0], [0]],
                         [pred['class_ids'] for pred in predictions])
 
+  def testTrainClassifierWithLabelVocabulary(self):
+    apple, banana = 'apple', 'banana'
+    def _input_fn_with_label_vocab():
+      return FEATURES_DICT, [[apple], [banana], [banana], [apple], [apple]]
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.BoostedTreesClassifier(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5,
+        label_vocabulary=[apple, banana])
+    est.train(input_fn=_input_fn_with_label_vocab, steps=5)
+    self._assert_checkpoint(
+        est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+    eval_res = est.evaluate(input_fn=_input_fn_with_label_vocab, steps=1)
+    self.assertAllClose(eval_res['accuracy'], 1.0)
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertAllClose([[0], [1], [1], [0], [0]],
+                        [pred['class_ids'] for pred in predictions])
+
+  def testTrainClassifierWithIntegerLabel(self):
+    def _input_fn_with_integer_label():
+      return (FEATURES_DICT,
+              constant_op.constant([[0], [1], [1], [0], [0]], dtypes.int32))
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.BoostedTreesClassifier(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5)
+    est.train(input_fn=_input_fn_with_integer_label, steps=5)
+    self._assert_checkpoint(
+        est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+    eval_res = est.evaluate(input_fn=_input_fn_with_integer_label, steps=1)
+    self.assertAllClose(eval_res['accuracy'], 1.0)
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertAllClose([[0], [1], [1], [0], [0]],
+                        [pred['class_ids'] for pred in predictions])
+
   def testTrainClassifierWithDataset(self):
     train_input_fn = _make_train_input_fn_dataset(is_classification=True)
     predict_input_fn = numpy_io.numpy_input_fn(