Fix wiring issues due to shared inputs and outputs
authorSami Kama <skama@nvidia.com>
Wed, 30 May 2018 03:59:21 +0000 (20:59 -0700)
committerSami Kama <skama@nvidia.com>
Wed, 30 May 2018 03:59:21 +0000 (20:59 -0700)
tensorflow/contrib/tensorrt/convert/convert_graph.cc
tensorflow/contrib/tensorrt/convert/convert_nodes.cc

index b7b26cf..5f79f6d 100644 (file)
@@ -91,8 +91,11 @@ void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
       if (!subgraph_node_ids.count(edge->src()->id()) &&
           !edge->src()->IsSource() && !edge->IsControlEdge()) {
         incoming_edges->insert(edge);
+        VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name()
+                << " Y, ";
       } else {
-        VLOG(2) << node->name() << " -> " << edge->src()->name() << " N, ";
+        VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name()
+                << " N, ";
       }
     }
   }
@@ -106,10 +109,12 @@ void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph,
     for (const tensorflow::Edge* edge : node->out_edges()) {
       if (!subgraph_node_ids.count(edge->dst()->id()) &&
           !edge->dst()->IsSink() && !edge->IsControlEdge()) {
-        VLOG(2) << node->name() << " -> " << edge->dst()->name() << " Y, ";
+        VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name()
+                << " Y, ";
         outgoing_edges->insert(edge);
       } else {
-        VLOG(2) << node->name() << " -> " << edge->dst()->name() << " N, ";
+        VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name()
+                << " N, ";
       }
     }
   }
@@ -181,29 +186,21 @@ struct ConvertGraphParams {
 static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) {
   GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids,
                            &p->subgraph_incoming_edges);
+  std::set<std::pair<int, int>> unique_tensors;
   for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) {
-    p->subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
-  }
-  auto output_name_to_index_map = BuildTensorNameMap(p->output_names);
-  std::set<std::pair<int, int>> subgraph_outputs_set;
-  // Collect outputs referenced from output_names
-  for (int node_id : p->subgraph_node_ids) {
-    tensorflow::Node* node = p->graph.FindNodeId(node_id);
-    if (output_name_to_index_map.count(node->name())) {
-      for (int index : output_name_to_index_map.at(node->name())) {
-        subgraph_outputs_set.insert({node_id, index});
-      }
-    }
+    unique_tensors.insert({edge->src()->id(), edge->src_output()});
   }
+  p->subgraph_inputs.insert(p->subgraph_inputs.begin(), unique_tensors.begin(),
+                            unique_tensors.end());
   GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids,
                            &p->subgraph_outgoing_edges);
+  unique_tensors.clear();
   for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) {
-    subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
+    unique_tensors.insert({edge->src()->id(), edge->src_output()});
   }
-  p->subgraph_outputs.reserve(subgraph_outputs_set.size());
+  p->subgraph_outputs.reserve(unique_tensors.size());
   p->subgraph_outputs.insert(p->subgraph_outputs.begin(),
-                             subgraph_outputs_set.begin(),
-                             subgraph_outputs_set.end());
+                             unique_tensors.begin(), unique_tensors.end());
   return tensorflow::Status::OK();
 }
 
@@ -257,19 +254,24 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) {
   for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) {
     subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i});
   }
+  std::set<std::pair<int, int>> unique_tensors;
   for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) {
     std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
+    if (unique_tensors.count(old_src)) continue;
+    unique_tensors.insert(old_src);
     int new_src_output = subgraph_edge_to_input_map.at(old_src);
     params->graph.AddEdge(edge->src(), edge->src_output(), trt_node,
                           new_src_output);
+    VLOG(1) << "Wire " << edge->src()->name() << ":" << edge->src_output()
+            << " -> " << trt_node->name() << ":" << new_src_output;
     params->graph.RemoveEdge(edge);
   }
-
-  VLOG(2) << "new wiring edges: " << trt_node->in_edges().size();
-  for (const tensorflow::Edge* edge : trt_node->in_edges()) {
-    VLOG(2) << edge->src()->name() << " port: " << edge->src_output();
+  if (VLOG_IS_ON(2)) {
+    VLOG(2) << "new edge count: " << trt_node->in_edges().size();
+    for (const tensorflow::Edge* edge : trt_node->in_edges()) {
+      VLOG(2) << edge->src()->name() << " port: " << edge->src_output();
+    }
   }
-
   TF_RETURN_IF_ERROR(status);
 
   // Re-map outgoing edges to use the new TRT node instead of the orig subgraph
@@ -278,11 +280,14 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) {
     subgraph_edge_to_output_map.insert({params->subgraph_outputs.at(i), i});
   }
   TF_RETURN_IF_ERROR(status);
