[Codegen][cuda-fp16] fallback to fp32 simulation when cuda arch < sm53 (#4268)
authorYizhi Liu <liuyizhi@apache.org>
Sun, 10 Nov 2019 06:16:34 +0000 (22:16 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Sun, 10 Nov 2019 06:16:34 +0000 (22:16 -0800)
src/codegen/codegen_cuda.cc
src/codegen/literal/cuda_half_t.h [new file with mode: 0644]

index 5f04dd0..22e8d84 100644 (file)
@@ -27,6 +27,7 @@
 #include <cmath>
 #include <vector>
 #include <string>
+#include "literal/cuda_half_t.h"
 #include "codegen_cuda.h"
 
 namespace tvm {
@@ -50,6 +51,7 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
 
 std::string CodeGenCUDA::Finish() {
   if (enable_fp16_) {
+    decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n";
     decl_stream << "#include <cuda_fp16.h>\n";
     decl_stream << "__device__ half max"
                 << "(half a, half b)\n"
@@ -65,10 +67,16 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << "__device__ half operator*"
                 << "(__half a, __half b)\n"
                 <<   "{\n  return __hmul(a, b);\n}\n";
+    // otherwise simulate computation via float32
+    decl_stream << "#else\n";
+    decl_stream << _cuda_half_t_def;
+    decl_stream << "#endif\n\n";
   }
 
   if (enable_int8_) {
+    decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n";
     decl_stream << "#include <sm_61_intrinsics.h>\n";
+    decl_stream << "#endif\n";
   }
 
   if (need_math_constants_h_) {
diff --git a/src/codegen/literal/cuda_half_t.h b/src/codegen/literal/cuda_half_t.h
new file mode 100644 (file)
index 0000000..23075b0
--- /dev/null
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file cuda_half_t.h
+ * \brief half_t (fp16) definition for cuda codegen.
+ */
+#ifndef TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_
+#define TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_
+
+static constexpr const char* _cuda_half_t_def = R"(
+typedef unsigned short uint16_t;
+typedef unsigned char uint8_t;
+typedef int int32_t;
+typedef unsigned long long uint64_t;
+typedef unsigned int uint32_t;
+
+#define TVM_FORCE_INLINE inline __attribute__((always_inline))
+#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__
+#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))
+#define TVM_HALF_OPERATOR(RTYPE, OP)                              \
+  TVM_XINLINE RTYPE operator OP (half a, half b) {                \
+    return RTYPE(float(a) OP float(b));                           \
+  }                                                               \
+  template<typename T>                                            \
+  TVM_XINLINE RTYPE operator OP (half a, T b) {                   \
+    return RTYPE(float(a) OP float(b));                           \
+  }                                                               \
+  template<typename T>                                            \
+  TVM_XINLINE RTYPE operator OP (T a, half b) {                   \
+    return RTYPE(float(a) OP float(b));                           \
+  }
+
+#define TVM_HALF_ASSIGNOP(AOP, OP)                                \
+  template<typename T>                                            \
+  TVM_XINLINE half operator AOP (const T& a) {                    \
+    return *this = half(float(*this) OP float(a));                \
+  }                                                               \
+  template<typename T>                                            \
+  TVM_XINLINE half operator AOP (const volatile T& a) volatile {  \
+    return *this = half(float(*this) OP float(a));                \
+  }
+
+class TVM_ALIGNED(2) half {
+ public:
+  uint16_t half_;
+
+  static TVM_XINLINE half Binary(uint16_t value) {
+    half res;
+    res.half_ = value;
+    return res;
+  }
+
+  TVM_XINLINE half() {}
+
+  TVM_XINLINE half(const float& value) { constructor(value); }
+  TVM_XINLINE explicit half(const double& value) { constructor(value); }
+  TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }
+  TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
+  TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
+  TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
+  TVM_XINLINE explicit half(const int64_t& value) { constructor(value); }
+  TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
+
+  TVM_XINLINE operator float() const {                          \
+    return float(half2float(half_));                            \
+  }                                                             \
+  TVM_XINLINE operator float() const volatile {                 \
+    return float(half2float(half_));                            \
+  }
+
+
+  TVM_HALF_ASSIGNOP(+=, +)
+  TVM_HALF_ASSIGNOP(-=, -)
+  TVM_HALF_ASSIGNOP(*=, *)
+  TVM_HALF_ASSIGNOP(/=, /)
+
+  TVM_XINLINE half operator+() {
+    return *this;
+  }
+
+  TVM_XINLINE half operator-() {
+    return half(-float(*this));
+  }
+
+  TVM_XINLINE half operator=(const half& a) {
+    half_ = a.half_;
+    return a;
+  }
+
+  template<typename T>
+  TVM_XINLINE half operator=(const T& a) {
+    return *this = half(a);
+  }
+
+  TVM_XINLINE half operator=(const half& a) volatile {
+    half_ = a.half_;
+    return a;
+  }
+
+  template<typename T>
+  TVM_XINLINE half operator=(const T& a) volatile {
+    return *this = half(a);
+  }
+
+ private:
+  union Bits {
+    float f;
+    int32_t si;
+    uint32_t ui;
+  };
+
+  static int const fp16FractionBits = 10;
+  static int const fp32FractionBits = 23;
+  static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits);   // == 0x7fffff
+  static int32_t const fp32HiddenBit = 1 << fp32FractionBits;   // == 0x800000
+  static int const shift = fp32FractionBits - fp16FractionBits;   // == 13
+  static int const shiftSign = 16;
+  static int32_t const expAdjust = 127 - 15;   // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)
+
+  static int32_t const infN = 0x7F800000;   // flt32 infinity
+  static int32_t const maxN = 0x477FFFFF;   // max flt32 that's a flt16 normal after >> by shift
+  static int32_t const minN = 0x38800000;   // min flt16 normal as a flt32
+  static int32_t const maxZ = 0x33000000;   // max fp32 number that's still rounded to zero in fp16
+  static int32_t const signN = 0x80000000;  // flt32 sign bit
+
+  static int32_t const infC = infN >> shift;
+  static int32_t const nanN = (infC + 1) << shift;   // minimum flt16 nan as a flt32
+  static int32_t const maxC = maxN >> shift;
+  static int32_t const minC = minN >> shift;
+  static int32_t const signC = signN >> shiftSign;  // flt16 sign bit
+
+  static int32_t const mulN = 0x52000000;  // (1 << 23) / minN
+  static int32_t const mulC = 0x33800000;  // minN / (1 << (23 - shift))
+
+  static int32_t const subC = 0x003FF;  // max flt32 subnormal down shifted
+  static int32_t const norC = 0x00400;  // min flt32 normal down shifted
+
+  static int32_t const maxD = infC - maxC - 1;
+  static int32_t const minD = minC - subC - 1;
+
+  TVM_XINLINE uint16_t float2half(const float& value) const {
+    Bits v;
+    v.f = value;
+    uint32_t sign = v.si & signN;    // grab sign bit
+    v.si ^= sign;                    // clear sign bit from v
+    sign >>= shiftSign;              // logical shift sign to fp16 position
+
+    if (v.si <= maxZ) {
+      // Handle eventual zeros here to ensure
+      // vshift will not exceed 32 below.
+      v.ui = 0;
+    } else if (v.si < minN) {
+      // Handle denorms
+      uint32_t exp32 = v.ui >> fp32FractionBits;
+      int32_t exp16 = exp32 - expAdjust;
+      // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
+      // Smaller (so negative) exp16 values should result in greater right shifts.
+      uint32_t vshift = 1 - exp16;
+      uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
+      v.ui = significand >> vshift;
+    } else if (v.si <= maxN) {
+      // Handle norms
+      v.ui -= expAdjust << fp32FractionBits;
+    } else if (v.si <= infN) {
+      v.si = infN;
+    } else if (v.si < nanN) {
+      v.si = nanN;
+    }
+
+    v.ui >>= shift;
+    return sign | (v.ui & 0x7fff);
+  }
+
+  // Same as above routine, except for addition of volatile keyword
+  TVM_XINLINE uint16_t float2half(
+    const volatile float& value) const volatile {
+    Bits v;
+    v.f = value;
+    uint32_t sign = v.si & signN;    // grab sign bit
+    v.si ^= sign;                    // clear sign bit from v
+    sign >>= shiftSign;              // logical shift sign to fp16 position
+
+    if (v.si <= maxZ) {
+      // Handle eventual zeros here to ensure
+      // vshift will not exceed 32 below.
+      v.ui = 0;
+    } else if (v.si < minN) {
+      // Handle denorms
+      uint32_t exp32 = v.ui >> fp32FractionBits;
+      int32_t exp16 = exp32 - expAdjust;
+      // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
+      // Smaller (so negative) exp16 values should result in greater right shifts.
+      uint32_t vshift = 1 - exp16;
+      uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
+      v.ui = significand >> vshift;
+    } else if (v.si <= maxN) {
+      // Handle norms
+      v.ui -= expAdjust << fp32FractionBits;
+    } else if (v.si <= infN) {
+      v.si = infN;
+    } else if (v.si < nanN) {
+      v.si = nanN;
+    }
+
+    v.ui >>= shift;
+    return sign | (v.ui & 0x7fff);
+  }
+
+  TVM_XINLINE float half2float(const uint16_t& value) const {
+    Bits v;
+    v.ui = value;
+    int32_t sign = v.si & signC;
+    v.si ^= sign;
+    sign <<= shiftSign;
+    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
+    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
+    Bits s;
+    s.si = mulC;
+    s.f *= v.si;
+    int32_t mask = -(norC > v.si);
+    v.si <<= shift;
+    v.si ^= (s.si ^ v.si) & mask;
+    v.si |= sign;
+    return v.f;
+  }
+
+  TVM_XINLINE float half2float(
+    const volatile uint16_t& value) const volatile {
+    Bits v;
+    v.ui = value;
+    int32_t sign = v.si & signC;
+    v.si ^= sign;
+    sign <<= shiftSign;
+    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
+    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
+    Bits s;
+    s.si = mulC;
+    s.f *= v.si;
+    int32_t mask = -(norC > v.si);
+    v.si <<= shift;
+    v.si ^= (s.si ^ v.si) & mask;
+    v.si |= sign;
+    return v.f;
+  }
+
+  template<typename T>
+  TVM_XINLINE void constructor(const T& value) {
+    half_ = float2half(float(value));
+  }
+};
+
+TVM_HALF_OPERATOR(half, +)
+TVM_HALF_OPERATOR(half, -)
+TVM_HALF_OPERATOR(half, *)
+TVM_HALF_OPERATOR(half, /)
+TVM_HALF_OPERATOR(bool, >)
+TVM_HALF_OPERATOR(bool, <)
+TVM_HALF_OPERATOR(bool, >=)
+TVM_HALF_OPERATOR(bool, <=)
+)";
+
+#endif  // TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_