bool is_float() const {
return code() == DataType::kFloat;
}
+ /*! \return whether type is a float16 type. */
+ bool is_float16() const {
+ return is_float() && bits() == 16;
+ }
/*! \return whether type is an int type. */
bool is_int() const {
return code() == DataType::kInt;
decl_stream << "#else\n";
decl_stream << _cuda_half_t_def;
decl_stream << "#endif\n\n";
+ decl_stream << _cuda_half_util;
}
if (enable_int8_) {
if (lanes == 1) {
os << "half";
} else if (lanes <= 8) {
+ // Emit CUDA code to access fp16 vector elements.
+ //
+ // half4 is stored as uint2
+ //
+ // h4.x is emitted as *(half2*)(&(u2.x)).x
+ // h4.y is emitted as *(half2*)(&(u2.x)).y
+ // h4.z is emitted as *(half2*)(&(u2.y)).x
+ // h4.w is emitted as *(half2*)(&(u2.y)).y
+ //
CHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
- os << "float" << lanes / 2;
+ os << "uint" << lanes / 2;
} else {
fail = true;
}
void CodeGenCUDA::PrintVecElemLoad(
const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*)
static const char access[] = {'x', 'y', 'z', 'w'};
- CHECK(i >= 0 && i < 4);
+ CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
if (t.is_int() && t.bits() == 8) {
os << "(0x000000ff & (" << vec << " >> " << i * 8 << "))";
+ } else if (t.is_float16()) {
+ os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
+ << access[i % 2];
} else {
os << vec << "." << access[i];
}
const std::string& vec, DataType t, int i, const std::string& value) {
this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'};
- CHECK(i >= 0 && i < 4);
+ CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
if (t.is_int() && t.bits() == 8) {
- stream << vec << "=" << vec << " & ~(0x000000ff << " << i * 8 << ") | ("
- << value << " << " << i * 8 << ");\n";
+ stream << vec << "=";
+ // Do not read the first undef lane.
+ if (i != 0) {
+ stream << vec << " & ~(0x000000ff << " << i * 8 << ") |";
+ }
+ stream << "(" << value << " << " << i * 8 << ");\n";
+ } else if (t.is_float16()) {
+ stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
+ << access[i % 2] << " = " << value << ";\n";
} else {
stream << vec << "." << access[i] << " = " << value << ";\n";
}
return;
}
+ if (op->dtype.is_float16()) {
+ std::string v = PrintExpr(op->value);
+ os << "make_";
+ PrintType(op->dtype, os);
+ os << '(';
+ for (int i = 0; i < op->lanes / 2; ++i) {
+ if (i != 0) os << ", ";
+ os << "__pack_half2(" << v << ", " << v << ")";
+ }
+ os << ')';
+ return;
+ }
+
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
}
)";
+static constexpr const char* _cuda_half_util = R"(
+// Pack two half values.
+static inline __device__ __host__ unsigned
+__pack_half2(const half x, const half y) {
+ unsigned v0 = *((unsigned short *)&x);
+ unsigned v1 = *((unsigned short *)&y);
+ return (v0 << 16) | v1;
+}
+)";
+
#endif // TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_
import tvm
import numpy as np
import unittest
-from tvm.contrib.nvcc import have_fp16, have_int8
+from tvm.contrib.nvcc import parse_compute_version, have_int8
from tvm.contrib import nvcc
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
-
def test_cuda_vectorize_add():
num_thread = 8
def check_cuda(dtype, n, lanes):
if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return
- if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
- print("skip because gpu does not support fp16")
+ if dtype == "float16":
+ major, minor = parse_compute_version(tvm.gpu(0).compute_version)
+ # fp16 starts from 5.3
+ if major < 6 or (major == 5 and minor < 3):
+ print("skip because gpu does not support fp16")
return
if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version):
print("skip because gpu does not support int8")
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
check_cuda("float32", 64, 2)
- check_cuda("int8", 64, 4)
- # check_cuda("float16", 64, 2)
-
- # TODO(tvm-team) fix fp16 codegen here
- # or hit an error if it is less frequently used.
- # check_cuda("float16", 64, 2)
-
+ check_cuda("float32", 64, 3)
+ check_cuda("float32", 64, 4)
+ check_cuda("int8", 64, 4)
+ check_cuda("float16", 64, 2)
+ check_cuda("float16", 64, 4)
+ check_cuda("float16", 64, 6)
+ check_cuda("float16", 64, 8)
def test_cuda_multiply_add():
num_thread = 8