[Codegen] remove fp16 function override for cuda (#4331)
authorYizhi Liu <liuyizhi@apache.org>
Thu, 14 Nov 2019 17:17:42 +0000 (09:17 -0800)
committerWuwei Lin <wuwei@apache.org>
Thu, 14 Nov 2019 17:17:42 +0000 (12:17 -0500)
* add volatile override back

* [codegen] remove fp16 function override for cuda

src/codegen/codegen_cuda.cc
src/codegen/literal/cuda_half_t.h

index 22e8d84..2a41282 100644 (file)
@@ -58,15 +58,19 @@ std::string CodeGenCUDA::Finish() {
                 << "{\n  return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
     decl_stream << "__device__ half min(half a, half b)\n"
                 << "{\n  return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
-    decl_stream << "__device__ half operator<="
-                << "(__half a,  __half b)\n"
-                << "{\n  return __hlt(a, b);\n}\n";
-    decl_stream << "__device__ half operator+"
-                << "(__half a,  __half &b)\n"
-                <<"{\n  return __hadd(a, b);\n}\n";
-    decl_stream << "__device__ half operator*"
-                << "(__half a, __half b)\n"
-                <<   "{\n  return __hmul(a, b);\n}\n";
+    // FIXME(tvm-team): "volatile" is used to enable cross thread reduction,
+    // which is needed by operations such as softmax.
+    // However, volatile overloading is not supported in NVRTC and CUDA < 9.2.
+    // We need to figure out a solution which can satisfy both scenario.
+    // decl_stream << "__device__ half operator<="
+    //             << "(const volatile __half &a,  const volatile __half &b)\n"
+    //             << "{\n  return __hlt(a, b);\n}\n";
+    // decl_stream << "__device__ half operator+"
+    //             << "(const volatile __half &a,  const volatile __half &b)\n"
+    //             <<"{\n  return __hadd(a, b);\n}\n";
+    // decl_stream << "__device__ half operator*"
+    //             << "(const volatile __half &a, const volatile __half &b)\n"
+    //             <<   "{\n  return __hmul(a, b);\n}\n";
     // otherwise simulate computation via float32
     decl_stream << "#else\n";
     decl_stream << _cuda_half_t_def;
index 23075b0..0889032 100644 (file)
@@ -28,6 +28,7 @@
 static constexpr const char* _cuda_half_t_def = R"(
 typedef unsigned short uint16_t;
 typedef unsigned char uint8_t;
+typedef signed char int8_t;
 typedef int int32_t;
 typedef unsigned long long uint64_t;
 typedef unsigned int uint32_t;
@@ -76,7 +77,7 @@ class TVM_ALIGNED(2) half {
   TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
   TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
   TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
-  TVM_XINLINE explicit half(const int64_t& value) { constructor(value); }
+  TVM_XINLINE explicit half(const long long& value) { constructor(value); }
   TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
 
   TVM_XINLINE operator float() const {                          \