Use functions to build dense splits. Tensorflow Function invocations share the same...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 25 May 2018 19:58:55 +0000 (12:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 25 May 2018 20:03:31 +0000 (13:03 -0700)
PiperOrigin-RevId: 198090110

tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py

index 8225318..409a2d8 100644 (file)
@@ -243,45 +243,74 @@ class DenseSplitHandler(InequalitySplitHandler):
 
   def make_splits(self, stamp_token, next_stamp_token, class_id):
     """Create the best split using the accumulated stats and flush the state."""
-    # Get the bucket boundaries
-    are_splits_ready, buckets = (
-        self._quantile_accumulator.get_buckets(stamp_token))
-    # After we receive the boundaries from previous iteration we can flush
-    # the quantile accumulator.
-    with ops.control_dependencies([buckets]):
-      flush_quantiles = self._quantile_accumulator.flush(
-          stamp_token=stamp_token, next_stamp_token=next_stamp_token)
-
-    # Get the aggregated gradients and hessians per <partition_id, feature_id>
-    # pair.
-    # In order to distribute the computation on all the PSs we use the PS that
-    # had the stats accumulator on.
-    with ops.device(None):
-      with ops.device(self._stats_accumulator.resource().device):
-        num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
-            self._stats_accumulator.flush(stamp_token, next_stamp_token))
-
-        # Put quantile and stats accumulator flushing in the dependency path.
-        are_splits_ready = control_flow_ops.with_dependencies(
-            [flush_quantiles, partition_ids], are_splits_ready)
-
-        partition_ids, gains, split_infos = (
-            split_handler_ops.build_dense_inequality_splits(
-                num_minibatches=num_minibatches,
-                bucket_boundaries=buckets,
-                partition_ids=partition_ids,
-                bucket_ids=bucket_ids,
-                gradients=gradients,
-                hessians=hessians,
-                class_id=class_id,
-                feature_column_group_id=self._feature_column_group_id,
-                l1_regularization=self._l1_regularization,
-                l2_regularization=self._l2_regularization,
-                tree_complexity_regularization=self.
-                _tree_complexity_regularization,
-                min_node_weight=self._min_node_weight,
-                multiclass_strategy=self._multiclass_strategy))
-    return (are_splits_ready, partition_ids, gains, split_infos)
+    if (self._gradient_shape == tensor_shape.scalar() and
+        self._hessian_shape == tensor_shape.scalar()):
+      handler = make_dense_split_scalar
+    else:
+      handler = make_dense_split_tensor
+
+    are_splits_ready, partition_ids, gains, split_infos = (
+        handler(self._quantile_accumulator.resource(),
+                self._stats_accumulator.resource(), stamp_token,
+                next_stamp_token, self._multiclass_strategy, class_id,
+                self._feature_column_group_id, self._l1_regularization,
+                self._l2_regularization, self._tree_complexity_regularization,
+                self._min_node_weight))
+    return are_splits_ready, partition_ids, gains, split_infos
+
+
+def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
+                      stamp_token, next_stamp_token, multiclass_strategy,
+                      class_id, feature_column_id, l1_regularization,
+                      l2_regularization, tree_complexity_regularization,
+                      min_node_weight, is_multi_dimentional):
+  """Function that builds splits for a dense feature column."""
+  # Get the bucket boundaries
+  are_splits_ready, buckets = (
+      gen_quantile_ops.quantile_accumulator_get_buckets(
+          quantile_accumulator_handles=[quantile_accumulator_handle],
+          stamp_token=stamp_token))
+  # quantile_accumulator_get_buckets returns a list of results per handle that
+  # we pass to it. In this case we're getting results just for one resource.
+  are_splits_ready = are_splits_ready[0]
+  buckets = buckets[0]
+
+  # After we receive the boundaries from previous iteration we can flush
+  # the quantile accumulator.
+  with ops.control_dependencies([buckets]):
+    flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
+        quantile_accumulator_handle=quantile_accumulator_handle,
+        stamp_token=stamp_token,
+        next_stamp_token=next_stamp_token)
+
+  if is_multi_dimentional:
+    num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+        gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
+            stats_accumulator_handle, stamp_token, next_stamp_token))
+  else:
+    num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+        gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
+            stats_accumulator_handle, stamp_token, next_stamp_token))
+
+  # Put quantile and stats accumulator flushing in the dependency path.
+  with ops.control_dependencies([flush_quantiles, partition_ids]):
+    are_splits_ready = array_ops.identity(are_splits_ready)
+  partition_ids, gains, split_infos = (
+      split_handler_ops.build_dense_inequality_splits(
+          num_minibatches=num_minibatches,
+          bucket_boundaries=buckets,
+          partition_ids=partition_ids,
+          bucket_ids=bucket_ids,
+          gradients=gradients,
+          hessians=hessians,
+          class_id=class_id,
+          feature_column_group_id=feature_column_id,
+          l1_regularization=l1_regularization,
+          l2_regularization=l2_regularization,
+          tree_complexity_regularization=tree_complexity_regularization,
+          min_node_weight=min_node_weight,
+          multiclass_strategy=multiclass_strategy))
+  return are_splits_ready, partition_ids, gains, split_infos
 
 
 class SparseSplitHandler(InequalitySplitHandler):
