Remove proto import in header files for core/kernels/boosted_trees.
authorYifei Feng <yifeif@google.com>
Thu, 19 Apr 2018 08:26:07 +0000 (01:26 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 19 Apr 2018 08:29:24 +0000 (01:29 -0700)
Move implementations that requires declaration of TreeEnsemble to .cc files.

The goal is to make kernels mostly independent of proto headers, which will let us lock down our .so import

PiperOrigin-RevId: 193478404

tensorflow/core/kernels/boosted_trees/resources.cc
tensorflow/core/kernels/boosted_trees/resources.h

index 2ea12c5..c410748 100644 (file)
@@ -21,6 +21,35 @@ limitations under the License.
 
 namespace tensorflow {
 
+// Constructor.
+BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
+    : tree_ensemble_(
+          protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
+              &arena_)) {}
+
+string BoostedTreesEnsembleResource::DebugString() {
+  return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
+                         "]");
+}
+
+bool BoostedTreesEnsembleResource::InitFromSerialized(const string& serialized,
+                                                      const int64 stamp_token) {
+  CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
+  if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
+    set_stamp(stamp_token);
+    return true;
+  }
+  return false;
+}
+
+string BoostedTreesEnsembleResource::SerializeAsString() const {
+  return tree_ensemble_->SerializeAsString();
+}
+
+int32 BoostedTreesEnsembleResource::num_trees() const {
+  return tree_ensemble_->trees_size();
+}
+
 int32 BoostedTreesEnsembleResource::next_node(
     const int32 tree_id, const int32 node_id, const int32 index_in_batch,
     const std::vector<TTypes<int32>::ConstVec>& bucketized_features) const {
@@ -49,6 +78,115 @@ float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
   }
 }
 
+int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
+    const int32 tree_id) const {
+  DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+  return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
+}
+
+void BoostedTreesEnsembleResource::SetNumLayersGrown(
+    const int32 tree_id, int32 new_num_layers) const {
+  DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+  tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
+      new_num_layers);
+}
+
+void BoostedTreesEnsembleResource::UpdateLastLayerNodesRange(
+    const int32 node_range_start, int32 node_range_end) const {
+  tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
+      node_range_start);
+  tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
+      node_range_end);
+}
+
+void BoostedTreesEnsembleResource::GetLastLayerNodesRange(
+    int32* node_range_start, int32* node_range_end) const {
+  *node_range_start =
+      tree_ensemble_->growing_metadata().last_layer_node_start();
+  *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
+}
+
+int64 BoostedTreesEnsembleResource::GetNumNodes(const int32 tree_id) {
+  DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+  return tree_ensemble_->trees(tree_id).nodes_size();
+}
+
+int32 BoostedTreesEnsembleResource::GetNumLayersAttempted() {
+  return tree_ensemble_->growing_metadata().num_layers_attempted();
+}
+
+bool BoostedTreesEnsembleResource::is_leaf(const int32 tree_id,
+                                           const int32 node_id) const {
+  DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+  DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
+  const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
+  return node.node_case() == boosted_trees::Node::kLeaf;
+}
+
+int32 BoostedTreesEnsembleResource::feature_id(const int32 tree_id,
+                                               const int32 node_id) const {
+  const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+  DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+  return node.bucketized_split().feature_id();
+}
+
+int32 BoostedTreesEnsembleResource::bucket_threshold(
+    const int32 tree_id, const int32 node_id) const {
+  const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+  DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+  return node.bucketized_split().threshold();
+}
+
+int32 BoostedTreesEnsembleResource::left_id(const int32 tree_id,
+                                            const int32 node_id) const {
+  const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+  DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+  return node.bucketized_split().left_id();
+}
+
+int32 BoostedTreesEnsembleResource::right_id(const int32 tree_id,
+                                             const int32 node_id) const {
+  const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+  DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+  return node.bucketized_split().right_id();
+}
+
+std::vector<float> BoostedTreesEnsembleResource::GetTreeWeights() const {
+  return {tree_ensemble_->tree_weights().begin(),
+          tree_ensemble_->tree_weights().end()};
+}
+
+float BoostedTreesEnsembleResource::GetTreeWeight(const int32 tree_id) const {
+  return tree_ensemble_->tree_weights(tree_id);
+}
+
+float BoostedTreesEnsembleResource::IsTreeFinalized(const int32 tree_id) const {
+  DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+  return tree_ensemble_->tree_metadata(tree_id).is_finalized();
+}
+
+float BoostedTreesEnsembleResource::IsTreePostPruned(
+    const int32 tree_id) const {
+  DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+  return tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size() >
+         0;
+}
+
+void BoostedTreesEnsembleResource::SetIsFinalized(const int32 tree_id,
+                                                  const bool is_finalized) {
+  DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+  return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
+      is_finalized);
+}
+
+// Sets the weight of i'th tree.
+void BoostedTreesEnsembleResource::SetTreeWeight(const int32 tree_id,
+                                                 const float weight) {
+  DCHECK_GE(tree_id, 0);
+  DCHECK_LT(tree_id, num_trees());
+  tree_ensemble_->set_tree_weights(tree_id, weight);
+}
+
 void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
   tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
       tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
