* Remove the bias centering graph if it is turned off.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 22 May 2018 18:02:30 +0000 (11:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 18:04:34 +0000 (11:04 -0700)
* Create consts once. Otherwise each time the constant is passed to an Op, a new Const op is created.
* Speed up the graph construction by using a functions to build splits.

PiperOrigin-RevId: 197590220

tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py

index 04e3226..401bec8 100644 (file)
@@ -43,47 +43,60 @@ namespace {
 const int32 DUMMY_FEATURE_DIMENSION = -1;
 }  // namespace
 
-class BaseBuildSplitOp : public OpKernel {
+class SplitBuilderState {
  public:
-  explicit BaseBuildSplitOp(OpKernelConstruction* const context)
-      : OpKernel(context) {
-    OP_REQUIRES_OK(context, context->GetAttr("feature_column_group_id",
-                                             &feature_column_group_id_));
+  explicit SplitBuilderState(OpKernelContext* const context) {
+    const Tensor* l1_regularization_t;
     OP_REQUIRES_OK(context,
-                   context->GetAttr("l1_regularization", &l1_regularization_));
+                   context->input("l1_regularization", &l1_regularization_t));
+    const Tensor* l2_regularization_t;
     OP_REQUIRES_OK(context,
-                   context->GetAttr("l2_regularization", &l2_regularization_));
-    OP_REQUIRES_OK(context, context->GetAttr("tree_complexity_regularization",
-                                             &tree_complexity_regularization_));
+                   context->input("l2_regularization", &l2_regularization_t));
+    const Tensor* tree_complexity_regularization_t;
+    OP_REQUIRES_OK(context, context->input("tree_complexity_regularization",
+                                           &tree_complexity_regularization_t));
+    const Tensor* min_node_weight_t;
     OP_REQUIRES_OK(context,
-                   context->GetAttr("min_node_weight", &min_node_weight_));
+                   context->input("min_node_weight", &min_node_weight_t));
 
-    int strategy;
-    OP_REQUIRES_OK(context, context->GetAttr("multiclass_strategy", &strategy));
+    const Tensor* feature_column_group_id_t;
+    OP_REQUIRES_OK(context, context->input("feature_column_group_id",
+                                           &feature_column_group_id_t));
+
+    const Tensor* multiclass_strategy_t;
+    OP_REQUIRES_OK(
+        context, context->input("multiclass_strategy", &multiclass_strategy_t));
+    int strategy = multiclass_strategy_t->scalar<int32>()();
     OP_REQUIRES(
         context,
         boosted_trees::learner::LearnerConfig_MultiClassStrategy_IsValid(
             strategy),
         errors::InvalidArgument("Wrong multiclass strategy passed."));
-    multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy);
-  }
 
-  NodeStats ComputeNodeStats(const GradientStats& grad_stats) {
-    return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_,
-                     multiclass_strategy_, grad_stats);
-  }
+    multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy);
 
-  void ReadClassId(OpKernelContext* const context, int32* class_id) {
     const Tensor* class_id_t;
     OP_REQUIRES_OK(context, context->input("class_id", &class_id_t));
     OP_REQUIRES(context, TensorShapeUtils::IsScalar(class_id_t->shape()),
                 errors::InvalidArgument("class_id must be a scalar."));
-    *class_id = class_id_t->scalar<int32>()();
+    class_id_ = class_id_t->scalar<int32>()();
+
+    l1_regularization_ = l1_regularization_t->scalar<float>()();
+    l2_regularization_ = l2_regularization_t->scalar<float>()();
+    tree_complexity_regularization_ =
+        tree_complexity_regularization_t->scalar<float>()();
+    min_node_weight_ = min_node_weight_t->scalar<float>()();
+    feature_column_group_id_ = feature_column_group_id_t->scalar<int32>()();
+  }
+
+  NodeStats ComputeNodeStats(const GradientStats& grad_stats) {
+    return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_,
+                     multiclass_strategy_, grad_stats);
   }
 
-  void FillLeaf(const int class_id, const NodeStats& best_node_stats,
+  void FillLeaf(const NodeStats& best_node_stats,
                 boosted_trees::trees::Leaf* leaf) const {
-    if (class_id == -1) {
+    if (class_id_ == -1) {
       // This would be the case either for TREE_PER_CLASS with only 2 classes,
       // or for other multiclass strategies.
       for (float f : best_node_stats.weight_contribution) {
@@ -93,25 +106,31 @@ class BaseBuildSplitOp : public OpKernel {
       CHECK(best_node_stats.weight_contribution.size() == 1)
           << "Weight contribution size = "
           << best_node_stats.weight_contribution.size();
-      leaf->mutable_sparse_vector()->add_index(class_id);
+      leaf->mutable_sparse_vector()->add_index(class_id_);
       leaf->mutable_sparse_vector()->add_value(
           best_node_stats.weight_contribution[0]);
     }
   }
 
- protected:
+  int32 feature_column_group_id() { return feature_column_group_id_; }
+  float tree_complexity_regularization() {
+    return tree_complexity_regularization_;
+  }
+
+ private:
   LearnerConfig_MultiClassStrategy multiclass_strategy_;
-  int32 feature_column_group_id_;
   float l1_regularization_;
   float l2_regularization_;
-  float min_node_weight_;
   float tree_complexity_regularization_;
+  float min_node_weight_;
+  int32 class_id_;
+  int32 feature_column_group_id_;
 };
 
-class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
+class BuildDenseInequalitySplitsOp : public OpKernel {
  public:
   explicit BuildDenseInequalitySplitsOp(OpKernelConstruction* const context)
-      : BaseBuildSplitOp(context) {}
+      : OpKernel(context) {}
 
   void Compute(OpKernelContext* const context) override {
     const Tensor* num_minibatches_t;
@@ -139,9 +158,6 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
     const Tensor* hessians_t;
     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
 
-    int class_id;
-    ReadClassId(context, &class_id);
-
     // Find the number of unique partitions before we allocate the output.
     std::vector<int32> partition_boundaries;
     partition_boundaries.push_back(0);
@@ -185,6 +201,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
                                 &output_splits_t));
     tensorflow::TTypes<string>::Vec output_splits =
         output_splits_t->vec<string>();
+    SplitBuilderState state(context);
     for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
       float best_gain = std::numeric_limits<float>::lowest();
       int start_index = partition_boundaries[root_idx];
@@ -196,7 +213,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
             GradientStats(*gradients_t, *hessians_t, bucket_idx);
       }
       root_gradient_stats *= normalizer_ratio;
-      NodeStats root_stats = ComputeNodeStats(root_gradient_stats);
+      NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
       int32 best_bucket_idx = 0;
       NodeStats best_right_node_stats(0);
       NodeStats best_left_node_stats(0);
@@ -206,10 +223,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
         GradientStats g(*gradients_t, *hessians_t, bucket_idx);
         g *= normalizer_ratio;
         left_gradient_stats += g;
-        NodeStats left_stats = ComputeNodeStats(left_gradient_stats);
+        NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
         GradientStats right_gradient_stats =
             root_gradient_stats - left_gradient_stats;
-        NodeStats right_stats = ComputeNodeStats(right_gradient_stats);
+        NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
         if (left_stats.gain + right_stats.gain > best_gain) {
           best_gain = left_stats.gain + right_stats.gain;
           best_left_node_stats = left_stats;
@@ -220,18 +237,18 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
       SplitInfo split_info;
       auto* dense_split =
           split_info.mutable_split_node()->mutable_dense_float_binary_split();
-      dense_split->set_feature_column(feature_column_group_id_);
+      dense_split->set_feature_column(state.feature_column_group_id());
       dense_split->set_threshold(
           bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
 
       auto* left_child = split_info.mutable_left_child();
       auto* right_child = split_info.mutable_right_child();
 
-      FillLeaf(class_id, best_left_node_stats, left_child);
-      FillLeaf(class_id, best_right_node_stats, right_child);
+      state.FillLeaf(best_left_node_stats, left_child);
+      state.FillLeaf(best_right_node_stats, right_child);
       split_info.SerializeToString(&output_splits(root_idx));
       gains(root_idx) =
-          best_gain - root_stats.gain - tree_complexity_regularization_;
+          best_gain - root_stats.gain - state.tree_complexity_regularization();
       output_partition_ids(root_idx) = partition_ids(start_index);
     }
   }
@@ -239,13 +256,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
 REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU),
                         BuildDenseInequalitySplitsOp);
 
