from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl as _impl
-from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert,calib_convert
+from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert, calib_convert
from tensorflow.python.util import compat
import tensorflow as tf
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.framework import ops
-from tensorflow.python.framework import ops
-
-
# TODO(skama): get outputs from session when implemented as c++
# optimization pass
def create_inference_graph(input_graph_def,
Raises:
RuntimeError: if the returned status message is malformed.
"""
- supported_precision_modes={"FP32":0,
- "FP16":1,
- "INT8":2}
+ supported_precision_modes = {"FP32": 0,
+ "FP16": 1,
+ "INT8": 2}
if precision_mode.upper() not in supported_precision_modes:
raise ValueError(("precision mode '{}' is not supported."
- "It should be one of {}").format(precision_mode,"{'FP32','FP16','INT8'}"))
- mode=supported_precision_modes[precision_mode.upper()]
+ "It should be one of {}").format(precision_mode,
+ "{'FP32', 'FP16', 'INT8'}"))
+ mode = supported_precision_modes[precision_mode.upper()]
def py2bytes(inp):
return inp
# pair or strings where first one is encoded status and the second
# one is the transformed graphs protobuf string.
out = trt_convert(input_graph_def_str, out_names, max_batch_size,
- max_workspace_size_bytes,mode,minimum_segment_size)
+ max_workspace_size_bytes, mode, minimum_segment_size)
status = to_string(out[0])
output_graph_def_string = out[1]
del input_graph_def_str # Save some memory
return output_graph_def
def calib_graph_to_infer_graph(calibration_graph_def):
+ """Convert an existing calibration graph containing calibration data
+ to inference graph"""
def py2bytes(inp):
return inp
return inp.decode("utf-8")
if _six.PY2:
- to_bytes = py2bytes
to_string = py2string
else:
- to_bytes = py3bytes
to_string = py3string
- graph_str=calibration_graph_def.SerializeToString()
- out=calib_convert(graph_str)
- status=to_string(out[0])
+ graph_str = calibration_graph_def.SerializeToString()
+ out = calib_convert(graph_str)
+ status = to_string(out[0])
output_graph_def_string = out[1]
del graph_str #save some memory
if len(status) < 2:
- raise _impl.UnknownError(None,None,status)
+ raise _impl.UnknownError(None, None, status)
if status[:2] != "OK":
- msg=status.split(";")
+ msg = status.split(";")
if len(msg) == 1:
raise RuntimeError("Status message is malformed {}".format(status))
raise _impl._make_specific_exception(None,None,";".join(msg[1:]), int(msg[0]))
out = out.outputs[0]
with csess.Session(
config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess:
- for _ in range(30):
+ # run over real calibration data here, we are mimicking a
+ # calibration set of 30 different batches. Use as much calibration data as you want
+ for _ in range(30):
val = sess.run(out, {inp: dumm_inp})
return val
outputs=["output"],
max_batch_size=inp_dims[0],
max_workspace_size_bytes=1 << 25,
- precision_mode="INt8", # TRT Engine precision "FP32","FP16" or "INT8"
+ precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8"
minimum_segment_size=2 # minimum number of nodes in an engine
)
o4 = run_graph(fp16_graph, dummy_input)