@@ -399,63 +428,64 @@ class SparseSplitHandler(InequalitySplitHandler):
     return are_splits_ready, partition_ids, gains, split_infos
 
 
-def _specialize_sparse_split(is_multi_dimentional):
+def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle,
+                       stamp_token, next_stamp_token, multiclass_strategy,
+                       class_id, feature_column_id, l1_regularization,
+                       l2_regularization, tree_complexity_regularization,
+                       min_node_weight, is_multi_dimentional):
+  """Function that builds splits for a sparse feature column."""
+  # Get the bucket boundaries
+  are_splits_ready, buckets = (
+      gen_quantile_ops.quantile_accumulator_get_buckets(
+          quantile_accumulator_handles=[quantile_accumulator_handle],
+          stamp_token=stamp_token))
+  # quantile_accumulator_get_buckets returns a list of results per handle that
+  # we pass to it. In this case we're getting results just for one resource.
+  are_splits_ready = are_splits_ready[0]
+  buckets = buckets[0]
+
+  # After we receive the boundaries from previous iteration we can flush
+  # the quantile accumulator.
+  with ops.control_dependencies([buckets]):
+    flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
+        quantile_accumulator_handle=quantile_accumulator_handle,
+        stamp_token=stamp_token,
+        next_stamp_token=next_stamp_token)
+
+  if is_multi_dimentional:
+    num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+        gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
+            stats_accumulator_handle, stamp_token, next_stamp_token))
+  else:
+    num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+        gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
+            stats_accumulator_handle, stamp_token, next_stamp_token))
+
+  # Put quantile and stats accumulator flushing in the dependency path.
+  with ops.control_dependencies([flush_quantiles, partition_ids]):
+    are_splits_ready = array_ops.identity(are_splits_ready)
+  partition_ids, gains, split_infos = (
+      split_handler_ops.build_sparse_inequality_splits(
+          num_minibatches=num_minibatches,
+          bucket_boundaries=buckets,
+          partition_ids=partition_ids,
+          bucket_ids=bucket_ids,
+          gradients=gradients,
+          hessians=hessians,
+          class_id=class_id,
+          feature_column_group_id=feature_column_id,
+          l1_regularization=l1_regularization,
+          l2_regularization=l2_regularization,
+          tree_complexity_regularization=tree_complexity_regularization,
+          min_node_weight=min_node_weight,
+          bias_feature_id=_BIAS_FEATURE_ID,
+          multiclass_strategy=multiclass_strategy))
+  return are_splits_ready, partition_ids, gains, split_infos
+
+
+def _specialize_make_split(func, is_multi_dimentional):
   """Builds a specialized version of the function."""
 
-  def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle,
-                         stamp_token, next_stamp_token, multiclass_strategy,
-                         class_id, feature_column_id, l1_regularization,
-                         l2_regularization, tree_complexity_regularization,
-                         min_node_weight, is_multi_dimentional):
-    """Function that builds splits for a sparse feature column."""
-    # Get the bucket boundaries
-    are_splits_ready, buckets = (
-        gen_quantile_ops.quantile_accumulator_get_buckets(
-            quantile_accumulator_handles=[quantile_accumulator_handle],
-            stamp_token=stamp_token))
-    # quantile_accumulator_get_buckets returns a list of results per handle that
-    # we pass to it. In this case we're getting results just for one resource.
-    are_splits_ready = are_splits_ready[0]
-    buckets = buckets[0]
-
-    # After we receive the boundaries from previous iteration we can flush
-    # the quantile accumulator.
-    with ops.control_dependencies([buckets]):
-      flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
-          quantile_accumulator_handle=quantile_accumulator_handle,
-          stamp_token=stamp_token,
-          next_stamp_token=next_stamp_token)
-
-    if is_multi_dimentional:
-      num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
-          gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
-              stats_accumulator_handle, stamp_token, next_stamp_token))
-    else:
-      num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
-          gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
-              stats_accumulator_handle, stamp_token, next_stamp_token))
-
-    # Put quantile and stats accumulator flushing in the dependency path.
-    with ops.control_dependencies([flush_quantiles, partition_ids]):
-      are_splits_ready = array_ops.identity(are_splits_ready)
-    partition_ids, gains, split_infos = (
-        split_handler_ops.build_sparse_inequality_splits(
-            num_minibatches=num_minibatches,
-            bucket_boundaries=buckets,
-            partition_ids=partition_ids,
-            bucket_ids=bucket_ids,
-            gradients=gradients,
-            hessians=hessians,
-            class_id=class_id,
-            feature_column_group_id=feature_column_id,
-            l1_regularization=l1_regularization,
-            l2_regularization=l2_regularization,
-            tree_complexity_regularization=tree_complexity_regularization,
-            min_node_weight=min_node_weight,
-            bias_feature_id=_BIAS_FEATURE_ID,
-            multiclass_strategy=multiclass_strategy))
-    return are_splits_ready, partition_ids, gains, split_infos
-
   @function.Defun(
       dtypes.resource,
       dtypes.resource,
@@ -474,7 +504,7 @@ def _specialize_sparse_split(is_multi_dimentional):
         l1_regularization, l2_regularization, tree_complexity_regularization,
         min_node_weight):
     """Function that builds splits for a sparse feature column."""