-class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
+class BuildSparseInequalitySplitsOp : public OpKernel {
  public:
   explicit BuildSparseInequalitySplitsOp(OpKernelConstruction* const context)
-      : BaseBuildSplitOp(context) {
-    OP_REQUIRES_OK(context,
-                   context->GetAttr("bias_feature_id", &bias_feature_id_));
-  }
+      : OpKernel(context) {}
 
   void Compute(OpKernelContext* const context) override {
     const Tensor* num_minibatches_t;
@@ -275,8 +289,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
     const Tensor* hessians_t;
     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
 
-    int class_id;
-    ReadClassId(context, &class_id);
+    const Tensor* bias_feature_id_t;
+    OP_REQUIRES_OK(context,
+                   context->input("bias_feature_id", &bias_feature_id_t));
+    int64 bias_feature_id = bias_feature_id_t->scalar<int64>()();
 
     // For each partition (tree node), store starting index for each dimension.
     PartitionAndDimensionBoundaries partition_boundaries;
@@ -354,6 +370,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
                                 &output_splits_t));
     tensorflow::TTypes<string>::Vec output_splits =
         output_splits_t->vec<string>();
+    SplitBuilderState state(context);
     // For each tree node that needs to be split.
     for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
       const auto& dimension_boundaries =
@@ -372,7 +389,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
 
       OP_REQUIRES(
           context,
-          bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id_,
+          bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id,
           errors::InvalidArgument("Bias feature ID missing."));
 
       // Dimension for bias feature is always 0
@@ -388,7 +405,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
       GradientStats root_gradient_stats(*gradients_t, *hessians_t,
                                         bias_start_index);
       root_gradient_stats *= normalizer_ratio;
-      NodeStats root_stats = ComputeNodeStats(root_gradient_stats);
+      NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
 
       // Iterate through dimensions.
       for (int j = 0; j < dimension_boundaries.size() - 1; ++j) {
@@ -408,7 +425,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
             << bucket_ids_and_dimensions(start_index, 1) << " and for "
             << bucket_ids_and_dimensions(end_index - 1, 0) << " "
             << bucket_ids_and_dimensions(end_index - 1, 1);
-        if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id_) {
+        if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) {
           // 0-dimension case which has a first bucket for catch all feature.
           CHECK(bucket_ids_and_dimensions(start_index, 1) == 0)
               << "Dimension of bias feature should be 0";
@@ -447,10 +464,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
               present_gradient_stats - left_gradient_stats;
 
           {
-            NodeStats left_stats_default_left =
-                ComputeNodeStats(root_gradient_stats - right_gradient_stats);
+            NodeStats left_stats_default_left = state.ComputeNodeStats(
+                root_gradient_stats - right_gradient_stats);
             NodeStats right_stats_default_left =
-                ComputeNodeStats(right_gradient_stats);
+                state.ComputeNodeStats(right_gradient_stats);
             if (left_stats_default_left.gain + right_stats_default_left.gain >
                 best_gain) {
               best_gain =
@@ -466,9 +483,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
           // enough missing examples.
           if (!fixed_default_direction) {
             NodeStats left_stats_default_right =
-                ComputeNodeStats(left_gradient_stats);
-            NodeStats right_stats_default_right =
-                ComputeNodeStats(root_gradient_stats - left_gradient_stats);
+                state.ComputeNodeStats(left_gradient_stats);
+            NodeStats right_stats_default_right = state.ComputeNodeStats(
+                root_gradient_stats - left_gradient_stats);
             if (left_stats_default_right.gain + right_stats_default_right.gain >
                 best_gain) {
               best_gain = left_stats_default_right.gain +
@@ -494,7 +511,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
                           ->mutable_sparse_float_binary_split_default_left()
                           ->mutable_split();
       }
-      dense_split->set_feature_column(feature_column_group_id_);
+      dense_split->set_feature_column(state.feature_column_group_id());
       // Set the feature index for the best feature column.
       const int64 best_dimension_id =
           bucket_ids_and_dimensions(best_element_idx, 1);
@@ -505,11 +522,11 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
 
       auto* left_child = split_info.mutable_left_child();
       auto* right_child = split_info.mutable_right_child();
-      FillLeaf(class_id, best_left_node_stats, left_child);
-      FillLeaf(class_id, best_right_node_stats, right_child);
+      state.FillLeaf(best_left_node_stats, left_child);
+      state.FillLeaf(best_right_node_stats, right_child);
       split_info.SerializeToString(&output_splits(root_idx));
       gains(root_idx) =
-          best_gain - root_stats.gain - tree_complexity_regularization_;
+          best_gain - root_stats.gain - state.tree_complexity_regularization();
       output_partition_ids(root_idx) = partition_ids(bias_start_index);
     }
   }
@@ -526,19 +543,14 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
   // For each partition, store start indices of feature column dimensions.
   typedef std::vector<std::vector<DimensionBoundary>>
       PartitionAndDimensionBoundaries;
-
-  int64 bias_feature_id_;
 };
 REGISTER_KERNEL_BUILDER(Name("BuildSparseInequalitySplits").Device(DEVICE_CPU),
                         BuildSparseInequalitySplitsOp);
 
-class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
+class BuildCategoricalEqualitySplitsOp : public OpKernel {
  public:
   explicit BuildCategoricalEqualitySplitsOp(OpKernelConstruction* const context)
-      : BaseBuildSplitOp(context) {
-    OP_REQUIRES_OK(context,
-                   context->GetAttr("bias_feature_id", &bias_feature_id_));
-  }
+      : OpKernel(context) {}
 
   void Compute(OpKernelContext* const context) override {
     const Tensor* num_minibatches_t;
@@ -561,8 +573,10 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
     const Tensor* hessians_t;
     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
 
-    int class_id;
-    ReadClassId(context, &class_id);
+    const Tensor* bias_feature_id_t;
+    OP_REQUIRES_OK(context,
+                   context->input("bias_feature_id", &bias_feature_id_t));
+    int64 bias_feature_id = bias_feature_id_t->scalar<int64>()();
 
     // Find the number of unique partitions before we allocate the output.
     std::vector<int32> partition_boundaries;
@@ -605,16 +619,17 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
                                 &output_splits_t));
     tensorflow::TTypes<string>::Vec output_splits =
         output_splits_t->vec<string>();
+    SplitBuilderState state(context);
     for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
       float best_gain = std::numeric_limits<float>::lowest();
       int start_index = partition_boundaries[non_empty_partitions[root_idx]];
       int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1];
       // First feature ID in each partition should be the bias feature.
