From 35ba5d8dc8899e28ac789dc493f0dd205e169c74 Mon Sep 17 00:00:00 2001 From: Sami Kama Date: Mon, 12 Feb 2018 13:35:23 -0800 Subject: [PATCH] Fix Py3 byte and string issue after swig update. Clarify failure message on finding libnvinfer in configure.py --- configure.py | 21 +++++++++++++++------ tensorflow/contrib/tensorrt/python/trt_convert.py | 12 ++++++++++-- tensorflow/contrib/tensorrt/test/test_tftrt.py | 2 ++ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/configure.py b/configure.py index 9cf5902..3aa1a3e 100644 --- a/configure.py +++ b/configure.py @@ -1078,12 +1078,21 @@ def set_tf_tensorrt_install_path(environ_cp): break # Reset and Retry - print('Invalid path to TensorRT. None of the following files can be found:') - print(trt_install_path) - print(os.path.join(trt_install_path, 'lib')) - print(os.path.join(trt_install_path, 'lib64')) - if search_result: - print(libnvinfer_path_from_ldconfig) + if len(possible_files): + print('TensorRT libraries found in one the following directories', + 'are not compatible with selected cuda and cudnn installations') + print(trt_install_path) + print(os.path.join(trt_install_path, 'lib')) + print(os.path.join(trt_install_path, 'lib64')) + if search_result: + print(libnvinfer_path_from_ldconfig) + else: + print('Invalid path to TensorRT. None of the following files can be found:') + print(trt_install_path) + print(os.path.join(trt_install_path, 'lib')) + print(os.path.join(trt_install_path, 'lib64')) + if search_result: + print(libnvinfer_path_from_ldconfig) else: raise UserInputError('Invalid TF_TENSORRT setting was provided %d ' diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 69bbf45..9454862 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -55,10 +55,18 @@ def create_inference_graph(input_graph_def, def py3bytes(inp): return inp.encode("utf-8", errors="surrogateescape") + def py2string(inp): + return inp + + def py3string(inp): + return inp.decode("utf-8") + if _six.PY2: to_bytes = py2bytes + to_string = py2string else: to_bytes = py3bytes + to_string = py3string out_names = [] for i in outputs: @@ -76,8 +84,8 @@ def create_inference_graph(input_graph_def, # one is the transformed graphs protobuf string. out = trt_convert(input_graph_def_str, out_names, max_batch_size, max_workspace_size_bytes) - status = out[0] - output_graph_def_string = to_bytes(out[1]) + status = to_string(out[0]) + output_graph_def_string = out[1] del input_graph_def_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index 927a3e4..adf3438 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -67,5 +67,7 @@ if "__main__" in __name__: inpDims[0]) # Get optimized graph o1 = runGraph(gdef, dummy_input) o2 = runGraph(trt_graph, dummy_input) + o3 = runGraph(trt_graph, dummy_input) assert (np.array_equal(o1, o2)) + assert (np.array_equal(o2, o3)) print("Pass") -- 2.7.4