From b290a16b2d07c9230f17fde3a9cf2c28ca2f7505 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Thu, 7 Mar 2019 23:31:00 -0800 Subject: [PATCH] Use return names in JIT operators Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17638 Differential Revision: D14295606 Pulled By: cpuhrsch fbshipit-source-id: 62040ac65434411357808735f0fe6cd33cc1c30f --- aten/src/ATen/native/native_functions.yaml | 4 ++++ tools/jit/gen_jit_dispatch.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 912448b..6ce0eb5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -607,11 +607,13 @@ matches_jit_signature: True - func: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid + matches_jit_signature: True dispatch: CUDA: cudnn_affine_grid_generator_forward # TODO: Why do I have to call this grad?! - func: cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta + matches_jit_signature: True dispatch: CUDA: cudnn_affine_grid_generator_backward @@ -680,6 +682,7 @@ # NB: input is special cased in a way I don't quite understand - func: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output + matches_jit_signature: True dispatch: CUDA: cudnn_grid_sampler_forward @@ -3629,6 +3632,7 @@ variants: method, function - func: _gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor + matches_jit_signature: True - func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index a47ca47..d0d2829 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -481,6 +481,9 @@ def signature(decl, should_match_schema=True): arg_list = ', '.join(args) if len(decl['returns']) == 1: ret_list = jit_type_of(decl['returns'][0]) + # Adding output name if it exists + if decl['returns'][0].get('field_name'): + ret_list += ' ' + decl['returns'][0]['field_name'] else: def type_maybe_field(r): return '{} {}'.format(jit_type_of(r), r['field_name']) if 'field_name' in r else jit_type_of(r) -- 2.7.4