From b6ae98b4ac1ec3051d81f3133b827d6bb305aa2b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 25 May 2018 12:58:55 -0700 Subject: [PATCH] Use functions to build dense splits. Tensorflow Function invocations share the same graph so using them reduces the graph construction overhead. PiperOrigin-RevId: 198090110 --- .../lib/learner/batch/ordinal_split_handler.py | 230 ++++++++++++--------- .../learner/batch/ordinal_split_handler_test.py | 34 +-- 2 files changed, 150 insertions(+), 114 deletions(-) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 8225318..409a2d8 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -243,45 +243,74 @@ class DenseSplitHandler(InequalitySplitHandler): def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - # Get the bucket boundaries - are_splits_ready, buckets = ( - self._quantile_accumulator.get_buckets(stamp_token)) - # After we receive the boundaries from previous iteration we can flush - # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = self._quantile_accumulator.flush( - stamp_token=stamp_token, next_stamp_token=next_stamp_token) - - # Get the aggregated gradients and hessians per - # pair. - # In order to distribute the computation on all the PSs we use the PS that - # had the stats accumulator on. - with ops.device(None): - with ops.device(self._stats_accumulator.resource().device): - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - self._stats_accumulator.flush(stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - are_splits_ready = control_flow_ops.with_dependencies( - [flush_quantiles, partition_ids], are_splits_ready) - - partition_ids, gains, split_infos = ( - split_handler_ops.build_dense_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=self._feature_column_group_id, - l1_regularization=self._l1_regularization, - l2_regularization=self._l2_regularization, - tree_complexity_regularization=self. - _tree_complexity_regularization, - min_node_weight=self._min_node_weight, - multiclass_strategy=self._multiclass_strategy)) - return (are_splits_ready, partition_ids, gains, split_infos) + if (self._gradient_shape == tensor_shape.scalar() and + self._hessian_shape == tensor_shape.scalar()): + handler = make_dense_split_scalar + else: + handler = make_dense_split_tensor + + are_splits_ready, partition_ids, gains, split_infos = ( + handler(self._quantile_accumulator.resource(), + self._stats_accumulator.resource(), stamp_token, + next_stamp_token, self._multiclass_strategy, class_id, + self._feature_column_group_id, self._l1_regularization, + self._l2_regularization, self._tree_complexity_regularization, + self._min_node_weight)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional): + """Function that builds splits for a dense feature column.""" + # Get the bucket boundaries + are_splits_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) + # quantile_accumulator_get_buckets returns a list of results per handle that + # we pass to it. In this case we're getting results just for one resource. + are_splits_ready = are_splits_ready[0] + buckets = buckets[0] + + # After we receive the boundaries from previous iteration we can flush + # the quantile accumulator. + with ops.control_dependencies([buckets]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_dense_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos class SparseSplitHandler(InequalitySplitHandler): @@ -399,63 +428,64 @@ class SparseSplitHandler(InequalitySplitHandler): return are_splits_ready, partition_ids, gains, split_infos -def _specialize_sparse_split(is_multi_dimentional): +def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional): + """Function that builds splits for a sparse feature column.""" + # Get the bucket boundaries + are_splits_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) + # quantile_accumulator_get_buckets returns a list of results per handle that + # we pass to it. In this case we're getting results just for one resource. + are_splits_ready = are_splits_ready[0] + buckets = buckets[0] + + # After we receive the boundaries from previous iteration we can flush + # the quantile accumulator. + with ops.control_dependencies([buckets]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + bias_feature_id=_BIAS_FEATURE_ID, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _specialize_make_split(func, is_multi_dimentional): """Builds a specialized version of the function.""" - def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle, - stamp_token, next_stamp_token, multiclass_strategy, - class_id, feature_column_id, l1_regularization, - l2_regularization, tree_complexity_regularization, - min_node_weight, is_multi_dimentional): - """Function that builds splits for a sparse feature column.""" - # Get the bucket boundaries - are_splits_ready, buckets = ( - gen_quantile_ops.quantile_accumulator_get_buckets( - quantile_accumulator_handles=[quantile_accumulator_handle], - stamp_token=stamp_token)) - # quantile_accumulator_get_buckets returns a list of results per handle that - # we pass to it. In this case we're getting results just for one resource. - are_splits_ready = are_splits_ready[0] - buckets = buckets[0] - - # After we receive the boundaries from previous iteration we can flush - # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( - quantile_accumulator_handle=quantile_accumulator_handle, - stamp_token=stamp_token, - next_stamp_token=next_stamp_token) - - if is_multi_dimentional: - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - gen_stats_accumulator_ops.stats_accumulator_tensor_flush( - stats_accumulator_handle, stamp_token, next_stamp_token)) - else: - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - gen_stats_accumulator_ops.stats_accumulator_scalar_flush( - stats_accumulator_handle, stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - with ops.control_dependencies([flush_quantiles, partition_ids]): - are_splits_ready = array_ops.identity(are_splits_ready) - partition_ids, gains, split_infos = ( - split_handler_ops.build_sparse_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=feature_column_id, - l1_regularization=l1_regularization, - l2_regularization=l2_regularization, - tree_complexity_regularization=tree_complexity_regularization, - min_node_weight=min_node_weight, - bias_feature_id=_BIAS_FEATURE_ID, - multiclass_strategy=multiclass_strategy)) - return are_splits_ready, partition_ids, gains, split_infos - @function.Defun( dtypes.resource, dtypes.resource, @@ -474,7 +504,7 @@ def _specialize_sparse_split(is_multi_dimentional): l1_regularization, l2_regularization, tree_complexity_regularization, min_node_weight): """Function that builds splits for a sparse feature column.""" - return _make_sparse_split( + return func( quantile_accumulator_handle, stats_accumulator_handle, stamp_token, next_stamp_token, multiclass_strategy, class_id, feature_column_id, l1_regularization, l2_regularization, tree_complexity_regularization, @@ -482,9 +512,15 @@ def _specialize_sparse_split(is_multi_dimentional): return f +make_dense_split_scalar = _specialize_make_split(_make_dense_split, + is_multi_dimentional=False) +make_dense_split_tensor = _specialize_make_split(_make_dense_split, + is_multi_dimentional=True) -make_sparse_split_scalar = _specialize_sparse_split(is_multi_dimentional=False) -make_sparse_split_tensor = _specialize_sparse_split(is_multi_dimentional=True) +make_sparse_split_scalar = _specialize_make_split(_make_sparse_split, + is_multi_dimentional=False) +make_sparse_split_tensor = _specialize_make_split(_make_sparse_split, + is_multi_dimentional=True) @function.Defun( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index c081a3f..2f2c230 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -67,9 +67,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.scalar() split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -203,10 +203,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.TensorShape([2, 2]) split_handler = ordinal_split_handler.DenseSplitHandler( - l1_regularization=0, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0., + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=3, feature_column_group_id=0, @@ -291,10 +291,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.TensorShape([2]) split_handler = ordinal_split_handler.DenseSplitHandler( - l1_regularization=0, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0., + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=3, feature_column_group_id=0, @@ -376,9 +376,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -451,9 +451,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, + l2_regularization=1., tree_complexity_regularization=0.5, - min_node_weight=0, + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -585,7 +585,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, + l2_regularization=1., tree_complexity_regularization=0.5, min_node_weight=1.5, epsilon=0.001, -- 2.7.4