+  unique_tensors.clear();
   for (const tensorflow::Edge* edge : params->subgraph_outgoing_edges) {
     std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
     int new_src_output = subgraph_edge_to_output_map.at(old_src);
     TF_RETURN_IF_ERROR(params->graph.UpdateEdge(
         trt_node, new_src_output, edge->dst(), edge->dst_input()));
+    VLOG(1) << "Wire " << trt_node->name() << ":" << new_src_output << " -> "
+            << edge->dst()->name() << ":" << edge->dst_input();
   }
   // Remove the original subgraph
   for (int node_id : params->subgraph_node_ids) {
@@ -317,9 +322,12 @@ tensorflow::Status ConvertCalibGraphToInferGraph(
       tensorflow::GraphConstructorOptions(), graph_def, &graph));
   //  get calib nodes
   std::vector<tensorflow::Node*> calib_nodes;
-  for (auto node : graph.op_nodes()) {
+  std::vector<tensorflow::Node*> topo_order;
+  tensorflow::GetPostOrder(graph, &topo_order);
+  for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
+    auto node = *rit;
     if (node->type_string() == "TRTCalibOp") {
-      VLOG(1) << "Found Calib Node";
+      VLOG(1) << "Found Calib Node " << node->name();
       calib_nodes.push_back(node);
     }
   }
index 32b211d..16bfcc3 100644 (file)
@@ -362,10 +362,11 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights,
       break;
     }
     case tensorflow::DataType::DT_HALF: {
-      Reorder2({k, c}, static_cast<Eigen::half const*>(iweights.GetValues()),
-               istrides, static_cast<Eigen::half*>(
-                             const_cast<void*>(oweights->GetValues())),
-               ostrides);
+      Reorder2(
+          {k, c}, static_cast<Eigen::half const*>(iweights.GetValues()),
+          istrides,
+          static_cast<Eigen::half*>(const_cast<void*>(oweights->GetValues())),
+          ostrides);
       break;
     }
     default:
@@ -1179,9 +1180,9 @@ tensorflow::Status BinaryTensorOpTensor(
   CHECK_EQ_TYPE(tensor_r->getType(), dtype);
   auto op_pair = ops.find(node_def.op());
   if (op_pair == ops.end())
-    return tensorflow::errors::Unimplemented("binary op: " + node_def.op() +
-                                             " not supported at: " +
-                                             node_def.name());
+    return tensorflow::errors::Unimplemented(
+        "binary op: " + node_def.op() +
+        " not supported at: " + node_def.name());
 
   nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
       *const_cast<nvinfer1::ITensor*>(tensor_l),
@@ -2138,9 +2139,7 @@ void Converter::register_op_converters() {
 }
 
 }  // namespace