-      OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id_,
+      OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id,
                   errors::InvalidArgument("Bias feature ID missing."));
       GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
       root_gradient_stats *= normalizer_ratio;
-      NodeStats root_stats = ComputeNodeStats(root_gradient_stats);
+      NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
       int32 best_feature_idx = 0;
       NodeStats best_right_node_stats(0);
       NodeStats best_left_node_stats(0);
@@ -625,8 +640,8 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
         left_gradient_stats *= normalizer_ratio;
         GradientStats right_gradient_stats =
             root_gradient_stats - left_gradient_stats;
-        NodeStats left_stats = ComputeNodeStats(left_gradient_stats);
-        NodeStats right_stats = ComputeNodeStats(right_gradient_stats);
+        NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
+        NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
         if (left_stats.gain + right_stats.gain > best_gain) {
           best_gain = left_stats.gain + right_stats.gain;
           best_left_node_stats = left_stats;
@@ -637,21 +652,18 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
       SplitInfo split_info;
       auto* equality_split = split_info.mutable_split_node()
                                  ->mutable_categorical_id_binary_split();
-      equality_split->set_feature_column(feature_column_group_id_);
+      equality_split->set_feature_column(state.feature_column_group_id());
       equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
       auto* left_child = split_info.mutable_left_child();
       auto* right_child = split_info.mutable_right_child();
-      FillLeaf(class_id, best_left_node_stats, left_child);
-      FillLeaf(class_id, best_right_node_stats, right_child);
+      state.FillLeaf(best_left_node_stats, left_child);
+      state.FillLeaf(best_right_node_stats, right_child);
       split_info.SerializeToString(&output_splits(root_idx));
       gains(root_idx) =
-          best_gain - root_stats.gain - tree_complexity_regularization_;
+          best_gain - root_stats.gain - state.tree_complexity_regularization();
       output_partition_ids(root_idx) = partition_ids(start_index);
     }
   }
-
- private:
-  int64 bias_feature_id_;
 };
 
 REGISTER_KERNEL_BUILDER(
index f06b73c..23f4021 100644 (file)
@@ -64,6 +64,8 @@ from __future__ import print_function
 import re
 
 from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
+from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops
+from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops
 from tensorflow.contrib.boosted_trees.python.ops import quantile_ops
 from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
 from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
@@ -72,6 +74,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
@@ -327,9 +330,6 @@ class SparseSplitHandler(InequalitySplitHandler):
         multiclass_strategy=multiclass_strategy,
         init_stamp_token=init_stamp_token,
         name=name)
-    # Register sparse_make_stats_update function as an Op to the graph.
-    g = ops.get_default_graph()
-    sparse_make_stats_update.add_to_graph(g)
     self._sparse_float_column = sparse_float_column
 
   def scheduled_reads(self):
@@ -361,8 +361,8 @@ class SparseSplitHandler(InequalitySplitHandler):
     are_buckets_ready, buckets = scheduled_reads[0]
     with ops.name_scope(self._name, "SparseSplitHandler"):
       (quantile_indices, quantile_values, quantile_shapes, quantile_weights,
-       example_partition_ids,
-       feature_ids, gradients, hessians) = sparse_make_stats_update(
+       example_partition_ids, feature_ids, gradients,
+       hessians) = sparse_make_stats_update(
            is_active, are_buckets_ready, self._sparse_float_column.indices,
            self._sparse_float_column.values,
            self._sparse_float_column.dense_shape, buckets,
@@ -379,42 +379,104 @@ class SparseSplitHandler(InequalitySplitHandler):
 
   def make_splits(self, stamp_token, next_stamp_token, class_id):
     """Create the best split using the accumulated stats and flush the state."""
+    if (self._gradient_shape == tensor_shape.scalar() and
+        self._hessian_shape == tensor_shape.scalar()):
+      handler = make_sparse_split_scalar
+    else:
+      handler = make_sparse_split_tensor
+
+    are_splits_ready, partition_ids, gains, split_infos = (
+        handler(self._quantile_accumulator.resource(),
+                self._stats_accumulator.resource(), stamp_token,
+                next_stamp_token, self._multiclass_strategy, class_id,
+                self._feature_column_group_id, self._l1_regularization,
+                self._l2_regularization, self._tree_complexity_regularization,
+                self._min_node_weight))
+    return are_splits_ready, partition_ids, gains, split_infos
+
+
+def _specialize_sparse_split(is_multi_dimentional):
+  """Builds a specialized version of the function."""
+
+  def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle,
+                         stamp_token, next_stamp_token, multiclass_strategy,
+                         class_id, feature_column_id, l1_regularization,
+                         l2_regularization, tree_complexity_regularization,
+                         min_node_weight, is_multi_dimentional):
+    """Function that builds splits for a sparse feature column."""
     # Get the bucket boundaries
     are_splits_ready, buckets = (
-        self._quantile_accumulator.get_buckets(stamp_token))
+        gen_quantile_ops.quantile_accumulator_get_buckets(
+            quantile_accumulator_handles=[quantile_accumulator_handle],
+            stamp_token=stamp_token))
 
     # After we receive the boundaries from previous iteration we can flush
     # the quantile accumulator.
