[Doc] `make_tensor` to `torch.testing` module (#63925)
authorKushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
Mon, 30 Aug 2021 19:16:23 +0000 (12:16 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 19:25:40 +0000 (12:25 -0700)
Summary:
This PR aims to add `make_tensor` to the `torch.testing` module in PyTorch docs.

TODOs:

* [x] Add examples

cc: pmeier mruberry brianjo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63925

Reviewed By: ngimel

Differential Revision: D30633487

Pulled By: mruberry

fbshipit-source-id: 8e5a1f880c6ece5925b4039fee8122bd739538af

24 files changed:
docs/source/testing.rst
test/test_autograd.py
test/test_binary_ufuncs.py
test/test_buffer_protocol.py
test/test_foreach.py
test/test_indexing.py
test/test_jit.py
test/test_linalg.py
test/test_ops.py
test/test_reductions.py
test/test_shape_ops.py
test/test_sort_and_select.py
test/test_sparse.py
test/test_sparse_csr.py
test/test_tensor_creation_ops.py
test/test_testing.py
test/test_torch.py
test/test_unary_ufuncs.py
test/test_view_ops.py
torch/testing/__init__.py
torch/testing/_creation.py [new file with mode: 0644]
torch/testing/_internal/common_methods_invocations.py
torch/testing/_internal/common_modules.py
torch/testing/_internal/common_utils.py

index 981a636..9f1e2c3 100644 (file)
@@ -9,3 +9,4 @@ torch.testing
 .. automodule:: torch.testing
 
 .. autofunction:: assert_close
+.. autofunction:: make_tensor
index 4d41645..364d488 100644 (file)
@@ -24,13 +24,14 @@ from torch.autograd.profiler import (profile, record_function, emit_nvtx)
 from torch.autograd.profiler_util import (_format_time, EventList, FunctionEvent, FunctionEventAvg)
 import torch.autograd.functional as autogradF
 from torch.utils.checkpoint import checkpoint
+from torch.testing import make_tensor
 from torch.testing._internal.common_cuda import TEST_CUDA
 from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack,
                                                   suppress_warnings, slowTest,
                                                   load_tests,
                                                   IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck,
                                                   TEST_WITH_ROCM, disable_gc,
-                                                  gradcheck, gradgradcheck, make_tensor)
+                                                  gradcheck, gradgradcheck)
 from torch.autograd import Variable, Function, detect_anomaly, kineto_available
 from torch.autograd.function import InplaceFunction
 import torch.autograd.forward_ad as fwAD
index 4995e0d..1e9e804 100644 (file)
@@ -13,12 +13,12 @@ from functools import partial
 from torch._six import inf, nan
 from torch.testing._internal.common_utils import (
     TestCase, iter_indices, TEST_WITH_ASAN, run_tests,
-    torch_to_numpy_dtype_dict, make_tensor, TEST_SCIPY, set_default_dtype)
+    torch_to_numpy_dtype_dict, TEST_SCIPY, set_default_dtype)
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA,
     dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyOnCPUAndCUDA,
     skipCUDAIfRocm, skipIf, ops)
-from torch.testing import all_types_and_complex_and, integral_types_and
+from torch.testing import all_types_and_complex_and, integral_types_and, make_tensor
 from torch.testing._internal.common_methods_invocations import binary_ufuncs
 
 if TEST_SCIPY:
index c797b91..619386e 100644 (file)
@@ -1,4 +1,5 @@
 import torch.testing._internal.common_utils as common
+from torch.testing import make_tensor
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests,
     dtypes
@@ -23,7 +24,7 @@ class TestBufferProtocol(common.TestCase):
         if offset is None:
             offset = first * get_dtype_size(dtype)
 
-        numpy_original = common.make_tensor(shape, torch.device("cpu"), dtype).numpy()
+        numpy_original = make_tensor(shape, torch.device("cpu"), dtype).numpy()
         original = memoryview(numpy_original)
         # First call PyTorch's version in case of errors.
         # If this call exits successfully, the NumPy version must also do so.
