[mlir][sparse] Use the correct ABI on x86 and re-enable tests
authorBenjamin Kramer <benny.kra@googlemail.com>
Thu, 11 Aug 2022 08:38:12 +0000 (10:38 +0200)
committerBenjamin Kramer <benny.kra@googlemail.com>
Thu, 11 Aug 2022 08:41:53 +0000 (10:41 +0200)
c7ec6e19d5446a448f888b33f66316cf2ec6ecae made LLVM adhere to the x86
psABI and pass bf16 in SSE registers instead of GPRs. This breaks the
custom versions of runtime functions we have for bf16 conversion. A
great fix for this would be to use __bf16 types instead which carry the
right ABI, but that type isn't widely available.

Instead just pretend it's a 32 bit float on the ABI boundary and
carefully cast it to the right type.

Fixes #57042

mlir/lib/ExecutionEngine/Float16bits.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/dense_output_bf16.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_bf16.mlir

index c163316..5c3fb61 100644 (file)
@@ -13,6 +13,7 @@
 
 #include "mlir/ExecutionEngine/Float16bits.h"
 #include <cmath>
+#include <cstring>
 
 namespace {
 
@@ -146,30 +147,46 @@ std::ostream &operator<<(std::ostream &os, const bf16 &d) {
   return os;
 }
 
-// Provide a float->bfloat conversion routine in case the runtime doesn't have
-// one.
-extern "C" uint16_t
+// Mark these symbols as weak so they don't conflict when compiler-rt also
+// defines them.
+#define ATTR_WEAK
 #ifdef __has_attribute
 #if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) &&  \
     !defined(_WIN32)
-    __attribute__((__weak__))
+#undef ATTR_WEAK
+#define ATTR_WEAK __attribute__((__weak__))
+#endif
 #endif
+
+#if defined(__x86_64__)
+// On x86 bfloat16 is passed in SSE2 registers. Since both float and _Float16
+// are passed in the same register we can use the wider type and careful casting
+// to conform to x86_64 psABI. This only works with the assumption that we're
+// dealing with little-endian values passed in wider registers.
+using BF16ABIType = float;
+#else
+// Default to uint16_t if we have nothing else.
+using BF16ABIType = uint16_t;
 #endif
-    __truncsfbf2(float f) {
-  return float2bfloat(f);
+
+// Provide a float->bfloat conversion routine in case the runtime doesn't have
+// one.
+extern "C" BF16ABIType ATTR_WEAK __truncsfbf2(float f) {
+  uint16_t bf = float2bfloat(f);
+  // The output can be a float type, bitcast it from uint16_t.
+  BF16ABIType ret = 0;
+  std::memcpy(&ret, &bf, sizeof(bf));
+  return ret;
 }
 
 // Provide a double->bfloat conversion routine in case the runtime doesn't have
 // one.
-extern "C" uint16_t
-#ifdef __has_attribute
-#if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) &&  \
-    !defined(_WIN32)
-    __attribute__((__weak__))
-#endif
-#endif
-    __truncdfbf2(double d) {
+extern "C" BF16ABIType ATTR_WEAK __truncdfbf2(double d) {
   // This does a double rounding step, but it's precise enough for our use
   // cases.
-  return __truncsfbf2(static_cast<float>(d));
+  uint16_t bf = __truncsfbf2(static_cast<float>(d));
+  // The output can be a float type, bitcast it from uint16_t.
+  BF16ABIType ret = 0;
+  std::memcpy(&ret, &bf, sizeof(bf));
+  return ret;
 }
index fb39236..f776a3d 100644 (file)
@@ -1,6 +1,3 @@
-// FIXME: see #57042
-// UNSUPPORTED: i386, x86_64
-
 // RUN: mlir-opt %s --sparse-compiler | \
 // RUN: mlir-cpu-runner \
 // RUN:  -e entry -entry-point-result=void  \
index 6f6a478..f3307c5 100644 (file)
@@ -1,6 +1,3 @@
-// FIXME: see #57042
-// UNSUPPORTED: i386, x86_64
-
 // RUN: mlir-opt %s --sparse-compiler | \
 // RUN: mlir-cpu-runner \
 // RUN:  -e entry -entry-point-result=void  \