Fix build by adding op_lib dependencies to trt_engine_op_loader, and remove
authorGuangda Lai <laigd@google.com>
Thu, 26 Apr 2018 20:12:04 +0000 (13:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 20:14:59 +0000 (13:14 -0700)
unnecessary dependency from the tf_gen_op_libs.

PiperOrigin-RevId: 194442728

tensorflow/contrib/tensorrt/BUILD
tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
tensorflow/contrib/tensorrt/tensorrt_test.cc

index f80b4f1..742be7b 100644 (file)
@@ -102,9 +102,6 @@ tf_gen_op_libs(
         "trt_engine_op",
         "trt_calib_op",
     ],
-    deps = if_tensorrt([
-        "@local_config_tensorrt//:nv_infer",
-    ]),
 )
 
 tf_cuda_library(
@@ -138,6 +135,12 @@ tf_custom_op_py_library(
     ] + if_tensorrt([
         "@local_config_tensorrt//:nv_infer",
     ]),
+    kernels = [
+        ":trt_engine_op_kernel",
+        ":trt_engine_op_op_lib",
+        ":trt_calib_op_op_lib",
+        ":trt_shape_function",
+    ],
     srcs_version = "PY2AND3",
     deps = [
         "//tensorflow/contrib/util:util_py",
index 53ba7ba..b8f881c 100644 (file)
@@ -85,7 +85,8 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
       LOG(FATAL) << "input data inconsistent batch size";
       break;
     }
-    switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
+    auto dtype = trt_engine_ptr_->getBindingDataType(binding_index);
+    switch (dtype) {
       case nvinfer1::DataType::kFLOAT:
         buffers[binding_index] = (void*)(input_tensor.flat<float>().data());
         break;
@@ -95,6 +96,9 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
       case nvinfer1::DataType::kINT8:
         LOG(FATAL) << "int8 is not supported yet!";
         break;
+      default:
+        LOG(FATAL) << "Unknown data type: " << int(dtype);
+        break;
     }
   }
 
@@ -120,7 +124,8 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
 
     OP_REQUIRES_OK(context,
                    context->allocate_output(i, output_shape, &output_tensor));
-    switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
+    auto dtype = trt_engine_ptr_->getBindingDataType(binding_index);
+    switch (dtype) {
       case nvinfer1::DataType::kFLOAT:
         buffers[binding_index] =
             reinterpret_cast<void*>(output_tensor->flat<float>().data());
@@ -131,6 +136,9 @@ void TRTEngineOp::Compute(OpKernelContext* context) {
       case nvinfer1::DataType::kINT8:
         LOG(FATAL) << "int8 is not supported yet!";
         break;
+      default:
+        LOG(FATAL) << "Unknown data type: " << int(dtype);
+        break;
     }
   }
   // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
index e11522e..3712a9a 100644 (file)
@@ -95,9 +95,9 @@ nvinfer1::IHostMemory* CreateNetwork() {
 }
 
 // Executes the network.
-void Execute(nvinfer1::IExecutionContext& context, const float* input,
+void Execute(nvinfer1::IExecutionContext* context, const float* input,
              float* output) {
-  const nvinfer1::ICudaEngine& engine = context.getEngine();
+  const nvinfer1::ICudaEngine& engine = context->getEngine();
 
   // We have two bindings: input and output.
   ASSERT_EQ(engine.getNbBindings(), 2);
@@ -118,7 +118,7 @@ void Execute(nvinfer1::IExecutionContext& context, const float* input,
   // could be removed.
   ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
                                cudaMemcpyHostToDevice, stream));
-  context.enqueue(1, buffers, stream, nullptr);
+  context->enqueue(1, buffers, stream, nullptr);
   ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
                                cudaMemcpyDeviceToHost, stream));
   cudaStreamSynchronize(stream);
@@ -143,7 +143,7 @@ TEST(TensorrtTest, BasicFunctions) {
   // Execute the network.
   float input = 1234;
   float output;
-  Execute(*context, &input, &output);
+  Execute(context, &input, &output);
   EXPECT_EQ(output, input * 2 + 3);
 
   // Destroy the engine.