--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file cuda_half_t.h
+ * \brief half_t (fp16) definition for cuda codegen.
+ */
+#ifndef TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_
+#define TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_
+
+static constexpr const char* _cuda_half_t_def = R"(
+typedef unsigned short uint16_t;
+typedef unsigned char uint8_t;
+typedef int int32_t;
+typedef unsigned long long uint64_t;
+typedef unsigned int uint32_t;
+
+#define TVM_FORCE_INLINE inline __attribute__((always_inline))
+#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__
+#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))
+#define TVM_HALF_OPERATOR(RTYPE, OP) \
+ TVM_XINLINE RTYPE operator OP (half a, half b) { \
+ return RTYPE(float(a) OP float(b)); \
+ } \
+ template<typename T> \
+ TVM_XINLINE RTYPE operator OP (half a, T b) { \
+ return RTYPE(float(a) OP float(b)); \
+ } \
+ template<typename T> \
+ TVM_XINLINE RTYPE operator OP (T a, half b) { \
+ return RTYPE(float(a) OP float(b)); \
+ }
+
+#define TVM_HALF_ASSIGNOP(AOP, OP) \
+ template<typename T> \
+ TVM_XINLINE half operator AOP (const T& a) { \
+ return *this = half(float(*this) OP float(a)); \
+ } \
+ template<typename T> \
+ TVM_XINLINE half operator AOP (const volatile T& a) volatile { \
+ return *this = half(float(*this) OP float(a)); \
+ }
+
+class TVM_ALIGNED(2) half {
+ public:
+ uint16_t half_;
+
+ static TVM_XINLINE half Binary(uint16_t value) {
+ half res;
+ res.half_ = value;
+ return res;
+ }
+
+ TVM_XINLINE half() {}
+
+ TVM_XINLINE half(const float& value) { constructor(value); }
+ TVM_XINLINE explicit half(const double& value) { constructor(value); }
+ TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }
+ TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
+ TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
+ TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
+ TVM_XINLINE explicit half(const int64_t& value) { constructor(value); }
+ TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
+
+ TVM_XINLINE operator float() const { \
+ return float(half2float(half_)); \
+ } \
+ TVM_XINLINE operator float() const volatile { \
+ return float(half2float(half_)); \
+ }
+
+
+ TVM_HALF_ASSIGNOP(+=, +)
+ TVM_HALF_ASSIGNOP(-=, -)
+ TVM_HALF_ASSIGNOP(*=, *)
+ TVM_HALF_ASSIGNOP(/=, /)
+
+ TVM_XINLINE half operator+() {
+ return *this;
+ }
+
+ TVM_XINLINE half operator-() {
+ return half(-float(*this));
+ }
+
+ TVM_XINLINE half operator=(const half& a) {
+ half_ = a.half_;
+ return a;
+ }
+
+ template<typename T>
+ TVM_XINLINE half operator=(const T& a) {
+ return *this = half(a);
+ }
+
+ TVM_XINLINE half operator=(const half& a) volatile {
+ half_ = a.half_;
+ return a;
+ }
+
+ template<typename T>
+ TVM_XINLINE half operator=(const T& a) volatile {
+ return *this = half(a);
+ }
+
+ private:
+ union Bits {
+ float f;
+ int32_t si;
+ uint32_t ui;
+ };
+
+ static int const fp16FractionBits = 10;
+ static int const fp32FractionBits = 23;
+ static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff
+ static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000
+ static int const shift = fp32FractionBits - fp16FractionBits; // == 13
+ static int const shiftSign = 16;
+ static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)
+
+ static int32_t const infN = 0x7F800000; // flt32 infinity
+ static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift
+ static int32_t const minN = 0x38800000; // min flt16 normal as a flt32
+ static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16
+ static int32_t const signN = 0x80000000; // flt32 sign bit
+
+ static int32_t const infC = infN >> shift;
+ static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32
+ static int32_t const maxC = maxN >> shift;
+ static int32_t const minC = minN >> shift;
+ static int32_t const signC = signN >> shiftSign; // flt16 sign bit
+
+ static int32_t const mulN = 0x52000000; // (1 << 23) / minN
+ static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))
+
+ static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted
+ static int32_t const norC = 0x00400; // min flt32 normal down shifted
+
+ static int32_t const maxD = infC - maxC - 1;
+ static int32_t const minD = minC - subC - 1;
+
+ TVM_XINLINE uint16_t float2half(const float& value) const {
+ Bits v;
+ v.f = value;
+ uint32_t sign = v.si & signN; // grab sign bit
+ v.si ^= sign; // clear sign bit from v
+ sign >>= shiftSign; // logical shift sign to fp16 position
+
+ if (v.si <= maxZ) {
+ // Handle eventual zeros here to ensure
+ // vshift will not exceed 32 below.
+ v.ui = 0;
+ } else if (v.si < minN) {
+ // Handle denorms
+ uint32_t exp32 = v.ui >> fp32FractionBits;
+ int32_t exp16 = exp32 - expAdjust;
+ // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
+ // Smaller (so negative) exp16 values should result in greater right shifts.
+ uint32_t vshift = 1 - exp16;
+ uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
+ v.ui = significand >> vshift;
+ } else if (v.si <= maxN) {
+ // Handle norms
+ v.ui -= expAdjust << fp32FractionBits;
+ } else if (v.si <= infN) {
+ v.si = infN;
+ } else if (v.si < nanN) {
+ v.si = nanN;
+ }
+
+ v.ui >>= shift;
+ return sign | (v.ui & 0x7fff);
+ }
+
+ // Same as above routine, except for addition of volatile keyword
+ TVM_XINLINE uint16_t float2half(
+ const volatile float& value) const volatile {
+ Bits v;
+ v.f = value;
+ uint32_t sign = v.si & signN; // grab sign bit
+ v.si ^= sign; // clear sign bit from v
+ sign >>= shiftSign; // logical shift sign to fp16 position
+
+ if (v.si <= maxZ) {
+ // Handle eventual zeros here to ensure
+ // vshift will not exceed 32 below.
+ v.ui = 0;
+ } else if (v.si < minN) {
+ // Handle denorms
+ uint32_t exp32 = v.ui >> fp32FractionBits;
+ int32_t exp16 = exp32 - expAdjust;
+ // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
+ // Smaller (so negative) exp16 values should result in greater right shifts.
+ uint32_t vshift = 1 - exp16;
+ uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
+ v.ui = significand >> vshift;
+ } else if (v.si <= maxN) {
+ // Handle norms
+ v.ui -= expAdjust << fp32FractionBits;
+ } else if (v.si <= infN) {
+ v.si = infN;
+ } else if (v.si < nanN) {
+ v.si = nanN;
+ }
+
+ v.ui >>= shift;
+ return sign | (v.ui & 0x7fff);
+ }
+
+ TVM_XINLINE float half2float(const uint16_t& value) const {
+ Bits v;
+ v.ui = value;
+ int32_t sign = v.si & signC;
+ v.si ^= sign;
+ sign <<= shiftSign;
+ v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
+ v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
+ Bits s;
+ s.si = mulC;
+ s.f *= v.si;
+ int32_t mask = -(norC > v.si);
+ v.si <<= shift;
+ v.si ^= (s.si ^ v.si) & mask;
+ v.si |= sign;
+ return v.f;
+ }
+
+ TVM_XINLINE float half2float(
+ const volatile uint16_t& value) const volatile {
+ Bits v;
+ v.ui = value;
+ int32_t sign = v.si & signC;
+ v.si ^= sign;
+ sign <<= shiftSign;
+ v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
+ v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
+ Bits s;
+ s.si = mulC;
+ s.f *= v.si;
+ int32_t mask = -(norC > v.si);
+ v.si <<= shift;
+ v.si ^= (s.si ^ v.si) & mask;
+ v.si |= sign;
+ return v.f;
+ }
+
+ template<typename T>
+ TVM_XINLINE void constructor(const T& value) {
+ half_ = float2half(float(value));
+ }
+};
+
+TVM_HALF_OPERATOR(half, +)
+TVM_HALF_OPERATOR(half, -)
+TVM_HALF_OPERATOR(half, *)
+TVM_HALF_OPERATOR(half, /)
+TVM_HALF_OPERATOR(bool, >)
+TVM_HALF_OPERATOR(bool, <)
+TVM_HALF_OPERATOR(bool, >=)
+TVM_HALF_OPERATOR(bool, <=)
+)";
+
+#endif // TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_