@@ -125,7 +126,7 @@ class TestBufferProtocol(common.TestCase):
 
     @dtypes(*common.torch_to_numpy_dtype_dict.keys())
     def test_shared_buffer(self, device, dtype):
-        x = common.make_tensor((1,), device, dtype)
+        x = make_tensor((1,), device, dtype)
         # Modify the whole tensor
         arr, tensor = self._run_test(SHAPE, dtype)
         tensor[:] = x
@@ -158,7 +159,7 @@ class TestBufferProtocol(common.TestCase):
 
     @dtypes(*common.torch_to_numpy_dtype_dict.keys())
     def test_non_writable_buffer(self, device, dtype):
-        numpy_arr = common.make_tensor((1,), device, dtype).numpy()
+        numpy_arr = make_tensor((1,), device, dtype).numpy()
         byte_arr = numpy_arr.tobytes()
         with self.assertWarnsOnceRegex(UserWarning,
                                        r"The given buffer is not writable."):
index ce9b0d7..123ef35 100644 (file)
@@ -4,11 +4,13 @@ import random
 import re
 import torch
 import unittest
+
+from torch.testing import make_tensor
 from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_SLOW
 from torch.testing._internal.common_device_type import \
     (instantiate_device_type_tests, dtypes, onlyCUDA, skipCUDAIfRocm, skipMeta, ops)
 from torch.testing._internal.common_methods_invocations import \
-    (foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db, foreach_minmax_op_db, make_tensor)
+    (foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db, foreach_minmax_op_db)
 
 # Includes some values such that N * N won't be a multiple of 4,
 # which should ensure we test the vectorized and non-vectorized
index 6158091..8b8a2ea 100644 (file)
@@ -8,7 +8,8 @@ from functools import reduce
 
 import numpy as np
 
-from torch.testing._internal.common_utils import TestCase, run_tests, make_tensor
+from torch.testing import make_tensor
+from torch.testing._internal.common_utils import TestCase, run_tests
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA,
     onlyOnCPUAndCUDA)
index 2595411..d1a170d 100644 (file)
@@ -69,8 +69,7 @@ from torch._six import PY37
 from torch.autograd import Variable
 from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any  # noqa: F401
 from torch.nn.utils.rnn import PackedSequence
-from torch.testing import FileCheck
-from torch.testing._internal.common_utils import make_tensor
+from torch.testing import FileCheck, make_tensor
 import torch.autograd.profiler
 import torch.cuda
 import torch.jit
index 8ba3373..f7ce392 100644 (file)
@@ -14,14 +14,14 @@ from functools import reduce
 
 from torch.testing._internal.common_utils import \
     (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
-     TEST_WITH_ASAN, make_tensor, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU,
+     TEST_WITH_ASAN, TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU,
      iter_indices, gradcheck, gradgradcheck)
 from torch.testing._internal.common_device_type import \
     (instantiate_device_type_tests, dtypes,
      onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
      skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA, dtypesIfCUDA,
      onlyCUDA, skipCUDAVersionIn, skipMeta, skipCUDAIfNoCusolver)
-from torch.testing import floating_and_complex_types, floating_types, all_types
+from torch.testing import floating_and_complex_types, floating_types, all_types, make_tensor
 from torch.testing._internal.common_cuda import SM53OrLater, tf32_on_and_off, CUDA11OrLater, CUDA9
 from torch.distributions.binomial import Binomial
 
index a6baf8d..27aee72 100644 (file)
@@ -5,9 +5,9 @@ import warnings
 import torch
 
 from torch.testing import \
-    (FileCheck, floating_and_complex_types_and, get_all_dtypes)
+    (FileCheck, floating_and_complex_types_and, get_all_dtypes, make_tensor)
 from torch.testing._internal.common_utils import \