-tensorflow::Status GetTensorRTGraph(tensorrt::convert::SubGraphParams& s) {
-  return tensorflow::errors::Unimplemented("Not implemented yet");
-}
+
 tensorflow::Status ConvertCalibrationNodeToEngineNode(
     tensorflow::Graph& graph, tensorflow::Node* c_node) {
   const auto ndef = c_node->def();
@@ -2164,9 +2163,23 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode(
   for (auto n : graph.op_nodes()) {
     node_maps.insert({n->name(), n});
   }
-  VLOG(1) << "Output Nodes:";
+  std::set<int> subgraph_ids;
+  for (const auto internal_node : segment_nodes) {
+    subgraph_ids.insert(node_maps.at(internal_node)->id());
+  }
+  if (VLOG_IS_ON(2)) {
+    string node_names = StrCat(c_node->name(), " segment nodes= ");
+
+    for (const auto& node_name : segment_nodes) {
+      StrAppend(&node_names, node_name, ", ");
+    }
+    VLOG(2) << node_names;
+  }
+
+  VLOG(0) << "Output Nodes:";
   std::vector<tensorflow::DataType> out_types;
   std::vector<const tensorflow::Edge*> out_edges;
+
   for (auto& i : output_nodes) {
     auto node_port = tensorflow::str_util::Split(i, ":");
     VLOG(1) << " " << i << " in graph " << node_maps.count(i);
@@ -2186,9 +2199,13 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode(
         out_types.push_back(out_node->output_type(0));
       }
       for (auto out_edge : out_node->out_edges()) {
+        if (subgraph_ids.count(out_edge->dst()->id()))
+          continue;  // skip internal edges;
         if (out_edge->src_output() == port) {
           out_edges.push_back(out_edge);
-          break;
+          VLOG(1) << "OUTPUT EDGE " << out_edge->src()->name() << ":"
+                  << out_edge->src_output() << " -> " << out_edge->dst()->name()
+                  << ":" << out_edge->dst_input();
         }
       }
     } else {
@@ -2255,13 +2272,18 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode(
   }
   auto trt_engine_node = graph.AddNode(engine_node, &status);
   TF_RETURN_IF_ERROR(status);
-  for (size_t i = 0; i < out_edges.size(); i++) {
-    VLOG(1) << "Connecting trt_engine_node output " << i << " with "
-            << out_edges.at(i)->dst()->name() << " port "
-            << out_edges.at(i)->dst_input();
-    TF_RETURN_IF_ERROR(graph.UpdateEdge(trt_engine_node, i,
-                                        out_edges.at(i)->dst(),
-                                        out_edges.at(i)->dst_input()));
+  std::map<string, int> port_map;
+  for (size_t t = 0; t < output_nodes.size(); t++) {
+    port_map.insert({output_nodes.at(t), t});
+  }
+  for (auto& i : out_edges) {
+    string s(i->src()->name());
+    if (i->src_output()) StrAppend(&s, ":", i->src_output());
+    int out_port = port_map.at(s);
+    VLOG(1) << "Connecting " << trt_engine_node->name() << " port " << out_port
+            << " with " << i->dst()->name() << " port " << i->dst_input();
+    TF_RETURN_IF_ERROR(
+        graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input()));
   }
   VLOG(1) << "Segment nodes:";
   for (auto& i : segment_nodes) {
@@ -2332,6 +2354,7 @@ tensorflow::Status ConvertSubgraph(
     std::vector<string>* output_names,
     std::vector<tensorflow::DataType>* output_dtypes,
     const string& engine_name) {
+  std::set<string> added_tensors;
   for (const std::pair<int, int>& input : s.input_inds) {
     VLOG(2) << "parsing input. Node id= " << input.first;
     int node_id = input.first;
@@ -2374,7 +2397,6 @@ tensorflow::Status ConvertSubgraph(
 
     auto op_info = op_info_vec.at(shape_inference_output_idx);
     tensorflow::DataType tf_dtype = op_info.dtype();
-    input_dtypes->push_back(tf_dtype);
 
     nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
     auto type_status = ConvertDType(tf_dtype, &dtype);
@@ -2410,8 +2432,10 @@ tensorflow::Status ConvertSubgraph(
     if (output_idx != 0) {
       input_tensor_name = StrCat(node_name, ":", output_idx);
     }
-
+    if (added_tensors.count(input_tensor_name)) continue;
+    added_tensors.insert(input_tensor_name);
     input_names->push_back(input_tensor_name);
+    input_dtypes->push_back(tf_dtype);
     nvinfer1::ITensor* input_tensor = converter.network()->addInput(
         input_tensor_name.c_str(), dtype, input_dim_pseudo_chw);
 
@@ -2435,6 +2459,7 @@ tensorflow::Status ConvertSubgraph(
 
   // Gather output metadata
   int trt_engine_op_output_idx = 0;
+  added_tensors.clear();
   for (const std::pair<int, int>& output : s.output_inds) {
     int node_id = output.first;
     int output_idx = output.second;
@@ -2451,6 +2476,8 @@ tensorflow::Status ConvertSubgraph(
     if (output_idx != 0)
       tensorflow::strings::StrAppend(&tensor_name, ":", output_idx);
     VLOG(2) << "Output tensor name: " << tensor_name;
+    if (added_tensors.count(tensor_name)) continue;
+    added_tensors.insert(tensor_name);
     output_names->push_back(tensor_name);
     auto tensor_or_weights = converter.get_tensor(tensor_name);
     if (!tensor_or_weights.is_tensor()) {