const int32 DUMMY_FEATURE_DIMENSION = -1;
} // namespace
-class BaseBuildSplitOp : public OpKernel {
+class SplitBuilderState {
public:
- explicit BaseBuildSplitOp(OpKernelConstruction* const context)
- : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("feature_column_group_id",
- &feature_column_group_id_));
+ explicit SplitBuilderState(OpKernelContext* const context) {
+ const Tensor* l1_regularization_t;
OP_REQUIRES_OK(context,
- context->GetAttr("l1_regularization", &l1_regularization_));
+ context->input("l1_regularization", &l1_regularization_t));
+ const Tensor* l2_regularization_t;
OP_REQUIRES_OK(context,
- context->GetAttr("l2_regularization", &l2_regularization_));
- OP_REQUIRES_OK(context, context->GetAttr("tree_complexity_regularization",
- &tree_complexity_regularization_));
+ context->input("l2_regularization", &l2_regularization_t));
+ const Tensor* tree_complexity_regularization_t;
+ OP_REQUIRES_OK(context, context->input("tree_complexity_regularization",
+ &tree_complexity_regularization_t));
+ const Tensor* min_node_weight_t;
OP_REQUIRES_OK(context,
- context->GetAttr("min_node_weight", &min_node_weight_));
+ context->input("min_node_weight", &min_node_weight_t));
- int strategy;
- OP_REQUIRES_OK(context, context->GetAttr("multiclass_strategy", &strategy));
+ const Tensor* feature_column_group_id_t;
+ OP_REQUIRES_OK(context, context->input("feature_column_group_id",
+ &feature_column_group_id_t));
+
+ const Tensor* multiclass_strategy_t;
+ OP_REQUIRES_OK(
+ context, context->input("multiclass_strategy", &multiclass_strategy_t));
+ int strategy = multiclass_strategy_t->scalar<int32>()();
OP_REQUIRES(
context,
boosted_trees::learner::LearnerConfig_MultiClassStrategy_IsValid(
strategy),
errors::InvalidArgument("Wrong multiclass strategy passed."));
- multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy);
- }
- NodeStats ComputeNodeStats(const GradientStats& grad_stats) {
- return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_,
- multiclass_strategy_, grad_stats);
- }
+ multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy);
- void ReadClassId(OpKernelContext* const context, int32* class_id) {
const Tensor* class_id_t;
OP_REQUIRES_OK(context, context->input("class_id", &class_id_t));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(class_id_t->shape()),
errors::InvalidArgument("class_id must be a scalar."));
- *class_id = class_id_t->scalar<int32>()();
+ class_id_ = class_id_t->scalar<int32>()();
+
+ l1_regularization_ = l1_regularization_t->scalar<float>()();
+ l2_regularization_ = l2_regularization_t->scalar<float>()();
+ tree_complexity_regularization_ =
+ tree_complexity_regularization_t->scalar<float>()();
+ min_node_weight_ = min_node_weight_t->scalar<float>()();
+ feature_column_group_id_ = feature_column_group_id_t->scalar<int32>()();
+ }
+
+ NodeStats ComputeNodeStats(const GradientStats& grad_stats) {
+ return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_,
+ multiclass_strategy_, grad_stats);
}
- void FillLeaf(const int class_id, const NodeStats& best_node_stats,
+ void FillLeaf(const NodeStats& best_node_stats,
boosted_trees::trees::Leaf* leaf) const {
- if (class_id == -1) {
+ if (class_id_ == -1) {
// This would be the case either for TREE_PER_CLASS with only 2 classes,
// or for other multiclass strategies.
for (float f : best_node_stats.weight_contribution) {
CHECK(best_node_stats.weight_contribution.size() == 1)
<< "Weight contribution size = "
<< best_node_stats.weight_contribution.size();
- leaf->mutable_sparse_vector()->add_index(class_id);
+ leaf->mutable_sparse_vector()->add_index(class_id_);
leaf->mutable_sparse_vector()->add_value(
best_node_stats.weight_contribution[0]);
}
}
- protected:
+ int32 feature_column_group_id() { return feature_column_group_id_; }
+ float tree_complexity_regularization() {
+ return tree_complexity_regularization_;
+ }
+
+ private:
LearnerConfig_MultiClassStrategy multiclass_strategy_;
- int32 feature_column_group_id_;
float l1_regularization_;
float l2_regularization_;
- float min_node_weight_;
float tree_complexity_regularization_;
+ float min_node_weight_;
+ int32 class_id_;
+ int32 feature_column_group_id_;
};
-class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
+class BuildDenseInequalitySplitsOp : public OpKernel {
public:
explicit BuildDenseInequalitySplitsOp(OpKernelConstruction* const context)
- : BaseBuildSplitOp(context) {}
+ : OpKernel(context) {}
void Compute(OpKernelContext* const context) override {
const Tensor* num_minibatches_t;
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
- int class_id;
- ReadClassId(context, &class_id);
-
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
partition_boundaries.push_back(0);
&output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+ SplitBuilderState state(context);
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[root_idx];
GradientStats(*gradients_t, *hessians_t, bucket_idx);
}
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
int32 best_bucket_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
GradientStats g(*gradients_t, *hessians_t, bucket_idx);
g *= normalizer_ratio;
left_gradient_stats += g;
- NodeStats left_stats = ComputeNodeStats(left_gradient_stats);
+ NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
- NodeStats right_stats = ComputeNodeStats(right_gradient_stats);
+ NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
SplitInfo split_info;
auto* dense_split =
split_info.mutable_split_node()->mutable_dense_float_binary_split();
- dense_split->set_feature_column(feature_column_group_id_);
+ dense_split->set_feature_column(state.feature_column_group_id());
dense_split->set_threshold(
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- FillLeaf(class_id, best_left_node_stats, left_child);
- FillLeaf(class_id, best_right_node_stats, right_child);
+ state.FillLeaf(best_left_node_stats, left_child);
+ state.FillLeaf(best_right_node_stats, right_child);
split_info.SerializeToString(&output_splits(root_idx));
gains(root_idx) =
- best_gain - root_stats.gain - tree_complexity_regularization_;
+ best_gain - root_stats.gain - state.tree_complexity_regularization();
output_partition_ids(root_idx) = partition_ids(start_index);
}
}
REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU),
BuildDenseInequalitySplitsOp);
-class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
+class BuildSparseInequalitySplitsOp : public OpKernel {
public:
explicit BuildSparseInequalitySplitsOp(OpKernelConstruction* const context)
- : BaseBuildSplitOp(context) {
- OP_REQUIRES_OK(context,
- context->GetAttr("bias_feature_id", &bias_feature_id_));
- }
+ : OpKernel(context) {}
void Compute(OpKernelContext* const context) override {
const Tensor* num_minibatches_t;
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
- int class_id;
- ReadClassId(context, &class_id);
+ const Tensor* bias_feature_id_t;
+ OP_REQUIRES_OK(context,
+ context->input("bias_feature_id", &bias_feature_id_t));
+ int64 bias_feature_id = bias_feature_id_t->scalar<int64>()();
// For each partition (tree node), store starting index for each dimension.
PartitionAndDimensionBoundaries partition_boundaries;
&output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+ SplitBuilderState state(context);
// For each tree node that needs to be split.
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
const auto& dimension_boundaries =
OP_REQUIRES(
context,
- bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id_,
+ bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id,
errors::InvalidArgument("Bias feature ID missing."));
// Dimension for bias feature is always 0
GradientStats root_gradient_stats(*gradients_t, *hessians_t,
bias_start_index);
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
// Iterate through dimensions.
for (int j = 0; j < dimension_boundaries.size() - 1; ++j) {
<< bucket_ids_and_dimensions(start_index, 1) << " and for "
<< bucket_ids_and_dimensions(end_index - 1, 0) << " "
<< bucket_ids_and_dimensions(end_index - 1, 1);
- if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id_) {
+ if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) {
// 0-dimension case which has a first bucket for catch all feature.
CHECK(bucket_ids_and_dimensions(start_index, 1) == 0)
<< "Dimension of bias feature should be 0";
present_gradient_stats - left_gradient_stats;
{
- NodeStats left_stats_default_left =
- ComputeNodeStats(root_gradient_stats - right_gradient_stats);
+ NodeStats left_stats_default_left = state.ComputeNodeStats(
+ root_gradient_stats - right_gradient_stats);
NodeStats right_stats_default_left =
- ComputeNodeStats(right_gradient_stats);
+ state.ComputeNodeStats(right_gradient_stats);
if (left_stats_default_left.gain + right_stats_default_left.gain >
best_gain) {
best_gain =
// enough missing examples.
if (!fixed_default_direction) {
NodeStats left_stats_default_right =
- ComputeNodeStats(left_gradient_stats);
- NodeStats right_stats_default_right =
- ComputeNodeStats(root_gradient_stats - left_gradient_stats);
+ state.ComputeNodeStats(left_gradient_stats);
+ NodeStats right_stats_default_right = state.ComputeNodeStats(
+ root_gradient_stats - left_gradient_stats);
if (left_stats_default_right.gain + right_stats_default_right.gain >
best_gain) {
best_gain = left_stats_default_right.gain +
->mutable_sparse_float_binary_split_default_left()
->mutable_split();
}
- dense_split->set_feature_column(feature_column_group_id_);
+ dense_split->set_feature_column(state.feature_column_group_id());
// Set the feature index for the best feature column.
const int64 best_dimension_id =
bucket_ids_and_dimensions(best_element_idx, 1);
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- FillLeaf(class_id, best_left_node_stats, left_child);
- FillLeaf(class_id, best_right_node_stats, right_child);
+ state.FillLeaf(best_left_node_stats, left_child);
+ state.FillLeaf(best_right_node_stats, right_child);
split_info.SerializeToString(&output_splits(root_idx));
gains(root_idx) =
- best_gain - root_stats.gain - tree_complexity_regularization_;
+ best_gain - root_stats.gain - state.tree_complexity_regularization();
output_partition_ids(root_idx) = partition_ids(bias_start_index);
}
}
// For each partition, store start indices of feature column dimensions.
typedef std::vector<std::vector<DimensionBoundary>>
PartitionAndDimensionBoundaries;
-
- int64 bias_feature_id_;
};
REGISTER_KERNEL_BUILDER(Name("BuildSparseInequalitySplits").Device(DEVICE_CPU),
BuildSparseInequalitySplitsOp);
-class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
+class BuildCategoricalEqualitySplitsOp : public OpKernel {
public:
explicit BuildCategoricalEqualitySplitsOp(OpKernelConstruction* const context)
- : BaseBuildSplitOp(context) {
- OP_REQUIRES_OK(context,
- context->GetAttr("bias_feature_id", &bias_feature_id_));
- }
+ : OpKernel(context) {}
void Compute(OpKernelContext* const context) override {
const Tensor* num_minibatches_t;
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
- int class_id;
- ReadClassId(context, &class_id);
+ const Tensor* bias_feature_id_t;
+ OP_REQUIRES_OK(context,
+ context->input("bias_feature_id", &bias_feature_id_t));
+ int64 bias_feature_id = bias_feature_id_t->scalar<int64>()();
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
&output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+ SplitBuilderState state(context);
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[non_empty_partitions[root_idx]];
int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1];
// First feature ID in each partition should be the bias feature.
- OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id_,
+ OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id,
errors::InvalidArgument("Bias feature ID missing."));
GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
int32 best_feature_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
left_gradient_stats *= normalizer_ratio;
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
- NodeStats left_stats = ComputeNodeStats(left_gradient_stats);
- NodeStats right_stats = ComputeNodeStats(right_gradient_stats);
+ NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
+ NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
SplitInfo split_info;
auto* equality_split = split_info.mutable_split_node()
->mutable_categorical_id_binary_split();
- equality_split->set_feature_column(feature_column_group_id_);
+ equality_split->set_feature_column(state.feature_column_group_id());
equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- FillLeaf(class_id, best_left_node_stats, left_child);
- FillLeaf(class_id, best_right_node_stats, right_child);
+ state.FillLeaf(best_left_node_stats, left_child);
+ state.FillLeaf(best_right_node_stats, right_child);
split_info.SerializeToString(&output_splits(root_idx));
gains(root_idx) =
- best_gain - root_stats.gain - tree_complexity_regularization_;
+ best_gain - root_stats.gain - state.tree_complexity_regularization();
output_partition_ids(root_idx) = partition_ids(start_index);
}
}
-
- private:
- int64 bias_feature_id_;
};
REGISTER_KERNEL_BUILDER(
import re
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
+from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops
+from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops
from tensorflow.contrib.boosted_trees.python.ops import quantile_ops
from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
multiclass_strategy=multiclass_strategy,
init_stamp_token=init_stamp_token,
name=name)
- # Register sparse_make_stats_update function as an Op to the graph.
- g = ops.get_default_graph()
- sparse_make_stats_update.add_to_graph(g)
self._sparse_float_column = sparse_float_column
def scheduled_reads(self):
are_buckets_ready, buckets = scheduled_reads[0]
with ops.name_scope(self._name, "SparseSplitHandler"):
(quantile_indices, quantile_values, quantile_shapes, quantile_weights,
- example_partition_ids,
- feature_ids, gradients, hessians) = sparse_make_stats_update(
+ example_partition_ids, feature_ids, gradients,
+ hessians) = sparse_make_stats_update(
is_active, are_buckets_ready, self._sparse_float_column.indices,
self._sparse_float_column.values,
self._sparse_float_column.dense_shape, buckets,
def make_splits(self, stamp_token, next_stamp_token, class_id):
"""Create the best split using the accumulated stats and flush the state."""
+ if (self._gradient_shape == tensor_shape.scalar() and
+ self._hessian_shape == tensor_shape.scalar()):
+ handler = make_sparse_split_scalar
+ else:
+ handler = make_sparse_split_tensor
+
+ are_splits_ready, partition_ids, gains, split_infos = (
+ handler(self._quantile_accumulator.resource(),
+ self._stats_accumulator.resource(), stamp_token,
+ next_stamp_token, self._multiclass_strategy, class_id,
+ self._feature_column_group_id, self._l1_regularization,
+ self._l2_regularization, self._tree_complexity_regularization,
+ self._min_node_weight))
+ return are_splits_ready, partition_ids, gains, split_infos
+
+
+def _specialize_sparse_split(is_multi_dimentional):
+ """Builds a specialized version of the function."""
+
+ def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy,
+ class_id, feature_column_id, l1_regularization,
+ l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional):
+ """Function that builds splits for a sparse feature column."""
# Get the bucket boundaries
are_splits_ready, buckets = (
- self._quantile_accumulator.get_buckets(stamp_token))
+ gen_quantile_ops.quantile_accumulator_get_buckets(
+ quantile_accumulator_handles=[quantile_accumulator_handle],
+ stamp_token=stamp_token))
# After we receive the boundaries from previous iteration we can flush
# the quantile accumulator.
- with ops.control_dependencies([buckets]):
- flush_quantiles = self._quantile_accumulator.flush(
- stamp_token=stamp_token, next_stamp_token=next_stamp_token)
-
- with ops.device(None):
- with ops.device(self._stats_accumulator.resource().device):
- num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
- self._stats_accumulator.flush(stamp_token, next_stamp_token))
-
- # Put quantile and stats accumulator flushing in the dependency path.
- are_splits_ready = control_flow_ops.with_dependencies(
- [flush_quantiles, partition_ids], are_splits_ready)
- partition_ids, gains, split_infos = (
- split_handler_ops.build_sparse_inequality_splits(
- num_minibatches=num_minibatches,
- bucket_boundaries=buckets,
- partition_ids=partition_ids,
- bucket_ids=bucket_ids,
- gradients=gradients,
- hessians=hessians,
- class_id=class_id,
- feature_column_group_id=self._feature_column_group_id,
- l1_regularization=self._l1_regularization,
- l2_regularization=self._l2_regularization,
- tree_complexity_regularization=self.
- _tree_complexity_regularization,
- min_node_weight=self._min_node_weight,
- bias_feature_id=_BIAS_FEATURE_ID,
- multiclass_strategy=self._multiclass_strategy))
- return (are_splits_ready, partition_ids, gains, split_infos)
+ with ops.control_dependencies([buckets[0]]):
+ flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
+ quantile_accumulator_handle=quantile_accumulator_handle,
+ stamp_token=stamp_token,
+ next_stamp_token=next_stamp_token)
+
+ if is_multi_dimentional:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+ else:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+
+ # Put quantile and stats accumulator flushing in the dependency path.
+ with ops.control_dependencies([flush_quantiles, partition_ids]):
+ are_splits_ready = array_ops.identity(are_splits_ready)
+ partition_ids, gains, split_infos = (
+ split_handler_ops.build_sparse_inequality_splits(
+ num_minibatches=num_minibatches,
+ bucket_boundaries=buckets[0],
+ partition_ids=partition_ids,
+ bucket_ids=bucket_ids,
+ gradients=gradients,
+ hessians=hessians,
+ class_id=class_id,
+ feature_column_group_id=feature_column_id,
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
+ min_node_weight=min_node_weight,
+ bias_feature_id=_BIAS_FEATURE_ID,
+ multiclass_strategy=multiclass_strategy))
+ return are_splits_ready, partition_ids, gains, split_infos
+
+ @function.Defun(
+ dtypes.resource,
+ dtypes.resource,
+ dtypes.int64,
+ dtypes.int64,
+ dtypes.int32,
+ dtypes.int32,
+ dtypes.int32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ noinline=True)
+ def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
+ next_stamp_token, multiclass_strategy, class_id, feature_column_id,
+ l1_regularization, l2_regularization, tree_complexity_regularization,
+ min_node_weight):
+ """Function that builds splits for a sparse feature column."""
+ return _make_sparse_split(
+ quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
+ next_stamp_token, multiclass_strategy, class_id, feature_column_id,
+ l1_regularization, l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional)
+
+ return f
+
+
+make_sparse_split_scalar = _specialize_sparse_split(is_multi_dimentional=False)
+make_sparse_split_tensor = _specialize_sparse_split(is_multi_dimentional=True)
@function.Defun(
empty_float = constant_op.constant([], dtype=dtypes.float32)
handler_not_active = (constant_op.constant(
- [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant(
- [0, 1], dtype=dtypes.int64), empty_float)
+ [], dtype=dtypes.int64, shape=[0, 2]), empty_float,
+ constant_op.constant([0, 1], dtype=dtypes.int64),
+ empty_float)
handler_active = (sparse_column_indices, sparse_column_values,
sparse_column_shape, weights)
quantile_indices, quantile_values, quantile_shape, quantile_weights = (
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.proto import split_info_pb2
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
+
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
+
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
example_weights,
is_active=array_ops.constant([True, False]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
is_active=array_ops.constant([False, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler(
- l1_regularization=0,
- l2_regularization=2,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
epsilon=0.01,
num_quantiles=2,
feature_column_group_id=0,
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
-
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler(
- l1_regularization=0,
- l2_regularization=2,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
epsilon=0.01,
num_quantiles=2,
feature_column_group_id=0,
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler(
- l1_regularization=0,
- l2_regularization=2,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
epsilon=0.01,
num_quantiles=2,
feature_column_group_id=0,
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler(
- l1_regularization=0,
- l2_regularization=2,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
epsilon=0.01,
num_quantiles=2,
feature_column_group_id=0,
example_weights,
is_active=array_ops.constant([True, False]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
is_active=array_ops.constant([False, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler(
- l1_regularization=0,
- l2_regularization=2,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
epsilon=0.01,
num_quantiles=2,
feature_column_group_id=0,
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
class_id = -1
split_handler = ordinal_split_handler.SparseSplitHandler(
- l1_regularization=0,
- l2_regularization=2,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
epsilon=0.01,
num_quantiles=2,
feature_column_group_id=0,
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
- are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
- split_handler.make_splits(1, 2, class_id))
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
using shape_inference::ShapeHandle;
REGISTER_OP("BuildDenseInequalitySplits")
- .Attr("feature_column_group_id: int")
- .Attr("l1_regularization: float")
- .Attr("l2_regularization: float")
- .Attr("tree_complexity_regularization: float")
- .Attr("min_node_weight: float")
- .Attr("multiclass_strategy: int")
.Input("num_minibatches: int64")
.Input("partition_ids: int32")
.Input("bucket_ids: int64")
.Input("hessians: float32")
.Input("bucket_boundaries: float32")
.Input("class_id: int32")
+ .Input("feature_column_group_id: int32")
+ .Input("l1_regularization: float")
+ .Input("l2_regularization: float")
+ .Input("tree_complexity_regularization: float")
+ .Input("min_node_weight: float")
+ .Input("multiclass_strategy: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
gradients: A rank 1 tensor of gradients.
hessians: A rank 1 tensor of hessians.
bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization.
+class_id: A scalar, the class id for which we're building the splits.
+feature_column_group_id: A scalar, the index of the feature we are spiltting on.
+l1_regularization: A scalar, which specifies the l1 regularization term.
+l2_regularization: A scalar, which specifies the l2 regularization term.
+tree_complexity_regularization: A scalar, which specifies the tree complexity
+ regularization term.
+min_node_weight: A scalar, minimum sum of example hessian needed in a child.
+ If a split results in a leaf node with a smaller value, the split will not
+ be considered.
+multiclass_strategy: A scalar, specifying the multiclass handling strategy.
+ See LearnerConfig.MultiClassStrategy for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.
)doc");
REGISTER_OP("BuildSparseInequalitySplits")
- .Attr("feature_column_group_id: int")
- .Attr("bias_feature_id: int")
- .Attr("l1_regularization: float")
- .Attr("l2_regularization: float")
- .Attr("tree_complexity_regularization: float")
- .Attr("min_node_weight: float")
- .Attr("multiclass_strategy: int")
.Input("num_minibatches: int64")
.Input("partition_ids: int32")
.Input("bucket_ids: int64")
.Input("hessians: float32")
.Input("bucket_boundaries: float32")
.Input("class_id: int32")
+ .Input("feature_column_group_id: int32")
+ .Input("bias_feature_id: int64")
+ .Input("l1_regularization: float")
+ .Input("l2_regularization: float")
+ .Input("tree_complexity_regularization: float")
+ .Input("min_node_weight: float")
+ .Input("multiclass_strategy: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
gradients: A rank 1 tensor of gradients.
hessians: A rank 1 tensor of hessians.
bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization.
+class_id: A scalar, the class id for which we're building the splits.
+feature_column_group_id: A scalar, the index of the feature we are spiltting on.
+l1_regularization: A scalar, which specifies the l1 regularization term.
+l2_regularization: A scalar, which specifies the l2 regularization term.
+tree_complexity_regularization: A scalar, which specifies the tree complexity
+ regularization term.
+min_node_weight: A scalar, minimum sum of example hessian needed in a child.
+ If a split results in a leaf node with a smaller value, the split will not
+ be considered.
+multiclass_strategy: A scalar, specifying the multiclass handling strategy.
+ See LearnerConfig.MultiClassStrategy for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.
)doc");
REGISTER_OP("BuildCategoricalEqualitySplits")
- .Attr("feature_column_group_id: int")
- .Attr("bias_feature_id: int")
- .Attr("l1_regularization: float")
- .Attr("l2_regularization: float")
- .Attr("tree_complexity_regularization: float")
- .Attr("min_node_weight: float")
- .Attr("multiclass_strategy: int")
.Input("num_minibatches: int64")
.Input("partition_ids: int32")
.Input("feature_ids: int64")
.Input("gradients: float32")
.Input("hessians: float32")
.Input("class_id: int32")
+ .Input("feature_column_group_id: int32")
+ .Input("bias_feature_id: int64")
+ .Input("l1_regularization: float")
+ .Input("l2_regularization: float")
+ .Input("tree_complexity_regularization: float")
+ .Input("min_node_weight: float")
+ .Input("multiclass_strategy: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
feature_ids: A rank 2 tensor of feature IDs and dimensions.
gradients: A rank 1 tensor of gradients.
hessians: A rank 1 tensor of hessians.
+class_id: A scalar, the class id for which we're building the splits.
+feature_column_group_id: A scalar, the index of the feature we are spiltting on.
+l1_regularization: A scalar, which specifies the l1 regularization term.
+l2_regularization: A scalar, which specifies the l2 regularization term.
+tree_complexity_regularization: A scalar, which specifies the tree complexity
+ regularization term.
+min_node_weight: A scalar, minimum sum of example hessian needed in a child.
+ If a split results in a leaf node with a smaller value, the split will not
+ be considered.
+multiclass_strategy: A scalar, specifying the multiclass handling strategy.
+ See LearnerConfig.MultiClassStrategy for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.
)doc");
} // namespace tensorflow
- // namespace tensorflow
import abc
import collections
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
"""Moves a list of tensors to a device by concatenating/splitting them."""
# Reset the device setting to avoid weird interactions with device merging
# logic.
+ zero = constant_op.constant(0, dtype=dtypes.int32)
with ops.device(None):
if all(tensor.shape == tensor_shape.scalar() for tensor in tensors):
with ops.device(tensors[0].device):
return array_ops.unstack(values)
else:
with ops.device(tensors[0].device):
- sizes = array_ops.stack(
- [array_ops.shape(tensor)[0] for tensor in tensors])
- values = array_ops.concat(tensors, axis=0)
+ sizes = array_ops.stack(array_ops.shape_n(tensors))[:, 0]
+ values = array_ops.concat(tensors, axis=zero)
with ops.device(device):
sizes = array_ops.unstack(sizes)
- return list(array_ops.split(values, sizes, axis=0))
+ return list(array_ops.split(values, sizes, axis=zero))
def _scheduled_stamp_resource_op_runner(batch, stamp):
stamp_token=stamp_token,
next_stamp_token=next_stamp_token)
return result
+
+ def resource(self):
+ return self._quantile_accumulator_handle
elif isinstance(fc, feature_column_lib._EmbeddingColumn):
# pylint: enable=protected-access
transformed_features[fc.name] = fc_core.input_layer(
- features, [fc],
- weight_collections=[scope])
+ features, [fc], weight_collections=[scope])
else:
result = feature_column_ops.transform_features(features, [fc])
if len(result) > 1:
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
self._attempted_trees = variables.Variable(
- initial_value=array_ops.zeros([], dtypes.int64), trainable=False,
+ initial_value=array_ops.zeros([], dtypes.int64),
+ trainable=False,
name="attempted_trees")
self._finalized_trees = variables.Variable(
- initial_value=array_ops.zeros([], dtypes.int64), trainable=False,
+ initial_value=array_ops.zeros([], dtypes.int64),
+ trainable=False,
name="finalized_trees")
if not features:
raise ValueError("Features dictionary must be specified.")
self._sparse_int_indices = sparse_int_indices
self._sparse_int_values = sparse_int_values
self._sparse_int_shapes = sparse_int_shapes
- self._reduce_dim = (self._learner_config.multi_class_strategy ==
- learner_pb2.LearnerConfig.TREE_PER_CLASS and
- learner_config.num_classes == 2)
+ self._reduce_dim = (
+ self._learner_config.multi_class_strategy ==
+ learner_pb2.LearnerConfig.TREE_PER_CLASS and
+ learner_config.num_classes == 2)
def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode):
"""Runs prediction and returns a dictionary of the prediction results.
ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle,
ensemble_stamp)
num_handlers = (
- len(self._dense_floats) + len(self._sparse_float_shapes) +
- len(self._sparse_int_shapes))
+ len(self._dense_floats) + len(self._sparse_float_shapes) + len(
+ self._sparse_int_shapes))
# Used during feature selection.
used_handlers = model_ops.tree_ensemble_used_handlers(
ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers)
# Use the current ensemble to predict on the current batch of input.
# For faster prediction we check if the inputs are on the same device
# as the model. If not, we create a copy of the model on the worker.
- input_deps = (self._dense_floats + self._sparse_float_indices +
- self._sparse_int_indices)
+ input_deps = (
+ self._dense_floats + self._sparse_float_indices +
+ self._sparse_int_indices)
if not input_deps:
raise ValueError("No input tensors for prediction.")
ValueError: if inputs are not valid.
"""
# Get the worker device from input dependencies.
- input_deps = (self._dense_floats + self._sparse_float_indices +
- self._sparse_int_indices)
+ input_deps = (
+ self._dense_floats + self._sparse_float_indices +
+ self._sparse_int_indices)
worker_device = input_deps[0].device
# Get tensors relevant for training and form the loss.
aggregation_method=None)[0]
strategy = self._learner_config.multi_class_strategy
- class_id = -1
+ class_id = constant_op.constant(-1, dtype=dtypes.int32)
# Handle different multiclass strategies.
if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS:
# We build one vs rest trees.
# Get the weights for each example for quantiles calculation,
weights = self._get_weights(hessian_shape, squeezed_hessians)
- regularization_config = self._learner_config.regularization
- min_node_weight = self._learner_config.constraints.min_node_weight
# Create all handlers ensuring resources are evenly allocated across PS.
fc_name_idx = 0
handlers = []
init_stamp_token = constant_op.constant(0, dtype=dtypes.int64)
+ l1_regularization = constant_op.constant(
+ self._learner_config.regularization.l1, dtypes.float32)
+ l2_regularization = constant_op.constant(
+ self._learner_config.regularization.l2, dtypes.float32)
+ tree_complexity_regularization = constant_op.constant(
+ self._learner_config.regularization.tree_complexity, dtypes.float32)
+ min_node_weight = constant_op.constant(
+ self._learner_config.constraints.min_node_weight, dtypes.float32)
+ epsilon = 0.01
+ num_quantiles = 100
+ strategy_tensor = constant_op.constant(strategy)
with ops.device(self._get_replica_device_setter(worker_device)):
# Create handlers for dense float columns
for dense_float_column_idx in range(len(self._dense_floats)):
fc_name = self._fc_names[fc_name_idx]
handlers.append(
ordinal_split_handler.DenseSplitHandler(
- l1_regularization=regularization_config.l1,
- l2_regularization=regularization_config.l2,
- tree_complexity_regularization=(
- regularization_config.tree_complexity),
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
feature_column_group_id=dense_float_column_idx,
- epsilon=0.01,
- num_quantiles=100,
+ epsilon=epsilon,
+ num_quantiles=num_quantiles,
dense_float_column=self._dense_floats[dense_float_column_idx],
name=fc_name,
gradient_shape=gradient_shape,
hessian_shape=hessian_shape,
- multiclass_strategy=strategy,
+ multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token))
fc_name_idx += 1
fc_name = self._fc_names[fc_name_idx]
handlers.append(
ordinal_split_handler.SparseSplitHandler(
- l1_regularization=regularization_config.l1,
- l2_regularization=regularization_config.l2,
- tree_complexity_regularization=(
- regularization_config.tree_complexity),
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
feature_column_group_id=sparse_float_column_idx,
- epsilon=0.01,
- num_quantiles=100,
+ epsilon=epsilon,
+ num_quantiles=num_quantiles,
sparse_float_column=sparse_tensor.SparseTensor(
self._sparse_float_indices[sparse_float_column_idx],
self._sparse_float_values[sparse_float_column_idx],
name=fc_name,
gradient_shape=gradient_shape,
hessian_shape=hessian_shape,
- multiclass_strategy=strategy,
+ multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token))
fc_name_idx += 1
fc_name = self._fc_names[fc_name_idx]
handlers.append(
categorical_split_handler.EqualitySplitHandler(
- l1_regularization=regularization_config.l1,
- l2_regularization=regularization_config.l2,
- tree_complexity_regularization=(
- regularization_config.tree_complexity),
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
feature_column_group_id=sparse_int_column_idx,
sparse_int_column=sparse_tensor.SparseTensor(
name=fc_name,
gradient_shape=gradient_shape,
hessian_shape=hessian_shape,
- multiclass_strategy=strategy,
+ multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token))
fc_name_idx += 1
name="continue_centering",
trainable=False)
stats_update_ops.append(
- control_flow_ops.cond(continue_centering,
- self._make_update_bias_stats_fn(
- ensemble_stamp, predictions, gradients,
- bias_stats_accumulator),
- control_flow_ops.no_op))
+ control_flow_ops.cond(
+ continue_centering,
+ self._make_update_bias_stats_fn(ensemble_stamp, predictions,
+ gradients, bias_stats_accumulator),
+ control_flow_ops.no_op))
# Update handler stats.
handler_reads = collections.OrderedDict()
shape=[len(handlers)], seed=[seed + 1, 1])
active_handlers = array_ops.stack(
[active_handlers_current_layer, active_handlers_next_layer], axis=1)
- active_handlers = (active_handlers <
- self._learner_config.feature_fraction_per_level)
+ active_handlers = (
+ active_handlers < self._learner_config.feature_fraction_per_level)
elif subsampling_type == "feature_fraction_per_tree":
seed = predictions_dict[NUM_TREES_ATTEMPTED]
active_handlers_current_layer = stateless.stateless_random_uniform(
active_handlers_current_layer = (
active_handlers_current_layer <
self._learner_config.feature_fraction_per_tree)
- active_handlers = array_ops.stack([
- active_handlers_current_layer,
- array_ops.ones([len(handlers)], dtype=dtypes.bool)], axis=1)
+ active_handlers = array_ops.stack(
+ [
+ active_handlers_current_layer,
+ array_ops.ones([len(handlers)], dtype=dtypes.bool)
+ ],
+ axis=1)
else:
active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool)
empty_hessians = constant_op.constant(
[], dtype=dtypes.float32, shape=empty_hess_shape)
+ active_handlers = array_ops.unstack(active_handlers, axis=0)
for handler_idx in range(len(handlers)):
handler = handlers[handler_idx]
is_active = active_handlers[handler_idx]
# This is a workaround for the slowness of graph building in tf.cond.
# See (b/36554864).
split_sizes = array_ops.reshape(
- array_ops.shape_n(partition_ids_list), [-1])
+ array_ops.shape_n(partition_ids_list), [len(partition_ids_list)])
partition_ids = array_ops.concat(partition_ids_list, axis=0)
gains = array_ops.concat(gains_list, axis=0)
split_infos = array_ops.concat(split_info_list, axis=0)
# Update ensemble.
update_ops = [are_all_splits_ready]
- update_model = control_flow_ops.cond(continue_centering, _center_bias_fn,
- _grow_ensemble_fn)
+ if self._center_bias:
+ update_model = control_flow_ops.cond(continue_centering,
+ _center_bias_fn, _grow_ensemble_fn)
+ else:
+ update_model = _grow_ensemble_fn()
update_ops.append(update_model)
# Update ensemble stats.