Add support for HIP to DispatchStub. (#14413)
authorEdward Yang <ezyang@fb.com>
Wed, 28 Nov 2018 19:05:36 +0000 (11:05 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 28 Nov 2018 19:07:45 +0000 (11:07 -0800)
Summary:
I feel a bit bad writing this patch, because there isn't really
any reason not to use the normal dispatch mechanism for CUDA
and HIP here (so we have *yet another dispatcher*), but I don't
really want to sign up to rewrite DispatchStub to deduplicate the
dispatcher right now.

Need to natively add support for HIP here, as I don't want to
have to HIPify files which are not in a CUDA directory.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14413

Differential Revision: D13220358

Pulled By: ezyang

fbshipit-source-id: cc61218322589a1dc2ab8eb9d5ddd3c616f6b712

aten/src/ATen/native/DispatchStub.h

index f12c98c..fca92da 100644 (file)
@@ -31,6 +31,9 @@
 //
 // To call:
 //   stub(kCPU, tensor);
+//
+// TODO: CPU instruction set selection should be folded into whatever
+// the main dispatch mechanism is.
 
 // ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
 #if defined(__clang__)
@@ -66,6 +69,9 @@ struct CAFFE2_API DispatchStub<rT (*)(Args...), T> {
     } else if (device_type == DeviceType::CUDA) {
       AT_ASSERTM(cuda_dispatch_ptr, "DispatchStub: missing CUDA kernel");
       return (*cuda_dispatch_ptr)(std::forward<ArgTypes>(args)...);
+    } else if (device_type == DeviceType::HIP) {
+      AT_ASSERTM(hip_dispatch_ptr, "DispatchStub: missing HIP kernel");
+      return (*hip_dispatch_ptr)(std::forward<ArgTypes>(args)...);
     } else {
       AT_ERROR("DispatchStub: unsupported device type", device_type);
     }
@@ -92,6 +98,7 @@ struct CAFFE2_API DispatchStub<rT (*)(Args...), T> {
 
   FnPtr cpu_dispatch_ptr = nullptr;
   FnPtr cuda_dispatch_ptr = nullptr;
+  FnPtr hip_dispatch_ptr = nullptr;
   static FnPtr DEFAULT;
 #ifdef HAVE_AVX_CPU_DEFINITION
   static FnPtr AVX;
@@ -103,11 +110,18 @@ struct CAFFE2_API DispatchStub<rT (*)(Args...), T> {
 
 namespace {
 template <typename FnPtr, typename T>
-struct RegisterDispatch {
-  RegisterDispatch(DispatchStub<FnPtr, T>& stub, FnPtr value) {
+struct RegisterCUDADispatch {
+  RegisterCUDADispatch(DispatchStub<FnPtr, T>& stub, FnPtr value) {
     stub.cuda_dispatch_ptr = value;
   }
 };
+
+template <typename FnPtr, typename T>
+struct RegisterHIPDispatch {
+  RegisterHIPDispatch(DispatchStub<FnPtr, T>& stub, FnPtr value) {
+    stub.hip_dispatch_ptr = value;
+  }
+};
 } // anonymous namespace
 
 // Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
@@ -142,10 +156,17 @@ struct RegisterDispatch {
   REGISTER_AVX2_DISPATCH(name, static_cast<fn_type>(nullptr))
 
 #define REGISTER_CUDA_DISPATCH(name, fn) \
-  static RegisterDispatch<decltype(fn), struct name> name ## __register(name, fn);
+  static RegisterCUDADispatch<decltype(fn), struct name> name ## __register(name, fn);
+
+#define REGISTER_HIP_DISPATCH(name, fn) \
+  static RegisterHIPDispatch<decltype(fn), struct name> name ## __register(name, fn);
 
+// NB: This macro must be used in an actual 'cu' file; if you try using
+// it from a 'cpp' file it will not work!
 #if defined(__CUDACC__)
 #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
+#elif defined(__HIPCC__)
+#define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
 #elif defined(CPU_CAPABILITY)
 #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
 #endif