From: Guangda Lai Date: Thu, 26 Apr 2018 20:12:04 +0000 (-0700) Subject: Fix build by adding op_lib dependencies to trt_engine_op_loader, and remove X-Git-Tag: upstream/v1.9.0_rc1~206^2~1^2~28 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5f06514bff4061b839ee71847a299adbef9e7e03;p=platform%2Fupstream%2Ftensorflow.git Fix build by adding op_lib dependencies to trt_engine_op_loader, and remove unnecessary dependency from the tf_gen_op_libs. PiperOrigin-RevId: 194442728 --- diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index f80b4f1..742be7b 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -102,9 +102,6 @@ tf_gen_op_libs( "trt_engine_op", "trt_calib_op", ], - deps = if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), ) tf_cuda_library( @@ -138,6 +135,12 @@ tf_custom_op_py_library( ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), + kernels = [ + ":trt_engine_op_kernel", + ":trt_engine_op_op_lib", + ":trt_calib_op_op_lib", + ":trt_shape_function", + ], srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/util:util_py", diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 53ba7ba..b8f881c 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -85,7 +85,8 @@ void TRTEngineOp::Compute(OpKernelContext* context) { LOG(FATAL) << "input data inconsistent batch size"; break; } - switch (trt_engine_ptr_->getBindingDataType(binding_index)) { + auto dtype = trt_engine_ptr_->getBindingDataType(binding_index); + switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = (void*)(input_tensor.flat().data()); break; @@ -95,6 +96,9 @@ void TRTEngineOp::Compute(OpKernelContext* context) { case nvinfer1::DataType::kINT8: LOG(FATAL) << "int8 is not supported yet!"; break; + default: + LOG(FATAL) << "Unknown data type: " << int(dtype); + break; } } @@ -120,7 +124,8 @@ void TRTEngineOp::Compute(OpKernelContext* context) { OP_REQUIRES_OK(context, context->allocate_output(i, output_shape, &output_tensor)); - switch (trt_engine_ptr_->getBindingDataType(binding_index)) { + auto dtype = trt_engine_ptr_->getBindingDataType(binding_index); + switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = reinterpret_cast(output_tensor->flat().data()); @@ -131,6 +136,9 @@ void TRTEngineOp::Compute(OpKernelContext* context) { case nvinfer1::DataType::kINT8: LOG(FATAL) << "int8 is not supported yet!"; break; + default: + LOG(FATAL) << "Unknown data type: " << int(dtype); + break; } } // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc index e11522e..3712a9a 100644 --- a/tensorflow/contrib/tensorrt/tensorrt_test.cc +++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc @@ -95,9 +95,9 @@ nvinfer1::IHostMemory* CreateNetwork() { } // Executes the network. -void Execute(nvinfer1::IExecutionContext& context, const float* input, +void Execute(nvinfer1::IExecutionContext* context, const float* input, float* output) { - const nvinfer1::ICudaEngine& engine = context.getEngine(); + const nvinfer1::ICudaEngine& engine = context->getEngine(); // We have two bindings: input and output. ASSERT_EQ(engine.getNbBindings(), 2); @@ -118,7 +118,7 @@ void Execute(nvinfer1::IExecutionContext& context, const float* input, // could be removed. ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float), cudaMemcpyHostToDevice, stream)); - context.enqueue(1, buffers, stream, nullptr); + context->enqueue(1, buffers, stream, nullptr); ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float), cudaMemcpyDeviceToHost, stream)); cudaStreamSynchronize(stream); @@ -143,7 +143,7 @@ TEST(TensorrtTest, BasicFunctions) { // Execute the network. float input = 1234; float output; - Execute(*context, &input, &output); + Execute(context, &input, &output); EXPECT_EQ(output, input * 2 + 3); // Destroy the engine.