Add __float2half_rn for cuda compute capabilities less than 53 (#4489)
authorreminisce <wujun.nju@gmail.com>
Tue, 10 Dec 2019 22:05:52 +0000 (14:05 -0800)
committerYizhi Liu <liuyizhi@apache.org>
Tue, 10 Dec 2019 22:05:52 +0000 (14:05 -0800)
* Fix

* clean up

src/codegen/literal/cuda_half_t.h
tests/python/unittest/test_codegen_cuda.py

index 94e9528..630a741 100644 (file)
@@ -176,8 +176,10 @@ class TVM_ALIGNED(2) half {
       uint32_t vshift = 1 - exp16;
       uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
       v.ui = significand >> vshift;
+      v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
     } else if (v.si <= maxN) {
       // Handle norms
+      v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
       v.ui -= expAdjust << fp32FractionBits;
     } else if (v.si <= infN) {
       v.si = infN;
@@ -211,8 +213,10 @@ class TVM_ALIGNED(2) half {
       uint32_t vshift = 1 - exp16;
       uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
       v.ui = significand >> vshift;
+      v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
     } else if (v.si <= maxN) {
       // Handle norms
+      v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
       v.ui -= expAdjust << fp32FractionBits;
     } else if (v.si <= infN) {
       v.si = infN;
@@ -275,6 +279,10 @@ TVM_HALF_OPERATOR(bool, >)
 TVM_HALF_OPERATOR(bool, <)
 TVM_HALF_OPERATOR(bool, >=)
 TVM_HALF_OPERATOR(bool, <=)
+
+TVM_XINLINE half __float2half_rn(const float a) {
+  return half(a);
+}
 )";
 
 #endif  // TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_
index 7991c60..27a8d87 100644 (file)
@@ -17,6 +17,7 @@
 # under the License.
 import tvm
 import numpy as np
+import unittest
 from tvm.contrib.nvcc import have_fp16, have_int8
 from tvm.contrib import nvcc
 
@@ -263,6 +264,32 @@ def test_rfactor_predicates():
     fcuda = tvm.build(s, [A, B], "cuda")
 
 
+@unittest.skipIf(not tvm.gpu(0).exist or not tvm.module.enabled("cuda"), "skip because cuda is not enabled..")
+def test_cuda_const_float_to_half():
+    # This import is required to use nvcc to perform code gen;
+    # otherwise it is found that the code gen is done by nvrtc.
+    from tvm import autotvm
+    shape = (2, 3, 4)
+    a = tvm.placeholder(shape, dtype='float16', name='a')
+    b = tvm.const(0.5, dtype='float16')
+    c = tvm.compute(shape, lambda i, j, k: a[i, j, k] > b, name='c')
+    s = tvm.create_schedule(c.op)
+    axes = [axis for axis in c.op.axis]
+    fused = s[c].fuse(*axes)
+    bx, tx = s[c].split(fused, factor=64)
+    s[c].bind(bx, tvm.thread_axis('blockIdx.x'))
+    s[c].bind(tx, tvm.thread_axis('threadIdx.x'))
+
+    func = tvm.build(s, [a, c], 'cuda')
+    ctx = tvm.gpu(0)
+    a_np = np.random.uniform(size=shape).astype(a.dtype)
+    c_np = np.zeros(shape=shape, dtype=c.dtype)
+    a = tvm.nd.array(a_np, ctx)
+    c = tvm.nd.array(c_np, ctx)
+    func(a, c)
+    np.testing.assert_equal(c.asnumpy(), a_np > b.value)
+
+
 if __name__ == "__main__":
     test_cuda_vectorize_add()
     test_cuda_multiply_add()
@@ -272,3 +299,4 @@ if __name__ == "__main__":
     test_cuda_shuffle()
     test_cuda_reducition_binding()
     test_rfactor_predicates()
+    test_cuda_const_float_to_half()