-    return _make_sparse_split(
+    return func(
         quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
         next_stamp_token, multiclass_strategy, class_id, feature_column_id,
         l1_regularization, l2_regularization, tree_complexity_regularization,
@@ -482,9 +512,15 @@ def _specialize_sparse_split(is_multi_dimentional):
 
   return f
 
+make_dense_split_scalar = _specialize_make_split(_make_dense_split,
+                                                 is_multi_dimentional=False)
+make_dense_split_tensor = _specialize_make_split(_make_dense_split,
+                                                 is_multi_dimentional=True)
 
-make_sparse_split_scalar = _specialize_sparse_split(is_multi_dimentional=False)
-make_sparse_split_tensor = _specialize_sparse_split(is_multi_dimentional=True)
+make_sparse_split_scalar = _specialize_make_split(_make_sparse_split,
+                                                  is_multi_dimentional=False)
+make_sparse_split_tensor = _specialize_make_split(_make_sparse_split,
+                                                  is_multi_dimentional=True)
 
 
 @function.Defun(
index c081a3f..2f2c230 100644 (file)
@@ -67,9 +67,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
       hessian_shape = tensor_shape.scalar()
       split_handler = ordinal_split_handler.DenseSplitHandler(
           l1_regularization=0.1,
-          l2_regularization=1,
-          tree_complexity_regularization=0,
-          min_node_weight=0,
+          l2_regularization=1.,
+          tree_complexity_regularization=0.,
+          min_node_weight=0.,
           epsilon=0.001,
           num_quantiles=10,
           feature_column_group_id=0,
@@ -203,10 +203,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
       hessian_shape = tensor_shape.TensorShape([2, 2])
 
       split_handler = ordinal_split_handler.DenseSplitHandler(
-          l1_regularization=0,
-          l2_regularization=1,
-          tree_complexity_regularization=0,
-          min_node_weight=0,
+          l1_regularization=0.,
+          l2_regularization=1.,
+          tree_complexity_regularization=0.,
+          min_node_weight=0.,
           epsilon=0.001,
           num_quantiles=3,
           feature_column_group_id=0,
@@ -291,10 +291,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
       hessian_shape = tensor_shape.TensorShape([2])
 
       split_handler = ordinal_split_handler.DenseSplitHandler(
-          l1_regularization=0,
-          l2_regularization=1,
-          tree_complexity_regularization=0,
-          min_node_weight=0,
+          l1_regularization=0.,
+          l2_regularization=1.,
+          tree_complexity_regularization=0.,
+          min_node_weight=0.,
           epsilon=0.001,
           num_quantiles=3,
           feature_column_group_id=0,
@@ -376,9 +376,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
 
       split_handler = ordinal_split_handler.DenseSplitHandler(
           l1_regularization=0.1,
-          l2_regularization=1,
-          tree_complexity_regularization=0,
-          min_node_weight=0,
+          l2_regularization=1.,
+          tree_complexity_regularization=0.,
+          min_node_weight=0.,
           epsilon=0.001,
           num_quantiles=10,
           feature_column_group_id=0,
@@ -451,9 +451,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
 
       split_handler = ordinal_split_handler.DenseSplitHandler(
           l1_regularization=0.1,
-          l2_regularization=1,
+          l2_regularization=1.,
           tree_complexity_regularization=0.5,
-          min_node_weight=0,
+          min_node_weight=0.,
           epsilon=0.001,
           num_quantiles=10,
           feature_column_group_id=0,
@@ -585,7 +585,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
 
       split_handler = ordinal_split_handler.DenseSplitHandler(
           l1_regularization=0.1,
-          l2_regularization=1,
+          l2_regularization=1.,
           tree_complexity_regularization=0.5,
           min_node_weight=1.5,
           epsilon=0.001,