-    with ops.control_dependencies([buckets]):
-      flush_quantiles = self._quantile_accumulator.flush(
-          stamp_token=stamp_token, next_stamp_token=next_stamp_token)
-
-    with ops.device(None):
-      with ops.device(self._stats_accumulator.resource().device):
-        num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
-            self._stats_accumulator.flush(stamp_token, next_stamp_token))
-
-        # Put quantile and stats accumulator flushing in the dependency path.
-        are_splits_ready = control_flow_ops.with_dependencies(
-            [flush_quantiles, partition_ids], are_splits_ready)
-        partition_ids, gains, split_infos = (
-            split_handler_ops.build_sparse_inequality_splits(
-                num_minibatches=num_minibatches,
-                bucket_boundaries=buckets,
-                partition_ids=partition_ids,
-                bucket_ids=bucket_ids,
-                gradients=gradients,
-                hessians=hessians,
-                class_id=class_id,
-                feature_column_group_id=self._feature_column_group_id,
-                l1_regularization=self._l1_regularization,
-                l2_regularization=self._l2_regularization,
-                tree_complexity_regularization=self.
-                _tree_complexity_regularization,
-                min_node_weight=self._min_node_weight,
-                bias_feature_id=_BIAS_FEATURE_ID,
-                multiclass_strategy=self._multiclass_strategy))
-    return (are_splits_ready, partition_ids, gains, split_infos)
+    with ops.control_dependencies([buckets[0]]):
+      flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
+          quantile_accumulator_handle=quantile_accumulator_handle,
+          stamp_token=stamp_token,
+          next_stamp_token=next_stamp_token)
+
+    if is_multi_dimentional:
+      num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+          gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
+              stats_accumulator_handle, stamp_token, next_stamp_token))
+    else:
+      num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+          gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
+              stats_accumulator_handle, stamp_token, next_stamp_token))
+
+    # Put quantile and stats accumulator flushing in the dependency path.
+    with ops.control_dependencies([flush_quantiles, partition_ids]):
+      are_splits_ready = array_ops.identity(are_splits_ready)
+    partition_ids, gains, split_infos = (
+        split_handler_ops.build_sparse_inequality_splits(
+            num_minibatches=num_minibatches,
+            bucket_boundaries=buckets[0],
+            partition_ids=partition_ids,
+            bucket_ids=bucket_ids,
+            gradients=gradients,
+            hessians=hessians,
+            class_id=class_id,
+            feature_column_group_id=feature_column_id,
+            l1_regularization=l1_regularization,
+            l2_regularization=l2_regularization,
+            tree_complexity_regularization=tree_complexity_regularization,
+            min_node_weight=min_node_weight,
+            bias_feature_id=_BIAS_FEATURE_ID,
+            multiclass_strategy=multiclass_strategy))
+    return are_splits_ready, partition_ids, gains, split_infos
+
+  @function.Defun(
+      dtypes.resource,
+      dtypes.resource,
+      dtypes.int64,
+      dtypes.int64,
+      dtypes.int32,
+      dtypes.int32,
+      dtypes.int32,
+      dtypes.float32,
+      dtypes.float32,
+      dtypes.float32,
+      dtypes.float32,
+      noinline=True)
+  def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
+        next_stamp_token, multiclass_strategy, class_id, feature_column_id,
+        l1_regularization, l2_regularization, tree_complexity_regularization,
+        min_node_weight):
+    """Function that builds splits for a sparse feature column."""
+    return _make_sparse_split(
+        quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
+        next_stamp_token, multiclass_strategy, class_id, feature_column_id,
+        l1_regularization, l2_regularization, tree_complexity_regularization,
+        min_node_weight, is_multi_dimentional)
+
+  return f
+
+
+make_sparse_split_scalar = _specialize_sparse_split(is_multi_dimentional=False)
+make_sparse_split_tensor = _specialize_sparse_split(is_multi_dimentional=True)
 
 
 @function.Defun(
@@ -540,8 +602,9 @@ def sparse_make_stats_update(
 
   empty_float = constant_op.constant([], dtype=dtypes.float32)
   handler_not_active = (constant_op.constant(
-      [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant(
-          [0, 1], dtype=dtypes.int64), empty_float)
+      [], dtype=dtypes.int64, shape=[0, 2]), empty_float,
+                        constant_op.constant([0, 1], dtype=dtypes.int64),
+                        empty_float)
   handler_active = (sparse_column_indices, sparse_column_values,
                     sparse_column_shape, weights)
   quantile_indices, quantile_values, quantile_shape, quantile_weights = (
index 54d0301..c081a3f 100644 (file)
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import numpy as np
+
 from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler
 from tensorflow.contrib.boosted_trees.proto import learner_pb2
 from tensorflow.contrib.boosted_trees.proto import split_info_pb2
@@ -92,7 +94,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
           example_weights,
           is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_1]):
-        are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+        are_splits_ready = split_handler.make_splits(
+            np.int64(0), np.int64(1), class_id)[0]
+
       with ops.control_dependencies([are_splits_ready]):
         update_2 = split_handler.update_stats_sync(
             1,
@@ -105,7 +109,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
             is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_2]):
         are_splits_ready2, partitions, gains, splits = (
-            split_handler.make_splits(1, 2, class_id))
+            split_handler.make_splits(np.int64(1), np.int64(2), class_id))
         are_splits_ready, are_splits_ready2, partitions, gains, splits = (
             sess.run([
                 are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -227,7 +231,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
           example_weights,
           is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_1]):
-        are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+        are_splits_ready = split_handler.make_splits(
+            np.int64(0), np.int64(1), class_id)[0]
+
       with ops.control_dependencies([are_splits_ready]):
         update_2 = split_handler.update_stats_sync(
             1,
@@ -240,7 +246,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
             is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_2]):
         are_splits_ready2, partitions, gains, splits = (
-            split_handler.make_splits(1, 2, class_id))
+            split_handler.make_splits(np.int64(1), np.int64(2), class_id))
         are_splits_ready, are_splits_ready2, partitions, gains, splits = (
             sess.run([
                 are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -313,7 +319,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
           example_weights,
           is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_1]):
-        are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+        are_splits_ready = split_handler.make_splits(
+            np.int64(0), np.int64(1), class_id)[0]
       with ops.control_dependencies([are_splits_ready]):
         update_2 = split_handler.update_stats_sync(
             1,
@@ -326,7 +333,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
             is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_2]):
         are_splits_ready2, partitions, gains, splits = (
-            split_handler.make_splits(1, 2, class_id))
+            split_handler.make_splits(np.int64(1), np.int64(2), class_id))
         are_splits_ready, are_splits_ready2, partitions, gains, splits = (
             sess.run([
                 are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -396,7 +403,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
           example_weights,
           is_active=array_ops.constant([True, False]))
       with ops.control_dependencies([update_1]):
-        are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+        are_splits_ready = split_handler.make_splits(
+            np.int64(0), np.int64(1), class_id)[0]
       with ops.control_dependencies([are_splits_ready]):
         update_2 = split_handler.update_stats_sync(
             1,
@@ -409,7 +417,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
             is_active=array_ops.constant([False, True]))
       with ops.control_dependencies([update_2]):
         are_splits_ready2, partitions, gains, splits = (
-            split_handler.make_splits(1, 2, class_id))
+            split_handler.make_splits(np.int64(1), np.int64(2), class_id))
         are_splits_ready, are_splits_ready2, partitions, gains, splits = (
             sess.run([
                 are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -470,7 +478,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
           example_weights,
           is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_1]):
-        are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+        are_splits_ready = split_handler.make_splits(
+            np.int64(0), np.int64(1), class_id)[0]
       with ops.control_dependencies([are_splits_ready]):
         update_2 = split_handler.update_stats_sync(
             1,
@@ -483,7 +492,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
             is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_2]):
         are_splits_ready2, partitions, gains, splits = (
-            split_handler.make_splits(1, 2, class_id))
+            split_handler.make_splits(np.int64(1), np.int64(2), class_id))
         are_splits_ready, are_splits_ready2, partitions, gains, splits = (
             sess.run([
                 are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -603,7 +612,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
           example_weights,
           is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_1]):
-        are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+        are_splits_ready = split_handler.make_splits(
+            np.int64(0), np.int64(1), class_id)[0]
       with ops.control_dependencies([are_splits_ready]):
         update_2 = split_handler.update_stats_sync(
             1,
@@ -616,7 +626,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
             is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_2]):
         are_splits_ready2, partitions, gains, splits = (
-            split_handler.make_splits(1, 2, class_id))
+            split_handler.make_splits(np.int64(1), np.int64(2), class_id))
         are_splits_ready, are_splits_ready2, partitions, gains, splits = (
             sess.run([
                 are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -685,10 +695,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
       class_id = -1
 
       split_handler = ordinal_split_handler.SparseSplitHandler(
-          l1_regularization=0,
-          l2_regularization=2,
-          tree_complexity_regularization=0,
-          min_node_weight=0,
+          l1_regularization=0.0,
+          l2_regularization=2.0,
+          tree_complexity_regularization=0.0,
+          min_node_weight=0.0,
           epsilon=0.01,
           num_quantiles=2,
           feature_column_group_id=0,
@@ -713,8 +723,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
           example_weights,
           is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_1]):
-        are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
-
+        are_splits_ready = split_handler.make_splits(
+            np.int64(0), np.int64(1), class_id)[0]
       with ops.control_dependencies([are_splits_ready]):
         update_2 = split_handler.update_stats_sync(
             1,
@@ -727,7 +737,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
             is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_2]):
         are_splits_ready2, partitions, gains, splits = (
-            split_handler.make_splits(1, 2, class_id))
+            split_handler.make_splits(np.int64(1), np.int64(2), class_id))
         are_splits_ready, are_splits_ready2, partitions, gains, splits = (
             sess.run([
                 are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -811,10 +821,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
       class_id = -1
 
       split_handler = ordinal_split_handler.SparseSplitHandler(
-          l1_regularization=0,
-          l2_regularization=2,
-          tree_complexity_regularization=0,
-          min_node_weight=0,
+          l1_regularization=0.0,
+          l2_regularization=2.0,
+          tree_complexity_regularization=0.0,
+          min_node_weight=0.0,
           epsilon=0.01,
           num_quantiles=2,
           feature_column_group_id=0,
@@ -839,7 +849,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
           example_weights,
           is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_1]):
-        are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+        are_splits_ready = split_handler.make_splits(
+            np.int64(0), np.int64(1), class_id)[0]
 
       with ops.control_dependencies([are_splits_ready]):
         update_2 = split_handler.update_stats_sync(
@@ -853,7 +864,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
             is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_2]):
         are_splits_ready2, partitions, gains, splits = (
-            split_handler.make_splits(1, 2, class_id))
+            split_handler.make_splits(np.int64(1), np.int64(2), class_id))
         are_splits_ready, are_splits_ready2, partitions, gains, splits = (
             sess.run([
                 are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -905,10 +916,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
       class_id = -1
 
       split_handler = ordinal_split_handler.SparseSplitHandler(
-          l1_regularization=0,
-          l2_regularization=2,
-          tree_complexity_regularization=0,
-          min_node_weight=0,
+          l1_regularization=0.0,
+          l2_regularization=2.0,
+          tree_complexity_regularization=0.0,
+          min_node_weight=0.0,
           epsilon=0.01,
           num_quantiles=2,
           feature_column_group_id=0,
@@ -933,7 +944,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
           example_weights,
           is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_1]):
-        are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+        are_splits_ready = split_handler.make_splits(
+            np.int64(0), np.int64(1), class_id)[0]
 
       with ops.control_dependencies([are_splits_ready]):
         update_2 = split_handler.update_stats_sync(
@@ -947,7 +959,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
             is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_2]):
         are_splits_ready2, partitions, gains, splits = (
-            split_handler.make_splits(1, 2, class_id))
+            split_handler.make_splits(np.int64(1), np.int64(2), class_id))
         are_splits_ready, are_splits_ready2, partitions, gains, splits = (
             sess.run([
                 are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -996,10 +1008,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
       class_id = -1
 
       split_handler = ordinal_split_handler.SparseSplitHandler(
-          l1_regularization=0,
-          l2_regularization=2,
-          tree_complexity_regularization=0,
-          min_node_weight=0,
+          l1_regularization=0.0,
+          l2_regularization=2.0,
+          tree_complexity_regularization=0.0,
+          min_node_weight=0.0,
           epsilon=0.01,
           num_quantiles=2,
           feature_column_group_id=0,
@@ -1024,7 +1036,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
           example_weights,
           is_active=array_ops.constant([True, False]))
       with ops.control_dependencies([update_1]):
-        are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+        are_splits_ready = split_handler.make_splits(
+            np.int64(0), np.int64(1), class_id)[0]
 
       with ops.control_dependencies([are_splits_ready]):
         update_2 = split_handler.update_stats_sync(
@@ -1038,7 +1051,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
             is_active=array_ops.constant([False, True]))
       with ops.control_dependencies([update_2]):
         are_splits_ready2, partitions, gains, splits = (
-            split_handler.make_splits(1, 2, class_id))
+            split_handler.make_splits(np.int64(1), np.int64(2), class_id))
         are_splits_ready, are_splits_ready2, partitions, gains, splits = (
             sess.run([
                 are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -1065,10 +1078,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
       class_id = -1
 
       split_handler = ordinal_split_handler.SparseSplitHandler(
-          l1_regularization=0,
-          l2_regularization=2,
-          tree_complexity_regularization=0,
-          min_node_weight=0,
+          l1_regularization=0.0,
+          l2_regularization=2.0,
+          tree_complexity_regularization=0.0,
+          min_node_weight=0.0,
           epsilon=0.01,
           num_quantiles=2,
           feature_column_group_id=0,
@@ -1096,7 +1109,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
           example_weights,
           is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_1]):
-        are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+        are_splits_ready = split_handler.make_splits(
+            np.int64(0), np.int64(1), class_id)[0]
 
       with ops.control_dependencies([are_splits_ready]):
         update_2 = split_handler.update_stats_sync(
@@ -1110,7 +1124,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
             is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_2]):
         are_splits_ready2, partitions, gains, splits = (
-            split_handler.make_splits(1, 2, class_id))
+            split_handler.make_splits(np.int64(1), np.int64(2), class_id))
         are_splits_ready, are_splits_ready2, partitions, gains, splits = (
             sess.run([
                 are_splits_ready, are_splits_ready2, partitions, gains, splits
@@ -1138,10 +1152,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
       class_id = -1
 
       split_handler = ordinal_split_handler.SparseSplitHandler(
-          l1_regularization=0,
-          l2_regularization=2,
-          tree_complexity_regularization=0,
-          min_node_weight=0,
+          l1_regularization=0.0,
+          l2_regularization=2.0,
+          tree_complexity_regularization=0.0,
+          min_node_weight=0.0,
           epsilon=0.01,
           num_quantiles=2,
           feature_column_group_id=0,
@@ -1166,7 +1180,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
           example_weights,
           is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_1]):
-        are_splits_ready = split_handler.make_splits(0, 1, class_id)[0]
+        are_splits_ready = split_handler.make_splits(
+            np.int64(0), np.int64(1), class_id)[0]
 
       with ops.control_dependencies([are_splits_ready]):
         update_2 = split_handler.update_stats_sync(
@@ -1180,7 +1195,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
             is_active=array_ops.constant([True, True]))
       with ops.control_dependencies([update_2]):
         are_splits_ready2, partitions, gains, splits = (
-            split_handler.make_splits(1, 2, class_id))
+            split_handler.make_splits(np.int64(1), np.int64(2), class_id))
         are_splits_ready, are_splits_ready2, partitions, gains, splits = (
             sess.run([
                 are_splits_ready, are_splits_ready2, partitions, gains, splits
index 5d0ebbf..ca5c7f3 100644 (file)
@@ -23,12 +23,6 @@ using shape_inference::InferenceContext;
 using shape_inference::ShapeHandle;
 
 REGISTER_OP("BuildDenseInequalitySplits")
-    .Attr("feature_column_group_id: int")
-    .Attr("l1_regularization: float")
-    .Attr("l2_regularization: float")
-    .Attr("tree_complexity_regularization: float")
-    .Attr("min_node_weight: float")
-    .Attr("multiclass_strategy: int")
     .Input("num_minibatches: int64")
     .Input("partition_ids: int32")
     .Input("bucket_ids: int64")
@@ -36,6 +30,12 @@ REGISTER_OP("BuildDenseInequalitySplits")
     .Input("hessians: float32")
     .Input("bucket_boundaries: float32")
     .Input("class_id: int32")
+    .Input("feature_column_group_id: int32")
+    .Input("l1_regularization: float")
+    .Input("l2_regularization: float")
+    .Input("tree_complexity_regularization: float")
+    .Input("min_node_weight: float")
+    .Input("multiclass_strategy: int32")
     .Output("output_partition_ids: int32")
     .Output("gains: float32")
     .Output("split_infos: string")
@@ -73,6 +73,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions.
 gradients: A rank 1 tensor of gradients.
 hessians: A rank 1 tensor of hessians.
 bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization.
+class_id: A scalar, the class id for which we're building the splits.
+feature_column_group_id: A scalar, the index of the feature we are spiltting on.
+l1_regularization: A scalar, which specifies the l1 regularization term.
+l2_regularization: A scalar, which specifies the l2 regularization term.
+tree_complexity_regularization: A scalar, which specifies the tree complexity
+    regularization term.
+min_node_weight: A scalar, minimum sum of example hessian needed in a child.
+    If a split results in a leaf node with a smaller value, the split will not
+    be considered.
+multiclass_strategy: A scalar, specifying the multiclass handling strategy.
+    See LearnerConfig.MultiClassStrategy for valid values.
 output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
     for.
 gains: A rank 1 tensor, for the computed gain for the created splits.
@@ -81,13 +92,6 @@ split_infos: A rank 1 tensor of serialized protos which contains the
 )doc");
 
 REGISTER_OP("BuildSparseInequalitySplits")
-    .Attr("feature_column_group_id: int")
-    .Attr("bias_feature_id: int")
-    .Attr("l1_regularization: float")
-    .Attr("l2_regularization: float")
-    .Attr("tree_complexity_regularization: float")
-    .Attr("min_node_weight: float")
-    .Attr("multiclass_strategy: int")
     .Input("num_minibatches: int64")
     .Input("partition_ids: int32")
     .Input("bucket_ids: int64")
@@ -95,6 +99,13 @@ REGISTER_OP("BuildSparseInequalitySplits")
     .Input("hessians: float32")
     .Input("bucket_boundaries: float32")
     .Input("class_id: int32")
+    .Input("feature_column_group_id: int32")
+    .Input("bias_feature_id: int64")
+    .Input("l1_regularization: float")
+    .Input("l2_regularization: float")
+    .Input("tree_complexity_regularization: float")
+    .Input("min_node_weight: float")
+    .Input("multiclass_strategy: int32")
     .Output("output_partition_ids: int32")
     .Output("gains: float32")
     .Output("split_infos: string")
@@ -133,6 +144,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions.
 gradients: A rank 1 tensor of gradients.
 hessians: A rank 1 tensor of hessians.
 bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization.
+class_id: A scalar, the class id for which we're building the splits.
+feature_column_group_id: A scalar, the index of the feature we are spiltting on.
+l1_regularization: A scalar, which specifies the l1 regularization term.
+l2_regularization: A scalar, which specifies the l2 regularization term.
+tree_complexity_regularization: A scalar, which specifies the tree complexity
+    regularization term.
+min_node_weight: A scalar, minimum sum of example hessian needed in a child.
+    If a split results in a leaf node with a smaller value, the split will not
+    be considered.
+multiclass_strategy: A scalar, specifying the multiclass handling strategy.
+    See LearnerConfig.MultiClassStrategy for valid values.
 output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
     for.
 gains: A rank 1 tensor, for the computed gain for the created splits.
@@ -141,19 +163,19 @@ split_infos: A rank 1 tensor of serialized protos which contains the
 )doc");
 
 REGISTER_OP("BuildCategoricalEqualitySplits")
-    .Attr("feature_column_group_id: int")
-    .Attr("bias_feature_id: int")
-    .Attr("l1_regularization: float")
-    .Attr("l2_regularization: float")
-    .Attr("tree_complexity_regularization: float")
-    .Attr("min_node_weight: float")
-    .Attr("multiclass_strategy: int")
     .Input("num_minibatches: int64")
     .Input("partition_ids: int32")
     .Input("feature_ids: int64")
     .Input("gradients: float32")
     .Input("hessians: float32")
     .Input("class_id: int32")
+    .Input("feature_column_group_id: int32")
+    .Input("bias_feature_id: int64")
+    .Input("l1_regularization: float")
+    .Input("l2_regularization: float")
+    .Input("tree_complexity_regularization: float")
+    .Input("min_node_weight: float")
+    .Input("multiclass_strategy: int32")
     .Output("output_partition_ids: int32")
     .Output("gains: float32")
     .Output("split_infos: string")
@@ -188,6 +210,17 @@ partition_ids: A rank 1 tensor of partition IDs.
 feature_ids: A rank 2 tensor of feature IDs and dimensions.
 gradients: A rank 1 tensor of gradients.
 hessians: A rank 1 tensor of hessians.
+class_id: A scalar, the class id for which we're building the splits.
+feature_column_group_id: A scalar, the index of the feature we are spiltting on.
+l1_regularization: A scalar, which specifies the l1 regularization term.
+l2_regularization: A scalar, which specifies the l2 regularization term.
+tree_complexity_regularization: A scalar, which specifies the tree complexity
+    regularization term.
+min_node_weight: A scalar, minimum sum of example hessian needed in a child.
+    If a split results in a leaf node with a smaller value, the split will not
+    be considered.
+multiclass_strategy: A scalar, specifying the multiclass handling strategy.
+    See LearnerConfig.MultiClassStrategy for valid values.
 output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
     for.
 gains: A rank 1 tensor, for the computed gain for the created splits.
@@ -196,4 +229,3 @@ split_infos: A rank 1 tensor of serialized protos which contains the
 )doc");
 
 }  // namespace tensorflow
-   // namespace tensorflow
index 7a5f329..8434209 100644 (file)
@@ -20,6 +20,8 @@ from __future__ import print_function
 import abc
 import collections
 
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
@@ -60,6 +62,7 @@ def _move_tensors(tensors, device):
   """Moves a list of tensors to a device by concatenating/splitting them."""
   # Reset the device setting to avoid weird interactions with device merging
   # logic.
+  zero = constant_op.constant(0, dtype=dtypes.int32)
   with ops.device(None):
     if all(tensor.shape == tensor_shape.scalar() for tensor in tensors):
       with ops.device(tensors[0].device):
@@ -68,12 +71,11 @@ def _move_tensors(tensors, device):
         return array_ops.unstack(values)
     else:
       with ops.device(tensors[0].device):
-        sizes = array_ops.stack(
-            [array_ops.shape(tensor)[0] for tensor in tensors])
-        values = array_ops.concat(tensors, axis=0)
+        sizes = array_ops.stack(array_ops.shape_n(tensors))[:, 0]
+        values = array_ops.concat(tensors, axis=zero)
       with ops.device(device):
         sizes = array_ops.unstack(sizes)
-        return list(array_ops.split(values, sizes, axis=0))
+        return list(array_ops.split(values, sizes, axis=zero))
 
 
 def _scheduled_stamp_resource_op_runner(batch, stamp):
index 50cc00a..19b6b32 100644 (file)
@@ -201,3 +201,6 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject):
         stamp_token=stamp_token,
         next_stamp_token=next_stamp_token)
     return result
