Dropping support for CUDA < 8.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 18 May 2018 13:31:20 +0000 (06:31 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 18 May 2018 13:33:41 +0000 (06:33 -0700)
PiperOrigin-RevId: 197137612

tensorflow/core/kernels/batch_matmul_op_real.cc
tensorflow/core/kernels/matmul_op.cc
tensorflow/core/kernels/relu_op_gpu.cu.cc
tensorflow/core/util/cuda_kernel_helper.h
tensorflow/core/util/port.cc
tensorflow/stream_executor/cuda/cuda_blas.cc
tensorflow/stream_executor/cuda/cuda_driver.cc
tensorflow/stream_executor/cuda/cuda_driver.h

index 7e1e2aa..97cec3a 100644 (file)
@@ -27,9 +27,8 @@ TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
 #if GOOGLE_CUDA
 TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
 TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
-#if CUDA_VERSION >= 7050
-TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
-#endif
+// TODO(csigg): Implement Stream::ThenBlasGemv for Eigen::half and uncomment.
+// TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
 #endif  // GOOGLE_CUDA
 
 #ifdef TENSORFLOW_USE_SYCL
index 3664f95..f9c15ce 100644 (file)
@@ -577,9 +577,7 @@ TF_CALL_float(REGISTER_GPU);
 TF_CALL_double(REGISTER_GPU);
 TF_CALL_complex64(REGISTER_GPU);
 TF_CALL_complex128(REGISTER_GPU);
-#if CUDA_VERSION >= 7050
 TF_CALL_half(REGISTER_GPU);
-#endif
 #endif  // GOOGLE_CUDA
 
 #ifdef TENSORFLOW_USE_SYCL
index 6e46c97..089ca8e 100644 (file)
@@ -31,8 +31,6 @@ namespace tensorflow {
 typedef Eigen::GpuDevice GPUDevice;
 
 namespace functor {
-#ifdef TF_HAS_CUDA_FP16
-
 // This kernel computes ReluGrad by processing one half2, two fp16, at a time.
 // It effectively does: backdrops = (feature > 0) ? gradient : 0
 // It also tries to use native half2 primitives as much as possible.
@@ -113,8 +111,6 @@ struct ReluGrad<Device, Eigen::half> {
                                        backprop.data(), count);
   }
 };
-
-#endif  // TF_HAS_CUDA_FP16
 }  // namespace functor
 
 // Definition of the GPU implementations declared in relu_op.cc.
index 0ab8756..540adb5 100644 (file)
@@ -21,10 +21,7 @@ limitations under the License.
 #include "tensorflow/core/util/cuda_device_functions.h"
 #include "tensorflow/core/util/cuda_launch_config.h"
 
-#if CUDA_VERSION >= 7050
 #include "cuda/include/cuda_fp16.h"
-#define TF_HAS_CUDA_FP16
-#endif
 
 // Deprecated, use 'for(int i : CudaGridRangeX(n))' instead.
 #define CUDA_1D_KERNEL_LOOP(i, n) \
index 490c584..c081cea 100644 (file)
@@ -31,9 +31,7 @@ bool IsGoogleCudaEnabled() {
 
 bool CudaSupportsHalfMatMulAndConv() {
 #if GOOGLE_CUDA
-  // NOTE: We check compile-time and not runtime, since the check for
-  // whether we include the fp16 kernels or not is compile-time.
-  return CUDA_VERSION >= 7050;
+  return true;
 #else
   return false;
 #endif
index dcc3f7a..3e9a23c 100644 (file)
@@ -16,11 +16,7 @@ limitations under the License.
 #include "cuda/include/cublas_v2.h"
 #include "cuda/include/cuda.h"
 
-#if CUDA_VERSION >= 8000
 #define SE_CUDA_DATA_HALF CUDA_R_16F
-#else
-#define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF
-#endif
 
 #include "tensorflow/stream_executor/cuda/cuda_blas.h"
 
@@ -45,10 +41,8 @@ limitations under the License.
 // approach when the issue is fixed.
 #if CUDA_VERSION < 9000
 #include "cuda/include/cuda_fp16.h"
-#if CUDA_VERSION >= 7050
 #define EIGEN_HAS_CUDA_FP16
 #endif
-#endif
 
 #include "third_party/eigen3/Eigen/Core"
 
@@ -543,9 +537,7 @@ cublasSideMode_t CUDABlasSide(blas::Side side) {
 // blas::ComputationType to a cudaDataType_t.
 //
 // These are used to build the argument type and computation type args to
-// cublasGemmEx.  cublasGemmEx and cudaDataType_t are available only on
-// CUDA >= 8.0.
-#if CUDA_VERSION >= 8000
+// cublasGemmEx.
 template <typename T>
 struct CUDADataType;
 
@@ -620,8 +612,6 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
       return CUDA_C_64F;
   }
 }
-#endif
-
 }  // namespace
 
 template <typename FuncT, typename... Args>
