fix resize bug (#61166)
authorBBuf <1182563586@qq.com>
Fri, 27 Aug 2021 17:42:24 +0000 (10:42 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 17:49:31 +0000 (10:49 -0700)
Summary:
I think the original intention here is to only take effect in the case of align_corners (because output_size = 1 and the divisor will be 0), but it affects non-align_corners too. For example:

```python
input = torch.tensor(
        np.arange(1, 5, dtype=np.int32).reshape((1, 1, 2, 2)) )
m = torch.nn.Upsample(scale_factor=0.5, mode="bilinear")
of_out = m(input)
```

The result we expect should be [[[[2.5]]]]

but pytorch get [[[[1.0]]]] which is different from OpenCV  and PIL, this pr try to fixed it。

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61166

Reviewed By: malfet

Differential Revision: D30543178

Pulled By: heitorschueroff

fbshipit-source-id: 21a4035483981986b0ae4a401ef0efbc565ccaf1

aten/src/ATen/native/UpSample.h
aten/src/ATen/native/cuda/UpSample.cuh
test/test_nn.py

index e50b053..602abce 100644 (file)
@@ -251,12 +251,16 @@ static inline scalar_t area_pixel_compute_scale(
     bool align_corners,
     const c10::optional<double> scale) {
   // see Note [area_pixel_compute_scale]
-  if (output_size > 1) {
-    return align_corners
-        ? static_cast<scalar_t>(input_size - 1) / (output_size - 1)
-        : compute_scales_value<scalar_t>(scale, input_size, output_size);
-  } else {
-    return scalar_t(0);
+  if(align_corners){
+    if(output_size > 1) {
+      return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
+    }
+    else {
+      return static_cast<scalar_t>(0);
+    }
+  }
+  else{
+    return compute_scales_value<scalar_t>(scale, input_size, output_size);
   }
 }
 
index 71443e1..c69a259 100644 (file)
@@ -94,11 +94,16 @@ __host__ __forceinline__ static accscalar_t area_pixel_compute_scale(
     int output_size,
     bool align_corners,
     const c10::optional<double> scale) {
-  if (output_size > 1) {
-    return align_corners ? (accscalar_t)(input_size - 1) / (output_size - 1)
-                         :  compute_scales_value<accscalar_t>(scale, input_size, output_size);
-  } else {
-    return static_cast<accscalar_t>(0);
+  if(align_corners) {
+    if(output_size > 1) {
+      return (accscalar_t)(input_size - 1) / (output_size - 1);
+    }
+    else {
+      return static_cast<accscalar_t>(0);
+    }
+  }
+  else{
+    return compute_scales_value<accscalar_t>(scale, input_size, output_size);
   }
 }
 
index c6fe0b2..4e01c94 100644 (file)
@@ -10475,6 +10475,13 @@ class TestNN(NNTestCase):
             out_t_5 = m(in_t_9[:, :, :5, :5, :5])
         self.assertEqual(out_t_9[:, :, :15, :15, :15], out_t_5)
 
+    def test_upsampling_small_scale(self):
+        m = torch.nn.Upsample(scale_factor=0.5, mode="bilinear")
+        in_t = torch.arange(1, 5, dtype=torch.float64).reshape(1, 1, 2, 2)
+        out_t = m(in_t)
+        expected_out_t = torch.tensor([[[[2.5]]]])
+        self.assertEqual(expected_out_t, out_t)
+
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
     def test_interpolate_illegal_memory_access(self):
         in_s = 45