index 561ca3a..df78d3f 100644 (file)
@@ -17,12 +17,16 @@ limitations under the License.
 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
 
 #include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/protobuf.h"
 
 namespace tensorflow {
 
+// Forward declaration for proto class TreeEnsemble
+namespace boosted_trees {
+class TreeEnsemble;
+}  // namespace boosted_trees
+
 // A StampedResource is a resource that has a stamp token associated with it.
 // Before reading from or applying updates to the resource, the stamp should
 // be checked to verify that the update is not stale.
@@ -42,31 +46,15 @@ class StampedResource : public ResourceBase {
 // Keep a tree ensemble in memory for efficient evaluation and mutation.
 class BoostedTreesEnsembleResource : public StampedResource {
  public:
-  // Constructor.
-  BoostedTreesEnsembleResource()
-      : tree_ensemble_(
-            protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
-                &arena_)) {}
-
-  string DebugString() override {
-    return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
-                           "]");
-  }
-
-  bool InitFromSerialized(const string& serialized, const int64 stamp_token) {
-    CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
-    if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
-      set_stamp(stamp_token);
-      return true;
-    }
-    return false;
-  }
-
-  string SerializeAsString() const {
-    return tree_ensemble_->SerializeAsString();
-  }
-
-  int32 num_trees() const { return tree_ensemble_->trees_size(); }
+  BoostedTreesEnsembleResource();
+
+  string DebugString() override;
+
+  bool InitFromSerialized(const string& serialized, const int64 stamp_token);
+
+  string SerializeAsString() const;
+
+  int32 num_trees() const;
 
   // Find the next node to which the example (specified by index_in_batch)
   // traverses down from the current node indicated by tree_id and node_id.
@@ -82,73 +70,31 @@ class BoostedTreesEnsembleResource : public StampedResource {
 
   float node_value(const int32 tree_id, const int32 node_id) const;
 
-  int32 GetNumLayersGrown(const int32 tree_id) const {
-    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
-    return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
-  }
+  int32 GetNumLayersGrown(const int32 tree_id) const;
 
-  void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const {
-    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
-    tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
-        new_num_layers);
-  }
+  void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const;
 
   void UpdateLastLayerNodesRange(const int32 node_range_start,
-                                 int32 node_range_end) const {
-    tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
-        node_range_start);
-    tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
-        node_range_end);
-  }
+                                 int32 node_range_end) const;
 
   void GetLastLayerNodesRange(int32* node_range_start,
-                              int32* node_range_end) const {
-    *node_range_start =
-        tree_ensemble_->growing_metadata().last_layer_node_start();
-    *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
-  }
+                              int32* node_range_end) const;
 
-  int64 GetNumNodes(const int32 tree_id) {
-    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
-    return tree_ensemble_->trees(tree_id).nodes_size();
-  }
+  int64 GetNumNodes(const int32 tree_id);
 
   void UpdateGrowingMetadata() const;
 
-  int32 GetNumLayersAttempted() {
-    return tree_ensemble_->growing_metadata().num_layers_attempted();
-  }
-
-  bool is_leaf(const int32 tree_id, const int32 node_id) const {
-    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
-    DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
-    const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
-    return node.node_case() == boosted_trees::Node::kLeaf;
-  }
-
-  int32 feature_id(const int32 tree_id, const int32 node_id) const {
-    const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
-    DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
-    return node.bucketized_split().feature_id();
-  }
-
-  int32 bucket_threshold(const int32 tree_id, const int32 node_id) const {
-    const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
-    DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
-    return node.bucketized_split().threshold();
-  }
-
-  int32 left_id(const int32 tree_id, const int32 node_id) const {
-    const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
-    DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
-    return node.bucketized_split().left_id();
-  }
-
-  int32 right_id(const int32 tree_id, const int32 node_id) const {
-    const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
-    DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
-    return node.bucketized_split().right_id();
-  }
+  int32 GetNumLayersAttempted();
+
+  bool is_leaf(const int32 tree_id, const int32 node_id) const;
+
+  int32 feature_id(const int32 tree_id, const int32 node_id) const;
+
+  int32 bucket_threshold(const int32 tree_id, const int32 node_id) const;
+
+  int32 left_id(const int32 tree_id, const int32 node_id) const;
+
+  int32 right_id(const int32 tree_id, const int32 node_id) const;
 
   // Add a tree to the ensemble and returns a new tree_id.
   int32 AddNewTree(const float weight);
@@ -163,38 +109,18 @@ class BoostedTreesEnsembleResource : public StampedResource {
   // Retrieves tree weights and returns as a vector.
   // It involves a copy, so should be called only sparingly (like once per
   // iteration, not per example).
-  std::vector<float> GetTreeWeights() const {
-    return {tree_ensemble_->tree_weights().begin(),
-            tree_ensemble_->tree_weights().end()};
-  }
-
-  float GetTreeWeight(const int32 tree_id) const {
-    return tree_ensemble_->tree_weights(tree_id);
-  }
-
-  float IsTreeFinalized(const int32 tree_id) const {
-    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
-    return tree_ensemble_->tree_metadata(tree_id).is_finalized();
-  }
-
-  float IsTreePostPruned(const int32 tree_id) const {
-    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
-    return tree_ensemble_->tree_metadata(tree_id)
-               .post_pruned_nodes_meta_size() > 0;
-  }
-
-  void SetIsFinalized(const int32 tree_id, const bool is_finalized) {
-    DCHECK_LT(tree_id, tree_ensemble_->trees_size());
-    return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
-        is_finalized);
-  }
+  std::vector<float> GetTreeWeights() const;
+
+  float GetTreeWeight(const int32 tree_id) const;
+
+  float IsTreeFinalized(const int32 tree_id) const;
+
+  float IsTreePostPruned(const int32 tree_id) const;
+
+  void SetIsFinalized(const int32 tree_id, const bool is_finalized);
 
   // Sets the weight of i'th tree.
-  void SetTreeWeight(const int32 tree_id, const float weight) {
-    DCHECK_GE(tree_id, 0);
-    DCHECK_LT(tree_id, num_trees());
-    tree_ensemble_->set_tree_weights(tree_id, weight);
-  }
+  void SetTreeWeight(const int32 tree_id, const float weight);
 
   // Resets the resource and frees the protos in arena.
   // Caller needs to hold the mutex lock while calling this.