}
TORCH_CHECK(y.scalar_type() != kBool && x.scalar_type() != kBool, "trapezoid: received a bool input for `x` or `y`, but bool is not supported")
Tensor x_viewed;
+ // Note that we explicitly choose not to broadcast 'x' to match the shape of 'y' here because
+ // we want to follow NumPy's behavior of broadcasting 'dx' and 'dy' together after the differences are taken.
if (x.dim() == 1) {
// This step takes 'x' with dimension (n,), and returns 'x_view' with
- // dimension (1,1,...,n,...,1,1) based on dim and y.dim() so that 'x'
- // can be broadcasted later to match 'y'.
- // Note: This behavior differs from numpy in that numpy tries to
- // broadcast 'dx', but this tries to broadcast 'x' to match 'y' instead.
+ // dimension (1,1,...,n,...,1,1) based on dim and y.dim() so that, later on, 'dx'
+ // can be broadcast to match 'dy' at the correct dimensions.
TORCH_CHECK(x.size(0) == y.size(dim), "trapezoid: There must be one `x` value for each sample point");
DimVector new_sizes(y.dim(), 1); // shape = [1] * y.
new_sizes[dim] = x.size(0); // shape[axis] = d.shape[0]
} else {
x_viewed = x;
}
- // Note the .slice operation reduces the dimension along 'dim' by 1.
- // The sizes of other dimensions are untouched.
+ // Note the .slice operation reduces the dimension along 'dim' by 1,
+ // while the sizes of other dimensions are untouched.
Tensor x_left = x_viewed.slice(dim, 0, -1);
Tensor x_right = x_viewed.slice(dim, 1);
TORCH_CHECK(y.scalar_type() != kBool && x.scalar_type() != kBool, "cumulative_trapezoid: received a bool input for `x` or `y`, but bool is not supported")
Tensor x_viewed;
if (x.dim() == 1) {
+ // See trapezoid for implementation notes
TORCH_CHECK(x.size(0) == y.size(dim), "cumulative_trapezoid: There must be one `x` value for each sample point");
DimVector new_sizes(y.dim(), 1); // shape = [1] * y.
new_sizes[dim] = x.size(0); // shape[axis] = d.shape[0]
x_viewed = x.view(new_sizes);
} else if (x.dim() < y.dim()) {
+ // See trapezoid for implementation notes
DimVector new_sizes = add_padding_to_shape(x.sizes(), y.dim());
x_viewed = x.view(new_sizes);
} else {