# pylint: disable=unused-import,line-too-long
import six as _six
+from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
+from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl as _impl
-from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert, calib_convert
-from tensorflow.python.util import compat
-from tensorflow.python.grappler import tf_optimizer
-from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.util import compat
+# pylint: enable=unused-import,line-too-long
# TODO(skama): get outputs from session when implemented as c++
minimum_segment_size=3):
"""Python wrapper for the TRT transormation.
-
Args:
input_graph_def: GraphDef object containing a model to be transformed.
- outputs: List of tensors or node names for the model outputs.
+ outputs: list of tensors or node names for the model outputs.
max_batch_size: max size for the input batch
max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
+ precision_mode: one of 'FP32', 'FP16' and 'INT8'
+ minimum_segment_size: the minimum number of nodes required for a subgraph to
+ be replaced by TRTEngineOp.
Returns:
New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
Raises:
+ ValueError: if the provided precision mode is invalid.
RuntimeError: if the returned status message is malformed.
"""
supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
def calib_graph_to_infer_graph(calibration_graph_def):
- """Convert an existing calibration graph containing calibration data
- to inference graph"""
+ """Convert an existing calibration graph to inference graph.
+
+ Args:
+ calibration_graph_def: the calibration GraphDef object with calibration data
+ Returns:
+ New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
+ Raises:
+ RuntimeError: if the returned status message is malformed.
+ """
def py2string(inp):
return inp
out = calib_convert(graph_str)
status = to_string(out[0])
output_graph_def_string = out[1]
- del graph_str #save some memory
+ del graph_str # Save some memory
if len(status) < 2:
raise _impl.UnknownError(None, None, status)
if status[:2] != "OK":
msg = status.split(";")
if len(msg) == 1:
raise RuntimeError("Status message is malformed {}".format(status))
+ # pylint: disable=protected-access
raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
int(msg[0]))
+ # pylint: enable=protected-access
output_graph_def = graph_pb2.GraphDef()
output_graph_def.ParseFromString(output_graph_def_string)
- del output_graph_def_string #save some memory
+ del output_graph_def_string # Save some memory
return output_graph_def
def run_graph(gdef, dumm_inp):
- """Run given graphdef once"""
+ """Run given graphdef once."""
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
ops.reset_default_graph()
g = ops.Graph()
# Use real data that is representatitive of the inference dataset
-# for calibration. For this test script it is random data
-
-
+# for calibration. For this test script it is random data.
def run_calibration(gdef, dumm_inp):
- """Run given calibration graph multiple times"""
+ """Run given calibration graph multiple times."""
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
ops.reset_default_graph()
g = ops.Graph()