from torch._six import inf, nan
from torch.testing import (
- integral_types_and, floating_and_complex_types_and)
+ integral_types_and, floating_and_complex_types_and, get_all_dtypes)
from torch.testing._internal.common_utils import (
TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict,
IS_WINDOWS, make_tensor)
result = op(t, *args, dim=dim, **kwargs)
self.assertEqual(result.shape, _reduced_shape(t.shape, dim))
+ def _test_noncontiguous(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
+ """Helper method to test noncontiguous input tensors."""
+ assert not t.is_contiguous()
+
+ t_contig = t.contiguous()
+ for args, kwargs in op.generate_args_kwargs(t_contig, **reduction_kwargs):
+ kwargs.update(reduction_kwargs)
+ result = op(t, *args, **kwargs)
+ expected = op(t_contig, *args, **kwargs)
+ self.assertEqual(result, expected)
+
+ @ops(reduction_ops)
+ def test_noncontiguous_innermost(self, device, dtype, op: ReductionOpInfo):
+ """Tests reducing along noncontiguous innermost dimension."""
+ t = make_tensor((10, 10), device, dtype)
+ self._test_noncontiguous(op, t[:, ::2], dim=1)
+
+ @ops(reduction_ops)
+ def test_noncontiguous_outermost(self, device, dtype, op: ReductionOpInfo):
+ """Tests reducing along noncontiguous outermost dimension."""
+ t = make_tensor((10, 10), device, dtype)
+ self._test_noncontiguous(op, t[::2, :], dim=0)
+
+ @ops(reduction_ops)
+ def test_noncontiguous_all(self, device, dtype, op: ReductionOpInfo):
+ """Tests reducing all dimensions of a noncontiguous tensor."""
+ t = make_tensor((5, 5, 5), device, dtype)
+ self._test_noncontiguous(op, t[::2, ::3, 1:-1:2])
+
+ @ops(reduction_ops)
+ def test_noncontiguous_transposed(self, device, dtype, op: ReductionOpInfo):
+ """Tests reducing a transposed tensor."""
+ t = make_tensor((5, 5), device, dtype)
+ self._test_noncontiguous(op, t.T)
+
+ @ops(reduction_ops)
+ def test_noncontiguous_expanded(self, device, dtype, op: ReductionOpInfo):
+ """Tests reducing a tensor with expanded singleton dimensions."""
+ t = make_tensor((2, 3), device, dtype)
+ self._test_noncontiguous(op, t.unsqueeze(1).expand(-1, 5, -1))
+
+ # NumPy does not support BFloat16 so we don't test that against reference
+ # implementations. We also don't compare dtypes or test for different
+ # keepdim because we already have other tests covering those.
+ # The test_reference_testing in test_ops.py only uses the samples from
+ # sample_inputs_func which do not test as exhaustively as these tests.
+
+ def _test_ref(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
+ """Compares op against op.ref for the given input and reduction kwargs"""
+ for args, kwargs in op.generate_args_kwargs(t, **reduction_kwargs):
+ kwargs.update(reduction_kwargs)
+ result = op(t, *args, **kwargs)
+ expected = op.ref(t.detach().cpu().numpy(), *args, **kwargs)
+ self.assertEqual(result, expected, exact_dtype=False)
+
+ @ops(filter(lambda op: op.ref is not None, reduction_ops),
+ allowed_dtypes=get_all_dtypes(include_bfloat16=False))
+ def test_ref_scalar_input(self, device, dtype, op: ReductionOpInfo):
+ """Compares op against reference for scalar input tensors"""
+ self._test_ref(op, make_tensor([], device, dtype))
+
+ @ops(filter(lambda op: op.ref is not None, reduction_ops),
+ allowed_dtypes=get_all_dtypes(include_bfloat16=False))
+ def test_ref_small_input(self, device, dtype, op: ReductionOpInfo):
+ """Compares op against reference for small input tensors"""
+ t = make_tensor((5, 3, 4, 2), device, dtype, exclude_zero=True)
+ self._test_ref(op, t)
+ for dim in [0, 1, 3] + ([[0, 2], [1, 3]] if op.supports_multiple_dims else []):
+ self._test_ref(op, t, dim=dim)
+
+ @ops(filter(lambda op: op.ref is not None, reduction_ops),
+ allowed_dtypes=[torch.float32])
+ def test_ref_large_input_1D(self, device, dtype, op: ReductionOpInfo):
+ """Compares op against reference for a large 1D input tensor to check stability"""
+ self._test_ref(op, make_tensor((2 ** 20,), device, dtype, low=-1, high=2, exclude_zero=True))
+
+ @ops(filter(lambda op: op.ref is not None, reduction_ops),
+ allowed_dtypes=[torch.float32])
+ def test_ref_large_input_2D(self, device, dtype, op: ReductionOpInfo):
+ """Compares op against reference for a large 2D input tensor to test parallelism"""
+ t = make_tensor((32, 2 ** 16), device, dtype, low=-1, high=2, exclude_zero=True)
+ self._test_ref(op, t, dim=1)
+
+ @ops(filter(lambda op: op.ref is not None, reduction_ops),
+ allowed_dtypes=[torch.float32])
+ def test_ref_large_input_64bit_indexing(self, device, dtype, op: ReductionOpInfo):
+ """Compares op against reference for a very large input tensor that requires 64 bit indexing"""
+ self._test_ref(op, make_tensor((275000000,), device, dtype, low=-1, high=2, exclude_zero=True))
+
+ @ops(filter(lambda op: op.ref is not None, reduction_ops),
+ allowed_dtypes=get_all_dtypes(include_bfloat16=False))
+ def test_ref_duplicate_values(self, device, dtype, op: ReductionOpInfo):
+ """Compares op against reference for input tensors with duplicate values"""
+ t = make_tensor((8, 8), device, dtype, exclude_zero=True)
+ t[::2, ::2] = t[1::2, 1::2]
+ self._test_ref(op, t)
+ self._test_ref(op, t, dim=0)
+ self._test_ref(op, t, dim=1)
+
+ @ops(filter(lambda op: op.ref is not None, reduction_ops),
+ allowed_dtypes=[torch.float32, torch.complex64])
+ def test_ref_extremal_values(self, device, dtype, op: ReductionOpInfo):
+ """Compares op against reference for input tensors with extremal values"""
+ t = make_tensor((10,), device, dtype, exclude_zero=True)
+ extremals = [0, 1] + [nan, inf, -inf] if torch.is_floating_point(t) else []
+ for extremal in extremals:
+ t[5] = extremal
+ self._test_ref(op, t)
+
###########################################################################
# TODO: Legacy tests - port to ReductionOpInfo
###########################################################################
return tuple(map(to_numpy, x))
elif isinstance(x, dict):
return {k: to_numpy(v) for k, v in x.items()}
+ elif isinstance(x, torch.dtype):
+ return torch.empty(0, dtype=x).numpy().dtype
elif isinstance(x, (numbers.Number, bool, str)):
return x
"""Generates input tensors for testing reduction operators"""
yield make_tensor([], device, dtype, requires_grad=requires_grad)
yield make_tensor([2], device, dtype, requires_grad=requires_grad)
- yield make_tensor([2, 3], device, dtype, requires_grad=requires_grad, noncontiguous=True)
- yield make_tensor([3, 2, 1, 5], device, dtype, requires_grad=requires_grad)
+ yield make_tensor([3, 5], device, dtype, requires_grad=requires_grad, noncontiguous=True)
+ yield make_tensor([3, 2, 1, 2], device, dtype, requires_grad=requires_grad)
def _generate_reduction_kwargs(ndim, supports_multiple_dims=True):
# Override OpInfo defaults and call base class __init__
kwargs.setdefault('inplace_variant', None)
kwargs.setdefault('sample_inputs_func', sample_inputs_func)
+ kwargs.setdefault('default_test_dtypes', (
+ torch.uint8, torch.int64, torch.float16, torch.bfloat16, torch.float32, torch.complex64))
super(ReductionOpInfo, self).__init__(name, **kwargs)
self.identity = identity
return list(generator())
-def sample_inputs_prod(op_info, device, dtype, requires_grad):
- def make_arg(shape):
- # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
- return make_tensor(shape, device, dtype, low=-1, high=+1, requires_grad=requires_grad)
-
- def prod_single_zero():
- result = make_arg(2 * (S,))
- with torch.no_grad():
- result[0, 1] = 0
- return result
-
- # will not be needed once OpInfo tests support Iterables
- def sample_generator():
- for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad):
- yield SampleInput(sample.input) # only Tensor, ignore other inputs
- yield sample
- sample.kwargs['keepdim'] = True
- yield sample
- yield SampleInput(prod_single_zero())
- yield SampleInput(make_arg((3, 3, 3)), args=(1,))
- yield SampleInput(make_arg((3, 3, 3)), args=(1,), kwargs={'keepdim': True})
-
- # test zero scalar tensor
- zero = make_arg(())
- with torch.no_grad():
- zero.zero_()
- yield SampleInput(zero)
- yield SampleInput(zero, args=(0,))
- yield SampleInput(zero, args=(0,), kwargs={'keepdim': True})
-
- return list(sample_generator())
-
def sample_inputs_nextafter(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
return op(input.triu() if upper else input.tril(), upper)
+def reference_reduction_numpy(f, supports_keepdims=True):
+ """Wraps a NumPy reduction operator.
+
+ The wrapper function will forward dim and keepdim kwargs to the wrapped
+ function as the NumPy equivalent axis and keepdims kwargs.
+
+ Args:
+ f: NumPy reduction operator to wrap
+ supports_keepdims (bool, optional): Whether the NumPy operator accepts
+ keepdims parameter. If it does not, the wrapper will manually unsqueeze
+ the reduced dimensions if it was called with keepdim=True. Defaults to True.
+
+ Returns:
+ Wrapped function
+ """
+ @wraps(f)
+ def wrapper(x: np.ndarray, *args, **kwargs):
+ # Copy keys into a set
+ keys = set(kwargs.keys())
+
+ dim = kwargs.pop('dim', None)
+ keepdim = kwargs.pop('keepdim', False)
+
+ if 'dim' in keys:
+ if x.ndim == 0:
+ # NumPy reductions don't accept dim=0 for scalar inputs
+ for i in dim if isinstance(dim, tuple) else (dim,):
+ assert i in {0, -1}
+ kwargs['axis'] = None
+ else:
+ kwargs['axis'] = tuple(dim) if isinstance(dim, Sequence) else dim
+
+ if 'keepdim' in keys and supports_keepdims:
+ kwargs['keepdims'] = keepdim
+
+ result = f(x, *args, **kwargs)
+
+ # Unsqueeze reduced dimensions if NumPy does not support keepdims
+ if keepdim and not supports_keepdims and x.ndim > 0:
+ dim = list(range(x.ndim)) if dim is None else dim
+ result = np.expand_dims(result, dim)
+
+ return result
+
+ return wrapper
+
+
# Operator database (sorted alphabetically)
op_db: List[OpInfo] = [
UnaryUfuncInfo('abs',
supports_out=False,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_max_min_reduction_no_dim,),
- # TODO(@heitorschueroff) Add test for dtype kwarg
- OpInfo('mean',
- dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
- assert_autodiffed=True,
- supports_forward_ad=True,
- sample_inputs_func=sample_inputs_reduction,
- # Need to skip out test because one of the overload for mean does not support it
- # TODO(@heitorschueroff) fix this when implementing ReductionInfo
- skips=(SkipInfo('TestCommon', 'test_out'),)),
OpInfo('quantile',
dtypes=floating_types(),
sample_inputs_func=sample_inputs_reduction_quantile),
supports_autograd=False,
result_dtype=torch.bool,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+ ref=reference_reduction_numpy(np.all),
skips=(
# FIXME: does not support passing keepdim without dim
SkipInfo('TestReductions', 'test_dim_default_keepdim'),
SkipInfo('TestReductions', 'test_dim_none'),
SkipInfo('TestReductions', 'test_dim_none_keepdim'),
# FIXME: uint8 input returns uint8 instead of bool
- SkipInfo('TestReductions', 'test_result_dtype', dtypes=[torch.uint8]),
+ SkipInfo('TestReductions', 'test_result_dtype',
+ dtypes=[torch.uint8]),
),
),
ReductionOpInfo(
supports_autograd=False,
result_dtype=torch.bool,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+ ref=reference_reduction_numpy(np.any),
skips=(
# FIXME: does not support passing keepdim without dim
SkipInfo('TestReductions', 'test_dim_default_keepdim'),
SkipInfo('TestReductions', 'test_dim_none'),
SkipInfo('TestReductions', 'test_dim_none_keepdim'),
# FIXME: uint8 input returns uint8 instead of bool
- SkipInfo('TestReductions', 'test_result_dtype', dtypes=[torch.uint8]),
+ SkipInfo('TestReductions', 'test_result_dtype',
+ dtypes=[torch.uint8]),
),
),
ReductionOpInfo(
'amax',
nan_policy='propagate',
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
- ref=lambda a, dim=None, keepdim=False, **kwargs: np.amax(a, axis=dim, keepdims=keepdim, **kwargs),
+ ref=reference_reduction_numpy(np.amax),
skips=(
# FIXME: sum reduces all dimensions when dim=[]
SkipInfo('TestReductions', 'test_dim_empty'),
'amin',
nan_policy='propagate',
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
- ref=lambda a, dim=None, keepdim=False, **kwargs: np.amin(a, axis=dim, keepdims=keepdim, **kwargs),
+ ref=reference_reduction_numpy(np.amin),
skips=(
# FIXME: sum reduces all dimensions when dim=[]
SkipInfo('TestReductions', 'test_dim_empty'),
supports_autograd=False,
result_dtype=torch.int64,
dtypes=all_types_and(torch.float16, torch.bfloat16),
+ ref=reference_reduction_numpy(np.argmax, supports_keepdims=False),
skips=(
# FIXME: keepdim parameter is ignored when dim=None
SkipInfo('TestReductions', 'test_dim_default_keepdim'),
supports_autograd=False,
result_dtype=torch.int64,
dtypes=all_types_and(torch.float16, torch.bfloat16),
+ ref=reference_reduction_numpy(np.argmin, supports_keepdims=False),
skips=(
# FIXME: keepdim parameter is ignored when dim=None
SkipInfo('TestReductions', 'test_dim_default_keepdim'),
result_dtype=torch.int64,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_reduction_count_nonzero,
+ ref=reference_reduction_numpy(np.count_nonzero),
skips=(
# FIXME: count_nonzero does not accept keepdim kwarg
SkipInfo('TestReductions', 'test_dim_default_keepdim'),
),
),
ReductionOpInfo(
+ 'mean',
+ nan_policy='propagate',
+ supports_out=False,
+ supports_forward_ad=True,
+ assert_autodiffed=True,
+ promotes_int_to_float=True,
+ dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+ ref=reference_reduction_numpy(np.mean),
+ decorators=(
+ # FIXME: fix precision
+ DecorateInfo(toleranceOverride({
+ torch.float16: tol(atol=1e-05, rtol=1e-02),
+ }), 'TestReductions', 'test_noncontiguous_all'),
+ DecorateInfo(toleranceOverride({
+ torch.float16: tol(atol=1e-05, rtol=1e-02),
+ }), 'TestReductions', 'test_ref_small_input'),
+ ),
+ skips=(
+ # FIXME: prod does not support passing keepdim without passing dim
+ SkipInfo('TestReductions', 'test_dim_default_keepdim'),
+ # FIXME: prod reduces all dimensions when dim=[]
+ SkipInfo('TestReductions', 'test_dim_empty'),
+ SkipInfo('TestReductions', 'test_dim_empty_keepdim'),
+ # FIXME: prod does not support passing None to dim
+ SkipInfo('TestReductions', 'test_dim_none'),
+ SkipInfo('TestReductions', 'test_dim_none_keepdim'),
+ ),
+ ),
+ ReductionOpInfo(
'prod',
identity=1,
nan_policy='propagate',
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
- sample_inputs_func=sample_inputs_prod,
+ ref=reference_reduction_numpy(np.prod),
skips=(
# FIXME: prod does not support passing keepdim without passing dim
SkipInfo('TestReductions', 'test_dim_default_keepdim'),
# FIXME: prod does not support passing None to dim
SkipInfo('TestReductions', 'test_dim_none'),
SkipInfo('TestReductions', 'test_dim_none_keepdim'),
+ # FIXME: improve precision, failing with nan != inf
+ SkipInfo('TestReductions', 'test_ref_small_input',
+ dtypes=[torch.float16, torch.complex64]),
+ SkipInfo('TestReductions', 'test_ref_duplicate_values',
+ dtypes=[torch.uint8, torch.float16, torch.complex64]),
),
),
ReductionOpInfo(
supports_forward_ad=True,
promotes_int_to_int64=True,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+ ref=reference_reduction_numpy(np.sum),
+ decorators=(
+ # FIXME: fix precision
+ DecorateInfo(toleranceOverride({
+ torch.float16: tol(atol=1e-05, rtol=1e-02),
+ }), 'TestReductions', 'test_noncontiguous_all'),
+ DecorateInfo(toleranceOverride({
+ torch.float16: tol(atol=1e-03, rtol=1e-02),
+ }), 'TestReductions', 'test_ref_small_input'),
+ DecorateInfo(toleranceOverride({
+ torch.float32: tol(atol=1e-03, rtol=1e-03),
+ }), 'TestReductions', 'test_ref_large_input_64bit_indexing'),
+ DecorateInfo(toleranceOverride({
+ torch.float16: tol(atol=1e-05, rtol=1e-02),
+ }), 'TestReductions', 'test_ref_duplicate_values'),
+ ),
skips=(
# FIXME: sum does not support passing keepdim without passing dim
SkipInfo('TestReductions', 'test_dim_default_keepdim'),
supports_out=False,
promotes_int_to_int64=True,
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+ ref=reference_reduction_numpy(np.nansum),
+ decorators=(
+ # FIXME: fix precision
+ DecorateInfo(toleranceOverride({
+ torch.float16: tol(atol=1e-05, rtol=1e-02),
+ }), 'TestReductions', 'test_noncontiguous_all'),
+ DecorateInfo(toleranceOverride({
+ torch.float16: tol(atol=1e-03, rtol=1e-02),
+ }), 'TestReductions', 'test_ref_small_input'),
+ DecorateInfo(toleranceOverride({
+ torch.float32: tol(atol=1e-03, rtol=1e-03),
+ }), 'TestReductions', 'test_ref_large_input_64bit_indexing'),
+ DecorateInfo(toleranceOverride({
+ torch.float16: tol(atol=1e-05, rtol=1e-02),
+ }), 'TestReductions', 'test_ref_duplicate_values'),
+ ),
skips=(
# FIXME: nansum does not support passing keepdim without passing dim
SkipInfo('TestReductions', 'test_dim_default_keepdim'),