fixing trapezoid() comments for clarity (#64592)
authorKevin Tse <ktse@fb.com>
Wed, 8 Sep 2021 16:42:22 +0000 (09:42 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 16:45:46 +0000 (09:45 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64592

cc mruberry rgommers heitorschueroff

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D30785663

Pulled By: NivekT

fbshipit-source-id: e968687fbb83a59bb46ce6858c6caafa5aa04412

aten/src/ATen/native/Integration.cpp

index e57dc45..32311d6 100644 (file)
@@ -78,12 +78,12 @@ Tensor trapezoid(const Tensor& y, const Tensor& x, int64_t dim) {
     }
     TORCH_CHECK(y.scalar_type() != kBool && x.scalar_type() != kBool, "trapezoid: received a bool input for `x` or `y`, but bool is not supported")
     Tensor x_viewed;
+    // Note that we explicitly choose not to broadcast 'x' to match the shape of 'y' here because
+    // we want to follow NumPy's behavior of broadcasting 'dx' and 'dy' together after the differences are taken.
     if (x.dim() == 1) {
         // This step takes 'x' with dimension (n,), and returns 'x_view' with
-        // dimension (1,1,...,n,...,1,1) based on dim and y.dim() so that 'x'
-        // can be broadcasted later to match 'y'.
-        // Note: This behavior differs from numpy in that numpy tries to
-        // broadcast 'dx', but this tries to broadcast 'x' to match 'y' instead.
+        // dimension (1,1,...,n,...,1,1) based on dim and y.dim() so that, later on, 'dx'
+        // can be broadcast to match 'dy' at the correct dimensions.
         TORCH_CHECK(x.size(0) == y.size(dim), "trapezoid: There must be one `x` value for each sample point");
         DimVector new_sizes(y.dim(), 1); // shape = [1] * y.
         new_sizes[dim] = x.size(0); // shape[axis] = d.shape[0]
@@ -97,8 +97,8 @@ Tensor trapezoid(const Tensor& y, const Tensor& x, int64_t dim) {
     } else {
         x_viewed = x;
     }
-    // Note the .slice operation reduces the dimension along 'dim' by 1.
-    // The sizes of other dimensions are untouched.
+    // Note the .slice operation reduces the dimension along 'dim' by 1,
+    // while the sizes of other dimensions are untouched.
     Tensor x_left = x_viewed.slice(dim, 0, -1);
     Tensor x_right = x_viewed.slice(dim, 1);
 
@@ -129,11 +129,13 @@ Tensor cumulative_trapezoid(const Tensor& y, const Tensor& x, int64_t dim) {
     TORCH_CHECK(y.scalar_type() != kBool && x.scalar_type() != kBool, "cumulative_trapezoid: received a bool input for `x` or `y`, but bool is not supported")
     Tensor x_viewed;
     if (x.dim() == 1) {
+        // See trapezoid for implementation notes
         TORCH_CHECK(x.size(0) == y.size(dim), "cumulative_trapezoid: There must be one `x` value for each sample point");
         DimVector new_sizes(y.dim(), 1); // shape = [1] * y.
         new_sizes[dim] = x.size(0); // shape[axis] = d.shape[0]
         x_viewed = x.view(new_sizes);
     } else if (x.dim() < y.dim()) {
+        // See trapezoid for implementation notes
         DimVector new_sizes = add_padding_to_shape(x.sizes(), y.dim());
         x_viewed = x.view(new_sizes);
     } else {