Fix INT8 conversion bailing in case of unsupported TRT feature
authorSami Kama <skama@nvidia.com>
Thu, 5 Apr 2018 20:30:55 +0000 (13:30 -0700)
committerSami Kama <skama@nvidia.com>
Thu, 5 Apr 2018 20:30:55 +0000 (13:30 -0700)
tensorflow/contrib/tensorrt/convert/convert_graph.cc
tensorflow/contrib/tensorrt/convert/convert_nodes.cc

index ff8cc63..b412b29 100644 (file)
@@ -405,7 +405,13 @@ tensorflow::Status ConvertGraphDefToTensorRT(
                          max_mem_per_engine, static_graph_properties,
                          &output_edge_map, precision_mode);
     if (precision_mode == INT8MODE) {
-      TF_RETURN_IF_ERROR(GetCalibNode(&p));
+      tensorflow::Status status = GetCalibNode(&p);
+      if (status != tensorflow::Status::OK()) {
+        LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count
+                     << " due to: \"" << status.ToString()
+                     << "\" SKIPPING......( " << subgraph_node_names.size()
+                     << " nodes)";
+      }
     } else {
       tensorflow::Status status = ConvertSubGraphToTensorRT(&p);
       if (status != tensorflow::Status::OK()) {
@@ -414,8 +420,8 @@ tensorflow::Status ConvertGraphDefToTensorRT(
                      << "\" SKIPPING......( " << subgraph_node_names.size()
                      << " nodes)";
       }
-      count++;
     }
+    count++;
   }
   graph.ToGraphDef(new_graph_def);
   return tensorflow::Status::OK();
index e920a79..ee1273d 100644 (file)
@@ -2262,6 +2262,7 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
   auto ws = new tensorflow::tensorrt::TRTWeightStore();
   TF_CHECK_OK(weight_rmgr->Create(calib_op_name, calib_op_name, ws));
   Converter converter(op_res->network_, ws, s.precision_mode == FP16MODE);
+
   std::vector<string> input_names;
   std::vector<tensorflow::DataType> input_dtypes;
   for (const std::pair<int, int>& input : s.input_inds) {
@@ -2270,20 +2271,41 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
     int output_idx = input.second;
     tensorflow::Node* node = s.graph.FindNodeId(node_id);
     auto node_name = node->name();
-    input_names.push_back(node_name);  // insert original node name without port
-    // TODO(jie): alternative :)
-    if (!s.graph_properties.HasOutputProperties(node_name))
+    // input_names should use the node name in the graph
+    // here it should be the input tensor name -> matching the binding
+    // insert original node name without port
+    auto tensor_name = node_name;
+    if (output_idx != 0) {
+      tensor_name = StrCat(tensor_name, ":", output_idx);
+    }
+
+    VLOG(2) << "input name: " << node_name << " tensor_name: " << tensor_name
+            << " idx: " << output_idx;
+
+    auto shape_inference_node_name = node_name;
+    auto shape_inference_output_idx = output_idx;
+    // rewire the shape inference to original node in the graph
+    if (s.output_edge_map->count(tensor_name)) {
+      shape_inference_node_name = s.output_edge_map->at(tensor_name).second;
+      shape_inference_output_idx = s.output_edge_map->at(tensor_name).first;
+    }
+    if (shape_inference_output_idx < 0) continue;
+    VLOG(2) << "shapeinference name: " << shape_inference_node_name
+            << " idx: " << shape_inference_output_idx;
+
+    if (!s.graph_properties.HasOutputProperties(shape_inference_node_name))
       return tensorflow::errors::Internal("failed to find input node: " +
-                                          node_name);
+                                          shape_inference_node_name);
 
-    auto op_info_vec = s.graph_properties.GetOutputProperties(node_name);
-    if (static_cast<int>(op_info_vec.size()) < output_idx)
+    auto op_info_vec =
+        s.graph_properties.GetOutputProperties(shape_inference_node_name);
+    if (static_cast<int>(op_info_vec.size()) <= shape_inference_output_idx)
       return tensorflow::errors::Internal(
-          "accessing output index of: ", output_idx, ", at node: ", node_name,
-          "with output entry from shape_map: ", op_info_vec.size());
-
-    auto op_info = op_info_vec.at(output_idx);
+          "accessing output index of: ", shape_inference_output_idx,
+          ", at node: ", shape_inference_node_name,
+          " with output entry from shape_map: ", op_info_vec.size());
 
+    auto op_info = op_info_vec.at(shape_inference_output_idx);
     tensorflow::DataType tf_dtype = op_info.dtype();
     input_dtypes.push_back(tf_dtype);
 
@@ -2294,16 +2316,23 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
                    << "' failed";
       return type_status;
     }
-    TF_CHECK_OK(ConvertDType(tf_dtype, &dtype));
 
     VLOG(2) << "accessing output index of: " << output_idx
             << ", at node: " << node_name
             << "with output entry from shape_map: " << op_info_vec.size();
-
     // TODO(ben,jie): update TRT input format/dimension
     nvinfer1::DimsCHW input_dim_psuedo_chw;
     for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
 
+    // TODO(jie): TRT 3.x only support 4 dimensional input tensor.
+    //            update the code once TRT 4.0 comes out.
+    if (op_info.shape().dim_size() != 4) {
+      string err_str = "Require 4 dimensional input.";
+      StrAppend(&err_str, " Got ", op_info.shape().dim_size(), " ",
+                shape_inference_node_name);
+      return tensorflow::errors::Unimplemented(err_str);
+    }
+
     for (int i = 1; i < op_info.shape().dim_size(); i++) {
       VLOG(2) << "dimension: " << i
               << " , size: " << op_info.shape().dim(i).size();
@@ -2312,8 +2341,11 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
 
     // TODO(ben,jie): proper way to restore input tensor name?
     auto input_tensor_name = node_name;
-    if (output_idx != 0) input_tensor_name = StrCat(node_name, ":", output_idx);
+    if (output_idx != 0) {
+      input_tensor_name = StrCat(node_name, ":", output_idx);
+    }
 
+    input_names.push_back(input_tensor_name);
     nvinfer1::ITensor* input_tensor = converter.network()->addInput(
         input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
 
@@ -2377,11 +2409,13 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
     tensor->setType(trt_dtype);
   }
 
-  VLOG(2) << "finished output";
+  VLOG(2) << "Finished processing outputs";
 
   // Build the engine
   op_res->builder_->setMaxBatchSize(s.max_batch_size);
   op_res->builder_->setMaxWorkspaceSize(s.max_workspace_size_bytes);
+  VLOG(0) << "Max batch size= " << s.max_batch_size
+          << " max workspace size= " << s.max_workspace_size_bytes;
 
   // Build the TRT op
   // TODO(sami,ben,jie): proper naming!
@@ -2475,7 +2509,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
   std::vector<string> input_names;
   std::vector<tensorflow::DataType> input_dtypes;
   for (const std::pair<int, int>& input : s.input_inds) {
-    VLOG(2) << "parsing input!!!!!";
+    VLOG(2) << "parsing input. Node id= " << input.first ;
     int node_id = input.first;
     int output_idx = input.second;
     tensorflow::Node* node = s.graph.FindNodeId(node_id);