Remove python_default_init from ATen and use Optional (#15234)
authorWanchao Liang <wanchaol@users.noreply.github.com>
Thu, 20 Dec 2018 05:35:01 +0000 (21:35 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 05:38:50 +0000 (21:38 -0800)
Summary:
Optional clean up. This PR remove python_default_init from the yaml files, and the code-gen, and utilize optional type to do the work.

This also fix the bug in the #13149 to correctly adopt as_strided backward.

Fixes #9941
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15234

Differential Revision: D13502044

Pulled By: wanchaol

fbshipit-source-id: 774b61fc4414482cf11d56e22bd0275aefb352a4

18 files changed:
aten/src/ATen/Declarations.cwrap
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/function_wrapper.py
aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/SpectralOps.cpp
aten/src/ATen/native/TensorShape.cpp
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native_parse.py
tools/autograd/derivatives.yaml
tools/autograd/gen_python_functions.py
tools/autograd/templates/Functions.cpp
tools/jit/gen_jit_dispatch.py
torch/csrc/jit/passes/shape_analysis.cpp
torch/csrc/jit/tracer.cpp
torch/csrc/jit/tracer.h
torch/csrc/utils/python_arg_parser.h

index 252cc96..4975dc4 100644 (file)
         - arg: THTensor* result
           output: True
         - THTensor* self
-        - arg: real p
-          python_default_init: AS_REAL(2)
+        - real p
         - arg: long dim
           wrap_dim: self
         - arg: bool keepdim
index c8d03e0..936a881 100644 (file)
@@ -298,10 +298,8 @@ public:
   Tensor argmax() const;
   Tensor argmin(int64_t dim, bool keepdim=false) const;
   Tensor argmin() const;
-  Tensor as_strided(IntList size, IntList stride) const;
-  Tensor & as_strided_(IntList size, IntList stride);
-  Tensor as_strided(IntList size, IntList stride, int64_t storage_offset) const;
-  Tensor & as_strided_(IntList size, IntList stride, int64_t storage_offset);
+  Tensor as_strided(IntList size, IntList stride, c10::optional<int64_t> storage_offset=c10::nullopt) const;
+  Tensor & as_strided_(IntList size, IntList stride, c10::optional<int64_t> storage_offset=c10::nullopt);
   Tensor asin() const;
   Tensor & asin_();
   Tensor atan() const;
@@ -449,7 +447,7 @@ public:
   Tensor & squeeze_();
   Tensor & squeeze_(int64_t dim);
   Tensor sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1) const;
-  Tensor stft(int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor & window={}, bool normalized=false, bool onesided=true) const;
+  Tensor stft(int64_t n_fft, c10::optional<int64_t> hop_length=c10::nullopt, c10::optional<int64_t> win_length=c10::nullopt, const Tensor & window={}, bool normalized=false, bool onesided=true) const;
   int64_t stride(int64_t dim) const;
   Tensor sum(ScalarType dtype) const;
   Tensor sum() const;
@@ -487,7 +485,7 @@ public:
   Tensor view_as(const Tensor & other) const;
   Tensor where(const Tensor & condition, const Tensor & other) const;
   Tensor norm(Scalar p=2) const;
-  Tensor norm(Scalar p, int64_t dim, bool keepdim=false) const;
+  Tensor norm(c10::optional<Scalar> p, int64_t dim, bool keepdim=false) const;
   Tensor clone() const;
   Tensor & resize_as_(const Tensor & the_template);
   Tensor pow(Scalar exponent) const;
index 17b74bc..b44ca0f 100644 (file)
@@ -115,16 +115,10 @@ inline Tensor Tensor::argmin(int64_t dim, bool keepdim) const {
 inline Tensor Tensor::argmin() const {
     return type().argmin(*this);
 }
-inline Tensor Tensor::as_strided(IntList size, IntList stride) const {
-    return type().as_strided(*this, size, stride);
-}
-inline Tensor & Tensor::as_strided_(IntList size, IntList stride) {
-    return type().as_strided_(*this, size, stride);
-}
-inline Tensor Tensor::as_strided(IntList size, IntList stride, int64_t storage_offset) const {
+inline Tensor Tensor::as_strided(IntList size, IntList stride, c10::optional<int64_t> storage_offset) const {
     return type().as_strided(*this, size, stride, storage_offset);
 }
-inline Tensor & Tensor::as_strided_(IntList size, IntList stride, int64_t storage_offset) {
+inline Tensor & Tensor::as_strided_(IntList size, IntList stride, c10::optional<int64_t> storage_offset) {
     return type().as_strided_(*this, size, stride, storage_offset);
 }
 inline Tensor Tensor::asin() const {
@@ -568,7 +562,7 @@ inline Tensor & Tensor::squeeze_(int64_t dim) {
 inline Tensor Tensor::sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const {
     return type().sspaddmm(*this, mat1, mat2, beta, alpha);
 }
-inline Tensor Tensor::stft(int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor & window, bool normalized, bool onesided) const {
+inline Tensor Tensor::stft(int64_t n_fft, c10::optional<int64_t> hop_length, c10::optional<int64_t> win_length, const Tensor & window, bool normalized, bool onesided) const {
     return type().stft(*this, n_fft, hop_length, win_length, window, normalized, onesided);
 }
 inline int64_t Tensor::stride(int64_t dim) const {
@@ -682,7 +676,7 @@ inline Tensor Tensor::where(const Tensor & condition, const Tensor & other) cons
 inline Tensor Tensor::norm(Scalar p) const {
     return type().norm(*this, p);
 }
-inline Tensor Tensor::norm(Scalar p, int64_t dim, bool keepdim) const {
+inline Tensor Tensor::norm(c10::optional<Scalar> p, int64_t dim, bool keepdim) const {
     return type().norm(*this, p, dim, keepdim);
 }
 inline Tensor Tensor::clone() const {
index 55667e0..df41133 100644 (file)
@@ -205,10 +205,8 @@ struct CAFFE2_API Type {
   virtual Tensor argmax(const Tensor & self) const = 0;
   virtual Tensor argmin(const Tensor & self, int64_t dim, bool keepdim) const = 0;
   virtual Tensor argmin(const Tensor & self) const = 0;
-  virtual Tensor as_strided(const Tensor & self, IntList size, IntList stride) const = 0;
-  virtual Tensor & as_strided_(Tensor & self, IntList size, IntList stride) const = 0;
-  virtual Tensor as_strided(const Tensor & self, IntList size, IntList stride, int64_t storage_offset) const = 0;
-  virtual Tensor & as_strided_(Tensor & self, IntList size, IntList stride, int64_t storage_offset) const = 0;
+  virtual Tensor as_strided(const Tensor & self, IntList size, IntList stride, c10::optional<int64_t> storage_offset) const = 0;
+  virtual Tensor & as_strided_(Tensor & self, IntList size, IntList stride, c10::optional<int64_t> storage_offset) const = 0;
   virtual Tensor asin(const Tensor & self) const = 0;
   virtual Tensor & asin_(Tensor & self) const = 0;
   virtual Tensor atan(const Tensor & self) const = 0;
@@ -356,7 +354,7 @@ struct CAFFE2_API Type {
   virtual Tensor & squeeze_(Tensor & self) const = 0;
   virtual Tensor & squeeze_(Tensor & self, int64_t dim) const = 0;
   virtual Tensor sspaddmm(const Tensor & self, const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const = 0;
-  virtual Tensor stft(const Tensor & self, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor & window, bool normalized, bool onesided) const = 0;
+  virtual Tensor stft(const Tensor & self, int64_t n_fft, c10::optional<int64_t> hop_length, c10::optional<int64_t> win_length, const Tensor & window, bool normalized, bool onesided) const = 0;
   virtual int64_t stride(const Tensor & self, int64_t dim) const = 0;
   virtual Tensor sum(const Tensor & self, ScalarType dtype) const = 0;
   virtual Tensor sum(const Tensor & self) const = 0;
@@ -394,7 +392,7 @@ struct CAFFE2_API Type {
   virtual Tensor view_as(const Tensor & self, const Tensor & other) const = 0;
   virtual Tensor where(const Tensor & condition, const Tensor & self, const Tensor & other) const = 0;
   virtual Tensor norm(const Tensor & self, Scalar p) const = 0;
-  virtual Tensor norm(const Tensor & self, Scalar p, int64_t dim, bool keepdim) const = 0;
+  virtual Tensor norm(const Tensor & self, c10::optional<Scalar> p, int64_t dim, bool keepdim) const = 0;
   virtual Tensor clone(const Tensor & self) const = 0;
   virtual Tensor & resize_as_(Tensor & self, const Tensor & the_template) const = 0;
   virtual Tensor pow(const Tensor & self, Scalar exponent) const = 0;
index fd6a377..c56f7b8 100644 (file)
@@ -395,7 +395,6 @@ THFormal = TypedDict('THFormal', {
     'is_nullable': bool,
     'default': str,
     'default_init': str,
-    'python_default_init': str,
     'output': bool,
     'size': int,
     'declared_type': str,
@@ -422,7 +421,6 @@ AtFormal = TypedDict('AtFormal', {
     'is_nullable': bool,
     'default': str,
     'default_init': str,
-    'python_default_init': str,
     'output': bool,
     'size': int,
 }, total=False)
@@ -619,10 +617,6 @@ def create_generic(top_env, declarations):
             default = translate_default(argument, type_str, argument['default'])
             translated['default'] = default
             translated['default_init'] = argument.get('default_init', default)
-        if 'python_default_init' in argument:
-            assert 'default' not in argument
-            default = translate_default(argument, type_str, argument['python_default_init'])
-            translated['python_default_init'] = default
         if argument.get('output'):
             translated['output'] = True
         if argument.get('size'):
index e4b3f89..4282d37 100644 (file)
@@ -103,7 +103,7 @@ static std::unique_ptr<TensorIterator> make_reduction(
   // efficiency.
   // not generalize this to common mismatched input/output types to avoid cross
   // product of templated kernel launches.
-  if (self.type().scalarType() == dtype || 
+  if (self.type().scalarType() == dtype ||
       (self.is_cuda() && self.type().scalarType() == kHalf && dtype == kFloat)) {
     return TensorIterator::reduce_op(viewed_result, self);
   }
@@ -401,10 +401,11 @@ Tensor& _norm_out_cpu(Tensor& result, const Tensor& self, Scalar p, int64_t dim_
   }
 }
 
-Tensor& norm_out(Tensor &result, const Tensor &self, Scalar p, int64_t dim, bool keepdim) {
+Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> pOpt, int64_t dim, bool keepdim) {
   AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
            "norm only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
   AT_CHECK(at::isFloatingType(self.type().scalarType()), "norm only supports floating-point dtypes");
+  auto p = pOpt.value_or(2.0);
   dim = maybe_wrap_dim(dim, self.dim());
   if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) {
     return result;
@@ -438,7 +439,7 @@ Tensor _norm(const Tensor &self, Scalar p) {
   }
 }
 
-Tensor norm(const Tensor& self, Scalar p, int64_t dim, bool keepdim) {
+Tensor norm(const Tensor& self, optional<Scalar> p, int64_t dim, bool keepdim) {
   Tensor result = at::empty({0}, self.options());
   return at::native::norm_out(result, self, p, dim, keepdim);
 }
index 4062fc2..5681079 100644 (file)
@@ -173,8 +173,8 @@ Tensor irfft(const Tensor& self, const int64_t signal_ndim, const bool normalize
 }
 
 
-Tensor stft(const Tensor& self, const int64_t n_fft, const int64_t hop_length,
-            const int64_t win_length, const Tensor& window,
+Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
+            const optional<int64_t> win_lengthOpt, const Tensor& window,
             const bool normalized, const bool onesided) {
   #define REPR(SS) \
     SS << "stft(" << self.type() << self.sizes() << ", n_fft=" << n_fft \
@@ -187,6 +187,10 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const int64_t hop_length,
     } \
     SS << ", normalized=" << normalized << ", onesided=" << onesided << ")"
 
+  // default_init hop_length and win_length
+  auto hop_length = hop_lengthOpt.value_or(n_fft >> 2);
+  auto win_length = win_lengthOpt.value_or(n_fft);
+
   if (!at::isFloatingType(self.type().scalarType()) || self.dim() > 2 || self.dim() < 1) {
     std::ostringstream ss;
     REPR(ss) << ": expected a 1D or 2D tensor of floating types";
index 02a96ae..52320c2 100644 (file)
@@ -298,7 +298,8 @@ Tensor sum_to_size(const Tensor& self, IntList size) {
   return sum_to(self, size);
 }
 
-Tensor as_strided(const Tensor& self, IntList size, IntList stride, int64_t storage_offset) {
+Tensor as_strided(const Tensor& self, IntList size, IntList stride, optional<int64_t> storage_offset_) {
+  auto storage_offset = storage_offset_.value_or(self.storage_offset());
   auto tid = self.type_id();
   AT_CHECK(
       tid == CPUTensorId() || tid == CUDATensorId(),
@@ -308,19 +309,12 @@ Tensor as_strided(const Tensor& self, IntList size, IntList stride, int64_t stor
   return result;
 }
 
-Tensor &as_strided_(Tensor& self, IntList size, IntList stride, int64_t storage_offset) {
+Tensor &as_strided_(Tensor& self, IntList size, IntList stride, optional<int64_t> storage_offset_) {
+  auto storage_offset = storage_offset_.value_or(self.storage_offset());
   setStrided(self, size, stride, storage_offset);
   return self;
 }
 
-Tensor as_strided(const Tensor& self, IntList size, IntList stride) {
-  return at::as_strided(self, size, stride, self.storage_offset());
-}
-
-Tensor &as_strided_(Tensor& self, IntList size, IntList stride) {
-  return self.as_strided_(size, stride, self.storage_offset());
-}
-
 Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
   int64_t allDim = self.dim();
   int64_t end = start+length;
index 6f527f4..6c8d98f 100644 (file)
 - func: _argmin(Tensor self, int64_t dim, bool keepdim=false) -> Tensor
   variants: function
 
-- func: as_strided(Tensor self, IntList size, IntList stride) -> Tensor
+- func: as_strided(Tensor self, IntList size, IntList stride, int64_t? storage_offset=None) -> Tensor
   variants: function, method
   device_guard: false
 
-- func: as_strided_(Tensor self, IntList size, IntList stride) -> Tensor
+- func: as_strided_(Tensor self, IntList size, IntList stride, int64_t? storage_offset=None) -> Tensor
   variants: function, method
   device_guard: false
 
-- func: as_strided(Tensor self, IntList size, IntList stride, int64_t storage_offset) -> Tensor
-  variants: function, method
-  device_guard: false
-  python_default_init:
-    storage_offset: self.storage_offset()
-
-- func: as_strided_(Tensor self, IntList size, IntList stride, int64_t storage_offset) -> Tensor
-  variants: function, method
-  device_guard: false
-  python_default_init:
-    storage_offset: self.storage_offset()
-
 - func: asin(Tensor self) -> Tensor
   variants: function, method
 
 # missing the `pad_mode` and `center` arguments, which are taken care of at
 # `torch.functional.py`. They shall be moved here once we have mapping between
 # Python strings and C++ Enum in codegen.
-- func: stft(Tensor self, int64_t n_fft, int64_t hop_length, int64_t win_length, Tensor? window={}, bool normalized=false, bool onesided=true) -> Tensor
+- func: stft(Tensor self, int64_t n_fft, int64_t? hop_length=None, int64_t? win_length=None, Tensor? window={}, bool normalized=false, bool onesided=true) -> Tensor
   variants: function, method
-  python_default_init:
-    hop_length: n_fft >> 2
-    win_length: n_fft
 
 - func: stride(Tensor self, int64_t dim) -> int64_t
   variants: function, method
 - func: norm(Tensor self, Scalar p=2) -> Tensor
   variants: function, method
 
-- func: norm(Tensor self, Scalar p, int64_t dim, bool keepdim=false) -> Tensor
+- func: norm(Tensor self, Scalar? p, int64_t dim, bool keepdim=false) -> Tensor
   variants: function, method
-  python_default_init:
-    p: 2
 
-- func: norm_out(Tensor result, Tensor self, Scalar p, int64_t dim, bool keepdim=false) -> Tensor
-  python_default_init:
-    p: 2
+- func: norm_out(Tensor result, Tensor self, Scalar? p, int64_t dim, bool keepdim=false) -> Tensor
 
 - func: frobenius_norm(Tensor self) -> Tensor
   variants: function
index 7d2f29d..aacb235 100644 (file)
@@ -48,7 +48,6 @@ def sanitize_types(types):
 
 def parse_arguments(args, func_decl, func_name, func_return):
     arguments = []
-    python_default_inits = func_decl.get('python_default_init', {})
     is_out_fn = func_name.endswith('_out')
     if is_out_fn and func_decl.get('variants', []) not in [[], 'function', ['function']]:
         raise RuntimeError("Native functions suffixed with _out MUST be declared with only the function variant; "
@@ -71,16 +70,11 @@ def parse_arguments(args, func_decl, func_name, func_return):
 
         t, name = type_and_name
         default = None
-        python_default_init = None
 
         if '=' in name:
             ns = name.split('=', 1)
             name, default = ns[0], parse_default(ns[1])
 
-        if name in python_default_inits:
-            assert default is None
-            python_default_init = python_default_inits[name]
-
         typ = sanitize_types(t)
         assert len(typ) == 1
         argument_dict = {'type': typ[0].rstrip('?'), 'name': name, 'is_nullable': typ[0].endswith('?')}
@@ -90,8 +84,6 @@ def parse_arguments(args, func_decl, func_name, func_return):
             argument_dict['size'] = int(match.group(1))
         if default is not None:
             argument_dict['default'] = default
-        if python_default_init is not None:
-            argument_dict['python_default_init'] = python_default_init
         # TODO: convention is that the ith-argument correspond to the i-th return, but it would
         # be better if we just named everything and matched by name.
         if is_out_fn and arg_idx < len(func_return):
index 1cb5287..500dfb8 100644 (file)
 - name: alias(Tensor self)
   self: grad
 
-- name: as_strided(Tensor self, IntList size, IntList stride, int64_t storage_offset)
+- name: as_strided(Tensor self, IntList size, IntList stride, int64_t? storage_offset)
   self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)
 
 - name: asin(Tensor self)
 - name: norm(Tensor self, Scalar p)
   self: norm_backward(grad, self, p, result)
 
-- name: norm(Tensor self, Scalar p, int64_t dim, bool keepdim)
+- name: norm(Tensor self, Scalar? p, int64_t dim, bool keepdim)
   self: norm_backward(grad, self, p, result, dim, keepdim)
 
 - name: _pdist_forward(Tensor self, double p)
index b1e794e..2c9a9e7 100644 (file)
@@ -251,6 +251,9 @@ def create_python_bindings(python_functions, has_self, is_module=False):
         'const Type &': 'scalartype',
         'const THPLayout &': 'layout',
         'const Device &': 'device',
+        'c10::optional<ScalarType>': 'scalartypeOptional',
+        'c10::optional<Scalar>': 'scalarOptional',
+        'c10::optional<int64_t>': 'toInt64Optional',
         'int64_t': 'toInt64',
         'bool': 'toBool',
         'double': 'toDouble',
@@ -320,11 +323,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
                     default_expr += '.scalarType()'
                 expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
             else:
-                opt_match = re.match(r'c10::optional<(.+)>', typename)
-                if (opt_match):
-                    unpack = opt_match.group(1).lower() + 'Optional'
-                else:
-                    unpack = unpack_methods.get(typename, typename.lower())
+                unpack = unpack_methods.get(typename, typename.lower())
                 expr = 'r.{}({})'.format(unpack, arg_index)
 
             if unpack_args:
@@ -349,7 +348,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
             formal_args.append(formal)
 
         # We always want to unpack when we have TensorOptions.
-        unpack = any(arg.get('python_default_init') for arg in inputs) or has_tensor_options
+        unpack = has_tensor_options
         for arg in inputs:
             if arg['simple_type'] in ['Type', 'TensorOptions']:
                 continue
@@ -397,7 +396,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
             elif arg['name'] == 'layout' and arg['simple_type'] == 'Layout':
                 # out(s) determines the type and layout if it is present, so only use this if there are no outputs.
                 if len(outputs) == 0:
-                    layout = parse_arg(arg, layout_idx, arg.get('python_default_init'))[0]
+                    layout = parse_arg(arg, layout_idx)[0]
             elif arg['name'] == 'device' and arg['simple_type'] == 'Device':
                 if len(outputs) == 0:
                     assert parsed_type_args
@@ -770,8 +769,6 @@ def get_python_signature(declaration, include_out):
             default = arg['default']
             if default == 'nullptr' or default == 'nullopt' or default == '{}':
                 default = 'None'
-        if arg.get('python_default_init') is not None:
-            default = 'None'
         if default is not None:
             param += '=' + str(default)
         return param
index 308865f..776f238 100644 (file)
@@ -88,8 +88,8 @@ int64_t _safe_size(IntList sizes, IntList dim) {
   return size;
 }
 
-Tensor norm_backward(const Tensor & grad, const Tensor & self, const Scalar & p_, const Tensor & norm) {
-  double p = p_.toDouble();
+Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional<Scalar> & p_, const Tensor & norm) {
+  double p = p_.value_or(2.0).toDouble();
   Tensor self_scaled;
   Tensor scale_v;
   if (p == 0.0) {
@@ -114,7 +114,7 @@ Tensor norm_backward(const Tensor & grad, const Tensor & self, const Scalar & p_
   return self_scaled * scale_v;
 }
 
-Tensor norm_backward(Tensor grad, const Tensor & self, const Scalar & p_, Tensor norm, int64_t dim, bool keepdim) {
+Tensor norm_backward(Tensor grad, const Tensor & self, const optional<Scalar> & p_, Tensor norm, int64_t dim, bool keepdim) {
   if (!keepdim && self.dim() != 0) {
     grad = grad.unsqueeze(dim);
     norm = norm.unsqueeze(dim);
@@ -1371,7 +1371,7 @@ static inline int64_t _min_storage_size(IntList sizes, IntList strides, int64_t
 }
 
 // See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for explanation
-Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntList sizes, IntList strides, int64_t storage_offset) {
+Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntList sizes, IntList strides, optional<int64_t> storage_offset_) {
   // For output geometry,
   //   check for size 0 dimensions,
   //   skip size 1 dimensions,
@@ -1379,6 +1379,7 @@ Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntList s
   // Step (0)     for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
   // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
   //              on output geometry
+  auto storage_offset = storage_offset_.value_or(input_geometry.storage_offset());
   auto odim = grad.dim();
   std::vector<int64_t> out_sizes_, out_strides_;
   out_sizes_.reserve(odim);
index 66371ad..2f53e67 100644 (file)
@@ -56,6 +56,7 @@ TYPE_MAP = {
     'ScalarType': 'ScalarType',
     'ScalarType?': 'ScalarType?',
     'int64_t': 'int',
+    'int64_t?': 'int?',
     'double': 'float',
     'bool': 'bool',
     'Generator': 'Generator',
@@ -91,6 +92,7 @@ FROM_IVALUE = {
     'bool': '{}.toBool()',
     'double': '{}.toDouble()',
     'int64_t': '{}.toInt()',
+    'int64_t?': '{}.toOptional<int64_t>()',
     'std::string': '{}.toString()->string()',
     'Generator': 'nullptr',
     'std::array<bool,2>': 'as_bool_array<2>({}.toIntList()->elements())',
index 85465d5..69e3170 100644 (file)
@@ -902,7 +902,7 @@ class ShapePropagator {
             "aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor",
-            "aten::norm(Tensor self, Scalar p, int dim, bool keepdim) -> Tensor",
+            "aten::norm(Tensor self, Scalar? p, int dim, bool keepdim) -> Tensor",
             "aten::var(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
             "aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor",
             "aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
@@ -1262,9 +1262,7 @@ class ShapePropagator {
           node->matches(
               "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") ||
           node->matches(
-              "aten::as_strided(Tensor self, int[] size, int[] stride) -> Tensor") ||
-          node->matches(
-              "aten::as_strided(Tensor self, int[] size, int[] stride, int storage_offset) -> Tensor")) {
+              "aten::as_strided(Tensor self, int[] size, int[] stride, int? storage_offset) -> Tensor")) {
         return reshape_prop(node, attr::size, tensor_types);
       } else if (node->matches(
                      "aten::reshape(Tensor self, int[] shape) -> Tensor")) {
index f2fc0c0..874fb4d 100644 (file)
@@ -73,6 +73,18 @@ void addInputs(Node *n, const char * name, int64_t value) {
     detail::genericAddInput(n, value);
   }
 }
+
+void addInputs(Node *n, const char * name, c10::optional<int64_t> value)     {
+  if(value) {
+    detail::genericAddInput(n, *value);
+  } else {
+    Graph * g = n->owningGraph();
+    Value* none =
+        g->insertNode(g->createNone(IntType::get()))
+            ->output();
+    n->addInput(none);
+  }
+}
 void addInputs(Node *n, const char * name, bool value)               { detail::genericAddInput(n, value); }
 void addInputs(Node *n, const char * name, double value)             { detail::genericAddInput(n, value); }
 void addInputs(Node *n, const char * name, const at::Scalar& value)  { detail::genericAddInput(n, value); }
index cc4bb6b..bbbca8d 100644 (file)
@@ -200,6 +200,7 @@ inline void abandon() {
 // NB: those serve both as an intermediate steps in addInputs below,
 // as well as the overloads that terminate template recursion
 TORCH_API void addInputs(Node *n, const char * name, int64_t value);
+TORCH_API void addInputs(Node *n, const char * name, c10::optional<int64_t> value);
 TORCH_API void addInputs(Node *n, const char * name, bool value);
 TORCH_API void addInputs(Node *n, const char * name, double value);
 TORCH_API void addInputs(Node *n, const char * name, const at::Scalar& value);
index 432383b..215c0b4 100644 (file)
@@ -127,6 +127,7 @@ struct PythonArgs {
   inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
   inline c10::optional<at::ScalarType> scalartypeOptional(int i);
   inline c10::optional<at::Scalar> scalarOptional(int i);
+  inline c10::optional<int64_t> toInt64Optional(int i);
   inline const THPLayout& layout(int i);
   inline const THPLayout& layoutWithDefault(int i, const THPLayout& default_layout);
   inline at::Device device(int i);
@@ -405,6 +406,12 @@ inline int64_t PythonArgs::toInt64WithDefault(int i, int64_t default_int) {
   return toInt64(i);
 }
 
+inline c10::optional<int64_t> PythonArgs::toInt64Optional(int i) {
+  if (!args[i])
+    return c10::nullopt;
+  return toInt64(i);
+}
+
 inline double PythonArgs::toDouble(int i) {
   if (!args[i]) return signature.params[i].default_double;
   return THPUtils_unpackDouble(args[i]);