-    (TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper, make_tensor,
+    (TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper,
      gradcheck, gradgradcheck, IS_IN_CI, suppress_warnings)
 from torch.testing._internal.common_methods_invocations import \
     (op_db, _NOTHING, UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo)
index eed7f73..ca3042b 100644 (file)
@@ -10,10 +10,10 @@ import warnings
 
 from torch._six import inf, nan
 from torch.testing import (
-    integral_types_and, floating_and_complex_types_and, get_all_dtypes)
+    integral_types_and, floating_and_complex_types_and, get_all_dtypes, make_tensor)
 from torch.testing._internal.common_utils import (
     TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict,
-    IS_WINDOWS, make_tensor)
+    IS_WINDOWS)
 from torch.testing._internal.common_device_type import (
     OpDTypes, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU,
     onlyOnCPUAndCUDA, onlyCUDA, largeTensorTest, ops, precisionOverride)
index 916adee..cb4ec3c 100644 (file)
@@ -7,8 +7,9 @@ import random
 import warnings
 
 from torch._six import nan
+from torch.testing import make_tensor
 from torch.testing._internal.common_utils import (
-    TestCase, run_tests, make_tensor, torch_to_numpy_dtype_dict)
+    TestCase, run_tests, torch_to_numpy_dtype_dict)
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyOnCPUAndCUDA,
     dtypesIfCPU, dtypesIfCUDA, largeTensorTest)
index 564258a..e562e38 100644 (file)
@@ -5,9 +5,9 @@ import random
 from torch._six import nan
 from itertools import permutations, product
 
-from torch.testing import all_types, all_types_and
+from torch.testing import all_types, all_types_and, make_tensor
 from torch.testing._internal.common_utils import \
-    (TEST_WITH_ROCM, TestCase, run_tests, make_tensor, slowTest)
+    (TEST_WITH_ROCM, TestCase, run_tests, slowTest)
 from torch.testing._internal.common_device_type import \
     (instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA,
      skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA, dtypesIfCPU, onlyCPU, largeTensorTest)
index abe5e93..333f29f 100644 (file)
@@ -5,8 +5,9 @@ import operator
 import random
 from collections import defaultdict
 import unittest
+from torch.testing import make_tensor
 from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
-    do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, make_tensor, \
+    do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
     DeterministicGuard
 from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
 from numbers import Number
index b9f4885..fbb2b30 100644 (file)
@@ -3,8 +3,10 @@ import warnings
 import unittest
 import random
 import itertools
+
+from torch.testing import make_tensor
 from torch.testing._internal.common_utils import \
-    (IS_MACOS, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff, make_tensor)
+    (IS_MACOS, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff)
 from torch.testing._internal.common_device_type import \
     (instantiate_device_type_tests, dtypes, onlyCPU, onlyCUDA)
 
index 192e03f..9ef3742 100644 (file)
@@ -8,9 +8,10 @@ import unittest
 from itertools import product, combinations, combinations_with_replacement, permutations
 import random
 
+from torch.testing import make_tensor
 from torch.testing._internal.common_utils import (
     TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
-    torch_to_numpy_dtype_dict, slowTest, make_tensor, TEST_SCIPY, IS_MACOS, IS_PPC,
+    torch_to_numpy_dtype_dict, slowTest, TEST_SCIPY, IS_MACOS, IS_PPC,
     IS_WINDOWS)
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA,
index 7e67569..f38183d 100644 (file)
@@ -10,8 +10,9 @@ from typing import Any, Callable, Iterator, List, Tuple
 
 import torch
 
+from torch.testing import make_tensor
 from torch.testing._internal.common_utils import \
