Added kernels and estimators for Gradient Boosting Trees algorithm.
authorYounghee Kwon <youngheek@google.com>
Thu, 29 Mar 2018 16:43:19 +0000 (09:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 16:48:47 +0000 (09:48 -0700)
BoostedTreesClassifier and BoostedTreesRegressor are added to tf.estimator.
Also some training utility functions are added to tf.contrib.estimator.

PiperOrigin-RevId: 190942599

49 files changed:
tensorflow/contrib/cmake/python_modules.txt
tensorflow/contrib/cmake/python_protos.txt
tensorflow/contrib/cmake/tf_core_ops.cmake
tensorflow/contrib/cmake/tf_python.cmake
tensorflow/contrib/estimator/BUILD
tensorflow/contrib/estimator/__init__.py
tensorflow/contrib/estimator/python/estimator/boosted_trees.py [new file with mode: 0644]
tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py [new file with mode: 0644]
tensorflow/contrib/makefile/tf_op_files.txt
tensorflow/contrib/makefile/tf_proto_files.txt
tensorflow/core/BUILD
tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateEnsemble.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_BoostedTreesDeserializeEnsemble.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_BoostedTreesEnsembleResourceHandleOp.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_BoostedTreesGetEnsembleStates.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeStatsSummary.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_BoostedTreesPredict.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_BoostedTreesSerializeEnsemble.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_BoostedTreesTrainingPredict.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsemble.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_IsBoostedTreesEnsembleInitialized.pbtxt [new file with mode: 0644]
tensorflow/core/kernels/BUILD
tensorflow/core/kernels/boosted_trees/BUILD [new file with mode: 0644]
tensorflow/core/kernels/boosted_trees/boosted_trees.proto [new file with mode: 0644]
tensorflow/core/kernels/boosted_trees/prediction_ops.cc [new file with mode: 0644]
tensorflow/core/kernels/boosted_trees/resource_ops.cc [new file with mode: 0644]
tensorflow/core/kernels/boosted_trees/resources.cc [new file with mode: 0644]
tensorflow/core/kernels/boosted_trees/resources.h [new file with mode: 0644]
tensorflow/core/kernels/boosted_trees/stats_ops.cc [new file with mode: 0644]
tensorflow/core/kernels/boosted_trees/training_ops.cc [new file with mode: 0644]
tensorflow/core/ops/boosted_trees_ops.cc [new file with mode: 0644]
tensorflow/python/BUILD
tensorflow/python/__init__.py
tensorflow/python/estimator/BUILD
tensorflow/python/estimator/canned/boosted_trees.py [new file with mode: 0644]
tensorflow/python/estimator/canned/boosted_trees_test.py [new file with mode: 0644]
tensorflow/python/estimator/estimator_lib.py
tensorflow/python/kernel_tests/boosted_trees/BUILD [new file with mode: 0644]
tensorflow/python/kernel_tests/boosted_trees/__init__.py [new file with mode: 0644]
tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py [new file with mode: 0644]
tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py [new file with mode: 0644]
tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py [new file with mode: 0644]
tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py [new file with mode: 0644]
tensorflow/python/ops/boosted_trees_ops.py [new file with mode: 0644]
tensorflow/python/training/device_setter.py
tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt [new file with mode: 0644]
tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt [new file with mode: 0644]
tensorflow/tools/api/golden/tensorflow.estimator.pbtxt

index 112b690..cc7d791 100644 (file)
@@ -79,6 +79,7 @@ tensorflow/python/keras/_impl/keras/preprocessing
 tensorflow/python/keras/_impl/keras/utils
 tensorflow/python/keras/_impl/keras/wrappers
 tensorflow/python/kernel_tests
+tensorflow/python/kernel_tests/boosted_trees
 tensorflow/python/kernel_tests/distributions
 tensorflow/python/kernel_tests/linalg
 tensorflow/python/kernel_tests/random
index c03c0c8..0c80d52 100644 (file)
@@ -1,4 +1,5 @@
 tensorflow/core
+tensorflow/core/kernels/boosted_trees
 tensorflow/core/profiler
 tensorflow/python
 tensorflow/contrib/boosted_trees/proto
index d6712aa..092a48b 100644 (file)
@@ -15,8 +15,9 @@
 set(tf_op_lib_names
     "audio_ops"
     "array_ops"
-               "batch_ops"
+    "batch_ops"
     "bitwise_ops"
+    "boosted_trees_ops"
     "candidate_sampling_ops"
     "checkpoint_ops"
     "control_flow_ops"
@@ -28,7 +29,7 @@ set(tf_op_lib_names
     "image_ops"
     "io_ops"
     "linalg_ops"
-               "list_ops"
+    "list_ops"
     "lookup_ops"
     "logging_ops"
     "manip_ops"
@@ -48,7 +49,7 @@ set(tf_op_lib_names
     "state_ops"
     "stateless_random_ops"
     "string_ops"
-               "summary_ops"
+    "summary_ops"
     "training_ops"
 )
 
index 31e715b..b776307 100755 (executable)
@@ -319,6 +319,7 @@ GENERATE_PYTHON_OP_LIB("audio_ops")
 GENERATE_PYTHON_OP_LIB("array_ops")
 GENERATE_PYTHON_OP_LIB("batch_ops")
 GENERATE_PYTHON_OP_LIB("bitwise_ops")
+GENERATE_PYTHON_OP_LIB("boosted_trees_ops")
 GENERATE_PYTHON_OP_LIB("math_ops")
 GENERATE_PYTHON_OP_LIB("functional_ops")
 GENERATE_PYTHON_OP_LIB("candidate_sampling_ops")
index d125e40..2be62c9 100644 (file)
@@ -14,6 +14,7 @@ py_library(
     srcs = ["__init__.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":boosted_trees",
         ":dnn",
         ":dnn_linear_combined",
         ":extenders",
@@ -27,6 +28,36 @@ py_library(
 )
 
 py_library(
+    name = "boosted_trees",
+    srcs = ["python/estimator/boosted_trees.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/python/estimator",
+        "//tensorflow/python/estimator:boosted_trees",
+    ],
+)
+
+py_test(
+    name = "boosted_trees_test",
+    size = "medium",
+    srcs = ["python/estimator/boosted_trees_test.py"],
+    srcs_version = "PY2AND3",
+    tags = [
+        "no_pip",
+        "notsan",
+    ],
+    deps = [
+        ":boosted_trees",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:training",
+        "//tensorflow/python/estimator:numpy_io",
+        "//tensorflow/python/feature_column",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_library(
     name = "dnn",
     srcs = ["python/estimator/dnn.py"],
     srcs_version = "PY2AND3",
index 6b9f957..d2fc2c4 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 # pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.estimator.python.estimator.boosted_trees import *
 from tensorflow.contrib.estimator.python.estimator.dnn import *
 from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import *
 from tensorflow.contrib.estimator.python.estimator.extenders import *
@@ -44,6 +45,8 @@ _allowed_symbols = [
     'DNNEstimator',
     'DNNLinearCombinedEstimator',
     'LinearEstimator',
+    'boosted_trees_classifier_train_in_memory',
+    'boosted_trees_regressor_train_in_memory',
     'call_logit_fn',
     'dnn_logit_fn_builder',
     'linear_logit_fn_builder',
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
new file mode 100644 (file)
index 0000000..5880164
--- /dev/null
@@ -0,0 +1,323 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Boosted Trees estimators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
+
+
+class _BoostedTreesEstimator(estimator.Estimator):
+  """An Estimator for Tensorflow Boosted Trees models."""
+
+  def __init__(self,
+               feature_columns,
+               n_batches_per_layer,
+               head,
+               model_dir=None,
+               weight_column=None,
+               n_trees=100,
+               max_depth=6,
+               learning_rate=0.1,
+               l1_regularization=0.,
+               l2_regularization=0.,
+               tree_complexity=0.,
+               config=None):
+    """Initializes a `BoostedTreesEstimator` instance.
+
+    Args:
+      feature_columns: An iterable containing all the feature columns used by
+        the model. All items in the set should be instances of classes derived
+        from `FeatureColumn`.
+      n_batches_per_layer: the number of batches to collect statistics per
+        layer.
+      head: the `Head` instance defined for Estimator.
+      model_dir: Directory to save model parameters, graph and etc. This can
+        also be used to load checkpoints from the directory into a estimator
+        to continue training a previously saved model.
+      weight_column: A string or a `_NumericColumn` created by
+        `tf.feature_column.numeric_column` defining feature column representing
+        weights. It is used to downweight or boost examples during training. It
+        will be multiplied by the loss of the example. If it is a string, it is
+        used as a key to fetch weight tensor from the `features`. If it is a
+        `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+        then weight_column.normalizer_fn is applied on it to get weight tensor.
+      n_trees: number trees to be created.
+      max_depth: maximum depth of the tree to grow.
+      learning_rate: shrinkage parameter to be used when a tree added to the
+        model.
+      l1_regularization: regularization multiplier applied to the absolute
+        weights of the tree leafs.
+      l2_regularization: regularization multiplier applied to the square weights
+        of the tree leafs.
+      tree_complexity: regularization factor to penalize trees with more leaves.
+      config: `RunConfig` object to configure the runtime settings.
+    """
+    # TODO(youngheek): param validations.
+
+    # HParams for the model.
+    tree_hparams = canned_boosted_trees.TreeHParams(
+        n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+        tree_complexity)
+
+    def _model_fn(features, labels, mode, config):
+      return canned_boosted_trees._bt_model_fn(  # pylint: disable=protected-access
+          features, labels, mode, head, feature_columns, tree_hparams,
+          n_batches_per_layer, config)
+
+    super(_BoostedTreesEstimator, self).__init__(
+        model_fn=_model_fn, model_dir=model_dir, config=config)
+
+
+def boosted_trees_classifier_train_in_memory(
+    train_input_fn,
+    feature_columns,
+    model_dir=None,
+    n_classes=canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT,
+    weight_column=None,
+    label_vocabulary=None,
+    n_trees=100,
+    max_depth=6,
+    learning_rate=0.1,
+    l1_regularization=0.,
+    l2_regularization=0.,
+    tree_complexity=0.,
+    config=None,
+    train_hooks=None):
+  """Trains a boosted tree classifier with in memory dataset.
+
+  Example:
+
+  ```python
+  bucketized_feature_1 = bucketized_column(
+    numeric_column('feature_1'), BUCKET_BOUNDARIES_1)
+  bucketized_feature_2 = bucketized_column(
+    numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+
+  def input_fn_train():
+    dataset = create-dataset-from-training-data
+    # Don't use repeat or cache, since it is assumed to be one epoch
+    # This is either tf.data.Dataset, or a tuple of feature dict and label.
+    return dataset
+
+  classifier = boosted_trees_classifier_train_in_memory(
+      train_input_fn,
+      feature_columns=[bucketized_feature_1, bucketized_feature_2],
+      n_trees=100,
+      ... <some other params>
+  )
+
+  def input_fn_eval():
+    ...
+    return dataset
+
+  metrics = classifier.evaluate(input_fn=input_fn_eval, steps=10)
+  ```
+
+  Args:
+    train_input_fn: the input function returns a dataset containing a single
+      epoch of *unbatched* features and labels.
+    feature_columns: An iterable containing all the feature columns used by
+      the model. All items in the set should be instances of classes derived
+      from `FeatureColumn`.
+    model_dir: Directory to save model parameters, graph and etc. This can
+      also be used to load checkpoints from the directory into a estimator
+      to continue training a previously saved model.
+    n_classes: number of label classes. Default is binary classification.
+      Multiclass support is not yet implemented.
+    weight_column: A string or a `_NumericColumn` created by
+      `tf.feature_column.numeric_column` defining feature column representing
+      weights. It is used to downweight or boost examples during training. It
+      will be multiplied by the loss of the example. If it is a string, it is
+      used as a key to fetch weight tensor from the `features`. If it is a
+      `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+      then weight_column.normalizer_fn is applied on it to get weight tensor.
+    label_vocabulary: A list of strings represents possible label values. If
+      given, labels must be string type and have any value in
+      `label_vocabulary`. If it is not given, that means labels are
+      already encoded as integer or float within [0, 1] for `n_classes=2` and
+      encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
+      Also there will be errors if vocabulary is not provided and labels are
+      string.
+    n_trees: number trees to be created.
+    max_depth: maximum depth of the tree to grow.
+    learning_rate: shrinkage parameter to be used when a tree added to the
+      model.
+    l1_regularization: regularization multiplier applied to the absolute
+      weights of the tree leafs.
+    l2_regularization: regularization multiplier applied to the square weights
+      of the tree leafs.
+    tree_complexity: regularization factor to penalize trees with more leaves.
+    config: `RunConfig` object to configure the runtime settings.
+    train_hooks: a list of Hook instances to be passed to estimator.train().
+
+  Returns:
+    a `BoostedTreesClassifier` instance created with the given arguments and
+      trained with the data loaded up on memory from the input_fn.
+
+  Raises:
+    ValueError: when wrong arguments are given or unsupported functionalities
+       are requested.
+  """
+  # pylint: disable=protected-access
+  # TODO(nponomareva): Support multi-class cases.
+  if n_classes == canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT:
+    n_classes = 2
+  head, closed_form = (
+      canned_boosted_trees._create_classification_head_and_closed_form(
+          n_classes, weight_column, label_vocabulary=label_vocabulary))
+
+  # HParams for the model.
+  tree_hparams = canned_boosted_trees.TreeHParams(
+      n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+      tree_complexity)
+
+  def _model_fn(features, labels, mode, config):
+    return canned_boosted_trees._bt_model_fn(
+        features,
+        labels,
+        mode,
+        head,
+        feature_columns,
+        tree_hparams,
+        n_batches_per_layer=1,
+        config=config,
+        closed_form_grad_and_hess_fn=closed_form,
+        train_in_memory=True)
+
+  in_memory_classifier = estimator.Estimator(
+      model_fn=_model_fn, model_dir=model_dir, config=config)
+
+  in_memory_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
+
+  return in_memory_classifier
+  # pylint: enable=protected-access
+
+
+def boosted_trees_regressor_train_in_memory(
+    train_input_fn,
+    feature_columns,
+    model_dir=None,
+    label_dimension=canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT,
+    weight_column=None,
+    n_trees=100,
+    max_depth=6,
+    learning_rate=0.1,
+    l1_regularization=0.,
+    l2_regularization=0.,
+    tree_complexity=0.,
+    config=None,
+    train_hooks=None):
+  """Trains a boosted tree regressor with in memory dataset.
+
+  Example:
+
+  ```python
+  bucketized_feature_1 = bucketized_column(
+    numeric_column('feature_1'), BUCKET_BOUNDARIES_1)
+  bucketized_feature_2 = bucketized_column(
+    numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+
+  def input_fn_train():
+    dataset = create-dataset-from-training-data
+    # Don't use repeat or cache, since it is assumed to be one epoch
+    # This is either tf.data.Dataset, or a tuple of feature dict and label.
+    return dataset
+
+  regressor = boosted_trees_regressor_train_in_memory(
+      train_input_fn,
+      feature_columns=[bucketized_feature_1, bucketized_feature_2],
+      n_trees=100,
+      ... <some other params>
+  )
+
+  def input_fn_eval():
+    ...
+    return dataset
+
+  metrics = regressor.evaluate(input_fn=input_fn_eval, steps=10)
+  ```
+
+  Args:
+    train_input_fn: the input function returns a dataset containing a single
+      epoch of *unbatched* features and labels.
+    feature_columns: An iterable containing all the feature columns used by
+      the model. All items in the set should be instances of classes derived
+      from `FeatureColumn`.
+    model_dir: Directory to save model parameters, graph and etc. This can
+      also be used to load checkpoints from the directory into a estimator
+      to continue training a previously saved model.
+    label_dimension: Number of regression targets per example.
+      Multi-dimensional support is not yet implemented.
+    weight_column: A string or a `_NumericColumn` created by
+      `tf.feature_column.numeric_column` defining feature column representing
+      weights. It is used to downweight or boost examples during training. It
+      will be multiplied by the loss of the example. If it is a string, it is
+      used as a key to fetch weight tensor from the `features`. If it is a
+      `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+      then weight_column.normalizer_fn is applied on it to get weight tensor.
+    n_trees: number trees to be created.
+    max_depth: maximum depth of the tree to grow.
+    learning_rate: shrinkage parameter to be used when a tree added to the
+      model.
+    l1_regularization: regularization multiplier applied to the absolute
+      weights of the tree leafs.
+    l2_regularization: regularization multiplier applied to the square weights
+      of the tree leafs.
+    tree_complexity: regularization factor to penalize trees with more leaves.
+    config: `RunConfig` object to configure the runtime settings.
+    train_hooks: a list of Hook instances to be passed to estimator.train().
+
+  Returns:
+    a `BoostedTreesClassifier` instance created with the given arguments and
+      trained with the data loaded up on memory from the input_fn.
+
+  Raises:
+    ValueError: when wrong arguments are given or unsupported functionalities
+       are requested.
+  """
+  # pylint: disable=protected-access
+  # TODO(nponomareva): Extend it to multi-dimension cases.
+  if label_dimension == canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT:
+    label_dimension = 1
+  head = canned_boosted_trees._create_regression_head(label_dimension,
+                                                      weight_column)
+
+  # HParams for the model.
+  tree_hparams = canned_boosted_trees.TreeHParams(
+      n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+      tree_complexity)
+
+  def _model_fn(features, labels, mode, config):
+    return canned_boosted_trees._bt_model_fn(
+        features,
+        labels,
+        mode,
+        head,
+        feature_columns,
+        tree_hparams,
+        n_batches_per_layer=1,
+        config=config,
+        train_in_memory=True)
+
+  in_memory_regressor = estimator.Estimator(
+      model_fn=_model_fn, model_dir=model_dir, config=config)
+
+  in_memory_regressor.train(input_fn=train_input_fn, hooks=train_hooks)
+
+  return in_memory_regressor
+  # pylint: enable=protected-access
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
new file mode 100644 (file)
index 0000000..e99a87f
--- /dev/null
@@ -0,0 +1,207 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests boosted_trees estimators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.estimator.python.estimator import boosted_trees
+from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import checkpoint_utils
+
+NUM_FEATURES = 3
+
+BUCKET_BOUNDARIES = [-2., .5, 12.]  # Boundaries for all the features.
+INPUT_FEATURES = np.array(
+    [
+        [12.5, 1.0, -2.001, -2.0001, -1.999],  # feature_0 quantized:[3,2,0,0,1]
+        [2.0, -3.0, 0.5, 0.0, 0.4995],         # feature_1 quantized:[2,0,2,1,1]
+        [3.0, 20.0, 50.0, -100.0, 102.75],     # feature_2 quantized:[2,3,3,0,3]
+    ],
+    dtype=np.float32)
+CLASSIFICATION_LABELS = [[0.], [1.], [1.], [0.], [0.]]
+REGRESSION_LABELS = [[1.5], [0.3], [0.2], [2.], [5.]]
+FEATURES_DICT = {'f_%d' % i: INPUT_FEATURES[i] for i in range(NUM_FEATURES)}
+
+
+def _make_train_input_fn(is_classification):
+  """Makes train input_fn for classification/regression."""
+
+  def _input_fn():
+    features = dict(FEATURES_DICT)
+    if is_classification:
+      labels = CLASSIFICATION_LABELS
+    else:
+      labels = REGRESSION_LABELS
+    return features, labels
+
+  return _input_fn
+
+
+class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self._head = canned_boosted_trees._create_regression_head(label_dimension=1)
+    self._feature_columns = {
+        feature_column.bucketized_column(
+            feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+            BUCKET_BOUNDARIES)
+        for i in range(NUM_FEATURES)
+    }
+
+  def _assert_checkpoint(self, model_dir, expected_global_step):
+    self.assertEqual(expected_global_step,
+                     checkpoint_utils.load_variable(model_dir,
+                                                    ops.GraphKeys.GLOBAL_STEP))
+
+  def testTrainAndEvaluateEstimator(self):
+    input_fn = _make_train_input_fn(is_classification=False)
+
+    est = boosted_trees._BoostedTreesEstimator(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=2,
+        head=self._head,
+        max_depth=5)
+
+    # It will stop after 10 steps because of the max depth and num trees.
+    num_steps = 100
+    # Train for a few steps, and validate final checkpoint.
+    est.train(input_fn, steps=num_steps)
+    self._assert_checkpoint(est.model_dir, 11)
+    eval_res = est.evaluate(input_fn=input_fn, steps=1)
+    self.assertAllClose(eval_res['average_loss'], 0.913176)
+
+  def testInferEstimator(self):
+    train_input_fn = _make_train_input_fn(is_classification=False)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees._BoostedTreesEstimator(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5,
+        head=self._head)
+
+    # It will stop after 5 steps because of the max depth and num trees.
+    num_steps = 100
+    # Train for a few steps, and validate final checkpoint.
+    est.train(train_input_fn, steps=num_steps)
+    self._assert_checkpoint(est.model_dir, 6)
+
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertEquals(5, len(predictions))
+    self.assertAllClose([0.703549], predictions[0]['predictions'])
+    self.assertAllClose([0.266539], predictions[1]['predictions'])
+    self.assertAllClose([0.256479], predictions[2]['predictions'])
+    self.assertAllClose([1.088732], predictions[3]['predictions'])
+    self.assertAllClose([1.901732], predictions[4]['predictions'])
+
+
+class BoostedTreesClassifierTrainInMemoryTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self._feature_columns = {
+        feature_column.bucketized_column(
+            feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+            BUCKET_BOUNDARIES)
+        for i in range(NUM_FEATURES)
+    }
+
+  def _assert_checkpoint(self, model_dir, expected_global_step):
+    self.assertEqual(expected_global_step,
+                     checkpoint_utils.load_variable(model_dir,
+                                                    ops.GraphKeys.GLOBAL_STEP))
+
+  def testBinaryClassifierTrainInMemoryAndEvalAndInfer(self):
+    train_input_fn = _make_train_input_fn(is_classification=True)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.boosted_trees_classifier_train_in_memory(
+        train_input_fn=train_input_fn,
+        feature_columns=self._feature_columns,
+        n_trees=1,
+        max_depth=5)
+    # It will stop after 5 steps because of the max depth and num trees.
+    self._assert_checkpoint(est.model_dir, 6)
+
+    # Check eval.
+    eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+    self.assertAllClose(eval_res['accuracy'], 1.0)
+
+    # Check predict that all labels are correct.
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertEquals(5, len(predictions))
+    self.assertAllClose([0], predictions[0]['class_ids'])
+    self.assertAllClose([1], predictions[1]['class_ids'])
+    self.assertAllClose([1], predictions[2]['class_ids'])
+    self.assertAllClose([0], predictions[3]['class_ids'])
+    self.assertAllClose([0], predictions[4]['class_ids'])
+
+
+class BoostedTreesRegressorTrainInMemoryTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self._feature_columns = {
+        feature_column.bucketized_column(
+            feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+            BUCKET_BOUNDARIES)
+        for i in range(NUM_FEATURES)
+    }
+
+  def _assert_checkpoint(self, model_dir, expected_global_step):
+    self.assertEqual(expected_global_step,
+                     checkpoint_utils.load_variable(model_dir,
+                                                    ops.GraphKeys.GLOBAL_STEP))
+
+  def testRegressorTrainInMemoryAndEvalAndInfer(self):
+    train_input_fn = _make_train_input_fn(is_classification=False)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.boosted_trees_regressor_train_in_memory(
+        train_input_fn=train_input_fn,
+        feature_columns=self._feature_columns,
+        n_trees=1,
+        max_depth=5)
+    # It will stop after 5 steps because of the max depth and num trees.
+    self._assert_checkpoint(est.model_dir, 6)
+
+    # Check eval.
+    eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+    self.assertAllClose(eval_res['average_loss'], 2.2136638)
+
+    # Validate predictions.
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertEquals(5, len(predictions))
+    self.assertAllClose([0.703549], predictions[0]['predictions'])
+    self.assertAllClose([0.266539], predictions[1]['predictions'])
+    self.assertAllClose([0.256479], predictions[2]['predictions'])
+    self.assertAllClose([1.088732], predictions[3]['predictions'])
+    self.assertAllClose([1.901732], predictions[4]['predictions'])
+
+
+if __name__ == '__main__':
+  googletest.main()
index 5a812af..1578629 100644 (file)
@@ -228,6 +228,11 @@ tensorflow/core/kernels/cast_op_impl_int64.cc
 tensorflow/core/kernels/cast_op_impl_int8.cc
 tensorflow/core/kernels/cast_op_impl_uint16.cc
 tensorflow/core/kernels/cast_op_impl_uint8.cc
+tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+tensorflow/core/kernels/boosted_trees/resource_ops.cc
+tensorflow/core/kernels/boosted_trees/resources.cc
+tensorflow/core/kernels/boosted_trees/stats_ops.cc
+tensorflow/core/kernels/boosted_trees/training_ops.cc
 tensorflow/core/kernels/bias_op.cc
 tensorflow/core/kernels/bcast_ops.cc
 tensorflow/core/kernels/batch_norm_op.cc
@@ -285,6 +290,7 @@ tensorflow/core/ops/data_flow_ops.cc
 tensorflow/core/ops/ctc_ops.cc
 tensorflow/core/ops/control_flow_ops.cc
 tensorflow/core/ops/candidate_sampling_ops.cc
+tensorflow/core/ops/boosted_trees_ops.cc
 tensorflow/core/ops/array_ops.cc
 tensorflow/core/ops/array_grad.cc
 tensorflow/core/kernels/spacetobatch_functor.cc
index d569bde..1f25469 100644 (file)
@@ -18,6 +18,7 @@ tensorflow/core/protobuf/device_properties.proto
 tensorflow/core/protobuf/rewriter_config.proto
 tensorflow/core/protobuf/tensor_bundle.proto
 tensorflow/core/lib/core/error_codes.proto
+tensorflow/core/kernels/boosted_trees/boosted_trees.proto
 tensorflow/core/framework/versions.proto
 tensorflow/core/framework/variable.proto
 tensorflow/core/framework/types.proto
index b8dbd90..614e06c 100644 (file)
@@ -629,6 +629,7 @@ tf_gen_op_libs(
     op_lib_names = [
         "batch_ops",
         "bitwise_ops",
+        "boosted_trees_ops",
         "candidate_sampling_ops",
         "checkpoint_ops",
         "control_flow_ops",
@@ -741,6 +742,7 @@ cc_library(
         ":audio_ops_op_lib",
         ":batch_ops_op_lib",
         ":bitwise_ops_op_lib",
+        ":boosted_trees_ops_op_lib",
         ":candidate_sampling_ops_op_lib",
         ":checkpoint_ops_op_lib",
         ":control_flow_ops_op_lib",
@@ -882,6 +884,7 @@ cc_library(
         "//tensorflow/core/kernels:audio",
         "//tensorflow/core/kernels:batch_kernels",
         "//tensorflow/core/kernels:bincount_op",
+        "//tensorflow/core/kernels:boosted_trees_ops",
         "//tensorflow/core/kernels:candidate_sampler_ops",
         "//tensorflow/core/kernels:checkpoint_ops",
         "//tensorflow/core/kernels:control_flow_ops",
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt
new file mode 100644 (file)
index 0000000..b1921e3
--- /dev/null
@@ -0,0 +1,87 @@
+op {
+  graph_op_name: "BoostedTreesCalculateBestGainsPerFeature"
+  visibility: HIDDEN
+  in_arg {
+    name: "node_id_range"
+    description: <<END
+A Rank 1 tensor (shape=[2]) to specify the range [first, last] of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1]+1)` (Note that the last index node_id_range[1] is inclusive).
+END
+  }
+  in_arg {
+    name: "stats_summary_list"
+    description: <<END
+A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used.
+END
+  }
+  out_arg {
+    name: "node_ids_list"
+    description: <<END
+An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.
+END
+  }
+  out_arg {
+    name: "gains_list"
+    description: <<END
+An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.
+END
+  }
+  out_arg {
+    name: "thresholds_list"
+    description: <<END
+An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.
+END
+  }
+  out_arg {
+    name: "left_node_contribs_list"
+    description: <<END
+A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.
+END
+  }
+  out_arg {
+    name: "right_node_contribs_list"
+    description: <<END
+A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node.
+END
+  }
+  attr {
+    name: "l1"
+    description: <<END
+l1 regularization factor on leaf weights, per instance based.
+END
+  }
+  attr {
+    name: "l2"
+    description: <<END
+l2 regularization factor on leaf weights, per instance based.
+END
+  }
+  attr {
+    name: "tree_complexity"
+    description: <<END
+adjustment to the gain, per leaf based.
+END
+  }
+  attr {
+    name: "max_splits"
+    description: <<END
+the number of nodes that can be split in the whole tree. Used as a dimension of output tensors.
+END
+  }
+  attr {
+    name: "num_features"
+    description: <<END
+inferred from the size of `stats_summary_list`; the number of total features.
+END
+  }
+  summary: "Calculates gains for each feature and returns the best possible split information for the feature."
+  description: <<END
+The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature.
+
+It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split.
+
+In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features).
+
+The length of output lists are all of the same length, `num_features`.
+The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateEnsemble.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateEnsemble.pbtxt
new file mode 100644 (file)
index 0000000..aee73b9
--- /dev/null
@@ -0,0 +1,23 @@
+op {
+  graph_op_name: "BoostedTreesCreateEnsemble"
+  visibility: HIDDEN
+  in_arg {
+    name: "tree_ensemble_handle"
+    description: <<END
+Handle to the tree ensemble resource to be created.
+END
+  }
+  in_arg {
+    name: "stamp_token"
+    description: <<END
+Token to use as the initial value of the resource stamp.
+END
+  }
+  in_arg {
+    name: "tree_ensemble_serialized"
+    description: <<END
+Serialized proto of the tree ensemble.
+END
+  }
+  summary: "Creates a tree ensemble model and returns a handle to it."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesDeserializeEnsemble.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesDeserializeEnsemble.pbtxt
new file mode 100644 (file)
index 0000000..b1602ba
--- /dev/null
@@ -0,0 +1,26 @@
+op {
+  graph_op_name: "BoostedTreesDeserializeEnsemble"
+  visibility: HIDDEN
+  in_arg {
+    name: "tree_ensemble_handle"
+    description: <<END
+Handle to the tree ensemble.
+END
+  }
+  in_arg {
+    name: "stamp_token"
+    description: <<END
+Token to use as the new value of the resource stamp.
+END
+  }
+  in_arg {
+    name: "tree_ensemble_serialized"
+    description: <<END
+Serialized proto of the ensemble.
+END
+  }
+  summary: "Deserializes a serialized tree ensemble config and replaces current tree"
+  description: <<END
+ensemble.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesEnsembleResourceHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesEnsembleResourceHandleOp.pbtxt
new file mode 100644 (file)
index 0000000..1bce563
--- /dev/null
@@ -0,0 +1,5 @@
+op {
+  graph_op_name: "BoostedTreesEnsembleResourceHandleOp"
+  visibility: HIDDEN
+  summary: "Creates a handle to a BoostedTreesEnsembleResource"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesGetEnsembleStates.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesGetEnsembleStates.pbtxt
new file mode 100644 (file)
index 0000000..ef45a92
--- /dev/null
@@ -0,0 +1,35 @@
+op {
+  graph_op_name: "BoostedTreesGetEnsembleStates"
+  visibility: HIDDEN
+  in_arg {
+    name: "tree_ensemble_handle"
+    description: <<END
+Handle to the tree ensemble.
+END
+  }
+  out_arg {
+    name: "stamp_token"
+    description: <<END
+Stamp token of the tree ensemble resource.
+END
+  }
+  out_arg {
+    name: "num_trees"
+    description: <<END
+The number of trees in the tree ensemble resource.
+END
+  }
+  out_arg {
+    name: "num_finalized_trees"
+    description: <<END
+The number of trees that were finished successfully.
+END
+  }
+  out_arg {
+    name: "num_attempted_layers"
+    description: <<END
+The number of layers we attempted to build (but not necessarily succeeded).
+END
+  }
+  summary: "Retrieves the tree ensemble resource stamp token."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeStatsSummary.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeStatsSummary.pbtxt
new file mode 100644 (file)
index 0000000..dc0856c
--- /dev/null
@@ -0,0 +1,56 @@
+op {
+  graph_op_name: "BoostedTreesMakeStatsSummary"
+  visibility: HIDDEN
+  in_arg {
+    name: "node_ids"
+    description: <<END
+int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer.
+END
+  }
+  in_arg {
+    name: "gradients"
+    description: <<END
+float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients.
+END
+  }
+  in_arg {
+    name: "hessians"
+    description: <<END
+float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians.
+END
+  }
+  in_arg {
+    name: "bucketized_features_list"
+    description: <<END
+int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column).
+END
+  }
+  out_arg {
+    name: "stats_summary"
+    description: <<END
+output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians.
+END
+  }
+  attr {
+    name: "max_splits"
+    description: <<END
+int; the maximum number of splits possible in the whole tree.
+END
+  }
+  attr {
+    name: "num_buckets"
+    description: <<END
+int; equals to the maximum possible value of bucketized feature.
+END
+  }
+  attr {
+    name: "num_features"
+    description: <<END
+int; inferred from the size of bucketized_features_list; the number of features.
+END
+  }
+  summary: "Makes the summary of accumulated stats for the batch."
+  description: <<END
+The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesPredict.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesPredict.pbtxt
new file mode 100644 (file)
index 0000000..b23e77a
--- /dev/null
@@ -0,0 +1,41 @@
+op {
+  graph_op_name: "BoostedTreesPredict"
+  visibility: HIDDEN
+  in_arg {
+    name: "bucketized_features"
+    description: <<END
+A list of rank 1 Tensors containing bucket id for each
+feature.
+END
+  }
+  out_arg {
+    name: "logits"
+    description: <<END
+Output rank 2 Tensor containing logits for each example.
+END
+  }
+  attr {
+    name: "num_bucketized_features"
+    description: <<END
+Inferred.
+END
+  }
+  attr {
+    name: "logits_dimension"
+    description: <<END
+scalar, dimension of the logits, to be used for partial logits
+shape.
+END
+  }
+  attr {
+    name: "max_depth"
+    description: <<END
+scalar, max depth of trees. To be used for parallelization costs.
+END
+  }
+  summary: "Runs multiple additive regression ensemble predictors on input instances and"
+  description: <<END
+computes the logits. It is designed to be used during prediction.
+It traverses all the trees and calculates the final score for each instance.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesSerializeEnsemble.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesSerializeEnsemble.pbtxt
new file mode 100644 (file)
index 0000000..c0b3688
--- /dev/null
@@ -0,0 +1,23 @@
+op {
+  graph_op_name: "BoostedTreesSerializeEnsemble"
+  visibility: HIDDEN
+  in_arg {
+    name: "tree_ensemble_handle"
+    description: <<END
+Handle to the tree ensemble.
+END
+  }
+  out_arg {
+    name: "stamp_token"
+    description: <<END
+Stamp token of the tree ensemble resource.
+END
+  }
+  out_arg {
+    name: "tree_ensemble_serialized"
+    description: <<END
+Serialized proto of the ensemble.
+END
+  }
+  summary: "Serializes the tree ensemble to a proto."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesTrainingPredict.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesTrainingPredict.pbtxt
new file mode 100644 (file)
index 0000000..7203d3c
--- /dev/null
@@ -0,0 +1,69 @@
+op {
+  graph_op_name: "BoostedTreesTrainingPredict"
+  visibility: HIDDEN
+  in_arg {
+    name: "cached_tree_ids"
+    description: <<END
+Rank 1 Tensor containing cached tree ids which is the starting
+tree of prediction.
+END
+  }
+  in_arg {
+    name: "cached_node_ids"
+    description: <<END
+Rank 1 Tensor containing cached node id which is the starting
+node of prediction.
+END
+  }
+  in_arg {
+    name: "bucketized_features"
+    description: <<END
+A list of rank 1 Tensors containing bucket id for each
+feature.
+END
+  }
+  out_arg {
+    name: "partial_logits"
+    description: <<END
+Rank 2 Tensor containing logits update (with respect to cached
+values stored) for each example.
+END
+  }
+  out_arg {
+    name: "tree_ids"
+    description: <<END
+Rank 1 Tensor containing new tree ids for each example.
+END
+  }
+  out_arg {
+    name: "node_ids"
+    description: <<END
+Rank 1 Tensor containing new node ids in the new tree_ids.
+END
+  }
+  attr {
+    name: "num_bucketized_features"
+    description: <<END
+Inferred.
+END
+  }
+  attr {
+    name: "logits_dimension"
+    description: <<END
+scalar, dimension of the logits, to be used for partial logits
+shape.
+END
+  }
+  attr {
+    name: "max_depth"
+    description: <<END
+scalar, max depth of trees. To be used for parallelization costs.
+END
+  }
+  summary: "Runs multiple additive regression ensemble predictors on input instances and"
+  description: <<END
+computes the update to cached logits. It is designed to be used during training.
+It traverses the trees starting from cached tree id and cached node id and
+calculates the updates to be pushed to the cache.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsemble.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsemble.pbtxt
new file mode 100644 (file)
index 0000000..00f8953
--- /dev/null
@@ -0,0 +1,82 @@
+op {
+  graph_op_name: "BoostedTreesUpdateEnsemble"
+  visibility: HIDDEN
+  in_arg {
+    name: "tree_ensemble_handle"
+    description: <<END
+Handle to the ensemble variable.
+END
+  }
+  in_arg {
+    name: "feature_ids"
+    description: <<END
+Rank 1 tensor with ids for each feature. This is the real id of
+the feature that will be used in the split.
+END
+  }
+  in_arg {
+    name: "node_ids"
+    description: <<END
+List of rank 1 tensors representing the nodes for which this feature
+has a split.
+END
+  }
+  in_arg {
+    name: "gains"
+    description: <<END
+List of rank 1 tensors representing the gains for each of the feature's
+split.
+END
+  }
+  in_arg {
+    name: "thresholds"
+    description: <<END
+List of rank 1 tensors representing the thesholds for each of the
+feature's split.
+END
+  }
+  in_arg {
+    name: "left_node_contribs"
+    description: <<END
+List of rank 2 tensors with left leaf contribs for each of
+the feature's splits. Will be added to the previous node values to constitute
+the values of the left nodes.
+END
+  }
+  in_arg {
+    name: "right_node_contribs"
+    description: <<END
+List of rank 2 tensors with right leaf contribs for each
+of the feature's splits. Will be added to the previous node values to constitute
+the values of the right nodes.
+END
+  }
+  attr {
+    name: "max_depth"
+    description: <<END
+Max depth of the tree to build.
+END
+  }
+  attr {
+    name: "learning_rate"
+    description: <<END
+shrinkage const for each new tree.
+END
+  }
+  attr {
+    name: "pruning_mode"
+    description: <<END
+0-No pruning, 1-Pre-pruning, 2-Post-pruning.
+END
+  }
+  attr {
+    name: "num_features"
+    description: <<END
+Number of features that have best splits returned. INFERRED.
+END
+  }
+  summary: "Updates the tree ensemble by either adding a layer to the last tree being grown"
+  description: <<END
+or by starting a new tree.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesEnsembleInitialized.pbtxt b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesEnsembleInitialized.pbtxt
new file mode 100644 (file)
index 0000000..d54b7ef
--- /dev/null
@@ -0,0 +1,17 @@
+op {
+  graph_op_name: "IsBoostedTreesEnsembleInitialized"
+  visibility: HIDDEN
+  in_arg {
+    name: "tree_ensemble_handle"
+    description: <<END
+Handle to the tree ensemble resouce.
+END
+  }
+  out_arg {
+    name: "is_initialized"
+    description: <<END
+output boolean on whether it is initialized or not.
+END
+  }
+  summary: "Checks whether a tree ensemble has been initialized."
+}
index ca54978..d2a2cdd 100644 (file)
@@ -6096,6 +6096,13 @@ cc_library(
     ],
 )
 
+tf_kernel_library(
+    name = "boosted_trees_ops",
+    deps = [
+        "//tensorflow/core/kernels/boosted_trees:boosted_trees_ops",
+    ],
+)
+
 cc_library(
     name = "captured_function",
     hdrs = ["captured_function.h"],
diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD
new file mode 100644 (file)
index 0000000..62327df
--- /dev/null
@@ -0,0 +1,89 @@
+# Description:
+#   OpKernels for boosted trees ops.
+
+package(
+    default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"])  # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+load(
+    "//tensorflow/core:platform/default/build_config.bzl",
+    "tf_proto_library",
+)
+
+tf_proto_library(
+    name = "boosted_trees_proto",
+    srcs = ["boosted_trees.proto"],
+    cc_api_version = 2,
+    visibility = ["//visibility:public"],
+)
+
+tf_kernel_library(
+    name = "prediction_ops",
+    srcs = ["prediction_ops.cc"],
+    deps = [
+        ":resource_ops",
+        ":resources",
+        "//tensorflow/core:boosted_trees_ops_op_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+cc_library(
+    name = "resources",
+    srcs = ["resources.cc"],
+    hdrs = ["resources.h"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
+    ],
+)
+
+tf_kernel_library(
+    name = "resource_ops",
+    srcs = ["resource_ops.cc"],
+    deps = [
+        ":resources",
+        "//tensorflow/core:boosted_trees_ops_op_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
+    ],
+)
+
+tf_kernel_library(
+    name = "stats_ops",
+    srcs = ["stats_ops.cc"],
+    deps = [
+        "//tensorflow/core:boosted_trees_ops_op_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+    ],
+)
+
+tf_kernel_library(
+    name = "training_ops",
+    srcs = ["training_ops.cc"],
+    deps = [
+        ":resources",
+        "//tensorflow/core:boosted_trees_ops_op_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
+    ],
+)
+
+tf_kernel_library(
+    name = "boosted_trees_ops",
+    deps = [
+        ":prediction_ops",
+        ":resource_ops",
+        ":stats_ops",
+        ":training_ops",
+    ],
+)
diff --git a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
new file mode 100644 (file)
index 0000000..106ceed
--- /dev/null
@@ -0,0 +1,113 @@
+syntax = "proto3";
+
+package tensorflow.boosted_trees;
+option cc_enable_arenas = true;
+option java_outer_classname = "BoostedTreesProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+// Node describes a node in a tree.
+message Node {
+  oneof node {
+    Leaf leaf = 1;
+    BucketizedSplit bucketized_split = 2;
+  }
+  NodeMetadata metadata = 777;
+}
+
+// NodeMetadata encodes metadata associated with each node in a tree.
+message NodeMetadata {
+  // The gain associated with this node.
+  float gain = 1;
+
+  // The original leaf node before this node was split.
+  Leaf original_leaf = 2;
+}
+
+// Leaves can either hold dense or sparse information.
+message Leaf {
+  oneof leaf {
+    // See third_party/tensorflow/contrib/decision_trees/
+    // proto/generic_tree_model.proto
+    // for a description of how vector and sparse_vector might be used.
+    Vector vector = 1;
+    SparseVector sparse_vector = 2;
+  }
+  float scalar = 3;
+}
+
+message Vector {
+  repeated float value = 1;
+}
+
+message SparseVector {
+  repeated int32 index = 1;
+  repeated float value = 2;
+}
+
+message BucketizedSplit {
+  // Float feature column and split threshold describing
+  // the rule feature <= threshold.
+  int32 feature_id = 1;
+  int32 threshold = 2;
+
+  // Node children indexing into a contiguous
+  // vector of nodes starting from the root.
+  int32 left_id = 3;
+  int32 right_id = 4;
+}
+
+// Tree describes a list of connected nodes.
+// Node 0 must be the root and can carry any payload including a leaf
+// in the case of representing the bias.
+// Note that each node id is implicitly its index in the list of nodes.
+message Tree {
+  repeated Node nodes = 1;
+}
+
+message TreeMetadata {
+  // Number of layers grown for this tree.
+  int32 num_layers_grown = 2;
+
+  // Whether the tree is finalized in that no more layers can be grown.
+  bool is_finalized = 3;
+
+  // If tree was finalized and post pruning happened, it is possible that cache
+  // still refers to some nodes that were deleted or that the node ids changed
+  // (e.g. node id 5 became node id 2 due to pruning of the other branch).
+  // The mapping below allows us to understand where the old ids now map to and
+  // how the values should be adjusted due to post-pruning.
+  // The size of the list should be equal to the number of nodes in the tree
+  // before post-pruning happened.
+  // If the node was pruned, it will have new_node_id equal to the id of a node
+  // that this node was collapsed into. For a node that didn't get pruned, it is
+  // possible that its id still changed, so new_node_id will have the
+  // corresponding id in the pruned tree.
+  // If post-pruning didn't happen, or it did and it had no effect (e.g. no
+  // nodes got pruned), this list will be empty.
+  repeated PostPruneNodeUpdate post_pruned_nodes_meta = 4;
+
+  message PostPruneNodeUpdate {
+    int32 new_node_id = 1;
+    float logit_change = 2;
+  }
+}
+
+message GrowingMetadata {
+  // Number of trees that we have attempted to build. After pruning, these
+  // trees might have been removed.
+  int64 num_trees_attempted = 1;
+  // Number of layers that we have attempted to build. After pruning, these
+  // layers might have been removed.
+  int64 num_layers_attempted = 2;
+}
+
+// TreeEnsemble describes an ensemble of decision trees.
+message TreeEnsemble {
+  repeated Tree trees = 1;
+  repeated float tree_weights = 2;
+
+  repeated TreeMetadata tree_metadata = 3;
+  // Metadata that is used during the training.
+  GrowingMetadata growing_metadata = 4;
+}
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
new file mode 100644 (file)
index 0000000..b13a450
--- /dev/null
@@ -0,0 +1,263 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/boosted_trees/resources.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+// The Op used during training time to get the predictions so far with the
+// current ensemble being built.
+// Expect some logits are cached from the previous step and passed through
+// to be reused.
+class BoostedTreesTrainingPredictOp : public OpKernel {
+ public:
+  explicit BoostedTreesTrainingPredictOp(OpKernelConstruction* const context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
+                                             &num_bucketized_features_));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("logits_dimension", &logits_dimension_));
+    OP_REQUIRES(context, logits_dimension_ == 1,
+                errors::InvalidArgument(
+                    "Currently only one dimensional outputs are supported."));
+    OP_REQUIRES_OK(context, context->GetAttr("max_depth", &max_depth_));
+  }
+
+  void Compute(OpKernelContext* const context) override {
+    BoostedTreesEnsembleResource* resource;
+    // Get the resource.
+    OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+                                           &resource));
+    // Release the reference to the resource once we're done using it.
+    core::ScopedUnref unref_me(resource);
+
+    // Get the inputs.
+    OpInputList bucketized_features_list;
+    OP_REQUIRES_OK(context, context->input_list("bucketized_features",
+                                                &bucketized_features_list));
+    std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features;
+    batch_bucketized_features.reserve(bucketized_features_list.size());
+    for (const Tensor& tensor : bucketized_features_list) {
+      batch_bucketized_features.emplace_back(tensor.vec<int32>());
+    }
+    const int batch_size = batch_bucketized_features[0].size();
+
+    const Tensor* cached_tree_ids_t;
+    OP_REQUIRES_OK(context,
+                   context->input("cached_tree_ids", &cached_tree_ids_t));
+    const auto cached_tree_ids = cached_tree_ids_t->vec<int32>();
+
+    const Tensor* cached_node_ids_t;
+    OP_REQUIRES_OK(context,
+                   context->input("cached_node_ids", &cached_node_ids_t));
+    const auto cached_node_ids = cached_node_ids_t->vec<int32>();
+
+    // Allocate outputs.
+    Tensor* output_partial_logits_t = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output("partial_logits",
+                                            {batch_size, logits_dimension_},
+                                            &output_partial_logits_t));
+    auto output_partial_logits = output_partial_logits_t->matrix<float>();
+
+    Tensor* output_tree_ids_t = nullptr;
+    OP_REQUIRES_OK(context, context->allocate_output("tree_ids", {batch_size},
+                                                     &output_tree_ids_t));
+    auto output_tree_ids = output_tree_ids_t->vec<int32>();
+
+    Tensor* output_node_ids_t = nullptr;
+    OP_REQUIRES_OK(context, context->allocate_output("node_ids", {batch_size},
+                                                     &output_node_ids_t));
+    auto output_node_ids = output_node_ids_t->vec<int32>();
+
+    // Indicate that the latest tree was used.
+    const int32 latest_tree = resource->num_trees() - 1;
+
+    if (latest_tree < 0) {
+      // Ensemble was empty. Nothing changes.
+      output_node_ids = cached_node_ids;
+      output_tree_ids = cached_tree_ids;
+      // All the predictions are zeros.
+      output_partial_logits.setZero();
+    } else {
+      output_tree_ids.setConstant(latest_tree);
+      auto do_work = [&resource, &batch_bucketized_features, &cached_tree_ids,
+                      &cached_node_ids, &output_partial_logits,
+                      &output_node_ids, batch_size,
+                      latest_tree](int32 start, int32 end) {
+        for (int32 i = start; i < end; ++i) {
+          int32 tree_id = cached_tree_ids(i);
+          int32 node_id = cached_node_ids(i);
+          float partial_tree_logit = 0.0;
+
+          // If the tree was pruned, returns the node id into which the
+          // current_node_id was pruned, as well the correction of the cached
+          // logit prediction.
+          resource->GetPostPruneCorrection(tree_id, node_id, &node_id,
+                                           &partial_tree_logit);
+
+          // Logic in the loop adds the cached node value again if it is a leaf.
+          // If it is not a leaf anymore we need to subtract the old node's
+          // value. The following logic handles both of these cases.
+          partial_tree_logit -= resource->node_value(tree_id, node_id);
+          float partial_all_logit = 0.0;
+          while (true) {
+            if (resource->is_leaf(tree_id, node_id)) {
+              partial_tree_logit += resource->node_value(tree_id, node_id);
+
+              // Tree is done
+              partial_all_logit +=
+                  resource->GetTreeWeight(tree_id) * partial_tree_logit;
+              partial_tree_logit = 0.0;
+              // Stop if it was the latest tree.
+              if (tree_id == latest_tree) {
+                break;
+              }
+              // Move onto other trees.
+              ++tree_id;
+              node_id = 0;
+            } else {
+              node_id = resource->next_node(tree_id, node_id, i,
+                                            batch_bucketized_features);
+            }
+          }
+          output_node_ids(i) = node_id;
+          output_partial_logits(i, 0) = partial_all_logit;
+        }
+      };
+      // Assume we will not go over more than one full tree. 4 is a magic
+      // number.
+      const int64 cost = 4 * max_depth_;
+      thread::ThreadPool* const worker_threads =
+          context->device()->tensorflow_cpu_worker_threads()->workers;
+      Shard(worker_threads->NumThreads(), worker_threads, batch_size,
+            /*cost_per_unit=*/cost, do_work);
+    }
+  }
+
+ private:
+  int32 logits_dimension_;         // the size of the output prediction vector.
+  int32 num_bucketized_features_;  // Indicates the number of features.
+  int32 max_depth_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesTrainingPredict").Device(DEVICE_CPU),
+                        BoostedTreesTrainingPredictOp);
+
+// The Op to get the predictions at the evaluation/inference time.
+class BoostedTreesPredictOp : public OpKernel {
+ public:
+  explicit BoostedTreesPredictOp(OpKernelConstruction* const context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
+                                             &num_bucketized_features_));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("logits_dimension", &logits_dimension_));
+    OP_REQUIRES(context, logits_dimension_ == 1,
+                errors::InvalidArgument(
+                    "Currently only one dimensional outputs are supported."));
+    OP_REQUIRES_OK(context, context->GetAttr("max_depth", &max_depth_));
+  }
+
+  void Compute(OpKernelContext* const context) override {
+    BoostedTreesEnsembleResource* resource;
+    // Get the resource.
+    OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+                                           &resource));
+    // Release the reference to the resource once we're done using it.
+    core::ScopedUnref unref_me(resource);
+
+    // Get the inputs.
+    OpInputList bucketized_features_list;
+    OP_REQUIRES_OK(context, context->input_list("bucketized_features",
+                                                &bucketized_features_list));
+    std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features;
+    batch_bucketized_features.reserve(bucketized_features_list.size());
+    for (const Tensor& tensor : bucketized_features_list) {
+      batch_bucketized_features.emplace_back(tensor.vec<int32>());
+    }
+    const int batch_size = batch_bucketized_features[0].size();
+
+    // Allocate outputs.
+    Tensor* output_logits_t = nullptr;
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                "logits", {batch_size, logits_dimension_},
+                                &output_logits_t));
+    auto output_logits = output_logits_t->matrix<float>();
+
+    const int32 latest_tree = resource->num_trees() - 1;
+
+    auto do_work = [&resource, &batch_bucketized_features, &output_logits,
+                    batch_size, latest_tree](int32 start, int32 end) {
+      for (int32 i = start; i < end; ++i) {
+        float tree_logit = 0.0;
+        int32 tree_id = 0;
+        int32 node_id = 0;
+        while (true) {
+          if (resource->is_leaf(tree_id, node_id)) {
+            tree_logit += resource->GetTreeWeight(tree_id) *
+                          resource->node_value(tree_id, node_id);
+
+            // Stop if it was the latest tree.
+            if (tree_id == latest_tree) {
+              break;
+            }
+            // Move onto other trees.
+            ++tree_id;
+            node_id = 0;
+          } else {
+            node_id = resource->next_node(tree_id, node_id, i,
+                                          batch_bucketized_features);
+          }
+        }
+        output_logits(i, 0) = tree_logit;
+      }
+    };
+    const int64 cost = (latest_tree + 1) * max_depth_;
+    thread::ThreadPool* const worker_threads =
+        context->device()->tensorflow_cpu_worker_threads()->workers;
+    Shard(worker_threads->NumThreads(), worker_threads, batch_size,
+          /*cost_per_unit=*/cost, do_work);
+  }
+
+ private:
+  int32
+      logits_dimension_;  // Indicates the size of the output prediction vector.
+  int32 num_bucketized_features_;  // Indicates the number of features.
+  int32 max_depth_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesPredict").Device(DEVICE_CPU),
+                        BoostedTreesPredictOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/resource_ops.cc b/tensorflow/core/kernels/boosted_trees/resource_ops.cc
new file mode 100644 (file)
index 0000000..f49242d
--- /dev/null
@@ -0,0 +1,189 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/boosted_trees/resources.h"
+
+namespace tensorflow {
+
+REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesEnsembleResource);
+
+REGISTER_KERNEL_BUILDER(
+    Name("IsBoostedTreesEnsembleInitialized").Device(DEVICE_CPU),
+    IsResourceInitialized<BoostedTreesEnsembleResource>);
+
+// Creates a tree ensemble resource.
+class BoostedTreesCreateEnsembleOp : public OpKernel {
+ public:
+  explicit BoostedTreesCreateEnsembleOp(OpKernelConstruction* context)
+      : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    // Get the stamp token.
+    const Tensor* stamp_token_t;
+    OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
+    int64 stamp_token = stamp_token_t->scalar<int64>()();
+
+    // Get the tree ensemble proto.
+    const Tensor* tree_ensemble_serialized_t;
+    OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized",
+                                           &tree_ensemble_serialized_t));
+    std::unique_ptr<BoostedTreesEnsembleResource> result(
+        new BoostedTreesEnsembleResource());
+    if (!result->InitFromSerialized(
+            tree_ensemble_serialized_t->scalar<string>()(), stamp_token)) {
+      result->Unref();
+      OP_REQUIRES(
+          context, false,
+          errors::InvalidArgument("Unable to parse tree ensemble proto."));
+    }
+
+    // Only create one, if one does not exist already. Report status for all
+    // other exceptions.
+    auto status =
+        CreateResource(context, HandleFromInput(context, 0), result.release());
+    if (status.code() != tensorflow::error::ALREADY_EXISTS) {
+      OP_REQUIRES_OK(context, status);
+    }
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesCreateEnsemble").Device(DEVICE_CPU),
+                        BoostedTreesCreateEnsembleOp);
+
+// Op for retrieving some model states (needed for training).
+class BoostedTreesGetEnsembleStatesOp : public OpKernel {
+ public:
+  explicit BoostedTreesGetEnsembleStatesOp(OpKernelConstruction* context)
+      : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    // Looks up the resource.
+    BoostedTreesEnsembleResource* tree_ensemble_resource;
+    OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+                                           &tree_ensemble_resource));
+    tf_shared_lock l(*tree_ensemble_resource->get_mutex());
+    core::ScopedUnref unref_me(tree_ensemble_resource);
+
+    // Sets the outputs.
+    const int num_trees = tree_ensemble_resource->num_trees();
+    const int num_finalized_trees =
+        (num_trees <= 0 ||
+         tree_ensemble_resource->IsTreeFinalized(num_trees - 1))
+            ? num_trees
+            : num_trees - 1;
+    const int num_attempted_layers =
+        tree_ensemble_resource->GetNumLayersAttempted();
+
+    // growing_metadata
+    Tensor* output_stamp_token_t = nullptr;
+    Tensor* output_num_trees_t = nullptr;
+    Tensor* output_num_finalized_trees_t = nullptr;
+    Tensor* output_num_attempted_layers_t = nullptr;
+
+    OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
+                                                     &output_stamp_token_t));
+    OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape(),
+                                                     &output_num_trees_t));
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(2, TensorShape(),
+                                            &output_num_finalized_trees_t));
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(3, TensorShape(),
+                                            &output_num_attempted_layers_t));
+
+    output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp();
+    output_num_trees_t->scalar<int32>()() = num_trees;
+    output_num_finalized_trees_t->scalar<int32>()() = num_finalized_trees;
+    output_num_attempted_layers_t->scalar<int32>()() = num_attempted_layers;
+  }
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("BoostedTreesGetEnsembleStates").Device(DEVICE_CPU),
+    BoostedTreesGetEnsembleStatesOp);
+
+// Op for serializing a model.
+class BoostedTreesSerializeEnsembleOp : public OpKernel {
+ public:
+  explicit BoostedTreesSerializeEnsembleOp(OpKernelConstruction* context)
+      : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    BoostedTreesEnsembleResource* tree_ensemble_resource;
+    OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+                                           &tree_ensemble_resource));
+    tf_shared_lock l(*tree_ensemble_resource->get_mutex());
+    core::ScopedUnref unref_me(tree_ensemble_resource);
+    Tensor* output_stamp_token_t = nullptr;
+    OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
+                                                     &output_stamp_token_t));
+    output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp();
+    Tensor* output_proto_t = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(1, TensorShape(), &output_proto_t));
+    output_proto_t->scalar<string>()() =
+        tree_ensemble_resource->SerializeAsString();
+  }
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("BoostedTreesSerializeEnsemble").Device(DEVICE_CPU),
+    BoostedTreesSerializeEnsembleOp);
+
+// Op for deserializing a tree ensemble variable from a checkpoint.
+class BoostedTreesDeserializeEnsembleOp : public OpKernel {
+ public:
+  explicit BoostedTreesDeserializeEnsembleOp(OpKernelConstruction* context)
+      : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    BoostedTreesEnsembleResource* tree_ensemble_resource;
+    OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+                                           &tree_ensemble_resource));
+    mutex_lock l(*tree_ensemble_resource->get_mutex());
+    core::ScopedUnref unref_me(tree_ensemble_resource);
+
+    // Get the stamp token.
+    const Tensor* stamp_token_t;
+    OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
+    int64 stamp_token = stamp_token_t->scalar<int64>()();
+
+    // Get the tree ensemble proto.
+    const Tensor* tree_ensemble_serialized_t;
+    OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized",
+                                           &tree_ensemble_serialized_t));
+    // Deallocate all the previous objects on the resource.
+    tree_ensemble_resource->Reset();
+    OP_REQUIRES(
+        context,
+        tree_ensemble_resource->InitFromSerialized(
+            tree_ensemble_serialized_t->scalar<string>()(), stamp_token),
+        errors::InvalidArgument("Unable to parse tree ensemble proto."));
+  }
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("BoostedTreesDeserializeEnsemble").Device(DEVICE_CPU),
+    BoostedTreesDeserializeEnsembleOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
new file mode 100644 (file)
index 0000000..2ea12c5
--- /dev/null
@@ -0,0 +1,301 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/boosted_trees/resources.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+int32 BoostedTreesEnsembleResource::next_node(
+    const int32 tree_id, const int32 node_id, const int32 index_in_batch,
+    const std::vector<TTypes<int32>::ConstVec>& bucketized_features) const {
+  DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+  DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
+  const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
+  DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+  const auto& split = node.bucketized_split();
+  if (bucketized_features[split.feature_id()](index_in_batch) <=
+      split.threshold()) {
+    return split.left_id();
+  } else {
+    return split.right_id();
+  }
+}
+
+float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
+                                               const int32 node_id) const {
+  DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+  DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
+  const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
+  if (node.node_case() == boosted_trees::Node::kLeaf) {
+    return node.leaf().scalar();
+  } else {
+    return node.metadata().original_leaf().scalar();
+  }
+}
+
+void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
+  tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
+      tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
+
+  const int n_trees = num_trees();
+
+  if (n_trees <= 0 ||
+      // Checks if we are building the first layer of the dummy empty tree
+      ((n_trees == 1 || IsTreeFinalized(n_trees - 2)) &&
+       (tree_ensemble_->trees(n_trees - 1).nodes_size() == 1))) {
+    tree_ensemble_->mutable_growing_metadata()->set_num_trees_attempted(
+        tree_ensemble_->growing_metadata().num_trees_attempted() + 1);
+  }
+}
+
+// Add a tree to the ensemble and returns a new tree_id.
+int32 BoostedTreesEnsembleResource::AddNewTree(const float weight) {
+  const int32 new_tree_id = tree_ensemble_->trees_size();
+  auto* node = tree_ensemble_->add_trees()->add_nodes();
+  node->mutable_leaf()->set_scalar(0.0);
+  tree_ensemble_->add_tree_weights(weight);
+  tree_ensemble_->add_tree_metadata();
+
+  return new_tree_id;
+}
+
+void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
+    const int32 tree_id, const int32 node_id, const int32 feature_id,
+    const int32 threshold, const float gain, const float left_contrib,
+    const float right_contrib, int32* left_node_id, int32* right_node_id) {
+  auto* tree = tree_ensemble_->mutable_trees(tree_id);
+  auto* node = tree->mutable_nodes(node_id);
+  DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf);
+  float prev_node_value = node->leaf().scalar();
+  *left_node_id = tree->nodes_size();
+  *right_node_id = *left_node_id + 1;
+  auto* left_node = tree->add_nodes();
+  auto* right_node = tree->add_nodes();
+  if (node_id != 0) {
+    // Save previous leaf value if it is not the first leaf in the tree.
+    node->mutable_metadata()->mutable_original_leaf()->Swap(
+        node->mutable_leaf());
+  }
+  node->mutable_metadata()->set_gain(gain);
+  auto* new_split = node->mutable_bucketized_split();
+  new_split->set_feature_id(feature_id);
+  new_split->set_threshold(threshold);
+  new_split->set_left_id(*left_node_id);
+  new_split->set_right_id(*right_node_id);
+  // TODO(npononareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE.
+  left_node->mutable_leaf()->set_scalar(prev_node_value + left_contrib);
+  right_node->mutable_leaf()->set_scalar(prev_node_value + right_contrib);
+}
+
+void BoostedTreesEnsembleResource::Reset() {
+  // Reset stamp.
+  set_stamp(-1);
+
+  // Clear tree ensemle.
+  arena_.Reset();
+  CHECK_EQ(0, arena_.SpaceAllocated());
+  tree_ensemble_ =
+      protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(&arena_);
+}
+
+void BoostedTreesEnsembleResource::PostPruneTree(const int32 current_tree) {
+  // No-op if tree is empty.
+  auto* tree = tree_ensemble_->mutable_trees(current_tree);
+  int32 num_nodes = tree->nodes_size();
+  if (num_nodes == 0) {
+    return;
+  }
+
+  std::vector<int32> nodes_to_delete;
+  // If a node was pruned, we need to save the change of the prediction from
+  // this node to its parent, as well as the parent id.
+  std::vector<std::pair<int32, float>> nodes_changes;
+  nodes_changes.reserve(num_nodes);
+  for (int32 i = 0; i < num_nodes; ++i) {
+    nodes_changes.emplace_back(i, 0.0);
+  }
+  // Prune the tree recursively starting from the root. Each node that has
+  // negative gain and only leaf children will be pruned recursively up from
+  // the bottom of the tree. This method returns the list of nodes pruned, and
+  // updates the nodes in the tree not to refer to those pruned nodes.
+  RecursivelyDoPostPrunePreparation(current_tree, 0, &nodes_to_delete,
+                                    &nodes_changes);
+
+  if (nodes_to_delete.empty()) {
+    // No pruning happened, and no post-processing needed.
+    return;
+  }
+
+  // Sort node ids so they are in asc order.
+  std::sort(nodes_to_delete.begin(), nodes_to_delete.end());
+
+  // We need to
+  // - update split left and right children ids with new indices
+  // - actually remove the nodes that need to be removed
+  // - save the information about pruned node so we could recover the
+  // predictions from cache. Build a map for old node index=>new node index.
+  // nodes_to_delete contains nodes who's indices should be skipped, in
+  // ascending order. Save the information about new indices into meta.
+  std::map<int32, int32> old_to_new_ids;
+  int32 new_index = 0;
+  int32 index_for_deleted = 0;
+  auto* post_prune_meta = tree_ensemble_->mutable_tree_metadata(current_tree)
+                              ->mutable_post_pruned_nodes_meta();
+
+  for (int32 i = 0; i < num_nodes; ++i) {
+    if (index_for_deleted < nodes_to_delete.size() &&
+        i == nodes_to_delete[index_for_deleted]) {
+      // Node i will get removed,
+      ++index_for_deleted;
+      // Update meta info that will allow us to use cached predictions from
+      // those nodes.
+      int32 new_id;
+      float logit_change;
+      CalculateParentAndLogitUpdate(i, nodes_changes, &new_id, &logit_change);
+      auto* meta = post_prune_meta->Add();
+      meta->set_new_node_id(old_to_new_ids[new_id]);
+      meta->set_logit_change(logit_change);
+    } else {
+      old_to_new_ids[i] = new_index++;
+      auto* meta = post_prune_meta->Add();
+      // Update meta info that will allow us to use cached predictions from
+      // those nodes.
+      meta->set_new_node_id(old_to_new_ids[i]);
+      meta->set_logit_change(0.0);
+    }
+  }
+  index_for_deleted = 0;
+  int32 i = 0;
+  protobuf::RepeatedPtrField<boosted_trees::Node> new_nodes;
+  new_nodes.Reserve(old_to_new_ids.size());
+  for (auto node : *(tree->mutable_nodes())) {
+    if (index_for_deleted < nodes_to_delete.size() &&
+        i == nodes_to_delete[index_for_deleted]) {
+      ++index_for_deleted;
+      ++i;
+      continue;
+    } else {
+      if (node.node_case() == boosted_trees::Node::kBucketizedSplit) {
+        node.mutable_bucketized_split()->set_left_id(
+            old_to_new_ids[node.bucketized_split().left_id()]);
+        node.mutable_bucketized_split()->set_right_id(
+            old_to_new_ids[node.bucketized_split().right_id()]);
+      }
+      *new_nodes.Add() = std::move(node);
+    }
+    ++i;
+  }
+  // Replace all the nodes in a tree with the ones we keep.
+  *tree->mutable_nodes() = std::move(new_nodes);
+
+  // Note that if the whole tree got pruned, we will end up with one node.
+  // We can't remove that tree because it will cause problems with cache.
+}
+
+void BoostedTreesEnsembleResource::GetPostPruneCorrection(
+    const int32 tree_id, const int32 initial_node_id, int32* current_node_id,
+    float* logit_update) const {
+  DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+  if (IsTreeFinalized(tree_id) && IsTreePostPruned(tree_id)) {
+    DCHECK_LT(
+        initial_node_id,
+        tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size());
+    const auto& meta =
+        tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta(
+            initial_node_id);
+    *current_node_id = meta.new_node_id();
+    *logit_update += meta.logit_change();
+  }
+}
+
+bool BoostedTreesEnsembleResource::IsTerminalSplitNode(
+    const int32 tree_id, const int32 node_id) const {
+  const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
+  DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+  const int32 left_id = node.bucketized_split().left_id();
+  const int32 right_id = node.bucketized_split().right_id();
+  return is_leaf(tree_id, left_id) && is_leaf(tree_id, right_id);
+}
+
+// For each pruned node, finds the leaf where it finally ended up and
+// calculates the total update from that pruned node prediction.
+void BoostedTreesEnsembleResource::CalculateParentAndLogitUpdate(
+    const int32 start_node_id,
+    const std::vector<std::pair<int32, float>>& nodes_change, int32* parent_id,
+    float* change) const {
+  *change = 0.0;
+  int32 node_id = start_node_id;
+  int32 parent = nodes_change[node_id].first;
+
+  while (parent != node_id) {
+    (*change) += nodes_change[node_id].second;
+    node_id = parent;
+    parent = nodes_change[node_id].first;
+  }
+  *parent_id = parent;
+}
+
+void BoostedTreesEnsembleResource::RecursivelyDoPostPrunePreparation(
+    const int32 tree_id, const int32 node_id,
+    std::vector<int32>* nodes_to_delete,
+    std::vector<std::pair<int32, float>>* nodes_meta) {
+  auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
+  DCHECK_NE(node->node_case(), boosted_trees::Node::NODE_NOT_SET);
+  // Base case when we reach a leaf.
+  if (node->node_case() == boosted_trees::Node::kLeaf) {
+    return;
+  }
+
+  // Traverse node children first and recursively prune their sub-trees.
+  RecursivelyDoPostPrunePreparation(tree_id, node->bucketized_split().left_id(),
+                                    nodes_to_delete, nodes_meta);
+  RecursivelyDoPostPrunePreparation(tree_id,
+                                    node->bucketized_split().right_id(),
+                                    nodes_to_delete, nodes_meta);
+
+  // Two conditions must be satisfied to prune the node:
+  // 1- The split gain is negative.
+  // 2- After depth-first pruning, the node only has leaf children.
+  const auto& node_metadata = node->metadata();
+  if (node_metadata.gain() < 0 && IsTerminalSplitNode(tree_id, node_id)) {
+    const int32 left_id = node->bucketized_split().left_id();
+    const int32 right_id = node->bucketized_split().right_id();
+
+    // Save children that need to be deleted.
+    nodes_to_delete->push_back(left_id);
+    nodes_to_delete->push_back(right_id);
+
+    // Change node back into leaf.
+    *node->mutable_leaf() = node_metadata.original_leaf();
+    const float parent_value = node_value(tree_id, node_id);
+
+    // Save the old values of weights of children.
+    (*nodes_meta)[left_id].first = node_id;
+    (*nodes_meta)[left_id].second = parent_value - node_value(tree_id, left_id);
+
+    (*nodes_meta)[right_id].first = node_id;
+    (*nodes_meta)[right_id].second =
+        parent_value - node_value(tree_id, right_id);
+
+    // Clear gain for leaf node.
+    node->clear_metadata();
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/resources.h b/tensorflow/core/kernels/boosted_trees/resources.h
new file mode 100644 (file)
index 0000000..c82588b
--- /dev/null
@@ -0,0 +1,221 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+// A StampedResource is a resource that has a stamp token associated with it.
+// Before reading from or applying updates to the resource, the stamp should
+// be checked to verify that the update is not stale.
+class StampedResource : public ResourceBase {
+ public:
+  StampedResource() : stamp_(-1) {}
+
+  bool is_stamp_valid(int64 stamp) const { return stamp_ == stamp; }
+
+  int64 stamp() const { return stamp_; }
+  void set_stamp(int64 stamp) { stamp_ = stamp; }
+
+ private:
+  int64 stamp_;
+};
+
+// Keep a tree ensemble in memory for efficient evaluation and mutation.
+class BoostedTreesEnsembleResource : public StampedResource {
+ public:
+  // Constructor.
+  BoostedTreesEnsembleResource()
+      : tree_ensemble_(
+            protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
+                &arena_)) {}
+
+  string DebugString() override {
+    return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
+                           "]");
+  }
+
+  bool InitFromSerialized(const string& serialized, const int64 stamp_token) {
+    CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
+    if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
+      set_stamp(stamp_token);
+      return true;
+    }
+    return false;
+  }
+
+  string SerializeAsString() const {
+    return tree_ensemble_->SerializeAsString();
+  }
+
+  int32 num_trees() const { return tree_ensemble_->trees_size(); }
+
+  // Find the next node to which the example (specified by index_in_batch)
+  // traverses down from the current node indicated by tree_id and node_id.
+  // Args:
+  //   tree_id: the index of the tree in the ensemble.
+  //   node_id: the index of the node within the tree.
+  //   index_in_batch: the index of the example within the batch (relevant to
+  //       the index of the row to read in each bucketized_features).
+  //   bucketized_features: vector of feature Vectors.
+  int32 next_node(
+      const int32 tree_id, const int32 node_id, const int32 index_in_batch,
+      const std::vector<TTypes<int32>::ConstVec>& bucketized_features) const;
+
+  float node_value(const int32 tree_id, const int32 node_id) const;
+
+  int32 GetNumLayersGrown(const int32 tree_id) const {
+    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+    return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
+  }
+
+  void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const {
+    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+    tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
+        new_num_layers);
+  }
+
+  void UpdateGrowingMetadata() const;
+
+  int32 GetNumLayersAttempted() {
+    return tree_ensemble_->growing_metadata().num_layers_attempted();
+  }
+
+  bool is_leaf(const int32 tree_id, const int32 node_id) const {
+    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+    DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
+    const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
+    return node.node_case() == boosted_trees::Node::kLeaf;
+  }
+
+  int32 feature_id(const int32 tree_id, const int32 node_id) const {
+    const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+    DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+    return node.bucketized_split().feature_id();
+  }
+
+  int32 bucket_threshold(const int32 tree_id, const int32 node_id) const {
+    const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+    DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+    return node.bucketized_split().threshold();
+  }
+
+  int32 left_id(const int32 tree_id, const int32 node_id) const {
+    const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+    DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+    return node.bucketized_split().left_id();
+  }
+
+  int32 right_id(const int32 tree_id, const int32 node_id) const {
+    const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+    DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+    return node.bucketized_split().right_id();
+  }
+
+  // Add a tree to the ensemble and returns a new tree_id.
+  int32 AddNewTree(const float weight);
+
+  // Grows the tree by adding a split and leaves.
+  void AddBucketizedSplitNode(const int32 tree_id, const int32 node_id,
+                              const int32 feature_id, const int32 threshold,
+                              const float gain, const float left_contrib,
+                              const float right_contrib, int32* left_node_id,
+                              int32* right_node_id);
+
+  // Retrieves tree weights and returns as a vector.
+  // It involves a copy, so should be called only sparingly (like once per
+  // iteration, not per example).
+  std::vector<float> GetTreeWeights() const {
+    return {tree_ensemble_->tree_weights().begin(),
+            tree_ensemble_->tree_weights().end()};
+  }
+
+  float GetTreeWeight(const int32 tree_id) const {
+    return tree_ensemble_->tree_weights(tree_id);
+  }
+
+  float IsTreeFinalized(const int32 tree_id) const {
+    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+    return tree_ensemble_->tree_metadata(tree_id).is_finalized();
+  }
+
+  float IsTreePostPruned(const int32 tree_id) const {
+    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+    return tree_ensemble_->tree_metadata(tree_id)
+               .post_pruned_nodes_meta_size() > 0;
+  }
+
+  void SetIsFinalized(const int32 tree_id, const bool is_finalized) {
+    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+    return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
+        is_finalized);
+  }
+
+  // Sets the weight of i'th tree.
+  void SetTreeWeight(const int32 tree_id, const float weight) {
+    DCHECK_GE(tree_id, 0);
+    DCHECK_LT(tree_id, num_trees());
+    tree_ensemble_->set_tree_weights(tree_id, weight);
+  }
+
+  // Resets the resource and frees the protos in arena.
+  // Caller needs to hold the mutex lock while calling this.
+  virtual void Reset();
+
+  void PostPruneTree(const int32 current_tree);
+
+  // For a given node, returns the id in a pruned tree, as well as correction
+  // to the cached prediction that should be applied. If tree was not
+  // post-pruned, current_node_id will be equal to initial_node_id and logit
+  // update will be equal to zero.
+  void GetPostPruneCorrection(const int32 tree_id, const int32 initial_node_id,
+                              int32* current_node_id,
+                              float* logit_update) const;
+  mutex* get_mutex() { return &mu_; }
+
+ private:
+  // Helper method to check whether a node is a terminal node in that it
+  // only has leaf nodes as children.
+  bool IsTerminalSplitNode(const int32 tree_id, const int32 node_id) const;
+
+  // For each pruned node, finds the leaf where it finally ended up and
+  // calculates the total update from that pruned node prediction.
+  void CalculateParentAndLogitUpdate(
+      const int32 start_node_id,
+      const std::vector<std::pair<int32, float>>& nodes_change,
+      int32* parent_id, float* change) const;
+
+  // Helper method to collect the information to be used to prune some nodes in
+  // the tree.
+  void RecursivelyDoPostPrunePreparation(
+      const int32 tree_id, const int32 node_id,
+      std::vector<int32>* nodes_to_delete,
+      std::vector<std::pair<int32, float>>* nodes_meta);
+
+ protected:
+  protobuf::Arena arena_;
+  mutex mu_;
+  boosted_trees::TreeEnsemble* tree_ensemble_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
new file mode 100644 (file)
index 0000000..33fdab6
--- /dev/null
@@ -0,0 +1,296 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+
+namespace {
+const float kEps = 1e-15;
+}  // namespace
+
+class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
+ public:
+  explicit BoostedTreesCalculateBestGainsPerFeatureOp(
+      OpKernelConstruction* const context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("l1", &l1_));
+    OP_REQUIRES_OK(context, context->GetAttr("l2", &l2_));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("tree_complexity", &tree_complexity_));
+    OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
+    OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
+  }
+
+  void Compute(OpKernelContext* const context) override {
+    // node_id_range
+    const Tensor* node_id_range_t;
+    OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
+    const auto node_id_range = node_id_range_t->vec<int32>();
+    int32 node_id_first = node_id_range(0);
+    int32 node_id_last = node_id_range(1);  // inclusive.
+    // stats_summary_list
+    OpInputList stats_summary_list;
+    OP_REQUIRES_OK(context, context->input_list("stats_summary_list",
+                                                &stats_summary_list));
+    const int64 num_buckets = stats_summary_list[0].dim_size(1);
+    std::vector<TTypes<float, 3>::ConstTensor> stats_summary;
+    stats_summary.reserve(stats_summary_list.size());
+    for (const auto& tensor : stats_summary_list) {
+      stats_summary.emplace_back(tensor.tensor<float, 3>());
+    }
+
+    // Allocate output lists of tensors:
+    OpOutputList output_node_ids_list;
+    OP_REQUIRES_OK(
+        context, context->output_list("node_ids_list", &output_node_ids_list));
+    OpOutputList output_gains_list;
+    OP_REQUIRES_OK(context,
+                   context->output_list("gains_list", &output_gains_list));
+    OpOutputList output_thresholds_list;
+    OP_REQUIRES_OK(context, context->output_list("thresholds_list",
+                                                 &output_thresholds_list));
+    OpOutputList output_left_node_contribs_list;
+    OP_REQUIRES_OK(context,
+                   context->output_list("left_node_contribs_list",
+                                        &output_left_node_contribs_list));
+    OpOutputList output_right_node_contribs_list;
+    OP_REQUIRES_OK(context,
+                   context->output_list("right_node_contribs_list",
+                                        &output_right_node_contribs_list));
+
+    // Get the best split info per node for each feature.
+    for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
+      std::vector<float> cum_grad;
+      std::vector<float> cum_hess;
+      cum_grad.reserve(num_buckets);
+      cum_hess.reserve(num_buckets);
+
+      std::vector<int32> output_node_ids;
+      std::vector<float> output_gains;
+      std::vector<int32> output_thresholds;
+      std::vector<float> output_left_node_contribs;
+      std::vector<float> output_right_node_contribs;
+      for (int node_id = node_id_first; node_id <= node_id_last; ++node_id) {
+        // Calculate gains.
+        cum_grad.clear();
+        cum_hess.clear();
+        float total_grad = 0.0;
+        float total_hess = 0.0;
+        for (int bucket = 0; bucket < num_buckets; ++bucket) {
+          // TODO(nponomareva): Consider multi-dimensional gradients/hessians.
+          total_grad += stats_summary[feature_idx](node_id, bucket, 0);
+          total_hess += stats_summary[feature_idx](node_id, bucket, 1);
+          cum_grad.push_back(total_grad);
+          cum_hess.push_back(total_hess);
+        }
+        float best_gain = std::numeric_limits<float>::lowest();
+        float best_bucket = 0;
+        float best_contrib_for_left = 0.0;
+        float best_contrib_for_right = 0.0;
+        // Parent gain.
+        float parent_gain;
+        float unused;
+        CalculateWeightsAndGains(total_grad, total_hess, &unused, &parent_gain);
+
+        for (int bucket = 0; bucket < num_buckets; ++bucket) {
+          const float cum_grad_bucket = cum_grad[bucket];
+          const float cum_hess_bucket = cum_hess[bucket];
+          // Left child.
+          float contrib_for_left;
+          float gain_for_left;
+          CalculateWeightsAndGains(cum_grad_bucket, cum_hess_bucket,
+                                   &contrib_for_left, &gain_for_left);
+          // Right child.
+          float contrib_for_right;
+          float gain_for_right;
+          CalculateWeightsAndGains(total_grad - cum_grad_bucket,
+                                   total_hess - cum_hess_bucket,
+                                   &contrib_for_right, &gain_for_right);
+
+          if (gain_for_left + gain_for_right > best_gain) {
+            best_gain = gain_for_left + gain_for_right;
+            best_bucket = bucket;
+            best_contrib_for_left = contrib_for_left;
+            best_contrib_for_right = contrib_for_right;
+          }
+        }  // for bucket
+        output_node_ids.push_back(node_id);
+        // Remove the parent gain for the parent node.
+        output_gains.push_back(best_gain - parent_gain);
+        output_thresholds.push_back(best_bucket);
+        output_left_node_contribs.push_back(best_contrib_for_left);
+        output_right_node_contribs.push_back(best_contrib_for_right);
+      }  // for node_id
+      const int num_nodes = output_node_ids.size();
+      // output_node_ids
+      Tensor* output_node_ids_t;
+      OP_REQUIRES_OK(context,
+                     output_node_ids_list.allocate(feature_idx, {num_nodes},
+                                                   &output_node_ids_t));
+      auto output_node_ids_vec = output_node_ids_t->vec<int32>();
+      // output_gains
+      Tensor* output_gains_t;
+      OP_REQUIRES_OK(context, output_gains_list.allocate(
+                                  feature_idx, {num_nodes}, &output_gains_t));
+      auto output_gains_vec = output_gains_t->vec<float>();
+      // output_thresholds
+      Tensor* output_thresholds_t;
+      OP_REQUIRES_OK(context,
+                     output_thresholds_list.allocate(feature_idx, {num_nodes},
+                                                     &output_thresholds_t));
+      auto output_thresholds_vec = output_thresholds_t->vec<int32>();
+      // output_left_node_contribs
+      Tensor* output_left_node_contribs_t;
+      OP_REQUIRES_OK(context, output_left_node_contribs_list.allocate(
+                                  feature_idx, {num_nodes, 1},
+                                  &output_left_node_contribs_t));
+      auto output_left_node_contribs_matrix =
+          output_left_node_contribs_t->matrix<float>();
+      // output_right_node_contribs
+      Tensor* output_right_node_contribs_t;
+      OP_REQUIRES_OK(context, output_right_node_contribs_list.allocate(
+                                  feature_idx, {num_nodes, 1},
+                                  &output_right_node_contribs_t));
+      auto output_right_node_contribs_matrix =
+          output_right_node_contribs_t->matrix<float>();
+      // Sets output tensors from vectors.
+      for (int i = 0; i < num_nodes; ++i) {
+        output_node_ids_vec(i) = output_node_ids[i];
+        // Adjust the gains to penalize by tree complexity.
+        output_gains_vec(i) = output_gains[i] - tree_complexity_;
+        output_thresholds_vec(i) = output_thresholds[i];
+        // Logits are 1-dimensional for now.
+        // TODO(nponomareva): Consider multi-dimensional logits.
+        output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i];
+        output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i];
+      }
+    }  // for f
+  }
+
+ private:
+  void CalculateWeightsAndGains(const float g, const float h, float* weight,
+                                float* gain) {
+    //
+    // The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is
+    // (g+l1*sgn(w))^2/(h+l2).
+    // This is because for each leaf we optimize
+    // 1/2(h+l2)*w^2+g*w+l1*abs(w)
+    float g_with_l1 = g;
+    // Apply L1 regularization.
+    // 1) Assume w>0 => w=-(g+l1)/(h+l2)=> g+l1 < 0 => g < -l1
+    // 2) Assume w<0 => w=-(g-l1)/(h+l2)=> g-l1 > 0 => g > l1
+    // For g from (-l1, l1), thus there is no solution => set to 0.
+    if (l1_ > 0) {
+      if (g > l1_) {
+        g_with_l1 -= l1_;
+      } else if (g < -l1_) {
+        g_with_l1 += l1_;
+      } else {
+        *weight = 0.0;
+        *gain = 0.0;
+        return;
+      }
+    }
+    // Apply L2 regularization.
+    if (h + l2_ <= kEps) {
+      // Avoid division by 0 or infinitesimal.
+      *weight = 0;
+      *gain = 0;
+    } else {
+      *weight = -g_with_l1 / (h + l2_);
+      *gain = -g_with_l1 * (*weight);
+    }
+  }
+
+  float l1_;
+  float l2_;
+  float tree_complexity_;
+  int max_splits_;
+  int num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("BoostedTreesCalculateBestGainsPerFeature").Device(DEVICE_CPU),
+    BoostedTreesCalculateBestGainsPerFeatureOp);
+
+class BoostedTreesMakeStatsSummaryOp : public OpKernel {
+ public:
+  explicit BoostedTreesMakeStatsSummaryOp(OpKernelConstruction* const context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("max_splits", &max_splits_));
+    OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
+    OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
+  }
+
+  void Compute(OpKernelContext* const context) override {
+    // node_ids
+    const Tensor* node_ids_t;
+    OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
+    const auto node_ids = node_ids_t->vec<int32>();
+    // gradients
+    const Tensor* gradients_t;
+    OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
+    const auto gradients = gradients_t->matrix<float>();
+    // hessians
+    const Tensor* hessians_t;
+    OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
+    const auto hessians = hessians_t->matrix<float>();
+    // bucketized_features
+    OpInputList bucketized_features_list;
+    OP_REQUIRES_OK(context, context->input_list("bucketized_features_list",
+                                                &bucketized_features_list));
+    std::vector<tensorflow::TTypes<int32>::ConstVec> bucketized_features;
+    bucketized_features.reserve(num_features_);
+    for (const Tensor& tensor : bucketized_features_list) {
+      bucketized_features.emplace_back(tensor.vec<int32>());
+    }
+
+    // Infer batch size.
+    const int64 batch_size = node_ids_t->dim_size(0);
+    // Allocate output stats tensor (Rank 4).
+    Tensor* output_stats_summary_t = nullptr;
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                "stats_summary",
+                                {num_features_, max_splits_, num_buckets_, 2},
+                                &output_stats_summary_t));
+    auto output_stats_summary = output_stats_summary_t->tensor<float, 4>();
+    output_stats_summary.setZero();
+
+    // Partition by node, and then bucketize.
+    for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
+      const auto& features = bucketized_features[feature_idx];
+      for (int i = 0; i < batch_size; ++i) {
+        const int32 node = node_ids(i);
+        const int32 bucket = features(i);
+        output_stats_summary(feature_idx, node, bucket, 0) += gradients(i, 0);
+        output_stats_summary(feature_idx, node, bucket, 1) += hessians(i, 0);
+      }
+    }
+  }
+
+ private:
+  int max_splits_;
+  int num_buckets_;
+  int num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesMakeStatsSummary").Device(DEVICE_CPU),
+                        BoostedTreesMakeStatsSummaryOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/training_ops.cc b/tensorflow/core/kernels/boosted_trees/training_ops.cc
new file mode 100644 (file)
index 0000000..b9ded40
--- /dev/null
@@ -0,0 +1,219 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/boosted_trees/resources.h"
+
+namespace tensorflow {
+
+namespace {
+constexpr float kLayerByLayerTreeWeight = 1.0;
+
+// TODO(nponomareva, youngheek): consider using vector.
+struct SplitCandidate {
+  SplitCandidate() {}
+
+  // Index in the list of the feature ids.
+  int64 feature_idx;
+
+  // Index in the tensor of node_ids for the feature with idx feature_idx.
+  int64 candidate_idx;
+
+  float gain;
+};
+
+enum PruningMode { kNoPruning = 0, kPrePruning = 1, kPostPruning = 2 };
+
+}  // namespace
+
+class BoostedTreesUpdateEnsembleOp : public OpKernel {
+ public:
+  explicit BoostedTreesUpdateEnsembleOp(OpKernelConstruction* const context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("max_depth", &max_depth_));
+    OP_REQUIRES_OK(context, context->GetAttr("learning_rate", &learning_rate_));
+    OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
+
+    int32 pruning_index;
+    OP_REQUIRES_OK(context, context->GetAttr("pruning_mode", &pruning_index));
+    pruning_mode_ = static_cast<PruningMode>(pruning_index);
+  }
+
+  void Compute(OpKernelContext* const context) override {
+    // Get decision tree ensemble.
+    BoostedTreesEnsembleResource* ensemble_resource;
+    OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+                                           &ensemble_resource));
+    core::ScopedUnref unref_me(ensemble_resource);
+    mutex_lock l(*ensemble_resource->get_mutex());
+    // Increase the ensemble stamp.
+    ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
+
+    // Read node ids, gains, thresholds and node contribs.
+    OpInputList node_ids_list;
+    OpInputList gains_list;
+    OpInputList thresholds_list;
+    OpInputList left_node_contribs;
+    OpInputList right_node_contribs;
+    OP_REQUIRES_OK(context, context->input_list("node_ids", &node_ids_list));
+    OP_REQUIRES_OK(context, context->input_list("gains", &gains_list));
+    OP_REQUIRES_OK(context,
+                   context->input_list("thresholds", &thresholds_list));
+    OP_REQUIRES_OK(context, context->input_list("left_node_contribs",
+                                                &left_node_contribs));
+    OP_REQUIRES_OK(context, context->input_list("right_node_contribs",
+                                                &right_node_contribs));
+
+    const Tensor* feature_ids_t;
+    OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
+
+    auto feature_ids = feature_ids_t->vec<int32>();
+
+    // Find best splits for each active node.
+    std::map<int32, SplitCandidate> best_splits;
+    FindBestSplitsPerNode(context, node_ids_list, gains_list, &best_splits);
+
+    int32 current_tree =
+        UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
+
+    // No-op if no new splits can be considered.
+    if (best_splits.empty()) {
+      LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
+      return;
+    }
+
+    const int32 new_num_layers =
+        ensemble_resource->GetNumLayersGrown(current_tree) + 1;
+    VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #"
+            << current_tree << " of ensemble of " << current_tree + 1
+            << " trees.";
+    bool split_happened = false;
+    // Add the splits to the tree.
+    for (auto& split_entry : best_splits) {
+      const int32 node_id = split_entry.first;
+      const SplitCandidate& candidate = split_entry.second;
+
+      const int64 feature_idx = candidate.feature_idx;
+      const int64 candidate_idx = candidate.candidate_idx;
+
+      const int32 feature_id = feature_ids(feature_idx);
+      const int32 threshold =
+          thresholds_list[feature_idx].vec<int32>()(candidate_idx);
+      const float gain = gains_list[feature_idx].vec<float>()(candidate_idx);
+
+      if (pruning_mode_ == kPrePruning) {
+        // Don't consider negative splits if we're pre-pruning the tree.
+        // Note that zero-gain splits are acceptable.
+        if (gain < 0) {
+          continue;
+        }
+      }
+      // For now assume that the weights vectors are one dimensional.
+      // TODO(nponomareva): change here for multiclass.
+      const float left_contrib =
+          learning_rate_ *
+          left_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0);
+      const float right_contrib =
+          learning_rate_ *
+          right_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0);
+
+      // unused.
+      int32 left_node_id;
+      int32 right_node_id;
+
+      ensemble_resource->AddBucketizedSplitNode(
+          current_tree, node_id, feature_id, threshold, gain, left_contrib,
+          right_contrib, &left_node_id, &right_node_id);
+      split_happened = true;
+    }
+    if (split_happened) {
+      // Update growable tree metadata.
+      ensemble_resource->SetNumLayersGrown(current_tree, new_num_layers);
+      // Finalize the tree if needed.
+      if (ensemble_resource->GetNumLayersGrown(current_tree) >= max_depth_) {
+        ensemble_resource->SetIsFinalized(current_tree, true);
+        if (pruning_mode_ == kPostPruning) {
+          ensemble_resource->PostPruneTree(current_tree);
+        }
+        if (ensemble_resource->num_trees() > 0) {
+          // Create a dummy new tree with an empty node.
+          ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
+        }
+      }
+    }
+  }
+
+ private:
+  int32 UpdateGlobalAttemptsAndRetrieveGrowableTree(
+      BoostedTreesEnsembleResource* const ensemble_resource) {
+    int32 num_trees = ensemble_resource->num_trees();
+    int32 current_tree = num_trees - 1;
+
+    // Increment global attempt stats.
+    ensemble_resource->UpdateGrowingMetadata();
+
+    // Note we don't set tree weight to be equal to learning rate, since we
+    // apply learning rate to leaf weights instead, when doing layer-by-layer
+    // boosting.
+    if (num_trees <= 0) {
+      // Create a new tree with a no-op leaf.
+      current_tree = ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
+    }
+    return current_tree;
+  }
+
+  // Helper method which effectively does a reduce over all split candidates
+  // and finds the best split for each node.
+  void FindBestSplitsPerNode(
+      OpKernelContext* const context, const OpInputList& node_ids_list,
+      const OpInputList& gains_list,
+      std::map<int32, SplitCandidate>* best_split_per_node) {
+    // Find best split per node going through every feature candidate.
+    for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
+      const auto& node_ids = node_ids_list[feature_idx].vec<int32>();
+      const auto& gains = gains_list[feature_idx].vec<float>();
+
+      for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
+           ++candidate_idx) {
+        // Get current split candidate.
+        const auto& node_id = node_ids(candidate_idx);
+        const auto& gain = gains(candidate_idx);
+
+        auto best_split_it = best_split_per_node->find(node_id);
+        SplitCandidate candidate;
+        candidate.feature_idx = feature_idx;
+        candidate.candidate_idx = candidate_idx;
+        candidate.gain = gain;
+
+        if (best_split_it == best_split_per_node->end() ||
+            gain > best_split_it->second.gain) {
+          (*best_split_per_node)[node_id] = candidate;
+        }
+      }
+    }
+  }
+
+ private:
+  int32 num_features_;
+  float learning_rate_;
+  int32 max_depth_;
+  PruningMode pruning_mode_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsemble").Device(DEVICE_CPU),
+                        BoostedTreesUpdateEnsembleOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
new file mode 100644 (file)
index 0000000..297e946
--- /dev/null
@@ -0,0 +1,319 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource);
+
+REGISTER_OP("IsBoostedTreesEnsembleInitialized")
+    .Input("tree_ensemble_handle: resource")
+    .Output("is_initialized: bool")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused_input;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+      c->set_output(0, c->Scalar());
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesCalculateBestGainsPerFeature")
+    .Input("node_id_range: int32")
+    .Input("stats_summary_list: num_features * float32")
+    .Attr("l1: float")
+    .Attr("l2: float")
+    .Attr("tree_complexity: float")
+    .Attr("max_splits: int >= 1")
+    .Attr("num_features: int >= 1")  // not passed but populated automatically.
+    .Output("node_ids_list: num_features * int32")
+    .Output("gains_list: num_features * float32")
+    .Output("thresholds_list: num_features * int32")
+    .Output("left_node_contribs_list: num_features * float32")
+    .Output("right_node_contribs_list: num_features * float32")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      // Confirms the rank of the inputs and sets the shape of the outputs.
+      int max_splits;
+      int num_features;
+      float l1, l2, tree_complexity;
+      TF_RETURN_IF_ERROR(c->GetAttr("l1", &l1));
+      if (l1 < 0) {
+        return errors::InvalidArgument("l1 must be non-negative.");
+      }
+      TF_RETURN_IF_ERROR(c->GetAttr("l2", &l2));
+      if (l2 < 0) {
+        return errors::InvalidArgument("l2 must be non-negative.");
+      }
+      TF_RETURN_IF_ERROR(c->GetAttr("tree_complexity", &tree_complexity));
+      if (tree_complexity < 0) {
+        return errors::InvalidArgument("Tree complexity must be non-negative.");
+      }
+      TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
+      TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+      shape_inference::ShapeHandle node_id_range_shape;
+      shape_inference::ShapeHandle unused_shape;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
+      TF_RETURN_IF_ERROR(
+          c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
+      // Checks that all stats summary entries are of the same shape.
+      shape_inference::ShapeHandle summary_shape_base;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &summary_shape_base));
+      TF_RETURN_IF_ERROR(c->Merge(summary_shape_base,
+                                  c->MakeShape({max_splits, -1, 2}),
+                                  &unused_shape));
+      for (int i = 1; i < num_features; ++i) {
+        shape_inference::ShapeHandle summary_shape;
+        TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + i), 3, &summary_shape));
+        TF_RETURN_IF_ERROR(
+            c->Merge(summary_shape_base, summary_shape, &unused_shape));
+      }
+      // Sets the output lists.
+      std::vector<shape_inference::ShapeHandle> output_shapes_vec(
+          num_features, c->MakeShape({-1}));
+      TF_RETURN_IF_ERROR(c->set_output("node_ids_list", output_shapes_vec));
+      TF_RETURN_IF_ERROR(c->set_output("gains_list", output_shapes_vec));
+      TF_RETURN_IF_ERROR(c->set_output("thresholds_list", output_shapes_vec));
+      std::vector<shape_inference::ShapeHandle> output_shapes_contribs(
+          num_features, c->MakeShape({-1, 1}));
+      TF_RETURN_IF_ERROR(
+          c->set_output("left_node_contribs_list", output_shapes_contribs));
+      TF_RETURN_IF_ERROR(
+          c->set_output("right_node_contribs_list", output_shapes_contribs));
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesCreateEnsemble")
+    .Input("tree_ensemble_handle: resource")
+    .Input("stamp_token: int64")
+    .Input("tree_ensemble_serialized: string")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused_input;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesDeserializeEnsemble")
+    .Input("tree_ensemble_handle: resource")
+    .Input("stamp_token: int64")
+    .Input("tree_ensemble_serialized: string")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused_input;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesGetEnsembleStates")
+    .Input("tree_ensemble_handle: resource")
+    .Output("stamp_token: int64")
+    .Output("num_trees: int32")
+    .Output("num_finalized_trees: int32")
+    .Output("num_attempted_layers: int32")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused_input;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+      c->set_output(0, c->Scalar());
+      c->set_output(1, c->Scalar());
+      c->set_output(2, c->Scalar());
+      c->set_output(3, c->Scalar());
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesMakeStatsSummary")
+    .Input("node_ids: int32")
+    .Input("gradients: float")
+    .Input("hessians: float")
+    .Input("bucketized_features_list: num_features * int32")
+    .Attr("max_splits: int >= 1")
+    .Attr("num_buckets: int >= 1")
+    .Attr("num_features: int >= 1")
+    .Output("stats_summary: float")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      // Sets the shape of the output as a Rank 4 Tensor.
+      int max_splits;
+      int num_buckets;
+      int num_features;
+      TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
+      TF_RETURN_IF_ERROR(c->GetAttr("num_buckets", &num_buckets));
+      TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+      shape_inference::ShapeHandle node_ids_shape;
+      shape_inference::ShapeHandle gradients_shape;
+      shape_inference::ShapeHandle hessians_shape;
+      shape_inference::ShapeHandle bucketized_feature_shape;
+      shape_inference::ShapeHandle unused_shape;
+      shape_inference::DimensionHandle unused_dim;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
+      TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0),
+                                  c->Dim(gradients_shape, 0), &unused_dim));
+      TF_RETURN_IF_ERROR(
+          c->Merge(gradients_shape, hessians_shape, &unused_shape));
+      for (int f = 0; f < num_features; ++f) {
+        TF_RETURN_IF_ERROR(
+            c->WithRank(c->input(3 + f), 1, &bucketized_feature_shape));
+        TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0),
+                                    c->Dim(bucketized_feature_shape, 0),
+                                    &unused_dim));
+      }
+      c->set_output(0,
+                    c->MakeShape({num_features, max_splits, num_buckets, 2}));
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesPredict")
+    .Input("tree_ensemble_handle: resource")
+    .Input("bucketized_features: num_bucketized_features * int32")
+    .Attr("num_bucketized_features: int >= 1")
+    .Attr("logits_dimension: int")
+    .Attr("max_depth: int >= 1")
+    .Output("logits: float")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle feature_shape;
+      int num_bucketized_features;
+      TF_RETURN_IF_ERROR(
+          c->GetAttr("num_bucketized_features", &num_bucketized_features));
+      shape_inference::ShapeHandle unused_input;
+      for (int i = 0; i < num_bucketized_features; ++i) {
+        TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 1), 1, &feature_shape));
+        // Check that the shapes of all bucketized features are the same.
+        TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input));
+      }
+
+      int logits_dimension;
+      TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
+      auto logits_shape =
+          c->MakeShape({c->Dim(feature_shape, 0), logits_dimension});
+      // Logits.
+      c->set_output(0, logits_shape);
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesSerializeEnsemble")
+    .Input("tree_ensemble_handle: resource")
+    .Output("stamp_token: int64")
+    .Output("tree_ensemble_serialized: string")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused_input;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+      c->set_output(0, c->Scalar());
+      c->set_output(1, c->Scalar());
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesTrainingPredict")
+    .Input("tree_ensemble_handle: resource")
+    .Input("cached_tree_ids: int32")
+    .Input("cached_node_ids: int32")
+    .Input("bucketized_features: num_bucketized_features * int32")
+    .Attr("num_bucketized_features: int >= 1")
+    .Attr("logits_dimension: int")
+    .Attr("max_depth: int >= 1")
+    .Output("partial_logits: float")
+    .Output("tree_ids: int32")
+    .Output("node_ids: int32")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle feature_shape;
+      int num_bucketized_features;
+      TF_RETURN_IF_ERROR(
+          c->GetAttr("num_bucketized_features", &num_bucketized_features));
+
+      int max_depth;
+      TF_RETURN_IF_ERROR(c->GetAttr("max_depth", &max_depth));
+
+      shape_inference::ShapeHandle unused_input;
+      for (int i = 0; i < num_bucketized_features; ++i) {
+        TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 3), 1, &feature_shape));
+        TF_RETURN_IF_ERROR(
+            c->Merge(c->input(i + 3), feature_shape, &unused_input));
+      }
+      // all inputs/outputs except logits should have same shape.
+      TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input));
+      TF_RETURN_IF_ERROR(c->Merge(c->input(2), feature_shape, &unused_input));
+
+      int logits_dimension;
+      TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
+      auto logits_shape =
+          c->MakeShape({c->Dim(feature_shape, 0), logits_dimension});
+      // Partial logits.
+      c->set_output(0, logits_shape);
+      // Tree ids.
+      c->set_output(1, c->MakeShape({c->Dim(feature_shape, 0)}));
+      // Node ids.
+      c->set_output(2, c->MakeShape({c->Dim(feature_shape, 0)}));
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesUpdateEnsemble")
+    .Input("tree_ensemble_handle: resource")
+    .Input("feature_ids: int32")
+    .Input("node_ids: num_features * int32")
+    .Input("gains: num_features * float")
+    .Input("thresholds: num_features * int32")
+    .Input("left_node_contribs: num_features * float")
+    .Input("right_node_contribs: num_features * float")
+    .Attr("max_depth: int >= 1")
+    .Attr("learning_rate: float")
+    .Attr("pruning_mode: int >=0")
+    .Attr("num_features: int >= 0")  // Inferred.
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle shape_handle;
+      int num_features;
+      TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+
+      // Feature_ids, should be one for each feature.
+      shape_inference::ShapeHandle feature_ids_shape;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape));
+      TF_RETURN_IF_ERROR(
+          c->Merge(c->input(1), c->Vector(num_features), &shape_handle));
+
+      for (int i = 0; i < num_features; ++i) {
+        // Node ids.
+        TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2), 1, &shape_handle));
+        auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
+        auto shape_rank_2 = c->MakeShape({c->Dim(shape_handle, 0), 1});
+
+        // Gains.
+        TF_RETURN_IF_ERROR(
+            c->WithRank(c->input(i + num_features + 2), 1, &shape_handle));
+        // TODO(nponomareva): replace this with input("name",vector of shapes).
+        TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features + 2),
+                                    shape_rank_1, &shape_handle));
+        // Thresholds.
+        TF_RETURN_IF_ERROR(
+            c->WithRank(c->input(i + num_features * 2 + 2), 1, &shape_handle));
+        TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 2 + 2),
+                                    shape_rank_1, &shape_handle));
+        // Left and right node contribs.
+        TF_RETURN_IF_ERROR(
+            c->WithRank(c->input(i + num_features * 3 + 2), 2, &shape_handle));
+        TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 3 + 2),
+                                    shape_rank_2, &shape_handle));
+        TF_RETURN_IF_ERROR(
+            c->WithRank(c->input(i + num_features * 4 + 2), 2, &shape_handle));
+        TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 4 + 2),
+                                    shape_rank_2, &shape_handle));
+      }
+      return Status::OK();
+    });
+
+}  // namespace tensorflow
index 0c3c3c4..2cc3c48 100644 (file)
@@ -73,6 +73,7 @@ py_library(
     deps = [
         ":array_ops",
         ":bitwise_ops",
+        ":boosted_trees_ops",
         ":check_ops",
         ":client",
         ":client_testlib",
@@ -1375,6 +1376,14 @@ tf_gen_op_wrapper_private_py(
 )
 
 tf_gen_op_wrapper_private_py(
+    name = "boosted_trees_ops_gen",
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        "//tensorflow/core:boosted_trees_ops_op_lib",
+    ],
+)
+
+tf_gen_op_wrapper_private_py(
     name = "summary_ops_gen",
     visibility = ["//tensorflow:__subpackages__"],
     deps = ["//tensorflow/core:summary_ops_op_lib"],
@@ -1624,6 +1633,19 @@ py_library(
 )
 
 py_library(
+    name = "boosted_trees_ops",
+    srcs = ["ops/boosted_trees_ops.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":boosted_trees_ops_gen",
+        ":framework",
+        ":ops",
+        ":training",
+        "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+    ],
+)
+
+py_library(
     name = "sets",
     srcs = [
         "ops/sets.py",
index 3346937..ab1d01a 100644 (file)
@@ -98,6 +98,8 @@ from tensorflow.python.summary import summary
 from tensorflow.python.user_ops import user_ops
 from tensorflow.python.util import compat
 
+# Import boosted trees ops to make sure the ops are registered (but unused).
+from tensorflow.python.ops import gen_boosted_trees_ops as _gen_boosted_trees_ops
 
 # Import cudnn rnn ops to make sure their ops are registered.
 from tensorflow.python.ops import gen_cudnn_rnn_ops as _
index 1fcff18..f93bc22 100644 (file)
@@ -15,6 +15,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":baseline",
+        ":boosted_trees",
         ":dnn",
         ":dnn_linear_combined",
         ":estimator",
@@ -240,6 +241,53 @@ py_test(
 )
 
 py_library(
+    name = "boosted_trees",
+    srcs = ["canned/boosted_trees.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":estimator",
+        ":head",
+        ":model_fn",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:boosted_trees_ops",
+        "//tensorflow/python:data_flow_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:lookup_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:state_ops",
+        "//tensorflow/python:summary",
+        "//tensorflow/python:training",
+        "//tensorflow/python:variable_scope",
+        "//tensorflow/python/feature_column",
+        "//tensorflow/python/ops/losses",
+    ],
+)
+
+py_test(
+    name = "boosted_trees_test",
+    size = "medium",
+    srcs = ["canned/boosted_trees_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":boosted_trees",
+        "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:platform",
+        "//tensorflow/python:platform_test",
+        "//tensorflow/python:resources",
+        "//tensorflow/python:training",
+        "//tensorflow/python/estimator:numpy_io",
+        "//tensorflow/python/feature_column",
+    ],
+)
+
+py_library(
     name = "dnn",
     srcs = ["canned/dnn.py"],
     srcs_version = "PY2AND3",
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
new file mode 100644 (file)
index 0000000..a9bbabd
--- /dev/null
@@ -0,0 +1,736 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Estimator classes for BoostedTrees."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.summary import summary
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
+from tensorflow.python.util.tf_export import tf_export
+
+TreeHParams = collections.namedtuple(
+    'TreeHParams',
+    ['n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity'])
+
+_HOLD_FOR_MULTI_CLASS_SUPPORT = object()
+_HOLD_FOR_MULTI_DIM_SUPPORT = object()
+
+
+def _get_transformed_features(features, feature_columns):
+  """Gets the transformed features from features/feature_columns pair.
+
+  Args:
+    features: a dicionary of name to Tensor.
+    feature_columns: a list/set of tf.feature_column.
+
+  Returns:
+    result_features: a list of the transformed features, sorted by the name.
+    num_buckets: the maximum number of buckets across bucketized_columns.
+
+  Raises:
+    ValueError: when unsupported features/columns are tried.
+  """
+  num_buckets = 1
+  # pylint:disable=protected-access
+  for fc in feature_columns:
+    if isinstance(fc, feature_column_lib._BucketizedColumn):
+      # N boundaries creates (N+1) buckets.
+      num_buckets = max(num_buckets, len(fc.boundaries) + 1)
+    else:
+      raise ValueError('For now, only bucketized_column is supported but '
+                       'got: {}'.format(fc))
+  transformed = feature_column_lib._transform_features(features,
+                                                       feature_columns)
+  # pylint:enable=protected-access
+  result_features = []
+  for column in sorted(transformed, key=lambda tc: tc.name):
+    source_name = column.source_column.name
+    squeezed_tensor = array_ops.squeeze(transformed[column], axis=1)
+    if len(squeezed_tensor.shape) > 1:
+      raise ValueError('For now, only supports features equivalent to rank 1 '
+                       'but column `{}` got: {}'.format(
+                           source_name, features[source_name].shape))
+    result_features.append(squeezed_tensor)
+  return result_features, num_buckets
+
+
+def _keep_as_local_variable(tensor, name=None):
+  """Stores a tensor as a local Variable for faster read."""
+  return variable_scope.variable(
+      initial_value=tensor,
+      trainable=False,
+      collections=[ops.GraphKeys.LOCAL_VARIABLES],
+      validate_shape=False,
+      name=name)
+
+
+class _CacheTrainingStatesUsingHashTable(object):
+  """Caching logits, etc. using MutableHashTable."""
+
+  def __init__(self, example_ids, logits_dimension):
+    """Creates a cache with the given configuration.
+
+    It maintains a MutableDenseHashTable for all values.
+    The API lookup() and insert() would have those specs,
+      tree_ids: shape=[batch_size], dtype=int32
+      node_ids: shape=[batch_size], dtype=int32
+      logits: shape=[batch_size, logits_dimension], dtype=float32
+    However in the MutableDenseHashTable, ids are bitcasted into float32 and
+    all values are concatenated as a single tensor (of float32).
+
+    Hence conversion happens internally before inserting to the HashTable and
+    after lookup from it.
+
+    Args:
+      example_ids: a Rank 1 tensor to be used as a key of the cache.
+      logits_dimension: a constant (int) for the dimension of logits.
+
+    Raises:
+      ValueError: if example_ids is other than int64 or string.
+    """
+    if dtypes.as_dtype(dtypes.int64).is_compatible_with(example_ids.dtype):
+      empty_key = -1 << 62
+    elif dtypes.as_dtype(dtypes.string).is_compatible_with(example_ids.dtype):
+      empty_key = ''
+    else:
+      raise ValueError('Unsupported example_id_feature dtype %s.',
+                       example_ids.dtype)
+    # Cache holds latest <tree_id, node_id, logits> for each example.
+    # tree_id and node_id are both int32 but logits is a float32.
+    # To reduce the overhead, we store all of them together as float32 and
+    # bitcast the ids to int32.
+    self._table_ref = lookup_ops.mutable_dense_hash_table_v2(
+        empty_key=empty_key, value_dtype=dtypes.float32, value_shape=[3])
+    self._example_ids = example_ids
+    self._logits_dimension = logits_dimension
+
+  def lookup(self):
+    """Returns cached_tree_ids, cached_node_ids, cached_logits."""
+    cached_tree_ids, cached_node_ids, cached_logits = array_ops.split(
+        lookup_ops.lookup_table_find_v2(
+            self._table_ref, self._example_ids, default_value=[0.0, 0.0, 0.0]),
+        [1, 1, self._logits_dimension],
+        axis=1)
+    cached_tree_ids = array_ops.squeeze(
+        array_ops.bitcast(cached_tree_ids, dtypes.int32))
+    cached_node_ids = array_ops.squeeze(
+        array_ops.bitcast(cached_node_ids, dtypes.int32))
+    return (cached_tree_ids, cached_node_ids, cached_logits)
+
+  def insert(self, tree_ids, node_ids, logits):
+    """Inserts values and returns the op."""
+    insert_op = lookup_ops.lookup_table_insert_v2(
+        self._table_ref, self._example_ids,
+        array_ops.concat(
+            [
+                array_ops.expand_dims(
+                    array_ops.bitcast(tree_ids, dtypes.float32), 1),
+                array_ops.expand_dims(
+                    array_ops.bitcast(node_ids, dtypes.float32), 1),
+                logits,
+            ],
+            axis=1,
+            name='value_concat_for_cache_insert'))
+    return insert_op
+
+
+class _CacheTrainingStatesUsingVariables(object):
+  """Caching logits, etc. using Variables."""
+
+  def __init__(self, batch_size, logits_dimension):
+    """Creates a cache with the given configuration.
+
+    It maintains three variables, tree_ids, node_ids, logits, for caching.
+      tree_ids: shape=[batch_size], dtype=int32
+      node_ids: shape=[batch_size], dtype=int32
+      logits: shape=[batch_size, logits_dimension], dtype=float32
+
+    Note, this can be used only with in-memory data setting.
+
+    Args:
+      batch_size: `int`, the size of the cache.
+      logits_dimension: a constant (int) for the dimension of logits.
+    """
+    self._logits_dimension = logits_dimension
+    self._tree_ids = _keep_as_local_variable(
+        array_ops.zeros([batch_size], dtype=dtypes.int32),
+        name='tree_ids_cache')
+    self._node_ids = _keep_as_local_variable(
+        array_ops.zeros([batch_size], dtype=dtypes.int32),
+        name='node_ids_cache')
+    self._logits = _keep_as_local_variable(
+        array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32),
+        name='logits_cache')
+
+  def lookup(self):
+    """Returns cached_tree_ids, cached_node_ids, cached_logits."""
+    return (self._tree_ids, self._node_ids, self._logits)
+
+  def insert(self, tree_ids, node_ids, logits):
+    """Inserts values and returns the op."""
+    return control_flow_ops.group(
+        [
+            self._tree_ids.assign(tree_ids),
+            self._node_ids.assign(node_ids),
+            self._logits.assign(logits)
+        ],
+        name='cache_insert')
+
+
+class StopAtAttemptsHook(session_run_hook.SessionRunHook):
+  """Hook that requests stop at the number of trees."""
+
+  def __init__(self, num_finalized_trees_tensor, num_attempted_layers_tensor,
+               max_trees, max_depth):
+    self._num_finalized_trees_tensor = num_finalized_trees_tensor
+    self._num_attempted_layers_tensor = num_attempted_layers_tensor
+    self._max_trees = max_trees
+    self._max_depth = max_depth
+
+  def before_run(self, run_context):
+    return session_run_hook.SessionRunArgs(
+        [self._num_finalized_trees_tensor, self._num_attempted_layers_tensor])
+
+  def after_run(self, run_context, run_values):
+    num_finalized_trees, num_attempted_layers = run_values.results
+    if (num_finalized_trees >= self._max_trees or
+        1.0 * num_attempted_layers / self._max_depth > 2 * self._max_trees):
+      run_context.request_stop()
+
+
+class StopAtNumTreesHook(session_run_hook.SessionRunHook):
+  """Hook that requests stop at the number of trees."""
+
+  def __init__(self, num_trees_tensor, max_trees):
+    self._num_trees_tensor = num_trees_tensor
+    self._max_trees = max_trees
+
+  def before_run(self, run_context):
+    return session_run_hook.SessionRunArgs(self._num_trees_tensor)
+
+  def after_run(self, run_context, run_values):
+    num_trees = run_values.results
+    if num_trees > self._max_trees:
+      run_context.request_stop()
+
+
+def _bt_model_fn(
+    features,
+    labels,
+    mode,
+    head,
+    feature_columns,
+    tree_hparams,
+    n_batches_per_layer,
+    config,
+    closed_form_grad_and_hess_fn=None,
+    example_id_column_name=None,
+    # TODO(youngheek): replace this later using other options.
+    train_in_memory=False,
+    name='TreeEnsembleModel'):
+  """Gradient Boosted Decision Tree model_fn.
+
+  Args:
+    features: dict of `Tensor`.
+    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
+      dtype `int32` or `int64` in the range `[0, n_classes)`.
+    mode: Defines whether this is training, evaluation or prediction.
+      See `ModeKeys`.
+    head: A `head_lib._Head` instance.
+    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
+    tree_hparams: TODO. collections.namedtuple for hyper parameters.
+    n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at
+      least n_batches_per_layer accumulations.
+    config: `RunConfig` object to configure the runtime settings.
+    closed_form_grad_and_hess_fn: a function that accepts logits and labels
+      and returns gradients and hessians. By default, they are created by
+      tf.gradients() from the loss.
+    example_id_column_name: Name of the feature for a unique ID per example.
+      Currently experimental -- not exposed to public API.
+    train_in_memory: `bool`, when true, it assumes the dataset is in memory,
+      i.e., input_fn should return the entire dataset as a single batch, and
+      also n_batches_per_layer should be set as 1.
+    name: Name to use for the model.
+
+  Returns:
+      An `EstimatorSpec` instance.
+
+  Raises:
+    ValueError: mode or params are invalid, or features has the wrong type.
+  """
+  is_single_machine = (config.num_worker_replicas == 1)
+  if train_in_memory:
+    assert n_batches_per_layer == 1, (
+        'When train_in_memory is enabled, input_fn should return the entire '
+        'dataset as a single batch, and n_batches_per_layer should be set as '
+        '1.')
+  worker_device = control_flow_ops.no_op().device
+  # maximum number of splits possible in the whole tree =2^(D-1)-1
+  # TODO(youngheek): perhaps storage could be optimized by storing stats with
+  # the dimension max_splits_per_layer, instead of max_splits (for the entire
+  # tree).
+  max_splits = (1 << tree_hparams.max_depth) - 1
+  with ops.name_scope(name) as name:
+    # Prepare.
+    global_step = training_util.get_or_create_global_step()
+    input_feature_list, num_buckets = _get_transformed_features(
+        features, feature_columns)
+    if train_in_memory and mode == model_fn.ModeKeys.TRAIN:
+      input_feature_list = [
+          _keep_as_local_variable(feature) for feature in input_feature_list
+      ]
+    num_features = len(input_feature_list)
+
+    cache = None
+    if mode == model_fn.ModeKeys.TRAIN:
+      if train_in_memory and is_single_machine:  # maybe just train_in_memory?
+        batch_size = array_ops.shape(input_feature_list[0])[0]
+        cache = _CacheTrainingStatesUsingVariables(batch_size,
+                                                   head.logits_dimension)
+      elif example_id_column_name:
+        example_ids = features[example_id_column_name]
+        cache = _CacheTrainingStatesUsingHashTable(example_ids,
+                                                   head.logits_dimension)
+
+    # Create Ensemble resources.
+    if is_single_machine:
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+      local_tree_ensemble = tree_ensemble
+      ensemble_reload = control_flow_ops.no_op()
+    else:
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+      with ops.device(worker_device):
+        local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
+            name=name + '_local', is_local=True)
+      # TODO(soroush): Do partial updates if this becomes a bottleneck.
+      ensemble_reload = local_tree_ensemble.deserialize(
+          *tree_ensemble.serialize())
+
+    # Create logits.
+    if mode != model_fn.ModeKeys.TRAIN:
+      logits = boosted_trees_ops.predict(
+          tree_ensemble_handle=local_tree_ensemble.resource_handle,
+          bucketized_features=input_feature_list,
+          logits_dimension=head.logits_dimension,
+          max_depth=tree_hparams.max_depth)
+    else:
+      if cache:
+        cached_tree_ids, cached_node_ids, cached_logits = cache.lookup()
+      else:
+        # Always start from the beginning when no cache is set up.
+        batch_size = array_ops.shape(input_feature_list[0])[0]
+        cached_tree_ids, cached_node_ids, cached_logits = (
+            array_ops.zeros([batch_size], dtype=dtypes.int32),
+            array_ops.zeros([batch_size], dtype=dtypes.int32),
+            array_ops.zeros(
+                [batch_size, head.logits_dimension], dtype=dtypes.float32))
+      with ops.control_dependencies([ensemble_reload]):
+        (stamp_token, num_trees, num_finalized_trees,
+         num_attempted_layers) = local_tree_ensemble.get_states()
+        summary.scalar('ensemble/num_trees', num_trees)
+        summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
+        summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
+
+        partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
+            tree_ensemble_handle=local_tree_ensemble.resource_handle,
+            cached_tree_ids=cached_tree_ids,
+            cached_node_ids=cached_node_ids,
+            bucketized_features=input_feature_list,
+            logits_dimension=head.logits_dimension,
+            max_depth=tree_hparams.max_depth)
+      logits = cached_logits + partial_logits
+
+    # Create training graph.
+    def _train_op_fn(loss):
+      """Run one training iteration."""
+      train_op = []
+      if cache:
+        train_op.append(cache.insert(tree_ids, node_ids, logits))
+      if closed_form_grad_and_hess_fn:
+        gradients, hessians = closed_form_grad_and_hess_fn(logits, labels)
+      else:
+        gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0]
+        hessians = gradients_impl.gradients(
+            gradients, logits, name='Hessians')[0]
+      stats_summary_list = [
+          array_ops.squeeze(
+              boosted_trees_ops.make_stats_summary(
+                  node_ids=node_ids,
+                  gradients=gradients,
+                  hessians=hessians,
+                  bucketized_features_list=[input_feature_list[f]],
+                  max_splits=max_splits,
+                  num_buckets=num_buckets),
+              axis=0) for f in range(num_features)
+      ]
+
+      def grow_tree_from_stats_summaries(stats_summary_list):
+        """Updates ensemble based on the best gains from stats summaries."""
+        (node_ids_per_feature, gains_list, thresholds_list,
+         left_node_contribs_list, right_node_contribs_list) = (
+             boosted_trees_ops.calculate_best_gains_per_feature(
+                 node_id_range=array_ops.stack([
+                     math_ops.reduce_min(node_ids),
+                     math_ops.reduce_max(node_ids)
+                 ]),
+                 stats_summary_list=stats_summary_list,
+                 l1=tree_hparams.l1,
+                 l2=tree_hparams.l2,
+                 tree_complexity=tree_hparams.tree_complexity,
+                 max_splits=max_splits))
+        grow_op = boosted_trees_ops.update_ensemble(
+            # Confirm if local_tree_ensemble or tree_ensemble should be used.
+            tree_ensemble.resource_handle,
+            feature_ids=math_ops.range(0, num_features, dtype=dtypes.int32),
+            node_ids=node_ids_per_feature,
+            gains=gains_list,
+            thresholds=thresholds_list,
+            left_node_contribs=left_node_contribs_list,
+            right_node_contribs=right_node_contribs_list,
+            learning_rate=tree_hparams.learning_rate,
+            max_depth=tree_hparams.max_depth,
+            pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
+        return grow_op
+
+      if train_in_memory and is_single_machine:
+        train_op.append(state_ops.assign_add(global_step, 1))
+        train_op.append(grow_tree_from_stats_summaries(stats_summary_list))
+      else:
+        summary_accumulator = data_flow_ops.ConditionalAccumulator(
+            dtype=dtypes.float32,
+            # The stats consist of gradients and hessians (the last dimension).
+            shape=[num_features, max_splits, num_buckets, 2],
+            shared_name='stats_summary_accumulator')
+        apply_grad = summary_accumulator.apply_grad(
+            array_ops.stack(stats_summary_list, axis=0), stamp_token)
+
+        def grow_tree_from_accumulated_summaries_fn():
+          """Updates the tree with the best layer from accumulated summaries."""
+          # Take out the accumulated summaries from the accumulator and grow.
+          stats_summary_list = array_ops.unstack(
+              summary_accumulator.take_grad(1), axis=0)
+          grow_op = grow_tree_from_stats_summaries(stats_summary_list)
+          return grow_op
+
+        with ops.control_dependencies([apply_grad]):
+          train_op.append(state_ops.assign_add(global_step, 1))
+          if config.is_chief:
+            train_op.append(
+                control_flow_ops.cond(
+                    math_ops.greater_equal(
+                        summary_accumulator.num_accumulated(),
+                        n_batches_per_layer),
+                    grow_tree_from_accumulated_summaries_fn,
+                    control_flow_ops.no_op,
+                    name='wait_until_n_batches_accumulated'))
+
+      return control_flow_ops.group(train_op, name='train_op')
+
+  estimator_spec = head.create_estimator_spec(
+      features=features,
+      mode=mode,
+      labels=labels,
+      train_op_fn=_train_op_fn,
+      logits=logits)
+  if mode == model_fn.ModeKeys.TRAIN:
+    # Add an early stop hook.
+    estimator_spec = estimator_spec._replace(
+        training_hooks=estimator_spec.training_hooks +
+        (StopAtNumTreesHook(num_trees, tree_hparams.n_trees),))
+  return estimator_spec
+
+
+def _create_classification_head(n_classes,
+                                weight_column=None,
+                                label_vocabulary=None):
+  """Creates a classification head. Refer to canned.head for details on args."""
+  # TODO(nponomareva): Support multi-class cases.
+  if n_classes == 2:
+    # pylint: disable=protected-access
+    return head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+        weight_column=weight_column,
+        label_vocabulary=label_vocabulary,
+        loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+    # pylint: enable=protected-access
+  else:
+    raise ValueError('For now only binary classification is supported.'
+                     'n_classes given as {}'.format(n_classes))
+
+
+def _create_classification_head_and_closed_form(n_classes, weight_column,
+                                                label_vocabulary):
+  """Creates a head for classifier and the closed form gradients/hessians."""
+  head = _create_classification_head(n_classes, weight_column, label_vocabulary)
+  if n_classes == 2 and weight_column is None and label_vocabulary is None:
+    # Use the closed-form gradients/hessians for 2 class.
+    def _grad_and_hess_for_logloss(logits, labels):
+      # TODO(youngheek): add weights handling.
+      predictions = math_ops.reciprocal(math_ops.exp(-logits) + 1.0)
+      normalizer = math_ops.reciprocal(
+          math_ops.cast(array_ops.size(predictions), dtypes.float32))
+      gradients = (predictions - labels) * normalizer
+      hessians = predictions * (1.0 - predictions) * normalizer
+      return gradients, hessians
+
+    closed_form = _grad_and_hess_for_logloss
+  else:
+    closed_form = None
+  return (head, closed_form)
+
+
+def _create_regression_head(label_dimension, weight_column=None):
+  if label_dimension != 1:
+    raise ValueError('For now only 1 dimension regression is supported.'
+                     'label_dimension given as {}'.format(label_dimension))
+  # pylint: disable=protected-access
+  return head_lib._regression_head_with_mean_squared_error_loss(
+      label_dimension=label_dimension,
+      weight_column=weight_column,
+      loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+  # pylint: enable=protected-access
+
+
+@tf_export('estimator.BoostedTreesClassifier')
+class BoostedTreesClassifier(estimator.Estimator):
+  """A Classifier for Tensorflow Boosted Trees models."""
+
+  def __init__(
+      self,
+      feature_columns,
+      n_batches_per_layer,
+      model_dir=None,
+      n_classes=_HOLD_FOR_MULTI_CLASS_SUPPORT,
+      weight_column=None,
+      label_vocabulary=None,
+      n_trees=100,
+      max_depth=6,
+      learning_rate=0.1,
+      l1_regularization=0.,
+      l2_regularization=0.,
+      tree_complexity=0.,
+      config=None):
+    """Initializes a `BoostedTreesClassifier` instance.
+
+    Example:
+
+    ```python
+    bucketized_feature_1 = bucketized_column(
+      numeric_column('feature_1'), BUCKET_BOUNDARIES_1)
+    bucketized_feature_2 = bucketized_column(
+      numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+
+    classifier = estimator.BoostedTreesClassifier(
+        feature_columns=[bucketized_feature_1, bucketized_feature_2],
+        n_trees=100,
+        ... <some other params>
+    )
+
+    def input_fn_train():
+      ...
+      return dataset
+
+    classifier.train(input_fn=input_fn_train)
+
+    def input_fn_eval():
+      ...
+      return dataset
+
+    metrics = classifier.evaluate(input_fn=input_fn_eval)
+    ```
+
+    Args:
+      feature_columns: An iterable containing all the feature columns used by
+        the model. All items in the set should be instances of classes derived
+        from `FeatureColumn`.
+      n_batches_per_layer: the number of batches to collect statistics per
+        layer.
+      model_dir: Directory to save model parameters, graph and etc. This can
+        also be used to load checkpoints from the directory into a estimator
+        to continue training a previously saved model.
+      n_classes: number of label classes. Default is binary classification.
+        Multiclass support is not yet implemented.
+      weight_column: A string or a `_NumericColumn` created by
+        `tf.feature_column.numeric_column` defining feature column representing
+        weights. It is used to downweight or boost examples during training. It
+        will be multiplied by the loss of the example. If it is a string, it is
+        used as a key to fetch weight tensor from the `features`. If it is a
+        `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+        then weight_column.normalizer_fn is applied on it to get weight tensor.
+      label_vocabulary: A list of strings represents possible label values. If
+        given, labels must be string type and have any value in
+        `label_vocabulary`. If it is not given, that means labels are
+        already encoded as integer or float within [0, 1] for `n_classes=2` and
+        encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
+        Also there will be errors if vocabulary is not provided and labels are
+        string.
+      n_trees: number trees to be created.
+      max_depth: maximum depth of the tree to grow.
+      learning_rate: shrinkage parameter to be used when a tree added to the
+        model.
+      l1_regularization: regularization multiplier applied to the absolute
+        weights of the tree leafs.
+      l2_regularization: regularization multiplier applied to the square weights
+        of the tree leafs.
+      tree_complexity: regularization factor to penalize trees with more leaves.
+      config: `RunConfig` object to configure the runtime settings.
+
+    Raises:
+      ValueError: when wrong arguments are given or unsupported functionalities
+         are requested.
+    """
+    # TODO(nponomareva): Support multi-class cases.
+    if n_classes == _HOLD_FOR_MULTI_CLASS_SUPPORT:
+      n_classes = 2
+    head, closed_form = _create_classification_head_and_closed_form(
+        n_classes, weight_column, label_vocabulary=label_vocabulary)
+
+    # HParams for the model.
+    tree_hparams = TreeHParams(
+        n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+        tree_complexity)
+
+    def _model_fn(features, labels, mode, config):
+      return _bt_model_fn(  # pylint: disable=protected-access
+          features,
+          labels,
+          mode,
+          head,
+          feature_columns,
+          tree_hparams,
+          n_batches_per_layer,
+          config,
+          closed_form_grad_and_hess_fn=closed_form)
+
+    super(BoostedTreesClassifier, self).__init__(
+        model_fn=_model_fn, model_dir=model_dir, config=config)
+
+
+@tf_export('estimator.BoostedTreesRegressor')
+class BoostedTreesRegressor(estimator.Estimator):
+  """A Regressor for Tensorflow Boosted Trees models."""
+
+  def __init__(
+      self,
+      feature_columns,
+      n_batches_per_layer,
+      model_dir=None,
+      label_dimension=_HOLD_FOR_MULTI_DIM_SUPPORT,
+      weight_column=None,
+      n_trees=100,
+      max_depth=6,
+      learning_rate=0.1,
+      l1_regularization=0.,
+      l2_regularization=0.,
+      tree_complexity=0.,
+      config=None):
+    """Initializes a `BoostedTreesRegressor` instance.
+
+    Example:
+
+    ```python
+    bucketized_feature_1 = bucketized_column(
+      numeric_column('feature_1'), BUCKET_BOUNDARIES_1)
+    bucketized_feature_2 = bucketized_column(
+      numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+
+    regressor = estimator.BoostedTreesRegressor(
+        feature_columns=[bucketized_feature_1, bucketized_feature_2],
+        n_trees=100,
+        ... <some other params>
+    )
+
+    def input_fn_train():
+      ...
+      return dataset
+
+    regressor.train(input_fn=input_fn_train)
+
+    def input_fn_eval():
+      ...
+      return dataset
+
+    metrics = regressor.evaluate(input_fn=input_fn_eval)
+    ```
+
+    Args:
+      feature_columns: An iterable containing all the feature columns used by
+        the model. All items in the set should be instances of classes derived
+        from `FeatureColumn`.
+      n_batches_per_layer: the number of batches to collect statistics per
+        layer.
+      model_dir: Directory to save model parameters, graph and etc. This can
+        also be used to load checkpoints from the directory into a estimator
+        to continue training a previously saved model.
+      label_dimension: Number of regression targets per example.
+        Multi-dimensional support is not yet implemented.
+      weight_column: A string or a `_NumericColumn` created by
+        `tf.feature_column.numeric_column` defining feature column representing
+        weights. It is used to downweight or boost examples during training. It
+        will be multiplied by the loss of the example. If it is a string, it is
+        used as a key to fetch weight tensor from the `features`. If it is a
+        `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+        then weight_column.normalizer_fn is applied on it to get weight tensor.
+      n_trees: number trees to be created.
+      max_depth: maximum depth of the tree to grow.
+      learning_rate: shrinkage parameter to be used when a tree added to the
+        model.
+      l1_regularization: regularization multiplier applied to the absolute
+        weights of the tree leafs.
+      l2_regularization: regularization multiplier applied to the square weights
+        of the tree leafs.
+      tree_complexity: regularization factor to penalize trees with more leaves.
+      config: `RunConfig` object to configure the runtime settings.
+
+    Raises:
+      ValueError: when wrong arguments are given or unsupported functionalities
+         are requested.
+    """
+    # TODO(nponomareva): Extend it to multi-dimension cases.
+    if label_dimension == _HOLD_FOR_MULTI_DIM_SUPPORT:
+      label_dimension = 1
+    head = _create_regression_head(label_dimension, weight_column)
+
+    # HParams for the model.
+    tree_hparams = TreeHParams(
+        n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+        tree_complexity)
+
+    def _model_fn(features, labels, mode, config):
+      return _bt_model_fn(  # pylint: disable=protected-access
+          features, labels, mode, head, feature_columns, tree_hparams,
+          n_batches_per_layer, config)
+
+    super(BoostedTreesRegressor, self).__init__(
+        model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
new file mode 100644 (file)
index 0000000..9276fba
--- /dev/null
@@ -0,0 +1,799 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests boosted_trees estimators and model_fn."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator import run_config
+from tensorflow.python.estimator.canned import boosted_trees
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gen_boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import checkpoint_utils
+
+NUM_FEATURES = 3
+
+BUCKET_BOUNDARIES = [-2., .5, 12.]  # Boundaries for all the features.
+INPUT_FEATURES = np.array(
+    [
+        [12.5, 1.0, -2.001, -2.0001, -1.999],  # feature_0 quantized:[3,2,0,0,1]
+        [2.0, -3.0, 0.5, 0.0, 0.4995],         # feature_1 quantized:[2,0,2,1,1]
+        [3.0, 20.0, 50.0, -100.0, 102.75],     # feature_2 quantized:[2,3,3,0,3]
+    ],
+    dtype=np.float32)
+CLASSIFICATION_LABELS = [[0.], [1.], [1.], [0.], [0.]]
+REGRESSION_LABELS = [[1.5], [0.3], [0.2], [2.], [5.]]
+FEATURES_DICT = {'f_%d' % i: INPUT_FEATURES[i] for i in range(NUM_FEATURES)}
+
+# EXAMPLE_ID is not exposed to Estimator yet, but supported at model_fn level.
+EXAMPLE_IDS = np.array([0, 1, 2, 3, 4], dtype=np.int64)
+EXAMPLE_ID_COLUMN = '__example_id__'
+
+
+def _make_train_input_fn(is_classification):
+  """Makes train input_fn for classification/regression."""
+
+  def _input_fn():
+    features = dict(FEATURES_DICT)
+    features[EXAMPLE_ID_COLUMN] = constant_op.constant(EXAMPLE_IDS)
+    if is_classification:
+      labels = CLASSIFICATION_LABELS
+    else:
+      labels = REGRESSION_LABELS
+    return features, labels
+
+  return _input_fn
+
+
+class BoostedTreesClassifierTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self._feature_columns = {
+        feature_column.bucketized_column(
+            feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+            BUCKET_BOUNDARIES)
+        for i in range(NUM_FEATURES)
+    }
+
+  def _assert_checkpoint(self, model_dir, expected_global_step):
+    self.assertEqual(expected_global_step,
+                     checkpoint_utils.load_variable(model_dir,
+                                                    ops.GraphKeys.GLOBAL_STEP))
+
+  def testTrainAndEvaluateBinaryClassifier(self):
+    input_fn = _make_train_input_fn(is_classification=True)
+
+    est = boosted_trees.BoostedTreesClassifier(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5)
+
+    # It will stop after 5 steps because of the max depth and num trees.
+    num_steps = 100
+    # Train for a few steps, and validate final checkpoint.
+    est.train(input_fn, steps=num_steps)
+    self._assert_checkpoint(est.model_dir, 6)
+    eval_res = est.evaluate(input_fn=input_fn, steps=1)
+    self.assertAllClose(eval_res['accuracy'], 1.0)
+
+  def testInferBinaryClassifier(self):
+    train_input_fn = _make_train_input_fn(is_classification=True)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.BoostedTreesClassifier(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5)
+
+    # It will stop after 5 steps because of the max depth and num trees.
+    num_steps = 100
+    # Train for a few steps, and validate final checkpoint.
+    est.train(train_input_fn, steps=num_steps)
+
+    predictions = list(est.predict(input_fn=predict_input_fn))
+    self.assertEquals(5, len(predictions))
+    # All labels are correct.
+    self.assertAllClose([0], predictions[0]['class_ids'])
+    self.assertAllClose([1], predictions[1]['class_ids'])
+    self.assertAllClose([1], predictions[2]['class_ids'])
+    self.assertAllClose([0], predictions[3]['class_ids'])
+    self.assertAllClose([0], predictions[4]['class_ids'])
+
+
+class BoostedTreesRegressionTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self._feature_columns = {
+        feature_column.bucketized_column(
+            feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+            BUCKET_BOUNDARIES)
+        for i in range(NUM_FEATURES)
+    }
+
+  def _assert_checkpoint(self, model_dir, expected_global_step):
+    self.assertEqual(expected_global_step,
+                     checkpoint_utils.load_variable(model_dir,
+                                                    ops.GraphKeys.GLOBAL_STEP))
+
+  def testTrainAndEvaluateRegressor(self):
+    input_fn = _make_train_input_fn(is_classification=False)
+
+    est = boosted_trees.BoostedTreesRegressor(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=2,
+        max_depth=5)
+
+    # It will stop after 10 steps because of the max depth and num trees.
+    num_steps = 100
+    # Train for a few steps, and validate final checkpoint.
+    est.train(input_fn, steps=num_steps)
+    self._assert_checkpoint(est.model_dir, 11)
+    eval_res = est.evaluate(input_fn=input_fn, steps=1)
+    self.assertAllClose(eval_res['average_loss'], 0.913176)
+
+  def testInferRegressor(self):
+    train_input_fn = _make_train_input_fn(is_classification=False)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.BoostedTreesRegressor(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5)
+
+    # It will stop after 5 steps because of the max depth and num trees.
+    num_steps = 100
+    # Train for a few steps, and validate final checkpoint.
+    est.train(train_input_fn, steps=num_steps)
+    self._assert_checkpoint(est.model_dir, 6)
+
+    predictions = list(est.predict(input_fn=predict_input_fn))
+
+    self.assertEquals(5, len(predictions))
+    self.assertAllClose([0.703549], predictions[0]['predictions'])
+    self.assertAllClose([0.266539], predictions[1]['predictions'])
+    self.assertAllClose([0.256479], predictions[2]['predictions'])
+    self.assertAllClose([1.088732], predictions[3]['predictions'])
+    self.assertAllClose([1.901732], predictions[4]['predictions'])
+
+
+class ModelFnTests(test_util.TensorFlowTestCase):
+  """Tests bt_model_fn including unexposed internal functionalities."""
+
+  def setUp(self):
+    self._feature_columns = {
+        feature_column.bucketized_column(
+            feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+            BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
+    }
+    self._tree_hparams = boosted_trees.TreeHParams(
+        n_trees=2,
+        max_depth=2,
+        learning_rate=0.1,
+        l1=0.,
+        l2=0.01,
+        tree_complexity=0.)
+
+  def _get_expected_ensembles_for_classification(self):
+    first_round = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 2
+              threshold: 2
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 0.387675
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.181818
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.0625
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 1
+        }
+        """
+    second_round = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 2
+              threshold: 2
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 0.387675
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 3
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              gain: 0.0
+              original_leaf {
+                scalar: -0.181818
+              }
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 0
+              left_id: 5
+              right_id: 6
+            }
+            metadata {
+              gain: 0.105518
+              original_leaf {
+                scalar: 0.0625
+              }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.348397
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.181818
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.224091
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.056815
+            }
+          }
+        }
+        trees {
+          nodes {
+            leaf {
+              scalar: 0.0
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 2
+          is_finalized: true
+        }
+        tree_metadata {
+          num_layers_grown: 0
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 2
+        }
+        """
+    third_round = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 2
+              threshold: 2
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 0.387675
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 3
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              gain: 0.0
+              original_leaf {
+                scalar: -0.181818
+              }
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 0
+              left_id: 5
+              right_id: 6
+            }
+            metadata {
+              gain: 0.105518
+              original_leaf {
+                scalar: 0.0625
+              }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.348397
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.181818
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.224091
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.056815
+            }
+          }
+        }
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 0
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 0.287131
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.162042
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.086986
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 2
+          is_finalized: true
+        }
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 2
+          num_layers_attempted: 3
+        }
+        """
+    return (first_round, second_round, third_round)
+
+  def _get_expected_ensembles_for_regression(self):
+    first_round = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 1
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 1.169714
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.241322
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.083951
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 1
+        }
+        """
+    second_round = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 1
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 1.169714
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 1
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              gain: 2.673407
+              original_leaf {
+                scalar: 0.241322
+              }
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 0
+              left_id: 5
+              right_id: 6
+            }
+            metadata {
+              gain: 0.324102
+              original_leaf {
+                scalar: 0.083951
+              }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.563167
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.247047
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.095273
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.222102
+            }
+          }
+        }
+        trees {
+          nodes {
+            leaf {
+              scalar: 0.0
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 2
+          is_finalized: true
+        }
+        tree_metadata {
+          num_layers_grown: 0
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 2
+        }
+        """
+    third_round = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 1
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 1.169714
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 1
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              gain: 2.673407
+              original_leaf {
+                scalar: 0.241322
+              }
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 0
+              left_id: 5
+              right_id: 6
+            }
+            metadata {
+              gain: 0.324102
+              original_leaf {
+                scalar: 0.083951
+              }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.563167
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.247047
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.095273
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.222102
+            }
+          }
+        }
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 0
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 0.981026
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.005166
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.180281
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 2
+          is_finalized: true
+        }
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 2
+          num_layers_attempted: 3
+        }
+        """
+    return (first_round, second_round, third_round)
+
+  def _get_train_op_and_ensemble(self, head, config, is_classification,
+                                 train_in_memory):
+    """Calls bt_model_fn() and returns the train_op and ensemble_serialzed."""
+    features, labels = _make_train_input_fn(is_classification)()
+    estimator_spec = boosted_trees._bt_model_fn(  # pylint:disable=protected-access
+        features=features,
+        labels=labels,
+        mode=model_fn.ModeKeys.TRAIN,
+        head=head,
+        feature_columns=self._feature_columns,
+        tree_hparams=self._tree_hparams,
+        example_id_column_name=EXAMPLE_ID_COLUMN,
+        n_batches_per_layer=1,
+        config=config,
+        train_in_memory=train_in_memory)
+    resources.initialize_resources(resources.shared_resources()).run()
+    variables.global_variables_initializer().run()
+    variables.local_variables_initializer().run()
+
+    # Gets the train_op and serialized proto of the ensemble.
+    shared_resources = resources.shared_resources()
+    self.assertEqual(1, len(shared_resources))
+    train_op = estimator_spec.train_op
+    with ops.control_dependencies([train_op]):
+      _, ensemble_serialized = (
+          gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
+              shared_resources[0].handle))
+    return train_op, ensemble_serialized
+
+  def testTrainClassifierInMemory(self):
+    ops.reset_default_graph()
+    expected_first, expected_second, expected_third = (
+        self._get_expected_ensembles_for_classification())
+    with self.test_session() as sess:
+      # Train with train_in_memory mode.
+      with sess.graph.as_default():
+        train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+            boosted_trees._create_classification_head(n_classes=2),
+            run_config.RunConfig(),
+            is_classification=True,
+            train_in_memory=True)
+      _, serialized = sess.run([train_op, ensemble_serialized])
+      # Validate the trained ensemble.
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      ensemble_proto.ParseFromString(serialized)
+      self.assertProtoEquals(expected_first, ensemble_proto)
+
+      # Run one more time and validate the trained ensemble.
+      _, serialized = sess.run([train_op, ensemble_serialized])
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      ensemble_proto.ParseFromString(serialized)
+      self.assertProtoEquals(expected_second, ensemble_proto)
+
+      # Third round training and validation.
+      _, serialized = sess.run([train_op, ensemble_serialized])
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      ensemble_proto.ParseFromString(serialized)
+      self.assertProtoEquals(expected_third, ensemble_proto)
+
+  def testTrainClassifierNonInMemory(self):
+    ops.reset_default_graph()
+    expected_first, expected_second, expected_third = (
+        self._get_expected_ensembles_for_classification())
+    with self.test_session() as sess:
+      # Train without train_in_memory mode.
+      with sess.graph.as_default():
+        train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+            boosted_trees._create_classification_head(n_classes=2),
+            run_config.RunConfig(),
+            is_classification=True,
+            train_in_memory=False)
+      _, serialized = sess.run([train_op, ensemble_serialized])
+      # Validate the trained ensemble.
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      ensemble_proto.ParseFromString(serialized)
+      self.assertProtoEquals(expected_first, ensemble_proto)
+
+      # Run one more time and validate the trained ensemble.
+      _, serialized = sess.run([train_op, ensemble_serialized])
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      ensemble_proto.ParseFromString(serialized)
+      self.assertProtoEquals(expected_second, ensemble_proto)
+
+      # Third round training and validation.
+      _, serialized = sess.run([train_op, ensemble_serialized])
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      ensemble_proto.ParseFromString(serialized)
+      self.assertProtoEquals(expected_third, ensemble_proto)
+
+  def testTrainRegressorInMemory(self):
+    ops.reset_default_graph()
+    expected_first, expected_second, expected_third = (
+        self._get_expected_ensembles_for_regression())
+    with self.test_session() as sess:
+      # Train with train_in_memory mode.
+      with sess.graph.as_default():
+        train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+            boosted_trees._create_regression_head(label_dimension=1),
+            run_config.RunConfig(),
+            is_classification=False,
+            train_in_memory=True)
+      _, serialized = sess.run([train_op, ensemble_serialized])
+      # Validate the trained ensemble.
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      ensemble_proto.ParseFromString(serialized)
+      self.assertProtoEquals(expected_first, ensemble_proto)
+
+      # Run one more time and validate the trained ensemble.
+      _, serialized = sess.run([train_op, ensemble_serialized])
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      ensemble_proto.ParseFromString(serialized)
+      self.assertProtoEquals(expected_second, ensemble_proto)
+
+      # Third round training and validation.
+      _, serialized = sess.run([train_op, ensemble_serialized])
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      ensemble_proto.ParseFromString(serialized)
+      self.assertProtoEquals(expected_third, ensemble_proto)
+
+  def testTrainRegressorNonInMemory(self):
+    ops.reset_default_graph()
+    expected_first, expected_second, expected_third = (
+        self._get_expected_ensembles_for_regression())
+    with self.test_session() as sess:
+      # Train without train_in_memory mode.
+      with sess.graph.as_default():
+        train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+            boosted_trees._create_regression_head(label_dimension=1),
+            run_config.RunConfig(),
+            is_classification=False,
+            train_in_memory=False)
+      _, serialized = sess.run([train_op, ensemble_serialized])
+      # Validate the trained ensemble.
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      ensemble_proto.ParseFromString(serialized)
+      self.assertProtoEquals(expected_first, ensemble_proto)
+
+      # Run one more time and validate the trained ensemble.
+      _, serialized = sess.run([train_op, ensemble_serialized])
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      ensemble_proto.ParseFromString(serialized)
+      self.assertProtoEquals(expected_second, ensemble_proto)
+
+      # Third round training and validation.
+      _, serialized = sess.run([train_op, ensemble_serialized])
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      ensemble_proto.ParseFromString(serialized)
+      self.assertProtoEquals(expected_third, ensemble_proto)
+
+
+if __name__ == '__main__':
+  googletest.main()
index be8930b..60c59cb 100644 (file)
@@ -21,6 +21,8 @@ from __future__ import print_function
 # pylint: disable=unused-import,line-too-long,wildcard-import
 from tensorflow.python.estimator.canned.baseline import BaselineClassifier
 from tensorflow.python.estimator.canned.baseline import BaselineRegressor
+from tensorflow.python.estimator.canned.boosted_trees import BoostedTreesClassifier
+from tensorflow.python.estimator.canned.boosted_trees import BoostedTreesRegressor
 from tensorflow.python.estimator.canned.dnn import DNNClassifier
 from tensorflow.python.estimator.canned.dnn import DNNRegressor
 from tensorflow.python.estimator.canned.dnn_linear_combined import DNNLinearCombinedClassifier
@@ -52,6 +54,8 @@ _allowed_symbols = [
     # Canned Estimators
     'BaselineClassifier',
     'BaselineRegressor',
+    'BoostedTreesClassifier',
+    'BoostedTreesRegressor',
     'DNNClassifier',
     'DNNRegressor',
     'DNNLinearCombinedClassifier',
diff --git a/tensorflow/python/kernel_tests/boosted_trees/BUILD b/tensorflow/python/kernel_tests/boosted_trees/BUILD
new file mode 100644 (file)
index 0000000..30e6289
--- /dev/null
@@ -0,0 +1,76 @@
+# Description:
+#   Kernel tests for Boosted Trees.
+
+package(
+    default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"])  # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)
+
+tf_py_test(
+    name = "resource_ops_test",
+    size = "small",
+    srcs = ["resource_ops_test.py"],
+    additional_deps = [
+        "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+        "//tensorflow/python:boosted_trees_ops",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:resources",
+        "//tensorflow/python:training",
+        "//tensorflow/python:variables",
+    ],
+)
+
+tf_py_test(
+    name = "prediction_ops_test",
+    size = "small",
+    srcs = ["prediction_ops_test.py"],
+    additional_deps = [
+        "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:boosted_trees_ops",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:resources",
+    ],
+)
+
+tf_py_test(
+    name = "stats_ops_test",
+    size = "small",
+    srcs = ["stats_ops_test.py"],
+    additional_deps = [
+        "//tensorflow/python:boosted_trees_ops",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_test_lib",
+    ],
+)
+
+tf_py_test(
+    name = "training_ops_test",
+    size = "small",
+    srcs = ["training_ops_test.py"],
+    additional_deps = [
+        "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:boosted_trees_ops",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:resources",
+    ],
+)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/__init__.py b/tensorflow/python/kernel_tests/boosted_trees/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
new file mode 100644 (file)
index 0000000..d132f15
--- /dev/null
@@ -0,0 +1,926 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests boosted_trees prediction kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from google.protobuf import text_format
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.platform import googletest
+
+
+class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
+  """Tests prediction ops for training."""
+
+  def testCachedPredictionOnEmptyEnsemble(self):
+    """Tests that prediction on a dummy ensemble does not fail."""
+    with self.test_session() as session:
+      # Create a dummy ensemble.
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto='')
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # No previous cached values.
+      cached_tree_ids = [0, 0]
+      cached_node_ids = [0, 0]
+
+      # We have two features: 0 and 1. Values don't matter here on a dummy
+      # ensemble.
+      feature_0_values = [67, 5]
+      feature_1_values = [9, 17]
+
+      # Grow tree ensemble.
+      predict_op = boosted_trees_ops.training_predict(
+          tree_ensemble_handle,
+          max_depth=2,
+          cached_tree_ids=cached_tree_ids,
+          cached_node_ids=cached_node_ids,
+          bucketized_features=[feature_0_values, feature_1_values],
+          logits_dimension=1)
+
+      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+      # Nothing changed.
+      self.assertAllClose(cached_tree_ids, new_tree_ids)
+      self.assertAllClose(cached_node_ids, new_node_ids)
+      self.assertAllClose([[0], [0]], logits_updates)
+
+  def testNoCachedPredictionButTreeExists(self):
+    """Tests that predictions are updated once trees are added."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 15
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 1.14
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 8.79
+            }
+          }
+        }
+        tree_weights: 0.1
+        tree_metadata {
+          is_finalized: true
+          num_layers_grown: 1
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 2
+        }
+      """, tree_ensemble_config)
+
+      # Create existing ensemble with one root split
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Two examples, none were cached before.
+      cached_tree_ids = [0, 0]
+      cached_node_ids = [0, 0]
+
+      feature_0_values = [67, 5]
+
+      # Grow tree ensemble.
+      predict_op = boosted_trees_ops.training_predict(
+          tree_ensemble_handle,
+          max_depth=2,
+          cached_tree_ids=cached_tree_ids,
+          cached_node_ids=cached_node_ids,
+          bucketized_features=[feature_0_values],
+          logits_dimension=1)
+
+      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+      # We are in the first tree.
+      self.assertAllClose([0, 0], new_tree_ids)
+      self.assertAllClose([2, 1], new_node_ids)
+      self.assertAllClose([[0.1 * 8.79], [0.1 * 1.14]], logits_updates)
+
+  def testCachedPredictionIsCurrent(self):
+    """Tests that prediction based on previous node in the tree works."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 15
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+              original_leaf {
+                scalar: -2
+              }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 1.14
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 8.79
+            }
+          }
+        }
+        tree_weights: 0.1
+        tree_metadata {
+          is_finalized: true
+          num_layers_grown: 2
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 2
+        }
+      """, tree_ensemble_config)
+
+      # Create existing ensemble with one root split
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Two examples, one was cached in node 1 first, another in node 0.
+      cached_tree_ids = [0, 0]
+      cached_node_ids = [1, 2]
+
+      # We have two features: 0 and 1. Values don't matter because trees didn't
+      # change.
+      feature_0_values = [67, 5]
+      feature_1_values = [9, 17]
+
+      # Grow tree ensemble.
+      predict_op = boosted_trees_ops.training_predict(
+          tree_ensemble_handle,
+          max_depth=4,
+          cached_tree_ids=cached_tree_ids,
+          cached_node_ids=cached_node_ids,
+          bucketized_features=[feature_0_values, feature_1_values],
+          logits_dimension=1)
+
+      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+      # Nothing changed.
+      self.assertAllClose(cached_tree_ids, new_tree_ids)
+      self.assertAllClose(cached_node_ids, new_node_ids)
+      self.assertAllClose([[0], [0]], logits_updates)
+
+  def testCachedPredictionFromTheSameTree(self):
+    """Tests that prediction based on previous node in the tree works."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 15
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+              original_leaf {
+                scalar: -2
+              }
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 7
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              gain: 1.4
+              original_leaf {
+                scalar: 7.14
+              }
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 7
+              left_id: 5
+              right_id: 6
+            }
+            metadata {
+              gain: 2.7
+              original_leaf {
+                scalar: -4.375
+              }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 1.14
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 8.79
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -5.875
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -2.075
+            }
+          }
+        }
+        tree_weights: 0.1
+        tree_metadata {
+          is_finalized: true
+          num_layers_grown: 2
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 2
+        }
+      """, tree_ensemble_config)
+
+      # Create existing ensemble with one root split
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Two examples, one was cached in node 1 first, another in node 0.
+      cached_tree_ids = [0, 0]
+      cached_node_ids = [1, 0]
+
+      # We have two features: 0 and 1.
+      feature_0_values = [67, 5]
+      feature_1_values = [9, 17]
+
+      # Grow tree ensemble.
+      predict_op = boosted_trees_ops.training_predict(
+          tree_ensemble_handle,
+          max_depth=4,
+          cached_tree_ids=cached_tree_ids,
+          cached_node_ids=cached_node_ids,
+          bucketized_features=[feature_0_values, feature_1_values],
+          logits_dimension=1)
+
+      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+      # We are still in the same tree.
+      self.assertAllClose([0, 0], new_tree_ids)
+      # When using the full tree, the first example will end up in node 4,
+      # the second in node 5.
+      self.assertAllClose([4, 5], new_node_ids)
+      # Full predictions for each instance would be 8.79 and -5.875,
+      # so an update from the previous cached values lr*(7.14 and -2) would be
+      # 1.65 and -3.875, and then multiply them by 0.1 (lr)
+      self.assertAllClose([[0.1 * 1.65], [0.1 * -3.875]], logits_updates)
+
+  def testCachedPredictionFromPreviousTree(self):
+    """Tests the predictions work when we have cache from previous trees."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 28
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 1.14
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 8.79
+            }
+          }
+        }
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 26
+              left_id: 1
+              right_id: 2
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 50
+              left_id: 3
+              right_id: 4
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 7
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 5
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 6
+            }
+          }
+        }
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 34
+              left_id: 1
+              right_id: 2
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -7.0
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 5.0
+            }
+          }
+        }
+        tree_metadata {
+          is_finalized: true
+        }
+        tree_metadata {
+          is_finalized: true
+        }
+        tree_metadata {
+          is_finalized: false
+        }
+        tree_weights: 0.1
+        tree_weights: 0.1
+        tree_weights: 0.1
+      """, tree_ensemble_config)
+
+      # Create existing ensemble with one root split
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Two examples, one was cached in node 1 first, another in node 2.
+      cached_tree_ids = [0, 0]
+      cached_node_ids = [1, 0]
+
+      # We have two features: 0 and 1.
+      feature_0_values = [36, 32]
+      feature_1_values = [11, 27]
+
+      # Grow tree ensemble.
+      predict_op = boosted_trees_ops.training_predict(
+          tree_ensemble_handle,
+          max_depth=2,
+          cached_tree_ids=cached_tree_ids,
+          cached_node_ids=cached_node_ids,
+          bucketized_features=[feature_0_values, feature_1_values],
+          logits_dimension=1)
+
+      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+      # Example 1 will get to node 3 in tree 1 and node 2 of tree 2
+      # Example 2 will get to node 2 in tree 1 and node 1 of tree 2
+
+      # We are in the last tree.
+      self.assertAllClose([2, 2], new_tree_ids)
+      # When using the full tree, the first example will end up in node 4,
+      # the second in node 5.
+      self.assertAllClose([2, 1], new_node_ids)
+      # Example 1: tree 0: 8.79, tree 1: 5.0, tree 2: 5.0 = >
+      #            change = 0.1*(5.0+5.0)
+      # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
+      #            change= 0.1(1.14+7.0-7.0)
+      self.assertAllClose([[1], [0.114]], logits_updates)
+
+  def testCachedPredictionFromTheSameTreeWithPostPrunedNodes(self):
+    """Tests that prediction based on previous node in the tree works."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id:0
+              threshold: 33
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: -0.2
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.01
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 5
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              gain: 0.5
+              original_leaf {
+                scalar: 0.0143
+               }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.0553
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.0783
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 3
+          is_finalized: true
+          post_pruned_nodes_meta {
+            new_node_id: 0
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 2
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: -0.07
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: -0.083
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 3
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 4
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: -0.22
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: -0.57
+          }
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 3
+        }
+      """, tree_ensemble_config)
+
+      # Create existing ensemble.
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      cached_tree_ids = [0, 0, 0, 0, 0, 0]
+      # Leaves 3,4, 7 and 8 got deleted during post-pruning, leaves 5 and 6
+      # changed the ids to 3 and 4 respectively.
+      cached_node_ids = [3, 4, 5, 6, 7, 8]
+
+      # We have two features: 0 and 1.
+      feature_0_values = [12, 17, 35, 36, 23, 11]
+      feature_1_values = [12, 12, 17, 18, 123, 24]
+
+      # Grow tree ensemble.
+      predict_op = boosted_trees_ops.training_predict(
+          tree_ensemble_handle,
+          max_depth=3,
+          cached_tree_ids=cached_tree_ids,
+          cached_node_ids=cached_node_ids,
+          bucketized_features=[feature_0_values, feature_1_values],
+          logits_dimension=1)
+
+      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+      # We are still in the same tree.
+      self.assertAllClose([0, 0, 0, 0, 0, 0], new_tree_ids)
+      # Examples from leaves 3,4,7,8 should be in leaf 1, examples from leaf 5
+      # and 6 in leaf 3 and 4.
+      self.assertAllClose([1, 1, 3, 4, 1, 1], new_node_ids)
+
+      cached_values = [[0.08], [0.093], [0.0553], [0.0783], [0.15 + 0.08],
+                       [0.5 + 0.08]]
+      self.assertAllClose([[0.01], [0.01], [0.0553], [0.0783], [0.01], [0.01]],
+                          logits_updates + cached_values)
+
+  def testCachedPredictionFromThePreviousTreeWithPostPrunedNodes(self):
+    """Tests that prediction based on previous node in the tree works."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id:0
+              threshold: 33
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: -0.2
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.01
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 5
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              gain: 0.5
+              original_leaf {
+                scalar: 0.0143
+               }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.0553
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.0783
+            }
+          }
+        }
+        trees {
+          nodes {
+            leaf {
+              scalar: 0.55
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 3
+          is_finalized: true
+          post_pruned_nodes_meta {
+            new_node_id: 0
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 2
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: -0.07
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: -0.083
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 3
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 4
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: -0.22
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: -0.57
+          }
+        }
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 2
+          num_layers_attempted: 4
+        }
+      """, tree_ensemble_config)
+
+      # Create existing ensemble.
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      cached_tree_ids = [0, 0, 0, 0, 0, 0]
+      # Leaves 3,4, 7 and 8 got deleted during post-pruning, leaves 5 and 6
+      # changed the ids to 3 and 4 respectively.
+      cached_node_ids = [3, 4, 5, 6, 7, 8]
+
+      # We have two features: 0 and 1.
+      feature_0_values = [12, 17, 35, 36, 23, 11]
+      feature_1_values = [12, 12, 17, 18, 123, 24]
+
+      # Grow tree ensemble.
+      predict_op = boosted_trees_ops.training_predict(
+          tree_ensemble_handle,
+          max_depth=3,
+          cached_tree_ids=cached_tree_ids,
+          cached_node_ids=cached_node_ids,
+          bucketized_features=[feature_0_values, feature_1_values],
+          logits_dimension=1)
+
+      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+      # We are in the last tree.
+      self.assertAllClose([1, 1, 1, 1, 1, 1], new_tree_ids)
+      # Examples from leaves 3,4,7,8 should be in leaf 1, examples from leaf 5
+      # and 6 in leaf 3 and 4 in tree 0. For tree 1, all of the examples are in
+      # the root node.
+      self.assertAllClose([0, 0, 0, 0, 0, 0], new_node_ids)
+
+      cached_values = [[0.08], [0.093], [0.0553], [0.0783], [0.15 + 0.08],
+                       [0.5 + 0.08]]
+      root = 0.55
+      self.assertAllClose([[root + 0.01], [root + 0.01], [root + 0.0553],
+                           [root + 0.0783], [root + 0.01], [root + 0.01]],
+                          logits_updates + cached_values)
+
+  def testCachedPredictionTheWholeTreeWasPruned(self):
+    """Tests that prediction based on previous node in the tree works."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            leaf {
+              scalar: 0.00
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: true
+          post_pruned_nodes_meta {
+            new_node_id: 0
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 0
+            logit_change: -6.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 0
+            logit_change: 5.0
+          }
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 1
+        }
+      """, tree_ensemble_config)
+
+      # Create existing ensemble.
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      cached_tree_ids = [
+          0,
+          0,
+      ]
+      # The predictions were cached in 1 and 2, both were pruned to the root.
+      cached_node_ids = [1, 2]
+
+      # We have two features: 0 and 1.These are not going to be used anywhere.
+      feature_0_values = [12, 17]
+      feature_1_values = [12, 12]
+
+      # Grow tree ensemble.
+      predict_op = boosted_trees_ops.training_predict(
+          tree_ensemble_handle,
+          max_depth=1,
+          cached_tree_ids=cached_tree_ids,
+          cached_node_ids=cached_node_ids,
+          bucketized_features=[feature_0_values, feature_1_values],
+          logits_dimension=1)
+
+      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+      # We are in the last tree.
+      self.assertAllClose([0, 0], new_tree_ids)
+      self.assertAllClose([0, 0], new_node_ids)
+
+      self.assertAllClose([[-6.0], [5.0]], logits_updates)
+
+
+class PredictionOpsTest(test_util.TensorFlowTestCase):
+  """Tests prediction ops for inference."""
+
+  def testPredictionMultipleTree(self):
+    """Tests the predictions work when we have multiple trees."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 28
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 1.14
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 8.79
+            }
+          }
+        }
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 26
+              left_id: 1
+              right_id: 2
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 50
+              left_id: 3
+              right_id: 4
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 7.0
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 5.0
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 6.0
+            }
+          }
+        }
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 34
+              left_id: 1
+              right_id: 2
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -7.0
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 5.0
+            }
+          }
+        }
+        tree_weights: 0.1
+        tree_weights: 0.2
+        tree_weights: 1.0
+      """, tree_ensemble_config)
+
+      # Create existing ensemble with one root split
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      feature_0_values = [36, 32]
+      feature_1_values = [11, 27]
+
+      # Example 1: tree 0: 1.14, tree 1: 5.0, tree 2: 5.0 = >
+      #            logit = 0.1*5.0+0.2*5.0+1*5
+      # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
+      #            logit= 0.1*1.14+0.2*7.0-1*7.0
+      expected_logits = [[6.114], [-5.486]]
+
+      # Do with parallelization, e.g. EVAL
+      predict_op = boosted_trees_ops.predict(
+          tree_ensemble_handle,
+          bucketized_features=[feature_0_values, feature_1_values],
+          logits_dimension=1,
+          max_depth=2)
+
+      logits = session.run(predict_op)
+      self.assertAllClose(expected_logits, logits)
+
+      # Do without parallelization, e.g. INFER - the result is the same
+      predict_op = boosted_trees_ops.predict(
+          tree_ensemble_handle,
+          bucketized_features=[feature_0_values, feature_1_values],
+          logits_dimension=1,
+          max_depth=2)
+
+      logits = session.run(predict_op)
+      self.assertAllClose(expected_logits, logits)
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py
new file mode 100644 (file)
index 0000000..a223241
--- /dev/null
@@ -0,0 +1,228 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for boosted_trees resource kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from google.protobuf import text_format
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.platform import googletest
+
+
+class ResourceOpsTest(test_util.TensorFlowTestCase):
+  """Tests resource_ops."""
+
+  def testCreate(self):
+    with self.test_session():
+      ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
+      resources.initialize_resources(resources.shared_resources()).run()
+      stamp_token = ensemble.get_stamp_token()
+      self.assertEqual(0, stamp_token.eval())
+      (_, num_trees, num_finalized_trees,
+       num_attempted_layers) = ensemble.get_states()
+      self.assertEqual(0, num_trees.eval())
+      self.assertEqual(0, num_finalized_trees.eval())
+      self.assertEqual(0, num_attempted_layers.eval())
+
+  def testCreateWithProto(self):
+    with self.test_session():
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 4
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            bucketized_split {
+              threshold: 21
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              gain: 1.4
+              original_leaf {
+                scalar: 7.14
+              }
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 7
+              left_id: 5
+              right_id: 6
+            }
+            metadata {
+              gain: 2.7
+              original_leaf {
+                scalar: -4.375
+              }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 6.54
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 7.305
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -4.525
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -4.145
+            }
+          }
+        }
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 75
+              threshold: 21
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: -1.4
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.6
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.165
+            }
+          }
+        }
+        tree_weights: 0.15
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 2
+          is_finalized: true
+        }
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 2
+          num_layers_attempted: 6
+        }
+      """, ensemble_proto)
+      ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble',
+          stamp_token=7,
+          serialized_proto=ensemble_proto.SerializeToString())
+      resources.initialize_resources(resources.shared_resources()).run()
+      (stamp_token, num_trees, num_finalized_trees,
+       num_attempted_layers) = ensemble.get_states()
+      self.assertEqual(7, stamp_token.eval())
+      self.assertEqual(2, num_trees.eval())
+      self.assertEqual(1, num_finalized_trees.eval())
+      self.assertEqual(6, num_attempted_layers.eval())
+
+  def testSerializeDeserialize(self):
+    with self.test_session():
+      # Initialize.
+      ensemble = boosted_trees_ops.TreeEnsemble('ensemble', stamp_token=5)
+      resources.initialize_resources(resources.shared_resources()).run()
+      (stamp_token, num_trees, num_finalized_trees,
+       num_attempted_layers) = ensemble.get_states()
+      self.assertEqual(5, stamp_token.eval())
+      self.assertEqual(0, num_trees.eval())
+      self.assertEqual(0, num_finalized_trees.eval())
+      self.assertEqual(0, num_attempted_layers.eval())
+
+      # Deserialize.
+      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 75
+              threshold: 21
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: -1.4
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.6
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.165
+            }
+          }
+        }
+        tree_weights: 0.5
+        tree_metadata {
+          num_layers_grown: 4  # it's fake intentionally.
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 5
+        }
+      """, ensemble_proto)
+      with ops.control_dependencies([
+          ensemble.deserialize(
+              stamp_token=3,
+              serialized_proto=ensemble_proto.SerializeToString())
+      ]):
+        (stamp_token, num_trees, num_finalized_trees,
+         num_attempted_layers) = ensemble.get_states()
+      self.assertEqual(3, stamp_token.eval())
+      self.assertEqual(1, num_trees.eval())
+      # This reads from metadata, not really counting the layers.
+      self.assertEqual(5, num_attempted_layers.eval())
+      self.assertEqual(0, num_finalized_trees.eval())
+
+      # Serialize.
+      new_ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+      new_stamp_token, new_serialized = ensemble.serialize()
+      self.assertEqual(3, new_stamp_token.eval())
+      new_ensemble_proto.ParseFromString(new_serialized.eval())
+      self.assertProtoEquals(ensemble_proto, new_ensemble_proto)
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
new file mode 100644 (file)
index 0000000..a54cc43
--- /dev/null
@@ -0,0 +1,289 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for boosted_trees stats kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.platform import googletest
+
+
+class StatsOpsTest(test_util.TensorFlowTestCase):
+  """Tests stats_ops."""
+
+  def testCalculateBestGainsWithoutRegularization(self):
+    """Testing Gain calculation without any regularization."""
+    with self.test_session() as sess:
+      max_splits = 7
+      node_id_range = [1, 2]  # node 1 through 2 will be processed.
+      stats_summary_list = [
+          [
+              [[0., 0.], [.08, .09], [0., 0.], [0., 0.]],  # node 0; ignored
+              [[0., 0.], [.15, .36], [.06, .07], [.1, .2]],  # node 1
+              [[0., 0.], [-.33, .58], [0., 0.], [.3, .4]],  # node 2
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 3; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 4; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 5; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 6; ignored
+          ],  # feature 0
+          [
+              [[0., 0.], [0., 0.], [.08, .09], [0., 0.]],  # node 0; ignored
+              [[0., 0.], [.3, .5], [-.05, .06], [.06, .07]],  # node 1
+              [[.1, .1], [.2, .3], [-.4, .5], [.07, .08]],  # node 2
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 3; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 4; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 5; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 6; ignored
+          ],  # feature 1
+      ]  # num_features * shape=[max_splits, num_buckets, 2]
+
+      (node_ids_list, gains_list, thresholds_list, left_node_contribs_list,
+       right_node_contribs_list
+      ) = boosted_trees_ops.calculate_best_gains_per_feature(
+          node_id_range,
+          stats_summary_list,
+          l1=0.0,
+          l2=0.0,
+          tree_complexity=0.0,
+          max_splits=max_splits)
+
+      self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
+      self.assertAllClose([[0.004775, 0.41184], [0.02823, 0.41184]],
+                          sess.run(gains_list))
+      self.assertAllEqual([[1, 1], [1, 1]], sess.run(thresholds_list))
+      # The left node contrib will be later added to the previous node value to
+      # make the left node value, and the same for right node contrib.
+      self.assertAllClose([[[-.416667], [.568966]], [[-.6], [-.75]]],
+                          sess.run(left_node_contribs_list))
+      self.assertAllClose([[[-.592593], [-.75]], [[-.076923], [.568966]]],
+                          sess.run(right_node_contribs_list))
+
+  def testCalculateBestGainsWithL2(self):
+    """Testing Gain calculation with L2."""
+    with self.test_session() as sess:
+      max_splits = 7
+      node_id_range = [1, 2]  # node 1 through 2 will be processed.
+      stats_summary_list = [
+          [
+              [[0., 0.], [.08, .09], [0., 0.], [0., 0.]],  # node 0; ignored
+              [[0., 0.], [.15, .36], [.06, .07], [.1, .2]],  # node 1
+              [[0., 0.], [-.33, .58], [0., 0.], [.3, .4]],  # node 2
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 3; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 4; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 5; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 6; ignored
+          ],  # feature 0
+          [
+              [[0., 0.], [0., 0.], [.08, .09], [0., 0.]],  # node 0; ignored
+              [[0., 0.], [.3, .5], [-.05, .06], [.06, .07]],  # node 1
+              [[.1, .1], [.2, .3], [-.4, .5], [.07, .08]],  # node 2
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 3; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 4; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 5; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 6; ignored
+          ],  # feature 1
+      ]  # num_features * shape=[max_splits, num_buckets, 2]
+
+      (node_ids_list, gains_list, thresholds_list, left_node_contribs_list,
+       right_node_contribs_list
+      ) = boosted_trees_ops.calculate_best_gains_per_feature(
+          node_id_range,
+          stats_summary_list,
+          l1=0.0,
+          l2=0.1,
+          tree_complexity=0.0,
+          max_splits=max_splits)
+
+      self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
+      self.assertAllClose([[0., 0.33931375], [0.01879096, 0.33931375]],
+                          sess.run(gains_list))
+      self.assertAllEqual([[0, 1], [1, 1]], sess.run(thresholds_list))
+      # The left node contrib will be later added to the previous node value to
+      # make the left node value, and the same for right node contrib.
+      self.assertAllClose([[[0.], [.485294]], [[-.5], [-.6]]],
+                          sess.run(left_node_contribs_list))
+      self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]],
+                          sess.run(right_node_contribs_list))
+
+  def testCalculateBestGainsWithL1(self):
+    """Testing Gain calculation with L1."""
+    with self.test_session() as sess:
+      max_splits = 7
+      node_id_range = [1, 2]  # node 1 through 2 will be processed.
+      stats_summary_list = [
+          [
+              [[0., 0.], [.08, .09], [0., 0.], [0., 0.]],  # node 0; ignored
+              [[0., 0.], [.15, .36], [.06, .07], [.1, .2]],  # node 1
+              [[0., 0.], [-.33, .58], [0., 0.], [.3, .4]],  # node 2
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 3; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 4; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 5; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 6; ignored
+          ],  # feature 0
+          [
+              [[0., 0.], [0., 0.], [.08, .09], [0., 0.]],  # node 0; ignored
+              [[0., 0.], [.3, .5], [-.05, .06], [.06, .07]],  # node 1
+              [[.1, .1], [.2, .3], [-.4, .5], [.07, .08]],  # node 2
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 3; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 4; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 5; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 6; ignored
+          ],  # feature 1
+      ]  # num_features * shape=[max_splits, num_buckets, 2]
+
+      l1 = 0.1
+      (node_ids_list, gains_list, thresholds_list, left_node_contribs_list,
+       right_node_contribs_list
+      ) = boosted_trees_ops.calculate_best_gains_per_feature(
+          node_id_range,
+          stats_summary_list,
+          l1=l1,
+          l2=0.0,
+          tree_complexity=0.0,
+          max_splits=max_splits)
+
+      self.assertAllEqual([[0, 1], [1, 1]], sess.run(thresholds_list))
+
+      self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
+      self.assertAllClose([[[0.0], [0.3965517]], [[-0.4], [-0.5]]],
+                          sess.run(left_node_contribs_list))
+
+      self.assertAllClose([[[-0.3333333], [-0.5]], [[0.0], [0.396552]]],
+                          sess.run(right_node_contribs_list))
+
+      # Gain should also include an adjustment of the gradient by l1.
+      self.assertAllClose([[0.0, 0.191207], [0.01, 0.191207]],
+                          sess.run(gains_list))
+
+  def testCalculateBestGainsWithTreeComplexity(self):
+    """Testing Gain calculation with L2."""
+    with self.test_session() as sess:
+      max_splits = 7
+      node_id_range = [1, 2]  # node 1 through 2 will be processed.
+      stats_summary_list = [
+          [
+              [[0., 0.], [.08, .09], [0., 0.], [0., 0.]],  # node 0; ignored
+              [[0., 0.], [.15, .36], [.06, .07], [.1, .2]],  # node 1
+              [[0., 0.], [-.33, .58], [0., 0.], [.3, .4]],  # node 2
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 3; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 4; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 5; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 6; ignored
+          ],  # feature 0
+          [
+              [[0., 0.], [0., 0.], [.08, .09], [0., 0.]],  # node 0; ignored
+              [[0., 0.], [.3, .5], [-.05, .06], [.06, .07]],  # node 1
+              [[.1, .1], [.2, .3], [-.4, .5], [.07, .08]],  # node 2
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 3; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 4; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 5; ignored
+              [[0., 0.], [0., 0.], [0., 0.], [0., 0.]],  # node 6; ignored
+          ],  # feature 1
+      ]  # num_features * shape=[max_splits, num_buckets, 2]
+
+      l2 = 0.1
+      tree_complexity = 3.
+      (node_ids_list, gains_list, thresholds_list, left_node_contribs_list,
+       right_node_contribs_list
+      ) = boosted_trees_ops.calculate_best_gains_per_feature(
+          node_id_range,
+          stats_summary_list,
+          l1=0.0,
+          l2=l2,
+          tree_complexity=tree_complexity,
+          max_splits=max_splits)
+
+      self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
+
+      self.assertAllClose([[-3., -2.66068625], [-2.98120904, -2.66068625]],
+                          sess.run(gains_list))
+
+      self.assertAllEqual([[0, 1], [1, 1]], sess.run(thresholds_list))
+      # The left node contrib will be later added to the previous node value to
+      # make the left node value, and the same for right node contrib.
+      self.assertAllClose([[[0.], [.485294]], [[-.5], [-.6]]],
+                          sess.run(left_node_contribs_list))
+      self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]],
+                          sess.run(right_node_contribs_list))
+
+  def testMakeStatsSummarySimple(self):
+    """Simple test for MakeStatsSummary."""
+    with self.test_session():
+      self.assertAllClose([[[[1., 5.], [2., 6.]], [[3., 7.], [4., 8.]]]],
+                          boosted_trees_ops.make_stats_summary(
+                              node_ids=[0, 0, 1, 1],
+                              gradients=[[1.], [2.], [3.], [4.]],
+                              hessians=[[5.], [6.], [7.], [8.]],
+                              bucketized_features_list=[[0, 1, 0, 1]],
+                              max_splits=2,
+                              num_buckets=2).eval())
+
+  def testMakeStatsSummaryAccumulate(self):
+    """Tests that Summary actually accumulates."""
+    with self.test_session():
+      max_splits = 3
+      num_buckets = 4
+      node_ids = [1, 1, 2, 2, 1, 1, 2, 0]
+      gradients = [[.1], [.2], [.3], [-.4], [-.05], [.06], [.07], [.08]]
+      hessians = [[.2], [.3], [.4], [.5], [.06], [.07], [.08], [.09]]
+
+      # Tests a single feature.
+      bucketized_features = [[3, 1, 2, 0, 1, 2, 0, 1]]
+      result = boosted_trees_ops.make_stats_summary(
+          node_ids, gradients, hessians, bucketized_features, max_splits,
+          num_buckets)  # shape=[max_splits, num_buckets, num_features, 2]
+      self.assertAllClose(
+          [[
+              [[0., 0.], [.08, .09], [0., 0.], [0., 0.]],  # node 0
+              [[0., 0.], [.15, .36], [.06, .07], [.1, .2]],  # node 1
+              [[-.33, .58], [0., 0.], [.3, .4], [0., 0.]],  # node 2
+          ]],
+          result.eval())
+
+  def testMakeStatsSummaryMultipleFeatures(self):
+    """Tests that MakeStatsSummary works for multiple features."""
+    with self.test_session():
+      max_splits = 3
+      num_buckets = 4
+      node_ids = [1, 1, 2, 2, 1, 1, 2, 0]
+      gradients = [[.1], [.2], [.3], [-.4], [-.05], [.06], [.07], [.08]]
+      hessians = [[.2], [.3], [.4], [.5], [.06], [.07], [.08], [.09]]
+
+      # Tests multiple features.
+      # The output from another feature will stored be in 3rd dimension.
+      bucketized_features = [[3, 1, 2, 0, 1, 2, 0, 1], [0, 0, 0, 2, 2, 3, 3, 2]]
+      result = boosted_trees_ops.make_stats_summary(
+          node_ids, gradients, hessians, bucketized_features, max_splits,
+          num_buckets)  # shape=[max_splits, num_buckets, num_features, 2]
+      self.assertAllClose(
+          [
+              [
+                  [[0., 0.], [.08, .09], [0., 0.], [0., 0.]],  # node 0
+                  [[0., 0.], [.15, .36], [.06, .07], [.1, .2]],  # node 1
+                  [[-.33, .58], [0., 0.], [.3, .4], [0., 0.]],  # node 2
+              ],  # feature 0
+              [
+                  [[0., 0.], [0., 0.], [.08, .09], [0., 0.]],  # node 0
+                  [[.3, .5], [0., 0.], [-.05, .06], [.06, .07]],  # node 1
+                  [[.3, .4], [0., 0.], [-.4, .5], [.07, .08]],  # node 2
+              ],  # feature 1
+          ],
+          result.eval())
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
new file mode 100644 (file)
index 0000000..4226ff7
--- /dev/null
@@ -0,0 +1,1465 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for boosted_trees training kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from google.protobuf import text_format
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.platform import googletest
+
+
+class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
+  """Tests for growing tree ensemble from split candidates."""
+
+  def testGrowWithEmptyEnsemble(self):
+    """Test growing an empty ensemble."""
+    with self.test_session() as session:
+      # Create empty ensemble.
+      tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      feature_ids = [0, 2, 6]
+
+      # Prepare feature inputs.
+      # Note that features 1 & 3 have the same gain but different splits.
+      feature1_nodes = np.array([0], dtype=np.int32)
+      feature1_gains = np.array([7.62], dtype=np.float32)
+      feature1_thresholds = np.array([52], dtype=np.int32)
+      feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32)
+      feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32)
+
+      feature2_nodes = np.array([0], dtype=np.int32)
+      feature2_gains = np.array([0.63], dtype=np.float32)
+      feature2_thresholds = np.array([23], dtype=np.int32)
+      feature2_left_node_contribs = np.array([[-0.6]], dtype=np.float32)
+      feature2_right_node_contribs = np.array([[0.24]], dtype=np.float32)
+
+      # Feature split with the highest gain.
+      feature3_nodes = np.array([0], dtype=np.int32)
+      feature3_gains = np.array([7.65], dtype=np.float32)
+      feature3_thresholds = np.array([7], dtype=np.int32)
+      feature3_left_node_contribs = np.array([[-4.89]], dtype=np.float32)
+      feature3_right_node_contribs = np.array([[5.3]], dtype=np.float32)
+
+      # Grow tree ensemble.
+      grow_op = boosted_trees_ops.update_ensemble(
+          tree_ensemble_handle,
+          learning_rate=0.1,
+          pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
+          # Tree will be finalized now, since we will reach depth 1.
+          max_depth=1,
+          feature_ids=feature_ids,
+          node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
+          gains=[feature1_gains, feature2_gains, feature3_gains],
+          thresholds=[
+              feature1_thresholds, feature2_thresholds, feature3_thresholds
+          ],
+          left_node_contribs=[
+              feature1_left_node_contribs, feature2_left_node_contribs,
+              feature3_left_node_contribs
+          ],
+          right_node_contribs=[
+              feature1_right_node_contribs, feature2_right_node_contribs,
+              feature3_right_node_contribs
+          ])
+      session.run(grow_op)
+
+      new_stamp, serialized = session.run(tree_ensemble.serialize())
+
+      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+      tree_ensemble.ParseFromString(serialized)
+
+      # Note that since the tree is finalized, we added a new dummy tree.
+      expected_result = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 6
+              threshold: 7
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.65
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.489
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.53
+            }
+          }
+        }
+        trees {
+          nodes {
+            leaf {
+              scalar: 0.0
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: true
+        }
+        tree_metadata {
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 1
+        }
+      """
+      self.assertEqual(new_stamp, 1)
+      self.assertProtoEquals(expected_result, tree_ensemble)
+
+  def testGrowExistingEnsembleTreeNotFinalized(self):
+    """Test growing an existing ensemble with the last tree not finalized."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 4
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.714
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.4375
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 1
+        }
+      """, tree_ensemble_config)
+
+      # Create existing ensemble with one root split
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Prepare feature inputs.
+      # feature 1 only has a candidate for node 1, feature 2 has candidates
+      # for both nodes and feature 3 only has a candidate for node 2.
+
+      feature_ids = [0, 1, 0]
+
+      feature1_nodes = np.array([1], dtype=np.int32)
+      feature1_gains = np.array([1.4], dtype=np.float32)
+      feature1_thresholds = np.array([21], dtype=np.int32)
+      feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
+      feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
+
+      feature2_nodes = np.array([1, 2], dtype=np.int32)
+      feature2_gains = np.array([0.63, 2.7], dtype=np.float32)
+      feature2_thresholds = np.array([23, 7], dtype=np.int32)
+      feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
+      feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)
+
+      feature3_nodes = np.array([2], dtype=np.int32)
+      feature3_gains = np.array([1.7], dtype=np.float32)
+      feature3_thresholds = np.array([3], dtype=np.int32)
+      feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
+      feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)
+
+      # Grow tree ensemble.
+      grow_op = boosted_trees_ops.update_ensemble(
+          tree_ensemble_handle,
+          learning_rate=0.1,
+          pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
+          # tree is going to be finalized now, since we reach depth 2.
+          max_depth=2,
+          feature_ids=feature_ids,
+          node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
+          gains=[feature1_gains, feature2_gains, feature3_gains],
+          thresholds=[
+              feature1_thresholds, feature2_thresholds, feature3_thresholds
+          ],
+          left_node_contribs=[
+              feature1_left_node_contribs, feature2_left_node_contribs,
+              feature3_left_node_contribs
+          ],
+          right_node_contribs=[
+              feature1_right_node_contribs, feature2_right_node_contribs,
+              feature3_right_node_contribs
+          ])
+      session.run(grow_op)
+
+      # Expect the split for node 1 to be chosen from feature 1 and
+      # the split for node 2 to be chosen from feature 2.
+      # The grown tree should be finalized as max tree depth is 2 and we have
+      # grown 2 layers.
+      new_stamp, serialized = session.run(tree_ensemble.serialize())
+      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+      tree_ensemble.ParseFromString(serialized)
+
+      expected_result = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 4
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            bucketized_split {
+              threshold: 21
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              gain: 1.4
+              original_leaf {
+                scalar: 0.714
+              }
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 7
+              left_id: 5
+              right_id: 6
+            }
+            metadata {
+              gain: 2.7
+              original_leaf {
+                scalar: -0.4375
+              }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.114
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.879
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.5875
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.2075
+            }
+          }
+        }
+        trees {
+          nodes {
+            leaf {
+              scalar: 0.0
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_weights: 1.0
+        tree_metadata {
+          is_finalized: true
+          num_layers_grown: 2
+        }
+        tree_metadata {
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 2
+        }
+      """
+      self.assertEqual(new_stamp, 1)
+      self.assertProtoEquals(expected_result, tree_ensemble)
+
+  def testGrowExistingEnsembleTreeFinalized(self):
+    """Test growing an existing ensemble with the last tree finalized."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 4
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 7.14
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -4.375
+            }
+          }
+        }
+        trees {
+          nodes {
+            leaf {
+              scalar: 0.0
+            }
+          }
+        }
+        tree_weights: 0.15
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: true
+        }
+        tree_metadata {
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 1
+        }
+      """, tree_ensemble_config)
+
+      # Create existing ensemble with one root split
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Prepare feature inputs.
+
+      feature_ids = [75]
+
+      feature1_nodes = np.array([0], dtype=np.int32)
+      feature1_gains = np.array([-1.4], dtype=np.float32)
+      feature1_thresholds = np.array([21], dtype=np.int32)
+      feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
+      feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
+
+      # Grow tree ensemble.
+      grow_op = boosted_trees_ops.update_ensemble(
+          tree_ensemble_handle,
+          pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
+          learning_rate=0.1,
+          max_depth=2,
+          feature_ids=feature_ids,
+          node_ids=[feature1_nodes],
+          gains=[feature1_gains],
+          thresholds=[feature1_thresholds],
+          left_node_contribs=[feature1_left_node_contribs],
+          right_node_contribs=[feature1_right_node_contribs])
+      session.run(grow_op)
+
+      # Expect a new tree added, with a split on feature 75
+      new_stamp, serialized = session.run(tree_ensemble.serialize())
+      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+      tree_ensemble.ParseFromString(serialized)
+
+      expected_result = """
+       trees {
+          nodes {
+            bucketized_split {
+              feature_id: 4
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 7.14
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -4.375
+            }
+          }
+        }
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 75
+              threshold: 21
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: -1.4
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.6
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.165
+            }
+          }
+        }
+        tree_weights: 0.15
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: true
+        }
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 2
+          num_layers_attempted: 2
+        }
+      """
+      self.assertEqual(new_stamp, 1)
+      self.assertProtoEquals(expected_result, tree_ensemble)
+
+  def testPrePruning(self):
+    """Test growing an existing ensemble with pre-pruning."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 4
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 7.14
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -4.375
+            }
+          }
+        }
+        tree_weights: 0.1
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 1
+        }
+      """, tree_ensemble_config)
+
+      # Create existing ensemble with one root split
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Prepare feature inputs.
+      # For node 1, the best split is on feature 2 (gain -0.63), but the gain
+      # is negative so node 1 will not be split.
+      # For node 2, the best split is on feature 3, gain is positive.
+
+      feature_ids = [0, 1, 0]
+
+      feature1_nodes = np.array([1], dtype=np.int32)
+      feature1_gains = np.array([-1.4], dtype=np.float32)
+      feature1_thresholds = np.array([21], dtype=np.int32)
+      feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
+      feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
+
+      feature2_nodes = np.array([1, 2], dtype=np.int32)
+      feature2_gains = np.array([-0.63, 2.7], dtype=np.float32)
+      feature2_thresholds = np.array([23, 7], dtype=np.int32)
+      feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
+      feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)
+
+      feature3_nodes = np.array([2], dtype=np.int32)
+      feature3_gains = np.array([2.8], dtype=np.float32)
+      feature3_thresholds = np.array([3], dtype=np.int32)
+      feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
+      feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)
+
+      # Grow tree ensemble.
+      grow_op = boosted_trees_ops.update_ensemble(
+          tree_ensemble_handle,
+          learning_rate=0.1,
+          pruning_mode=boosted_trees_ops.PruningMode.PRE_PRUNING,
+          max_depth=3,
+          feature_ids=feature_ids,
+          node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
+          gains=[feature1_gains, feature2_gains, feature3_gains],
+          thresholds=[
+              feature1_thresholds, feature2_thresholds, feature3_thresholds
+          ],
+          left_node_contribs=[
+              feature1_left_node_contribs, feature2_left_node_contribs,
+              feature3_left_node_contribs
+          ],
+          right_node_contribs=[
+              feature1_right_node_contribs, feature2_right_node_contribs,
+              feature3_right_node_contribs
+          ])
+      session.run(grow_op)
+
+      # Expect the split for node 1 to be chosen from feature 1 and
+      # the split for node 2 to be chosen from feature 2.
+      # The grown tree should not be finalized as max tree depth is 3 and
+      # it's only grown 2 layers.
+      new_stamp, serialized = session.run(tree_ensemble.serialize())
+      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+      tree_ensemble.ParseFromString(serialized)
+
+      expected_result = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 4
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 7.14
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 3
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              gain: 2.8
+              original_leaf {
+                scalar: -4.375
+              }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -4.45
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -4.182
+            }
+          }
+        }
+        tree_weights: 0.1
+        tree_metadata {
+          is_finalized: false
+          num_layers_grown: 2
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 2
+        }
+      """
+      self.assertEqual(new_stamp, 1)
+      self.assertProtoEquals(expected_result, tree_ensemble)
+
+  def testMetadataWhenCantSplitDueToEmptySplits(self):
+    """Test that the metadata is updated even though we can't split."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 4
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.714
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.4375
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 1
+        }
+      """, tree_ensemble_config)
+
+      # Create existing ensemble with one root split
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Prepare feature inputs.
+      # feature 1 only has a candidate for node 1, feature 2 has candidates
+      # for both nodes and feature 3 only has a candidate for node 2.
+
+      # Grow tree ensemble.
+      grow_op = boosted_trees_ops.update_ensemble(
+          tree_ensemble_handle,
+          learning_rate=0.1,
+          pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
+          max_depth=2,
+          # No splits are available.
+          feature_ids=[],
+          node_ids=[],
+          gains=[],
+          thresholds=[],
+          left_node_contribs=[],
+          right_node_contribs=[])
+      session.run(grow_op)
+
+      # Expect no new splits created, but attempted (global) stats updated. Meta
+      # data for this tree should not be updated (we didn't succeed building a
+      # layer.
+      new_stamp, serialized = session.run(tree_ensemble.serialize())
+      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+      tree_ensemble.ParseFromString(serialized)
+
+      expected_result = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 4
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.714
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -0.4375
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 2
+        }
+      """
+      self.assertEqual(new_stamp, 1)
+      self.assertProtoEquals(expected_result, tree_ensemble)
+
+  def testMetadataWhenCantSplitDuePrePruning(self):
+    """Test metadata is updated correctly when no split due to prepruning."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge("""
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 4
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 7.14
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -4.375
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 1
+        }
+      """, tree_ensemble_config)
+
+      # Create existing ensemble with one root split
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Prepare feature inputs.
+      feature_ids = [0, 1, 0]
+
+      # All the gains are negative.
+      feature1_nodes = np.array([1], dtype=np.int32)
+      feature1_gains = np.array([-1.4], dtype=np.float32)
+      feature1_thresholds = np.array([21], dtype=np.int32)
+      feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
+      feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
+
+      feature2_nodes = np.array([1, 2], dtype=np.int32)
+      feature2_gains = np.array([-0.63, -2.7], dtype=np.float32)
+      feature2_thresholds = np.array([23, 7], dtype=np.int32)
+      feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
+      feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)
+
+      feature3_nodes = np.array([2], dtype=np.int32)
+      feature3_gains = np.array([-2.8], dtype=np.float32)
+      feature3_thresholds = np.array([3], dtype=np.int32)
+      feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
+      feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)
+
+      # Grow tree ensemble.
+      grow_op = boosted_trees_ops.update_ensemble(
+          tree_ensemble_handle,
+          learning_rate=0.1,
+          pruning_mode=boosted_trees_ops.PruningMode.PRE_PRUNING,
+          max_depth=3,
+          feature_ids=feature_ids,
+          node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
+          gains=[feature1_gains, feature2_gains, feature3_gains],
+          thresholds=[
+              feature1_thresholds, feature2_thresholds, feature3_thresholds
+          ],
+          left_node_contribs=[
+              feature1_left_node_contribs, feature2_left_node_contribs,
+              feature3_left_node_contribs
+          ],
+          right_node_contribs=[
+              feature1_right_node_contribs, feature2_right_node_contribs,
+              feature3_right_node_contribs
+          ])
+      session.run(grow_op)
+
+      # Expect that no new split was created because all the gains were negative
+      # Global metadata should be updated, tree metadata should not be updated.
+      new_stamp, serialized = session.run(tree_ensemble.serialize())
+      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+      tree_ensemble.ParseFromString(serialized)
+
+      expected_result = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 4
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 7.14
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -4.375
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 2
+        }
+      """
+      self.assertEqual(new_stamp, 1)
+      self.assertProtoEquals(expected_result, tree_ensemble)
+
+  def testPostPruningOfSomeNodes(self):
+    """Test growing an ensemble with post-pruning."""
+    with self.test_session() as session:
+      # Create empty ensemble.
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Prepare inputs.
+      # Second feature has larger (but still negative gain).
+      feature_ids = [0, 1]
+
+      feature1_nodes = np.array([0], dtype=np.int32)
+      feature1_gains = np.array([-1.3], dtype=np.float32)
+      feature1_thresholds = np.array([7], dtype=np.int32)
+      feature1_left_node_contribs = np.array([[0.013]], dtype=np.float32)
+      feature1_right_node_contribs = np.array([[0.0143]], dtype=np.float32)
+
+      feature2_nodes = np.array([0], dtype=np.int32)
+      feature2_gains = np.array([-0.2], dtype=np.float32)
+      feature2_thresholds = np.array([33], dtype=np.int32)
+      feature2_left_node_contribs = np.array([[0.01]], dtype=np.float32)
+      feature2_right_node_contribs = np.array([[0.0143]], dtype=np.float32)
+
+      # Grow tree ensemble.
+      grow_op = boosted_trees_ops.update_ensemble(
+          tree_ensemble_handle,
+          learning_rate=1.0,
+          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
+          max_depth=3,
+          feature_ids=feature_ids,
+          node_ids=[feature1_nodes, feature2_nodes],
+          gains=[feature1_gains, feature2_gains],
+          thresholds=[feature1_thresholds, feature2_thresholds],
+          left_node_contribs=[
+              feature1_left_node_contribs, feature2_left_node_contribs
+          ],
+          right_node_contribs=[
+              feature1_right_node_contribs, feature2_right_node_contribs
+          ])
+
+      session.run(grow_op)
+
+      # Expect the split from second features to be chosen despite the negative
+      # gain.
+      # No pruning happened just yet.
+      new_stamp, serialized = session.run(tree_ensemble.serialize())
+      res_ensemble = boosted_trees_pb2.TreeEnsemble()
+      res_ensemble.ParseFromString(serialized)
+
+      expected_result = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 33
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: -0.2
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.01
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.0143
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 1
+        }
+      """
+      self.assertEqual(new_stamp, 1)
+      self.assertProtoEquals(expected_result, res_ensemble)
+
+      # Prepare the second layer.
+      # Note that node 1 gain is negative and node 2 gain is positive.
+      feature_ids = [3]
+      feature1_nodes = np.array([1, 2], dtype=np.int32)
+      feature1_gains = np.array([-0.2, 0.5], dtype=np.float32)
+      feature1_thresholds = np.array([7, 5], dtype=np.int32)
+      feature1_left_node_contribs = np.array(
+          [[0.07], [0.041]], dtype=np.float32)
+      feature1_right_node_contribs = np.array(
+          [[0.083], [0.064]], dtype=np.float32)
+
+      # Grow tree ensemble.
+      grow_op = boosted_trees_ops.update_ensemble(
+          tree_ensemble_handle,
+          learning_rate=1.0,
+          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
+          max_depth=3,
+          feature_ids=feature_ids,
+          node_ids=[feature1_nodes],
+          gains=[feature1_gains],
+          thresholds=[feature1_thresholds],
+          left_node_contribs=[feature1_left_node_contribs],
+          right_node_contribs=[feature1_right_node_contribs])
+
+      session.run(grow_op)
+
+      # After adding this layer, the tree will not be finalized
+      new_stamp, serialized = session.run(tree_ensemble.serialize())
+      res_ensemble = boosted_trees_pb2.TreeEnsemble()
+      res_ensemble.ParseFromString(serialized)
+      expected_result = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id:1
+              threshold: 33
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: -0.2
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 3
+              threshold: 7
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              gain: -0.2
+              original_leaf {
+                scalar: 0.01
+               }
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 3
+              threshold: 5
+              left_id: 5
+              right_id: 6
+            }
+            metadata {
+              gain: 0.5
+              original_leaf {
+                scalar: 0.0143
+               }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.08
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.093
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.0553
+            }
+          }
+          nodes {
+            leaf {
+                scalar: 0.0783
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 2
+          is_finalized: false
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 2
+        }
+       """
+      self.assertEqual(new_stamp, 2)
+
+      self.assertProtoEquals(expected_result, res_ensemble)
+      # Now split the leaf 3, again with negative gain. After this layer, the
+      # tree will be finalized, and post-pruning happens. The leafs 3,4,7,8 will
+      # be pruned out.
+
+      # Prepare the third layer.
+      feature_ids = [92]
+      feature1_nodes = np.array([3], dtype=np.int32)
+      feature1_gains = np.array([-0.45], dtype=np.float32)
+      feature1_thresholds = np.array([11], dtype=np.int32)
+      feature1_left_node_contribs = np.array([[0.15]], dtype=np.float32)
+      feature1_right_node_contribs = np.array([[0.5]], dtype=np.float32)
+
+      # Grow tree ensemble.
+      grow_op = boosted_trees_ops.update_ensemble(
+          tree_ensemble_handle,
+          learning_rate=1.0,
+          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
+          max_depth=3,
+          feature_ids=feature_ids,
+          node_ids=[feature1_nodes],
+          gains=[feature1_gains],
+          thresholds=[feature1_thresholds],
+          left_node_contribs=[feature1_left_node_contribs],
+          right_node_contribs=[feature1_right_node_contribs])
+
+      session.run(grow_op)
+      # After adding this layer, the tree will be finalized
+      new_stamp, serialized = session.run(tree_ensemble.serialize())
+      res_ensemble = boosted_trees_pb2.TreeEnsemble()
+      res_ensemble.ParseFromString(serialized)
+      # Node that nodes 3, 4, 7 and 8 got deleted, so metadata stores has ids
+      # mapped to their parent node 1, with the respective change in logits.
+      expected_result = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id:1
+              threshold: 33
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: -0.2
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.01
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 3
+              threshold: 5
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              gain: 0.5
+              original_leaf {
+                scalar: 0.0143
+               }
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.0553
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.0783
+            }
+          }
+        }
+        trees {
+          nodes {
+            leaf {
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 3
+          is_finalized: true
+          post_pruned_nodes_meta {
+            new_node_id: 0
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 2
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: -0.07
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: -0.083
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 3
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 4
+            logit_change: 0.0
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: -0.22
+          }
+          post_pruned_nodes_meta {
+            new_node_id: 1
+            logit_change: -0.57
+          }
+        }
+        tree_metadata {
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 3
+        }
+       """
+      self.assertEqual(new_stamp, 3)
+      self.assertProtoEquals(expected_result, res_ensemble)
+
+  def testPostPruningOfAllNodes(self):
+    """Test growing an ensemble with post-pruning, with all nodes are pruned."""
+    with self.test_session() as session:
+      # Create empty ensemble.
+      # Create empty ensemble.
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Prepare inputs. All have negative gains.
+      feature_ids = [0, 1]
+
+      feature1_nodes = np.array([0], dtype=np.int32)
+      feature1_gains = np.array([-1.3], dtype=np.float32)
+      feature1_thresholds = np.array([7], dtype=np.int32)
+      feature1_left_node_contribs = np.array([[0.013]], dtype=np.float32)
+      feature1_right_node_contribs = np.array([[0.0143]], dtype=np.float32)
+
+      feature2_nodes = np.array([0], dtype=np.int32)
+      feature2_gains = np.array([-0.62], dtype=np.float32)
+      feature2_thresholds = np.array([33], dtype=np.int32)
+      feature2_left_node_contribs = np.array([[0.01]], dtype=np.float32)
+      feature2_right_node_contribs = np.array([[0.0143]], dtype=np.float32)
+
+      # Grow tree ensemble.
+      grow_op = boosted_trees_ops.update_ensemble(
+          tree_ensemble_handle,
+          learning_rate=1.0,
+          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
+          max_depth=2,
+          feature_ids=feature_ids,
+          node_ids=[feature1_nodes, feature2_nodes],
+          gains=[feature1_gains, feature2_gains],
+          thresholds=[feature1_thresholds, feature2_thresholds],
+          left_node_contribs=[
+              feature1_left_node_contribs, feature2_left_node_contribs
+          ],
+          right_node_contribs=[
+              feature1_right_node_contribs, feature2_right_node_contribs
+          ])
+
+      session.run(grow_op)
+
+      # Expect the split from feature 2 to be chosen despite the negative gain.
+      # The grown tree should not be finalized as max tree depth is 2 so no
+      # pruning occurs.
+      new_stamp, serialized = session.run(tree_ensemble.serialize())
+      res_ensemble = boosted_trees_pb2.TreeEnsemble()
+      res_ensemble.ParseFromString(serialized)
+
+      expected_result = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 1
+              threshold: 33
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: -0.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.01
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 0.0143
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 1
+        }
+      """
+      self.assertEqual(new_stamp, 1)
+      self.assertProtoEquals(expected_result, res_ensemble)
+
+      # Prepare inputs.
+      # All have negative gain.
+      feature_ids = [3]
+      feature1_nodes = np.array([1, 2], dtype=np.int32)
+      feature1_gains = np.array([-0.2, -0.5], dtype=np.float32)
+      feature1_thresholds = np.array([77, 79], dtype=np.int32)
+      feature1_left_node_contribs = np.array([[0.023], [0.3]], dtype=np.float32)
+      feature1_right_node_contribs = np.array(
+          [[0.012343], [24]], dtype=np.float32)
+
+      grow_op = boosted_trees_ops.update_ensemble(
+          tree_ensemble_handle,
+          learning_rate=1.0,
+          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
+          max_depth=2,
+          feature_ids=feature_ids,
+          node_ids=[feature1_nodes],
+          gains=[feature1_gains],
+          thresholds=[feature1_thresholds],
+          left_node_contribs=[feature1_left_node_contribs],
+          right_node_contribs=[feature1_right_node_contribs])
+
+      session.run(grow_op)
+
+      # Expect the split from feature 1 to be chosen despite the negative gain.
+      # The grown tree should be finalized. Since all nodes have negative gain,
+      # the whole tree is pruned.
+      new_stamp, serialized = session.run(tree_ensemble.serialize())
+      res_ensemble = boosted_trees_pb2.TreeEnsemble()
+      res_ensemble.ParseFromString(serialized)
+
+      # Expect the ensemble to be empty as post-pruning will prune
+      # the entire finalized tree.
+      self.assertEqual(new_stamp, 2)
+      self.assertProtoEquals("""
+      trees {
+        nodes {
+          leaf {
+          }
+        }
+      }
+      trees {
+        nodes {
+          leaf {
+          }
+        }
+      }
+      tree_weights: 1.0
+      tree_weights: 1.0
+      tree_metadata{
+        num_layers_grown: 2
+        is_finalized: true
+        post_pruned_nodes_meta {
+          new_node_id: 0
+          logit_change: 0.0
+        }
+        post_pruned_nodes_meta {
+          new_node_id: 0
+          logit_change: -0.01
+        }
+        post_pruned_nodes_meta {
+          new_node_id: 0
+          logit_change: -0.0143
+        }
+        post_pruned_nodes_meta {
+          new_node_id: 0
+          logit_change: -0.033
+        }
+        post_pruned_nodes_meta {
+          new_node_id: 0
+          logit_change: -0.022343
+        }
+        post_pruned_nodes_meta {
+          new_node_id: 0
+          logit_change: -0.3143
+        }
+        post_pruned_nodes_meta {
+          new_node_id: 0
+          logit_change: -24.0143
+        }
+      }
+      tree_metadata {
+      }
+      growing_metadata {
+        num_trees_attempted: 1
+        num_layers_attempted: 2
+      }
+      """, res_ensemble)
+
+  def testPostPruningChangesNothing(self):
+    """Test growing an ensemble with post-pruning with all gains >0."""
+    with self.test_session() as session:
+      # Create empty ensemble.
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # Prepare inputs.
+      # Second feature has larger (but still negative gain).
+      feature_ids = [3, 4]
+
+      feature1_nodes = np.array([0], dtype=np.int32)
+      feature1_gains = np.array([7.62], dtype=np.float32)
+      feature1_thresholds = np.array([52], dtype=np.int32)
+      feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32)
+      feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32)
+
+      feature2_nodes = np.array([0], dtype=np.int32)
+      feature2_gains = np.array([0.63], dtype=np.float32)
+      feature2_thresholds = np.array([23], dtype=np.int32)
+      feature2_left_node_contribs = np.array([[-0.6]], dtype=np.float32)
+      feature2_right_node_contribs = np.array([[0.24]], dtype=np.float32)
+
+      # Grow tree ensemble.
+      grow_op = boosted_trees_ops.update_ensemble(
+          tree_ensemble_handle,
+          learning_rate=1.0,
+          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
+          max_depth=1,
+          feature_ids=feature_ids,
+          node_ids=[feature1_nodes, feature2_nodes],
+          gains=[feature1_gains, feature2_gains],
+          thresholds=[feature1_thresholds, feature2_thresholds],
+          left_node_contribs=[
+              feature1_left_node_contribs, feature2_left_node_contribs
+          ],
+          right_node_contribs=[
+              feature1_right_node_contribs, feature2_right_node_contribs
+          ])
+
+      session.run(grow_op)
+
+      # Expect the split from the first feature to be chosen.
+      # Pruning got triggered but changed nothing.
+      new_stamp, serialized = session.run(tree_ensemble.serialize())
+      res_ensemble = boosted_trees_pb2.TreeEnsemble()
+      res_ensemble.ParseFromString(serialized)
+
+      expected_result = """
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 3
+              threshold: 52
+              left_id: 1
+              right_id: 2
+            }
+            metadata {
+              gain: 7.62
+            }
+          }
+          nodes {
+            leaf {
+              scalar: -4.375
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 7.143
+            }
+          }
+        }
+        trees {
+          nodes {
+            leaf {
+            }
+          }
+        }
+        tree_weights: 1.0
+        tree_weights: 1.0
+        tree_metadata {
+          num_layers_grown: 1
+          is_finalized: true
+        }
+        tree_metadata {
+        }
+        growing_metadata {
+          num_trees_attempted: 1
+          num_layers_attempted: 1
+        }
+      """
+      self.assertEqual(new_stamp, 1)
+      self.assertProtoEquals(expected_result, res_ensemble)
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
new file mode 100644 (file)
index 0000000..174d009
--- /dev/null
@@ -0,0 +1,160 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for boosted_trees."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_boosted_trees_ops
+from tensorflow.python.ops import resources
+
+# Re-exporting ops used by other modules.
+# pylint: disable=unused-import
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
+# pylint: enable=unused-import
+
+from tensorflow.python.training import saver
+
+
+class PruningMode(object):
+  NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3)
+
+
+class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject):
+  """SaveableObject implementation for TreeEnsemble."""
+
+  def __init__(self, resource_handle, create_op, name):
+    """Creates a _TreeEnsembleSavable object.
+
+    Args:
+      resource_handle: handle to the decision tree ensemble variable.
+      create_op: the op to initialize the variable.
+      name: the name to save the tree ensemble variable under.
+    """
+    stamp_token, serialized = (
+        gen_boosted_trees_ops.boosted_trees_serialize_ensemble(resource_handle))
+    # slice_spec is useful for saving a slice from a variable.
+    # It's not meaningful the tree ensemble variable. So we just pass an empty
+    # value.
+    slice_spec = ''
+    specs = [
+        saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec,
+                                        name + '_stamp'),
+        saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec,
+                                        name + '_serialized'),
+    ]
+    super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name)
+    self._resource_handle = resource_handle
+    self._create_op = create_op
+
+  def restore(self, restored_tensors, unused_restored_shapes):
+    """Restores the associated tree ensemble from 'restored_tensors'.
+
+    Args:
+      restored_tensors: the tensors that were loaded from a checkpoint.
+      unused_restored_shapes: the shapes this object should conform to after
+        restore. Not meaningful for trees.
+
+    Returns:
+      The operation that restores the state of the tree ensemble variable.
+    """
+    with ops.control_dependencies([self._create_op]):
+      return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble(
+          self._resource_handle,
+          stamp_token=restored_tensors[0],
+          tree_ensemble_serialized=restored_tensors[1])
+
+
+class TreeEnsemble(object):
+  """Creates TreeEnsemble resource."""
+
+  def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''):
+    with ops.name_scope(name, 'TreeEnsemble') as name:
+      self._resource_handle = (
+          gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op(
+              container='', shared_name=name, name=name))
+      create_op = gen_boosted_trees_ops.boosted_trees_create_ensemble(
+          self.resource_handle,
+          stamp_token,
+          tree_ensemble_serialized=serialized_proto)
+      is_initialized_op = (
+          gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized(
+              self._resource_handle))
+      # Adds the variable to the savable list.
+      if not is_local:
+        saveable = _TreeEnsembleSavable(self.resource_handle, create_op,
+                                        self.resource_handle.name)
+        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+      resources.register_resource(
+          self.resource_handle,
+          create_op,
+          is_initialized_op,
+          is_shared=not is_local)
+
+  @property
+  def resource_handle(self):
+    return self._resource_handle
+
+  def get_stamp_token(self):
+    """Returns the current stamp token of the resource."""
+    stamp_token, _, _, _ = (
+        gen_boosted_trees_ops.boosted_trees_get_ensemble_states(
+            self.resource_handle))
+    return stamp_token
+
+  def get_states(self):
+    """Returns states of the tree ensemble.
+
+    Returns:
+      stamp_token, num_trees, num_finalized_trees, num_attempted_layers.
+    """
+    stamp_token, num_trees, num_finalized_trees, num_attempted_layers = (
+        gen_boosted_trees_ops.boosted_trees_get_ensemble_states(
+            self.resource_handle))
+    # Use identity to give names.
+    return (array_ops.identity(stamp_token, name='stamp_token'),
+            array_ops.identity(num_trees, name='num_trees'),
+            array_ops.identity(num_finalized_trees, name='num_finalized_trees'),
+            array_ops.identity(
+                num_attempted_layers, name='num_attempted_layers'))
+
+  def serialize(self):
+    """Serializes the ensemble into proto and returns the serialized proto.
+
+    Returns:
+      stamp_token: int64 scalar Tensor to denote the stamp of the resource.
+      serialized_proto: string scalar Tensor of the serialized proto.
+    """
+    return gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
+        self.resource_handle)
+
+  def deserialize(self, stamp_token, serialized_proto):
+    """Deserialize the input proto and resets the ensemble from it.
+
+    Args:
+      stamp_token: int64 scalar Tensor to denote the stamp of the resource.
+      serialized_proto: string scalar Tensor of the serialized proto.
+
+    Returns:
+      Operation (for dependencies).
+    """
+    return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble(
+        self.resource_handle, stamp_token, serialized_proto)
index d31c375..be80c36 100644 (file)
@@ -25,14 +25,13 @@ from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import server_lib
 from tensorflow.python.util.tf_export import tf_export
 
-# This is a tuple of PS ops used by tf.estimator.Esitmator which should work in
+# This is a tuple of PS ops used by tf.estimator.Estimator which should work in
 # almost all of cases.
-STANDARD_PS_OPS = (
-    "Variable", "VariableV2", "AutoReloadVariable", "MutableHashTable",
-    "MutableHashTableV2", "MutableHashTableOfTensors",
-    "MutableHashTableOfTensorsV2", "MutableDenseHashTable",
-    "MutableDenseHashTableV2", "VarHandleOp"
-)
+STANDARD_PS_OPS = ("Variable", "VariableV2", "AutoReloadVariable",
+                   "MutableHashTable", "MutableHashTableV2",
+                   "MutableHashTableOfTensors", "MutableHashTableOfTensorsV2",
+                   "MutableDenseHashTable", "MutableDenseHashTableV2",
+                   "VarHandleOp", "BoostedTreesEnsembleResourceHandleOp")
 
 
 class _RoundRobinStrategy(object):
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
new file mode 100644 (file)
index 0000000..fd9be8c
--- /dev/null
@@ -0,0 +1,54 @@
+path: "tensorflow.estimator.BoostedTreesClassifier"
+tf_class {
+  is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
+  is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "config"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "model_dir"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "model_fn"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "params"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
+  }
+  member_method {
+    name: "evaluate"
+    argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "export_savedmodel"
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
+  }
+  member_method {
+    name: "get_variable_names"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_variable_value"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "latest_checkpoint"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "predict"
+    argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "train"
+    argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
new file mode 100644 (file)
index 0000000..6b305be
--- /dev/null
@@ -0,0 +1,54 @@
+path: "tensorflow.estimator.BoostedTreesRegressor"
+tf_class {
+  is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
+  is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "config"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "model_dir"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "model_fn"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "params"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
+  }
+  member_method {
+    name: "evaluate"
+    argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "export_savedmodel"
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
+  }
+  member_method {
+    name: "get_variable_names"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_variable_value"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "latest_checkpoint"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "predict"
+    argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "train"
+    argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+}
index a7a6cc1..4946f2c 100644 (file)
@@ -9,6 +9,14 @@ tf_module {
     mtype: "<type \'type\'>"
   }
   member {
+    name: "BoostedTreesClassifier"
+    mtype: "<type \'type\'>"
+  }
+  member {
+    name: "BoostedTreesRegressor"
+    mtype: "<type \'type\'>"
+  }
+  member {
     name: "DNNClassifier"
     mtype: "<type \'type\'>"
   }