Remove proto imports in header files for core/kernels/hexagon.
authorYifei Feng <yifeif@google.com>
Tue, 17 Apr 2018 01:41:28 +0000 (18:41 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 17 Apr 2018 01:44:28 +0000 (18:44 -0700)
The goal is to make kernels mostly independent of proto headers, which will let us lock down our .so imports.

PiperOrigin-RevId: 193134710

tensorflow/core/framework/graph_transfer_info.proto
tensorflow/core/kernels/hexagon/BUILD
tensorflow/core/kernels/hexagon/graph_transfer_utils.cc
tensorflow/core/kernels/hexagon/graph_transfer_utils.h
tensorflow/core/kernels/hexagon/graph_transferer.cc
tensorflow/core/kernels/hexagon/graph_transferer.h
tensorflow/core/kernels/hexagon/graph_transferer_test.cc
tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc

index 016259d..41dd54d 100644 (file)
@@ -8,6 +8,46 @@ option java_package = "org.tensorflow.framework";
 
 import "tensorflow/core/framework/types.proto";
 
+message GraphTransferNodeInput {
+  int32 node_id = 1;
+  int32 output_port = 2;
+}
+message GraphTransferNodeInfo {
+  string name = 1;
+  int32 node_id = 2;
+  string type_name = 3;
+  int32 soc_op_id = 4;
+  int32 padding_id = 5;
+  int32 input_count = 6;
+  int32 output_count = 7;
+};
+message GraphTransferConstNodeInfo {
+  string name = 1;
+  int32 node_id = 2;
+  repeated int64 shape = 3;
+  bytes data = 4;
+  DataType dtype = 5;
+};
+message GraphTransferNodeInputInfo {
+  int32 node_id = 1;
+  repeated GraphTransferNodeInput node_input = 2;
+};
+message GraphTransferNodeOutputInfo {
+  int32 node_id = 1;
+  repeated int32 max_byte_size = 2;
+};
+message GraphTransferGraphInputNodeInfo {
+  string name = 1;
+  repeated int64 shape = 2;
+  DataType dtype = 3;
+}
+
+message GraphTransferGraphOutputNodeInfo {
+  string name = 1;
+  repeated int64 shape = 2;
+  DataType dtype = 3;
+}
+
 // Protocol buffer representing a handle to a tensorflow resource. Handles are
 // not valid across executions, but can be serialized back and forth from within
 // a single run.
@@ -16,53 +56,14 @@ message GraphTransferInfo {
     NOP = 0;
     HEXAGON = 1;
   }
-  message NodeInput {
-    int32 node_id = 1;
-    int32 output_port = 2;
-  }
-  message NodeInfo {
-    string name = 1;
-    int32 node_id = 2;
-    string type_name = 3;
-    int32 soc_op_id = 4;
-    int32 padding_id = 5;
-    int32 input_count = 6;
-    int32 output_count = 7;
-  };
-  message ConstNodeInfo {
-    string name = 1;
-    int32 node_id = 2;
-    repeated int64 shape = 3;
-    bytes data = 4;
-    DataType dtype = 5;
-  };
-  message NodeInputInfo {
-    int32 node_id = 1;
-    repeated NodeInput node_input = 2;
-  };
-  message NodeOutputInfo {
-    int32 node_id = 1;
-    repeated int32 max_byte_size = 2;
-  };
-  message GraphInputNodeInfo {
-    string name = 1;
-    repeated int64 shape = 2;
-    DataType dtype = 3;
-  }
-
-  message GraphOutputNodeInfo {
-    string name = 1;
-    repeated int64 shape = 2;
-    DataType dtype = 3;
-  }
 
-  repeated NodeInfo node_info = 1;
-  repeated ConstNodeInfo const_node_info = 2;
-  repeated NodeInputInfo node_input_info = 3;
-  repeated NodeOutputInfo node_output_info = 4;
+  repeated GraphTransferNodeInfo node_info = 1;
+  repeated GraphTransferConstNodeInfo const_node_info = 2;
+  repeated GraphTransferNodeInputInfo node_input_info = 3;
+  repeated GraphTransferNodeOutputInfo node_output_info = 4;
   // Input Node parameters of transferred graph
-  repeated GraphInputNodeInfo graph_input_node_info = 5;
-  repeated GraphOutputNodeInfo graph_output_node_info = 6;
+  repeated GraphTransferGraphInputNodeInfo graph_input_node_info = 5;
+  repeated GraphTransferGraphOutputNodeInfo graph_output_node_info = 6;
   // Destination of graph transfer
   Destination destination = 7;
 };
index 4870d9a..66aeec5 100644 (file)
@@ -70,6 +70,7 @@ tf_kernel_library(
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/kernels:remote_fused_graph_execute_utils",
         "//third_party/eigen3",
+        "@com_google_absl//absl/memory",
     ],
 )
 
index 4040bf5..40bf5a4 100644 (file)
@@ -14,6 +14,8 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
 
 #include "tensorflow/cc/framework/scope.h"
 #include "tensorflow/cc/ops/const_op.h"
