//
// To call:
// stub(kCPU, tensor);
+//
+// TODO: CPU instruction set selection should be folded into whatever
+// the main dispatch mechanism is.
// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
#if defined(__clang__)
} else if (device_type == DeviceType::CUDA) {
AT_ASSERTM(cuda_dispatch_ptr, "DispatchStub: missing CUDA kernel");
return (*cuda_dispatch_ptr)(std::forward<ArgTypes>(args)...);
+ } else if (device_type == DeviceType::HIP) {
+ AT_ASSERTM(hip_dispatch_ptr, "DispatchStub: missing HIP kernel");
+ return (*hip_dispatch_ptr)(std::forward<ArgTypes>(args)...);
} else {
AT_ERROR("DispatchStub: unsupported device type", device_type);
}
FnPtr cpu_dispatch_ptr = nullptr;
FnPtr cuda_dispatch_ptr = nullptr;
+ FnPtr hip_dispatch_ptr = nullptr;
static FnPtr DEFAULT;
#ifdef HAVE_AVX_CPU_DEFINITION
static FnPtr AVX;
namespace {
template <typename FnPtr, typename T>
-struct RegisterDispatch {
- RegisterDispatch(DispatchStub<FnPtr, T>& stub, FnPtr value) {
+struct RegisterCUDADispatch {
+ RegisterCUDADispatch(DispatchStub<FnPtr, T>& stub, FnPtr value) {
stub.cuda_dispatch_ptr = value;
}
};
+
+template <typename FnPtr, typename T>
+struct RegisterHIPDispatch {
+ RegisterHIPDispatch(DispatchStub<FnPtr, T>& stub, FnPtr value) {
+ stub.hip_dispatch_ptr = value;
+ }
+};
} // anonymous namespace
// Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
REGISTER_AVX2_DISPATCH(name, static_cast<fn_type>(nullptr))
#define REGISTER_CUDA_DISPATCH(name, fn) \
- static RegisterDispatch<decltype(fn), struct name> name ## __register(name, fn);
+ static RegisterCUDADispatch<decltype(fn), struct name> name ## __register(name, fn);
+
+#define REGISTER_HIP_DISPATCH(name, fn) \
+ static RegisterHIPDispatch<decltype(fn), struct name> name ## __register(name, fn);
+// NB: This macro must be used in an actual 'cu' file; if you try using
+// it from a 'cpp' file it will not work!
#if defined(__CUDACC__)
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
+#elif defined(__HIPCC__)
+#define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
#elif defined(CPU_CAPABILITY)
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#endif