From cf4d066720bbf9c3c79aa62d9b1057939af5ff63 Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Sat, 10 Feb 2018 22:44:01 -0800 Subject: [PATCH] Fix python lint errors internally. One important change is to rename CreateInferenceGraph to create_inference_graph. --- tensorflow/contrib/tensorrt/README.md | 2 +- tensorflow/contrib/tensorrt/__init__.py | 4 +++ .../contrib/tensorrt/convert/convert_nodes.cc | 1 + tensorflow/contrib/tensorrt/python/__init__.py | 22 +++++++++++++--- .../contrib/tensorrt/python/ops/trt_engine_op.py | 1 + tensorflow/contrib/tensorrt/python/trt_convert.py | 30 ++++++++++++++-------- .../contrib/tensorrt/segment/segment_test.cc | 2 +- tensorflow/contrib/tensorrt/test/test_tftrt.py | 4 +-- 8 files changed, 49 insertions(+), 17 deletions(-) diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md index 1e9524c..dfcce0f 100644 --- a/tensorflow/contrib/tensorrt/README.md +++ b/tensorflow/contrib/tensorrt/README.md @@ -29,7 +29,7 @@ import tensorflow as tf import tensorflow.contrib.tensorrt as trt #... create and train or load model gdef = sess.graph.as_graph_def() -trt_gdef = trt.CreateInferenceGraph( +trt_gdef = trt.create_inference_graph( gdef, #original graph_def ["output"], #name of output node(s) max_batch_size, #maximum batch size to run the inference diff --git a/tensorflow/contrib/tensorrt/__init__.py b/tensorflow/contrib/tensorrt/__init__.py index 5072ab1..fd551d70 100644 --- a/tensorflow/contrib/tensorrt/__init__.py +++ b/tensorflow/contrib/tensorrt/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= +"""Exposes the python wrapper for TensorRT graph transforms.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function +# pylint: disable=unused-import,wildcard-import from tensorflow.contrib.tensorrt.python import * +# pylint: enable=unused-import,wildcard-import diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 5c22c62..9ee717d 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index 4aeea48..7e050a7 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -1,8 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Exposes the python wrapper for TensorRT graph transforms.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import,wildcard-import +# pylint: disable=unused-import,line-too-long from tensorflow.contrib.tensorrt.python.ops import trt_engine_op -from tensorflow.contrib.tensorrt.python.trt_convert import CreateInferenceGraph -# pylint: enable=unused-import,wildcard-import +from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph +# pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py index 97db237..31a3131 100644 --- a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py +++ b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= +"""Exposes the Python wrapper of TRTEngineOp.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 5161831..69bbf45 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -13,24 +13,26 @@ # limitations under the License. # ============================================================================= """Exposes the Python wrapper conversion to trt_graph.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import,wildcard-import, line-too-long +# pylint: disable=unused-import,line-too-long +import six as _six +from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert from tensorflow.core.framework import graph_pb2 from tensorflow.python.framework import errors from tensorflow.python.framework import errors_impl as _impl -from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert from tensorflow.python.framework import ops -import six as _six + # TODO(skama): get outputs from session when implemented as c++ # optimization pass -def CreateInferenceGraph(input_graph_def, - outputs, - max_batch_size=1, - max_workspace_size_bytes=2 << 20): +def create_inference_graph(input_graph_def, + outputs, + max_batch_size=1, + max_workspace_size_bytes=2 << 20): """Python wrapper for the TRT transormation. @@ -42,11 +44,17 @@ def CreateInferenceGraph(input_graph_def, Returns: New GraphDef with TRTEngineOps placed in graph replacing subgraphs. + + Raises: + RuntimeError: if the returned status message is malformed. """ + def py2bytes(inp): return inp + def py3bytes(inp): - return inp.encode('utf-8', errors='surrogateescape') + return inp.encode("utf-8", errors="surrogateescape") + if _six.PY2: to_bytes = py2bytes else: @@ -70,16 +78,18 @@ def CreateInferenceGraph(input_graph_def, max_workspace_size_bytes) status = out[0] output_graph_def_string = to_bytes(out[1]) - del input_graph_def_str #save some memory + del input_graph_def_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) + # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) + # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) - del output_graph_def_string #save some memory + del output_graph_def_string # Save some memory return output_graph_def diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index d7e10c1..93c113e 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/c/c_api.h" #include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/c/c_api.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index ad7a85c..927a3e4 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -63,8 +63,8 @@ if "__main__" in __name__: inpDims = (100, 24, 24, 2) dummy_input = np.random.random_sample(inpDims) gdef = getSimpleGraphDef() - trt_graph = trt.CreateInferenceGraph(gdef, ["output"], - inpDims[0]) # Get optimized graph + trt_graph = trt.create_inference_graph(gdef, ["output"], + inpDims[0]) # Get optimized graph o1 = runGraph(gdef, dummy_input) o2 = runGraph(trt_graph, dummy_input) assert (np.array_equal(o1, o2)) -- 2.7.4