import "tensorflow/core/framework/types.proto";
+message GraphTransferNodeInput {
+ int32 node_id = 1;
+ int32 output_port = 2;
+}
+message GraphTransferNodeInfo {
+ string name = 1;
+ int32 node_id = 2;
+ string type_name = 3;
+ int32 soc_op_id = 4;
+ int32 padding_id = 5;
+ int32 input_count = 6;
+ int32 output_count = 7;
+};
+message GraphTransferConstNodeInfo {
+ string name = 1;
+ int32 node_id = 2;
+ repeated int64 shape = 3;
+ bytes data = 4;
+ DataType dtype = 5;
+};
+message GraphTransferNodeInputInfo {
+ int32 node_id = 1;
+ repeated GraphTransferNodeInput node_input = 2;
+};
+message GraphTransferNodeOutputInfo {
+ int32 node_id = 1;
+ repeated int32 max_byte_size = 2;
+};
+message GraphTransferGraphInputNodeInfo {
+ string name = 1;
+ repeated int64 shape = 2;
+ DataType dtype = 3;
+}
+
+message GraphTransferGraphOutputNodeInfo {
+ string name = 1;
+ repeated int64 shape = 2;
+ DataType dtype = 3;
+}
+
// Protocol buffer representing a handle to a tensorflow resource. Handles are
// not valid across executions, but can be serialized back and forth from within
// a single run.
NOP = 0;
HEXAGON = 1;
}
- message NodeInput {
- int32 node_id = 1;
- int32 output_port = 2;
- }
- message NodeInfo {
- string name = 1;
- int32 node_id = 2;
- string type_name = 3;
- int32 soc_op_id = 4;
- int32 padding_id = 5;
- int32 input_count = 6;
- int32 output_count = 7;
- };
- message ConstNodeInfo {
- string name = 1;
- int32 node_id = 2;
- repeated int64 shape = 3;
- bytes data = 4;
- DataType dtype = 5;
- };
- message NodeInputInfo {
- int32 node_id = 1;
- repeated NodeInput node_input = 2;
- };
- message NodeOutputInfo {
- int32 node_id = 1;
- repeated int32 max_byte_size = 2;
- };
- message GraphInputNodeInfo {
- string name = 1;
- repeated int64 shape = 2;
- DataType dtype = 3;
- }
-
- message GraphOutputNodeInfo {
- string name = 1;
- repeated int64 shape = 2;
- DataType dtype = 3;
- }
- repeated NodeInfo node_info = 1;
- repeated ConstNodeInfo const_node_info = 2;
- repeated NodeInputInfo node_input_info = 3;
- repeated NodeOutputInfo node_output_info = 4;
+ repeated GraphTransferNodeInfo node_info = 1;
+ repeated GraphTransferConstNodeInfo const_node_info = 2;
+ repeated GraphTransferNodeInputInfo node_input_info = 3;
+ repeated GraphTransferNodeOutputInfo node_output_info = 4;
// Input Node parameters of transferred graph
- repeated GraphInputNodeInfo graph_input_node_info = 5;
- repeated GraphOutputNodeInfo graph_output_node_info = 6;
+ repeated GraphTransferGraphInputNodeInfo graph_input_node_info = 5;
+ repeated GraphTransferGraphOutputNodeInfo graph_output_node_info = 6;
// Destination of graph transfer
Destination destination = 7;
};
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
],
)
==============================================================================*/
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/const_op.h"
#include <utility>
#include <vector>
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
+class RemoteFusedGraphExecuteInfo;
+
class GraphTransferUtils {
public:
static std::priority_queue<std::tuple<float, int, string>>
#include <algorithm>
#include <cinttypes>
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
return nullptr;
}
+GraphTransferer::GraphTransferer() {
+ graph_transfer_info_ = new GraphTransferInfo();
+}
+
+GraphTransferer::~GraphTransferer() { delete graph_transfer_info_; }
+
/**
* graph loading functions
* - LoadGraphFromProto
for (const std::pair<string, Tensor>& input_node_info :
input_node_info_list) {
- GraphTransferInfo::GraphInputNodeInfo& graph_input_node_info =
- *graph_transfer_info_.add_graph_input_node_info();
+ GraphTransferGraphInputNodeInfo& graph_input_node_info =
+ *graph_transfer_info_->add_graph_input_node_info();
graph_input_node_info.set_name(input_node_info.first);
graph_input_node_info.set_dtype(input_node_info.second.dtype());
for (const int64 dim : ToTensorShapeArray(input_node_info.second.shape())) {
const Node* node = node_name_cache_list_.at(node_id);
CHECK_NOTNULL(node);
- GraphTransferInfo::GraphOutputNodeInfo& graph_output_node_info =
- *graph_transfer_info_.add_graph_output_node_info();
+ GraphTransferGraphOutputNodeInfo& graph_output_node_info =
+ *graph_transfer_info_->add_graph_output_node_info();
graph_output_node_info.set_name(strings::StrCat(node_name, ":", port));
// Get output tensor shape type
void GraphTransferer::SortParams(const std::vector<string>& output_node_names) {
// TODO(satok): optimize complexity
- std::unordered_map<int, GraphTransferInfo::NodeInputInfo*> input_map;
- for (GraphTransferInfo::NodeInputInfo& input :
- *graph_transfer_info_.mutable_node_input_info()) {
+ std::unordered_map<int, GraphTransferNodeInputInfo*> input_map;
+ for (GraphTransferNodeInputInfo& input :
+ *graph_transfer_info_->mutable_node_input_info()) {
input_map.emplace(input.node_id(), &input);
}
// Setup dependency map placeholder
std::vector<int> output_node_ids;
std::unordered_map<int, std::unordered_set<int>> dependency_map;
- for (const GraphTransferInfo::NodeInfo& params :
- graph_transfer_info_.node_info()) {
+ for (const GraphTransferNodeInfo& params :
+ graph_transfer_info_->node_info()) {
const int node_id = params.node_id();
for (const string& output_node_name : output_node_names) {
if (params.name() == output_node_name) {
continue;
}
CHECK_EQ(input_map.count(node_id), 1);
- for (const GraphTransferInfo::NodeInput& node_input :
+ for (const GraphTransferNodeInput& node_input :
input_map.at(node_id)->node_input()) {
dependency_map.at(node_id).emplace(node_input.node_id());
}
FillDependencyRec(output_node_id, dependency_map, completed);
}
- std::sort(graph_transfer_info_.mutable_node_info()->begin(),
- graph_transfer_info_.mutable_node_info()->end(),
+ std::sort(graph_transfer_info_->mutable_node_info()->begin(),
+ graph_transfer_info_->mutable_node_info()->end(),
TransferParamsComparator(dependency_map));
}
void GraphTransferer::SetSerializedGraphTransferInfo(
const string& serialized_proto) {
- graph_transfer_info_.ParseFromString(serialized_proto);
+ graph_transfer_info_->ParseFromString(serialized_proto);
}
const GraphTransferInfo& GraphTransferer::GetGraphTransferInfo() const {
- return graph_transfer_info_;
+ return *graph_transfer_info_;
}
GraphTransferInfo& GraphTransferer::GetMutableGraphTransferInfo() {
- return graph_transfer_info_;
+ return *graph_transfer_info_;
}
void GraphTransferer::CacheNode(const Node& node) {
data_size = max_bytes_per_data * num_output_elements;
shape_array = BuildShapeArray(shape_handle, context);
- GraphTransferInfo::ConstNodeInfo& const_node_info =
- *graph_transfer_info_.add_const_node_info();
+ GraphTransferConstNodeInfo& const_node_info =
+ *graph_transfer_info_->add_const_node_info();
const_node_info.set_name(node.name());
const_node_info.set_node_id(id);
// TODO(satok): Make this generic. Never assume rank is 4.
node_name_cache_list_.emplace_back(nullptr);
const int id = node_name_cache_list_.size() - 1;
node_name_to_id_cache_map_.emplace(shape_name, id);
- GraphTransferInfo::ConstNodeInfo& const_node_info =
- *graph_transfer_info_.add_const_node_info();
+ GraphTransferConstNodeInfo& const_node_info =
+ *graph_transfer_info_->add_const_node_info();
const_node_info.set_name(shape_name);
const_node_info.set_node_id(id);
// TODO(satok): Make this generic. Never assume rank is 5.
node_name_cache_list_.emplace_back(nullptr);
const int id = node_name_cache_list_.size() - 1;
node_name_to_id_cache_map_.emplace(node_name, id);
- GraphTransferInfo::ConstNodeInfo& const_node_info =
- *graph_transfer_info_.add_const_node_info();
+ GraphTransferConstNodeInfo& const_node_info =
+ *graph_transfer_info_->add_const_node_info();
const_node_info.set_name(node_name);
const_node_info.set_node_id(id);
CHECK_EQ(4, SHAPE_ARRAY_SIZE);
node_name_cache_list_.emplace_back(nullptr);
const int id = node_name_cache_list_.size() - 1;
node_name_to_id_cache_map_.emplace(val_name, id);
- GraphTransferInfo::ConstNodeInfo& const_node_info =
- *graph_transfer_info_.add_const_node_info();
+ GraphTransferConstNodeInfo& const_node_info =
+ *graph_transfer_info_->add_const_node_info();
const_node_info.set_name(val_name);
const_node_info.set_node_id(id);
// TODO(satok): Do not assume rank is 4 here.
CHECK_EQ(2, node.num_inputs());
- GraphTransferInfo::NodeInputInfo& node_input_info =
- *graph_transfer_info_.add_node_input_info();
+ GraphTransferNodeInputInfo& node_input_info =
+ *graph_transfer_info_->add_node_input_info();
node_input_info.set_node_id(id);
AddNodeInputByInputIndex(node, 0, &node_input_info);
new_const_tensor,
strings::StrCat(input_node->name(), "_", node.name(), "_1"));
- GraphTransferInfo::NodeInput& node_input =
- *node_input_info.add_node_input();
+ GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
node_input.set_node_id(id);
node_input.set_output_port(0);
} else {
const int padding, const int inputs_size,
const std::vector<int>& extra_inputs,
const int outputs_size) {
- GraphTransferInfo::NodeInfo& node_info =
- *graph_transfer_info_.add_node_info();
+ GraphTransferNodeInfo& node_info = *graph_transfer_info_->add_node_info();
node_info.set_name(name);
node_info.set_node_id(id);
node_info.set_type_name(type);
void GraphTransferer::AddNodeInputByInputIndex(
const Node& node, const int idx,
- GraphTransferInfo::NodeInputInfo* node_input_info) {
+ GraphTransferNodeInputInfo* node_input_info) {
const Edge* edge = nullptr;
TF_CHECK_OK(node.input_edge(idx, &edge));
const Node* input_node = edge->src();
const std::string& op_name = input_node->name();
CHECK_GT(node_name_to_id_cache_map_.count(op_name), 0) << op_name;
const int src_id = node_name_to_id_cache_map_[op_name];
- GraphTransferInfo::NodeInput& node_input = *node_input_info->add_node_input();
+ GraphTransferNodeInput& node_input = *node_input_info->add_node_input();
node_input.set_node_id(src_id);
node_input.set_output_port(port);
}
const int id, const Node& node, const std::vector<int>& extra_inputs) {
VLOG(1) << "Append input params: " << node.name() << ", " << node.num_inputs()
<< ", " << extra_inputs.size();
- GraphTransferInfo::NodeInputInfo& node_input_info =
- *graph_transfer_info_.add_node_input_info();
+ GraphTransferNodeInputInfo& node_input_info =
+ *graph_transfer_info_->add_node_input_info();
node_input_info.set_node_id(id);
for (int i = 0; i < node.num_inputs(); ++i) {
AddNodeInputByInputIndex(node, i, &node_input_info);
}
for (const int extra_input : extra_inputs) {
- GraphTransferInfo::NodeInput& node_input =
- *node_input_info.add_node_input();
+ GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
node_input.set_node_id(extra_input);
node_input.set_output_port(0);
}
const int id, const Node& node) {
VLOG(1) << "Append output params: " << node.name() << ", "
<< node.num_outputs();
- GraphTransferInfo::NodeOutputInfo& node_output_info =
- *graph_transfer_info_.add_node_output_info();
+ GraphTransferNodeOutputInfo& node_output_info =
+ *graph_transfer_info_->add_node_output_info();
node_output_info.set_node_id(id);
std::vector<DataType> data_types;
: dependency_map_(dep_map) {}
bool GraphTransferer::TransferParamsComparator::operator()(
- const GraphTransferInfo::NodeInfo& obj0,
- const GraphTransferInfo::NodeInfo& obj1) {
+ const GraphTransferNodeInfo& obj0, const GraphTransferNodeInfo& obj1) {
const int node_id0 = obj0.node_id();
const int node_id1 = obj1.node_id();
bool obj0_uses_obj1 = false;
void GraphTransferer::DumpNodeTransferParams() const {
LOG(INFO) << "*** Const Nodes ***";
- for (const GraphTransferInfo::ConstNodeInfo& params :
- graph_transfer_info_.const_node_info()) {
+ for (const GraphTransferConstNodeInfo& params :
+ graph_transfer_info_->const_node_info()) {
// TODO(satok): Stop assuming shape size is 4.
CHECK_EQ(params.shape_size(), 4);
LOG(INFO) << "[ " << params.node_id() << " \"" << params.name()
}
LOG(INFO) << "******\n";
LOG(INFO) << "*** Op Nodes ***";
- for (const GraphTransferInfo::NodeInfo& params :
- graph_transfer_info_.node_info()) {
+ for (const GraphTransferNodeInfo& params :
+ graph_transfer_info_->node_info()) {
LOG(INFO) << "[ " << params.node_id() << " \"" << params.name();
LOG(INFO) << " type: " << params.type_name();
LOG(INFO) << " padding: " << ToPaddingDebugString(params.padding_id());
}
LOG(INFO) << "******\n";
LOG(INFO) << "*** Node input params ***";
- for (const GraphTransferInfo::NodeInputInfo& params :
- graph_transfer_info_.node_input_info()) {
+ for (const GraphTransferNodeInputInfo& params :
+ graph_transfer_info_->node_input_info()) {
LOG(INFO) << "[ " << params.node_id() << " ]";
- for (const GraphTransferInfo::NodeInput& node_input : params.node_input()) {
+ for (const GraphTransferNodeInput& node_input : params.node_input()) {
LOG(INFO) << " src node id = " << node_input.node_id()
<< ", output port = " << node_input.output_port();
}
}
LOG(INFO) << "******\n";
LOG(INFO) << "*** Node output params ***";
- for (const GraphTransferInfo::NodeOutputInfo& params :
- graph_transfer_info_.node_output_info()) {
+ for (const GraphTransferNodeOutputInfo& params :
+ graph_transfer_info_->node_output_info()) {
LOG(INFO) << "[ " << params.node_id() << " ]";
for (const int max_size : params.max_byte_size()) {
LOG(INFO) << " max_size = " << max_size;
}
void GraphTransferer::DumpVerificationStringOfNodeTransferParams() const {
- for (const GraphTransferInfo::ConstNodeInfo& params :
- graph_transfer_info_.const_node_info()) {
+ for (const GraphTransferConstNodeInfo& params :
+ graph_transfer_info_->const_node_info()) {
std::stringstream sstream;
// TODO(satok): Stop assuming shape size is 4.
CHECK_EQ(params.shape_size(), 4);
LOG(INFO) << sstream.str();
}
LOG(INFO) << "Const node count = "
- << graph_transfer_info_.const_node_info_size();
- for (const GraphTransferInfo::NodeInfo& params :
- graph_transfer_info_.node_info()) {
+ << graph_transfer_info_->const_node_info_size();
+ for (const GraphTransferNodeInfo& params :
+ graph_transfer_info_->node_info()) {
std::stringstream sstream;
sstream << "---(OP) [" << params.name().c_str() << "," << std::hex
<< params.node_id() << std::dec << "," << params.soc_op_id() << ","
<< "," << params.output_count() << "," << params.type_name() << "]";
LOG(INFO) << sstream.str();
}
- LOG(INFO) << "Op node count = " << graph_transfer_info_.node_info_size();
- for (const GraphTransferInfo::NodeInputInfo& params :
- graph_transfer_info_.node_input_info()) {
+ LOG(INFO) << "Op node count = " << graph_transfer_info_->node_info_size();
+ for (const GraphTransferNodeInputInfo& params :
+ graph_transfer_info_->node_input_info()) {
std::stringstream sstream;
sstream << "---(INPUT) [" << std::hex << params.node_id() << std::dec;
- for (const GraphTransferInfo::NodeInput& node_input : params.node_input()) {
+ for (const GraphTransferNodeInput& node_input : params.node_input()) {
sstream << "," << std::hex << node_input.node_id() << std::dec << ","
<< node_input.output_port();
}
LOG(INFO) << sstream.str();
}
LOG(INFO) << "Input params count = "
- << graph_transfer_info_.node_input_info_size();
- for (const GraphTransferInfo::NodeOutputInfo& params :
- graph_transfer_info_.node_output_info()) {
+ << graph_transfer_info_->node_input_info_size();
+ for (const GraphTransferNodeOutputInfo& params :
+ graph_transfer_info_->node_output_info()) {
std::stringstream sstream;
sstream << "---(OUTPUT) [" << std::hex << params.node_id() << std::dec;
for (const int max_size : params.max_byte_size()) {
LOG(INFO) << sstream.str();
}
LOG(INFO) << "Output params count = "
- << graph_transfer_info_.node_output_info_size();
+ << graph_transfer_info_->node_output_info_size();
}
} // namespace tensorflow
#include <vector>
#include "tensorflow/core/common_runtime/shape_refiner.h"
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
namespace tensorflow {
+class GraphTransferInfo;
+class GraphTransferNodeInfo;
+class GraphTransferNodeInputInfo;
+
// GraphTransferer transfers graph definitions into SoC memory.
// This functionality is effective if SoC is capable to run
// the graph on that chip.
static constexpr int SHAPE_ARRAY_SIZE = MAX_SUPPORTED_RANK;
using TensorShapeMap = RemoteFusedGraphExecuteUtils::TensorShapeMap;
- GraphTransferer() = default;
+ GraphTransferer();
+
+ ~GraphTransferer();
// Load graph structure into GraphTransferer
// TODO(satok): Pass a pair of TensorShape and DataType instead of
public:
TransferParamsComparator(
const std::unordered_map<int, std::unordered_set<int>>& dep_map);
- bool operator()(const GraphTransferInfo::NodeInfo& obj0,
- const GraphTransferInfo::NodeInfo& obj1);
+ bool operator()(const GraphTransferNodeInfo& obj0,
+ const GraphTransferNodeInfo& obj1);
const std::unordered_map<int, std::unordered_set<int>>& dependency_map_;
};
const std::vector<int>& extra_inputs,
const int outputs_size);
- void AddNodeInputByInputIndex(
- const Node& node, const int idx,
- GraphTransferInfo::NodeInputInfo* node_input_info);
+ void AddNodeInputByInputIndex(const Node& node, const int idx,
+ GraphTransferNodeInputInfo* node_input_info);
void AppendNodeInputParams(const int id, const Node& node,
const std::vector<int>& extra_inputs);
// Dump pretty print of parameters
void DumpNodeTransferParams() const;
- GraphTransferInfo graph_transfer_info_{};
+ GraphTransferInfo* graph_transfer_info_;
std::vector<const Node*> node_name_cache_list_{};
std::unordered_map<string, int> node_name_to_id_cache_map_{};
return def;
}
-static const GraphTransferInfo::ConstNodeInfo* FindConstNodeInfo(
+static const GraphTransferConstNodeInfo* FindConstNodeInfo(
const GraphTransferer& gt, const string& name) {
- for (const GraphTransferInfo::ConstNodeInfo& params :
+ for (const GraphTransferConstNodeInfo& params :
gt.GetGraphTransferInfo().const_node_info()) {
if (params.name() == name) {
return ¶ms;
return nullptr;
}
-static const GraphTransferInfo::NodeInfo* FindNodeInfo(
- const GraphTransferer& gt, const string& name) {
- for (const GraphTransferInfo::NodeInfo& params :
+static const GraphTransferNodeInfo* FindNodeInfo(const GraphTransferer& gt,
+ const string& name) {
+ for (const GraphTransferNodeInfo& params :
gt.GetGraphTransferInfo().node_info()) {
if (params.name() == name) {
return ¶ms;
return nullptr;
}
-static const GraphTransferInfo::NodeInputInfo* FindNodeInputInfo(
+static const GraphTransferNodeInputInfo* FindNodeInputInfo(
const GraphTransferer& gt, const int node_id) {
- for (const GraphTransferInfo::NodeInputInfo& params :
+ for (const GraphTransferNodeInputInfo& params :
gt.GetGraphTransferInfo().node_input_info()) {
if (params.node_id() == node_id) {
return ¶ms;
return nullptr;
}
-static const GraphTransferInfo::NodeOutputInfo* FindNodeOutputInfo(
+static const GraphTransferNodeOutputInfo* FindNodeOutputInfo(
const GraphTransferer& gt, const int node_id) {
- for (const GraphTransferInfo::NodeOutputInfo& params :
+ for (const GraphTransferNodeOutputInfo& params :
gt.GetGraphTransferInfo().node_output_info()) {
if (params.node_id() == node_id) {
return ¶ms;
}
static void SanityCheckNodes(const GraphTransferer& gt) {
- for (const GraphTransferInfo::NodeInfo& params :
+ for (const GraphTransferNodeInfo& params :
gt.GetGraphTransferInfo().node_info()) {
if (params.input_count() > 0) {
- const GraphTransferInfo::NodeInputInfo* input_params =
+ const GraphTransferNodeInputInfo* input_params =
FindNodeInputInfo(gt, params.node_id());
ASSERT_NE(nullptr, input_params);
EXPECT_EQ(params.input_count(), input_params->node_input_size());
EXPECT_EQ(params.node_id(), input_params->node_id());
- for (const GraphTransferInfo::NodeInput& node_input :
+ for (const GraphTransferNodeInput& node_input :
input_params->node_input()) {
EXPECT_GE(node_input.output_port(), 0);
}
}
if (params.output_count() > 0) {
- const GraphTransferInfo::NodeOutputInfo* output_params =
+ const GraphTransferNodeOutputInfo* output_params =
FindNodeOutputInfo(gt, params.node_id());
ASSERT_NE(nullptr, output_params);
EXPECT_EQ(params.output_count(), output_params->max_byte_size_size());
const int const_node_count =
gt_.GetGraphTransferInfo().const_node_info_size();
ASSERT_EQ(2, const_node_count);
- const GraphTransferInfo::ConstNodeInfo* params_a =
- FindConstNodeInfo(gt_, NAME_A);
+ const GraphTransferConstNodeInfo* params_a = FindConstNodeInfo(gt_, NAME_A);
ASSERT_TRUE(params_a != nullptr);
EXPECT_EQ(NAME_A, params_a->name());
ASSERT_EQ(4, params_a->shape_size());
EXPECT_EQ(1, params_a->shape(3));
EXPECT_EQ(4, params_a->data().length());
- const GraphTransferInfo::ConstNodeInfo* params_b =
- FindConstNodeInfo(gt_, NAME_B);
+ const GraphTransferConstNodeInfo* params_b = FindConstNodeInfo(gt_, NAME_B);
ASSERT_TRUE(params_b != nullptr);
ASSERT_EQ(4, params_b->shape_size());
EXPECT_EQ(1, params_b->shape(0));
ASSERT_EQ(2, const_node_count);
const int op_node_count = gt_.GetGraphTransferInfo().node_info_size();
ASSERT_EQ(4, op_node_count);
- const GraphTransferInfo::NodeInfo* params_conv = FindNodeInfo(gt_, "conv");
+ const GraphTransferNodeInfo* params_conv = FindNodeInfo(gt_, "conv");
ASSERT_TRUE(params_conv != nullptr);
const int id = params_conv->node_id();
EXPECT_GE(id, 0);
ASSERT_EQ(2, const_node_count);
const int op_node_count = gt_.GetGraphTransferInfo().node_info_size();
ASSERT_EQ(4, op_node_count);
- const GraphTransferInfo::NodeInfo* params_max_pool =
- FindNodeInfo(gt_, "maxpool");
+ const GraphTransferNodeInfo* params_max_pool = FindNodeInfo(gt_, "maxpool");
ASSERT_TRUE(params_max_pool != nullptr);
const int id = params_max_pool->node_id();
EXPECT_GE(id, 0);
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
+#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
return data_ptr;
}
-/* static */ GraphTransferInfo::NodeInfo* HexagonControlWrapper::FindNodeInfo(
+/* static */ GraphTransferNodeInfo* HexagonControlWrapper::FindNodeInfo(
const string& name, GraphTransferInfo* graph_transfer_info) {
- for (GraphTransferInfo::NodeInfo& node_info :
+ for (GraphTransferNodeInfo& node_info :
*graph_transfer_info->mutable_node_info()) {
if (node_info.name() == name) {
return &node_info;
graph_transferer_.GetMutableGraphTransferInfo();
// Overwrite op type of input nodes for hexagon
- for (const GraphTransferInfo::GraphInputNodeInfo& graph_input :
+ for (const GraphTransferGraphInputNodeInfo& graph_input :
graph_transfer_info.graph_input_node_info()) {
- GraphTransferInfo::NodeInfo* node_info =
+ GraphTransferNodeInfo* node_info =
FindNodeInfo(graph_input.name(), &graph_transfer_info);
CHECK_NE(node_info, nullptr);
}
// Generate a new output node which is connected to graph output node
// TODO(satok): Support multiple output nodes
CHECK_EQ(graph_transfer_info.graph_output_node_info_size(), 1);
- for (const GraphTransferInfo::GraphOutputNodeInfo& graph_output :
+ for (const GraphTransferGraphOutputNodeInfo& graph_output :
graph_transfer_info.graph_output_node_info()) {
const int new_output_node_id = graph_transfer_info.node_info_size() +
graph_transfer_info.const_node_info_size() +
2 /* offset for ids */;
// Register a new output node
- GraphTransferInfo::NodeInfo& new_output_node_info =
+ GraphTransferNodeInfo& new_output_node_info =
*graph_transfer_info.add_node_info();
new_output_node_info.set_name(OUTPUT_OP_NAME);
new_output_node_info.set_node_id(new_output_node_id);
const string node_name = tid.first.ToString();
const int port = tid.second;
// Register node input for the new output node
- const GraphTransferInfo::NodeInfo* node_info =
+ const GraphTransferNodeInfo* node_info =
FindNodeInfo(node_name, &graph_transfer_info);
CHECK_NE(node_info, nullptr);
- GraphTransferInfo::NodeInputInfo& node_input_info =
+ GraphTransferNodeInputInfo& node_input_info =
*graph_transfer_info.add_node_input_info();
node_input_info.set_node_id(new_output_node_id);
- GraphTransferInfo::NodeInput& node_input =
- *node_input_info.add_node_input();
+ GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
node_input.set_node_id(node_info->node_id());
node_input.set_output_port(port);
}
int inputs_count = 0;
int outputs_count = 0;
- for (const GraphTransferInfo::NodeInputInfo& input_params :
+ for (const GraphTransferNodeInputInfo& input_params :
graph_transfer_info.node_input_info()) {
inputs_count += input_params.node_input_size();
}
- for (const GraphTransferInfo::NodeOutputInfo& output_params :
+ for (const GraphTransferNodeOutputInfo& output_params :
graph_transfer_info.node_output_info()) {
outputs_count += output_params.max_byte_size_size();
}
// Construct node input parameters
std::unordered_map<int, std::tuple<void*, int>> inputs_map;
- for (const GraphTransferInfo::NodeInputInfo& input_params :
+ for (const GraphTransferNodeInputInfo& input_params :
graph_transfer_info.node_input_info()) {
const int count = input_params.node_input_size();
CHECK(count <= MAX_IN_OUT_COUNT);
int node_ids[MAX_IN_OUT_COUNT];
int ports[MAX_IN_OUT_COUNT];
for (int i = 0; i < count; ++i) {
- const GraphTransferInfo::NodeInput& node_input =
- input_params.node_input(i);
+ const GraphTransferNodeInput& node_input = input_params.node_input(i);
node_ids[i] = node_input.node_id() + NODE_ID_OFFSET;
ports[i] = node_input.output_port();
}
// Construct node output parameters
std::unordered_map<int, std::tuple<void*, int>> outputs_map;
- for (const GraphTransferInfo::NodeOutputInfo& output_params :
+ for (const GraphTransferNodeOutputInfo& output_params :
graph_transfer_info.node_output_info()) {
const int count = output_params.max_byte_size_size();
CHECK(count <= MAX_IN_OUT_COUNT);
// Initialize graph
// 1. Setup const nodes
- for (const GraphTransferInfo::ConstNodeInfo& params :
+ for (const GraphTransferConstNodeInfo& params :
graph_transfer_info.const_node_info()) {
const int node_id = params.node_id();
// TODO(satok): Stop assuming shape size is 4.
}
// 2. Setup op nodes
- for (const GraphTransferInfo::NodeInfo& params :
- graph_transfer_info.node_info()) {
+ for (const GraphTransferNodeInfo& params : graph_transfer_info.node_info()) {
const int node_id = params.node_id();
const int op_id = params.soc_op_id();
CHECK(inputs_map.count(node_id) == 1);
// CAVEAT: Need offset as HVX library reserves some ids
static constexpr int NODE_ID_OFFSET = 0x10000;
- static GraphTransferInfo::NodeInfo* FindNodeInfo(
- const string& node_name, GraphTransferInfo* graph_transfer_info);
+ static GraphTransferNodeInfo* FindNodeInfo(
+ const string& name, GraphTransferInfo* graph_transfer_info);
const RemoteFusedGraphExecuteInfo* execute_info_{};
GraphTransferer graph_transferer_{};
#include <memory>
+#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
const GraphTransferInfo& graph_transfer_info) {
RemoteFusedGraphExecuteInfo execute_info;
execute_info.set_executor_name("build_hexagon_remote_fused_graph_executor");
- for (const GraphTransferInfo::GraphInputNodeInfo& input :
+ for (const GraphTransferGraphInputNodeInfo& input :
graph_transfer_info.graph_input_node_info()) {
execute_info.add_graph_input_node_name(input.name());
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
}
}
- for (const GraphTransferInfo::GraphOutputNodeInfo& output :
+ for (const GraphTransferGraphOutputNodeInfo& output :
graph_transfer_info.graph_output_node_info()) {
execute_info.add_graph_output_node_name(output.name());
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
// 1. check node_info
ASSERT_EQ(gfi0.node_info_size(), gfi1.node_info_size());
for (int i = 0; i < gfi0.node_info_size(); ++i) {
- const GraphTransferInfo::NodeInfo& ni0 = gfi0.node_info(i);
- const GraphTransferInfo::NodeInfo& ni1 = gfi1.node_info(i);
+ const GraphTransferNodeInfo& ni0 = gfi0.node_info(i);
+ const GraphTransferNodeInfo& ni1 = gfi1.node_info(i);
EXPECT_EQ(ni0.DebugString(), ni1.DebugString());
EXPECT_EQ(ni0.ByteSizeLong(), ni1.ByteSizeLong());
}
// 2. check const_node_info
ASSERT_EQ(gfi0.const_node_info_size(), gfi1.const_node_info_size());
for (int i = 0; i < gfi0.const_node_info_size(); ++i) {
- const GraphTransferInfo::ConstNodeInfo& cni0 = gfi0.const_node_info(i);
- const GraphTransferInfo::ConstNodeInfo& cni1 = gfi1.const_node_info(i);
+ const GraphTransferConstNodeInfo& cni0 = gfi0.const_node_info(i);
+ const GraphTransferConstNodeInfo& cni1 = gfi1.const_node_info(i);
ASSERT_EQ(cni0.shape_size(), cni1.shape_size());
for (int j = 0; j < cni0.shape_size(); ++j) {
EXPECT_EQ(cni0.shape(j), cni1.shape(j));
// 3. check node_input_info
ASSERT_EQ(gfi0.node_input_info_size(), gfi1.node_input_info_size());
for (int i = 0; i < gfi0.node_input_info_size(); ++i) {
- const GraphTransferInfo::NodeInputInfo& nii0 = gfi0.node_input_info(i);
- const GraphTransferInfo::NodeInputInfo& nii1 = gfi1.node_input_info(i);
+ const GraphTransferNodeInputInfo& nii0 = gfi0.node_input_info(i);
+ const GraphTransferNodeInputInfo& nii1 = gfi1.node_input_info(i);
EXPECT_EQ(nii0.ByteSizeLong(), nii1.ByteSizeLong());
EXPECT_EQ(nii0.DebugString(), nii1.DebugString());
}
// 4. check node_output_info
ASSERT_EQ(gfi0.node_output_info_size(), gfi1.node_output_info_size());
for (int i = 0; i < gfi0.node_output_info_size(); ++i) {
- const GraphTransferInfo::NodeOutputInfo& noi0 = gfi0.node_output_info(i);
- const GraphTransferInfo::NodeOutputInfo& noi1 = gfi1.node_output_info(i);
+ const GraphTransferNodeOutputInfo& noi0 = gfi0.node_output_info(i);
+ const GraphTransferNodeOutputInfo& noi1 = gfi1.node_output_info(i);
ASSERT_EQ(noi0.max_byte_size_size(), noi1.max_byte_size_size());
for (int j = 0; j < noi0.max_byte_size_size(); ++j) {
EXPECT_EQ(noi0.max_byte_size(j), noi1.max_byte_size(j));
ASSERT_EQ(gfi0.graph_input_node_info_size(),
gfi1.graph_input_node_info_size());
for (int i = 0; i < gfi0.graph_input_node_info_size(); ++i) {
- const GraphTransferInfo::GraphInputNodeInfo& gini0 =
+ const GraphTransferGraphInputNodeInfo& gini0 =
gfi0.graph_input_node_info(i);
- const GraphTransferInfo::GraphInputNodeInfo& gini1 =
+ const GraphTransferGraphInputNodeInfo& gini1 =
gfi0.graph_input_node_info(i);
EXPECT_EQ(gini0.ByteSizeLong(), gini1.ByteSizeLong());
EXPECT_EQ(gini0.DebugString(), gini1.DebugString());
ASSERT_EQ(gfi0.graph_output_node_info_size(),
gfi1.graph_output_node_info_size());
for (int i = 0; i < gfi0.graph_output_node_info_size(); ++i) {
- const GraphTransferInfo::GraphOutputNodeInfo& goni0 =
+ const GraphTransferGraphOutputNodeInfo& goni0 =
gfi0.graph_output_node_info(i);
- const GraphTransferInfo::GraphOutputNodeInfo& goni1 =
+ const GraphTransferGraphOutputNodeInfo& goni1 =
gfi0.graph_output_node_info(i);
EXPECT_EQ(goni0.ByteSizeLong(), goni1.ByteSizeLong());
EXPECT_EQ(goni0.DebugString(), goni1.DebugString());