Fix python lint errors internally. One important change is to rename CreateInferenceG...
authorgracehoney <31743510+aaroey@users.noreply.github.com>
Sun, 11 Feb 2018 06:44:01 +0000 (22:44 -0800)
committergracehoney <31743510+aaroey@users.noreply.github.com>
Sun, 11 Feb 2018 06:44:01 +0000 (22:44 -0800)
tensorflow/contrib/tensorrt/README.md
tensorflow/contrib/tensorrt/__init__.py
tensorflow/contrib/tensorrt/convert/convert_nodes.cc
tensorflow/contrib/tensorrt/python/__init__.py
tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
tensorflow/contrib/tensorrt/python/trt_convert.py
tensorflow/contrib/tensorrt/segment/segment_test.cc
tensorflow/contrib/tensorrt/test/test_tftrt.py

index 1e9524c..dfcce0f 100644 (file)
@@ -29,7 +29,7 @@ import tensorflow as tf
 import tensorflow.contrib.tensorrt as trt
 #... create and train or load model
 gdef = sess.graph.as_graph_def()
-trt_gdef = trt.CreateInferenceGraph(
+trt_gdef = trt.create_inference_graph(
     gdef, #original graph_def
     ["output"], #name of output node(s)
     max_batch_size, #maximum batch size to run the inference
index 5072ab1..fd551d7 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # =============================================================================
+"""Exposes the python wrapper for TensorRT graph transforms."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+# pylint: disable=unused-import,wildcard-import
 from tensorflow.contrib.tensorrt.python import *
+# pylint: enable=unused-import,wildcard-import
index 5c22c62..9ee717d 100644 (file)
@@ -25,6 +25,7 @@ limitations under the License.
 #include <vector>
 
 #include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"  // NOLINT
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/graph/graph.h"
index 4aeea48..7e050a7 100644 (file)
@@ -1,8 +1,24 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the python wrapper for TensorRT graph transforms."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-# pylint: disable=unused-import,wildcard-import
+# pylint: disable=unused-import,line-too-long
 from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
-from tensorflow.contrib.tensorrt.python.trt_convert import CreateInferenceGraph
-# pylint: enable=unused-import,wildcard-import
+from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
+# pylint: enable=unused-import,line-too-long
index 97db237..31a3131 100644 (file)
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # =============================================================================
+"""Exposes the Python wrapper of TRTEngineOp."""
 
 from __future__ import absolute_import
 from __future__ import division
index 5161831..69bbf45 100644 (file)
 # limitations under the License.
 # =============================================================================
 """Exposes the Python wrapper conversion to trt_graph."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-# pylint: disable=unused-import,wildcard-import, line-too-long
+# pylint: disable=unused-import,line-too-long
+import six as _six
+from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
 from tensorflow.core.framework import graph_pb2
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import errors_impl as _impl
-from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
 from tensorflow.python.framework import ops
-import six as _six
+
 
 # TODO(skama): get outputs from session when implemented as c++
 # optimization pass
-def CreateInferenceGraph(input_graph_def,
-                         outputs,
-                         max_batch_size=1,
-                         max_workspace_size_bytes=2 << 20):
+def create_inference_graph(input_graph_def,
+                           outputs,
+                           max_batch_size=1,
+                           max_workspace_size_bytes=2 << 20):
   """Python wrapper for the TRT transormation.
 
 
@@ -42,11 +44,17 @@ def CreateInferenceGraph(input_graph_def,
 
   Returns:
     New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
+
+  Raises:
+    RuntimeError: if the returned status message is malformed.
   """
+
   def py2bytes(inp):
     return inp
+
   def py3bytes(inp):
-    return inp.encode('utf-8', errors='surrogateescape')
+    return inp.encode("utf-8", errors="surrogateescape")
+
   if _six.PY2:
     to_bytes = py2bytes
   else:
@@ -70,16 +78,18 @@ def CreateInferenceGraph(input_graph_def,
                     max_workspace_size_bytes)
   status = out[0]
   output_graph_def_string = to_bytes(out[1])
-  del input_graph_def_str  #save some memory
+  del input_graph_def_str  # Save some memory
   if len(status) < 2:
     raise _impl.UnknownError(None, None, status)
   if status[:2] != "OK":
     msg = status.split(";")
     if len(msg) == 1:
       raise RuntimeError("Status message is malformed {}".format(status))
+    # pylint: disable=protected-access
     raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                          int(msg[0]))
+    # pylint: enable=protected-access
   output_graph_def = graph_pb2.GraphDef()
   output_graph_def.ParseFromString(output_graph_def_string)
-  del output_graph_def_string  #save some memory
+  del output_graph_def_string  # Save some memory
   return output_graph_def
index d7e10c1..93c113e 100644 (file)
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/c/c_api.h"
 #include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/c/c_api.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
index ad7a85c..927a3e4 100644 (file)
@@ -63,8 +63,8 @@ if "__main__" in __name__:
   inpDims = (100, 24, 24, 2)
   dummy_input = np.random.random_sample(inpDims)
   gdef = getSimpleGraphDef()
-  trt_graph = trt.CreateInferenceGraph(gdef, ["output"],
-                                       inpDims[0])  # Get optimized graph
+  trt_graph = trt.create_inference_graph(gdef, ["output"],
+                                         inpDims[0])  # Get optimized graph
   o1 = runGraph(gdef, dummy_input)
   o2 = runGraph(trt_graph, dummy_input)
   assert (np.array_equal(o1, o2))