fix cuda half math function is undefined: hpow, htanh (#6253)
authorcloud-mxd <maxiandi@bytedance.com>
Thu, 13 Aug 2020 15:36:59 +0000 (23:36 +0800)
committerGitHub <noreply@github.com>
Thu, 13 Aug 2020 15:36:59 +0000 (08:36 -0700)
src/target/source/literal/cuda_half_t.h

index baf4ba7..f8e92d5 100644 (file)
@@ -293,6 +293,22 @@ __pack_half2(const half x, const half y) {
   unsigned v1 = *((unsigned short *)&y);
   return (v1 << 16) | v0;
 }
+
+// fix undefined fp16 match function
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
+static inline __device__ __host__ half hpow(half x, half y) {
+  float tmp_x = __half2float(x);
+  float tmp_y = __half2float(y);
+  float result = powf(tmp_x, tmp_y);
+  return __float2half(result);
+}
+
+static inline __device__ __host__ half htanh(half x) {
+  float tmp_x = __half2float(x);
+  float result = tanhf(tmp_x);
+  return __float2half(result);
+}
+#endif
 )";
 
 static constexpr const char* _cuda_warp_intrinsic_util = R"(