-    (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, make_tensor, run_tests, skipIfRocm, slowTest)
+    (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, skipIfRocm, slowTest)
 from torch.testing._internal.common_device_type import \
     (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
      get_device_type_test_bases, instantiate_device_type_tests, onlyCUDA, onlyOnCPUAndCUDA,
index 15e36c8..c50b7ca 100644 (file)
@@ -27,13 +27,14 @@ from torch._six import inf, nan, string_classes
 from itertools import product, combinations, permutations
 from functools import partial
 from torch import multiprocessing as mp
+from torch.testing import make_tensor
 from torch.testing._internal.common_utils import (
     TestCase, TEST_WITH_ROCM, run_tests,
     IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
     do_test_dtypes, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest,
     skipCUDAMemoryLeakCheckIf, BytesIOContext, noarchTest,
     skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
-    wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, make_tensor)
+    wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard)
 from multiprocessing.reduction import ForkingPickler
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests,
index e5b8c4a..22f6151 100644 (file)
@@ -11,7 +11,7 @@ import unittest
 from torch._six import inf, nan
 from torch.testing._internal.common_utils import (
     TestCase, run_tests, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict,
-    suppress_warnings, make_tensor, TEST_SCIPY, slowTest, skipIfNoSciPy, IS_WINDOWS)
+    suppress_warnings, TEST_SCIPY, slowTest, skipIfNoSciPy, IS_WINDOWS)
 from torch.testing._internal.common_methods_invocations import (
     unary_ufuncs, _NOTHING)
 from torch.testing._internal.common_device_type import (
@@ -19,7 +19,7 @@ from torch.testing._internal.common_device_type import (
     onlyCUDA, dtypesIfCUDA, precisionOverride, skipCUDAIfRocm, dtypesIfCPU,
     OpDTypes)
 from torch.testing import (
-    floating_types_and, all_types_and_complex_and, floating_and_complex_types_and)
+    floating_types_and, all_types_and_complex_and, floating_and_complex_types_and, make_tensor)
 
 if TEST_SCIPY:
     import scipy
index 306c6cb..7bb6906 100644 (file)
@@ -6,8 +6,9 @@ from itertools import product, permutations, combinations
 from functools import partial
 import random
 
+from torch.testing import make_tensor
 from torch.testing._internal.common_utils import \
-    (TestCase, run_tests, suppress_warnings, make_tensor)
+    (TestCase, run_tests, suppress_warnings)
 from torch.testing._internal.common_device_type import \
     (instantiate_device_type_tests, onlyCPU, dtypes, onlyOnCPUAndCUDA)
 
index 526d02c..7ea18a4 100644 (file)
@@ -1,4 +1,5 @@
 from ._core import *  # noqa: F403
 from ._asserts import *  # noqa: F403
+from ._creation import *  # noqa: F403
 from ._check_kernel_launches import *  # noqa: F403
 from ._deprecated import *  # noqa: F403
diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py
new file mode 100644 (file)
index 0000000..4eb10d1
--- /dev/null
@@ -0,0 +1,155 @@
+"""
+This module contains tensor creation utilities.
+"""
+
+import torch
+from typing import Optional, List, Tuple, Union, cast
+import math
+
+__all__ = [
+    "make_tensor",
+]
+
+def make_tensor(
+    shape: Union[torch.Size, List[int], Tuple[int, ...]],
+    device: Union[str, torch.device],
+    dtype: torch.dtype,
+    *,
+    low: Optional[float] = None,
+    high: Optional[float] = None,
+    requires_grad: bool = False,
+    noncontiguous: bool = False,
+    exclude_zero: bool = False
+) -> torch.Tensor:
+    r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with
+    values uniformly drawn from ``[low, high)``.
+
+    If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable
+    finite values then they are clamped to the lowest or highest representable finite value, respectively.
+    If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`,
+    which depend on :attr:`dtype`.
+
+    +---------------------------+------------+----------+
+    | ``dtype``                 | ``low``    | ``high`` |
+    +===========================+============+==========+
+    | boolean type              | ``0``      | ``2``    |
+    +---------------------------+------------+----------+
+    | unsigned integral type    | ``0``      | ``10``   |
+    +---------------------------+------------+----------+
+    | signed integral types     | ``-9``     | ``10``   |
+    +---------------------------+------------+----------+
+    | floating types            | ``-9``     | ``9``    |
+    +---------------------------+------------+----------+
+    | complex types             | ``-9``     | ``9``    |
+    +---------------------------+------------+----------+
+
+    Args:
+        shape (Tuple[int, ...]): A sequence of integers defining the shape of the output tensor.
+        device (Union[str, torch.device]): The device of the returned tensor.
+        dtype (:class:`torch.dtype`): The data type of the returned tensor.
+        low (Optional[Number]): Sets the lower limit (inclusive) of the given range. If a number is provided it is
+            clamped to the least representable finite value of the given dtype. When ``None`` (default),
+            this value is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
+        high (Optional[Number]): Sets the upper limit (exclusive) of the given range. If a number is provided it is
+            clamped to the greatest representable finite value of the given dtype. When ``None`` (default) this value
+            is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
+        requires_grad (Optional[bool]): If autograd should record operations on the returned tensor. Default: ``False``.
+        noncontiguous (Optional[bool]): If `True`, the returned tensor will be noncontiguous. This argument is
+            ignored if the constructed tensor has fewer than two elements.
+        exclude_zero (Optional[bool]): If ``True`` then zeros are replaced with the dtype's small positive value
+            depending on the :attr:`dtype`. For bool and integer types zero is replaced with one. For floating
+            point types it is replaced with the dtype's smallest positive normal number (the "tiny" value of the
+            :attr:`dtype`'s :func:`~torch.finfo` object), and for complex types it is replaced with a complex number
+            whose real and imaginary parts are both the smallest positive normal number representable by the complex
+            type. Default ``False``.
+
+    Raises:
+        ValueError: If ``low > high``.
+        ValueError: If either :attr:`low` or :attr:`high` is ``nan``.
+        TypeError: If :attr:`dtype` isn't supported by this function.
+
+    Examples:
+        >>> from torch.testing import make_tensor
+        >>> # Creates a float tensor with values in [-1, 1)
+        >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
+        tensor([ 0.1205, 0.2282, -0.6380])
+        >>> # Creates a bool tensor on CUDA
+        >>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
+        tensor([[False, False],
+                [False, True]], device='cuda:0')
+    """
+    def _modify_low_high(low, high, lowest, highest, default_low, default_high, dtype):
+        """
+        Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) if required.
+        """
+        def clamp(a, l, h):
+            return min(max(a, l), h)
+
+        low = low if low is not None else default_low
+        high = high if high is not None else default_high
+
+        # Checks for error cases
+        if low != low or high != high:
+            raise ValueError("make_tensor: one of low or high was NaN!")
+        if low > high:
+            raise ValueError("make_tensor: low must be weakly less than high!")
+
+        low = clamp(low, lowest, highest)
+        high = clamp(high, lowest, highest)
+
+        if dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
+            return math.floor(low), math.ceil(high)
+
+        return low, high
+
+    _integral_types = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
+    _floating_types = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
+    _complex_types = [torch.cfloat, torch.cdouble]
+
+    if dtype is torch.bool:
+        result = torch.randint(0, 2, shape, device=device, dtype=dtype)
+    elif dtype is torch.uint8:
+        ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max)
+        low, high = cast(Tuple[int, int], _modify_low_high(low, high, ranges[0], ranges[1], 0, 10, dtype))
+        result = torch.randint(low, high, shape, device=device, dtype=dtype)
+    elif dtype in _integral_types:
+        ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max)
+        low, high = _modify_low_high(low, high, ranges[0], ranges[1], -9, 10, dtype)
+        result = torch.randint(low, high, shape, device=device, dtype=dtype)  # type: ignore[call-overload]
+    elif dtype in _floating_types:
+        ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max)
+        low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
+        rand_val = torch.rand(shape, device=device, dtype=dtype)
+        result = high * rand_val + low * (1 - rand_val)
+    elif dtype in _complex_types:
+        float_dtype = torch.float if dtype is torch.cfloat else torch.double
+        ranges_floats = (torch.finfo(float_dtype).min, torch.finfo(float_dtype).max)
+        low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
+        real_rand_val = torch.rand(shape, device=device, dtype=float_dtype)
+        imag_rand_val = torch.rand(shape, device=device, dtype=float_dtype)
+        real = high * real_rand_val + low * (1 - real_rand_val)
+        imag = high * imag_rand_val + low * (1 - imag_rand_val)
+        result = torch.complex(real, imag)
+    else:
+        raise TypeError(f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()."
+                        " To request support, file an issue at: https://github.com/pytorch/pytorch/issues")
+
+    if noncontiguous and result.numel() > 1:
+        result = torch.repeat_interleave(result, 2, dim=-1)
+        result = result[..., ::2]
+
+    if exclude_zero:
+        if dtype in _integral_types or dtype is torch.bool:
+            replace_with = torch.tensor(1, device=device, dtype=dtype)
+        elif dtype in _floating_types:
+            replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=dtype)
+        else:  # dtype in _complex_types:
+            float_dtype = torch.float if dtype is torch.cfloat else torch.double
+            float_eps = torch.tensor(torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype)
+            replace_with = torch.complex(float_eps, float_eps)
+        result[result == 0] = replace_with
+
+    if dtype in _floating_types + _complex_types:
+        result.requires_grad = requires_grad
+
+    return result
index 2230808..a3d61b4 100644 (file)
@@ -19,7 +19,7 @@ from torch.testing import \
     (make_non_contiguous, floating_types, floating_types_and, complex_types,
      floating_and_complex_types, floating_and_complex_types_and,
      all_types_and_complex_and, all_types_and, all_types_and_complex,
-     integral_types_and, all_types, double_types)
+     integral_types_and, all_types, double_types, make_tensor)
 from .._core import _dispatch_dtypes
 from torch.testing._internal.common_device_type import \
     (onlyOnCPUAndCUDA, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfNoCusolver,
@@ -32,7 +32,7 @@ from torch.testing._internal.common_utils import \
      random_symmetric_pd_matrix, make_symmetric_matrices,
      make_symmetric_pd_matrices, random_square_matrix_of_rank,
      random_fullrank_matrix_distinct_singular_value,
-     TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY,
+     TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
      torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
      GRADCHECK_NONDET_TOL,)
 import torch.testing._internal.opinfo_helper as opinfo_helper
index 99525a7..6ef4de3 100644 (file)
@@ -2,11 +2,11 @@ import torch
 from copy import deepcopy
 from functools import wraps, partial
 from itertools import chain
-from torch.testing import floating_types
+from torch.testing import floating_types, make_tensor
 from torch.testing._internal.common_device_type import (
     _TestParametrizer, _dtype_test_suffix, _update_param_kwargs, skipIf)
 from torch.testing._internal.common_nn import nllloss_reference, get_reduction
-from torch.testing._internal.common_utils import make_tensor, freeze_rng_state
+from torch.testing._internal.common_utils import freeze_rng_state
 from types import ModuleType
 from typing import List, Tuple, Type, Set, Dict
 
@@ -225,7 +225,7 @@ def generate_regression_criterion_inputs(make_input):
     return [
         ModuleInput(
             constructor_input=FunctionInput(reduction=reduction),
-            forward_input=FunctionInput(make_input(size=(4, )), make_input(size=4,)),
+            forward_input=FunctionInput(make_input(shape=(4, )), make_input(shape=4,)),
             reference_fn=no_batch_dim_reference_criterion_fn,
             desc='no_batch_dim_{}'.format(reduction)
         ) for reduction in ['none', 'mean', 'sum']]
@@ -236,7 +236,7 @@ def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad,
 
     return [
         ModuleInput(constructor_input=FunctionInput(kernel_size=2),
-                    forward_input=FunctionInput(make_input(size=(3, 6))),
+                    forward_input=FunctionInput(make_input(shape=(3, 6))),
                     desc='no_batch_dim',
                     reference_fn=no_batch_dim_reference_fn)]
 
@@ -246,13 +246,13 @@ def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwar
 
     return [
         ModuleInput(constructor_input=FunctionInput(alpha=2.),
-                    forward_input=FunctionInput(make_input(size=(3, 2, 5))),
+                    forward_input=FunctionInput(make_input(shape=(3, 2, 5))),
                     reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))),
         ModuleInput(constructor_input=FunctionInput(alpha=2.),
-                    forward_input=FunctionInput(make_input(size=())),
+                    forward_input=FunctionInput(make_input(shape=())),
                     desc='scalar'),
         ModuleInput(constructor_input=FunctionInput(),
-                    forward_input=FunctionInput(make_input(size=(3,))),
+                    forward_input=FunctionInput(make_input(shape=(3,))),
                     desc='no_batch_dim',
                     reference_fn=no_batch_dim_reference_fn)]
 
@@ -262,14 +262,14 @@ def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwa
 
     return [
         ModuleInput(constructor_input=FunctionInput(alpha=2.),
-                    forward_input=FunctionInput(make_input(size=(3, 2, 5))),
+                    forward_input=FunctionInput(make_input(shape=(3, 2, 5))),
                     reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))),
         ModuleInput(constructor_input=FunctionInput(alpha=2.),
-                    forward_input=FunctionInput(make_input(size=())),
+                    forward_input=FunctionInput(make_input(shape=())),
                     reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1)),
                     desc='scalar'),
         ModuleInput(constructor_input=FunctionInput(alpha=2.),
-                    forward_input=FunctionInput(make_input(size=(3,))),
+                    forward_input=FunctionInput(make_input(shape=(3,))),
                     desc='no_batch_dim',
                     reference_fn=no_batch_dim_reference_fn)]
 
@@ -279,12 +279,12 @@ def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **k
 
     return [
         ModuleInput(constructor_input=FunctionInput(),
-                    forward_input=FunctionInput(make_input(size=(2, 3, 4)),
-                                                make_input(size=(2, 3, 4))),
+                    forward_input=FunctionInput(make_input(shape=(2, 3, 4)),
+                                                make_input(shape=(2, 3, 4))),
                     reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum()
                                                                          for a, b in zip(i, t))),
         ModuleInput(constructor_input=FunctionInput(),
-                    forward_input=FunctionInput(make_input(size=()), make_input(size=())),
+                    forward_input=FunctionInput(make_input(shape=()), make_input(shape=())),
                     reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(),
                     desc='scalar')] + generate_regression_criterion_inputs(make_input)
 
index b8e5b09..90f3551 100644 (file)
@@ -43,13 +43,13 @@ from unittest.mock import MagicMock
 
 import numpy as np
 
-from torch.testing import floating_types_and, integral_types, complex_types, get_all_dtypes
 import expecttest
 from .._core import \
     (_compare_tensors_internal, _compare_scalars_internal, _compare_return_type)
 
 import torch
 import torch.cuda
+from torch.testing import make_tensor
 from torch._utils_internal import get_writable_path
 from torch._six import string_classes
 from torch import Tensor
@@ -1939,103 +1939,7 @@ def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False):
     return deco_retry
 
 
-# Methods for matrix and tensor generation
-
-def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, high=None,
-                requires_grad: bool = False, noncontiguous: bool = False,
-                exclude_zero: bool = False) -> torch.Tensor:
-    """ Creates a random tensor with the given size, device and dtype.
-
-        Default values for low and high:
-            * boolean type: low = 0, high = 2
-            * uint8 type: low = 0, high = 9
-            * floating and integral types: low = -9 and high = 9
-            * complex types, for each real and imaginary part: low = -9, high = 9
-        If low/high are specified and within dtype limits: low = low, high = high
-        If low/high are specified but exceed the limits: low = dtype_min, high = dtype_max
-        If low is -inf and/or high is inf: low = dtype_min, high = dtype_max
-        If low is inf or nan and/or high is -inf or nan: ValueError raised
-
-        If noncontiguous=True, a noncontiguous tensor with the given size will be returned unless the size
-        specifies a tensor with a 1 or 0 elements in which case the noncontiguous parameter is ignored because
-        it is not possible to create a noncontiguous Tensor with a single element.
-
-        If exclude_zero is passed with True (default is False), all the matching values (with zero) in
-        created tensor are replaced with a tiny (smallest positive representable number) value if floating type,
-        [`tiny` + `tiny`.j] if complex type and 1 if integer/boolean type.
-    """
-    def _modify_low_high(low, high, lowest, highest, default_low, default_high, dtype):
-        """
-        Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high) if required.
-        """
-        def clamp(a, l, h):
-            return min(max(a, l), h)
-
-        low = low if low is not None else default_low
-        high = high if high is not None else default_high
-
-        # Checks for error cases
-        if low != low or high != high:
-            raise ValueError("make_tensor: one of low or high was NaN!")
-        if low > high:
-            raise ValueError("make_tensor: low must be weakly less than high!")
-
-        low = clamp(low, lowest, highest)
-        high = clamp(high, lowest, highest)
-
-        if dtype in integral_types():
-            return math.floor(low), math.ceil(high)
-
-        return low, high
-
-    if dtype is torch.bool:
-        result = torch.randint(0, 2, size, device=device, dtype=dtype)
-    elif dtype is torch.uint8:
-        ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max)
-        low, high = _modify_low_high(low, high, ranges[0], ranges[1], 0, 9, dtype)
-        result = torch.randint(low, high, size, device=device, dtype=dtype)
-    elif dtype in integral_types():
-        ranges = (torch.iinfo(dtype).min, torch.iinfo(dtype).max)
-        low, high = _modify_low_high(low, high, ranges[0], ranges[1], -9, 9, dtype)
-        result = torch.randint(low, high, size, device=device, dtype=dtype)
-    elif dtype in floating_types_and(torch.half, torch.bfloat16):
-        ranges_floats = (torch.finfo(dtype).min, torch.finfo(dtype).max)
-        low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
-        rand_val = torch.rand(size, device=device, dtype=dtype)
-        result = high * rand_val + low * (1 - rand_val)
-    else:
-        assert dtype in complex_types()
-        float_dtype = torch.float if dtype is torch.cfloat else torch.double
-        ranges_floats = (torch.finfo(float_dtype).min, torch.finfo(float_dtype).max)
-        low, high = _modify_low_high(low, high, ranges_floats[0], ranges_floats[1], -9, 9, dtype)
-        real_rand_val = torch.rand(size, device=device, dtype=float_dtype)
-        imag_rand_val = torch.rand(size, device=device, dtype=float_dtype)
-        real = high * real_rand_val + low * (1 - real_rand_val)
-        imag = high * imag_rand_val + low * (1 - imag_rand_val)
-        result = torch.complex(real, imag)
-
-    if noncontiguous and result.numel() > 1:
-        result = torch.repeat_interleave(result, 2, dim=-1)
-        result = result[..., ::2]
-
-    if exclude_zero:
-        if dtype in integral_types() or dtype is torch.bool:
-            replace_with = torch.tensor(1, device=device, dtype=dtype)
-        elif dtype in floating_types_and(torch.half, torch.bfloat16):
-            replace_with = torch.tensor(torch.finfo(dtype).tiny, device=device, dtype=dtype)
-        elif dtype in complex_types():
-            float_dtype = torch.float if dtype is torch.cfloat else torch.double
-            float_eps = torch.tensor(torch.finfo(float_dtype).tiny, device=device, dtype=float_dtype)
-            replace_with = torch.complex(float_eps, float_eps)
-        else:
-            raise ValueError(f"Invalid dtype passed, supported dtypes are: {get_all_dtypes()}")
-        result[result == 0] = replace_with
-
-    if dtype in floating_types_and(torch.half, torch.bfloat16) or\
-       dtype in complex_types():
-        result.requires_grad = requires_grad
-
-    return result
+# Methods for matrix generation
 
 def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):
     assert rank <= l