Don't call cudaStreamDestroy at destruction time (#15692)
authorDmytro Dzhulgakov <dzhulgakov@fb.com>
Fri, 11 Jan 2019 20:32:50 +0000 (12:32 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 20:36:41 +0000 (12:36 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15692

It was leading to ocassional crashes with dynamically linked CUDA because runtime was already destroyed.

Also, unique_ptr<T[]> is more suitable than deque<T> for the purpose.

Reviewed By: Yangqing

Differential Revision: D13571988

fbshipit-source-id: 37eb26dfbe361c49160367b53f87bd037c6c0e46

c10/cuda/CUDAFunctions.h
c10/cuda/CUDAMacros.h
c10/cuda/CUDAStream.cpp
caffe2/contrib/nccl/cuda_nccl_op_gpu.cc
caffe2/core/common_gpu.h
caffe2/core/context_gpu.cu
caffe2/core/context_gpu.h
caffe2/core/cudnn_wrappers.h
caffe2/core/hip/miopen_wrapper.h

index a572516..a01ae1a 100644 (file)
@@ -11,6 +11,7 @@
 
 #include <c10/macros/Macros.h>
 #include <c10/core/Device.h>
+#include <c10/cuda/CUDAException.h>
 
 namespace c10 {
 namespace cuda {
index 9178425..cd63cec 100644 (file)
@@ -31,3 +31,8 @@
 #else
 #define C10_CUDA_API C10_CUDA_IMPORT
 #endif
+
+/**
+ * The maximum number of GPUs that we recognizes.
+ */
+#define C10_COMPILE_TIME_MAX_GPUS 16
index 13cc0df..393826f 100644 (file)
@@ -1,26 +1,35 @@
 #include <c10/cuda/CUDAStream.h>
-#include <c10/cuda/CUDAGuard.h>
 #include <c10/cuda/CUDAFunctions.h>
+#include <c10/cuda/CUDAGuard.h>
 #include <c10/util/Exception.h>
 
-#include <mutex>
+#include <array>
 #include <atomic>
 #include <cstdint>
-#include <deque>
+#include <mutex>
 #include <vector>
-#include <array>
+
+#include <iostream>
 
 namespace c10 {
 namespace cuda {
 
 namespace {
 
-// Internal implementation is entirely hidden
-struct CUDAStreamInternals {
-  CUDAStreamInternals() = default;
+// Internal implementation that leaks the stream. It's not intended to be used
+// outside of this file.
+struct LeakyStreamInternals {
+  LeakyStreamInternals() = default;
+  C10_DISABLE_COPY_AND_ASSIGN(LeakyStreamInternals);
+
+  ~LeakyStreamInternals() {
+    // NB: this code is invoked only in the destruction of global variables
+    // (since we never shrink the corresponding vectors). At this point the CUDA
+    // runtime might be already destroyed and invoking cudaStreamDestroy leads
+    // to a crash. It's likely an issue in CUDA, but to be safe - let's just
+    // "forget" the destruction.
 
-  ~CUDAStreamInternals() {
-    if (stream) cudaStreamDestroy(stream);
+    // if (stream) cudaStreamDestroy(stream);
   }
 
   DeviceIndex device_index = -1;
@@ -37,13 +46,13 @@ static constexpr unsigned int kDefaultFlags = cudaStreamNonBlocking;
 // Note: stream priority is not supported by HIP
 // Note: lower numbers are higher priorities, zero is default priority
 #ifndef __HIP_PLATFORM_HCC__
-  static int kHighPriority = -1;
-  static int kLowPriority = 0;
+static int kHighPriority = -1;
+static int kLowPriority = 0;
 #endif // __HIP_PLATFORM_HCC__
 
 // Default streams
 static std::once_flag init_flag;
-static std::vector<CUDAStreamInternals> default_streams;
+static LeakyStreamInternals default_streams[C10_COMPILE_TIME_MAX_GPUS];
 
 // Non-default streams
 // Note: the number of CUDA devices is determined at run time,
@@ -53,11 +62,16 @@ static std::vector<CUDAStreamInternals> default_streams;
 // the low and high priority counters track, for each device, the next stream
 // in the pool to be returned when a stream is requested (round-robin fashion
 // , see the note in CUDAStream.h).
-static std::deque<std::once_flag> device_flags;
-static std::deque<std::atomic<uint32_t>> low_priority_counters;
-static std::deque<std::atomic<uint32_t>> high_priority_counters;
-static std::vector<std::array<CUDAStreamInternals, kStreamsPerPool>> low_priority_streams;
-static std::vector<std::array<CUDAStreamInternals, kStreamsPerPool>> high_priority_streams;
+//
+// unique_ptr<T[]> is used instead of vector<T> because T might be non-moveable
+// and non-copyable.
+static std::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS];
+static std::atomic<uint32_t> low_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
+static std::atomic<uint32_t> high_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
+static std::array<LeakyStreamInternals, kStreamsPerPool>
+    low_priority_streams[C10_COMPILE_TIME_MAX_GPUS];
+static std::array<LeakyStreamInternals, kStreamsPerPool>
+    high_priority_streams[C10_COMPILE_TIME_MAX_GPUS];
 
 // Note [StreamId assignment]
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -88,8 +102,8 @@ static std::vector<std::array<CUDAStreamInternals, kStreamsPerPool>> high_priori
 
 enum class StreamIdType : uint8_t {
   DEFAULT = 0x0,
-  LOW     = 0x1,
-  HIGH    = 0x2,
+  LOW = 0x1,
+  HIGH = 0x2,
 };
 
 std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
@@ -123,15 +137,17 @@ static inline size_t streamIdIndex(StreamId s) {
 }
 
 StreamId makeStreamId(StreamIdType st, size_t si) {
-  return (static_cast<StreamId>(st) << kStreamsPerPoolBits) | static_cast<StreamId>(si);
+  return (static_cast<StreamId>(st) << kStreamsPerPoolBits) |
+      static_cast<StreamId>(si);
 }
 
 template <typename T, typename A>
 static bool pointer_within(const T* ptr, const A& arr) {
-  return std::greater_equal<const T*>()(ptr, arr.data()) && std::less<const T*>()(ptr, arr.data() + arr.size());
+  return std::greater_equal<const T*>()(ptr, arr.data()) &&
+      std::less<const T*>()(ptr, arr.data() + arr.size());
 }
 
-static StreamId CUDAStream_getStreamId(const CUDAStreamInternals* ptr) {
+static StreamId CUDAStream_getStreamId(const LeakyStreamInternals* ptr) {
   // Hypothetically, we could store the stream ID in the stream.  But that
   // introduces a degree of freedom which could lead to bugs (where we
   // misnumber streams in the pool, or overwrite the number).  Better
@@ -149,21 +165,30 @@ static StreamId CUDAStream_getStreamId(const CUDAStreamInternals* ptr) {
   // NB: Because ptr may not necessarily lie within the array, we must use
   // std::less and similar templates to avoid UB that arises when
   // doing an operator< comparison.
-  if (pointer_within<CUDAStreamInternals>(ptr, low_priority_streams[device_index])) {
-    return makeStreamId(StreamIdType::LOW, ptr - low_priority_streams[device_index].data());
+  if (pointer_within<LeakyStreamInternals>(
+          ptr, low_priority_streams[device_index])) {
+    return makeStreamId(
+        StreamIdType::LOW, ptr - low_priority_streams[device_index].data());
   }
 
   // Check if it's a high priority stream
-  if (pointer_within<CUDAStreamInternals>(ptr, high_priority_streams[device_index])) {
-    return makeStreamId(StreamIdType::HIGH, ptr - high_priority_streams[device_index].data());
+  if (pointer_within<LeakyStreamInternals>(
+          ptr, high_priority_streams[device_index])) {
+    return makeStreamId(
+        StreamIdType::HIGH, ptr - high_priority_streams[device_index].data());
   }
 
-  AT_ASSERTM(0, "Could not compute stream ID for ", ptr, " on device ", device_index,
-                " (something has gone horribly wrong!)");
+  AT_ASSERTM(
+      0,
+      "Could not compute stream ID for ",
+      ptr,
+      " on device ",
+      device_index,
+      " (something has gone horribly wrong!)");
 }
 
 // Thread-local current streams
-static thread_local CUDAStreamInternals** current_streams = nullptr;
+static thread_local LeakyStreamInternals** current_streams = nullptr;
 
 // Populates global values and creates a default stream for each device.
 // Note: the default stream on each device is signified by a nullptr,
@@ -173,14 +198,14 @@ static thread_local CUDAStreamInternals** current_streams = nullptr;
 // Warning: this function must only be called once!
 static void initGlobalStreamState() {
   num_gpus = device_count();
-
-  // Resizes deques and vectors
-  default_streams.resize(num_gpus);
-  device_flags.resize(num_gpus);
-  low_priority_counters.resize(num_gpus);
-  high_priority_counters.resize(num_gpus);
-  low_priority_streams.resize(num_gpus);
-  high_priority_streams.resize(num_gpus);
+  // Check if the number of GPUs matches the expected compile-time max number
+  // of GPUs.
+  AT_ASSERTM(
+      num_gpus <= C10_COMPILE_TIME_MAX_GPUS,
+      "Number of CUDA devices on the machine is larger than the compiled "
+      "max number of gpus expected (",
+      C10_COMPILE_TIME_MAX_GPUS,
+      "). Increase that and recompile.");
 
   // Initializes default streams
   for (auto i = decltype(num_gpus){0}; i < num_gpus; ++i) {
@@ -204,23 +229,17 @@ static void initDeviceStreamState(DeviceIndex device_index) {
     lowpri_stream.device_index = device_index;
     hipri_stream.device_index = device_index;
 
-    #ifndef __HIP_PLATFORM_HCC__
-      C10_CUDA_CHECK(cudaStreamCreateWithPriority(
-        &lowpri_stream.stream
-      , kDefaultFlags
-      , kLowPriority));
-      C10_CUDA_CHECK(cudaStreamCreateWithPriority(
-        &hipri_stream.stream
-      , kDefaultFlags
-      , kHighPriority));
-    #else
-      C10_CUDA_CHECK(cudaStreamCreateWithFlags(
-        &lowpri_stream.stream
-      , kDefaultFlags));
-      C10_CUDA_CHECK(cudaStreamCreateWithFlags(
-        &hipri_stream.stream
-      , kDefaultFlags));
-    #endif // __HIP_PLATFORM_HCC__
+#ifndef __HIP_PLATFORM_HCC__
+    C10_CUDA_CHECK(cudaStreamCreateWithPriority(
+        &lowpri_stream.stream, kDefaultFlags, kLowPriority));
+    C10_CUDA_CHECK(cudaStreamCreateWithPriority(
+        &hipri_stream.stream, kDefaultFlags, kHighPriority));
+#else
+    C10_CUDA_CHECK(
+        cudaStreamCreateWithFlags(&lowpri_stream.stream, kDefaultFlags));
+    C10_CUDA_CHECK(
+        cudaStreamCreateWithFlags(&hipri_stream.stream, kDefaultFlags));
+#endif // __HIP_PLATFORM_HCC__
   }
 }
 
@@ -229,10 +248,13 @@ static void initCUDAStreamsOnce() {
   // Inits default streams (once, globally)
   std::call_once(init_flag, initGlobalStreamState);
 
-  if (current_streams) return;
+  if (current_streams) {
+    return;
+  }
 
   // Inits current streams (thread local) to default streams
-  current_streams = (CUDAStreamInternals**) malloc(num_gpus * sizeof(CUDAStreamInternals*));
+  current_streams =
+      (LeakyStreamInternals**)malloc(num_gpus * sizeof(LeakyStreamInternals*));
   for (auto i = decltype(num_gpus){0}; i < num_gpus; ++i) {
     current_streams[i] = &default_streams[i];
   }
@@ -245,37 +267,50 @@ static inline void check_gpu(DeviceIndex device_index) {
 
 // Helper to determine the index of the stream to return
 // Note: Streams are returned round-robin (see note in CUDAStream.h)
-static uint32_t get_idx(std::atomic<uint32_t> &counter) {
+static uint32_t get_idx(std::atomic<uint32_t>counter) {
   auto raw_idx = counter++;
   return raw_idx % kStreamsPerPool;
 }
 
 // See Note [StreamId assignment]
-CUDAStreamInternals* CUDAStream_internals(CUDAStream s) {
+LeakyStreamInternals* CUDAStream_internals(CUDAStream s) {
   c10::DeviceIndex device_index = s.device_index();
   StreamIdType st = streamIdType(s.unwrap().id());
   size_t si = streamIdIndex(s.unwrap().id());
   switch (st) {
     case StreamIdType::DEFAULT:
-      AT_ASSERTM(si == 0, "Unrecognized stream ", s.unwrap(),
-                          " (I think this should be the default stream, but I got a non-zero index ", si, ").",
-                          " Did you manufacture the StreamId yourself?  Don't do that; use the",
-                          " official API like c10::cuda::getStreamFromPool() to get a new stream.");
+      AT_ASSERTM(
+          si == 0,
+          "Unrecognized stream ",
+          s.unwrap(),
+          " (I think this should be the default stream, but I got a non-zero index ",
+          si,
+          ").",
+          " Did you manufacture the StreamId yourself?  Don't do that; use the",
+          " official API like c10::cuda::getStreamFromPool() to get a new stream.");
       return &default_streams[device_index];
     case StreamIdType::LOW:
       return &low_priority_streams[device_index][si];
     case StreamIdType::HIGH:
       return &high_priority_streams[device_index][si];
     default:
-      AT_ASSERTM(0, "Unrecognized stream ", s.unwrap(), " (I didn't recognize the stream type, ", st, ")");
+      AT_ASSERTM(
+          0,
+          "Unrecognized stream ",
+          s.unwrap(),
+          " (I didn't recognize the stream type, ",
+          st,
+          ")");
   }
 }
 
-CUDAStream CUDAStream_fromInternals(const CUDAStreamInternals* ptr) {
-  return CUDAStream(CUDAStream::UNCHECKED,
-                    Stream(Stream::UNSAFE,
-                           c10::Device(DeviceType::CUDA, ptr->device_index),
-                           CUDAStream_getStreamId(ptr)));
+CUDAStream CUDAStream_fromInternals(const LeakyStreamInternals* ptr) {
+  return CUDAStream(
+      CUDAStream::UNCHECKED,
+      Stream(
+          Stream::UNSAFE,
+          c10::Device(DeviceType::CUDA, ptr->device_index),
+          CUDAStream_getStreamId(ptr)));
 }
 
 } // anonymous namespace
@@ -290,14 +325,16 @@ cudaStream_t CUDAStream::stream() const {
 // Note: when called the first time on a device, this will create the
 // stream pools for that device.
 CUDAStream getStreamFromPool(
-  const bool isHighPriority
-, DeviceIndex device_index) {
+    const bool isHighPriority,
+    DeviceIndex device_index) {
   initCUDAStreamsOnce();
-  if (device_index == -1) device_index = current_device();
+  if (device_index == -1)
+    device_index = current_device();
   check_gpu(device_index);
 
   // Initializes the stream pools (once)
-  std::call_once(device_flags[device_index], initDeviceStreamState, device_index);
+  std::call_once(
+      device_flags[device_index], initDeviceStreamState, device_index);
 
   if (isHighPriority) {
     const auto idx = get_idx(high_priority_counters[device_index]);
@@ -310,13 +347,17 @@ CUDAStream getStreamFromPool(
 
 CUDAStream getDefaultCUDAStream(DeviceIndex device_index) {
   initCUDAStreamsOnce();
-  if (device_index == -1) device_index = current_device();
+  if (device_index == -1) {
+    device_index = current_device();
+  }
   check_gpu(device_index);
   return CUDAStream_fromInternals(&default_streams[device_index]);
 }
 CUDAStream getCurrentCUDAStream(DeviceIndex device_index) {
   initCUDAStreamsOnce();
-  if (device_index == -1) device_index = current_device();
+  if (device_index == -1) {
+    device_index = current_device();
+  }
   check_gpu(device_index);
   return CUDAStream_fromInternals(current_streams[device_index]);
 }
@@ -333,4 +374,4 @@ std::ostream& operator<<(std::ostream& stream, const CUDAStream& s) {
 }
 
 } // namespace cuda
-} // namespace at
+} // namespace c10
index 4c5313f..0db4530 100644 (file)
@@ -212,8 +212,8 @@ std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> ncclOpDevInfer(
 
 REGISTER_CUDA_OPERATOR(NCCLAllreduce, NCCLAllreduceOp);
 OPERATOR_SCHEMA(NCCLAllreduce)
-    .NumInputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
-    .NumOutputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
+    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
+    .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
     .CostInferenceFunction(NCCLAllreduceOp::CostInference)
     .TensorInferenceFunction(NCCLAllreduceOp::ShapeInference)
     .IdenticalTypeAndShape()
@@ -224,8 +224,8 @@ SHOULD_NOT_DO_GRADIENT(NCCLAllreduce);
 
 REGISTER_CUDA_OPERATOR(NCCLBroadcast, NCCLBroadcastOp);
 OPERATOR_SCHEMA(NCCLBroadcast)
-    .NumInputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
-    .NumOutputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
+    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
+    .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
     .IdenticalTypeAndShape()
     .InputsCanCrossDevices()
     .EnforceOneToOneInplace()
@@ -235,7 +235,7 @@ SHOULD_NOT_DO_GRADIENT(NCCLBroadcast);
 
 REGISTER_CUDA_OPERATOR(NCCLReduce, NCCLReduceOp);
 OPERATOR_SCHEMA(NCCLReduce)
-    .NumInputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
+    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
     .NumOutputs(1)
     .IdenticalTypeAndShapeOfInput(0)
     .InputsCanCrossDevices()
@@ -245,16 +245,16 @@ SHOULD_NOT_DO_GRADIENT(NCCLReduce);
 
 REGISTER_CUDA_OPERATOR(NCCLAllGather, NCCLAllGatherOp);
 OPERATOR_SCHEMA(NCCLAllGather)
-    .NumInputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
-    .NumOutputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
+    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
+    .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
     .InputsCanCrossDevices()
     .DeviceInferenceFunction(ncclOpDevInfer);
 SHOULD_NOT_DO_GRADIENT(NCCLAllGather);
 
 REGISTER_CUDA_OPERATOR(NCCLReduceScatter, NCCLReduceScatterOp);
 OPERATOR_SCHEMA(NCCLReduceScatter)
-    .NumInputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
-    .NumOutputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
+    .NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
+    .NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
     .InputsCanCrossDevices()
     .DeviceInferenceFunction(ncclOpDevInfer);
 SHOULD_NOT_DO_GRADIENT(NCCLReduceScatter);
index 1c2b487..4a69fc2 100644 (file)
@@ -25,6 +25,7 @@
 #include "caffe2/core/common.h"
 #include "caffe2/core/logging.h"
 
+#include "c10/cuda/CUDAMacros.h"
 #include "c10/cuda/CUDAMathCompat.h"
 
 // Defines CAFFE2_CUDA_EXPORT and CAFFE2_CUDA_IMPORT. On Windows, this
@@ -95,10 +96,6 @@ constexpr int kFp16CUDADevicePropMajor = 3;
 #endif // CUDA_VERSION >= 9000
 
 /**
- * The maximum number of GPUs that caffe2 recognizes.
- */
-#define CAFFE2_COMPILE_TIME_MAX_GPUS 16
-/**
  * The maximum number of peers that each gpu can have when doing p2p setup.
  * Currently, according to NVidia documentation, each device can support a
  * system-wide maximum of eight peer connections.
index 287e17e..8f3b65a 100644 (file)
@@ -178,8 +178,8 @@ static std::unordered_map<void*, uint8_t> g_cuda_device_affiliation;
 // Data structures for optional memory tracking. Access to these structures
 // is garded by the CUDAContext::mutex.
 static std::unordered_map<void*, long> g_size_map;
-static std::vector<long> g_total_by_gpu_map(CAFFE2_COMPILE_TIME_MAX_GPUS, 0);
-static std::vector<long> g_max_by_gpu_map(CAFFE2_COMPILE_TIME_MAX_GPUS, 0);
+static std::vector<long> g_total_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0);
+static std::vector<long> g_max_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0);
 
 static long g_total_mem = 0;
 static long g_last_rep = 0;
@@ -207,11 +207,11 @@ static void Caffe2InitializeCuda() {
   // of GPUs.
   CAFFE_ENFORCE_LE(
       NumCudaDevices(),
-      CAFFE2_COMPILE_TIME_MAX_GPUS,
+      C10_COMPILE_TIME_MAX_GPUS,
       "Number of CUDA devices on the machine is larger than the compiled "
       "max number of gpus expected (",
-      CAFFE2_COMPILE_TIME_MAX_GPUS,
-      "). Increase that and recompile the caffe binary.");
+      C10_COMPILE_TIME_MAX_GPUS,
+      "). Increase that and recompile.");
 
   for (DeviceIndex i = 0; i < NumCudaDevices(); ++i) {
     DeviceGuard g(i);
index fb7e59e..ff6f613 100644 (file)
@@ -57,7 +57,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
 
  private:
   ThreadLocalCUDAObjects() {
-    for (DeviceIndex i = 0; i < CAFFE2_COMPILE_TIME_MAX_GPUS; ++i) {
+    for (DeviceIndex i = 0; i < C10_COMPILE_TIME_MAX_GPUS; ++i) {
       cuda_streams_[i] = vector<c10::cuda::CUDAStream>();
     }
   }
@@ -153,7 +153,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
   // WARNING: mapping from logical stream ID to c10::cuda::CUDAStream
   // is NOT bijective; multiple logical stream IDs may map to the
   // same underlying stream ID.
-  vector<c10::cuda::CUDAStream> cuda_streams_[CAFFE2_COMPILE_TIME_MAX_GPUS];
+  vector<c10::cuda::CUDAStream> cuda_streams_[C10_COMPILE_TIME_MAX_GPUS];
   std::unordered_map<c10::cuda::CUDAStream, cublasHandle_t> cublas_handles_;
 #ifdef CAFFE2_USE_CUDNN
   std::unordered_map<c10::cuda::CUDAStream, cudnnHandle_t> cudnn_handles_;
index e9e9d4a..8e4809a 100644 (file)
@@ -149,7 +149,7 @@ class CuDNNWrapper {
 
   using PerGPUCuDNNStates = std::array<
       std::array<SyncedCuDNNState, CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES>,
-      CAFFE2_COMPILE_TIME_MAX_GPUS>;
+      C10_COMPILE_TIME_MAX_GPUS>;
   static PerGPUCuDNNStates& cudnn_states();
 
   C10_DISABLE_COPY_AND_ASSIGN(CuDNNWrapper);
index 0074471..4eefb5e 100644 (file)
@@ -151,9 +151,9 @@ class MIOPENWrapper
         std::unique_ptr<MIOPENState> state;
     };
 
-    using PerGPUMIOPENStates =
-        std::array<std::array<SyncedMIOPENState, CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES>,
-                   CAFFE2_COMPILE_TIME_MAX_GPUS>;
+    using PerGPUMIOPENStates = std::array<
+        std::array<SyncedMIOPENState, CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES>,
+        C10_COMPILE_TIME_MAX_GPUS>;
     static PerGPUMIOPENStates& miopen_states();
 
     C10_DISABLE_COPY_AND_ASSIGN(MIOPENWrapper);