Fixes for PR
authorSami Kama <skama@nvidia.com>
Fri, 2 Mar 2018 20:59:28 +0000 (12:59 -0800)
committerSami Kama <skama@nvidia.com>
Fri, 2 Mar 2018 20:59:28 +0000 (12:59 -0800)
tensorflow/contrib/tensorrt/convert/convert_nodes.cc
tensorflow/contrib/tensorrt/python/trt_convert.py
tensorflow/contrib/tensorrt/test/test_tftrt.py

index a36851a..a7287e4 100644 (file)
@@ -2067,7 +2067,6 @@ void Converter::register_op_converters() {
   // This could be really handled as ConvertBinary
   op_registry_["BiasAdd"] = ConvertScale;
   op_registry_["Const"] = ConvertConst;
-  // op_registry_["MatMul"] = ConvertFullyConnected; // not used in vgg
   // TODO(ben,jie): this is a temp hack.
   op_registry_["Identity"] = ConvertIdentity;  // Identity should be removed
 
index 071f09d..d1f9f8a 100644 (file)
@@ -23,7 +23,7 @@ import six as _six
 from tensorflow.core.framework import graph_pb2
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import errors_impl as _impl
-from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert,calib_convert
+from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert, calib_convert
 from tensorflow.python.util import compat
 import tensorflow as tf
 from tensorflow.python.grappler import tf_optimizer
@@ -32,9 +32,6 @@ from tensorflow.python.framework import meta_graph
 from tensorflow.python.framework import ops
 
 
-from tensorflow.python.framework import ops
-
-
 # TODO(skama): get outputs from session when implemented as c++
 # optimization pass
 def create_inference_graph(input_graph_def,
@@ -58,13 +55,14 @@ def create_inference_graph(input_graph_def,
   Raises:
     RuntimeError: if the returned status message is malformed.
   """
-  supported_precision_modes={"FP32":0,
-                             "FP16":1,
-                             "INT8":2}
+  supported_precision_modes = {"FP32": 0,
+                               "FP16": 1,
+                               "INT8": 2}
   if precision_mode.upper() not in supported_precision_modes:
     raise ValueError(("precision mode '{}' is not supported."
-    "It should be one of {}").format(precision_mode,"{'FP32','FP16','INT8'}"))
-  mode=supported_precision_modes[precision_mode.upper()]
+                      "It should be one of {}").format(precision_mode,
+                      "{'FP32', 'FP16', 'INT8'}"))
+  mode = supported_precision_modes[precision_mode.upper()]
   def py2bytes(inp):
     return inp
 
@@ -99,7 +97,7 @@ def create_inference_graph(input_graph_def,
   # pair or strings where first one is encoded status and the second
   # one is the transformed graphs protobuf string.
   out = trt_convert(input_graph_def_str, out_names, max_batch_size,
-                    max_workspace_size_bytes,mode,minimum_segment_size)
+                    max_workspace_size_bytes, mode, minimum_segment_size)
   status = to_string(out[0])
   output_graph_def_string = out[1]
   del input_graph_def_str  # Save some memory
@@ -119,6 +117,8 @@ def create_inference_graph(input_graph_def,
   return output_graph_def
 
 def calib_graph_to_infer_graph(calibration_graph_def):
+  """Convert an existing calibration graph containing calibration data
+  to inference graph"""
   def py2bytes(inp):
     return inp
 
@@ -132,21 +132,19 @@ def calib_graph_to_infer_graph(calibration_graph_def):
     return inp.decode("utf-8")
 
   if _six.PY2:
-    to_bytes = py2bytes
     to_string = py2string
   else:
-    to_bytes = py3bytes
     to_string = py3string
 
-  graph_str=calibration_graph_def.SerializeToString()
-  out=calib_convert(graph_str)
-  status=to_string(out[0])
+  graph_str = calibration_graph_def.SerializeToString()
+  out = calib_convert(graph_str)
+  status = to_string(out[0])
   output_graph_def_string = out[1]
   del graph_str #save some memory
   if len(status) < 2:
-    raise _impl.UnknownError(None,None,status)
+    raise _impl.UnknownError(None, None, status)
   if status[:2] != "OK":
-    msg=status.split(";")
+    msg = status.split(";")
     if len(msg) == 1:
       raise RuntimeError("Status message is malformed {}".format(status))
     raise _impl._make_specific_exception(None,None,";".join(msg[1:]), int(msg[0]))
index cfa18ab..385a9f7 100644 (file)
@@ -89,7 +89,9 @@ def run_calibration(gdef, dumm_inp):
     out = out.outputs[0]
   with csess.Session(
       config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess:
-    for _ in range(30):
+    # run over real calibration data here, we are mimicking a 
+    # calibration set of 30 different batches. Use as much calibration data as you want
+    for _ in range(30):  
       val = sess.run(out, {inp: dumm_inp})
   return val
 
@@ -122,7 +124,7 @@ if "__main__" in __name__:
                                                outputs=["output"],
                                                max_batch_size=inp_dims[0],
                                                max_workspace_size_bytes=1 << 25,
-                                               precision_mode="INt8",  # TRT Engine precision "FP32","FP16" or "INT8"
+                                               precision_mode="INT8",  # TRT Engine precision "FP32","FP16" or "INT8"
                                                minimum_segment_size=2  # minimum number of nodes in an engine
                                               )
   o4 = run_graph(fp16_graph, dummy_input)