import tensorflow.contrib.tensorrt as trt
#... create and train or load model
gdef = sess.graph.as_graph_def()
-trt_gdef = trt.CreateInferenceGraph(
+trt_gdef = trt.create_inference_graph(
gdef, #original graph_def
["output"], #name of output node(s)
max_batch_size, #maximum batch size to run the inference
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
+"""Exposes the python wrapper for TensorRT graph transforms."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.tensorrt.python import *
+# pylint: enable=unused-import,wildcard-import
#include <vector>
#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the python wrapper for TensorRT graph transforms."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import,wildcard-import
+# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
-from tensorflow.contrib.tensorrt.python.trt_convert import CreateInferenceGraph
-# pylint: enable=unused-import,wildcard-import
+from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
+# pylint: enable=unused-import,line-too-long
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
+"""Exposes the Python wrapper of TRTEngineOp."""
from __future__ import absolute_import
from __future__ import division
# limitations under the License.
# =============================================================================
"""Exposes the Python wrapper conversion to trt_graph."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import,wildcard-import, line-too-long
+# pylint: disable=unused-import,line-too-long
+import six as _six
+from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl as _impl
-from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
from tensorflow.python.framework import ops
-import six as _six
+
# TODO(skama): get outputs from session when implemented as c++
# optimization pass
-def CreateInferenceGraph(input_graph_def,
- outputs,
- max_batch_size=1,
- max_workspace_size_bytes=2 << 20):
+def create_inference_graph(input_graph_def,
+ outputs,
+ max_batch_size=1,
+ max_workspace_size_bytes=2 << 20):
"""Python wrapper for the TRT transormation.
Returns:
New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
+
+ Raises:
+ RuntimeError: if the returned status message is malformed.
"""
+
def py2bytes(inp):
return inp
+
def py3bytes(inp):
- return inp.encode('utf-8', errors='surrogateescape')
+ return inp.encode("utf-8", errors="surrogateescape")
+
if _six.PY2:
to_bytes = py2bytes
else:
max_workspace_size_bytes)
status = out[0]
output_graph_def_string = to_bytes(out[1])
- del input_graph_def_str #save some memory
+ del input_graph_def_str # Save some memory
if len(status) < 2:
raise _impl.UnknownError(None, None, status)
if status[:2] != "OK":
msg = status.split(";")
if len(msg) == 1:
raise RuntimeError("Status message is malformed {}".format(status))
+ # pylint: disable=protected-access
raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
int(msg[0]))
+ # pylint: enable=protected-access
output_graph_def = graph_pb2.GraphDef()
output_graph_def.ParseFromString(output_graph_def_string)
- del output_graph_def_string #save some memory
+ del output_graph_def_string # Save some memory
return output_graph_def
limitations under the License.
==============================================================================*/
-#include "tensorflow/c/c_api.h"
#include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/c/c_api.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
inpDims = (100, 24, 24, 2)
dummy_input = np.random.random_sample(inpDims)
gdef = getSimpleGraphDef()
- trt_graph = trt.CreateInferenceGraph(gdef, ["output"],
- inpDims[0]) # Get optimized graph
+ trt_graph = trt.create_inference_graph(gdef, ["output"],
+ inpDims[0]) # Get optimized graph
o1 = runGraph(gdef, dummy_input)
o2 = runGraph(trt_graph, dummy_input)
assert (np.array_equal(o1, o2))