From 801cf0e8d3427b12d017ba7bcf0a17b9da2d2408 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 9 Nov 2019 22:16:34 -0800 Subject: [PATCH] [Codegen][cuda-fp16] fallback to fp32 simulation when cuda arch < sm53 (#4268) --- src/codegen/codegen_cuda.cc | 8 ++ src/codegen/literal/cuda_half_t.h | 280 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 288 insertions(+) create mode 100644 src/codegen/literal/cuda_half_t.h diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 5f04dd0..22e8d84 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -27,6 +27,7 @@ #include #include #include +#include "literal/cuda_half_t.h" #include "codegen_cuda.h" namespace tvm { @@ -50,6 +51,7 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { std::string CodeGenCUDA::Finish() { if (enable_fp16_) { + decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n"; decl_stream << "#include \n"; decl_stream << "__device__ half max" << "(half a, half b)\n" @@ -65,10 +67,16 @@ std::string CodeGenCUDA::Finish() { decl_stream << "__device__ half operator*" << "(__half a, __half b)\n" << "{\n return __hmul(a, b);\n}\n"; + // otherwise simulate computation via float32 + decl_stream << "#else\n"; + decl_stream << _cuda_half_t_def; + decl_stream << "#endif\n\n"; } if (enable_int8_) { + decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n"; decl_stream << "#include \n"; + decl_stream << "#endif\n"; } if (need_math_constants_h_) { diff --git a/src/codegen/literal/cuda_half_t.h b/src/codegen/literal/cuda_half_t.h new file mode 100644 index 0000000..23075b0 --- /dev/null +++ b/src/codegen/literal/cuda_half_t.h @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file cuda_half_t.h + * \brief half_t (fp16) definition for cuda codegen. + */ +#ifndef TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_ +#define TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_ + +static constexpr const char* _cuda_half_t_def = R"( +typedef unsigned short uint16_t; +typedef unsigned char uint8_t; +typedef int int32_t; +typedef unsigned long long uint64_t; +typedef unsigned int uint32_t; + +#define TVM_FORCE_INLINE inline __attribute__((always_inline)) +#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__ +#define TVM_ALIGNED(x) __attribute__ ((aligned(x))) +#define TVM_HALF_OPERATOR(RTYPE, OP) \ + TVM_XINLINE RTYPE operator OP (half a, half b) { \ + return RTYPE(float(a) OP float(b)); \ + } \ + template \ + TVM_XINLINE RTYPE operator OP (half a, T b) { \ + return RTYPE(float(a) OP float(b)); \ + } \ + template \ + TVM_XINLINE RTYPE operator OP (T a, half b) { \ + return RTYPE(float(a) OP float(b)); \ + } + +#define TVM_HALF_ASSIGNOP(AOP, OP) \ + template \ + TVM_XINLINE half operator AOP (const T& a) { \ + return *this = half(float(*this) OP float(a)); \ + } \ + template \ + TVM_XINLINE half operator AOP (const volatile T& a) volatile { \ + return *this = half(float(*this) OP float(a)); \ + } + +class TVM_ALIGNED(2) half { + public: + uint16_t half_; + + static TVM_XINLINE half Binary(uint16_t value) { + half res; + res.half_ = value; + return res; + } + + TVM_XINLINE half() {} + + TVM_XINLINE half(const float& value) { constructor(value); } + TVM_XINLINE explicit half(const double& value) { constructor(value); } + TVM_XINLINE explicit half(const int8_t& value) { constructor(value); } + TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); } + TVM_XINLINE explicit half(const int32_t& value) { constructor(value); } + TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); } + TVM_XINLINE explicit half(const int64_t& value) { constructor(value); } + TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); } + + TVM_XINLINE operator float() const { \ + return float(half2float(half_)); \ + } \ + TVM_XINLINE operator float() const volatile { \ + return float(half2float(half_)); \ + } + + + TVM_HALF_ASSIGNOP(+=, +) + TVM_HALF_ASSIGNOP(-=, -) + TVM_HALF_ASSIGNOP(*=, *) + TVM_HALF_ASSIGNOP(/=, /) + + TVM_XINLINE half operator+() { + return *this; + } + + TVM_XINLINE half operator-() { + return half(-float(*this)); + } + + TVM_XINLINE half operator=(const half& a) { + half_ = a.half_; + return a; + } + + template + TVM_XINLINE half operator=(const T& a) { + return *this = half(a); + } + + TVM_XINLINE half operator=(const half& a) volatile { + half_ = a.half_; + return a; + } + + template + TVM_XINLINE half operator=(const T& a) volatile { + return *this = half(a); + } + + private: + union Bits { + float f; + int32_t si; + uint32_t ui; + }; + + static int const fp16FractionBits = 10; + static int const fp32FractionBits = 23; + static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff + static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000 + static int const shift = fp32FractionBits - fp16FractionBits; // == 13 + static int const shiftSign = 16; + static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15) + + static int32_t const infN = 0x7F800000; // flt32 infinity + static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift + static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 + static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16 + static int32_t const signN = 0x80000000; // flt32 sign bit + + static int32_t const infC = infN >> shift; + static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32 + static int32_t const maxC = maxN >> shift; + static int32_t const minC = minN >> shift; + static int32_t const signC = signN >> shiftSign; // flt16 sign bit + + static int32_t const mulN = 0x52000000; // (1 << 23) / minN + static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift)) + + static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted + static int32_t const norC = 0x00400; // min flt32 normal down shifted + + static int32_t const maxD = infC - maxC - 1; + static int32_t const minD = minC - subC - 1; + + TVM_XINLINE uint16_t float2half(const float& value) const { + Bits v; + v.f = value; + uint32_t sign = v.si & signN; // grab sign bit + v.si ^= sign; // clear sign bit from v + sign >>= shiftSign; // logical shift sign to fp16 position + + if (v.si <= maxZ) { + // Handle eventual zeros here to ensure + // vshift will not exceed 32 below. + v.ui = 0; + } else if (v.si < minN) { + // Handle denorms + uint32_t exp32 = v.ui >> fp32FractionBits; + int32_t exp16 = exp32 - expAdjust; + // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. + // Smaller (so negative) exp16 values should result in greater right shifts. + uint32_t vshift = 1 - exp16; + uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); + v.ui = significand >> vshift; + } else if (v.si <= maxN) { + // Handle norms + v.ui -= expAdjust << fp32FractionBits; + } else if (v.si <= infN) { + v.si = infN; + } else if (v.si < nanN) { + v.si = nanN; + } + + v.ui >>= shift; + return sign | (v.ui & 0x7fff); + } + + // Same as above routine, except for addition of volatile keyword + TVM_XINLINE uint16_t float2half( + const volatile float& value) const volatile { + Bits v; + v.f = value; + uint32_t sign = v.si & signN; // grab sign bit + v.si ^= sign; // clear sign bit from v + sign >>= shiftSign; // logical shift sign to fp16 position + + if (v.si <= maxZ) { + // Handle eventual zeros here to ensure + // vshift will not exceed 32 below. + v.ui = 0; + } else if (v.si < minN) { + // Handle denorms + uint32_t exp32 = v.ui >> fp32FractionBits; + int32_t exp16 = exp32 - expAdjust; + // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. + // Smaller (so negative) exp16 values should result in greater right shifts. + uint32_t vshift = 1 - exp16; + uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); + v.ui = significand >> vshift; + } else if (v.si <= maxN) { + // Handle norms + v.ui -= expAdjust << fp32FractionBits; + } else if (v.si <= infN) { + v.si = infN; + } else if (v.si < nanN) { + v.si = nanN; + } + + v.ui >>= shift; + return sign | (v.ui & 0x7fff); + } + + TVM_XINLINE float half2float(const uint16_t& value) const { + Bits v; + v.ui = value; + int32_t sign = v.si & signC; + v.si ^= sign; + sign <<= shiftSign; + v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); + v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); + Bits s; + s.si = mulC; + s.f *= v.si; + int32_t mask = -(norC > v.si); + v.si <<= shift; + v.si ^= (s.si ^ v.si) & mask; + v.si |= sign; + return v.f; + } + + TVM_XINLINE float half2float( + const volatile uint16_t& value) const volatile { + Bits v; + v.ui = value; + int32_t sign = v.si & signC; + v.si ^= sign; + sign <<= shiftSign; + v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); + v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); + Bits s; + s.si = mulC; + s.f *= v.si; + int32_t mask = -(norC > v.si); + v.si <<= shift; + v.si ^= (s.si ^ v.si) & mask; + v.si |= sign; + return v.f; + } + + template + TVM_XINLINE void constructor(const T& value) { + half_ = float2half(float(value)); + } +}; + +TVM_HALF_OPERATOR(half, +) +TVM_HALF_OPERATOR(half, -) +TVM_HALF_OPERATOR(half, *) +TVM_HALF_OPERATOR(half, /) +TVM_HALF_OPERATOR(bool, >) +TVM_HALF_OPERATOR(bool, <) +TVM_HALF_OPERATOR(bool, >=) +TVM_HALF_OPERATOR(bool, <=) +)"; + +#endif // TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_ -- 2.7.4