#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
// TODO(samikama) runtime should be taken from a resourcemanager as well.
- // Only engine should be in the op and context and runtime should be taken
- // from resourcemanager
- // TODO(jie): Relying on TF scheme to limit gpu scope for device placement
- // cannot have dependency on //tensorflow/core:gpu_runtimeo
- // Copied the function here.
+ // Only engine should be in the op and context and runtime should be taken
+ // from resourcemanager
+ // TODO(jie): cudaSetDevice make sure trt engine is allocated on the same
+ // gpu where the input/output is also located.
int gpu_id = context->device()->tensorflow_gpu_device_info()->gpu_id;
- auto result = gpu::MultiPlatformManager::PlatformWithName("CUDA");
- if (!result.ok()) {
- LOG(FATAL) << "Could not find Platform with name CUDA";
- }
- gpu::Platform* gpu_machine_manager = result.ValueOrDie();
- gpu::cuda::ScopedActivateExecutorContext scoped_activation{
- gpu_machine_manager->ExecutorForDevice(gpu_id).ValueOrDie()};
+ cudaSetDevice(gpu_id);
+ int device;
+ cudaGetDevice(&device);
+ if (gpu_id != device) LOG(FATAL) << "set device failed!";
// TODO(samikama) runtime should be taken from a resourcemanager as well.
// Only engine should be in the op and context and runtime should be taken