<< "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half min(half a, half b)\n"
<< "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
- decl_stream << "__device__ half operator<="
- << "(__half a, __half b)\n"
- << "{\n return __hlt(a, b);\n}\n";
- decl_stream << "__device__ half operator+"
- << "(__half a, __half &b)\n"
- <<"{\n return __hadd(a, b);\n}\n";
- decl_stream << "__device__ half operator*"
- << "(__half a, __half b)\n"
- << "{\n return __hmul(a, b);\n}\n";
+ // FIXME(tvm-team): "volatile" is used to enable cross thread reduction,
+ // which is needed by operations such as softmax.
+ // However, volatile overloading is not supported in NVRTC and CUDA < 9.2.
+ // We need to figure out a solution which can satisfy both scenario.
+ // decl_stream << "__device__ half operator<="
+ // << "(const volatile __half &a, const volatile __half &b)\n"
+ // << "{\n return __hlt(a, b);\n}\n";
+ // decl_stream << "__device__ half operator+"
+ // << "(const volatile __half &a, const volatile __half &b)\n"
+ // <<"{\n return __hadd(a, b);\n}\n";
+ // decl_stream << "__device__ half operator*"
+ // << "(const volatile __half &a, const volatile __half &b)\n"
+ // << "{\n return __hmul(a, b);\n}\n";
// otherwise simulate computation via float32
decl_stream << "#else\n";
decl_stream << _cuda_half_t_def;
static constexpr const char* _cuda_half_t_def = R"(
typedef unsigned short uint16_t;
typedef unsigned char uint8_t;
+typedef signed char int8_t;
typedef int int32_t;
typedef unsigned long long uint64_t;
typedef unsigned int uint32_t;
TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
- TVM_XINLINE explicit half(const int64_t& value) { constructor(value); }
+ TVM_XINLINE explicit half(const long long& value) { constructor(value); }
TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
TVM_XINLINE operator float() const { \