From: Vitaly Fedyunin Date: Tue, 16 Apr 2019 17:50:48 +0000 (-0700) Subject: Adding pin_memory kwarg to zeros, ones, empty, ... tensor constructors (#18952) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~214 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1c5073fb4bb7fdb73b4137e4fcf9bc7345ab4b9a;p=platform%2Fupstream%2Fpytorch.git Adding pin_memory kwarg to zeros, ones, empty, ... tensor constructors (#18952) Summary: Make it possible to construct a pinned memory tensor without creating a storage first and without calling pin_memory() function. It is also faster, as copy operation is unnecessary. Supported functions: ```python torch.rand_like(t, pin_memory=True) torch.randn_like(t, pin_memory=True) torch.empty_like(t, pin_memory=True) torch.full_like(t, 4, pin_memory=True) torch.zeros_like(t, pin_memory=True) torch.ones_like(t, pin_memory=True) torch.tensor([10,11], pin_memory=True) torch.randn(3, 5, pin_memory=True) torch.rand(3, pin_memory=True) torch.zeros(3, pin_memory=True) torch.randperm(3, pin_memory=True) torch.empty(6, pin_memory=True) torch.ones(6, pin_memory=True) torch.eye(6, pin_memory=True) torch.arange(3, 5, pin_memory=True) ``` Part of the bigger: `Remove Storage` plan. Now compatible with both torch scripts: ` _1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"), pin_memory=False)` and ` _1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"))` Same checked for all similar functions `rand_like`, `empty_like` and others It is fixed version of #18455 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18952 Differential Revision: D14801792 Pulled By: VitalyFedyunin fbshipit-source-id: 8dbc61078ff7a637d0ecdb95d4e98f704d5450ba --- diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 96a8b04..d08676d 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -93,7 +94,13 @@ Tensor empty_cpu(IntArrayRef size, const TensorOptions& options) { AT_ASSERT(!options.is_variable()); // is_variable should have been 'unpacked' // TODO: remove this when Variable and Tensor are merged check_size_nonnegative(size); - auto* allocator = at::getCPUAllocator(); + c10::Allocator* allocator; + if (options.pinned_memory()) { + allocator = detail::getCUDAHooks().getPinnedMemoryAllocator(); + } else { + allocator = at::getCPUAllocator(); + } + int64_t nelements = prod_intlist(size); auto dtype = options.dtype(); auto storage_impl = c10::make_intrusive( diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index 2da789a..90d0208 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -46,6 +46,7 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) { Tensor empty_cuda(IntArrayRef size, const TensorOptions& options) { AT_ASSERT(options.backend() == at::Backend::CUDA); AT_ASSERT(!options.is_variable()); // is_variable should have been 'unpacked' // TODO: remove this when Variable and Tensor are merged + AT_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned"); check_size_nonnegative(size); auto* allocator = at::cuda::getCUDADeviceAllocator(); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 609e179..1f860b2 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -46,7 +46,7 @@ dispatch: CUDA: _cudnn_rnn_backward -- func: _cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType dtype, Layout layout, Device device) -> Tensor +- func: _cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor dispatch: CUDA: _cudnn_init_dropout_state @@ -176,11 +176,11 @@ - func: any(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) -- func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: arange(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: arange(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: arange(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: arange(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: arange(Scalar end, *, Tensor(a!) out) -> Tensor(a!) @@ -259,9 +259,9 @@ CPU: baddbmm_out_cpu CUDA: baddbmm_out_cuda -- func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: bartlett_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: bartlett_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor @@ -308,9 +308,9 @@ CPU: _bincount_cpu CUDA: _bincount_cuda -- func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: blackman_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: blackman_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: bmm(Tensor self, Tensor mat2) -> Tensor variants: function, method @@ -668,7 +668,7 @@ CPU: _embedding_bag_per_sample_weights_backward_cpu CUDA: _embedding_bag_per_sample_weights_backward_cuda -- func: empty(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: empty(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor cpu_half: True cpu_bool: True cuda_bool: True @@ -694,10 +694,10 @@ - func: empty_like(Tensor self) -> Tensor device_guard: False -- func: empty_like(Tensor self, *, ScalarType dtype, Layout layout, Device device) -> Tensor +- func: empty_like(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor device_guard: False -- func: empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor cpu_half: True cpu_bool: True cuda_bool: True @@ -769,9 +769,9 @@ variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. device_guard: False -- func: eye(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: eye(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: eye(int n, int m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: eye(int n, int m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: eye(int n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -820,13 +820,13 @@ CPU: _frac_out_cpu CUDA: _frac_out_cuda -- func: full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: full(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) - func: full_like(Tensor self, Scalar fill_value) -> Tensor -- func: full_like(Tensor self, Scalar fill_value, *, ScalarType dtype, Layout layout, Device device) -> Tensor +- func: full_like(Tensor self, Scalar fill_value, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor # NOTE [ grid_sampler Native Functions ] # `grid_sampler` does all the shape checking and then dispatches to one of @@ -859,17 +859,17 @@ CPU: grid_sampler_3d_backward_cpu CUDA: grid_sampler_3d_backward_cuda -- func: hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: hann_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: hann_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: hamming_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: hamming_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: hamming_window(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: hamming_window(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: hamming_window(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: hamming_window(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor @@ -996,7 +996,7 @@ - func: fbgemm_is_cpu_supported() -> bool -- func: linspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: linspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: linspace(Scalar start, Scalar end, int steps=100, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -1066,7 +1066,7 @@ - func: logdet(Tensor self) -> Tensor variants: function, method -- func: logspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: logspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: logspace(Scalar start, Scalar end, int steps=100, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -1330,13 +1330,13 @@ - func: _nnpack_spatial_convolution_backward_weight(Tensor input, int[] weightsize, Tensor grad_output, int[2] padding) -> Tensor variants: function -- func: ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: ones(int[] size, *, Tensor(a!) out) -> Tensor(a!) - func: ones_like(Tensor self) -> Tensor -- func: ones_like(Tensor self, *, ScalarType dtype, Layout layout, Device device) -> Tensor +- func: ones_like(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor - func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor @@ -1364,11 +1364,11 @@ - func: pinverse(Tensor self, float rcond=1e-15) -> Tensor variants: function, method -- func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: rand(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: rand(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: rand(int[] size, *, Tensor(a!) out) -> Tensor(a!) @@ -1376,15 +1376,15 @@ - func: rand_like(Tensor self) -> Tensor -- func: rand_like(Tensor self, *, ScalarType dtype, Layout layout, Device device) -> Tensor +- func: rand_like(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor -- func: randint(int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: randint(int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: randint(int high, int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: randint(int high, int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: randint(int low, int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: randint(int low, int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: randint(int low, int high, int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: randint(int low, int high, int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: randint(int high, int[] size, *, Tensor(a!) out) -> Tensor(a!) @@ -1398,13 +1398,13 @@ - func: randint_like(Tensor self, int low, int high) -> Tensor -- func: randint_like(Tensor self, int high, *, ScalarType dtype, Layout layout, Device device) -> Tensor +- func: randint_like(Tensor self, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor -- func: randint_like(Tensor self, int low, int high, *, ScalarType dtype, Layout layout, Device device) -> Tensor +- func: randint_like(Tensor self, int low, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor -- func: randn(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: randn(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: randn(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: randn(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: randn(int[] size, *, Tensor(a!) out) -> Tensor(a!) @@ -1412,11 +1412,11 @@ - func: randn_like(Tensor self) -> Tensor -- func: randn_like(Tensor self, *, ScalarType dtype, Layout layout, Device device) -> Tensor +- func: randn_like(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor -- func: randperm(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: randperm(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: randperm(int n, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: randperm(int n, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: randperm(int n, *, Tensor(a!) out) -> Tensor(a!) @@ -1425,9 +1425,9 @@ CPU: randperm_out_cpu CUDA: randperm_out_cuda -- func: range(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: range(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: range(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -1988,13 +1988,13 @@ - func: _weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) variants: function -- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: zeros(int[] size, *, Tensor(a!) out) -> Tensor(a!) - func: zeros_like(Tensor self) -> Tensor -- func: zeros_like(Tensor self, *, ScalarType dtype, Layout layout, Device device) -> Tensor +- func: zeros_like(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor - func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor variants: function @@ -2286,27 +2286,26 @@ # FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given # the default would never make sense. -- func: sparse_coo_tensor(int[] size, *, ScalarType dtype, Layout layout, Device device) -> Tensor +- func: sparse_coo_tensor(int[] size, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor -- func: sparse_coo_tensor(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: sparse_coo_tensor(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: sparse_coo_tensor(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: sparse_coo_tensor(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor +- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType dtype, Layout layout, Device device) -> Tensor +- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor dispatch: SparseCPU: new_with_dims_sparse SparseCUDA: new_with_dims_sparse requires_tensor: True -- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType dtype, Layout layout, Device device) -> Tensor +- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor dispatch: SparseCPU: new_with_dims_and_tensor_sparse SparseCUDA: new_with_dims_and_tensor_sparse requires_tensor: True - - func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) variants: method dispatch: @@ -2510,7 +2509,7 @@ # to(Device) must not exist because all constructors of Device also works for # TensorOptions. Otherwise, an ambiguity error is thrown. # See NOTE [ TensorOptions Constructors ]. -- func: to(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool non_blocking=False, bool copy=False) -> Tensor +- func: to(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, bool non_blocking=False, bool copy=False) -> Tensor variants: method device_guard: False diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index d3278c6..67f3e22 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -108,6 +108,7 @@ SparseTensor new_with_dims_and_tensor_sparse( /** Empty init **/ Tensor empty_sparse(IntArrayRef size, const TensorOptions& options) { + AT_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned"); return new_with_dims_sparse(size.size(), 0, size, options); } diff --git a/aten/src/ATen/native_parse.py b/aten/src/ATen/native_parse.py index 08175c3..bdc4f1a 100644 --- a/aten/src/ATen/native_parse.py +++ b/aten/src/ATen/native_parse.py @@ -202,19 +202,15 @@ def parse_arguments(args, func_variants, declaration, func_return): {'name': 'dtype', 'type': 'ScalarType', 'is_nullable': False, 'annotation': None}, {'name': 'layout', 'type': 'Layout', 'is_nullable': False, 'annotation': None}, {'name': 'device', 'type': 'Device', 'is_nullable': False, 'annotation': None}, + {'name': 'pin_memory', 'type': 'bool', 'is_nullable': False, 'annotation': None, 'default': False}, ] ] supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[0])) - supported_topt_arguments[1][0]['kwarg_only'] = True - supported_topt_arguments[1][1]['kwarg_only'] = True - supported_topt_arguments[1][2]['kwarg_only'] = True + for arg in supported_topt_arguments[1]: + arg.update({'kwarg_only': True}) supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[1])) - supported_topt_arguments[2][0]['default'] = 'c10::nullopt' - supported_topt_arguments[2][1]['default'] = 'c10::nullopt' - supported_topt_arguments[2][2]['default'] = 'c10::nullopt' - supported_topt_arguments[2][0]['is_nullable'] = True - supported_topt_arguments[2][1]['is_nullable'] = True - supported_topt_arguments[2][2]['is_nullable'] = True + for arg in supported_topt_arguments[2]: + arg.update({'default': 'c10::nullopt', 'is_nullable': True}) corresponding_topts = [ {'type': 'TensorOptions', 'name': 'options', 'is_nullable': False, 'annotation': None}, @@ -226,30 +222,28 @@ def parse_arguments(args, func_variants, declaration, func_return): def check_topt_representation(topt_representation): for idx, supported_topt in enumerate(supported_topt_arguments): - matches = True - matches = matches and topt_representation[0] == supported_topt[0] - matches = matches and topt_representation[1] == supported_topt[1] - matches = matches and topt_representation[2] == supported_topt[2] + matches = all(topt_representation[i] == topt for i, topt in enumerate(supported_topt)) if matches: return corresponding_topts[idx] return None def is_tensor_option(argument): - return argument['name'] in ['dtype', 'layout', 'device'] + return argument['name'] in ['dtype', 'layout', 'device', 'pin_memory'] new_arguments = [] idx = 0 while idx < len(arguments): argument = arguments[idx] - if is_tensor_option(argument) and len(arguments) - idx >= 3: + number_of_arguments = len(supported_topt_arguments[0]) + if is_tensor_option(argument) and len(arguments) - idx >= number_of_arguments: topt_representation = [] - for i in range(3): + for i in range(number_of_arguments): argument = arguments[idx] if not is_tensor_option(argument): break topt_representation.append(argument) idx += 1 - if len(topt_representation) == 3: + if len(topt_representation) == number_of_arguments: merged_argument = check_topt_representation(topt_representation) assert merged_argument, \ "Unsupported combination of TensorOptions {}, the only currently supported combinations are {}"\ diff --git a/test/test_torch.py b/test/test_torch.py index 4c44cae..b7ba9f7 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9984,6 +9984,40 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(pinned, x) self.assertNotEqual(pinned.data_ptr(), x.data_ptr()) + @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') + def test_pin_memory_from_constructor(self): + + def _get_like(t, **kwargs): + return [ + torch.rand_like(t, **kwargs), + torch.randn_like(t, **kwargs), + torch.empty_like(t, **kwargs), + torch.full_like(t, 4, **kwargs), + torch.zeros_like(t, **kwargs), + torch.ones_like(t, **kwargs), + ] + + def _get_tensors(**kwargs): + return [ + torch.tensor([10, 11], **kwargs), + torch.randn(3, 5, **kwargs), + torch.rand(3, **kwargs), + # torch.randint(3, 5, **kwargs), // unsupported + torch.zeros(3, **kwargs), + torch.randperm(3, **kwargs), + torch.empty(6, **kwargs), + torch.ones(6, **kwargs), + torch.eye(6, **kwargs), + torch.arange(3, 5, **kwargs)] + + pinned_tensors = _get_tensors(pin_memory=True) + _get_like(torch.empty(5, dtype=torch.float64), pin_memory=True) + for x in pinned_tensors: + self.assertTrue(x.is_pinned()) + + tensors = _get_tensors() + _get_like(torch.empty(5, dtype=torch.float64, pin_memory=True)) + for x in tensors: + self.assertFalse(x.is_pinned()) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_numpy_unresizable(self): x = np.zeros((2, 2)) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 2e7fe79..821058e 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -157,7 +157,8 @@ const auto options = TensorOptions() .dtype(${dtype}) .device(${device}) .layout(${layout}.layout) - .requires_grad(${requires_grad}); + .requires_grad(${requires_grad}) + .pinned_memory(${pin_memory}); """) @@ -429,9 +430,9 @@ def create_python_bindings(python_functions, has_self, is_module=False): arg_idx += 1 if 'layout' in (a['name'] for a in python_binding_arguments): - layout_idx, device_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2) + layout_idx, device_idx, pin_memory_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2, arg_idx + 3) else: - device_idx, requires_grad_idx = (arg_idx, arg_idx + 1) + device_idx, pin_memory_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2) device = None for arg in python_binding_arguments: @@ -459,9 +460,11 @@ def create_python_bindings(python_functions, has_self, is_module=False): has_device_bind = True elif arg['name'] == 'requires_grad' and arg['simple_type'] == 'bool': requires_grad = parse_arg(arg, requires_grad_idx)[0] + elif arg['name'] == 'pin_memory' and arg['simple_type'] == 'bool': + pin_memory = parse_arg(arg, pin_memory_idx)[0] else: raise RuntimeError(("found {} in python_binding_arguments but only " - "\"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", " + "\"bool pin_memory\", \"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", " "\"Device device\" are supported".format(arg))) dtype = parsed_type_args[0] if parsed_type_args else None @@ -470,7 +473,8 @@ def create_python_bindings(python_functions, has_self, is_module=False): 'dtype': dtype, 'layout': layout, 'device': device, - 'requires_grad': requires_grad + 'requires_grad': requires_grad, + 'pin_memory': pin_memory, })) formal_args.append('const TensorOptions & options') actuals.append('options') @@ -620,6 +624,15 @@ def create_python_bindings(python_functions, has_self, is_module=False): 'python_default_init': py_default_device } python_binding_arguments.append(device_arg) + pin_memory_arg = { + 'default': False, + 'dynamic_type': 'bool', + 'kwarg_only': True, + 'name': 'pin_memory', + 'type': 'bool', + 'simple_type': 'bool', + } + python_binding_arguments.append(pin_memory_arg) if is_factory_or_like_function: requires_grad_arg = { 'default': False, diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index 9eb5bb9..8b07b9e 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -96,11 +96,11 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k { HANDLE_TH_ERRORS static PythonArgParser parser({ - "arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", - "arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)", + "arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", + "arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", }); - ParsedArgs<8> parsed_args; + ParsedArgs<9> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { @@ -112,12 +112,14 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k .dtype(scalarType) .device(r.device(4)) .layout(r.layout(3).layout) - .requires_grad(r.toBool(5)); + .requires_grad(r.toBool(6)) + .pinned_memory(r.toBool(5)); return wrap(dispatch_arange(end, options)); } else { + AT_CHECK(!r.toBool(5), " `pin_memory` and `out` parameters are incompatible"); check_out_type_matches(r.tensor(1), r.scalartype(2), r.isNone(2), r.layout(3), r.isNone(3), r.device(4), r.isNone(4)); - return wrap(dispatch_arange(r.scalar(0), r.tensor(1)).set_requires_grad(r.toBool(5))); + return wrap(dispatch_arange(r.scalar(0), r.tensor(1)).set_requires_grad(r.toBool(6))); } } else if (r.idx == 1) { if (r.isNone(3)) { @@ -130,12 +132,14 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k .dtype(scalarType) .device(r.device(6)) .layout(r.layout(5).layout) - .requires_grad(r.toBool(7)); + .requires_grad(r.toBool(8)) + .pinned_memory(r.toBool(7)); return wrap(dispatch_arange(start, end, step, options)); } else { + AT_CHECK(!r.toBool(7), " `pin_memory` and `out` parameters are incompatible"); check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), r.layout(5), r.isNone(5), r.device(6), r.isNone(6)); - return wrap(dispatch_arange(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(7))); + return wrap(dispatch_arange(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(8))); } } Py_RETURN_NONE; diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index 14b0843..483fa5f 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -59,6 +59,7 @@ TYPE_MAP = { 'int64_t?': 'int?', 'double': 'float', 'bool': 'bool', + 'bool?': 'bool?', 'Generator': 'Generator?', } @@ -104,6 +105,7 @@ FROM_IVALUE = { 'Tensor?[]': 'toListOfOptionalTensor({})', 'TensorList': '{}.toTensorList()->elements()', 'bool': '{}.toBool()', + 'bool?': '{}.toOptional()', 'double': '{}.toDouble()', 'int64_t': '{}.toInt()', 'int64_t?': '{}.toOptional()', @@ -134,14 +136,16 @@ CALL_NAMESPACE_WITH_TENSOR_OPTIONS = CodeTemplate("""\ const auto options = TensorOptions() .dtype(${dtype}) .layout(${layout}) - .device(${device}); + .device(${device}) + .pinned_memory(${pin_memory}); auto result_ = torch::${name}(${args_with_tensor_options}); """) CALL_METHOD_WITH_TENSOR_OPTIONS = CodeTemplate("""\ const auto options = TensorOptions() .dtype(${dtype}) .layout(${layout}) - .device(${device}); + .device(${device}) + .pinned_memory(${pin_memory});; auto result_ = (${first}).${name}(${args_with_tensor_options}); """) @@ -242,15 +246,18 @@ def gen_jit_dispatch(declarations, out, template_path): dtype = args[tensor_options_arg_index] layout = args[tensor_options_arg_index + 1] device = args[tensor_options_arg_index + 2] + pin_memory = args[tensor_options_arg_index + 3] args_with_tensor_options = args[:tensor_options_arg_index] + \ - ['options'] + args[(tensor_options_arg_index + 3):] + ['options'] + args[(tensor_options_arg_index + 4):] if is_namespace_function: return CALL_NAMESPACE_WITH_TENSOR_OPTIONS.substitute( - name=decl['name'], dtype=dtype, layout=layout, device=device, + name=decl['name'], dtype=dtype, layout=layout, + device=device, pin_memory=pin_memory, args_with_tensor_options=pack_arguments(args_with_tensor_options)) else: return CALL_METHOD_WITH_TENSOR_OPTIONS.substitute( - name=decl['name'], dtype=dtype, layout=layout, device=device, + name=decl['name'], dtype=dtype, layout=layout, + device=device, pin_memory=pin_memory, args_with_tensor_options=pack_arguments(args_with_tensor_options[1:]), first=args_with_tensor_options[0], num_inputs=num_inputs) else: @@ -349,21 +356,19 @@ def gen_jit_dispatch(declarations, out, template_path): {'name': 'layout', 'simple_type': 'Layout'}, # device is specified as an IntArrayRef of { at::Device::Type, device_id } {'name': 'device', 'simple_type': 'Device'}, + # pin_memory is specified as a boolean + {'name': 'pin_memory', 'simple_type': 'bool', 'default': False}, ] # TODO: Don't repack this into TensorOptions. Needs various changes in downstream code. if 'default' in arg: - tensor_options_expansion[0]['simple_type'] += '?' - tensor_options_expansion[1]['simple_type'] += '?' - tensor_options_expansion[2]['simple_type'] += '?' - tensor_options_expansion[0]['default'] = 'None' - tensor_options_expansion[1]['default'] = 'None' - tensor_options_expansion[2]['default'] = 'None' + for el in tensor_options_expansion: + el['simple_type'] += '?' + el['default'] = 'None' if 'default' in arg and arg['default'] == 'at::kLong': tensor_options_expansion[0]['default'] = 'long' if 'kwarg_only' in arg and arg['kwarg_only']: - tensor_options_expansion[0]['kwarg_only'] = True - tensor_options_expansion[1]['kwarg_only'] = True - tensor_options_expansion[2]['kwarg_only'] = True + for el in tensor_options_expansion: + el['kwarg_only'] = True return tensor_options_expansion additional_jit_decls = [] diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 0d5ad32..5c7eb4a 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -240,7 +240,11 @@ def generate_type_hints(fname, decls, is_tensor=False): if a.get('kwarg_only', False) and render_kw_only_separator: python_args.append('*') render_kw_only_separator = False - python_args.append(arg_to_type_hint(a)) + try: + python_args.append(arg_to_type_hint(a)) + except Exception: + print("Error while processing function {}".format(fname)) + raise if is_tensor: if 'self: Tensor' in python_args: diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 262998c..948b7c1 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -17,6 +17,8 @@ new_common_args = parse_kwargs(""" Default: if None, same :class:`torch.device` as this tensor. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. """) add_docstr_all('new_tensor', diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 98512af..703b189 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -63,6 +63,8 @@ factory_common_args = parse_kwargs(""" for CPU tensor types and the current CUDA device for CUDA tensor types. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. """) factory_like_common_args = parse_kwargs(""" @@ -75,6 +77,8 @@ factory_like_common_args = parse_kwargs(""" Default: if ``None``, defaults to the device of :attr:`input`. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. """) factory_data_common_args = parse_kwargs(""" @@ -88,6 +92,8 @@ factory_data_common_args = parse_kwargs(""" for CPU tensor types and the current CUDA device for CUDA tensor types. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. """) add_docstr(torch.abs, @@ -3990,7 +3996,7 @@ Example:: add_docstr(torch.tensor, r""" -tensor(data, dtype=None, device=None, requires_grad=False) -> Tensor +tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor Constructs a tensor with :attr:`data`. @@ -4014,6 +4020,7 @@ Args: {dtype} {device} {requires_grad} + {pin_memory} Example:: @@ -5553,7 +5560,7 @@ Example:: add_docstr(torch.empty, r""" -empty(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor +empty(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor Returns a tensor filled with uninitialized data. The shape of the tensor is defined by the variable argument :attr:`sizes`. @@ -5566,6 +5573,7 @@ Args: {layout} {device} {requires_grad} + {pin_memory} Example:: diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 05fe6ff..e156aa0 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -800,19 +800,19 @@ bool Node::isNondeterministic() const { "aten::poisson(Tensor self, Generator? generator) -> Tensor", "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", - "aten::rand(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor", + "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::rand_like(Tensor self) -> Tensor", - "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", - "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor", - "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor", + "aten::rand_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randint_like(Tensor self, int high) -> Tensor", "aten::randint_like(Tensor self, int low, int high) -> Tensor", - "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor", - "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor", - "aten::randn(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor", + "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randn_like(Tensor self) -> Tensor", - "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", - "aten::randperm(int n, *, int? dtype, int? layout, Device? device) -> Tensor"}; + "aten::randn_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"}; if (nondeterministic_ops.find(this) == nullptr) { return false; diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index 198b3ed..f8fbdc9 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -387,7 +387,6 @@ void registerOperator(Operator&& op) { ". File a bug to add a case for this operator.\n"); } } - getRegistry().registerOperator(std::move(op)); } @@ -467,8 +466,6 @@ bool Operator::matches(const Node* node) const { // too many inputs if (!schema().is_vararg() && actuals.size() != formals.size()) { - // std::cout << "not all inputs used\n" << input_i << " " << inputs_size << - // "\n"; return false; } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index b1495d0..c221dd9 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -1197,14 +1197,14 @@ class ShapePropagator { // - has ScalarType dtype, Layeout layout and Device device arguments static const register_formula_for like_factories_with_options{ { - "aten::empty_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", - "aten::full_like(Tensor self, Scalar fill_value, *, int dtype, int layout, Device device) -> Tensor", - "aten::ones_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", - "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", - "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor", - "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor", - "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", - "aten::zeros_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", + "aten::empty_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::full_like(Tensor self, Scalar fill_value, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::ones_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::rand_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::randn_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", + "aten::zeros_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor", }, [](Node* node) -> type_vec_t { if (auto type = node->namedInput(attr::self) @@ -1226,14 +1226,14 @@ class ShapePropagator { // arguments static const register_formula_for size_factories_with_options{ { - "aten::empty(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor", - "aten::full(int[] size, Scalar fill_value, *, int? dtype, int? layout, Device? device) -> Tensor", - "aten::ones(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor", - "aten::rand(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor", - "aten::randn(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor", - "aten::zeros(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor", - "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor", - "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor", + "aten::empty(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::full(int[] size, Scalar fill_value, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::ones(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::zeros(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", + "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", }, [](Node* node) -> type_vec_t { if (auto maybe_size = node->get>(attr::size)) { diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index 1604faf..4680ae2 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -408,6 +408,7 @@ void addInputs(Node* n, const char* name, const at::TensorOptions& options) { addInputs(n, name, at::typeMetaToScalarType(options.dtype())); addInputs(n, name, options.layout()); addInputs(n, name, options.device()); + addInputs(n, name, options.pinned_memory()); } void addInputs(Node* n, const char* name, at::IntArrayRef value) { diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index a538ffc..908a27f 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -194,12 +194,15 @@ Tensor internal_new_from_data( PyObject* data, bool copy_variables, bool copy_numpy, - bool type_inference) { + bool type_inference, + bool pin_memory = false) { + if (THPUtils_checkString(data)) { throw TypeError("new(): invalid data type '%s'", Py_TYPE(data)->tp_name); } if (THPVariable_Check(data)) { + AT_CHECK(!pin_memory, "Can't pin tensor constructed from a variable"); auto var = reinterpret_cast(data)->cdata; if (copy_variables) { var = var.detach(); @@ -215,6 +218,7 @@ Tensor internal_new_from_data( #ifdef USE_NUMPY if (PyArray_Check(data)) { + AT_CHECK(!pin_memory, "Can't pin tensor constructed from numpy"); auto tensor = autograd::make_variable(tensor_from_numpy(data), /*requires_grad=*/false); const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type; auto device = device_opt.has_value() ? *device_opt : at::Device(type.device_type()); @@ -226,7 +230,7 @@ Tensor internal_new_from_data( auto sizes = compute_sizes(data); ScalarType inferred_scalar_type = type_inference ? infer_scalar_type(data) : scalar_type; - auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(inferred_scalar_type)), /*requires_grad=*/false); + auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(inferred_scalar_type).pinned_memory(pin_memory)), /*requires_grad=*/false); recursive_store( (char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0, inferred_scalar_type, tensor.dtype().itemsize(), data); @@ -508,10 +512,10 @@ Tensor sparse_coo_tensor_ctor(const Type& default_type, ScalarType scalar_type, Tensor tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) { static PythonArgParser parser({ - "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", + "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", }); - ParsedArgs<4> parsed_args; + ParsedArgs<5> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { PyObject* data = r.pyobject(0); @@ -522,7 +526,8 @@ Tensor tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* args, PyO } bool type_inference = r.isNone(1); - bool args_requires_grad = r.toBool(3); + bool pin_memory = r.toBool(3); + bool args_requires_grad = r.toBool(4); auto new_tensor = internal_new_from_data( typeWithDefault(r, 1, 2, type, scalar_type), r.scalartypeWithDefault(1, scalar_type), @@ -530,7 +535,8 @@ Tensor tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* args, PyO data, true, true, - type_inference); + type_inference, + pin_memory); new_tensor.detach_(); // ensure new_tensor a leaf node new_tensor.set_requires_grad(args_requires_grad); return new_tensor; @@ -590,14 +596,14 @@ Tensor new_tensor(const Type& type, ScalarType scalar_type, PyObject* args, PyOb Tensor new_empty(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) { static PythonArgParser parser({ - "new_empty(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", + "new_empty(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", }, /*traceable=*/true); - ParsedArgs<4> parsed_args; + ParsedArgs<5> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type); - return new_with_sizes(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3)); + return new_with_sizes(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(4)); } throw std::runtime_error("new_empty(): invalid arguments"); } diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index bb656bd..75f3062 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -75,6 +75,8 @@ def _parse_arg(value, desc): return int(tval) elif desc == 'f': return float(tval) + elif desc == 'b': + return bool(tval) elif desc == 't': return tval elif desc == 'is': @@ -1310,34 +1312,44 @@ scalar_type_to_onnx = [ ] -@parse_args('v', 'i', 'v', 'v') -def zeros(g, sizes, dtype, layout, device): +@parse_args('v', 'i', 'v', 'v', 'b') +def zeros(g, sizes, dtype, layout, device, pin_memory=False): + if pin_memory: + raise RuntimeError("onnx pin_memory support is not implemented") # NOTE: no way to set device and layout in ONNX, so we ignore it return g.op("ConstantOfShape", sizes, value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype])) -@parse_args('v', 'i', 'v', 'v') -def zeros_like(g, input, dtype, layout, device): +@parse_args('v', 'i', 'v', 'v', 'b') +def zeros_like(g, input, dtype, layout, device, pin_memory=False): + if pin_memory: + raise RuntimeError("onnx pin_memory support is not implemented") shape = g.op("Shape", input) return g.op("ConstantOfShape", shape, value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype])) -@parse_args('v', 'i', 'v', 'v') -def ones(g, sizes, dtype, layout, device): +@parse_args('v', 'i', 'v', 'v', 'b') +def ones(g, sizes, dtype, layout, device, pin_memory=False): + if pin_memory: + raise RuntimeError("onnx pin_memory support is not implemented") return g.op("ConstantOfShape", sizes, value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype])) -@parse_args('v', 'i', 'v', 'v') -def ones_like(g, input, dtype, layout, device): +@parse_args('v', 'i', 'v', 'v', 'b') +def ones_like(g, input, dtype, layout, device, pin_memory=False): + if pin_memory: + raise RuntimeError("onnx pin_memory support is not implemented") shape = g.op("Shape", input) return g.op("ConstantOfShape", shape, value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype])) -def full(g, sizes, value, dtype, layout, device): +def full(g, sizes, value, dtype, layout, device, pin_memory=False): + if pin_memory and _parse_arg(pin_memory,'b'): + raise RuntimeError("onnx pin_memory support is not implemented") const_value = _maybe_get_const(value, 't') if _is_value(const_value): tmp = zeros(sizes, dtype, layout, device) @@ -1348,8 +1360,10 @@ def full(g, sizes, value, dtype, layout, device): value_t=torch.tensor([const_value], dtype=scalar_type_to_pytorch_type[dtype])) -@parse_args('v', 'f', 'i', 'v', 'v') -def full_like(g, input, fill_value, dtype, layout, device): +@parse_args('v', 'f', 'i', 'v', 'v', 'b') +def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False): + if pin_memory: + raise RuntimeError("onnx pin_memory support is not implemented") shape = g.op("Shape", input) return g.op("ConstantOfShape", shape, value_t=torch.tensor([fill_value], dtype=scalar_type_to_pytorch_type[dtype])) @@ -1415,6 +1429,11 @@ def to(g, self, *args): dtype = _get_const(args[0], 'i', 'dtype') # Layout and device are ignored return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype]) + elif len(args) == 6: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool) -> Tensor + dtype = _get_const(args[0], 'i', 'dtype') + # Layout and device are ignored + return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype]) else: raise NotImplementedError("Unknown aten::to signature")