From 1d522598fb97e26f9e518116a9c7fa0f22735d8d Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Thu, 7 Mar 2019 02:17:42 -0800 Subject: [PATCH] use fp16<->fp32 intrinsic (#17496) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17496 As title. Reviewed By: hyuen Differential Revision: D14222907 fbshipit-source-id: d5d6c032e725ca8b52aca2be7401ec3c59f6a242 --- caffe2/perfkernels/adagrad_avx.cc | 9 ++++++--- caffe2/perfkernels/cvtsh_ss_bugfix.h | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/caffe2/perfkernels/adagrad_avx.cc b/caffe2/perfkernels/adagrad_avx.cc index c1de220..8631d42 100644 --- a/caffe2/perfkernels/adagrad_avx.cc +++ b/caffe2/perfkernels/adagrad_avx.cc @@ -104,9 +104,12 @@ void adagrad_fp16_update_prefetch__avx_f16c( for (; i < N; ++i) { float gi = g[i]; - float hi = h[i] + gi * gi; - nh[i] = hi; - nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon); + float nhi = + _cvtsh_ss(reinterpret_cast(h)[i]) + gi * gi; + reinterpret_cast(nh)[i] = _cvtss_sh(nhi, 0); + float nwi = _cvtsh_ss(reinterpret_cast(w)[i]) + + lr * gi / (std::sqrt(nhi) + epsilon); + reinterpret_cast(nw)[i] = _cvtss_sh(nwi, 0); } } diff --git a/caffe2/perfkernels/cvtsh_ss_bugfix.h b/caffe2/perfkernels/cvtsh_ss_bugfix.h index 825e266..d5b02f5 100644 --- a/caffe2/perfkernels/cvtsh_ss_bugfix.h +++ b/caffe2/perfkernels/cvtsh_ss_bugfix.h @@ -12,6 +12,7 @@ #if __APPLE_NEED_FIX || __CLANG_NEED_FIX +#include #include // This version of clang has a bug that _cvtsh_ss is not defined, see @@ -25,6 +26,14 @@ _cvtsh_ss(unsigned short a) return r[0]; } +static __inline unsigned short + __attribute__((__always_inline__, __nodebug__, __target__("f16c"))) +_cvtss_sh(float a, int imm8) { + unsigned short ret; + *reinterpret_cast(&ret) = a; + return ret; +} + #endif // __APPLE_NEED_FIX || __CLANG_NEED_FIX #undef __APPLE_NEED_FIX @@ -32,6 +41,7 @@ _cvtsh_ss(unsigned short a) #ifdef _MSC_VER +#include #include // It seems that microsoft msvc does not have a _cvtsh_ss implementation so @@ -54,4 +64,10 @@ static inline float _cvtsh_ss(unsigned short x) { return t1.floatval; } +static inline unsigned short _cvtss_sh(float x, int imm8) { + unsigned short ret; + *reinterpret_cast(&ret) = x; + return ret; +} + #endif // _MSC_VER -- 2.7.4