index 352d548..ada96ae 100644 (file)
@@ -20,14 +20,14 @@ limitations under the License.
 #include <utility>
 #include <vector>
 
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/kernels/hexagon/graph_transferer.h"
 #include "tensorflow/core/platform/macros.h"
 
 namespace tensorflow {
 
+class RemoteFusedGraphExecuteInfo;
+
 class GraphTransferUtils {
  public:
   static std::priority_queue<std::tuple<float, int, string>>
index 0963dff..7960cb4 100644 (file)
@@ -18,6 +18,8 @@ limitations under the License.
 #include <algorithm>
 #include <cinttypes>
 
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_transfer_info.pb.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/graph/graph_constructor.h"
@@ -73,6 +75,12 @@ static Node* FindMutableNodeByName(const string& name, Graph* graph) {
   return nullptr;
 }
 
+GraphTransferer::GraphTransferer() {
+  graph_transfer_info_ = new GraphTransferInfo();
+}
+
+GraphTransferer::~GraphTransferer() { delete graph_transfer_info_; }
+
 /**
  * graph loading functions
  * - LoadGraphFromProto
@@ -142,8 +150,8 @@ Status GraphTransferer::LoadGraphFromProto(
 
   for (const std::pair<string, Tensor>& input_node_info :
        input_node_info_list) {
-    GraphTransferInfo::GraphInputNodeInfo& graph_input_node_info =
-        *graph_transfer_info_.add_graph_input_node_info();
+    GraphTransferGraphInputNodeInfo& graph_input_node_info =
+        *graph_transfer_info_->add_graph_input_node_info();
     graph_input_node_info.set_name(input_node_info.first);
     graph_input_node_info.set_dtype(input_node_info.second.dtype());
     for (const int64 dim : ToTensorShapeArray(input_node_info.second.shape())) {
@@ -159,8 +167,8 @@ Status GraphTransferer::LoadGraphFromProto(
     const Node* node = node_name_cache_list_.at(node_id);
     CHECK_NOTNULL(node);
 
-    GraphTransferInfo::GraphOutputNodeInfo& graph_output_node_info =
-        *graph_transfer_info_.add_graph_output_node_info();
+    GraphTransferGraphOutputNodeInfo& graph_output_node_info =
+        *graph_transfer_info_->add_graph_output_node_info();
     graph_output_node_info.set_name(strings::StrCat(node_name, ":", port));
 
     // Get output tensor shape type
@@ -231,17 +239,17 @@ Status GraphTransferer::LoadGraphFromProtoFile(
 
 void GraphTransferer::SortParams(const std::vector<string>& output_node_names) {
   // TODO(satok): optimize complexity
-  std::unordered_map<int, GraphTransferInfo::NodeInputInfo*> input_map;
-  for (GraphTransferInfo::NodeInputInfo& input :
-       *graph_transfer_info_.mutable_node_input_info()) {
+  std::unordered_map<int, GraphTransferNodeInputInfo*> input_map;
+  for (GraphTransferNodeInputInfo& input :
+       *graph_transfer_info_->mutable_node_input_info()) {
     input_map.emplace(input.node_id(), &input);
   }
 
   // Setup dependency map placeholder
   std::vector<int> output_node_ids;
   std::unordered_map<int, std::unordered_set<int>> dependency_map;
-  for (const GraphTransferInfo::NodeInfo& params :
-       graph_transfer_info_.node_info()) {
+  for (const GraphTransferNodeInfo& params :
+       graph_transfer_info_->node_info()) {
     const int node_id = params.node_id();
     for (const string& output_node_name : output_node_names) {
       if (params.name() == output_node_name) {
@@ -255,7 +263,7 @@ void GraphTransferer::SortParams(const std::vector<string>& output_node_names) {
       continue;
     }
     CHECK_EQ(input_map.count(node_id), 1);
-    for (const GraphTransferInfo::NodeInput& node_input :
+    for (const GraphTransferNodeInput& node_input :
          input_map.at(node_id)->node_input()) {
       dependency_map.at(node_id).emplace(node_input.node_id());
     }
@@ -267,8 +275,8 @@ void GraphTransferer::SortParams(const std::vector<string>& output_node_names) {
     FillDependencyRec(output_node_id, dependency_map, completed);
   }
 
-  std::sort(graph_transfer_info_.mutable_node_info()->begin(),
-            graph_transfer_info_.mutable_node_info()->end(),
+  std::sort(graph_transfer_info_->mutable_node_info()->begin(),
+            graph_transfer_info_->mutable_node_info()->end(),
             TransferParamsComparator(dependency_map));
 }
 
@@ -278,15 +286,15 @@ void GraphTransferer::EnableStrictCheckMode(const bool enable) {
 
 void GraphTransferer::SetSerializedGraphTransferInfo(
     const string& serialized_proto) {
-  graph_transfer_info_.ParseFromString(serialized_proto);
+  graph_transfer_info_->ParseFromString(serialized_proto);
 }
 
 const GraphTransferInfo& GraphTransferer::GetGraphTransferInfo() const {
-  return graph_transfer_info_;
+  return *graph_transfer_info_;
 }
 
 GraphTransferInfo& GraphTransferer::GetMutableGraphTransferInfo() {
-  return graph_transfer_info_;
+  return *graph_transfer_info_;
 }
 
 void GraphTransferer::CacheNode(const Node& node) {
@@ -473,8 +481,8 @@ void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner,
   data_size = max_bytes_per_data * num_output_elements;
   shape_array = BuildShapeArray(shape_handle, context);
 
-  GraphTransferInfo::ConstNodeInfo& const_node_info =
-      *graph_transfer_info_.add_const_node_info();
+  GraphTransferConstNodeInfo& const_node_info =
+      *graph_transfer_info_->add_const_node_info();
   const_node_info.set_name(node.name());
   const_node_info.set_node_id(id);
   // TODO(satok): Make this generic. Never assume rank is 4.
@@ -505,8 +513,8 @@ int GraphTransferer::RegisterConstantShape(const std::vector<int>& shape) {
     node_name_cache_list_.emplace_back(nullptr);
     const int id = node_name_cache_list_.size() - 1;
     node_name_to_id_cache_map_.emplace(shape_name, id);
-    GraphTransferInfo::ConstNodeInfo& const_node_info =
-        *graph_transfer_info_.add_const_node_info();
+    GraphTransferConstNodeInfo& const_node_info =
+        *graph_transfer_info_->add_const_node_info();
     const_node_info.set_name(shape_name);
     const_node_info.set_node_id(id);
     // TODO(satok): Make this generic. Never assume rank is 5.
@@ -528,8 +536,8 @@ int GraphTransferer::RegisterConstTensor(const Tensor& tensor,
     node_name_cache_list_.emplace_back(nullptr);
     const int id = node_name_cache_list_.size() - 1;
     node_name_to_id_cache_map_.emplace(node_name, id);
-    GraphTransferInfo::ConstNodeInfo& const_node_info =
-        *graph_transfer_info_.add_const_node_info();
+    GraphTransferConstNodeInfo& const_node_info =
+        *graph_transfer_info_->add_const_node_info();
     const_node_info.set_name(node_name);
     const_node_info.set_node_id(id);
     CHECK_EQ(4, SHAPE_ARRAY_SIZE);
@@ -558,8 +566,8 @@ int GraphTransferer::RegisterConstScalar(const DataType dt, const int val,
     node_name_cache_list_.emplace_back(nullptr);
     const int id = node_name_cache_list_.size() - 1;
     node_name_to_id_cache_map_.emplace(val_name, id);
-    GraphTransferInfo::ConstNodeInfo& const_node_info =
-        *graph_transfer_info_.add_const_node_info();
+    GraphTransferConstNodeInfo& const_node_info =
+        *graph_transfer_info_->add_const_node_info();
     const_node_info.set_name(val_name);
     const_node_info.set_node_id(id);
     // TODO(satok): Do not assume rank is 4 here.
@@ -715,8 +723,8 @@ void GraphTransferer::RegisterPadNode(
 
   CHECK_EQ(2, node.num_inputs());
 
-  GraphTransferInfo::NodeInputInfo& node_input_info =
-      *graph_transfer_info_.add_node_input_info();
+  GraphTransferNodeInputInfo& node_input_info =
+      *graph_transfer_info_->add_node_input_info();
   node_input_info.set_node_id(id);
 
   AddNodeInputByInputIndex(node, 0, &node_input_info);
@@ -761,8 +769,7 @@ void GraphTransferer::RegisterPadNode(
         new_const_tensor,
         strings::StrCat(input_node->name(), "_", node.name(), "_1"));
 
-    GraphTransferInfo::NodeInput& node_input =
-        *node_input_info.add_node_input();
+    GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
     node_input.set_node_id(id);
     node_input.set_output_port(0);
   } else {
@@ -849,8 +856,7 @@ void GraphTransferer::AppendNodeParams(const string& name, const int id,
                                        const int padding, const int inputs_size,
                                        const std::vector<int>& extra_inputs,
                                        const int outputs_size) {
-  GraphTransferInfo::NodeInfo& node_info =
-      *graph_transfer_info_.add_node_info();
+  GraphTransferNodeInfo& node_info = *graph_transfer_info_->add_node_info();
   node_info.set_name(name);
   node_info.set_node_id(id);
   node_info.set_type_name(type);
@@ -863,7 +869,7 @@ void GraphTransferer::AppendNodeParams(const string& name, const int id,
 
 void GraphTransferer::AddNodeInputByInputIndex(
     const Node& node, const int idx,
-    GraphTransferInfo::NodeInputInfo* node_input_info) {
+    GraphTransferNodeInputInfo* node_input_info) {
   const Edge* edge = nullptr;
   TF_CHECK_OK(node.input_edge(idx, &edge));
   const Node* input_node = edge->src();
@@ -873,7 +879,7 @@ void GraphTransferer::AddNodeInputByInputIndex(
   const std::string& op_name = input_node->name();
   CHECK_GT(node_name_to_id_cache_map_.count(op_name), 0) << op_name;
   const int src_id = node_name_to_id_cache_map_[op_name];
-  GraphTransferInfo::NodeInput& node_input = *node_input_info->add_node_input();
+  GraphTransferNodeInput& node_input = *node_input_info->add_node_input();
   node_input.set_node_id(src_id);
   node_input.set_output_port(port);
 }
@@ -882,15 +888,14 @@ void GraphTransferer::AppendNodeInputParams(
     const int id, const Node& node, const std::vector<int>& extra_inputs) {
   VLOG(1) << "Append input params: " << node.name() << ", " << node.num_inputs()
           << ", " << extra_inputs.size();
-  GraphTransferInfo::NodeInputInfo& node_input_info =
-      *graph_transfer_info_.add_node_input_info();
+  GraphTransferNodeInputInfo& node_input_info =
+      *graph_transfer_info_->add_node_input_info();
   node_input_info.set_node_id(id);
   for (int i = 0; i < node.num_inputs(); ++i) {
     AddNodeInputByInputIndex(node, i, &node_input_info);
   }
   for (const int extra_input : extra_inputs) {
-    GraphTransferInfo::NodeInput& node_input =
-        *node_input_info.add_node_input();
+    GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
     node_input.set_node_id(extra_input);
     node_input.set_output_port(0);
   }
@@ -900,8 +905,8 @@ void GraphTransferer::AppendNodeOutputParams(const ShapeRefiner& shape_refiner,
                                              const int id, const Node& node) {
   VLOG(1) << "Append output params: " << node.name() << ", "
           << node.num_outputs();
-  GraphTransferInfo::NodeOutputInfo& node_output_info =
-      *graph_transfer_info_.add_node_output_info();
+  GraphTransferNodeOutputInfo& node_output_info =
+      *graph_transfer_info_->add_node_output_info();
   node_output_info.set_node_id(id);
 
   std::vector<DataType> data_types;
@@ -1030,8 +1035,7 @@ GraphTransferer::TransferParamsComparator::TransferParamsComparator(
     : dependency_map_(dep_map) {}
 
 bool GraphTransferer::TransferParamsComparator::operator()(
-    const GraphTransferInfo::NodeInfo& obj0,
-    const GraphTransferInfo::NodeInfo& obj1) {
+    const GraphTransferNodeInfo& obj0, const GraphTransferNodeInfo& obj1) {
   const int node_id0 = obj0.node_id();
   const int node_id1 = obj1.node_id();
   bool obj0_uses_obj1 = false;
@@ -1114,8 +1118,8 @@ void GraphTransferer::ClearCache() {
 
 void GraphTransferer::DumpNodeTransferParams() const {
   LOG(INFO) << "*** Const Nodes ***";
-  for (const GraphTransferInfo::ConstNodeInfo& params :
-       graph_transfer_info_.const_node_info()) {
+  for (const GraphTransferConstNodeInfo& params :
+       graph_transfer_info_->const_node_info()) {
     // TODO(satok): Stop assuming shape size is 4.
     CHECK_EQ(params.shape_size(), 4);
     LOG(INFO) << "[ " << params.node_id() << " \"" << params.name()
@@ -1131,8 +1135,8 @@ void GraphTransferer::DumpNodeTransferParams() const {
   }
   LOG(INFO) << "******\n";
   LOG(INFO) << "*** Op Nodes ***";
-  for (const GraphTransferInfo::NodeInfo& params :
-       graph_transfer_info_.node_info()) {
+  for (const GraphTransferNodeInfo& params :
+       graph_transfer_info_->node_info()) {
     LOG(INFO) << "[ " << params.node_id() << " \"" << params.name();
     LOG(INFO) << "  type: " << params.type_name();
     LOG(INFO) << "  padding: " << ToPaddingDebugString(params.padding_id());
@@ -1146,18 +1150,18 @@ void GraphTransferer::DumpNodeTransferParams() const {
   }
   LOG(INFO) << "******\n";
   LOG(INFO) << "*** Node input params ***";
-  for (const GraphTransferInfo::NodeInputInfo& params :
-       graph_transfer_info_.node_input_info()) {
+  for (const GraphTransferNodeInputInfo& params :
+       graph_transfer_info_->node_input_info()) {
     LOG(INFO) << "[ " << params.node_id() << " ]";
-    for (const GraphTransferInfo::NodeInput& node_input : params.node_input()) {
+    for (const GraphTransferNodeInput& node_input : params.node_input()) {
       LOG(INFO) << "    src node id = " << node_input.node_id()
                 << ", output port = " << node_input.output_port();
     }
   }
   LOG(INFO) << "******\n";
   LOG(INFO) << "*** Node output params ***";
-  for (const GraphTransferInfo::NodeOutputInfo& params :
-       graph_transfer_info_.node_output_info()) {
+  for (const GraphTransferNodeOutputInfo& params :
+       graph_transfer_info_->node_output_info()) {
     LOG(INFO) << "[ " << params.node_id() << " ]";
     for (const int max_size : params.max_byte_size()) {
       LOG(INFO) << "    max_size = " << max_size;
@@ -1167,8 +1171,8 @@ void GraphTransferer::DumpNodeTransferParams() const {
 }
 
 void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
-  for (const GraphTransferInfo::ConstNodeInfo& params :
-       graph_transfer_info_.const_node_info()) {
+  for (const GraphTransferConstNodeInfo& params :
+       graph_transfer_info_->const_node_info()) {
     std::stringstream sstream;
     // TODO(satok): Stop assuming shape size is 4.
     CHECK_EQ(params.shape_size(), 4);
@@ -1182,9 +1186,9 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
     LOG(INFO) << sstream.str();
   }
   LOG(INFO) << "Const node count = "
-            << graph_transfer_info_.const_node_info_size();
-  for (const GraphTransferInfo::NodeInfo& params :
-       graph_transfer_info_.node_info()) {
+            << graph_transfer_info_->const_node_info_size();
+  for (const GraphTransferNodeInfo& params :
+       graph_transfer_info_->node_info()) {
     std::stringstream sstream;
     sstream << "---(OP) [" << params.name().c_str() << "," << std::hex
             << params.node_id() << std::dec << "," << params.soc_op_id() << ","
@@ -1197,12 +1201,12 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
             << "," << params.output_count() << "," << params.type_name() << "]";
     LOG(INFO) << sstream.str();
   }
-  LOG(INFO) << "Op node count = " << graph_transfer_info_.node_info_size();
-  for (const GraphTransferInfo::NodeInputInfo& params :
-       graph_transfer_info_.node_input_info()) {
+  LOG(INFO) << "Op node count = " << graph_transfer_info_->node_info_size();
+  for (const GraphTransferNodeInputInfo& params :
+       graph_transfer_info_->node_input_info()) {
     std::stringstream sstream;
     sstream << "---(INPUT) [" << std::hex << params.node_id() << std::dec;
-    for (const GraphTransferInfo::NodeInput& node_input : params.node_input()) {
+    for (const GraphTransferNodeInput& node_input : params.node_input()) {
       sstream << "," << std::hex << node_input.node_id() << std::dec << ","
               << node_input.output_port();
     }
@@ -1210,9 +1214,9 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
     LOG(INFO) << sstream.str();
   }
   LOG(INFO) << "Input params count = "
-            << graph_transfer_info_.node_input_info_size();
-  for (const GraphTransferInfo::NodeOutputInfo& params :
-       graph_transfer_info_.node_output_info()) {
+            << graph_transfer_info_->node_input_info_size();
+  for (const GraphTransferNodeOutputInfo& params :
+       graph_transfer_info_->node_output_info()) {
     std::stringstream sstream;
     sstream << "---(OUTPUT) [" << std::hex << params.node_id() << std::dec;
     for (const int max_size : params.max_byte_size()) {
@@ -1222,7 +1226,7 @@ void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
     LOG(INFO) << sstream.str();
   }
   LOG(INFO) << "Output params count = "
-            << graph_transfer_info_.node_output_info_size();
+            << graph_transfer_info_->node_output_info_size();
 }
 
 }  // namespace tensorflow
index 0d43d02..86c1c56 100644 (file)
@@ -22,8 +22,6 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/core/common_runtime/shape_refiner.h"
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/graph_transfer_info.pb.h"
 #include "tensorflow/core/framework/shape_inference.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
@@ -34,6 +32,10 @@ limitations under the License.
 
 namespace tensorflow {
 
+class GraphTransferInfo;
+class GraphTransferNodeInfo;
+class GraphTransferNodeInputInfo;
+
 // GraphTransferer transfers graph definitions into SoC memory.
 // This functionality is effective if SoC is capable to run
 // the graph on that chip.
@@ -47,7 +49,9 @@ class GraphTransferer {
   static constexpr int SHAPE_ARRAY_SIZE = MAX_SUPPORTED_RANK;
   using TensorShapeMap = RemoteFusedGraphExecuteUtils::TensorShapeMap;
 
-  GraphTransferer() = default;
+  GraphTransferer();
+
+  ~GraphTransferer();
 
   // Load graph structure into GraphTransferer
   // TODO(satok): Pass a pair of TensorShape and DataType instead of
@@ -96,8 +100,8 @@ class GraphTransferer {
    public:
     TransferParamsComparator(
         const std::unordered_map<int, std::unordered_set<int>>& dep_map);
-    bool operator()(const GraphTransferInfo::NodeInfo& obj0,
-                    const GraphTransferInfo::NodeInfo& obj1);
+    bool operator()(const GraphTransferNodeInfo& obj0,
+                    const GraphTransferNodeInfo& obj1);
     const std::unordered_map<int, std::unordered_set<int>>& dependency_map_;
   };
 
@@ -174,9 +178,8 @@ class GraphTransferer {
                         const std::vector<int>& extra_inputs,
                         const int outputs_size);
 
-  void AddNodeInputByInputIndex(
-      const Node& node, const int idx,
-      GraphTransferInfo::NodeInputInfo* node_input_info);
+  void AddNodeInputByInputIndex(const Node& node, const int idx,
+                                GraphTransferNodeInputInfo* node_input_info);
 
   void AppendNodeInputParams(const int id, const Node& node,
                              const std::vector<int>& extra_inputs);
@@ -211,7 +214,7 @@ class GraphTransferer {
   // Dump pretty print of parameters
   void DumpNodeTransferParams() const;
 
-  GraphTransferInfo graph_transfer_info_{};
+  GraphTransferInfo* graph_transfer_info_;
 
   std::vector<const Node*> node_name_cache_list_{};
   std::unordered_map<string, int> node_name_to_id_cache_map_{};
index 20b09f1..765795b 100644 (file)
@@ -191,9 +191,9 @@ static GraphDef CreatePoolGraphDef() {
   return def;
 }
 
-static const GraphTransferInfo::ConstNodeInfo* FindConstNodeInfo(
+static const GraphTransferConstNodeInfo* FindConstNodeInfo(
     const GraphTransferer& gt, const string& name) {
-  for (const GraphTransferInfo::ConstNodeInfo& params :
+  for (const GraphTransferConstNodeInfo& params :
        gt.GetGraphTransferInfo().const_node_info()) {
     if (params.name() == name) {
       return &params;
@@ -202,9 +202,9 @@ static const GraphTransferInfo::ConstNodeInfo* FindConstNodeInfo(
   return nullptr;
 }
 
-static const GraphTransferInfo::NodeInfo* FindNodeInfo(
-    const GraphTransferer& gt, const string& name) {
-  for (const GraphTransferInfo::NodeInfo& params :
+static const GraphTransferNodeInfo* FindNodeInfo(const GraphTransferer& gt,
+                                                 const string& name) {
+  for (const GraphTransferNodeInfo& params :
        gt.GetGraphTransferInfo().node_info()) {
     if (params.name() == name) {
       return &params;
@@ -213,9 +213,9 @@ static const GraphTransferInfo::NodeInfo* FindNodeInfo(
   return nullptr;
 }
 
-static const GraphTransferInfo::NodeInputInfo* FindNodeInputInfo(
+static const GraphTransferNodeInputInfo* FindNodeInputInfo(
     const GraphTransferer& gt, const int node_id) {
-  for (const GraphTransferInfo::NodeInputInfo& params :
+  for (const GraphTransferNodeInputInfo& params :
        gt.GetGraphTransferInfo().node_input_info()) {
     if (params.node_id() == node_id) {
       return &params;
@@ -224,9 +224,9 @@ static const GraphTransferInfo::NodeInputInfo* FindNodeInputInfo(
   return nullptr;
 }
 
-static const GraphTransferInfo::NodeOutputInfo* FindNodeOutputInfo(
+static const GraphTransferNodeOutputInfo* FindNodeOutputInfo(
     const GraphTransferer& gt, const int node_id) {
-  for (const GraphTransferInfo::NodeOutputInfo& params :
+  for (const GraphTransferNodeOutputInfo& params :
        gt.GetGraphTransferInfo().node_output_info()) {
     if (params.node_id() == node_id) {
       return &params;
@@ -236,21 +236,21 @@ static const GraphTransferInfo::NodeOutputInfo* FindNodeOutputInfo(
 }
 
 static void SanityCheckNodes(const GraphTransferer& gt) {
-  for (const GraphTransferInfo::NodeInfo& params :
+  for (const GraphTransferNodeInfo& params :
        gt.GetGraphTransferInfo().node_info()) {
     if (params.input_count() > 0) {
-      const GraphTransferInfo::NodeInputInfo* input_params =
+      const GraphTransferNodeInputInfo* input_params =
           FindNodeInputInfo(gt, params.node_id());
       ASSERT_NE(nullptr, input_params);
       EXPECT_EQ(params.input_count(), input_params->node_input_size());
       EXPECT_EQ(params.node_id(), input_params->node_id());
-      for (const GraphTransferInfo::NodeInput& node_input :
+      for (const GraphTransferNodeInput& node_input :
            input_params->node_input()) {
         EXPECT_GE(node_input.output_port(), 0);
       }
     }
     if (params.output_count() > 0) {
-      const GraphTransferInfo::NodeOutputInfo* output_params =
+      const GraphTransferNodeOutputInfo* output_params =
           FindNodeOutputInfo(gt, params.node_id());
       ASSERT_NE(nullptr, output_params);
       EXPECT_EQ(params.output_count(), output_params->max_byte_size_size());
@@ -273,8 +273,7 @@ TEST_F(GraphTransfererTest, LoadAddGraph) {
   const int const_node_count =
       gt_.GetGraphTransferInfo().const_node_info_size();
   ASSERT_EQ(2, const_node_count);
-  const GraphTransferInfo::ConstNodeInfo* params_a =
-      FindConstNodeInfo(gt_, NAME_A);
+  const GraphTransferConstNodeInfo* params_a = FindConstNodeInfo(gt_, NAME_A);
   ASSERT_TRUE(params_a != nullptr);
   EXPECT_EQ(NAME_A, params_a->name());
   ASSERT_EQ(4, params_a->shape_size());
@@ -284,8 +283,7 @@ TEST_F(GraphTransfererTest, LoadAddGraph) {
   EXPECT_EQ(1, params_a->shape(3));
   EXPECT_EQ(4, params_a->data().length());
 
-  const GraphTransferInfo::ConstNodeInfo* params_b =
-      FindConstNodeInfo(gt_, NAME_B);
+  const GraphTransferConstNodeInfo* params_b = FindConstNodeInfo(gt_, NAME_B);
   ASSERT_TRUE(params_b != nullptr);
   ASSERT_EQ(4, params_b->shape_size());
   EXPECT_EQ(1, params_b->shape(0));
@@ -328,7 +326,7 @@ TEST_F(GraphTransfererTest, LoadConvGraph) {
   ASSERT_EQ(2, const_node_count);
   const int op_node_count = gt_.GetGraphTransferInfo().node_info_size();
   ASSERT_EQ(4, op_node_count);
-  const GraphTransferInfo::NodeInfo* params_conv = FindNodeInfo(gt_, "conv");
+  const GraphTransferNodeInfo* params_conv = FindNodeInfo(gt_, "conv");
   ASSERT_TRUE(params_conv != nullptr);
   const int id = params_conv->node_id();
   EXPECT_GE(id, 0);
@@ -354,8 +352,7 @@ TEST_F(GraphTransfererTest, LoadMaxPoolGraph) {
   ASSERT_EQ(2, const_node_count);
   const int op_node_count = gt_.GetGraphTransferInfo().node_info_size();
   ASSERT_EQ(4, op_node_count);
-  const GraphTransferInfo::NodeInfo* params_max_pool =
-      FindNodeInfo(gt_, "maxpool");
+  const GraphTransferNodeInfo* params_max_pool = FindNodeInfo(gt_, "maxpool");
   ASSERT_TRUE(params_max_pool != nullptr);
   const int id = params_max_pool->node_id();
   EXPECT_GE(id, 0);
index 9c2e1e1..66d24d1 100644 (file)
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
 
+#include "tensorflow/core/framework/graph_transfer_info.pb.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
 #include "tensorflow/core/kernels/hexagon/soc_interface.h"
@@ -54,9 +55,9 @@ static uint8* FindAlignedPointer(uint8* ptr) {
   return data_ptr;
 }
 
-/* static */ GraphTransferInfo::NodeInfo* HexagonControlWrapper::FindNodeInfo(
+/* static */ GraphTransferNodeInfo* HexagonControlWrapper::FindNodeInfo(
     const string& name, GraphTransferInfo* graph_transfer_info) {
-  for (GraphTransferInfo::NodeInfo& node_info :
+  for (GraphTransferNodeInfo& node_info :
        *graph_transfer_info->mutable_node_info()) {
     if (node_info.name() == name) {
       return &node_info;
@@ -138,9 +139,9 @@ bool HexagonControlWrapper::SetupGraph() {
       graph_transferer_.GetMutableGraphTransferInfo();
 
   // Overwrite op type of input nodes for hexagon
-  for (const GraphTransferInfo::GraphInputNodeInfo& graph_input :
+  for (const GraphTransferGraphInputNodeInfo& graph_input :
        graph_transfer_info.graph_input_node_info()) {
-    GraphTransferInfo::NodeInfo* node_info =
+    GraphTransferNodeInfo* node_info =
         FindNodeInfo(graph_input.name(), &graph_transfer_info);
     CHECK_NE(node_info, nullptr);
   }
@@ -148,13 +149,13 @@ bool HexagonControlWrapper::SetupGraph() {
   // Generate a new output node which is connected to graph output node
   // TODO(satok): Support multiple output nodes
   CHECK_EQ(graph_transfer_info.graph_output_node_info_size(), 1);
-  for (const GraphTransferInfo::GraphOutputNodeInfo& graph_output :
+  for (const GraphTransferGraphOutputNodeInfo& graph_output :
        graph_transfer_info.graph_output_node_info()) {
     const int new_output_node_id = graph_transfer_info.node_info_size() +
                                    graph_transfer_info.const_node_info_size() +
                                    2 /* offset for ids */;
     // Register a new output node
-    GraphTransferInfo::NodeInfo& new_output_node_info =
+    GraphTransferNodeInfo& new_output_node_info =
         *graph_transfer_info.add_node_info();
     new_output_node_info.set_name(OUTPUT_OP_NAME);
     new_output_node_info.set_node_id(new_output_node_id);
@@ -169,14 +170,13 @@ bool HexagonControlWrapper::SetupGraph() {
     const string node_name = tid.first.ToString();
     const int port = tid.second;
     // Register node input for the new output node
-    const GraphTransferInfo::NodeInfo* node_info =
+    const GraphTransferNodeInfo* node_info =
         FindNodeInfo(node_name, &graph_transfer_info);
     CHECK_NE(node_info, nullptr);
-    GraphTransferInfo::NodeInputInfo& node_input_info =
+    GraphTransferNodeInputInfo& node_input_info =
         *graph_transfer_info.add_node_input_info();
     node_input_info.set_node_id(new_output_node_id);
-    GraphTransferInfo::NodeInput& node_input =
-        *node_input_info.add_node_input();
+    GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
     node_input.set_node_id(node_info->node_id());
     node_input.set_output_port(port);
   }
@@ -189,12 +189,12 @@ bool HexagonControlWrapper::SetupGraph() {
 
   int inputs_count = 0;
   int outputs_count = 0;
-  for (const GraphTransferInfo::NodeInputInfo& input_params :
+  for (const GraphTransferNodeInputInfo& input_params :
        graph_transfer_info.node_input_info()) {
     inputs_count += input_params.node_input_size();
   }
 
-  for (const GraphTransferInfo::NodeOutputInfo& output_params :
+  for (const GraphTransferNodeOutputInfo& output_params :
        graph_transfer_info.node_output_info()) {
     outputs_count += output_params.max_byte_size_size();
   }
@@ -204,15 +204,14 @@ bool HexagonControlWrapper::SetupGraph() {
 
   // Construct node input parameters
   std::unordered_map<int, std::tuple<void*, int>> inputs_map;
-  for (const GraphTransferInfo::NodeInputInfo& input_params :
+  for (const GraphTransferNodeInputInfo& input_params :
        graph_transfer_info.node_input_info()) {
     const int count = input_params.node_input_size();
     CHECK(count <= MAX_IN_OUT_COUNT);
     int node_ids[MAX_IN_OUT_COUNT];
     int ports[MAX_IN_OUT_COUNT];
     for (int i = 0; i < count; ++i) {
-      const GraphTransferInfo::NodeInput& node_input =
-          input_params.node_input(i);
+      const GraphTransferNodeInput& node_input = input_params.node_input(i);
       node_ids[i] = node_input.node_id() + NODE_ID_OFFSET;
       ports[i] = node_input.output_port();
     }
@@ -224,7 +223,7 @@ bool HexagonControlWrapper::SetupGraph() {
 
   // Construct node output parameters
   std::unordered_map<int, std::tuple<void*, int>> outputs_map;
-  for (const GraphTransferInfo::NodeOutputInfo& output_params :
+  for (const GraphTransferNodeOutputInfo& output_params :
        graph_transfer_info.node_output_info()) {
     const int count = output_params.max_byte_size_size();
     CHECK(count <= MAX_IN_OUT_COUNT);
@@ -244,7 +243,7 @@ bool HexagonControlWrapper::SetupGraph() {
 
   // Initialize graph
   // 1. Setup const nodes
-  for (const GraphTransferInfo::ConstNodeInfo& params :
+  for (const GraphTransferConstNodeInfo& params :
        graph_transfer_info.const_node_info()) {
     const int node_id = params.node_id();
     // TODO(satok): Stop assuming shape size is 4.
@@ -267,8 +266,7 @@ bool HexagonControlWrapper::SetupGraph() {
   }
 
   // 2. Setup op nodes
-  for (const GraphTransferInfo::NodeInfo& params :
-       graph_transfer_info.node_info()) {
+  for (const GraphTransferNodeInfo& params : graph_transfer_info.node_info()) {
     const int node_id = params.node_id();
     const int op_id = params.soc_op_id();
     CHECK(inputs_map.count(node_id) == 1);
index dca1f94..132cfde 100644 (file)
@@ -67,8 +67,8 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
   // CAVEAT: Need offset as HVX library reserves some ids
   static constexpr int NODE_ID_OFFSET = 0x10000;
 
-  static GraphTransferInfo::NodeInfo* FindNodeInfo(
-      const string& node_name, GraphTransferInfo* graph_transfer_info);
+  static GraphTransferNodeInfo* FindNodeInfo(
+      const string& name, GraphTransferInfo* graph_transfer_info);
 
   const RemoteFusedGraphExecuteInfo* execute_info_{};
   GraphTransferer graph_transferer_{};
index 3f794df..5fb6b92 100644 (file)
@@ -29,6 +29,7 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
 
 #include <memory>
 
+#include "tensorflow/core/framework/graph_transfer_info.pb.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
@@ -209,7 +210,7 @@ BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(
     const GraphTransferInfo& graph_transfer_info) {
   RemoteFusedGraphExecuteInfo execute_info;
   execute_info.set_executor_name("build_hexagon_remote_fused_graph_executor");
-  for (const GraphTransferInfo::GraphInputNodeInfo& input :
+  for (const GraphTransferGraphInputNodeInfo& input :
        graph_transfer_info.graph_input_node_info()) {
     execute_info.add_graph_input_node_name(input.name());
     RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
@@ -221,7 +222,7 @@ BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(
     }
   }
 
-  for (const GraphTransferInfo::GraphOutputNodeInfo& output :
+  for (const GraphTransferGraphOutputNodeInfo& output :
        graph_transfer_info.graph_output_node_info()) {
     execute_info.add_graph_output_node_name(output.name());
     RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
@@ -325,8 +326,8 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
   // 1. check node_info
   ASSERT_EQ(gfi0.node_info_size(), gfi1.node_info_size());
   for (int i = 0; i < gfi0.node_info_size(); ++i) {
-    const GraphTransferInfo::NodeInfo& ni0 = gfi0.node_info(i);
-    const GraphTransferInfo::NodeInfo& ni1 = gfi1.node_info(i);
+    const GraphTransferNodeInfo& ni0 = gfi0.node_info(i);
+    const GraphTransferNodeInfo& ni1 = gfi1.node_info(i);
     EXPECT_EQ(ni0.DebugString(), ni1.DebugString());
     EXPECT_EQ(ni0.ByteSizeLong(), ni1.ByteSizeLong());
   }
@@ -334,8 +335,8 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
   // 2. check const_node_info
   ASSERT_EQ(gfi0.const_node_info_size(), gfi1.const_node_info_size());
   for (int i = 0; i < gfi0.const_node_info_size(); ++i) {
-    const GraphTransferInfo::ConstNodeInfo& cni0 = gfi0.const_node_info(i);
-    const GraphTransferInfo::ConstNodeInfo& cni1 = gfi1.const_node_info(i);
+    const GraphTransferConstNodeInfo& cni0 = gfi0.const_node_info(i);
+    const GraphTransferConstNodeInfo& cni1 = gfi1.const_node_info(i);
     ASSERT_EQ(cni0.shape_size(), cni1.shape_size());
     for (int j = 0; j < cni0.shape_size(); ++j) {
       EXPECT_EQ(cni0.shape(j), cni1.shape(j));
@@ -347,8 +348,8 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
   // 3. check node_input_info
   ASSERT_EQ(gfi0.node_input_info_size(), gfi1.node_input_info_size());
   for (int i = 0; i < gfi0.node_input_info_size(); ++i) {
-    const GraphTransferInfo::NodeInputInfo& nii0 = gfi0.node_input_info(i);
-    const GraphTransferInfo::NodeInputInfo& nii1 = gfi1.node_input_info(i);
+    const GraphTransferNodeInputInfo& nii0 = gfi0.node_input_info(i);
+    const GraphTransferNodeInputInfo& nii1 = gfi1.node_input_info(i);
     EXPECT_EQ(nii0.ByteSizeLong(), nii1.ByteSizeLong());
     EXPECT_EQ(nii0.DebugString(), nii1.DebugString());
   }
@@ -356,8 +357,8 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
   // 4. check node_output_info
   ASSERT_EQ(gfi0.node_output_info_size(), gfi1.node_output_info_size());
   for (int i = 0; i < gfi0.node_output_info_size(); ++i) {
-    const GraphTransferInfo::NodeOutputInfo& noi0 = gfi0.node_output_info(i);
-    const GraphTransferInfo::NodeOutputInfo& noi1 = gfi1.node_output_info(i);
+    const GraphTransferNodeOutputInfo& noi0 = gfi0.node_output_info(i);
+    const GraphTransferNodeOutputInfo& noi1 = gfi1.node_output_info(i);
     ASSERT_EQ(noi0.max_byte_size_size(), noi1.max_byte_size_size());
     for (int j = 0; j < noi0.max_byte_size_size(); ++j) {
       EXPECT_EQ(noi0.max_byte_size(j), noi1.max_byte_size(j));
@@ -370,9 +371,9 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
   ASSERT_EQ(gfi0.graph_input_node_info_size(),
             gfi1.graph_input_node_info_size());
   for (int i = 0; i < gfi0.graph_input_node_info_size(); ++i) {
-    const GraphTransferInfo::GraphInputNodeInfo& gini0 =
+    const GraphTransferGraphInputNodeInfo& gini0 =
         gfi0.graph_input_node_info(i);
-    const GraphTransferInfo::GraphInputNodeInfo& gini1 =
+    const GraphTransferGraphInputNodeInfo& gini1 =
         gfi0.graph_input_node_info(i);
     EXPECT_EQ(gini0.ByteSizeLong(), gini1.ByteSizeLong());
     EXPECT_EQ(gini0.DebugString(), gini1.DebugString());
@@ -382,9 +383,9 @@ static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
   ASSERT_EQ(gfi0.graph_output_node_info_size(),
             gfi1.graph_output_node_info_size());
   for (int i = 0; i < gfi0.graph_output_node_info_size(); ++i) {
-    const GraphTransferInfo::GraphOutputNodeInfo& goni0 =
+    const GraphTransferGraphOutputNodeInfo& goni0 =
         gfi0.graph_output_node_info(i);
-    const GraphTransferInfo::GraphOutputNodeInfo& goni1 =
+    const GraphTransferGraphOutputNodeInfo& goni1 =
         gfi0.graph_output_node_info(i);
     EXPECT_EQ(goni0.ByteSizeLong(), goni1.ByteSizeLong());
     EXPECT_EQ(goni0.DebugString(), goni1.DebugString());