Adding pin_memory kwarg to zeros, ones, empty, ... tensor constructors (#18952)
authorVitaly Fedyunin <vitalyf@fb.com>
Tue, 16 Apr 2019 17:50:48 +0000 (10:50 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 16 Apr 2019 18:06:15 +0000 (11:06 -0700)
Summary:
Make it possible to construct a pinned memory tensor without creating a storage first and without calling pin_memory() function. It is also faster, as copy operation is unnecessary.

Supported functions:
```python
torch.rand_like(t, pin_memory=True)
torch.randn_like(t, pin_memory=True)
torch.empty_like(t, pin_memory=True)
torch.full_like(t, 4, pin_memory=True)
torch.zeros_like(t, pin_memory=True)
torch.ones_like(t, pin_memory=True)
torch.tensor([10,11], pin_memory=True)
torch.randn(3, 5, pin_memory=True)
torch.rand(3, pin_memory=True)
torch.zeros(3, pin_memory=True)
torch.randperm(3, pin_memory=True)
torch.empty(6, pin_memory=True)
torch.ones(6, pin_memory=True)
torch.eye(6, pin_memory=True)
torch.arange(3, 5, pin_memory=True)
```

Part of the bigger: `Remove Storage` plan.

Now compatible with both torch scripts:
 `  _1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"), pin_memory=False)`
and
`  _1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"))`

Same checked for all similar functions `rand_like`, `empty_like` and others

It is fixed version of #18455
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18952

Differential Revision: D14801792

Pulled By: VitalyFedyunin

fbshipit-source-id: 8dbc61078ff7a637d0ecdb95d4e98f704d5450ba

18 files changed:
aten/src/ATen/native/TensorFactories.cpp
aten/src/ATen/native/cuda/TensorFactories.cu
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native/sparse/SparseTensor.cpp
aten/src/ATen/native_parse.py
test/test_torch.py
tools/autograd/gen_python_functions.py
tools/autograd/templates/python_torch_functions.cpp
tools/jit/gen_jit_dispatch.py
tools/pyi/gen_pyi.py
torch/_tensor_docs.py
torch/_torch_docs.py
torch/csrc/jit/ir.cpp
torch/csrc/jit/operator.cpp
torch/csrc/jit/passes/shape_analysis.cpp
torch/csrc/jit/tracer.cpp
torch/csrc/utils/tensor_new.cpp
torch/onnx/symbolic.py

index 96a8b04..d08676d 100644 (file)
@@ -18,6 +18,7 @@
 #include <c10/core/TensorOptions.h>
 #include <TH/THRandom.h>
 #include <TH/THGenerator.hpp>
+#include <ATen/detail/CUDAHooksInterface.h>
 #include <c10/util/Exception.h>
 
 #include <algorithm>
@@ -93,7 +94,13 @@ Tensor empty_cpu(IntArrayRef size, const TensorOptions& options) {
   AT_ASSERT(!options.is_variable());  // is_variable should have been 'unpacked'  // TODO: remove this when Variable and Tensor are merged
   check_size_nonnegative(size);
 
-  auto* allocator = at::getCPUAllocator();
+  c10::Allocator* allocator;
+  if (options.pinned_memory()) {
+    allocator = detail::getCUDAHooks().getPinnedMemoryAllocator();
+  } else {
+    allocator = at::getCPUAllocator();
+  }
+
   int64_t nelements = prod_intlist(size);
   auto dtype = options.dtype();
   auto storage_impl = c10::make_intrusive<StorageImpl>(
index 2da789a..90d0208 100644 (file)
@@ -46,6 +46,7 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) {
 Tensor empty_cuda(IntArrayRef size, const TensorOptions& options) {
   AT_ASSERT(options.backend() == at::Backend::CUDA);
   AT_ASSERT(!options.is_variable());  // is_variable should have been 'unpacked'  // TODO: remove this when Variable and Tensor are merged
+  AT_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned");
   check_size_nonnegative(size);
 
   auto* allocator = at::cuda::getCUDADeviceAllocator();
index 609e179..1f860b2 100644 (file)
@@ -46,7 +46,7 @@
   dispatch:
     CUDA: _cudnn_rnn_backward
 
-- func: _cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType dtype, Layout layout, Device device) -> Tensor
+- func: _cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
   dispatch:
     CUDA: _cudnn_init_dropout_state
 
 
 - func: any(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
 
-- func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: arange(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: arange(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: arange(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: arange(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: arange(Scalar end, *, Tensor(a!) out) -> Tensor(a!)
 
     CPU: baddbmm_out_cpu
     CUDA: baddbmm_out_cuda
 
-- func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: bartlett_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: bartlett_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
 
     CPU: _bincount_cpu
     CUDA: _bincount_cuda
 
-- func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: blackman_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: blackman_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: bmm(Tensor self, Tensor mat2) -> Tensor
   variants: function, method
     CPU: _embedding_bag_per_sample_weights_backward_cpu
     CUDA: _embedding_bag_per_sample_weights_backward_cuda
 
-- func: empty(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: empty(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
   cpu_half: True
   cpu_bool: True
   cuda_bool: True
 - func: empty_like(Tensor self) -> Tensor
   device_guard: False
 
-- func: empty_like(Tensor self, *, ScalarType dtype, Layout layout, Device device) -> Tensor
+- func: empty_like(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
   device_guard: False
 
-- func: empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
   cpu_half: True
   cpu_bool: True
   cuda_bool: True
   variants: method  # This is method-only to match the previous tensor API. In the future we could make this a function too.
   device_guard: False
 
-- func: eye(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: eye(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: eye(int n, int m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: eye(int n, int m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: eye(int n, *, Tensor(a!) out) -> Tensor(a!)
   dispatch:
     CPU: _frac_out_cpu
     CUDA: _frac_out_cuda
 
-- func: full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: full(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)
 
 - func: full_like(Tensor self, Scalar fill_value) -> Tensor
 
-- func: full_like(Tensor self, Scalar fill_value, *, ScalarType dtype, Layout layout, Device device) -> Tensor
+- func: full_like(Tensor self, Scalar fill_value, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
 
 # NOTE [ grid_sampler Native Functions ]
 # `grid_sampler` does all the shape checking and then dispatches to one of
     CPU: grid_sampler_3d_backward_cpu
     CUDA: grid_sampler_3d_backward_cuda
 
-- func: hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: hann_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: hann_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: hamming_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: hamming_window(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: hamming_window(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: hamming_window(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: hamming_window(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: hamming_window(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor
 
 
 - func: fbgemm_is_cpu_supported() -> bool
 
-- func: linspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: linspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: linspace(Scalar start, Scalar end, int steps=100, *, Tensor(a!) out) -> Tensor(a!)
   dispatch:
 - func: logdet(Tensor self) -> Tensor
   variants: function, method
 
-- func: logspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: logspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: logspace(Scalar start, Scalar end, int steps=100, *, Tensor(a!) out) -> Tensor(a!)
   dispatch:
 - func: _nnpack_spatial_convolution_backward_weight(Tensor input, int[] weightsize, Tensor grad_output, int[2] padding) -> Tensor
   variants: function
 
-- func: ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: ones(int[] size, *, Tensor(a!) out) -> Tensor(a!)
 
 - func: ones_like(Tensor self) -> Tensor
 
-- func: ones_like(Tensor self, *, ScalarType dtype, Layout layout, Device device) -> Tensor
+- func: ones_like(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
 
 - func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor
 
 - func: pinverse(Tensor self, float rcond=1e-15) -> Tensor
   variants: function, method
 
-- func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: rand(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: rand(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: rand(int[] size, *, Tensor(a!) out) -> Tensor(a!)
 
 
 - func: rand_like(Tensor self) -> Tensor
 
-- func: rand_like(Tensor self, *, ScalarType dtype, Layout layout, Device device) -> Tensor
+- func: rand_like(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
 
-- func: randint(int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: randint(int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: randint(int high, int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: randint(int high, int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: randint(int low, int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: randint(int low, int high, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: randint(int low, int high, int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: randint(int low, int high, int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: randint(int high, int[] size, *, Tensor(a!) out) -> Tensor(a!)
 
 
 - func: randint_like(Tensor self, int low, int high) -> Tensor
 
-- func: randint_like(Tensor self, int high, *, ScalarType dtype, Layout layout, Device device) -> Tensor
+- func: randint_like(Tensor self, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
 
-- func: randint_like(Tensor self, int low, int high, *, ScalarType dtype, Layout layout, Device device) -> Tensor
+- func: randint_like(Tensor self, int low, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
 
-- func: randn(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: randn(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: randn(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: randn(int[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: randn(int[] size, *, Tensor(a!) out) -> Tensor(a!)
 
 
 - func: randn_like(Tensor self) -> Tensor
 
-- func: randn_like(Tensor self, *, ScalarType dtype, Layout layout, Device device) -> Tensor
+- func: randn_like(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
 
-- func: randperm(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: randperm(int n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: randperm(int n, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: randperm(int n, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: randperm(int n, *, Tensor(a!) out) -> Tensor(a!)
 
     CPU: randperm_out_cpu
     CUDA: randperm_out_cuda
 
-- func: range(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: range(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: range(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)
   dispatch:
 - func: _weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)
   variants: function
 
-- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: zeros(int[] size, *, Tensor(a!) out) -> Tensor(a!)
 
 - func: zeros_like(Tensor self) -> Tensor
 
-- func: zeros_like(Tensor self, *, ScalarType dtype, Layout layout, Device device) -> Tensor
+- func: zeros_like(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
 
 - func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
   variants: function
 
 # FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given
 # the default would never make sense.
-- func: sparse_coo_tensor(int[] size, *, ScalarType dtype, Layout layout, Device device) -> Tensor
+- func: sparse_coo_tensor(int[] size, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
 
-- func: sparse_coo_tensor(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: sparse_coo_tensor(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: sparse_coo_tensor(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: sparse_coo_tensor(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None) -> Tensor
+- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
-- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType dtype, Layout layout, Device device) -> Tensor
+- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
   dispatch:
     SparseCPU: new_with_dims_sparse
     SparseCUDA: new_with_dims_sparse
   requires_tensor: True
 
-- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType dtype, Layout layout, Device device) -> Tensor
+- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
   dispatch:
     SparseCPU: new_with_dims_and_tensor_sparse
     SparseCUDA: new_with_dims_and_tensor_sparse
   requires_tensor: True
 
-
 - func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
   variants: method
   dispatch:
 # to(Device) must not exist because all constructors of Device also works for
 # TensorOptions. Otherwise, an ambiguity error is thrown.
 # See NOTE [ TensorOptions Constructors ].
-- func: to(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool non_blocking=False, bool copy=False) -> Tensor
+- func: to(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, bool non_blocking=False, bool copy=False) -> Tensor
   variants: method
   device_guard: False
 
index d3278c6..67f3e22 100644 (file)
@@ -108,6 +108,7 @@ SparseTensor new_with_dims_and_tensor_sparse(
 
 /** Empty init **/
 Tensor empty_sparse(IntArrayRef size, const TensorOptions& options) {
+  AT_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned");
   return new_with_dims_sparse(size.size(), 0, size, options);
 }
 
index 08175c3..bdc4f1a 100644 (file)
@@ -202,19 +202,15 @@ def parse_arguments(args, func_variants, declaration, func_return):
             {'name': 'dtype', 'type': 'ScalarType', 'is_nullable': False, 'annotation': None},
             {'name': 'layout', 'type': 'Layout', 'is_nullable': False, 'annotation': None},
             {'name': 'device', 'type': 'Device', 'is_nullable': False, 'annotation': None},
+            {'name': 'pin_memory', 'type': 'bool', 'is_nullable': False, 'annotation': None, 'default': False},
         ]
     ]
     supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[0]))
-    supported_topt_arguments[1][0]['kwarg_only'] = True
-    supported_topt_arguments[1][1]['kwarg_only'] = True
-    supported_topt_arguments[1][2]['kwarg_only'] = True
+    for arg in supported_topt_arguments[1]:
+        arg.update({'kwarg_only': True})
     supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[1]))
-    supported_topt_arguments[2][0]['default'] = 'c10::nullopt'
-    supported_topt_arguments[2][1]['default'] = 'c10::nullopt'
-    supported_topt_arguments[2][2]['default'] = 'c10::nullopt'
-    supported_topt_arguments[2][0]['is_nullable'] = True
-    supported_topt_arguments[2][1]['is_nullable'] = True
-    supported_topt_arguments[2][2]['is_nullable'] = True
+    for arg in supported_topt_arguments[2]:
+        arg.update({'default': 'c10::nullopt', 'is_nullable': True})
 
     corresponding_topts = [
         {'type': 'TensorOptions', 'name': 'options', 'is_nullable': False, 'annotation': None},
@@ -226,30 +222,28 @@ def parse_arguments(args, func_variants, declaration, func_return):
 
     def check_topt_representation(topt_representation):
         for idx, supported_topt in enumerate(supported_topt_arguments):
-            matches = True
-            matches = matches and topt_representation[0] == supported_topt[0]
-            matches = matches and topt_representation[1] == supported_topt[1]
-            matches = matches and topt_representation[2] == supported_topt[2]
+            matches = all(topt_representation[i] == topt for i, topt in enumerate(supported_topt))
             if matches:
                 return corresponding_topts[idx]
         return None
 
     def is_tensor_option(argument):
-        return argument['name'] in ['dtype', 'layout', 'device']
+        return argument['name'] in ['dtype', 'layout', 'device', 'pin_memory']
 
     new_arguments = []
     idx = 0
     while idx < len(arguments):
         argument = arguments[idx]
-        if is_tensor_option(argument) and len(arguments) - idx >= 3:
+        number_of_arguments = len(supported_topt_arguments[0])
+        if is_tensor_option(argument) and len(arguments) - idx >= number_of_arguments:
             topt_representation = []
-            for i in range(3):
+            for i in range(number_of_arguments):
                 argument = arguments[idx]
                 if not is_tensor_option(argument):
                     break
                 topt_representation.append(argument)
                 idx += 1
-            if len(topt_representation) == 3:
+            if len(topt_representation) == number_of_arguments:
                 merged_argument = check_topt_representation(topt_representation)
                 assert merged_argument, \
                     "Unsupported combination of TensorOptions {}, the only currently supported combinations are {}"\
index 4c44cae..b7ba9f7 100644 (file)
@@ -9984,6 +9984,40 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         self.assertEqual(pinned, x)
         self.assertNotEqual(pinned.data_ptr(), x.data_ptr())
 
+    @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
+    def test_pin_memory_from_constructor(self):
+
+        def _get_like(t, **kwargs):
+            return [
+                torch.rand_like(t, **kwargs),
+                torch.randn_like(t, **kwargs),
+                torch.empty_like(t, **kwargs),
+                torch.full_like(t, 4, **kwargs),
+                torch.zeros_like(t, **kwargs),
+                torch.ones_like(t, **kwargs),
+            ]
+
+        def _get_tensors(**kwargs):
+            return [
+                torch.tensor([10, 11], **kwargs),
+                torch.randn(3, 5, **kwargs),
+                torch.rand(3, **kwargs),
+                # torch.randint(3, 5, **kwargs), // unsupported
+                torch.zeros(3, **kwargs),
+                torch.randperm(3, **kwargs),
+                torch.empty(6, **kwargs),
+                torch.ones(6, **kwargs),
+                torch.eye(6, **kwargs),
+                torch.arange(3, 5, **kwargs)]
+
+        pinned_tensors = _get_tensors(pin_memory=True) + _get_like(torch.empty(5, dtype=torch.float64), pin_memory=True)
+        for x in pinned_tensors:
+            self.assertTrue(x.is_pinned())
+
+        tensors = _get_tensors() + _get_like(torch.empty(5, dtype=torch.float64, pin_memory=True))
+        for x in tensors:
+            self.assertFalse(x.is_pinned())
+
     @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
     def test_numpy_unresizable(self):
         x = np.zeros((2, 2))
index 2e7fe79..821058e 100644 (file)
@@ -157,7 +157,8 @@ const auto options = TensorOptions()
     .dtype(${dtype})
     .device(${device})
     .layout(${layout}.layout)
-    .requires_grad(${requires_grad});
+    .requires_grad(${requires_grad})
+    .pinned_memory(${pin_memory});
 """)
 
 
@@ -429,9 +430,9 @@ def create_python_bindings(python_functions, has_self, is_module=False):
                 arg_idx += 1
 
         if 'layout' in (a['name'] for a in python_binding_arguments):
-            layout_idx, device_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2)
+            layout_idx, device_idx, pin_memory_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2, arg_idx + 3)
         else:
-            device_idx, requires_grad_idx = (arg_idx, arg_idx + 1)
+            device_idx, pin_memory_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2)
 
         device = None
         for arg in python_binding_arguments:
@@ -459,9 +460,11 @@ def create_python_bindings(python_functions, has_self, is_module=False):
                     has_device_bind = True
             elif arg['name'] == 'requires_grad' and arg['simple_type'] == 'bool':
                 requires_grad = parse_arg(arg, requires_grad_idx)[0]
+            elif arg['name'] == 'pin_memory' and arg['simple_type'] == 'bool':
+                pin_memory = parse_arg(arg, pin_memory_idx)[0]
             else:
                 raise RuntimeError(("found {} in python_binding_arguments but only "
-                                    "\"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", "
+                                    "\"bool pin_memory\", \"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", "
                                     "\"Device device\" are supported".format(arg)))
 
         dtype = parsed_type_args[0] if parsed_type_args else None
@@ -470,7 +473,8 @@ def create_python_bindings(python_functions, has_self, is_module=False):
                 'dtype': dtype,
                 'layout': layout,
                 'device': device,
-                'requires_grad': requires_grad
+                'requires_grad': requires_grad,
+                'pin_memory': pin_memory,
             }))
             formal_args.append('const TensorOptions & options')
             actuals.append('options')
@@ -620,6 +624,15 @@ def create_python_bindings(python_functions, has_self, is_module=False):
                 'python_default_init': py_default_device
             }
             python_binding_arguments.append(device_arg)
+            pin_memory_arg = {
+                'default': False,
+                'dynamic_type': 'bool',
+                'kwarg_only': True,
+                'name': 'pin_memory',
+                'type': 'bool',
+                'simple_type': 'bool',
+            }
+            python_binding_arguments.append(pin_memory_arg)
         if is_factory_or_like_function:
             requires_grad_arg = {
                 'default': False,
index 9eb5bb9..8b07b9e 100644 (file)
@@ -96,11 +96,11 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k
 {
   HANDLE_TH_ERRORS
   static PythonArgParser parser({
-    "arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)",
-    "arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool requires_grad=False)",
+    "arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
+    "arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
   });
 
-  ParsedArgs<8> parsed_args;
+  ParsedArgs<9> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
 
   if (r.idx == 0) {
@@ -112,12 +112,14 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k
           .dtype(scalarType)
           .device(r.device(4))
           .layout(r.layout(3).layout)
-          .requires_grad(r.toBool(5));
+          .requires_grad(r.toBool(6))
+          .pinned_memory(r.toBool(5));
       return wrap(dispatch_arange(end, options));
     } else {
+      AT_CHECK(!r.toBool(5), " `pin_memory` and `out` parameters are incompatible");
       check_out_type_matches(r.tensor(1), r.scalartype(2), r.isNone(2), r.layout(3), r.isNone(3),
                              r.device(4), r.isNone(4));
-      return wrap(dispatch_arange(r.scalar(0), r.tensor(1)).set_requires_grad(r.toBool(5)));
+      return wrap(dispatch_arange(r.scalar(0), r.tensor(1)).set_requires_grad(r.toBool(6)));
     }
   } else if (r.idx == 1) {
     if (r.isNone(3)) {
@@ -130,12 +132,14 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k
           .dtype(scalarType)
           .device(r.device(6))
           .layout(r.layout(5).layout)
-          .requires_grad(r.toBool(7));
+          .requires_grad(r.toBool(8))
+          .pinned_memory(r.toBool(7));
       return wrap(dispatch_arange(start, end, step, options));
     } else {
+      AT_CHECK(!r.toBool(7), " `pin_memory` and `out` parameters are incompatible");
       check_out_type_matches(r.tensor(3), r.scalartype(4), r.isNone(4), r.layout(5), r.isNone(5),
                                r.device(6), r.isNone(6));
-      return wrap(dispatch_arange(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(7)));
+      return wrap(dispatch_arange(r.scalar(0), r.scalar(1), r.scalar(2), r.tensor(3)).set_requires_grad(r.toBool(8)));
     }
   }
   Py_RETURN_NONE;
index 14b0843..483fa5f 100644 (file)
@@ -59,6 +59,7 @@ TYPE_MAP = {
     'int64_t?': 'int?',
     'double': 'float',
     'bool': 'bool',
+    'bool?': 'bool?',
     'Generator': 'Generator?',
 }
 
@@ -104,6 +105,7 @@ FROM_IVALUE = {
     'Tensor?[]': 'toListOfOptionalTensor({})',
     'TensorList': '{}.toTensorList()->elements()',
     'bool': '{}.toBool()',
+    'bool?': '{}.toOptional<bool>()',
     'double': '{}.toDouble()',
     'int64_t': '{}.toInt()',
     'int64_t?': '{}.toOptional<int64_t>()',
@@ -134,14 +136,16 @@ CALL_NAMESPACE_WITH_TENSOR_OPTIONS = CodeTemplate("""\
 const auto options = TensorOptions()
         .dtype(${dtype})
         .layout(${layout})
-        .device(${device});
+        .device(${device})
+        .pinned_memory(${pin_memory});
 auto result_ = torch::${name}(${args_with_tensor_options});
 """)
 CALL_METHOD_WITH_TENSOR_OPTIONS = CodeTemplate("""\
 const auto options = TensorOptions()
         .dtype(${dtype})
         .layout(${layout})
-        .device(${device});
+        .device(${device})
+        .pinned_memory(${pin_memory});;
 auto result_ = (${first}).${name}(${args_with_tensor_options});
 """)
 
@@ -242,15 +246,18 @@ def gen_jit_dispatch(declarations, out, template_path):
             dtype = args[tensor_options_arg_index]
             layout = args[tensor_options_arg_index + 1]
             device = args[tensor_options_arg_index + 2]
+            pin_memory = args[tensor_options_arg_index + 3]
             args_with_tensor_options = args[:tensor_options_arg_index] + \
-                ['options'] + args[(tensor_options_arg_index + 3):]
+                ['options'] + args[(tensor_options_arg_index + 4):]
             if is_namespace_function:
                 return CALL_NAMESPACE_WITH_TENSOR_OPTIONS.substitute(
-                    name=decl['name'], dtype=dtype, layout=layout, device=device,
+                    name=decl['name'], dtype=dtype, layout=layout,
+                    device=device, pin_memory=pin_memory,
                     args_with_tensor_options=pack_arguments(args_with_tensor_options))
             else:
                 return CALL_METHOD_WITH_TENSOR_OPTIONS.substitute(
-                    name=decl['name'], dtype=dtype, layout=layout, device=device,
+                    name=decl['name'], dtype=dtype, layout=layout,
+                    device=device, pin_memory=pin_memory,
                     args_with_tensor_options=pack_arguments(args_with_tensor_options[1:]),
                     first=args_with_tensor_options[0], num_inputs=num_inputs)
         else:
@@ -349,21 +356,19 @@ def gen_jit_dispatch(declarations, out, template_path):
             {'name': 'layout', 'simple_type': 'Layout'},
             # device is specified as an IntArrayRef of { at::Device::Type, device_id }
             {'name': 'device', 'simple_type': 'Device'},
+            # pin_memory is specified as a boolean
+            {'name': 'pin_memory', 'simple_type': 'bool', 'default': False},
         ]
         # TODO: Don't repack this into TensorOptions. Needs various changes in downstream code.
         if 'default' in arg:
-            tensor_options_expansion[0]['simple_type'] += '?'
-            tensor_options_expansion[1]['simple_type'] += '?'
-            tensor_options_expansion[2]['simple_type'] += '?'
-            tensor_options_expansion[0]['default'] = 'None'
-            tensor_options_expansion[1]['default'] = 'None'
-            tensor_options_expansion[2]['default'] = 'None'
+            for el in tensor_options_expansion:
+                el['simple_type'] += '?'
+                el['default'] = 'None'
         if 'default' in arg and arg['default'] == 'at::kLong':
             tensor_options_expansion[0]['default'] = 'long'
         if 'kwarg_only' in arg and arg['kwarg_only']:
-            tensor_options_expansion[0]['kwarg_only'] = True
-            tensor_options_expansion[1]['kwarg_only'] = True
-            tensor_options_expansion[2]['kwarg_only'] = True
+            for el in tensor_options_expansion:
+                el['kwarg_only'] = True
         return tensor_options_expansion
 
     additional_jit_decls = []
index 0d5ad32..5c7eb4a 100644 (file)
@@ -240,7 +240,11 @@ def generate_type_hints(fname, decls, is_tensor=False):
                 if a.get('kwarg_only', False) and render_kw_only_separator:
                     python_args.append('*')
                     render_kw_only_separator = False
-                python_args.append(arg_to_type_hint(a))
+                try:
+                    python_args.append(arg_to_type_hint(a))
+                except Exception:
+                    print("Error while processing function {}".format(fname))
+                    raise
 
         if is_tensor:
             if 'self: Tensor' in python_args:
index 262998c..948b7c1 100644 (file)
@@ -17,6 +17,8 @@ new_common_args = parse_kwargs("""
         Default: if None, same :class:`torch.device` as this tensor.
     requires_grad (bool, optional): If autograd should record operations on the
         returned tensor. Default: ``False``.
+    pin_memory (bool, optional): If set, returned tensor would be allocated in
+        the pinned memory. Works only for CPU tensors. Default: ``False``.
 """)
 
 add_docstr_all('new_tensor',
index 98512af..703b189 100644 (file)
@@ -63,6 +63,8 @@ factory_common_args = parse_kwargs("""
         for CPU tensor types and the current CUDA device for CUDA tensor types.
     requires_grad (bool, optional): If autograd should record operations on the
         returned tensor. Default: ``False``.
+    pin_memory (bool, optional): If set, returned tensor would be allocated in
+        the pinned memory. Works only for CPU tensors. Default: ``False``.
 """)
 
 factory_like_common_args = parse_kwargs("""
@@ -75,6 +77,8 @@ factory_like_common_args = parse_kwargs("""
         Default: if ``None``, defaults to the device of :attr:`input`.
     requires_grad (bool, optional): If autograd should record operations on the
         returned tensor. Default: ``False``.
+    pin_memory (bool, optional): If set, returned tensor would be allocated in
+        the pinned memory. Works only for CPU tensors. Default: ``False``.
 """)
 
 factory_data_common_args = parse_kwargs("""
@@ -88,6 +92,8 @@ factory_data_common_args = parse_kwargs("""
         for CPU tensor types and the current CUDA device for CUDA tensor types.
     requires_grad (bool, optional): If autograd should record operations on the
         returned tensor. Default: ``False``.
+    pin_memory (bool, optional): If set, returned tensor would be allocated in
+        the pinned memory. Works only for CPU tensors. Default: ``False``.
 """)
 
 add_docstr(torch.abs,
@@ -3990,7 +3996,7 @@ Example::
 
 add_docstr(torch.tensor,
            r"""
-tensor(data, dtype=None, device=None, requires_grad=False) -> Tensor
+tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor
 
 Constructs a tensor with :attr:`data`.
 
@@ -4014,6 +4020,7 @@ Args:
     {dtype}
     {device}
     {requires_grad}
+    {pin_memory}
 
 
 Example::
@@ -5553,7 +5560,7 @@ Example::
 
 add_docstr(torch.empty,
            r"""
-empty(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
+empty(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor
 
 Returns a tensor filled with uninitialized data. The shape of the tensor is
 defined by the variable argument :attr:`sizes`.
@@ -5566,6 +5573,7 @@ Args:
     {layout}
     {device}
     {requires_grad}
+    {pin_memory}
 
 Example::
 
index 05fe6ff..e156aa0 100644 (file)
@@ -800,19 +800,19 @@ bool Node::isNondeterministic() const {
       "aten::poisson(Tensor self, Generator? generator) -> Tensor",
       "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
       "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
-      "aten::rand(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
+      "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
       "aten::rand_like(Tensor self) -> Tensor",
-      "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
-      "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
-      "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
+      "aten::rand_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
+      "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
+      "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
       "aten::randint_like(Tensor self, int high) -> Tensor",
       "aten::randint_like(Tensor self, int low, int high) -> Tensor",
-      "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor",
-      "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor",
-      "aten::randn(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
+      "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
+      "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
+      "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
       "aten::randn_like(Tensor self) -> Tensor",
-      "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
-      "aten::randperm(int n, *, int? dtype, int? layout, Device? device) -> Tensor"};
+      "aten::randn_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
+      "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
 
   if (nondeterministic_ops.find(this) == nullptr) {
     return false;
index 198b3ed..f8fbdc9 100644 (file)
@@ -387,7 +387,6 @@ void registerOperator(Operator&& op) {
           ". File a bug to add a case for this operator.\n");
     }
   }
-
   getRegistry().registerOperator(std::move(op));
 }
 
@@ -467,8 +466,6 @@ bool Operator::matches(const Node* node) const {
 
   // too many inputs
   if (!schema().is_vararg() && actuals.size() != formals.size()) {
-    // std::cout << "not all inputs used\n" << input_i << " " << inputs_size <<
-    // "\n";
     return false;
   }
 
index b1495d0..c221dd9 100644 (file)
@@ -1197,14 +1197,14 @@ class ShapePropagator {
     //   - has ScalarType dtype, Layeout layout and Device device arguments
     static const register_formula_for like_factories_with_options{
         {
-            "aten::empty_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
-            "aten::full_like(Tensor self, Scalar fill_value, *, int dtype, int layout, Device device) -> Tensor",
-            "aten::ones_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
-            "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
-            "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor",
-            "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor",
-            "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
-            "aten::zeros_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
+            "aten::empty_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
+            "aten::full_like(Tensor self, Scalar fill_value, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
+            "aten::ones_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
+            "aten::rand_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
+            "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
+            "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
+            "aten::randn_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
+            "aten::zeros_like(Tensor self, *, int dtype, int layout, Device device, bool pin_memory) -> Tensor",
         },
         [](Node* node) -> type_vec_t {
           if (auto type = node->namedInput(attr::self)
@@ -1226,14 +1226,14 @@ class ShapePropagator {
     //   arguments
     static const register_formula_for size_factories_with_options{
         {
-            "aten::empty(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
-            "aten::full(int[] size, Scalar fill_value, *, int? dtype, int? layout, Device? device) -> Tensor",
-            "aten::ones(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
-            "aten::rand(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
-            "aten::randn(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
-            "aten::zeros(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
-            "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
-            "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
+            "aten::empty(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
+            "aten::full(int[] size, Scalar fill_value, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
+            "aten::ones(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
+            "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
+            "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
+            "aten::zeros(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
+            "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
+            "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
         },
         [](Node* node) -> type_vec_t {
           if (auto maybe_size = node->get<std::vector<int64_t>>(attr::size)) {
index 1604faf..4680ae2 100644 (file)
@@ -408,6 +408,7 @@ void addInputs(Node* n, const char* name, const at::TensorOptions& options) {
   addInputs(n, name, at::typeMetaToScalarType(options.dtype()));
   addInputs(n, name, options.layout());
   addInputs(n, name, options.device());
+  addInputs(n, name, options.pinned_memory());
 }
 
 void addInputs(Node* n, const char* name, at::IntArrayRef value) {
index a538ffc..908a27f 100644 (file)
@@ -194,12 +194,15 @@ Tensor internal_new_from_data(
     PyObject* data,
     bool copy_variables,
     bool copy_numpy,
-    bool type_inference) {
+    bool type_inference,
+    bool pin_memory = false) {
+
   if (THPUtils_checkString(data)) {
     throw TypeError("new(): invalid data type '%s'", Py_TYPE(data)->tp_name);
   }
 
   if (THPVariable_Check(data)) {
+    AT_CHECK(!pin_memory, "Can't pin tensor constructed from a variable");
     auto var = reinterpret_cast<THPVariable*>(data)->cdata;
     if (copy_variables) {
       var = var.detach();
@@ -215,6 +218,7 @@ Tensor internal_new_from_data(
 
 #ifdef USE_NUMPY
   if (PyArray_Check(data)) {
+    AT_CHECK(!pin_memory, "Can't pin tensor constructed from numpy");
     auto tensor = autograd::make_variable(tensor_from_numpy(data), /*requires_grad=*/false);
     const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type;
     auto device = device_opt.has_value() ? *device_opt : at::Device(type.device_type());
@@ -226,7 +230,7 @@ Tensor internal_new_from_data(
 
   auto sizes = compute_sizes(data);
   ScalarType inferred_scalar_type = type_inference ? infer_scalar_type(data) : scalar_type;
-  auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(inferred_scalar_type)), /*requires_grad=*/false);
+  auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(inferred_scalar_type).pinned_memory(pin_memory)), /*requires_grad=*/false);
   recursive_store(
       (char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0,
       inferred_scalar_type, tensor.dtype().itemsize(), data);
@@ -508,10 +512,10 @@ Tensor sparse_coo_tensor_ctor(const Type& default_type, ScalarType scalar_type,
 
 Tensor tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
-    "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
+    "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
   });
 
-  ParsedArgs<4> parsed_args;
+  ParsedArgs<5> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
     PyObject* data = r.pyobject(0);
@@ -522,7 +526,8 @@ Tensor tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* args, PyO
     }
 
     bool type_inference = r.isNone(1);
-    bool args_requires_grad = r.toBool(3);
+    bool pin_memory = r.toBool(3);
+    bool args_requires_grad = r.toBool(4);
     auto new_tensor = internal_new_from_data(
                typeWithDefault(r, 1, 2, type, scalar_type),
                r.scalartypeWithDefault(1, scalar_type),
@@ -530,7 +535,8 @@ Tensor tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* args, PyO
                data,
                true,
                true,
-               type_inference);
+               type_inference,
+               pin_memory);
     new_tensor.detach_(); // ensure new_tensor a leaf node
     new_tensor.set_requires_grad(args_requires_grad);
     return new_tensor;
@@ -590,14 +596,14 @@ Tensor new_tensor(const Type& type, ScalarType scalar_type, PyObject* args, PyOb
 
 Tensor new_empty(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
-    "new_empty(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
+    "new_empty(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
   }, /*traceable=*/true);
 
-  ParsedArgs<4> parsed_args;
+  ParsedArgs<5> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
     const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
-    return new_with_sizes(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
+    return new_with_sizes(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(4));
   }
   throw std::runtime_error("new_empty(): invalid arguments");
 }
index bb656bd..75f3062 100644 (file)
@@ -75,6 +75,8 @@ def _parse_arg(value, desc):
             return int(tval)
         elif desc == 'f':
             return float(tval)
+        elif desc == 'b':
+            return bool(tval)
         elif desc == 't':
             return tval
         elif desc == 'is':
@@ -1310,34 +1312,44 @@ scalar_type_to_onnx = [
 ]
 
 
-@parse_args('v', 'i', 'v', 'v')
-def zeros(g, sizes, dtype, layout, device):
+@parse_args('v', 'i', 'v', 'v', 'b')
+def zeros(g, sizes, dtype, layout, device, pin_memory=False):
+    if pin_memory:
+        raise RuntimeError("onnx pin_memory support is not implemented")
     # NOTE: no way to set device and layout in ONNX, so we ignore it
     return g.op("ConstantOfShape", sizes,
                 value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype]))
 
 
-@parse_args('v', 'i', 'v', 'v')
-def zeros_like(g, input, dtype, layout, device):
+@parse_args('v', 'i', 'v', 'v', 'b')
+def zeros_like(g, input, dtype, layout, device, pin_memory=False):
+    if pin_memory:
+        raise RuntimeError("onnx pin_memory support is not implemented")
     shape = g.op("Shape", input)
     return g.op("ConstantOfShape", shape,
                 value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype]))
 
 
-@parse_args('v', 'i', 'v', 'v')
-def ones(g, sizes, dtype, layout, device):
+@parse_args('v', 'i', 'v', 'v', 'b')
+def ones(g, sizes, dtype, layout, device, pin_memory=False):
+    if pin_memory:
+        raise RuntimeError("onnx pin_memory support is not implemented")
     return g.op("ConstantOfShape", sizes,
                 value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype]))
 
 
-@parse_args('v', 'i', 'v', 'v')
-def ones_like(g, input, dtype, layout, device):
+@parse_args('v', 'i', 'v', 'v', 'b')
+def ones_like(g, input, dtype, layout, device, pin_memory=False):
+    if pin_memory:
+        raise RuntimeError("onnx pin_memory support is not implemented")
     shape = g.op("Shape", input)
     return g.op("ConstantOfShape", shape,
                 value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype]))
 
 
-def full(g, sizes, value, dtype, layout, device):
+def full(g, sizes, value, dtype, layout, device, pin_memory=False):
+    if pin_memory and _parse_arg(pin_memory,'b'):
+        raise RuntimeError("onnx pin_memory support is not implemented")
     const_value = _maybe_get_const(value, 't')
     if _is_value(const_value):
         tmp = zeros(sizes, dtype, layout, device)
@@ -1348,8 +1360,10 @@ def full(g, sizes, value, dtype, layout, device):
                     value_t=torch.tensor([const_value], dtype=scalar_type_to_pytorch_type[dtype]))
 
 
-@parse_args('v', 'f', 'i', 'v', 'v')
-def full_like(g, input, fill_value, dtype, layout, device):
+@parse_args('v', 'f', 'i', 'v', 'v', 'b')
+def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False):
+    if pin_memory:
+        raise RuntimeError("onnx pin_memory support is not implemented")
     shape = g.op("Shape", input)
     return g.op("ConstantOfShape", shape,
                 value_t=torch.tensor([fill_value], dtype=scalar_type_to_pytorch_type[dtype]))
@@ -1415,6 +1429,11 @@ def to(g, self, *args):
         dtype = _get_const(args[0], 'i', 'dtype')
         # Layout and device are ignored
         return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
+    elif len(args) == 6:
+        # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool) -> Tensor
+        dtype = _get_const(args[0], 'i', 'dtype')
+        # Layout and device are ignored
+        return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
     else:
         raise NotImplementedError("Unknown aten::to signature")