Revert the changes of ScopedActivateExecutorContext, which requires depending on...
authorgracehoney <31743510+aaroey@users.noreply.github.com>
Tue, 6 Mar 2018 22:59:33 +0000 (14:59 -0800)
committergracehoney <31743510+aaroey@users.noreply.github.com>
Tue, 6 Mar 2018 22:59:33 +0000 (14:59 -0800)
tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc

index 3f98e64..b32371b 100644 (file)
@@ -18,7 +18,6 @@ limitations under the License.
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/stream_executor.h"
 #include "tensorflow/core/platform/types.h"
-#include "tensorflow/stream_executor/cuda/cuda_activation.h"
 
 #if GOOGLE_CUDA
 #if GOOGLE_TENSORRT
@@ -43,19 +42,15 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
   OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
 
   // TODO(samikama) runtime should be taken from a resourcemanager as well.
-  //  Only engine should be in the op and context and runtime should be taken
-  //  from resourcemanager
-  // TODO(jie): Relying on TF scheme to limit gpu scope for device placement
-  //  cannot have dependency on //tensorflow/core:gpu_runtimeo
-  //  Copied the function here.
+  // Only engine should be in the op and context and runtime should be taken
+  // from resourcemanager
+  // TODO(jie): cudaSetDevice make sure trt engine is allocated on the same
+  // gpu where the input/output is also located.
   int gpu_id = context->device()->tensorflow_gpu_device_info()->gpu_id;
-  auto result = gpu::MultiPlatformManager::PlatformWithName("CUDA");
-  if (!result.ok()) {
-    LOG(FATAL) << "Could not find Platform with name CUDA";
-  }
-  gpu::Platform* gpu_machine_manager = result.ValueOrDie();
-  gpu::cuda::ScopedActivateExecutorContext scoped_activation{
-      gpu_machine_manager->ExecutorForDevice(gpu_id).ValueOrDie()};
+  cudaSetDevice(gpu_id);
+  int device;
+  cudaGetDevice(&device);
+  if (gpu_id != device) LOG(FATAL) << "set device failed!";
 
   // TODO(samikama) runtime should be taken from a resourcemanager as well.
   // Only engine should be in the op and context and runtime should be taken