From cf83d50cee669d4c6d2ac2bcb5078b9bd1b18e1d Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 14 Nov 2019 09:17:42 -0800 Subject: [PATCH] [Codegen] remove fp16 function override for cuda (#4331) * add volatile override back * [codegen] remove fp16 function override for cuda --- src/codegen/codegen_cuda.cc | 22 +++++++++++++--------- src/codegen/literal/cuda_half_t.h | 3 ++- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 22e8d84..2a41282 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -58,15 +58,19 @@ std::string CodeGenCUDA::Finish() { << "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n"; decl_stream << "__device__ half min(half a, half b)\n" << "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n"; - decl_stream << "__device__ half operator<=" - << "(__half a, __half b)\n" - << "{\n return __hlt(a, b);\n}\n"; - decl_stream << "__device__ half operator+" - << "(__half a, __half &b)\n" - <<"{\n return __hadd(a, b);\n}\n"; - decl_stream << "__device__ half operator*" - << "(__half a, __half b)\n" - << "{\n return __hmul(a, b);\n}\n"; + // FIXME(tvm-team): "volatile" is used to enable cross thread reduction, + // which is needed by operations such as softmax. + // However, volatile overloading is not supported in NVRTC and CUDA < 9.2. + // We need to figure out a solution which can satisfy both scenario. + // decl_stream << "__device__ half operator<=" + // << "(const volatile __half &a, const volatile __half &b)\n" + // << "{\n return __hlt(a, b);\n}\n"; + // decl_stream << "__device__ half operator+" + // << "(const volatile __half &a, const volatile __half &b)\n" + // <<"{\n return __hadd(a, b);\n}\n"; + // decl_stream << "__device__ half operator*" + // << "(const volatile __half &a, const volatile __half &b)\n" + // << "{\n return __hmul(a, b);\n}\n"; // otherwise simulate computation via float32 decl_stream << "#else\n"; decl_stream << _cuda_half_t_def; diff --git a/src/codegen/literal/cuda_half_t.h b/src/codegen/literal/cuda_half_t.h index 23075b0..0889032 100644 --- a/src/codegen/literal/cuda_half_t.h +++ b/src/codegen/literal/cuda_half_t.h @@ -28,6 +28,7 @@ static constexpr const char* _cuda_half_t_def = R"( typedef unsigned short uint16_t; typedef unsigned char uint8_t; +typedef signed char int8_t; typedef int int32_t; typedef unsigned long long uint64_t; typedef unsigned int uint32_t; @@ -76,7 +77,7 @@ class TVM_ALIGNED(2) half { TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); } TVM_XINLINE explicit half(const int32_t& value) { constructor(value); } TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); } - TVM_XINLINE explicit half(const int64_t& value) { constructor(value); } + TVM_XINLINE explicit half(const long long& value) { constructor(value); } TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); } TVM_XINLINE operator float() const { \ -- 2.7.4