From d37636901ed1c65c1f8b68e36e37e59eb503c554 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Mon, 30 Aug 2021 12:16:23 -0700 Subject: [PATCH] [Doc] `make_tensor` to `torch.testing` module (#63925) Summary: This PR aims to add `make_tensor` to the `torch.testing` module in PyTorch docs. TODOs: * [x] Add examples cc: pmeier mruberry brianjo Pull Request resolved: https://github.com/pytorch/pytorch/pull/63925 Reviewed By: ngimel Differential Revision: D30633487 Pulled By: mruberry fbshipit-source-id: 8e5a1f880c6ece5925b4039fee8122bd739538af --- docs/source/testing.rst | 1 + test/test_autograd.py | 3 +- test/test_binary_ufuncs.py | 4 +- test/test_buffer_protocol.py | 7 +- test/test_foreach.py | 4 +- test/test_indexing.py | 3 +- test/test_jit.py | 3 +- test/test_linalg.py | 4 +- test/test_ops.py | 4 +- test/test_reductions.py | 4 +- test/test_shape_ops.py | 3 +- test/test_sort_and_select.py | 4 +- test/test_sparse.py | 3 +- test/test_sparse_csr.py | 4 +- test/test_tensor_creation_ops.py | 3 +- test/test_testing.py | 3 +- test/test_torch.py | 3 +- test/test_unary_ufuncs.py | 4 +- test/test_view_ops.py | 3 +- torch/testing/__init__.py | 1 + torch/testing/_creation.py | 155 +++++++++++++++++++++ .../_internal/common_methods_invocations.py | 4 +- torch/testing/_internal/common_modules.py | 26 ++-- torch/testing/_internal/common_utils.py | 100 +------------ 24 files changed, 213 insertions(+), 140 deletions(-) create mode 100644 torch/testing/_creation.py diff --git a/docs/source/testing.rst b/docs/source/testing.rst index 981a636..9f1e2c3 100644 --- a/docs/source/testing.rst +++ b/docs/source/testing.rst @@ -9,3 +9,4 @@ torch.testing .. automodule:: torch.testing .. autofunction:: assert_close +.. autofunction:: make_tensor diff --git a/test/test_autograd.py b/test/test_autograd.py index 4d41645..364d488 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -24,13 +24,14 @@ from torch.autograd.profiler import (profile, record_function, emit_nvtx) from torch.autograd.profiler_util import (_format_time, EventList, FunctionEvent, FunctionEventAvg) import torch.autograd.functional as autogradF from torch.utils.checkpoint import checkpoint +from torch.testing import make_tensor from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack, suppress_warnings, slowTest, load_tests, IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck, TEST_WITH_ROCM, disable_gc, - gradcheck, gradgradcheck, make_tensor) + gradcheck, gradgradcheck) from torch.autograd import Variable, Function, detect_anomaly, kineto_available from torch.autograd.function import InplaceFunction import torch.autograd.forward_ad as fwAD diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 4995e0d..1e9e804 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -13,12 +13,12 @@ from functools import partial from torch._six import inf, nan from torch.testing._internal.common_utils import ( TestCase, iter_indices, TEST_WITH_ASAN, run_tests, - torch_to_numpy_dtype_dict, make_tensor, TEST_SCIPY, set_default_dtype) + torch_to_numpy_dtype_dict, TEST_SCIPY, set_default_dtype) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyOnCPUAndCUDA, skipCUDAIfRocm, skipIf, ops) -from torch.testing import all_types_and_complex_and, integral_types_and +from torch.testing import all_types_and_complex_and, integral_types_and, make_tensor from torch.testing._internal.common_methods_invocations import binary_ufuncs if TEST_SCIPY: diff --git a/test/test_buffer_protocol.py b/test/test_buffer_protocol.py index c797b91..619386e 100644 --- a/test/test_buffer_protocol.py +++ b/test/test_buffer_protocol.py @@ -1,4 +1,5 @@ import torch.testing._internal.common_utils as common +from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, dtypes @@ -23,7 +24,7 @@ class TestBufferProtocol(common.TestCase): if offset is None: offset = first * get_dtype_size(dtype) - numpy_original = common.make_tensor(shape, torch.device("cpu"), dtype).numpy() + numpy_original = make_tensor(shape, torch.device("cpu"), dtype).numpy() original = memoryview(numpy_original) # First call PyTorch's version in case of errors. # If this call exits successfully, the NumPy version must also do so. @@ -125,7 +126,7 @@ class TestBufferProtocol(common.TestCase): @dtypes(*common.torch_to_numpy_dtype_dict.keys()) def test_shared_buffer(self, device, dtype): - x = common.make_tensor((1,), device, dtype) + x = make_tensor((1,), device, dtype) # Modify the whole tensor arr, tensor = self._run_test(SHAPE, dtype) tensor[:] = x @@ -158,7 +159,7 @@ class TestBufferProtocol(common.TestCase): @dtypes(*common.torch_to_numpy_dtype_dict.keys()) def test_non_writable_buffer(self, device, dtype): - numpy_arr = common.make_tensor((1,), device, dtype).numpy() + numpy_arr = make_tensor((1,), device, dtype).numpy() byte_arr = numpy_arr.tobytes() with self.assertWarnsOnceRegex(UserWarning, r"The given buffer is not writable."): diff --git a/test/test_foreach.py b/test/test_foreach.py index ce9b0d7..123ef35 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -4,11 +4,13 @@ import random import re import torch import unittest + +from torch.testing import make_tensor from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_SLOW from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, onlyCUDA, skipCUDAIfRocm, skipMeta, ops) from torch.testing._internal.common_methods_invocations import \ - (foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db, foreach_minmax_op_db, make_tensor) + (foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db, foreach_minmax_op_db) # Includes some values such that N * N won't be a multiple of 4, # which should ensure we test the vectorized and non-vectorized diff --git a/test/test_indexing.py b/test/test_indexing.py index 6158091..8b8a2ea 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -8,7 +8,8 @@ from functools import reduce import numpy as np -from torch.testing._internal.common_utils import TestCase, run_tests, make_tensor +from torch.testing import make_tensor +from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA, onlyOnCPUAndCUDA) diff --git a/test/test_jit.py b/test/test_jit.py index 2595411..d1a170d 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -69,8 +69,7 @@ from torch._six import PY37 from torch.autograd import Variable from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401 from torch.nn.utils.rnn import PackedSequence -from torch.testing import FileCheck -from torch.testing._internal.common_utils import make_tensor +from torch.testing import FileCheck, make_tensor import torch.autograd.profiler import torch.cuda import torch.jit diff --git a/test/test_linalg.py b/test/test_linalg.py index 8ba3373..f7ce392 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -14,14 +14,14 @@ from functools import reduce from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest, - TEST_WITH_ASAN, make_tensor, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, + TEST_WITH_ASAN, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices, gradcheck, gradgradcheck) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA, dtypesIfCUDA, onlyCUDA, skipCUDAVersionIn, skipMeta, skipCUDAIfNoCusolver) -from torch.testing import floating_and_complex_types, floating_types, all_types +from torch.testing import floating_and_complex_types, floating_types, all_types, make_tensor from torch.testing._internal.common_cuda import SM53OrLater, tf32_on_and_off, CUDA11OrLater, CUDA9 from torch.distributions.binomial import Binomial diff --git a/test/test_ops.py b/test/test_ops.py index a6baf8d..27aee72 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -5,9 +5,9 @@ import warnings import torch from torch.testing import \ - (FileCheck, floating_and_complex_types_and, get_all_dtypes) + (FileCheck, floating_and_complex_types_and, get_all_dtypes, make_tensor) from torch.testing._internal.common_utils import \ - (TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper, make_tensor, + (TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper, gradcheck, gradgradcheck, IS_IN_CI, suppress_warnings) from torch.testing._internal.common_methods_invocations import \ (op_db, _NOTHING, UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo) diff --git a/test/test_reductions.py b/test/test_reductions.py index eed7f73..ca3042b 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -10,10 +10,10 @@ import warnings from torch._six import inf, nan from torch.testing import ( - integral_types_and, floating_and_complex_types_and, get_all_dtypes) + integral_types_and, floating_and_complex_types_and, get_all_dtypes, make_tensor) from torch.testing._internal.common_utils import ( TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict, - IS_WINDOWS, make_tensor) + IS_WINDOWS) from torch.testing._internal.common_device_type import ( OpDTypes, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, onlyOnCPUAndCUDA, onlyCUDA, largeTensorTest, ops, precisionOverride) diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py index 916adee..cb4ec3c 100644 --- a/test/test_shape_ops.py +++ b/test/test_shape_ops.py @@ -7,8 +7,9 @@ import random import warnings from torch._six import nan +from torch.testing import make_tensor from torch.testing._internal.common_utils import ( - TestCase, run_tests, make_tensor, torch_to_numpy_dtype_dict) + TestCase, run_tests, torch_to_numpy_dtype_dict) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyOnCPUAndCUDA, dtypesIfCPU, dtypesIfCUDA, largeTensorTest) diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 564258a..e562e38 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -5,9 +5,9 @@ import random from torch._six import nan from itertools import permutations, product -from torch.testing import all_types, all_types_and +from torch.testing import all_types, all_types_and, make_tensor from torch.testing._internal.common_utils import \ - (TEST_WITH_ROCM, TestCase, run_tests, make_tensor, slowTest) + (TEST_WITH_ROCM, TestCase, run_tests, slowTest) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA, skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA, dtypesIfCPU, onlyCPU, largeTensorTest) diff --git a/test/test_sparse.py b/test/test_sparse.py index abe5e93..333f29f 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -5,8 +5,9 @@ import operator import random from collections import defaultdict import unittest +from torch.testing import make_tensor from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ - do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, make_tensor, \ + do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \ DeterministicGuard from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version from numbers import Number diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index b9f4885..fbb2b30 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -3,8 +3,10 @@ import warnings import unittest import random import itertools + +from torch.testing import make_tensor from torch.testing._internal.common_utils import \ - (IS_MACOS, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff, make_tensor) + (IS_MACOS, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, onlyCPU, onlyCUDA) diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 192e03f..9ef3742 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -8,9 +8,10 @@ import unittest from itertools import product, combinations, combinations_with_replacement, permutations import random +from torch.testing import make_tensor from torch.testing._internal.common_utils import ( TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, - torch_to_numpy_dtype_dict, slowTest, make_tensor, TEST_SCIPY, IS_MACOS, IS_PPC, + torch_to_numpy_dtype_dict, slowTest, TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA, diff --git a/test/test_testing.py b/test/test_testing.py index 7e67569..f38183d 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -10,8 +10,9 @@ from typing import Any, Callable, Iterator, List, Tuple import torch +from torch.testing import make_tensor from torch.testing._internal.common_utils import \ - (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest) + (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, skipIfRocm, slowTest) from torch.testing._internal.common_device_type import \ (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes, get_device_type_test_bases, instantiate_device_type_tests, onlyCUDA, onlyOnCPUAndCUDA, diff --git a/test/test_torch.py b/test/test_torch.py index 15e36c8..c50b7ca 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -27,13 +27,14 @@ from torch._six import inf, nan, string_classes from itertools import product, combinations, permutations from functools import partial from torch import multiprocessing as mp +from torch.testing import make_tensor from torch.testing._internal.common_utils import ( TestCase, TEST_WITH_ROCM, run_tests, IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest, skipCUDAMemoryLeakCheckIf, BytesIOContext, noarchTest, skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, - wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, make_tensor) + wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index e5b8c4a..22f6151 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -11,7 +11,7 @@ import unittest from torch._six import inf, nan from torch.testing._internal.common_utils import ( TestCase, run_tests, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, - suppress_warnings, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy, IS_WINDOWS) + suppress_warnings, TEST_SCIPY, slowTest, skipIfNoSciPy, IS_WINDOWS) from torch.testing._internal.common_methods_invocations import ( unary_ufuncs, _NOTHING) from torch.testing._internal.common_device_type import ( @@ -19,7 +19,7 @@ from torch.testing._internal.common_device_type import ( onlyCUDA, dtypesIfCUDA, precisionOverride, skipCUDAIfRocm, dtypesIfCPU, OpDTypes) from torch.testing import ( - floating_types_and, all_types_and_complex_and, floating_and_complex_types_and) + floating_types_and, all_types_and_complex_and, floating_and_complex_types_and, make_tensor) if TEST_SCIPY: import scipy diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 306c6cb..7bb6906 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -6,8 +6,9 @@ from itertools import product, permutations, combinations from functools import partial import random +from torch.testing import make_tensor from torch.testing._internal.common_utils import \ - (TestCase, run_tests, suppress_warnings, make_tensor) + (TestCase, run_tests, suppress_warnings) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, onlyCPU, dtypes, onlyOnCPUAndCUDA) diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index 526d02c..7ea18a4 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -1,4 +1,5 @@ from ._core import * # noqa: F403 from ._asserts import * # noqa: F403 +from ._creation import * # noqa: F403 from ._check_kernel_launches import * # noqa: F403 from ._deprecated import * # noqa: F403 diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py new file mode 100644 index 0000000..4eb10d1 --- /dev/null +++ b/torch/testing/_creation.py @@ -0,0 +1,155 @@ +""" +This module contains tensor creation utilities. +""" + +import torch +from typing import Optional, List, Tuple, Union, cast +import math + +__all__ = [ + "make_tensor", +] + +def make_tensor( + shape: Union[torch.Size, List[int], Tuple[int, ...]], + device: Union[str, torch.device], + dtype: torch.dtype, + *, + low: Optional[float] = None, + high: Optional[float] = None, + requires_grad: bool = False, + noncontiguous: bool = False, + exclude_zero: bool = False +) -> torch.Tensor: + r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with + values uniformly drawn from ``[low, high)``. + + If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable + finite values then they are clamped to the lowest or highest representable finite value, respectively. + If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`, + which depend on :attr:`dtype`. + + +---------------------------+------------+----------+ + | ``dtype`` | ``low`` | ``high`` | + +===========================+============+==========+ + | boolean type | ``0`` | ``2`` | + +---------------------------+------------+----------+ + | unsigned integral type | ``0`` | ``10`` | + +---------------------------+------------+----------+ + | signed integral types | ``-9`` | ``10`` | + +---------------------------+------------+----------+ + | floating types | ``-9`` | ``9`` | + +---------------------------+------------+----------+ + | complex types | ``-9`` | ``9`` | + +---------------------------+------------+----------+ + + Args: + shape (Tuple[int, ...]): A sequence of integers defining the shape of the output tensor. + device (Union[str, torch.device]): The device of the returned tensor. + dtype (:class:`torch.dtype`): The data type of the returned tensor. + low (Optional[Number]): Sets the lower limit (inclusive) of the given range. If a number is provided it is + clamped to the least representable finite value of the given dtype. When ``None`` (default), + this value is determined based on the :attr:`dtype` (see the table above). Default: ``None``. + high (Optional[Number]): Sets the upper limit (exclusive) of the given range. If a number is provided it is + clamped to the greatest representable finite value of the given dtype. When ``None`` (default) this value + is determined based on the :attr:`dtype` (see the table above). Default: ``None``. + requires_grad (Optional[bool]): If autograd should record operations on the returned tensor. Default: ``False``. + noncontiguous (Optional[bool]): If `True`, the returned tensor will be noncontiguous. This argument is + ignored if the constructed tensor has fewer than two elements. + exclude_zero (Optional[bool]): If ``True`` then zeros are replaced with the dtype's small positive value + depending on the :attr:`dtype`. For bool and integer types zero is replaced with one. For floating + point types it is replaced with the dtype's smallest positive normal number (the "tiny" value of the + :attr:`dtype`'s :func:`~torch.finfo` object), and for complex types it is replaced with a complex number + whose real and imaginary parts are both the smallest positive normal number representable by the complex + type. Default ``False``. + + Raises: + ValueError: If ``low > high``. + ValueError: If either :attr:`low` or :attr:`high` is ``nan``. + TypeError: If :attr:`dtype` isn't supported by this function. + + Examples: + >>> from torch.testing import make_tensor + >>> # Creates a float tensor with values in [-1, 1) + >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) + tensor([ 0.1205, 0.2282, -0.6380]) + >>> # Creates a bool tensor on CUDA + >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) + tensor([[False, False], + [False, True]], device='cuda:0') + """ + def _modify_low_high(low, high, lowest, highest, default_low, default_high, dtype): + """ + Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) if required. + """ + def clamp(a, l, h): + return min(max(a, l), h) + + low = low if low is not None else default_low + high = high if high is not None else default_high + + # Checks for error cases + if low != low or high != high: + raise ValueError("make_tensor: one of low or high was NaN!") + if low > high: + raise ValueError("make_tensor: low must be weakly less than high!") + + low = clamp(low, lowest, highest) + high = clamp(high, lowest, highest) + + if dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]: + return math.floor(low), math.ceil(high) + + return low, high + + _integral_types = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] + _floating_types = [torch.float16, torch.bfloat16, torch.float32, torch.float64] + _complex_types = [torch.cfloat, torch.cdouble] + + if dtype is torch.bool: + result = torch.randint(0, 2, shape, device=device, dtype=dtype) + elif dtype is torch.uint8: + ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) + low, high = cast(Tuple[int, int], _modify_low_high(low, high, ranges[0], ranges[1], 0, 10, dtype)) + result = torch.randint(low, high, shape, device=device, dtype=dtype) + elif dtype in _integral_types: + ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) + low, high = _modify_low_high(low, high, ranges[0], ranges[1], -9, 10, dtype) + result = torch.randint(low, high, shape, device=device, dtype=dtype) # type: ignore[call-overload] + elif dtype in _floating_types: + ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max) + low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype) + rand_val = torch.rand(shape, device=device, dtype=dtype) + result = high * rand_val + low * (1 - rand_val) + elif dtype in _complex_types: + float_dtype = torch.float if dtype is torch.cfloat else torch.double + ranges_floats = (torch.finfo(float_dtype).min, torch.finfo(float_dtype).max) + low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype) + real_rand_val = torch.rand(shape, device=device, dtype=float_dtype) + imag_rand_val = torch.rand(shape, device=device, dtype=float_dtype) + real = high * real_rand_val + low * (1 - real_rand_val) + imag = high * imag_rand_val + low * (1 - imag_rand_val) + result = torch.complex(real, imag) + else: + raise TypeError(f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()." + " To request support, file an issue at: https://github.com/pytorch/pytorch/issues") + + if noncontiguous and result.numel() > 1: + result = torch.repeat_interleave(result, 2, dim=-1) + result = result[..., ::2] + + if exclude_zero: + if dtype in _integral_types or dtype is torch.bool: + replace_with = torch.tensor(1, device=device, dtype=dtype) + elif dtype in _floating_types: + replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=dtype) + else: # dtype in _complex_types: + float_dtype = torch.float if dtype is torch.cfloat else torch.double + float_eps = torch.tensor(torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype) + replace_with = torch.complex(float_eps, float_eps) + result[result == 0] = replace_with + + if dtype in _floating_types + _complex_types: + result.requires_grad = requires_grad + + return result diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 2230808..a3d61b4 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -19,7 +19,7 @@ from torch.testing import \ (make_non_contiguous, floating_types, floating_types_and, complex_types, floating_and_complex_types, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, - integral_types_and, all_types, double_types) + integral_types_and, all_types, double_types, make_tensor) from .._core import _dispatch_dtypes from torch.testing._internal.common_device_type import \ (onlyOnCPUAndCUDA, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfNoCusolver, @@ -32,7 +32,7 @@ from torch.testing._internal.common_utils import \ random_symmetric_pd_matrix, make_symmetric_matrices, make_symmetric_pd_matrices, random_square_matrix_of_rank, random_fullrank_matrix_distinct_singular_value, - TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY, + TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY, torch_to_numpy_dtype_dict, TEST_WITH_ASAN, GRADCHECK_NONDET_TOL,) import torch.testing._internal.opinfo_helper as opinfo_helper diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 99525a7..6ef4de3 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -2,11 +2,11 @@ import torch from copy import deepcopy from functools import wraps, partial from itertools import chain -from torch.testing import floating_types +from torch.testing import floating_types, make_tensor from torch.testing._internal.common_device_type import ( _TestParametrizer, _dtype_test_suffix, _update_param_kwargs, skipIf) from torch.testing._internal.common_nn import nllloss_reference, get_reduction -from torch.testing._internal.common_utils import make_tensor, freeze_rng_state +from torch.testing._internal.common_utils import freeze_rng_state from types import ModuleType from typing import List, Tuple, Type, Set, Dict @@ -225,7 +225,7 @@ def generate_regression_criterion_inputs(make_input): return [ ModuleInput( constructor_input=FunctionInput(reduction=reduction), - forward_input=FunctionInput(make_input(size=(4, )), make_input(size=4,)), + forward_input=FunctionInput(make_input(shape=(4, )), make_input(shape=4,)), reference_fn=no_batch_dim_reference_criterion_fn, desc='no_batch_dim_{}'.format(reduction) ) for reduction in ['none', 'mean', 'sum']] @@ -236,7 +236,7 @@ def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, return [ ModuleInput(constructor_input=FunctionInput(kernel_size=2), - forward_input=FunctionInput(make_input(size=(3, 6))), + forward_input=FunctionInput(make_input(shape=(3, 6))), desc='no_batch_dim', reference_fn=no_batch_dim_reference_fn)] @@ -246,13 +246,13 @@ def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwar return [ ModuleInput(constructor_input=FunctionInput(alpha=2.), - forward_input=FunctionInput(make_input(size=(3, 2, 5))), + forward_input=FunctionInput(make_input(shape=(3, 2, 5))), reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))), ModuleInput(constructor_input=FunctionInput(alpha=2.), - forward_input=FunctionInput(make_input(size=())), + forward_input=FunctionInput(make_input(shape=())), desc='scalar'), ModuleInput(constructor_input=FunctionInput(), - forward_input=FunctionInput(make_input(size=(3,))), + forward_input=FunctionInput(make_input(shape=(3,))), desc='no_batch_dim', reference_fn=no_batch_dim_reference_fn)] @@ -262,14 +262,14 @@ def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwa return [ ModuleInput(constructor_input=FunctionInput(alpha=2.), - forward_input=FunctionInput(make_input(size=(3, 2, 5))), + forward_input=FunctionInput(make_input(shape=(3, 2, 5))), reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))), ModuleInput(constructor_input=FunctionInput(alpha=2.), - forward_input=FunctionInput(make_input(size=())), + forward_input=FunctionInput(make_input(shape=())), reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1)), desc='scalar'), ModuleInput(constructor_input=FunctionInput(alpha=2.), - forward_input=FunctionInput(make_input(size=(3,))), + forward_input=FunctionInput(make_input(shape=(3,))), desc='no_batch_dim', reference_fn=no_batch_dim_reference_fn)] @@ -279,12 +279,12 @@ def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **k return [ ModuleInput(constructor_input=FunctionInput(), - forward_input=FunctionInput(make_input(size=(2, 3, 4)), - make_input(size=(2, 3, 4))), + forward_input=FunctionInput(make_input(shape=(2, 3, 4)), + make_input(shape=(2, 3, 4))), reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum() for a, b in zip(i, t))), ModuleInput(constructor_input=FunctionInput(), - forward_input=FunctionInput(make_input(size=()), make_input(size=())), + forward_input=FunctionInput(make_input(shape=()), make_input(shape=())), reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(), desc='scalar')] + generate_regression_criterion_inputs(make_input) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index b8e5b09..90f3551 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -43,13 +43,13 @@ from unittest.mock import MagicMock import numpy as np -from torch.testing import floating_types_and, integral_types, complex_types, get_all_dtypes import expecttest from .._core import \ (_compare_tensors_internal, _compare_scalars_internal, _compare_return_type) import torch import torch.cuda +from torch.testing import make_tensor from torch._utils_internal import get_writable_path from torch._six import string_classes from torch import Tensor @@ -1939,103 +1939,7 @@ def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False): return deco_retry -# Methods for matrix and tensor generation - -def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, high=None, - requires_grad: bool = False, noncontiguous: bool = False, - exclude_zero: bool = False) -> torch.Tensor: - """ Creates a random tensor with the given size, device and dtype. - - Default values for low and high: - * boolean type: low = 0, high = 2 - * uint8 type: low = 0, high = 9 - * floating and integral types: low = -9 and high = 9 - * complex types, for each real and imaginary part: low = -9, high = 9 - If low/high are specified and within dtype limits: low = low, high = high - If low/high are specified but exceed the limits: low = dtype_min, high = dtype_max - If low is -inf and/or high is inf: low = dtype_min, high = dtype_max - If low is inf or nan and/or high is -inf or nan: ValueError raised - - If noncontiguous=True, a noncontiguous tensor with the given size will be returned unless the size - specifies a tensor with a 1 or 0 elements in which case the noncontiguous parameter is ignored because - it is not possible to create a noncontiguous Tensor with a single element. - - If exclude_zero is passed with True (default is False), all the matching values (with zero) in - created tensor are replaced with a tiny (smallest positive representable number) value if floating type, - [`tiny` + `tiny`.j] if complex type and 1 if integer/boolean type. - """ - def _modify_low_high(low, high, lowest, highest, default_low, default_high, dtype): - """ - Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) if required. - """ - def clamp(a, l, h): - return min(max(a, l), h) - - low = low if low is not None else default_low - high = high if high is not None else default_high - - # Checks for error cases - if low != low or high != high: - raise ValueError("make_tensor: one of low or high was NaN!") - if low > high: - raise ValueError("make_tensor: low must be weakly less than high!") - - low = clamp(low, lowest, highest) - high = clamp(high, lowest, highest) - - if dtype in integral_types(): - return math.floor(low), math.ceil(high) - - return low, high - - if dtype is torch.bool: - result = torch.randint(0, 2, size, device=device, dtype=dtype) - elif dtype is torch.uint8: - ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) - low, high = _modify_low_high(low, high, ranges[0], ranges[1], 0, 9, dtype) - result = torch.randint(low, high, size, device=device, dtype=dtype) - elif dtype in integral_types(): - ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max) - low, high = _modify_low_high(low, high, ranges[0], ranges[1], -9, 9, dtype) - result = torch.randint(low, high, size, device=device, dtype=dtype) - elif dtype in floating_types_and(torch.half, torch.bfloat16): - ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max) - low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype) - rand_val = torch.rand(size, device=device, dtype=dtype) - result = high * rand_val + low * (1 - rand_val) - else: - assert dtype in complex_types() - float_dtype = torch.float if dtype is torch.cfloat else torch.double - ranges_floats = (torch.finfo(float_dtype).min, torch.finfo(float_dtype).max) - low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype) - real_rand_val = torch.rand(size, device=device, dtype=float_dtype) - imag_rand_val = torch.rand(size, device=device, dtype=float_dtype) - real = high * real_rand_val + low * (1 - real_rand_val) - imag = high * imag_rand_val + low * (1 - imag_rand_val) - result = torch.complex(real, imag) - - if noncontiguous and result.numel() > 1: - result = torch.repeat_interleave(result, 2, dim=-1) - result = result[..., ::2] - - if exclude_zero: - if dtype in integral_types() or dtype is torch.bool: - replace_with = torch.tensor(1, device=device, dtype=dtype) - elif dtype in floating_types_and(torch.half, torch.bfloat16): - replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=dtype) - elif dtype in complex_types(): - float_dtype = torch.float if dtype is torch.cfloat else torch.double - float_eps = torch.tensor(torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype) - replace_with = torch.complex(float_eps, float_eps) - else: - raise ValueError(f"Invalid dtype passed, supported dtypes are: {get_all_dtypes()}") - result[result == 0] = replace_with - - if dtype in floating_types_and(torch.half, torch.bfloat16) or\ - dtype in complex_types(): - result.requires_grad = requires_grad - - return result +# Methods for matrix generation def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'): assert rank <= l -- 2.7.4