+
+  def resource(self):
+    return self._quantile_accumulator_handle
index 08c1dcd..c725f32 100644 (file)
@@ -180,8 +180,7 @@ def extract_features(features, feature_columns, use_core_columns):
         elif isinstance(fc, feature_column_lib._EmbeddingColumn):
           # pylint: enable=protected-access
           transformed_features[fc.name] = fc_core.input_layer(
-              features, [fc],
-              weight_collections=[scope])
+              features, [fc], weight_collections=[scope])
         else:
           result = feature_column_ops.transform_features(features, [fc])
           if len(result) > 1:
@@ -334,10 +333,12 @@ class GradientBoostedDecisionTreeModel(object):
     self._feature_columns = feature_columns
     self._learner_config_serialized = learner_config.SerializeToString()
     self._attempted_trees = variables.Variable(
-        initial_value=array_ops.zeros([], dtypes.int64), trainable=False,
+        initial_value=array_ops.zeros([], dtypes.int64),
+        trainable=False,
         name="attempted_trees")
     self._finalized_trees = variables.Variable(
-        initial_value=array_ops.zeros([], dtypes.int64), trainable=False,
+        initial_value=array_ops.zeros([], dtypes.int64),
+        trainable=False,
         name="finalized_trees")
     if not features:
       raise ValueError("Features dictionary must be specified.")
