More numerically stable lerp (#18871)
authorMarek Kolodziej <mkolod@gmail.com>
Fri, 5 Apr 2019 19:43:02 +0000 (12:43 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 5 Apr 2019 19:51:20 +0000 (12:51 -0700)
Summary:
The C++ and CUDA implementations of the lerp are not numerically stable. This is discussed on Wikipedia [here](https://en.wikipedia.org/wiki/Linear_interpolation#Programming_language_support). I checked the GPU SASS output and there's no overhead from using the more precise implementation, from Kepler all the way to Turing. I haven't looked at CPU ASM though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18871

Differential Revision: D14793438

Pulled By: ezyang

fbshipit-source-id: 2ddc2e026c5285466cae7d1b4101174253100445

aten/src/ATen/native/Lerp.cpp
aten/src/ATen/native/cuda/Lerp.cu

index 96f9534..37dfa05 100644 (file)
@@ -13,7 +13,8 @@ void lerp_cpu(at::Tensor& ret, const at::Tensor& self, const at::Tensor& end, co
          const scalar_t& self_val,
          const scalar_t& end_val,
          const scalar_t& weight_val) {
-        ret_val = self_val + weight_val * (end_val - self_val);
+        ret_val = (weight_val < 0.5) ?
+            self_val + weight_val * (end_val - self_val) : end_val - (end_val - self_val) * (1 - weight_val);
       });
 }
 
@@ -24,7 +25,8 @@ void lerp_cpu(at::Tensor& ret, const at::Tensor& self, const at::Tensor& end, sc
       [=](scalar_t& ret_val,
          const scalar_t& self_val,
          const scalar_t& end_val) {
-        ret_val = self_val + weight_val * (end_val - self_val);
+        ret_val = (weight_val < 0.5) ?
+            self_val + weight_val * (end_val - self_val) : end_val - (end_val - self_val) * (1 - weight_val);
       });
 }
 
index 1946427..d7a660e 100644 (file)
@@ -13,7 +13,8 @@ void lerp_cuda(at::Tensor& ret, const at::Tensor& self, const at::Tensor& end, c
          const scalar_t& self_val,
          const scalar_t& end_val,
          const scalar_t& weight_val) {
-        ret_val = self_val + weight_val * (end_val - self_val);
+        ret_val = (weight_val < 0.5) ?
+            self_val + weight_val * (end_val - self_val) : end_val - (end_val - self_val) * (1 - weight_val);
       });
 }
 
@@ -25,7 +26,8 @@ void lerp_cuda(at::Tensor& ret, const at::Tensor& self, const at::Tensor& end, s
          scalar_t& ret_val,
          const scalar_t& self_val,
          const scalar_t& end_val) {
-        ret_val = self_val + weight_val * (end_val - self_val);
+        ret_val = (weight_val < 0.5) ?
+            self_val + weight_val * (end_val - self_val) : end_val - (end_val - self_val) * (1 - weight_val);
       });
 }
 } // namespace