From 7e3e661d35a80afd075db80d0dc7ba5c5f9911a1 Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Mon, 14 May 2018 08:27:42 -0700 Subject: [PATCH] Fix various formatting and build issues. --- tensorflow/contrib/tensorrt/BUILD | 2 + .../contrib/tensorrt/convert/convert_nodes.cc | 4 +- .../contrib/tensorrt/custom_plugin_examples/BUILD | 12 ++- .../tensorrt/custom_plugin_examples/__init__.py | 2 +- .../tensorrt/custom_plugin_examples/inc_op.py | 1 + .../custom_plugin_examples/inc_op_kernel.cu.cc | 3 +- .../custom_plugin_examples/inc_op_kernel.h | 6 +- .../custom_plugin_examples/inc_op_plugin.cc | 3 +- .../custom_plugin_examples/inc_op_plugin.h | 6 +- .../tensorrt/custom_plugin_examples/ops/inc_op.cc | 2 +- .../tensorrt/custom_plugin_examples/plugin_test.py | 102 +++++++++++---------- .../contrib/tensorrt/kernels/trt_engine_op.cc | 3 +- tensorflow/contrib/tensorrt/plugin/trt_plugin.h | 10 +- .../contrib/tensorrt/plugin/trt_plugin_factory.cc | 8 +- .../contrib/tensorrt/plugin/trt_plugin_factory.h | 32 +++---- .../tensorrt/plugin/trt_plugin_factory_test.cc | 17 ++-- .../contrib/tensorrt/plugin/trt_plugin_utils.h | 6 +- 17 files changed, 115 insertions(+), 104 deletions(-) diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 467c962..7a8a71a 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -302,6 +302,7 @@ tf_cuda_library( "plugin/trt_plugin_utils.h", ], deps = [ + "//tensorflow/core:framework_lite", "//tensorflow/core:platform_base", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", @@ -318,6 +319,7 @@ tf_cuda_cc_test( ], deps = [ ":trt_plugins", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ] + if_tensorrt([ diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index f043237..32b211d 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -1223,8 +1223,8 @@ tensorflow::Status ConvertPlugin(Converter& ctx, } } - nvinfer1::IPluginLayer* layer = - ctx.network()->addPlugin(&all_inputs[0], int(inputs.size()), *plugin); + nvinfer1::IPluginLayer* layer = ctx.network()->addPlugin( + &all_inputs[0], static_cast(inputs.size()), *plugin); for (int i = 0; i < layer->getNbOutputs(); i++) { nvinfer1::ITensor* output_tensor = layer->getOutput(i); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index 6f81ac2..a89cf3a 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -6,18 +6,18 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) +licenses(["notice"]) # Apache 2.0 + load( "//tensorflow:tensorflow.bzl", - "tf_copts", "tf_custom_op_library", "tf_custom_op_library_additional_deps", "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library", ) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_py_test") -load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( "@local_config_tensorrt//:build_defs.bzl", "if_tensorrt", @@ -46,6 +46,7 @@ tf_custom_op_library( ], deps = [ "//tensorflow/contrib/tensorrt:trt_plugins", + "//tensorflow/core:framework_lite", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), @@ -55,6 +56,7 @@ tf_kernel_library( name = "inc_op_plugin_kernel", srcs = ["inc_op_plugin.cc"], hdrs = [ + "inc_op_kernel.h", "inc_op_plugin.h", ], gpu_srcs = [ @@ -63,6 +65,7 @@ tf_kernel_library( ], deps = [ "//tensorflow/contrib/tensorrt:trt_plugins", + "//tensorflow/core:stream_executor_headers_lib", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]) + tf_custom_op_library_additional_deps(), @@ -95,7 +98,7 @@ py_library( ], ) -tf_py_test( +cuda_py_test( name = "plugin_test", size = "small", srcs = ["plugin_test.py"], @@ -109,6 +112,7 @@ tf_py_test( ], tags = [ "manual", + "noguitar", "notap", ], ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py index e06904a..363edab 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.tensorrt.custom_plugin_examples.ops import gen_inc_op from tensorflow.contrib.tensorrt.custom_plugin_examples import inc_op as import_inc_op_so +from tensorflow.contrib.tensorrt.custom_plugin_examples.ops import gen_inc_op inc_op = gen_inc_op.inc_plugin_trt diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py index 47fd55e..a007c3f 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= +"""Loader for the custom inc_op.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index abbc0c5..988b35f 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -18,12 +18,11 @@ limitations under the License. #include #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/stream_executor.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT #include "cuda/include/cuda_runtime_api.h" - +#include "tensorflow/core/platform/stream_executor.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h index 1d0ec0b..c35955e 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_INC_OP -#define TENSORFLOW_CONTRIB_TENSORRT_INC_OP +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_KERNEL_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_KERNEL_H_ #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -32,4 +32,4 @@ void IncrementKernel(const float* d_input, float inc, float* d_output, #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_INC_OP +#endif // TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_KERNEL_H_ diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index d56aedc..8d4c893 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" + +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 6015354..189e9c9 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN -#define TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_PLUGIN_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_PLUGIN_H_ #include #include @@ -99,4 +99,4 @@ class IncOpPlugin : public PluginTensorRT { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN +#endif // TENSORFLOW_CONTRIB_TENSORRT_CUSTOM_PLUGIN_EXAMPLES_INC_OP_PLUGIN_H_ diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc index 7466e59..d0eb0d2 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc @@ -30,7 +30,7 @@ REGISTER_OP("IncPluginTRT") return Status::OK(); }); -} // namespace tensorflow +} // namespace tensorflow #endif // GOOGLE_CUDA #endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index aedfb16..bc4d270 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -27,65 +27,69 @@ from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import importer from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test -def get_plugin_graph_def(): - """Create a simple graph and return its graph_def.""" - g = ops.Graph() - with g.as_default(): - a = array_ops.placeholder( - dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") - relu = nn.relu(a, "relu") - v = nn_ops.max_pool( - relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") +class TrtPluginTest(test_util.TensorFlowTestCase): - # insert custom_op in the graph - v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") + def _get_plugin_graph_def(self): + """Create a simple graph and return its graph_def.""" + g = ops.Graph() + with g.as_default(): + a = array_ops.placeholder( + dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") + relu = nn.relu(a, "relu") + v = nn_ops.max_pool( + relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - v = v * 2.0 - v = nn.relu(v) - v = nn.relu(v) - array_ops.squeeze(v, name="output") - return g.as_graph_def() + # insert custom_op in the graph + v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") + v *= 2.0 + v = nn.relu(v) + v = nn.relu(v) + array_ops.squeeze(v, name="output") + return g.as_graph_def() -def run_graph(gdef, dumm_inp): - """Run given graphdef once.""" - gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] + def _run_graph(self, gdef, dumm_inp): + """Run given graphdef once.""" + gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + ops.reset_default_graph() + g = ops.Graph() + with g.as_default(): + inp, out = importer.import_graph_def( + graph_def=gdef, return_elements=["input", "output"]) + inp = inp.outputs[0] + out = out.outputs[0] - with session.Session( - config=config_pb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: - val = sess.run(out, {inp: dumm_inp}) - return val + with session.Session( + config=config_pb2.ConfigProto(gpu_options=gpu_options), + graph=g) as sess: + val = sess.run(out, {inp: dumm_inp}) + return val + def testIncOpPlugin(self): + inp_dims = (5, 24, 24, 2) + dummy_input = numpy.ones(inp_dims).astype(numpy.float32) + orig_graph = self._get_plugin_graph_def() # graph with plugin node -if "__main__" in __name__: - inp_dims = (5, 24, 24, 2) - dummy_input = numpy.ones(inp_dims).astype(numpy.float32) - orig_graph = get_plugin_graph_def() # graph with plugin node + # trigger conversion. + # plugin nodes have been registered during import, converter will be able to + # create corresponding plugin layer during conversion. + trt_graph = tensorrt.create_inference_graph( + input_graph_def=orig_graph, + outputs=["output"], + max_batch_size=inp_dims[0], + max_workspace_size_bytes=1 << 25, + precision_mode="FP32", + minimum_segment_size=2) + o2 = self._run_graph(trt_graph, dummy_input) + self.assertEqual(35, o2.reshape([-1])[0]) - # trigger conversion. - # plugin nodes have been registered during import, converter will be able to - # create corresponding plugin layer during conversion. - trt_graph = tensorrt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="FP32", - minimum_segment_size=2) - o2 = run_graph(trt_graph, dummy_input) - if o2.reshape([-1])[0] == 35: - print("pass") - else: - raise RuntimeError("contrib/tensorrt/custom_plugin_examples wrong result") + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index d84fc8a..9ac8047 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -60,7 +60,8 @@ void TRTEngineOp::Compute(OpKernelContext* context) { infer->setGpuAllocator(allocator_.get()); #endif trt_engine_ptr_.reset(infer->deserializeCudaEngine( - serialized_engine_.c_str(), serialized_engine_.size(), PluginFactoryTensorRT::GetInstance())); + serialized_engine_.c_str(), serialized_engine_.size(), + PluginFactoryTensorRT::GetInstance())); trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); // Runtime is safe to delete after engine creation infer->destroy(); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index d80ec44..754920b 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ #include #include @@ -55,9 +55,9 @@ class PluginTensorRT : public nvinfer1::IPlugin { virtual bool StoreAttribute(const string& key, const void* ptr, const size_t size); - virtual size_t getSerializationSize() override; + size_t getSerializationSize() override; - virtual void serialize(void* buffer) override; + void serialize(void* buffer) override; protected: std::unordered_map > attr_map_; @@ -71,4 +71,4 @@ class PluginTensorRT : public nvinfer1::IPlugin { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 736a132..2bc5914 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -33,7 +33,7 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, return nullptr; } - std::lock_guard lock(instance_m_); + tensorflow::mutex_lock lock(instance_m_); auto plugin_ptr = plugin_registry_[encoded_op_name].first(serial_data, serial_length); owned_plugins_.emplace_back(plugin_ptr); @@ -44,7 +44,7 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string& op_name) { if (!IsPlugin(op_name)) return nullptr; - std::lock_guard lock(instance_m_); + tensorflow::mutex_lock lock(instance_m_); auto plugin_ptr = plugin_registry_[op_name].second(); owned_plugins_.emplace_back(plugin_ptr); @@ -56,7 +56,7 @@ bool PluginFactoryTensorRT::RegisterPlugin( PluginConstructFunc construct_func) { if (IsPlugin(op_name)) return false; - std::lock_guard lock(instance_m_); + tensorflow::mutex_lock lock(instance_m_); auto ret = plugin_registry_.emplace( op_name, std::make_pair(deserialize_func, construct_func)); @@ -64,7 +64,7 @@ bool PluginFactoryTensorRT::RegisterPlugin( } void PluginFactoryTensorRT::DestroyPlugins() { - std::lock_guard lock(instance_m_); + tensorflow::mutex_lock lock(instance_m_); for (auto& owned_plugin_ptr : owned_plugins_) { owned_plugin_ptr.release(); } diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 0eee705..bbae9fb 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ #include -#include #include #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -69,13 +69,12 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { // TODO(jie): Owned plugin should be associated with different sessions; // should really hand ownership of plugins to resource management; std::vector> owned_plugins_; - std::mutex instance_m_; + tensorflow::mutex instance_m_; }; class TrtPluginRegistrar { public: - TrtPluginRegistrar(const string& name, - PluginDeserializeFunc deserialize_func, + TrtPluginRegistrar(const string& name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func) { auto factory = PluginFactoryTensorRT::GetInstance(); QCHECK(factory->RegisterPlugin(name, deserialize_func, construct_func)) @@ -83,17 +82,16 @@ class TrtPluginRegistrar { } }; -#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \ - REGISTER_TRT_PLUGIN_UNIQ_HELPER( \ - __COUNTER__, name, deserialize_func, construct_func) -#define REGISTER_TRT_PLUGIN_UNIQ_HELPER( \ - ctr, name, deserialize_func, construct_func) \ - REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) +#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \ + REGISTER_TRT_PLUGIN_UNIQ_HELPER(__COUNTER__, name, deserialize_func, \ + construct_func) +#define REGISTER_TRT_PLUGIN_UNIQ_HELPER(ctr, name, deserialize_func, \ + construct_func) \ + REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) #define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \ - static ::tensorflow::tensorrt::TrtPluginRegistrar \ - trt_plugin_registrar##ctr TF_ATTRIBUTE_UNUSED = \ - ::tensorflow::tensorrt::TrtPluginRegistrar( \ - name, deserialize_func, construct_func) + static ::tensorflow::tensorrt::TrtPluginRegistrar trt_plugin_registrar##ctr \ + TF_ATTRIBUTE_UNUSED = ::tensorflow::tensorrt::TrtPluginRegistrar( \ + name, deserialize_func, construct_func) } // namespace tensorrt } // namespace tensorflow @@ -101,4 +99,4 @@ class TrtPluginRegistrar { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc index c5b0e75..129bdcd 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" @@ -37,16 +38,17 @@ class StubPlugin : public PluginTensorRT { StubPlugin(const void* serialized_data, size_t length) : PluginTensorRT(serialized_data, length) {} - const string& GetPluginName() override { return plugin_name_; } + const string& GetPluginName() const override { return plugin_name_; } - virtual bool Finalize() { return true; } + bool Finalize() override { return true; } - virtual bool SetAttribute(const string& key, const void* ptr, - const size_t size) { + bool SetAttribute(const string& key, const void* ptr, + const size_t size) override { return true; } - virtual bool GetAttribute(const string& key, const void* ptr, size_t& size) { + bool GetAttribute(const string& key, const void** ptr, + size_t* size) const override { return true; } @@ -89,8 +91,7 @@ class TrtPluginFactoryTest : public ::testing::Test { return true; } return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( - StubPlugin::kPluginName, CreateStubPluginDeserialize, - CreateStubPlugin); + StubPlugin::kPluginName, CreateStubPluginDeserialize, CreateStubPlugin); } }; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h index 4ff6fbe..274ce42 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ #include @@ -43,4 +43,4 @@ string ExtractOpName(const void* serial_data, size_t serial_length, #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ -- 2.7.4