@@ -354,9 +355,10 @@ class GradientBoostedDecisionTreeModel(object):
     self._sparse_int_indices = sparse_int_indices
     self._sparse_int_values = sparse_int_values
     self._sparse_int_shapes = sparse_int_shapes
-    self._reduce_dim = (self._learner_config.multi_class_strategy ==
-                        learner_pb2.LearnerConfig.TREE_PER_CLASS and
-                        learner_config.num_classes == 2)
+    self._reduce_dim = (
+        self._learner_config.multi_class_strategy ==
+        learner_pb2.LearnerConfig.TREE_PER_CLASS and
+        learner_config.num_classes == 2)
 
   def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode):
     """Runs prediction and returns a dictionary of the prediction results.
@@ -374,8 +376,8 @@ class GradientBoostedDecisionTreeModel(object):
     ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle,
                                                       ensemble_stamp)
     num_handlers = (
-        len(self._dense_floats) + len(self._sparse_float_shapes) +
-        len(self._sparse_int_shapes))
+        len(self._dense_floats) + len(self._sparse_float_shapes) + len(
+            self._sparse_int_shapes))
     # Used during feature selection.
     used_handlers = model_ops.tree_ensemble_used_handlers(
         ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers)
@@ -432,8 +434,9 @@ class GradientBoostedDecisionTreeModel(object):
     # Use the current ensemble to predict on the current batch of input.
     # For faster prediction we check if the inputs are on the same device
     # as the model. If not, we create a copy of the model on the worker.
-    input_deps = (self._dense_floats + self._sparse_float_indices +
-                  self._sparse_int_indices)
+    input_deps = (
+        self._dense_floats + self._sparse_float_indices +
+        self._sparse_int_indices)
     if not input_deps:
       raise ValueError("No input tensors for prediction.")
 
@@ -500,8 +503,9 @@ class GradientBoostedDecisionTreeModel(object):
       ValueError: if inputs are not valid.
     """
     # Get the worker device from input dependencies.
