"trt_engine_op",
"trt_calib_op",
],
- deps = if_tensorrt([
- "@local_config_tensorrt//:nv_infer",
- ]),
)
tf_cuda_library(
] + if_tensorrt([
"@local_config_tensorrt//:nv_infer",
]),
+ kernels = [
+ ":trt_engine_op_kernel",
+ ":trt_engine_op_op_lib",
+ ":trt_calib_op_op_lib",
+ ":trt_shape_function",
+ ],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/util:util_py",
LOG(FATAL) << "input data inconsistent batch size";
break;
}
- switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
+ auto dtype = trt_engine_ptr_->getBindingDataType(binding_index);
+ switch (dtype) {
case nvinfer1::DataType::kFLOAT:
buffers[binding_index] = (void*)(input_tensor.flat<float>().data());
break;
case nvinfer1::DataType::kINT8:
LOG(FATAL) << "int8 is not supported yet!";
break;
+ default:
+ LOG(FATAL) << "Unknown data type: " << int(dtype);
+ break;
}
}
OP_REQUIRES_OK(context,
context->allocate_output(i, output_shape, &output_tensor));
- switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
+ auto dtype = trt_engine_ptr_->getBindingDataType(binding_index);
+ switch (dtype) {
case nvinfer1::DataType::kFLOAT:
buffers[binding_index] =
reinterpret_cast<void*>(output_tensor->flat<float>().data());
case nvinfer1::DataType::kINT8:
LOG(FATAL) << "int8 is not supported yet!";
break;
+ default:
+ LOG(FATAL) << "Unknown data type: " << int(dtype);
+ break;
}
}
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
}
// Executes the network.
-void Execute(nvinfer1::IExecutionContext& context, const float* input,
+void Execute(nvinfer1::IExecutionContext* context, const float* input,
float* output) {
- const nvinfer1::ICudaEngine& engine = context.getEngine();
+ const nvinfer1::ICudaEngine& engine = context->getEngine();
// We have two bindings: input and output.
ASSERT_EQ(engine.getNbBindings(), 2);
// could be removed.
ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
cudaMemcpyHostToDevice, stream));
- context.enqueue(1, buffers, stream, nullptr);
+ context->enqueue(1, buffers, stream, nullptr);
ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
cudaMemcpyDeviceToHost, stream));
cudaStreamSynchronize(stream);
// Execute the network.
float input = 1234;
float output;
- Execute(*context, &input, &output);
+ Execute(context, &input, &output);
EXPECT_EQ(output, input * 2 + 3);
// Destroy the engine.