@@ -2229,7 +2219,6 @@ bool CUDABlas::GetBlasGemmAlgorithms(
 // Note that when CUDA version and compute capability is not sufficient, we
 // still return the out_algorithms. Caller needs to make sure that in this case,
 // the returned vector is empty.
-#if CUDA_VERSION >= 8000
   for (cublasGemmAlgo_t algo : {
          CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
              CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
@@ -2245,7 +2234,6 @@ bool CUDABlas::GetBlasGemmAlgorithms(
        }) {
     out_algorithms->push_back(algo);
   }
-#endif
   return true;
 }
 
index e7e4192..273ed83 100644 (file)
@@ -26,16 +26,16 @@ limitations under the License.
 #include "tensorflow/stream_executor/lib/env.h"
 #include "tensorflow/stream_executor/lib/error.h"
 #include "tensorflow/stream_executor/lib/human_readable.h"
+#include "tensorflow/stream_executor/lib/inlined_vector.h"
 #include "tensorflow/stream_executor/lib/notification.h"
-#include "tensorflow/stream_executor/lib/threadpool.h"
 #include "tensorflow/stream_executor/lib/stacktrace.h"
 #include "tensorflow/stream_executor/lib/static_threadlocal.h"
 #include "tensorflow/stream_executor/lib/strcat.h"
 #include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/lib/threadpool.h"
 #include "tensorflow/stream_executor/platform/logging.h"
 #include "tensorflow/stream_executor/platform/mutex.h"
 #include "tensorflow/stream_executor/platform/port.h"
-#include "tensorflow/stream_executor/lib/inlined_vector.h"
 
 bool FLAGS_gpuexec_cuda_driver_inject_init_error = false;
 bool FLAGS_gpuexec_cuda_sync_around_driver_calls = false;
@@ -204,11 +204,11 @@ string ToString(CUresult result) {
     case 719:
       return "CUDA_ERROR_LAUNCH_FAILED";
 
-    OSTREAM_CUDA_ERROR(CONTEXT_ALREADY_IN_USE)
-    OSTREAM_CUDA_ERROR(PEER_ACCESS_UNSUPPORTED)
-    OSTREAM_CUDA_ERROR(NOT_PERMITTED)
-    OSTREAM_CUDA_ERROR(NOT_SUPPORTED)
-    OSTREAM_CUDA_ERROR(UNKNOWN)  // Unknown internal error to CUDA.
+      OSTREAM_CUDA_ERROR(CONTEXT_ALREADY_IN_USE)
+      OSTREAM_CUDA_ERROR(PEER_ACCESS_UNSUPPORTED)
+      OSTREAM_CUDA_ERROR(NOT_PERMITTED)
+      OSTREAM_CUDA_ERROR(NOT_SUPPORTED)
+      OSTREAM_CUDA_ERROR(UNKNOWN)  // Unknown internal error to CUDA.
     default:
       return port::StrCat("CUresult(", static_cast<int>(result), ")");
   }
@@ -470,7 +470,8 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options,
 }
 
 /* static */ port::Status CUDADriver::CreateContext(
-    CUdevice device, DeviceOptions device_options, CudaContext** context) {
+    CUdevice device, const DeviceOptions &device_options,
+    CudaContext **context) {
   *context = nullptr;
 
   int flags = 0;
@@ -481,62 +482,45 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options,
   CUresult res;
   CUcontext former_context;
   CUcontext new_context;
-  {
-    // TODO(leary) Need to see if NVIDIA can expunge the leakiness in their
-    // context creation: see http://b/13248943
 
-#if CUDA_VERSION >= 7000
-    {
-      unsigned int former_primary_context_flags;
-      int former_primary_context_is_active;
-      CHECK_EQ(CUDA_SUCCESS,
-               cuDevicePrimaryCtxGetState(device, &former_primary_context_flags,
-                                          &former_primary_context_is_active));
-      if (former_primary_context_flags != flags) {
-        if (former_primary_context_is_active) {
-          LOG(ERROR)
-              << "The primary context is active and has a different flag set ("
-              << former_primary_context_flags << ") than the desired flag set ("
-              << flags << ").";
-        } else {
-          CHECK_EQ(CUDA_SUCCESS, cuDevicePrimaryCtxSetFlags(device, flags));
-        }
-      }
+  unsigned int former_primary_context_flags;
+  int former_primary_context_is_active;
+  CHECK_EQ(CUDA_SUCCESS,
+           cuDevicePrimaryCtxGetState(device, &former_primary_context_flags,
+                                      &former_primary_context_is_active));
+  if (former_primary_context_flags != flags) {
+    if (former_primary_context_is_active) {
+      LOG(ERROR)
+          << "The primary context is active and has a different flag set ("
+          << former_primary_context_flags << ") than the desired flag set ("
+          << flags << ").";
+    } else {
+      CHECK_EQ(CUDA_SUCCESS, cuDevicePrimaryCtxSetFlags(device, flags));
     }
+  }
 
-    former_context = CUDADriver::CurrentContextOrDie();
-    res = cuDevicePrimaryCtxRetain(&new_context, device);
-    if (former_context != nullptr) {
-      CUdevice former_device;
-      if (cuCtxGetDevice(&former_device) == CUDA_SUCCESS) {
-        if (former_device == device) {
-          if (former_context == new_context) {
-            VLOG(2) << "The primary context " << former_context
-                    << " for device " << device
-                    << " exists before initializing the StreamExecutor.";
-          } else {
-            LOG(WARNING)
-                << "A non-primary context " << former_context << " for device "
-                << device
-                << " exists before initializing the StreamExecutor. The "
-                << "primary context is now " << new_context << ". We "
-                << "haven't verified StreamExecutor works with that.";
-          }
+  former_context = CUDADriver::CurrentContextOrDie();
+  res = cuDevicePrimaryCtxRetain(&new_context, device);
+  if (former_context != nullptr) {
+    CUdevice former_device;
+    if (cuCtxGetDevice(&former_device) == CUDA_SUCCESS) {
+      if (former_device == device) {
+        if (former_context == new_context) {
+          VLOG(2) << "The primary context " << former_context << " for device "
+                  << device
+                  << " exists before initializing the StreamExecutor.";
+        } else {
+          LOG(WARNING) << "A non-primary context " << former_context
+                       << " for device " << device
+                       << " exists before initializing the StreamExecutor. The "
+                       << "primary context is now " << new_context << ". We "
+                       << "haven't verified StreamExecutor works with that.";
         }
-      } else {
-        LOG(ERROR) << "Failed to get the device of the current context "
-                   << former_context;
       }
+    } else {
+      LOG(ERROR) << "Failed to get the device of the current context "
+                 << former_context;
     }
-#else
-    former_context = CurrentContext();
-    if (former_context != nullptr) {
-      LOG(WARNING)
-          << "creating context when one is currently active; existing: "
-          << former_context;
-    }
-    res = cuCtxCreate(&new_context, flags, device);
-#endif
   }
   CHECK_EQ(CUDA_SUCCESS, cuCtxSetCurrent(former_context));
 
@@ -548,11 +532,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options,
     return port::Status::OK();
   }
 
-#if CUDA_VERSION >= 7000
   string message = "failed call to cuDevicePrimaryCtxRetain: " + ToString(res);
-#else
-  string message = "failed call to cuCtxCreate: " + ToString(res);
-#endif
   if (res == CUDA_ERROR_OUT_OF_MEMORY) {
     uint64 total_memory;
     if (GetDeviceTotalMemory(device, &total_memory)) {
@@ -569,7 +549,6 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options,
   if (context == nullptr) {
     return;
   }
-#if CUDA_VERSION >= 7000
   CUcontext former_context = CurrentContext();
   CUresult res = cuCtxSetCurrent(context->context());
   CUdevice device;
@@ -577,9 +556,6 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options,
   cuCtxSetCurrent(former_context);
 
   res = cuDevicePrimaryCtxRelease(device);
-#else
-  CUresult res = cuCtxDestroy(context->context());
-#endif
 
   if (res != CUDA_SUCCESS) {
     LOG(ERROR) << "failed to release CUDA context; leaking: " << ToString(res);
index a9969e2..b952cfa 100644 (file)
@@ -147,7 +147,7 @@ class CUDADriver {
   // userspace processes is given here:
   // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g65dc0012348bc84810e2103a40d8e2cf
   static port::Status CreateContext(CUdevice device,
-                                    DeviceOptions device_options,
+                                    const DeviceOptions& device_options,
                                     CudaContext** context);
 
   // Destroys the provided context via cuCtxDestroy.