def make_splits(self, stamp_token, next_stamp_token, class_id):
"""Create the best split using the accumulated stats and flush the state."""
- # Get the bucket boundaries
- are_splits_ready, buckets = (
- self._quantile_accumulator.get_buckets(stamp_token))
- # After we receive the boundaries from previous iteration we can flush
- # the quantile accumulator.
- with ops.control_dependencies([buckets]):
- flush_quantiles = self._quantile_accumulator.flush(
- stamp_token=stamp_token, next_stamp_token=next_stamp_token)
-
- # Get the aggregated gradients and hessians per <partition_id, feature_id>
- # pair.
- # In order to distribute the computation on all the PSs we use the PS that
- # had the stats accumulator on.
- with ops.device(None):
- with ops.device(self._stats_accumulator.resource().device):
- num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
- self._stats_accumulator.flush(stamp_token, next_stamp_token))
-
- # Put quantile and stats accumulator flushing in the dependency path.
- are_splits_ready = control_flow_ops.with_dependencies(
- [flush_quantiles, partition_ids], are_splits_ready)
-
- partition_ids, gains, split_infos = (
- split_handler_ops.build_dense_inequality_splits(
- num_minibatches=num_minibatches,
- bucket_boundaries=buckets,
- partition_ids=partition_ids,
- bucket_ids=bucket_ids,
- gradients=gradients,
- hessians=hessians,
- class_id=class_id,
- feature_column_group_id=self._feature_column_group_id,
- l1_regularization=self._l1_regularization,
- l2_regularization=self._l2_regularization,
- tree_complexity_regularization=self.
- _tree_complexity_regularization,
- min_node_weight=self._min_node_weight,
- multiclass_strategy=self._multiclass_strategy))
- return (are_splits_ready, partition_ids, gains, split_infos)
+ if (self._gradient_shape == tensor_shape.scalar() and
+ self._hessian_shape == tensor_shape.scalar()):
+ handler = make_dense_split_scalar
+ else:
+ handler = make_dense_split_tensor
+
+ are_splits_ready, partition_ids, gains, split_infos = (
+ handler(self._quantile_accumulator.resource(),
+ self._stats_accumulator.resource(), stamp_token,
+ next_stamp_token, self._multiclass_strategy, class_id,
+ self._feature_column_group_id, self._l1_regularization,
+ self._l2_regularization, self._tree_complexity_regularization,
+ self._min_node_weight))
+ return are_splits_ready, partition_ids, gains, split_infos
+
+
+def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy,
+ class_id, feature_column_id, l1_regularization,
+ l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional):
+ """Function that builds splits for a dense feature column."""
+ # Get the bucket boundaries
+ are_splits_ready, buckets = (
+ gen_quantile_ops.quantile_accumulator_get_buckets(
+ quantile_accumulator_handles=[quantile_accumulator_handle],
+ stamp_token=stamp_token))
+ # quantile_accumulator_get_buckets returns a list of results per handle that
+ # we pass to it. In this case we're getting results just for one resource.
+ are_splits_ready = are_splits_ready[0]
+ buckets = buckets[0]
+
+ # After we receive the boundaries from previous iteration we can flush
+ # the quantile accumulator.
+ with ops.control_dependencies([buckets]):
+ flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
+ quantile_accumulator_handle=quantile_accumulator_handle,
+ stamp_token=stamp_token,
+ next_stamp_token=next_stamp_token)
+
+ if is_multi_dimentional:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+ else:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+
+ # Put quantile and stats accumulator flushing in the dependency path.
+ with ops.control_dependencies([flush_quantiles, partition_ids]):
+ are_splits_ready = array_ops.identity(are_splits_ready)
+ partition_ids, gains, split_infos = (
+ split_handler_ops.build_dense_inequality_splits(
+ num_minibatches=num_minibatches,
+ bucket_boundaries=buckets,
+ partition_ids=partition_ids,
+ bucket_ids=bucket_ids,
+ gradients=gradients,
+ hessians=hessians,
+ class_id=class_id,
+ feature_column_group_id=feature_column_id,
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
+ min_node_weight=min_node_weight,
+ multiclass_strategy=multiclass_strategy))
+ return are_splits_ready, partition_ids, gains, split_infos
class SparseSplitHandler(InequalitySplitHandler):
return are_splits_ready, partition_ids, gains, split_infos
-def _specialize_sparse_split(is_multi_dimentional):
+def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy,
+ class_id, feature_column_id, l1_regularization,
+ l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional):
+ """Function that builds splits for a sparse feature column."""
+ # Get the bucket boundaries
+ are_splits_ready, buckets = (
+ gen_quantile_ops.quantile_accumulator_get_buckets(
+ quantile_accumulator_handles=[quantile_accumulator_handle],
+ stamp_token=stamp_token))
+ # quantile_accumulator_get_buckets returns a list of results per handle that
+ # we pass to it. In this case we're getting results just for one resource.
+ are_splits_ready = are_splits_ready[0]
+ buckets = buckets[0]
+
+ # After we receive the boundaries from previous iteration we can flush
+ # the quantile accumulator.
+ with ops.control_dependencies([buckets]):
+ flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
+ quantile_accumulator_handle=quantile_accumulator_handle,
+ stamp_token=stamp_token,
+ next_stamp_token=next_stamp_token)
+
+ if is_multi_dimentional:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+ else:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+
+ # Put quantile and stats accumulator flushing in the dependency path.
+ with ops.control_dependencies([flush_quantiles, partition_ids]):
+ are_splits_ready = array_ops.identity(are_splits_ready)
+ partition_ids, gains, split_infos = (
+ split_handler_ops.build_sparse_inequality_splits(
+ num_minibatches=num_minibatches,
+ bucket_boundaries=buckets,
+ partition_ids=partition_ids,
+ bucket_ids=bucket_ids,
+ gradients=gradients,
+ hessians=hessians,
+ class_id=class_id,
+ feature_column_group_id=feature_column_id,
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
+ min_node_weight=min_node_weight,
+ bias_feature_id=_BIAS_FEATURE_ID,
+ multiclass_strategy=multiclass_strategy))
+ return are_splits_ready, partition_ids, gains, split_infos
+
+
+def _specialize_make_split(func, is_multi_dimentional):
"""Builds a specialized version of the function."""
- def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle,
- stamp_token, next_stamp_token, multiclass_strategy,
- class_id, feature_column_id, l1_regularization,
- l2_regularization, tree_complexity_regularization,
- min_node_weight, is_multi_dimentional):
- """Function that builds splits for a sparse feature column."""
- # Get the bucket boundaries
- are_splits_ready, buckets = (
- gen_quantile_ops.quantile_accumulator_get_buckets(
- quantile_accumulator_handles=[quantile_accumulator_handle],
- stamp_token=stamp_token))
- # quantile_accumulator_get_buckets returns a list of results per handle that
- # we pass to it. In this case we're getting results just for one resource.
- are_splits_ready = are_splits_ready[0]
- buckets = buckets[0]
-
- # After we receive the boundaries from previous iteration we can flush
- # the quantile accumulator.
- with ops.control_dependencies([buckets]):
- flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
- quantile_accumulator_handle=quantile_accumulator_handle,
- stamp_token=stamp_token,
- next_stamp_token=next_stamp_token)
-
- if is_multi_dimentional:
- num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
- gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
- stats_accumulator_handle, stamp_token, next_stamp_token))
- else:
- num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
- gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
- stats_accumulator_handle, stamp_token, next_stamp_token))
-
- # Put quantile and stats accumulator flushing in the dependency path.
- with ops.control_dependencies([flush_quantiles, partition_ids]):
- are_splits_ready = array_ops.identity(are_splits_ready)
- partition_ids, gains, split_infos = (
- split_handler_ops.build_sparse_inequality_splits(
- num_minibatches=num_minibatches,
- bucket_boundaries=buckets,
- partition_ids=partition_ids,
- bucket_ids=bucket_ids,
- gradients=gradients,
- hessians=hessians,
- class_id=class_id,
- feature_column_group_id=feature_column_id,
- l1_regularization=l1_regularization,
- l2_regularization=l2_regularization,
- tree_complexity_regularization=tree_complexity_regularization,
- min_node_weight=min_node_weight,
- bias_feature_id=_BIAS_FEATURE_ID,
- multiclass_strategy=multiclass_strategy))
- return are_splits_ready, partition_ids, gains, split_infos
-
@function.Defun(
dtypes.resource,
dtypes.resource,
l1_regularization, l2_regularization, tree_complexity_regularization,
min_node_weight):
"""Function that builds splits for a sparse feature column."""
- return _make_sparse_split(
+ return func(
quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
next_stamp_token, multiclass_strategy, class_id, feature_column_id,
l1_regularization, l2_regularization, tree_complexity_regularization,
return f
+make_dense_split_scalar = _specialize_make_split(_make_dense_split,
+ is_multi_dimentional=False)
+make_dense_split_tensor = _specialize_make_split(_make_dense_split,
+ is_multi_dimentional=True)
-make_sparse_split_scalar = _specialize_sparse_split(is_multi_dimentional=False)
-make_sparse_split_tensor = _specialize_sparse_split(is_multi_dimentional=True)
+make_sparse_split_scalar = _specialize_make_split(_make_sparse_split,
+ is_multi_dimentional=False)
+make_sparse_split_tensor = _specialize_make_split(_make_sparse_split,
+ is_multi_dimentional=True)
@function.Defun(