-    input_deps = (self._dense_floats + self._sparse_float_indices +
-                  self._sparse_int_indices)
+    input_deps = (
+        self._dense_floats + self._sparse_float_indices +
+        self._sparse_int_indices)
     worker_device = input_deps[0].device
 
     # Get tensors relevant for training and form the loss.
@@ -517,7 +521,7 @@ class GradientBoostedDecisionTreeModel(object):
         aggregation_method=None)[0]
     strategy = self._learner_config.multi_class_strategy
 
-    class_id = -1
+    class_id = constant_op.constant(-1, dtype=dtypes.int32)
     # Handle different multiclass strategies.
     if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS:
       # We build one vs rest trees.
@@ -571,31 +575,39 @@ class GradientBoostedDecisionTreeModel(object):
     # Get the weights for each example for quantiles calculation,
     weights = self._get_weights(hessian_shape, squeezed_hessians)
 
-    regularization_config = self._learner_config.regularization
-    min_node_weight = self._learner_config.constraints.min_node_weight
     # Create all handlers ensuring resources are evenly allocated across PS.
     fc_name_idx = 0
     handlers = []
     init_stamp_token = constant_op.constant(0, dtype=dtypes.int64)
+    l1_regularization = constant_op.constant(
+        self._learner_config.regularization.l1, dtypes.float32)
+    l2_regularization = constant_op.constant(
+        self._learner_config.regularization.l2, dtypes.float32)
+    tree_complexity_regularization = constant_op.constant(
+        self._learner_config.regularization.tree_complexity, dtypes.float32)
+    min_node_weight = constant_op.constant(
+        self._learner_config.constraints.min_node_weight, dtypes.float32)
+    epsilon = 0.01
+    num_quantiles = 100
+    strategy_tensor = constant_op.constant(strategy)
     with ops.device(self._get_replica_device_setter(worker_device)):
       # Create handlers for dense float columns
       for dense_float_column_idx in range(len(self._dense_floats)):
         fc_name = self._fc_names[fc_name_idx]
         handlers.append(
             ordinal_split_handler.DenseSplitHandler(
-                l1_regularization=regularization_config.l1,
-                l2_regularization=regularization_config.l2,
-                tree_complexity_regularization=(
-                    regularization_config.tree_complexity),
+                l1_regularization=l1_regularization,
+                l2_regularization=l2_regularization,
+                tree_complexity_regularization=tree_complexity_regularization,
                 min_node_weight=min_node_weight,
                 feature_column_group_id=dense_float_column_idx,
-                epsilon=0.01,
-                num_quantiles=100,
+                epsilon=epsilon,
+                num_quantiles=num_quantiles,
                 dense_float_column=self._dense_floats[dense_float_column_idx],
                 name=fc_name,
                 gradient_shape=gradient_shape,
                 hessian_shape=hessian_shape,
-                multiclass_strategy=strategy,
+                multiclass_strategy=strategy_tensor,
                 init_stamp_token=init_stamp_token))
         fc_name_idx += 1
 
