Fix python formatting and add missing docstrings
authorgracehoney <31743510+aaroey@users.noreply.github.com>
Tue, 6 Mar 2018 23:33:21 +0000 (15:33 -0800)
committergracehoney <31743510+aaroey@users.noreply.github.com>
Tue, 6 Mar 2018 23:33:21 +0000 (15:33 -0800)
tensorflow/contrib/tensorrt/python/__init__.py
tensorflow/contrib/tensorrt/python/trt_convert.py
tensorflow/contrib/tensorrt/test/test_tftrt.py

index 3941d15..0b2321b 100644 (file)
@@ -20,6 +20,6 @@ from __future__ import print_function
 
 # pylint: disable=unused-import,line-too-long
 from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
-from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
 from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph
+from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
 # pylint: enable=unused-import,line-too-long
index 861b316..666220d 100644 (file)
@@ -20,15 +20,17 @@ from __future__ import print_function
 
 # pylint: disable=unused-import,line-too-long
 import six as _six
+from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
+from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
 from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import errors_impl as _impl
-from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert, calib_convert
-from tensorflow.python.util import compat
-from tensorflow.python.grappler import tf_optimizer
-from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python.framework import meta_graph
 from tensorflow.python.framework import ops
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.util import compat
+# pylint: enable=unused-import,line-too-long
 
 
 # TODO(skama): get outputs from session when implemented as c++
@@ -41,17 +43,20 @@ def create_inference_graph(input_graph_def,
                            minimum_segment_size=3):
   """Python wrapper for the TRT transormation.
 
-
   Args:
     input_graph_def: GraphDef object containing a model to be transformed.
-    outputs: List of tensors or node names for the model outputs.
+    outputs: list of tensors or node names for the model outputs.
     max_batch_size: max size for the input batch
     max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
+    precision_mode: one of 'FP32', 'FP16' and 'INT8'
+    minimum_segment_size: the minimum number of nodes required for a subgraph to
+      be replaced by TRTEngineOp.
 
   Returns:
     New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
 
   Raises:
+    ValueError: if the provided precision mode is invalid.
     RuntimeError: if the returned status message is malformed.
   """
   supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
@@ -116,8 +121,15 @@ def create_inference_graph(input_graph_def,
 
 
 def calib_graph_to_infer_graph(calibration_graph_def):
-  """Convert an existing calibration graph containing calibration data
-  to inference graph"""
+  """Convert an existing calibration graph to inference graph.
+
+  Args:
+    calibration_graph_def: the calibration GraphDef object with calibration data
+  Returns:
+    New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
+  Raises:
+    RuntimeError: if the returned status message is malformed.
+  """
 
   def py2string(inp):
     return inp
@@ -134,16 +146,18 @@ def calib_graph_to_infer_graph(calibration_graph_def):
   out = calib_convert(graph_str)
   status = to_string(out[0])
   output_graph_def_string = out[1]
-  del graph_str  #save some memory
+  del graph_str  # Save some memory
   if len(status) < 2:
     raise _impl.UnknownError(None, None, status)
   if status[:2] != "OK":
     msg = status.split(";")
     if len(msg) == 1:
       raise RuntimeError("Status message is malformed {}".format(status))
+    # pylint: disable=protected-access
     raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                          int(msg[0]))
+    # pylint: enable=protected-access
   output_graph_def = graph_pb2.GraphDef()
   output_graph_def.ParseFromString(output_graph_def_string)
-  del output_graph_def_string  #save some memory
+  del output_graph_def_string  # Save some memory
   return output_graph_def
index a5cfb9b..0b661bd 100644 (file)
@@ -60,7 +60,7 @@ def get_simple_graph_def():
 
 
 def run_graph(gdef, dumm_inp):
-  """Run given graphdef once"""
+  """Run given graphdef once."""
   gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
   ops.reset_default_graph()
   g = ops.Graph()
@@ -76,11 +76,9 @@ def run_graph(gdef, dumm_inp):
 
 
 # Use real data that is representatitive of the inference dataset
-# for calibration. For this test script it is random data
-
-
+# for calibration. For this test script it is random data.
 def run_calibration(gdef, dumm_inp):
-  """Run given calibration graph multiple times"""
+  """Run given calibration graph multiple times."""
   gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
   ops.reset_default_graph()
   g = ops.Graph()