From de0784510d06504d0825112e003370070ecdcd7d Mon Sep 17 00:00:00 2001 From: bddppq Date: Thu, 13 Dec 2018 15:41:55 -0800 Subject: [PATCH] Remove disabled_features in hipify Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15098 Reviewed By: ezyang Differential Revision: D13453762 Pulled By: bddppq fbshipit-source-id: e177042c78f5bf393163d660c25b80285353853d --- aten/src/ATen/native/Distributions.cpp | 14 ++- aten/src/ATen/native/Distributions.h | 92 +++++++++------- aten/src/ATen/native/cuda/Distributions.cu | 15 +-- aten/src/THC/THCBlas.cu | 26 ++++- aten/src/THC/THCGenerator.hpp | 4 +- aten/src/THC/THCTensorRandom.cpp | 2 +- aten/src/THC/THCTensorRandom.cu | 16 +-- aten/src/THC/THCTensorRandom.h | 8 +- aten/src/THCUNN/generic/RReLU.cu | 2 +- c10/cuda/CUDAMathCompat.h | 47 +++++++-- test/test_distributions.py | 1 - tools/amd_build/disabled_features.json | 127 +---------------------- tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py | 8 +- 13 files changed, 165 insertions(+), 197 deletions(-) diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index c81554d..30ffb76 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -209,13 +209,17 @@ Tensor _s_gamma_cpu(const Tensor& alpha, Generator *gen) { std::lock_guard lock(generator->mutex); CPU_tensor_apply2(ret, alpha, [generator](scalar_t& ret_val, const scalar_t& alpha){ - BaseSampler standard_uniform([generator] () { + + auto uniform_lambda = [generator] () { return THRandom_standard_uniform(generator); - }); - BaseSampler standard_normal([generator] () { + }; + BaseSampler standard_uniform(uniform_lambda); + + auto normal_lambda = [generator] () { return THRandom_normal(generator, 0.0, 1.0); - }); - auto sample = sample_gamma(alpha, standard_uniform, standard_normal); + }; + BaseSampler standard_normal(normal_lambda); + auto sample = sample_gamma(alpha, standard_uniform, standard_normal); ret_val = std::max(std::numeric_limits::min(), (scalar_t) sample); } ); diff --git a/aten/src/ATen/native/Distributions.h b/aten/src/ATen/native/Distributions.h index 0fe382a..31167d5 100644 --- a/aten/src/ATen/native/Distributions.h +++ b/aten/src/ATen/native/Distributions.h @@ -1,9 +1,6 @@ #pragma once #include -#ifdef __CUDA_ARCH__ -#include -#endif #include #include @@ -11,6 +8,8 @@ #include #include +#include + namespace at {namespace native { static inline THGenerator* get_generator(at::Generator* gen) { @@ -21,24 +20,41 @@ static inline THGenerator* get_generator(at::Generator* gen) { }} // namespace at::native +// ROCM hcc doesn't work well with using std:: in kernel functions +#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__) +#include +#define compat_exp c10::cuda::compat::exp +#define compat_floor c10::cuda::compat::floor +#define compat_log c10::cuda::compat::log +#define compat_pow c10::cuda::compat::pow +#define compat_sqrt c10::cuda::compat::sqrt +#define compat_tan c10::cuda::compat::tan +#else +#define compat_exp std::exp +#define compat_floor std::floor +#define compat_log std::log +#define compat_pow std::pow +#define compat_sqrt std::sqrt +#define compat_tan std::tan +#endif + namespace { -#ifdef __CUDA_ARCH__ -#define nvfunction_or_function nvstd::function -#define deviceforcuda __device__ -#else -#define nvfunction_or_function std::function -#define deviceforcuda +#if !defined(__CUDA_ARCH__) && !defined(__HIP_PLATFORM_HCC__) // we cannot use std::isnan directly due to some incompatibility of // gcc constexpr'ing and nvcc #define isnan std::isnan #endif -template +// Here sampler_t should be function type scalar_t(void). For gpu +// "sampler" is a device function, but since ROCM doesn't have +// equivalent to nvstd::function, we use a template type parameter to +// capture it. +template struct BaseSampler { - nvfunction_or_function sampler; - deviceforcuda BaseSampler(nvfunction_or_function sampler): sampler(sampler) {} - deviceforcuda scalar_t sample() { + sampler_t sampler; + C10_DEVICE BaseSampler(const sampler_t& sampler): sampler(sampler) {} + C10_DEVICE scalar_t sample() { return sampler(); } }; @@ -69,21 +85,21 @@ struct BaseSampler { * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -template -deviceforcuda scalar_t sample_gamma(scalar_t alpha, BaseSampler& standard_uniform, BaseSampler& standard_normal) { +template +C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler& standard_uniform, BaseSampler& standard_normal) { accscalar_t scale = 1.0f; // Boost alpha for higher acceptance probability. if (alpha < 1.0f) { if (alpha == 0.f) return 0.f; - scale *= std::pow(1 - standard_uniform.sample(), 1.0f / alpha); + scale *= compat_pow(1 - standard_uniform.sample(), 1.0f / alpha); alpha += 1.0f; } // This implements the acceptance-rejection method of Marsaglia and Tsang (2000) // doi:10.1145/358407.358414 const accscalar_t d = alpha - 1.0f / 3.0f; - const accscalar_t c = 1.0f / std::sqrt(9.0f * d); + const accscalar_t c = 1.0f / compat_sqrt(9.0f * d); for (;;) { accscalar_t x, y; do { @@ -95,13 +111,13 @@ deviceforcuda scalar_t sample_gamma(scalar_t alpha, BaseSampler& st const accscalar_t xx = x * x; if (u < 1.0f - 0.0331f * xx * xx) return static_cast(scale * d * v); - if (std::log(u) < 0.5f * xx + d * (1.0f - v + std::log(v))) + if (compat_log(u) < 0.5f * xx + d * (1.0f - v + compat_log(v))) return static_cast(scale * d * v); } } template -deviceforcuda static inline scalar_t polevl(const scalar_t x, const scalar_t A[], size_t len) { +C10_DEVICE static inline scalar_t polevl(const scalar_t x, const scalar_t A[], size_t len) { scalar_t result = 0; for (size_t i = 0; i <= len; i++) { result = result * x + A[i]; @@ -118,20 +134,21 @@ deviceforcuda static inline scalar_t polevl(const scalar_t x, const scalar_t A[ * Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier */ template -deviceforcuda static inline scalar_t digamma_one(scalar_t x) { +C10_DEVICE static inline scalar_t digamma_one(scalar_t x) { constexpr accscalar_t PSI_10 = 2.25175258906672110764; if (x == 0) { return INFINITY; } accscalar_t additional_summand = 0; - int x_is_integer = x == std::floor(x); + int x_is_integer = x == compat_floor(x); if (x < 0) { if (x_is_integer) { return INFINITY; } // it is more standard to write this as recursion, but // nvcc does not like that - additional_summand = - static_cast(M_PI) / std::tan(static_cast(M_PI) * x); + additional_summand = -static_cast(M_PI) / + compat_tan(static_cast(M_PI) * x); x = 1 - x; } @@ -161,13 +178,14 @@ deviceforcuda static inline scalar_t digamma_one(scalar_t x) { accscalar_t z = 1.0 / (x * x); y = z * polevl(z, A, 6); } - return static_cast(result + std::log(x) - (0.5f / x) - y + additional_summand); + return static_cast( + result + compat_log(x) - (0.5f / x) - y + additional_summand); } // Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha) // for random number x drawn from a standard Gamma distribution Gamma(alpha). template -deviceforcuda scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) { +C10_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) { // Use a Taylor series expansion for small x. accscalar_t x = static_cast(x_); accscalar_t alpha = static_cast(alpha_); @@ -182,11 +200,13 @@ deviceforcuda scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) { series1 += numer / denom; series2 += numer / (denom * denom); } - const auto pow_x_alpha = std::pow(x, alpha); - const auto gamma_pdf = std::pow(x, alpha - 1) * std::exp(-x); + const auto pow_x_alpha = compat_pow(x, alpha); + const auto gamma_pdf = compat_pow(x, alpha - 1) * compat_exp(-x); const auto gamma_cdf = pow_x_alpha * series1; - const auto gamma_cdf_alpha = (std::log(x) - digamma_one(alpha)) * gamma_cdf - - pow_x_alpha * series2; + const auto gamma_cdf_alpha = + (compat_log(x) - digamma_one(alpha)) * + gamma_cdf - + pow_x_alpha * series2; const auto result = -gamma_cdf_alpha / gamma_pdf; return isnan(result) ? static_cast( 0.f ) : static_cast(result); } @@ -200,20 +220,22 @@ deviceforcuda scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) { const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha); return static_cast(numer_1 * numer_2 / denom); } - const auto denom = std::sqrt(8 * alpha); + const auto denom = compat_sqrt(8 * alpha); const auto term2 = denom / (alpha - x); - const auto term3 = std::pow(x - alpha - alpha * std::log(x / alpha), static_cast(-1.5)); + const auto term3 = compat_pow( + x - alpha - alpha * compat_log(x / alpha), + static_cast(-1.5)); const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3; - const auto term1 = std::log(x / alpha) * term23 - - std::sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x)); + const auto term1 = compat_log(x / alpha) * term23 - + compat_sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x)); const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha)); const auto numer = x * term1; return static_cast(-stirling * numer / denom); } // Use a bivariate rational approximation to the reparameterized gradient. - const auto u = std::log(x / alpha); - const auto v = std::log(alpha); + const auto u = compat_log(x / alpha); + const auto v = compat_log(alpha); static const accscalar_t coef_uv[3][8] = { {0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115, 0.10406089, 0.0014179084}, @@ -228,7 +250,7 @@ deviceforcuda scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) { } const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3])); const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7])); - return static_cast(std::exp(p / q)); + return static_cast(compat_exp(p / q)); } } // namespace diff --git a/aten/src/ATen/native/cuda/Distributions.cu b/aten/src/ATen/native/cuda/Distributions.cu index 4ee3bb9..00a1b34 100644 --- a/aten/src/ATen/native/cuda/Distributions.cu +++ b/aten/src/ATen/native/cuda/Distributions.cu @@ -9,7 +9,6 @@ #include #include #include -#include #include @@ -72,13 +71,17 @@ void gamma_cuda_kernel( blockIdx.x * blockDim.x + threadIdx.x, seeds.second, &state); - BaseSampler standard_uniform([&state] __device__ () { + + auto uniform_lambda = [&state] __device__ () { return curand_uniform(&state); - }); - BaseSampler standard_normal([&state] __device__ () { + }; + BaseSampler standard_uniform(uniform_lambda); + + auto normal_lambda = [&state] __device__ () { return curand_normal(&state); - }); - auto sample = sample_gamma(alpha, standard_uniform, standard_normal); + }; + BaseSampler standard_normal(normal_lambda); + auto sample = sample_gamma(alpha, standard_uniform, standard_normal); auto min_value = std::numeric_limits::lowest(); ret_val = (min_value > sample) ? min_value : sample; }); diff --git a/aten/src/THC/THCBlas.cu b/aten/src/THC/THCBlas.cu index 73c3b00..51ae225 100644 --- a/aten/src/THC/THCBlas.cu +++ b/aten/src/THC/THCBlas.cu @@ -509,6 +509,7 @@ void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, i /* Inverse */ void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, int *info, int batchSize) { +#ifndef __HIP_PLATFORM_HCC__ if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) ) { THError("Cublas_Sgetrf only supports n, lda, batchSize" @@ -517,9 +518,13 @@ void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, i cublasHandle_t handle = THCState_getCurrentBlasHandle(state); cublasSetStream(handle, THCState_getCurrentStream(state)); THCublasCheck(cublasSgetrfBatched(handle, n, a, lda, pivot, info, batchSize)); +#else + THError("THCudaBlas_Sgetrf not supported in ROCM."); +#endif } void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize) { +#ifndef __HIP_PLATFORM_HCC__ if( (n >= INT_MAX) || (lda >= INT_MAX) || (batchSize >= INT_MAX) ) { THError("Cublas_Dgetrf only supports n, lda, batchSize" @@ -528,10 +533,14 @@ void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, cublasHandle_t handle = THCState_getCurrentBlasHandle(state); cublasSetStream(handle, THCState_getCurrentStream(state)); THCublasCheck(cublasDgetrfBatched(handle, n, a, lda, pivot, info, batchSize)); +#else + THError("THCudaBlas_Dgetrf not supported in ROCM."); +#endif } void THCudaBlas_Sgetrs(THCState *state, char transa, int n, int nrhs, const float **a, int lda, int *pivot, float **b, int ldb, int *info, int batchSize) { +#ifndef __HIP_PLATFORM_HCC__ if( (n >= INT_MAX) || (nrhs >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (batchSize >= INT_MAX) ) { THError("Cublas_Dgetrs only supports n, nrhs, lda, ldb, batchSize" @@ -544,11 +553,15 @@ void THCudaBlas_Sgetrs(THCState *state, char transa, int n, int nrhs, const floa cublasHandle_t handle = THCState_getCurrentBlasHandle(state); cublasSetStream(handle, THCState_getCurrentStream(state)); THCublasCheck(cublasSgetrsBatched(handle, opa, n, nrhs, a, lda, pivot, b, ldb, info, batchSize)); +#else + THError("THCudaBlas_Sgetrs not supported in ROCM."); +#endif } void THCudaBlas_Dgetrs(THCState *state, char transa, int n, int nrhs, const double **a, int lda, int *pivot, double **b, int ldb, int *info, int batchSize) { +#ifndef __HIP_PLATFORM_HCC__ if( (n >= INT_MAX) || (nrhs >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (batchSize >= INT_MAX) ) { THError("Cublas_Dgetrs only supports n, nrhs, lda, ldb, batchSize" @@ -561,10 +574,13 @@ void THCudaBlas_Dgetrs(THCState *state, char transa, int n, int nrhs, const doub cublasHandle_t handle = THCState_getCurrentBlasHandle(state); cublasSetStream(handle, THCState_getCurrentStream(state)); THCublasCheck(cublasDgetrsBatched(handle, opa, n, nrhs, a, lda, pivot, b, ldb, info, batchSize)); +#else + THError("THCudaBlas_Dgetrs not supported in ROCM."); +#endif } void THCudaBlas_Sgetri(THCState *state, int n, const float **a, int lda, int *pivot, float **c, int ldc, int *info, int batchSize) { - +#ifndef __HIP_PLATFORM_HCC__ if( (n >= INT_MAX) || (lda >= INT_MAX)|| (ldc >= INT_MAX) || (batchSize >= INT_MAX) ) { THError("Cublas_Sgetri only supports n, lda, ldc, batchSize" @@ -573,10 +589,13 @@ void THCudaBlas_Sgetri(THCState *state, int n, const float **a, int lda, int *pi cublasHandle_t handle = THCState_getCurrentBlasHandle(state); cublasSetStream(handle, THCState_getCurrentStream(state)); THCublasCheck(cublasSgetriBatched(handle, n, a, lda, pivot, c, ldc, info, batchSize)); +#else + THError("THCudaBlas_Sgetri not supported in ROCM."); +#endif } void THCudaBlas_Dgetri(THCState *state, int n, const double **a, int lda, int *pivot, double **c, int ldc, int *info, int batchSize) { - +#ifndef __HIP_PLATFORM_HCC__ if( (n >= INT_MAX) || (lda >= INT_MAX)|| (ldc >= INT_MAX) || (batchSize >= INT_MAX) ) { THError("Cublas_Dgetri only supports n, lda, ldc, batchSize" @@ -585,4 +604,7 @@ void THCudaBlas_Dgetri(THCState *state, int n, const double **a, int lda, int *p cublasHandle_t handle = THCState_getCurrentBlasHandle(state); cublasSetStream(handle, THCState_getCurrentStream(state)); THCublasCheck(cublasDgetriBatched(handle, n, a, lda, pivot, c, ldc, info, batchSize)); +#else + THError("THCudaBlas_Dgetri not supported in ROCM."); +#endif } diff --git a/aten/src/THC/THCGenerator.hpp b/aten/src/THC/THCGenerator.hpp index ea5d1ba..f1c4119 100644 --- a/aten/src/THC/THCGenerator.hpp +++ b/aten/src/THC/THCGenerator.hpp @@ -7,8 +7,8 @@ #include typedef struct THCGeneratorState { - struct curandStateMtgp32* gen_states; - struct mtgp32_kernel_params *kernel_params; + curandStateMtgp32* gen_states; + mtgp32_kernel_params *kernel_params; int initf; uint64_t initial_seed; std::atomic philox_seed_offset; diff --git a/aten/src/THC/THCTensorRandom.cpp b/aten/src/THC/THCTensorRandom.cpp index 5853d9c..e3cf5d9 100644 --- a/aten/src/THC/THCTensorRandom.cpp +++ b/aten/src/THC/THCTensorRandom.cpp @@ -87,7 +87,7 @@ THCGenerator* THCRandom_getGenerator(THCState* state) return gen; } -struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state) +curandStateMtgp32* THCRandom_generatorStates(THCState* state) { THCGenerator* gen = THCRandom_getGenerator(state); return gen->state.gen_states; diff --git a/aten/src/THC/THCTensorRandom.cu b/aten/src/THC/THCTensorRandom.cu index 69228ae..58bbabc 100644 --- a/aten/src/THC/THCTensorRandom.cu +++ b/aten/src/THC/THCTensorRandom.cu @@ -11,8 +11,6 @@ #include #include -#include -#include #define MAX_NUM_BLOCKS 200 #define BLOCK_SIZE 256 @@ -23,7 +21,7 @@ THCGenerator* THCRandom_getGenerator(THCState* state); /* Sets up generator. Allocates but does not create the generator states. Not thread-safe. */ __host__ void initializeGenerator(THCState *state, THCGenerator* gen) { - gen->state.gen_states = static_cast(THCudaMalloc(state, MAX_NUM_BLOCKS * sizeof(curandStateMtgp32))); + gen->state.gen_states = static_cast(THCudaMalloc(state, MAX_NUM_BLOCKS * sizeof(curandStateMtgp32))); gen->state.kernel_params = static_cast(THCudaMalloc(state, sizeof(mtgp32_kernel_params))); } @@ -44,7 +42,7 @@ __host__ void createGeneratorState(THCGenerator* gen, uint64_t seed) gen->state.philox_seed_offset = 0; } -__host__ void THCRandom_getRNGState(THCState* state, THByteTensor *rng_state) +THC_API __host__ void THCRandom_getRNGState(THCState* state, THByteTensor *rng_state) { THCGenerator* gen = THCRandom_getGenerator(state); std::lock_guard lock(gen->mutex); @@ -65,10 +63,14 @@ __host__ void THCRandom_getRNGState(THCState* state, THByteTensor *rng_state) __global__ void set_rngstate_kernel(curandStateMtgp32 *state, mtgp32_kernel_params *kernel) { +#ifndef __HIP_PLATFORM_HCC__ state[threadIdx.x].k = kernel; +#else + state[threadIdx.x].set_params(kernel); +#endif } -__host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state) +THC_API __host__ void THCRandom_setRNGState(THCState* state, THByteTensor *rng_state) { THCGenerator* gen = THCRandom_getGenerator(state); std::lock_guard lock(gen->mutex); @@ -118,7 +120,7 @@ __device__ inline at::Half half_uniform_scale_and_shift(float x, double a, doubl } #define GENERATE_KERNEL1(NAME, T, ARG1, CURAND_T, CURAND_FUNC, TRANSFORM) \ -__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1) \ +__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1) \ { \ int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; \ int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \ @@ -132,7 +134,7 @@ __global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1) \ } #define GENERATE_KERNEL2(NAME, T, ARG1, ARG2, CURAND_T, CURAND_FUNC, TRANSFORM) \ -__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2) \ +__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2) \ { \ int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; \ int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE; \ diff --git a/aten/src/THC/THCTensorRandom.h b/aten/src/THC/THCTensorRandom.h index 0dc8c2a..265b7ce 100644 --- a/aten/src/THC/THCTensorRandom.h +++ b/aten/src/THC/THCTensorRandom.h @@ -5,9 +5,9 @@ #include #include -#ifdef __HIP_PLATFORM_HCC__ -#include -#endif + +#include +#include typedef struct THCGenerator THCGenerator; @@ -29,6 +29,6 @@ THC_API uint64_t THCRandom_initialSeed(struct THCState *state); THC_API void THCRandom_getRNGState(struct THCState *state, THByteTensor *rng_state); THC_API void THCRandom_setRNGState(struct THCState *state, THByteTensor *rng_state); -THC_API struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state); +THC_API curandStateMtgp32* THCRandom_generatorStates(struct THCState* state); #endif diff --git a/aten/src/THCUNN/generic/RReLU.cu b/aten/src/THCUNN/generic/RReLU.cu index 2cbc4b9..654ea14 100644 --- a/aten/src/THCUNN/generic/RReLU.cu +++ b/aten/src/THCUNN/generic/RReLU.cu @@ -16,7 +16,7 @@ void THNN_(RReLU_updateOutput)( void *generator) { THCUNN_assertSameGPU(state, 3, input, output, noise); - struct curandStateMtgp32* gen_states = THCRandom_generatorStates(state); + curandStateMtgp32* gen_states = THCRandom_generatorStates(state); if (train) { diff --git a/c10/cuda/CUDAMathCompat.h b/c10/cuda/CUDAMathCompat.h index 35176e7..6356515 100644 --- a/c10/cuda/CUDAMathCompat.h +++ b/c10/cuda/CUDAMathCompat.h @@ -22,33 +22,68 @@ namespace cuda { namespace compat { __MATH_FUNCTIONS_DECL__ float abs(float x) { - return fabsf(x); + return ::fabsf(x); } __MATH_FUNCTIONS_DECL__ double abs(double x) { - return fabs(x); + return ::fabs(x); +} + +__MATH_FUNCTIONS_DECL__ float exp(float x) { + return ::expf(x); +} +__MATH_FUNCTIONS_DECL__ double exp(double x) { + return ::exp(x); +} + +__MATH_FUNCTIONS_DECL__ float floor(float x) { + return ::floorf(x); +} +__MATH_FUNCTIONS_DECL__ double floor(double x) { + return ::floor(x); +} + +__MATH_FUNCTIONS_DECL__ float log(float x) { + return ::logf(x); +} +__MATH_FUNCTIONS_DECL__ double log(double x) { + return ::log(x); } __MATH_FUNCTIONS_DECL__ float max(float x, float y) { - return fmaxf(x, y); + return ::fmaxf(x, y); } __MATH_FUNCTIONS_DECL__ double max(double x, double y) { - return fmax(x, y); + return ::fmax(x, y); } __MATH_FUNCTIONS_DECL__ float pow(float x, float y) { - return powf(x, y); + return ::powf(x, y); } __MATH_FUNCTIONS_DECL__ double pow(double x, double y) { return ::pow(x, y); } __MATH_FUNCTIONS_DECL__ void sincos(float x, float* sptr, float* cptr) { - return sincosf(x, sptr, cptr); + return ::sincosf(x, sptr, cptr); } __MATH_FUNCTIONS_DECL__ void sincos(double x, double* sptr, double* cptr) { return ::sincos(x, sptr, cptr); } +__MATH_FUNCTIONS_DECL__ float sqrt(float x) { + return ::sqrtf(x); +} +__MATH_FUNCTIONS_DECL__ double sqrt(double x) { + return ::sqrt(x); +} + +__MATH_FUNCTIONS_DECL__ float tan(float x) { + return ::tanf(x); +} +__MATH_FUNCTIONS_DECL__ double tan(double x) { + return ::tan(x); +} + } // namespace compat } // namespace cuda } // namespace c10 diff --git a/test/test_distributions.py b/test/test_distributions.py index de1c43e..993e5d9 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1975,7 +1975,6 @@ class TestDistributions(TestCase): @unittest.skipIf(not TEST_CUDA, "CUDA not found") @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @skipIfRocm def test_gamma_gpu_sample(self): set_rng_seed(0) for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]): diff --git a/tools/amd_build/disabled_features.json b/tools/amd_build/disabled_features.json index f7864dc..c7228b7 100644 --- a/tools/amd_build/disabled_features.json +++ b/tools/amd_build/disabled_features.json @@ -1,131 +1,8 @@ { - "disable_unsupported_hip_calls": - [ - { - "path": "aten/src/THC/THCBlas.cu", - "functions": { - "cublasSgemmEx": "rocblas_status_internal_error", - "cublasSgetrfBatched": "rocblas_status_internal_error", - "cublasDgetrfBatched": "rocblas_status_internal_error", - "cublasSgetrsBatched": "rocblas_status_internal_error", - "cublasDgetrsBatched": "rocblas_status_internal_error", - "cublasSgetriBatched": "rocblas_status_internal_error", - "cublasDgetriBatched": "rocblas_status_internal_error" - } - }, - { - "path": "aten/src/THC/THCStream.cpp", - "functions": { - "cudaStreamCreateWithFlags": "hipSuccess", - "cudaStreamCreateWithPriority": "hipSuccess" - } - }, - { - "path": "aten/src/THC/THCAllocator.cpp", - "functions": { - "cudaMallocManaged": "hipSuccess" - } - }, - { - "path": "aten/src/ATen/native/cuda/Distributions.cu", - "s_constants": { - "#include ": "" - } - }, - { - "path": "aten/src/ATen/native/cuda/RoiPooling.cu", - "s_constants": { - "RoiPooling2d_forward_kernel<<<": "RoiPooling2d_forward_kernel<<<" - } - }, - { - "path": "aten/src/THC/THCTensorRandom.cpp", - "s_constants": { - "struct curandStateMtgp32*": "curandStateMtgp32*" - } - }, - { - "path": "aten/src/THC/THCTensorRandom.cu", - "s_constants": { - "struct curandStateMtgp32*": "curandStateMtgp32*", - "__host__ void THCRandom_getRNGState": "extern \"C\" __host__ void THCRandom_getRNGState", - "__host__ void THCRandom_setRNGState": "extern \"C\" __host__ void THCRandom_setRNGState", - "state[threadIdx.x].k = kernel;" : "state[threadIdx.x].set_params(kernel);" - } - }, - { - "path": "aten/src/THC/THCTensorRandom.h", - "s_constants": { - "struct curandStateMtgp32*": "curandStateMtgp32*" - } - }, - { - "path": "aten/src/THCUNN/generic/RReLU.cu", - "s_constants": { - "struct curandStateMtgp32*": "curandStateMtgp32*" - } - }, - { - "path": "aten/src/THC/THCGenerator.hpp", - "s_constants": { - "struct curandStateMtgp32*": "curandStateMtgp32*", - "struct mtgp32_kernel_params": "mtgp32_kernel_params" - } - }, - { - "path": "aten/src/ATen/native/cuda/RoiPooling.cu", - "s_constants": { - "RoiPooling2d_backward_kernel<<<": "RoiPooling2d_backward_kernel<<<" - } - }, - { - "path": "aten/src/ATen/native/cuda/Unique.cu", - "s_constants": { - "inverse_indices_kernel<<<": "inverse_indices_kernel<<<" - } - } - ], + "disable_unsupported_hip_calls": [ + ], "disabled_modules": [ ], "disabled_functions": [ - { - "path": "aten/src/ATen/cuda/CUDAApplyUtils.cuh", - "functions": [ - "kernelPointwiseApply4" - ] - }, - { - "path": "aten/src/THCUNN/LookupTable.cu", - "functions": [ - "warpHasCollision" - ] - }, - { - "path": "aten/src/ATen/native/cuda/Distributions.cu", - "functions": [ - "gamma_cuda_kernel", - "gamma_grad_cuda_kernel" - ] - }, - { - "path": "aten/src/THCUNN/generic/SparseLinear.cu", - "functions": [ - "THNN_(SparseLinear_updateOutput)", - "THNN_(SparseLinear_accGradParameters)" - ] - }, - { - "path": "aten/src/THCUNN/generic/LookupTable.cu", - "functions": [ - "THNN_(LookupTable_accGradParameters)", - "THNN_(LookupTable_renorm)" - ] - }, - { - "path": "aten/src/THC/generic/THCTensor.cu", - "functions": [ - "THCTensor_(getTextureObject)" - ] - } ] } diff --git a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py index 53cfeb3..87931f1 100644 --- a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py +++ b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py @@ -276,7 +276,6 @@ CUDA_INCLUDE_MAP = collections.OrderedDict([ ("cusparse.h", ("hipsparse.h", CONV_INCLUDE, API_RAND)), ("cufft.h", ("hipfft.h", CONV_INCLUDE, API_BLAS)), ("cufftXt.h", ("hipfft.h", CONV_INCLUDE, API_BLAS)), - ("#include ", ("", CONV_INCLUDE, API_RAND, HIP_UNSUPPORTED)), ]) CUDA_IDENTIFIER_MAP = collections.OrderedDict([ @@ -2179,13 +2178,16 @@ CUDA_SPARSE_MAP = collections.OrderedDict([ ("cusparseStatus_t", ("hipsparseStatus_t", CONV_MATH_FUNC, API_SPARSE)), ("cusparseHandle_t", ("hipsparseHandle_t", CONV_MATH_FUNC, API_SPARSE)), ("cusparseOperation_t", ("hipsparseOperation_t", CONV_TYPE, API_SPARSE)), + ("cusparseCreateMatDescr", ("hipsparseCreateMatDescr", CONV_MATH_FUNC, API_SPARSE)), ("cusparseCreate", ("hipsparseCreate", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseDestroyMatDescr", ("hipsparseDestroyMatDescr", CONV_MATH_FUNC, API_SPARSE)), ("cusparseDestroy", ("hipsparseDestroy", CONV_MATH_FUNC, API_SPARSE)), ("cusparseXcoo2csr", ("hipsparseXcoo2csr", CONV_MATH_FUNC, API_SPARSE)), ("cusparseMatDescr_t", ("hipsparseMatDescr_t", CONV_MATH_FUNC, API_SPARSE)), - ("cusparseCreateMatDescr", ("hipsparseCreateMatDescr", CONV_MATH_FUNC, API_SPARSE)), ("cusparseScsrmm2", ("hipsparseScsrmm2", CONV_MATH_FUNC, API_SPARSE)), ("cusparseDcsrmm2", ("hipsparseDcsrmm2", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseScsrmm", ("hipsparseScsrmm", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseDcsrmm", ("hipsparseDcsrmm", CONV_MATH_FUNC, API_SPARSE)), ("cusparseXcsrsort_bufferSizeExt", ("hipsparseXcsrsort_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE)), ("cusparseXcsrsort", ("hipsparseXcsrsort", CONV_MATH_FUNC, API_SPARSE)), ("cusparseXcoosort_bufferSizeExt", ("hipsparseXcoosort_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE)), @@ -2193,6 +2195,7 @@ CUDA_SPARSE_MAP = collections.OrderedDict([ ("cusparseSetStream", ("hipsparseSetStream", CONV_MATH_FUNC, API_SPARSE)), ("cusparseCreateIdentityPermutation", ("hipsparseCreateIdentityPermutation", CONV_MATH_FUNC, API_SPARSE)), ("cusparseSetMatIndexBase", ("hipsparseSetMatIndexBase", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseSetMatType", ("hipsparseSetMatType", CONV_MATH_FUNC, API_SPARSE)), ("CUSPARSE_STATUS_SUCCESS", ("HIPSPARSE_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_SPARSE)), ("CUSPARSE_STATUS_NOT_INITIALIZED", ("HIPSPARSE_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_SPARSE)), ("CUSPARSE_STATUS_ALLOC_FAILED", ("HIPSPARSE_STATUS_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_SPARSE)), @@ -2208,6 +2211,7 @@ CUDA_SPARSE_MAP = collections.OrderedDict([ ("CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE", ("HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPARSE)), ("CUSPARSE_INDEX_BASE_ZERO", ("HIPSPARSE_INDEX_BASE_ZERO", CONV_NUMERIC_LITERAL, API_SPARSE)), ("CUSPARSE_INDEX_BASE_ONE", ("HIPSPARSE_INDEX_BASE_ONE", CONV_NUMERIC_LITERAL, API_SPARSE)), + ("CUSPARSE_MATRIX_TYPE_GENERAL", ("HIPSPARSE_MATRIX_TYPE_GENERAL", CONV_NUMERIC_LITERAL, API_SPARSE)), ]) PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict([ -- 2.7.4