Use return names in JIT operators
authorChristian Puhrsch <cpuhrsch@fb.com>
Fri, 8 Mar 2019 07:31:00 +0000 (23:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 8 Mar 2019 07:34:42 +0000 (23:34 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17638

Differential Revision: D14295606

Pulled By: cpuhrsch

fbshipit-source-id: 62040ac65434411357808735f0fe6cd33cc1c30f

aten/src/ATen/native/native_functions.yaml
tools/jit/gen_jit_dispatch.py

index 912448b..6ce0eb5 100644 (file)
   matches_jit_signature: True
 
 - func: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid
+  matches_jit_signature: True
   dispatch:
     CUDA: cudnn_affine_grid_generator_forward
 
 # TODO: Why do I have to call this grad?!
 - func: cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta
+  matches_jit_signature: True
   dispatch:
     CUDA: cudnn_affine_grid_generator_backward
 
 
 # NB: input is special cased in a way I don't quite understand
 - func: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output
+  matches_jit_signature: True
   dispatch:
     CUDA: cudnn_grid_sampler_forward
 
   variants: method, function
 
 - func: _gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor
+  matches_jit_signature: True
 
 - func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
index a47ca47..d0d2829 100644 (file)
@@ -481,6 +481,9 @@ def signature(decl, should_match_schema=True):
     arg_list = ', '.join(args)
     if len(decl['returns']) == 1:
         ret_list = jit_type_of(decl['returns'][0])
+        # Adding output name if it exists
+        if decl['returns'][0].get('field_name'):
+            ret_list += ' ' + decl['returns'][0]['field_name']
     else:
         def type_maybe_field(r):
             return '{} {}'.format(jit_type_of(r), r['field_name']) if 'field_name' in r else jit_type_of(r)