@@ -604,14 +616,13 @@ class GradientBoostedDecisionTreeModel(object):
         fc_name = self._fc_names[fc_name_idx]
         handlers.append(
             ordinal_split_handler.SparseSplitHandler(
-                l1_regularization=regularization_config.l1,
-                l2_regularization=regularization_config.l2,
-                tree_complexity_regularization=(
-                    regularization_config.tree_complexity),
+                l1_regularization=l1_regularization,
+                l2_regularization=l2_regularization,
+                tree_complexity_regularization=tree_complexity_regularization,
                 min_node_weight=min_node_weight,
                 feature_column_group_id=sparse_float_column_idx,
-                epsilon=0.01,
-                num_quantiles=100,
+                epsilon=epsilon,
+                num_quantiles=num_quantiles,
                 sparse_float_column=sparse_tensor.SparseTensor(
                     self._sparse_float_indices[sparse_float_column_idx],
                     self._sparse_float_values[sparse_float_column_idx],
@@ -619,7 +630,7 @@ class GradientBoostedDecisionTreeModel(object):
                 name=fc_name,
                 gradient_shape=gradient_shape,
                 hessian_shape=hessian_shape,
-                multiclass_strategy=strategy,
+                multiclass_strategy=strategy_tensor,
                 init_stamp_token=init_stamp_token))
         fc_name_idx += 1
 
@@ -628,10 +639,9 @@ class GradientBoostedDecisionTreeModel(object):
         fc_name = self._fc_names[fc_name_idx]
         handlers.append(
             categorical_split_handler.EqualitySplitHandler(
-                l1_regularization=regularization_config.l1,
-                l2_regularization=regularization_config.l2,
-                tree_complexity_regularization=(
-                    regularization_config.tree_complexity),
+                l1_regularization=l1_regularization,
+                l2_regularization=l2_regularization,
+                tree_complexity_regularization=tree_complexity_regularization,
                 min_node_weight=min_node_weight,
                 feature_column_group_id=sparse_int_column_idx,
                 sparse_int_column=sparse_tensor.SparseTensor(
@@ -641,7 +651,7 @@ class GradientBoostedDecisionTreeModel(object):
                 name=fc_name,
                 gradient_shape=gradient_shape,
                 hessian_shape=hessian_shape,
-                multiclass_strategy=strategy,
+                multiclass_strategy=strategy_tensor,
                 init_stamp_token=init_stamp_token))
         fc_name_idx += 1
 
@@ -694,11 +704,11 @@ class GradientBoostedDecisionTreeModel(object):
         name="continue_centering",
         trainable=False)
     stats_update_ops.append(
-        control_flow_ops.cond(continue_centering,
-                              self._make_update_bias_stats_fn(
-                                  ensemble_stamp, predictions, gradients,
-                                  bias_stats_accumulator),
-                              control_flow_ops.no_op))
+        control_flow_ops.cond(
+            continue_centering,
+            self._make_update_bias_stats_fn(ensemble_stamp, predictions,
+                                            gradients, bias_stats_accumulator),
+            control_flow_ops.no_op))
 
     # Update handler stats.
     handler_reads = collections.OrderedDict()
@@ -720,8 +730,8 @@ class GradientBoostedDecisionTreeModel(object):
           shape=[len(handlers)], seed=[seed + 1, 1])
       active_handlers = array_ops.stack(
           [active_handlers_current_layer, active_handlers_next_layer], axis=1)
-      active_handlers = (active_handlers <
-                         self._learner_config.feature_fraction_per_level)
+      active_handlers = (
+          active_handlers < self._learner_config.feature_fraction_per_level)
     elif subsampling_type == "feature_fraction_per_tree":
       seed = predictions_dict[NUM_TREES_ATTEMPTED]
       active_handlers_current_layer = stateless.stateless_random_uniform(
@@ -729,9 +739,12 @@ class GradientBoostedDecisionTreeModel(object):
       active_handlers_current_layer = (
           active_handlers_current_layer <
           self._learner_config.feature_fraction_per_tree)
-      active_handlers = array_ops.stack([
-          active_handlers_current_layer,
-          array_ops.ones([len(handlers)], dtype=dtypes.bool)], axis=1)
+      active_handlers = array_ops.stack(
+          [
+              active_handlers_current_layer,
+              array_ops.ones([len(handlers)], dtype=dtypes.bool)
+          ],
+          axis=1)
     else:
       active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool)
 
@@ -760,6 +773,7 @@ class GradientBoostedDecisionTreeModel(object):
     empty_hessians = constant_op.constant(
         [], dtype=dtypes.float32, shape=empty_hess_shape)
 
+    active_handlers = array_ops.unstack(active_handlers, axis=0)
     for handler_idx in range(len(handlers)):
       handler = handlers[handler_idx]
       is_active = active_handlers[handler_idx]
@@ -971,7 +985,7 @@ class GradientBoostedDecisionTreeModel(object):
       # This is a workaround for the slowness of graph building in tf.cond.
       # See (b/36554864).
       split_sizes = array_ops.reshape(
-          array_ops.shape_n(partition_ids_list), [-1])
+          array_ops.shape_n(partition_ids_list), [len(partition_ids_list)])
       partition_ids = array_ops.concat(partition_ids_list, axis=0)
       gains = array_ops.concat(gains_list, axis=0)
       split_infos = array_ops.concat(split_info_list, axis=0)
@@ -1036,8 +1050,11 @@ class GradientBoostedDecisionTreeModel(object):
 
       # Update ensemble.
       update_ops = [are_all_splits_ready]
-      update_model = control_flow_ops.cond(continue_centering, _center_bias_fn,
-                                           _grow_ensemble_fn)
+      if self._center_bias:
+        update_model = control_flow_ops.cond(continue_centering,
+                                             _center_bias_fn, _grow_ensemble_fn)
+      else:
+        update_model = _grow_ensemble_fn()
       update_ops.append(update_model)
 
       # Update ensemble stats.