From: A. Unique TensorFlower Date: Tue, 22 May 2018 18:02:30 +0000 (-0700) Subject: * Remove the bias centering graph if it is turned off. X-Git-Tag: upstream/v1.9.0_rc1~38^2~4^2~208 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=dd2f3ebe3ede1e7b89819f40f53fdfb6c0433af0;p=platform%2Fupstream%2Ftensorflow.git * Remove the bias centering graph if it is turned off. * Create consts once. Otherwise each time the constant is passed to an Op, a new Const op is created. * Speed up the graph construction by using a functions to build splits. PiperOrigin-RevId: 197590220 --- diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 04e3226..401bec8 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -43,47 +43,60 @@ namespace { const int32 DUMMY_FEATURE_DIMENSION = -1; } // namespace -class BaseBuildSplitOp : public OpKernel { +class SplitBuilderState { public: - explicit BaseBuildSplitOp(OpKernelConstruction* const context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("feature_column_group_id", - &feature_column_group_id_)); + explicit SplitBuilderState(OpKernelContext* const context) { + const Tensor* l1_regularization_t; OP_REQUIRES_OK(context, - context->GetAttr("l1_regularization", &l1_regularization_)); + context->input("l1_regularization", &l1_regularization_t)); + const Tensor* l2_regularization_t; OP_REQUIRES_OK(context, - context->GetAttr("l2_regularization", &l2_regularization_)); - OP_REQUIRES_OK(context, context->GetAttr("tree_complexity_regularization", - &tree_complexity_regularization_)); + context->input("l2_regularization", &l2_regularization_t)); + const Tensor* tree_complexity_regularization_t; + OP_REQUIRES_OK(context, context->input("tree_complexity_regularization", + &tree_complexity_regularization_t)); + const Tensor* min_node_weight_t; OP_REQUIRES_OK(context, - context->GetAttr("min_node_weight", &min_node_weight_)); + context->input("min_node_weight", &min_node_weight_t)); - int strategy; - OP_REQUIRES_OK(context, context->GetAttr("multiclass_strategy", &strategy)); + const Tensor* feature_column_group_id_t; + OP_REQUIRES_OK(context, context->input("feature_column_group_id", + &feature_column_group_id_t)); + + const Tensor* multiclass_strategy_t; + OP_REQUIRES_OK( + context, context->input("multiclass_strategy", &multiclass_strategy_t)); + int strategy = multiclass_strategy_t->scalar()(); OP_REQUIRES( context, boosted_trees::learner::LearnerConfig_MultiClassStrategy_IsValid( strategy), errors::InvalidArgument("Wrong multiclass strategy passed.")); - multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy); - } - NodeStats ComputeNodeStats(const GradientStats& grad_stats) { - return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_, - multiclass_strategy_, grad_stats); - } + multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy); - void ReadClassId(OpKernelContext* const context, int32* class_id) { const Tensor* class_id_t; OP_REQUIRES_OK(context, context->input("class_id", &class_id_t)); OP_REQUIRES(context, TensorShapeUtils::IsScalar(class_id_t->shape()), errors::InvalidArgument("class_id must be a scalar.")); - *class_id = class_id_t->scalar()(); + class_id_ = class_id_t->scalar()(); + + l1_regularization_ = l1_regularization_t->scalar()(); + l2_regularization_ = l2_regularization_t->scalar()(); + tree_complexity_regularization_ = + tree_complexity_regularization_t->scalar()(); + min_node_weight_ = min_node_weight_t->scalar()(); + feature_column_group_id_ = feature_column_group_id_t->scalar()(); + } + + NodeStats ComputeNodeStats(const GradientStats& grad_stats) { + return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_, + multiclass_strategy_, grad_stats); } - void FillLeaf(const int class_id, const NodeStats& best_node_stats, + void FillLeaf(const NodeStats& best_node_stats, boosted_trees::trees::Leaf* leaf) const { - if (class_id == -1) { + if (class_id_ == -1) { // This would be the case either for TREE_PER_CLASS with only 2 classes, // or for other multiclass strategies. for (float f : best_node_stats.weight_contribution) { @@ -93,25 +106,31 @@ class BaseBuildSplitOp : public OpKernel { CHECK(best_node_stats.weight_contribution.size() == 1) << "Weight contribution size = " << best_node_stats.weight_contribution.size(); - leaf->mutable_sparse_vector()->add_index(class_id); + leaf->mutable_sparse_vector()->add_index(class_id_); leaf->mutable_sparse_vector()->add_value( best_node_stats.weight_contribution[0]); } } - protected: + int32 feature_column_group_id() { return feature_column_group_id_; } + float tree_complexity_regularization() { + return tree_complexity_regularization_; + } + + private: LearnerConfig_MultiClassStrategy multiclass_strategy_; - int32 feature_column_group_id_; float l1_regularization_; float l2_regularization_; - float min_node_weight_; float tree_complexity_regularization_; + float min_node_weight_; + int32 class_id_; + int32 feature_column_group_id_; }; -class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { +class BuildDenseInequalitySplitsOp : public OpKernel { public: explicit BuildDenseInequalitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) {} + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -139,9 +158,6 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); - // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; partition_boundaries.push_back(0); @@ -185,6 +201,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[root_idx]; @@ -196,7 +213,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats(*gradients_t, *hessians_t, bucket_idx); } root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); int32 best_bucket_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -206,10 +223,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats g(*gradients_t, *hessians_t, bucket_idx); g *= normalizer_ratio; left_gradient_stats += g; - NodeStats left_stats = ComputeNodeStats(left_gradient_stats); + NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats right_stats = ComputeNodeStats(right_gradient_stats); + NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -220,18 +237,18 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { SplitInfo split_info; auto* dense_split = split_info.mutable_split_node()->mutable_dense_float_binary_split(); - dense_split->set_feature_column(feature_column_group_id_); + dense_split->set_feature_column(state.feature_column_group_id()); dense_split->set_threshold( bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(start_index); } } @@ -239,13 +256,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU), BuildDenseInequalitySplitsOp); -class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { +class BuildSparseInequalitySplitsOp : public OpKernel { public: explicit BuildSparseInequalitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) { - OP_REQUIRES_OK(context, - context->GetAttr("bias_feature_id", &bias_feature_id_)); - } + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -275,8 +289,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); + const Tensor* bias_feature_id_t; + OP_REQUIRES_OK(context, + context->input("bias_feature_id", &bias_feature_id_t)); + int64 bias_feature_id = bias_feature_id_t->scalar()(); // For each partition (tree node), store starting index for each dimension. PartitionAndDimensionBoundaries partition_boundaries; @@ -354,6 +370,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); // For each tree node that needs to be split. for (int root_idx = 0; root_idx < num_elements; ++root_idx) { const auto& dimension_boundaries = @@ -372,7 +389,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { OP_REQUIRES( context, - bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id_, + bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id, errors::InvalidArgument("Bias feature ID missing.")); // Dimension for bias feature is always 0 @@ -388,7 +405,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats root_gradient_stats(*gradients_t, *hessians_t, bias_start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); // Iterate through dimensions. for (int j = 0; j < dimension_boundaries.size() - 1; ++j) { @@ -408,7 +425,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { << bucket_ids_and_dimensions(start_index, 1) << " and for " << bucket_ids_and_dimensions(end_index - 1, 0) << " " << bucket_ids_and_dimensions(end_index - 1, 1); - if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id_) { + if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) { // 0-dimension case which has a first bucket for catch all feature. CHECK(bucket_ids_and_dimensions(start_index, 1) == 0) << "Dimension of bias feature should be 0"; @@ -447,10 +464,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { present_gradient_stats - left_gradient_stats; { - NodeStats left_stats_default_left = - ComputeNodeStats(root_gradient_stats - right_gradient_stats); + NodeStats left_stats_default_left = state.ComputeNodeStats( + root_gradient_stats - right_gradient_stats); NodeStats right_stats_default_left = - ComputeNodeStats(right_gradient_stats); + state.ComputeNodeStats(right_gradient_stats); if (left_stats_default_left.gain + right_stats_default_left.gain > best_gain) { best_gain = @@ -466,9 +483,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // enough missing examples. if (!fixed_default_direction) { NodeStats left_stats_default_right = - ComputeNodeStats(left_gradient_stats); - NodeStats right_stats_default_right = - ComputeNodeStats(root_gradient_stats - left_gradient_stats); + state.ComputeNodeStats(left_gradient_stats); + NodeStats right_stats_default_right = state.ComputeNodeStats( + root_gradient_stats - left_gradient_stats); if (left_stats_default_right.gain + right_stats_default_right.gain > best_gain) { best_gain = left_stats_default_right.gain + @@ -494,7 +511,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { ->mutable_sparse_float_binary_split_default_left() ->mutable_split(); } - dense_split->set_feature_column(feature_column_group_id_); + dense_split->set_feature_column(state.feature_column_group_id()); // Set the feature index for the best feature column. const int64 best_dimension_id = bucket_ids_and_dimensions(best_element_idx, 1); @@ -505,11 +522,11 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(bias_start_index); } } @@ -526,19 +543,14 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // For each partition, store start indices of feature column dimensions. typedef std::vector> PartitionAndDimensionBoundaries; - - int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER(Name("BuildSparseInequalitySplits").Device(DEVICE_CPU), BuildSparseInequalitySplitsOp); -class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { +class BuildCategoricalEqualitySplitsOp : public OpKernel { public: explicit BuildCategoricalEqualitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) { - OP_REQUIRES_OK(context, - context->GetAttr("bias_feature_id", &bias_feature_id_)); - } + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -561,8 +573,10 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); + const Tensor* bias_feature_id_t; + OP_REQUIRES_OK(context, + context->input("bias_feature_id", &bias_feature_id_t)); + int64 bias_feature_id = bias_feature_id_t->scalar()(); // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; @@ -605,16 +619,17 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[non_empty_partitions[root_idx]]; int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1]; // First feature ID in each partition should be the bias feature. - OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id_, + OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id, errors::InvalidArgument("Bias feature ID missing.")); GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); int32 best_feature_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -625,8 +640,8 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { left_gradient_stats *= normalizer_ratio; GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats left_stats = ComputeNodeStats(left_gradient_stats); - NodeStats right_stats = ComputeNodeStats(right_gradient_stats); + NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); + NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -637,21 +652,18 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { SplitInfo split_info; auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); - equality_split->set_feature_column(feature_column_group_id_); + equality_split->set_feature_column(state.feature_column_group_id()); equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(start_index); } } - - private: - int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index f06b73c..23f4021 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -64,6 +64,8 @@ from __future__ import print_function import re from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler +from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops +from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops from tensorflow.contrib.boosted_trees.python.ops import quantile_ops from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops @@ -72,6 +74,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -327,9 +330,6 @@ class SparseSplitHandler(InequalitySplitHandler): multiclass_strategy=multiclass_strategy, init_stamp_token=init_stamp_token, name=name) - # Register sparse_make_stats_update function as an Op to the graph. - g = ops.get_default_graph() - sparse_make_stats_update.add_to_graph(g) self._sparse_float_column = sparse_float_column def scheduled_reads(self): @@ -361,8 +361,8 @@ class SparseSplitHandler(InequalitySplitHandler): are_buckets_ready, buckets = scheduled_reads[0] with ops.name_scope(self._name, "SparseSplitHandler"): (quantile_indices, quantile_values, quantile_shapes, quantile_weights, - example_partition_ids, - feature_ids, gradients, hessians) = sparse_make_stats_update( + example_partition_ids, feature_ids, gradients, + hessians) = sparse_make_stats_update( is_active, are_buckets_ready, self._sparse_float_column.indices, self._sparse_float_column.values, self._sparse_float_column.dense_shape, buckets, @@ -379,42 +379,104 @@ class SparseSplitHandler(InequalitySplitHandler): def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" + if (self._gradient_shape == tensor_shape.scalar() and + self._hessian_shape == tensor_shape.scalar()): + handler = make_sparse_split_scalar + else: + handler = make_sparse_split_tensor + + are_splits_ready, partition_ids, gains, split_infos = ( + handler(self._quantile_accumulator.resource(), + self._stats_accumulator.resource(), stamp_token, + next_stamp_token, self._multiclass_strategy, class_id, + self._feature_column_group_id, self._l1_regularization, + self._l2_regularization, self._tree_complexity_regularization, + self._min_node_weight)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _specialize_sparse_split(is_multi_dimentional): + """Builds a specialized version of the function.""" + + def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional): + """Function that builds splits for a sparse feature column.""" # Get the bucket boundaries are_splits_ready, buckets = ( - self._quantile_accumulator.get_buckets(stamp_token)) + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) # After we receive the boundaries from previous iteration we can flush # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = self._quantile_accumulator.flush( - stamp_token=stamp_token, next_stamp_token=next_stamp_token) - - with ops.device(None): - with ops.device(self._stats_accumulator.resource().device): - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - self._stats_accumulator.flush(stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - are_splits_ready = control_flow_ops.with_dependencies( - [flush_quantiles, partition_ids], are_splits_ready) - partition_ids, gains, split_infos = ( - split_handler_ops.build_sparse_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=self._feature_column_group_id, - l1_regularization=self._l1_regularization, - l2_regularization=self._l2_regularization, - tree_complexity_regularization=self. - _tree_complexity_regularization, - min_node_weight=self._min_node_weight, - bias_feature_id=_BIAS_FEATURE_ID, - multiclass_strategy=self._multiclass_strategy)) - return (are_splits_ready, partition_ids, gains, split_infos) + with ops.control_dependencies([buckets[0]]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets[0], + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + bias_feature_id=_BIAS_FEATURE_ID, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos + + @function.Defun( + dtypes.resource, + dtypes.resource, + dtypes.int64, + dtypes.int64, + dtypes.int32, + dtypes.int32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + noinline=True) + def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight): + """Function that builds splits for a sparse feature column.""" + return _make_sparse_split( + quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional) + + return f + + +make_sparse_split_scalar = _specialize_sparse_split(is_multi_dimentional=False) +make_sparse_split_tensor = _specialize_sparse_split(is_multi_dimentional=True) @function.Defun( @@ -540,8 +602,9 @@ def sparse_make_stats_update( empty_float = constant_op.constant([], dtype=dtypes.float32) handler_not_active = (constant_op.constant( - [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant( - [0, 1], dtype=dtypes.int64), empty_float) + [], dtype=dtypes.int64, shape=[0, 2]), empty_float, + constant_op.constant([0, 1], dtype=dtypes.int64), + empty_float) handler_active = (sparse_column_indices, sparse_column_values, sparse_column_shape, weights) quantile_indices, quantile_values, quantile_shape, quantile_weights = ( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 54d0301..c081a3f 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import split_info_pb2 @@ -92,7 +94,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -105,7 +109,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -227,7 +231,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -240,7 +246,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -313,7 +319,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -326,7 +333,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -396,7 +403,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, False])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -409,7 +417,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([False, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -470,7 +478,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -483,7 +492,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -603,7 +612,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -616,7 +626,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -685,10 +695,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -713,8 +723,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] - + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -727,7 +737,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -811,10 +821,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -839,7 +849,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -853,7 +864,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -905,10 +916,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -933,7 +944,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -947,7 +959,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -996,10 +1008,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1024,7 +1036,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, False])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1038,7 +1051,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([False, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -1065,10 +1078,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1096,7 +1109,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1110,7 +1124,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -1138,10 +1152,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1166,7 +1180,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1180,7 +1195,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 5d0ebbf..ca5c7f3 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -23,12 +23,6 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; REGISTER_OP("BuildDenseInequalitySplits") - .Attr("feature_column_group_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("bucket_ids: int64") @@ -36,6 +30,12 @@ REGISTER_OP("BuildDenseInequalitySplits") .Input("hessians: float32") .Input("bucket_boundaries: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -73,6 +73,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -81,13 +92,6 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); REGISTER_OP("BuildSparseInequalitySplits") - .Attr("feature_column_group_id: int") - .Attr("bias_feature_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("bucket_ids: int64") @@ -95,6 +99,13 @@ REGISTER_OP("BuildSparseInequalitySplits") .Input("hessians: float32") .Input("bucket_boundaries: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("bias_feature_id: int64") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -133,6 +144,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -141,19 +163,19 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); REGISTER_OP("BuildCategoricalEqualitySplits") - .Attr("feature_column_group_id: int") - .Attr("bias_feature_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("feature_ids: int64") .Input("gradients: float32") .Input("hessians: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("bias_feature_id: int64") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -188,6 +210,17 @@ partition_ids: A rank 1 tensor of partition IDs. feature_ids: A rank 2 tensor of feature IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -196,4 +229,3 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); } // namespace tensorflow - // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py index 7a5f329..8434209 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py +++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py @@ -20,6 +20,8 @@ from __future__ import print_function import abc import collections +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -60,6 +62,7 @@ def _move_tensors(tensors, device): """Moves a list of tensors to a device by concatenating/splitting them.""" # Reset the device setting to avoid weird interactions with device merging # logic. + zero = constant_op.constant(0, dtype=dtypes.int32) with ops.device(None): if all(tensor.shape == tensor_shape.scalar() for tensor in tensors): with ops.device(tensors[0].device): @@ -68,12 +71,11 @@ def _move_tensors(tensors, device): return array_ops.unstack(values) else: with ops.device(tensors[0].device): - sizes = array_ops.stack( - [array_ops.shape(tensor)[0] for tensor in tensors]) - values = array_ops.concat(tensors, axis=0) + sizes = array_ops.stack(array_ops.shape_n(tensors))[:, 0] + values = array_ops.concat(tensors, axis=zero) with ops.device(device): sizes = array_ops.unstack(sizes) - return list(array_ops.split(values, sizes, axis=0)) + return list(array_ops.split(values, sizes, axis=zero)) def _scheduled_stamp_resource_op_runner(batch, stamp): diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 50cc00a..19b6b32 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -201,3 +201,6 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): stamp_token=stamp_token, next_stamp_token=next_stamp_token) return result + + def resource(self): + return self._quantile_accumulator_handle diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 08c1dcd..c725f32 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -180,8 +180,7 @@ def extract_features(features, feature_columns, use_core_columns): elif isinstance(fc, feature_column_lib._EmbeddingColumn): # pylint: enable=protected-access transformed_features[fc.name] = fc_core.input_layer( - features, [fc], - weight_collections=[scope]) + features, [fc], weight_collections=[scope]) else: result = feature_column_ops.transform_features(features, [fc]) if len(result) > 1: @@ -334,10 +333,12 @@ class GradientBoostedDecisionTreeModel(object): self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._attempted_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + initial_value=array_ops.zeros([], dtypes.int64), + trainable=False, name="attempted_trees") self._finalized_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + initial_value=array_ops.zeros([], dtypes.int64), + trainable=False, name="finalized_trees") if not features: raise ValueError("Features dictionary must be specified.") @@ -354,9 +355,10 @@ class GradientBoostedDecisionTreeModel(object): self._sparse_int_indices = sparse_int_indices self._sparse_int_values = sparse_int_values self._sparse_int_shapes = sparse_int_shapes - self._reduce_dim = (self._learner_config.multi_class_strategy == - learner_pb2.LearnerConfig.TREE_PER_CLASS and - learner_config.num_classes == 2) + self._reduce_dim = ( + self._learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.TREE_PER_CLASS and + learner_config.num_classes == 2) def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): """Runs prediction and returns a dictionary of the prediction results. @@ -374,8 +376,8 @@ class GradientBoostedDecisionTreeModel(object): ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, ensemble_stamp) num_handlers = ( - len(self._dense_floats) + len(self._sparse_float_shapes) + - len(self._sparse_int_shapes)) + len(self._dense_floats) + len(self._sparse_float_shapes) + len( + self._sparse_int_shapes)) # Used during feature selection. used_handlers = model_ops.tree_ensemble_used_handlers( ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers) @@ -432,8 +434,9 @@ class GradientBoostedDecisionTreeModel(object): # Use the current ensemble to predict on the current batch of input. # For faster prediction we check if the inputs are on the same device # as the model. If not, we create a copy of the model on the worker. - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) + input_deps = ( + self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) if not input_deps: raise ValueError("No input tensors for prediction.") @@ -500,8 +503,9 @@ class GradientBoostedDecisionTreeModel(object): ValueError: if inputs are not valid. """ # Get the worker device from input dependencies. - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) + input_deps = ( + self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) worker_device = input_deps[0].device # Get tensors relevant for training and form the loss. @@ -517,7 +521,7 @@ class GradientBoostedDecisionTreeModel(object): aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy - class_id = -1 + class_id = constant_op.constant(-1, dtype=dtypes.int32) # Handle different multiclass strategies. if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS: # We build one vs rest trees. @@ -571,31 +575,39 @@ class GradientBoostedDecisionTreeModel(object): # Get the weights for each example for quantiles calculation, weights = self._get_weights(hessian_shape, squeezed_hessians) - regularization_config = self._learner_config.regularization - min_node_weight = self._learner_config.constraints.min_node_weight # Create all handlers ensuring resources are evenly allocated across PS. fc_name_idx = 0 handlers = [] init_stamp_token = constant_op.constant(0, dtype=dtypes.int64) + l1_regularization = constant_op.constant( + self._learner_config.regularization.l1, dtypes.float32) + l2_regularization = constant_op.constant( + self._learner_config.regularization.l2, dtypes.float32) + tree_complexity_regularization = constant_op.constant( + self._learner_config.regularization.tree_complexity, dtypes.float32) + min_node_weight = constant_op.constant( + self._learner_config.constraints.min_node_weight, dtypes.float32) + epsilon = 0.01 + num_quantiles = 100 + strategy_tensor = constant_op.constant(strategy) with ops.device(self._get_replica_device_setter(worker_device)): # Create handlers for dense float columns for dense_float_column_idx in range(len(self._dense_floats)): fc_name = self._fc_names[fc_name_idx] handlers.append( ordinal_split_handler.DenseSplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=dense_float_column_idx, - epsilon=0.01, - num_quantiles=100, + epsilon=epsilon, + num_quantiles=num_quantiles, dense_float_column=self._dense_floats[dense_float_column_idx], name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -604,14 +616,13 @@ class GradientBoostedDecisionTreeModel(object): fc_name = self._fc_names[fc_name_idx] handlers.append( ordinal_split_handler.SparseSplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=sparse_float_column_idx, - epsilon=0.01, - num_quantiles=100, + epsilon=epsilon, + num_quantiles=num_quantiles, sparse_float_column=sparse_tensor.SparseTensor( self._sparse_float_indices[sparse_float_column_idx], self._sparse_float_values[sparse_float_column_idx], @@ -619,7 +630,7 @@ class GradientBoostedDecisionTreeModel(object): name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -628,10 +639,9 @@ class GradientBoostedDecisionTreeModel(object): fc_name = self._fc_names[fc_name_idx] handlers.append( categorical_split_handler.EqualitySplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=sparse_int_column_idx, sparse_int_column=sparse_tensor.SparseTensor( @@ -641,7 +651,7 @@ class GradientBoostedDecisionTreeModel(object): name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -694,11 +704,11 @@ class GradientBoostedDecisionTreeModel(object): name="continue_centering", trainable=False) stats_update_ops.append( - control_flow_ops.cond(continue_centering, - self._make_update_bias_stats_fn( - ensemble_stamp, predictions, gradients, - bias_stats_accumulator), - control_flow_ops.no_op)) + control_flow_ops.cond( + continue_centering, + self._make_update_bias_stats_fn(ensemble_stamp, predictions, + gradients, bias_stats_accumulator), + control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -720,8 +730,8 @@ class GradientBoostedDecisionTreeModel(object): shape=[len(handlers)], seed=[seed + 1, 1]) active_handlers = array_ops.stack( [active_handlers_current_layer, active_handlers_next_layer], axis=1) - active_handlers = (active_handlers < - self._learner_config.feature_fraction_per_level) + active_handlers = ( + active_handlers < self._learner_config.feature_fraction_per_level) elif subsampling_type == "feature_fraction_per_tree": seed = predictions_dict[NUM_TREES_ATTEMPTED] active_handlers_current_layer = stateless.stateless_random_uniform( @@ -729,9 +739,12 @@ class GradientBoostedDecisionTreeModel(object): active_handlers_current_layer = ( active_handlers_current_layer < self._learner_config.feature_fraction_per_tree) - active_handlers = array_ops.stack([ - active_handlers_current_layer, - array_ops.ones([len(handlers)], dtype=dtypes.bool)], axis=1) + active_handlers = array_ops.stack( + [ + active_handlers_current_layer, + array_ops.ones([len(handlers)], dtype=dtypes.bool) + ], + axis=1) else: active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool) @@ -760,6 +773,7 @@ class GradientBoostedDecisionTreeModel(object): empty_hessians = constant_op.constant( [], dtype=dtypes.float32, shape=empty_hess_shape) + active_handlers = array_ops.unstack(active_handlers, axis=0) for handler_idx in range(len(handlers)): handler = handlers[handler_idx] is_active = active_handlers[handler_idx] @@ -971,7 +985,7 @@ class GradientBoostedDecisionTreeModel(object): # This is a workaround for the slowness of graph building in tf.cond. # See (b/36554864). split_sizes = array_ops.reshape( - array_ops.shape_n(partition_ids_list), [-1]) + array_ops.shape_n(partition_ids_list), [len(partition_ids_list)]) partition_ids = array_ops.concat(partition_ids_list, axis=0) gains = array_ops.concat(gains_list, axis=0) split_infos = array_ops.concat(split_info_list, axis=0) @@ -1036,8 +1050,11 @@ class GradientBoostedDecisionTreeModel(object): # Update ensemble. update_ops = [are_all_splits_ready] - update_model = control_flow_ops.cond(continue_centering, _center_bias_fn, - _grow_ensemble_fn) + if self._center_bias: + update_model = control_flow_ops.cond(continue_centering, + _center_bias_fn, _grow_ensemble_fn) + else: + update_model = _grow_ensemble_fn() update_ops.append(update_model) # Update ensemble stats.