From 82175f31b47680a72ad7c00ef185d9a0c1ceb078 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 27 Nov 2018 18:36:05 -0800 Subject: [PATCH] Move Affine grid to C++ (#14392) Summary: Port AffineGrid to C++, because script does not support compiling Function classes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14392 Differential Revision: D13219698 Pulled By: eellison fbshipit-source-id: 3ddad8a84c72010b5a6c6f7f9712be614202faa6 --- aten/src/ATen/core/aten_interned_strings.h | 2 + aten/src/ATen/native/AffineGridGenerator.cpp | 123 +++++++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 6 ++ test/test_jit.py | 1 - tools/autograd/derivatives.yaml | 3 + torch/csrc/jit/register_prim_ops.cpp | 10 +++ torch/csrc/jit/script/compiler.cpp | 1 + torch/jit/__init__.py | 2 + torch/nn/_functions/vision.py | 82 ++---------------- torch/nn/functional.py | 2 + 10 files changed, 155 insertions(+), 77 deletions(-) create mode 100644 aten/src/ATen/native/AffineGridGenerator.cpp diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index eb3f8bd..5f6896d 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -196,6 +196,8 @@ _(aten, addcmul) \ _(aten, addmm) \ _(aten, addmv) \ _(aten, addr) \ +_(aten, affine_grid_generator) \ +_(aten, affine_grid_generator_backward) \ _(aten, alias) \ _(aten, all) \ _(aten, allclose) \ diff --git a/aten/src/ATen/native/AffineGridGenerator.cpp b/aten/src/ATen/native/AffineGridGenerator.cpp new file mode 100644 index 0000000..a029069 --- /dev/null +++ b/aten/src/ATen/native/AffineGridGenerator.cpp @@ -0,0 +1,123 @@ +#include "ATen/ATen.h" +#include "ATen/NativeFunctions.h" + +namespace at { namespace native { + +at::Tensor linspace_from_neg_one(const Tensor& grid, int64_t num_steps) { + if (num_steps > 1) { + return at::linspace(-1, 1, num_steps, grid.options()); + } else { + return at::tensor(-1, grid.options()); + } +} + +Tensor make_base_grid_4D( + const Tensor& theta, + int64_t N, + int64_t C, + int64_t H, + int64_t W) { + auto base_grid = at::empty({N, H, W, 3}, theta.options()); + + base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W)); + base_grid.select(-1, 1).copy_(linspace_from_neg_one(theta, H).unsqueeze_(-1)); + base_grid.select(-1, 2).fill_(1); + + return base_grid; +} + +Tensor make_base_grid_5D( + const Tensor& theta, + int64_t N, + int64_t C, + int64_t D, + int64_t H, + int64_t W) { + auto base_grid = at::empty({N, D, H, W, 4}, theta.options()); + + base_grid.select(-1, 0).copy_(linspace_from_neg_one(theta, W)); + base_grid.select(-1, 1).copy_(linspace_from_neg_one(theta, H).unsqueeze_(-1)); + base_grid.select(-1, 2).copy_(linspace_from_neg_one(theta, D).unsqueeze_(-1).unsqueeze_(-1)); + base_grid.select(-1, 3).fill_(1); + + return base_grid; +} + +Tensor affine_grid_generator_4D( + const Tensor& theta, + int64_t N, + int64_t C, + int64_t H, + int64_t W) { + Tensor base_grid = make_base_grid_4D(theta, N, C, H, W); + auto grid = base_grid.view({N, H * W, 3}).bmm(theta.transpose(1, 2)); + return grid.view({N, H, W, 2}); +} + +Tensor affine_grid_generator_5D( + const Tensor& theta, + int64_t N, + int64_t C, + int64_t D, + int64_t H, + int64_t W) { + Tensor base_grid = make_base_grid_5D(theta, N, C, D, H, W); + auto grid = base_grid.view({N, D * H * W, 4}).bmm(theta.transpose(1, 2)); + return grid.view({N, D, H, W, 3}); +} + +Tensor affine_grid_generator(const Tensor& theta, IntList size) { + AT_CHECK( + size.size() == 4 || size.size() == 5, + "AffineGridGenerator needs 4d (spatial) or 5d (volumetric) inputs."); + if (size.size() == 4) { + return affine_grid_generator_4D(theta, size[0], size[1], size[2], size[3]); + } else { + return affine_grid_generator_5D( + theta, size[0], size[1], size[2], size[3], size[4]); + } +} + +Tensor affine_grid_generator_4D_backward( + const Tensor& grad_grid, + int64_t N, + int64_t C, + int64_t H, + int64_t W) { + auto base_grid = make_base_grid_4D(grad_grid, N, C, H, W); + AT_ASSERT(grad_grid.sizes() == IntList({N, H, W, 2})); + auto grad_theta = base_grid.view({N, H * W, 3}) + .transpose(1, 2) + .bmm(grad_grid.view({N, H * W, 2})); + return grad_theta.transpose(1, 2); +} + +Tensor affine_grid_generator_5D_backward( + const Tensor& grad_grid, + int64_t N, + int64_t C, + int64_t D, + int64_t H, + int64_t W) { + auto base_grid = make_base_grid_5D(grad_grid, N, C, D, H, W); + AT_ASSERT(grad_grid.sizes() == IntList({N, D, H, W, 3})); + auto grad_theta = base_grid.view({N, D * H * W, 4}) + .transpose(1, 2) + .bmm(grad_grid.view({N, D * H * W, 3})); + return grad_theta.transpose(1, 2); +} + +Tensor affine_grid_generator_backward(const Tensor& grad, IntList size) { + AT_CHECK( + size.size() == 4 || size.size() == 5, + "AffineGridGenerator needs 4d (spatial) or 5d (volumetric) inputs."); + if (size.size() == 4) { + return affine_grid_generator_4D_backward( + grad, size[0], size[1], size[2], size[3]); + } else { + return affine_grid_generator_5D_backward( + grad, size[0], size[1], size[2], size[3], size[4]); + } +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ed1a150..fe51e37 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -144,6 +144,12 @@ - func: addr_out(Tensor result, Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor +- func: affine_grid_generator(Tensor theta, IntList size) -> Tensor + variants: function + +- func: affine_grid_generator_backward(Tensor grad, IntList size) -> Tensor + variants: function + - func: all(Tensor self, int64_t dim, bool keepdim=false) -> Tensor variants: function, method diff --git a/test/test_jit.py b/test/test_jit.py index db2f9e1..6c81284 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9358,7 +9358,6 @@ EXCLUDE_SCRIPT = { 'test_nn_max_unpool2d', # argument type not supported - 'test_nn_affine_grid', # unknown builtin op 'test_nn_binary_cross_entropy', diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 05c5a51..95559a8 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -132,6 +132,9 @@ vec1: grad.mv(vec2) * alpha vec2: grad.t().mv(vec1) * alpha +- name: affine_grid_generator(Tensor theta, IntList size) + theta: affine_grid_generator_backward(grad, size) + - name: alias(Tensor self) self: grad diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index f40cf28..c96aeb1 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -209,6 +209,16 @@ RegisterOperators reg({ }; }), Operator( + "prim::is_cuda(Tensor a) -> bool", + [](const Node* node) -> Operation { + return [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_cuda()); + return 0; + }; + }), + Operator( "prim::Undefined() -> Tensor", [](const Node* node) { return [](Stack& stack) { diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index ce3d49a..9328b89 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -2559,6 +2559,7 @@ std::shared_ptr SimpleValue::attr(SourceRange loc, Method & m, con "dtype", "device", "shape", + "is_cuda", }; if (fields.count(field)) { auto r = m.graph()->insert(Symbol::fromQualString("prim::"+field), {value}); diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 5d9b227..0bb8fa4 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -3,6 +3,7 @@ from torch import Tensor from torch.autograd import Variable, function from torch.nn import Module, ModuleList, ParameterList, Parameter, Sequential from torch.jit.frontend import get_jit_ast, get_default_args +import torch.backends.cudnn as cudnn import torch.jit.annotations from torch._six import raise_from, with_metaclass, get_function_from_type from .._jit_internal import createResolutionCallback, _compiled_weak_fns, \ @@ -1380,6 +1381,7 @@ def _get_builtin_table(): _builtin_table[id(_quadruple)] = "aten::_quadruple" _builtin_table[id(_list_with_default)] = "aten::list_with_default" _builtin_table[id(_unwrap_optional)] = "aten::_unwrap_optional" + _builtin_table[id(cudnn.is_acceptable)] = "aten::cudnn_is_acceptable" return _builtin_table diff --git a/torch/nn/_functions/vision.py b/torch/nn/_functions/vision.py index 014658a..cccf011 100644 --- a/torch/nn/_functions/vision.py +++ b/torch/nn/_functions/vision.py @@ -1,83 +1,13 @@ import torch -from torch.autograd import Function -from torch.autograd.function import once_differentiable -from torch._thnn import type2backend -from .thnn.auto import function_by_name import torch.backends.cudnn as cudnn +@torch._jit_internal.weak_script def affine_grid_generator(theta, size): - if theta.data.is_cuda and cudnn.enabled and cudnn.is_acceptable(theta.data) and len(size) == 4: + # type: (Tensor, List[int]) -> Tensor + if theta.is_cuda and cudnn.enabled and cudnn.is_acceptable(theta) and len(size) == 4: N, C, H, W = size - return torch.cudnn_affine_grid_generator(theta, N, C, H, W) + ret = torch.cudnn_affine_grid_generator(theta, N, C, H, W) else: - return AffineGridGenerator.apply(theta, size) - - -# TODO: Port these completely into C++ - - -class AffineGridGenerator(Function): - @staticmethod - def forward(ctx, theta, size): - assert type(size) == torch.Size - - ctx.size = size - ctx.is_cuda = theta.is_cuda - - if len(size) == 5: - N, C, D, H, W = size - base_grid = theta.new(N, D, H, W, 4) - - base_grid[:, :, :, :, 0] = (torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1])) - base_grid[:, :, :, :, 1] = (torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1]))\ - .unsqueeze(-1) - base_grid[:, :, :, :, 2] = (torch.linspace(-1, 1, D) if D > 1 else torch.Tensor([-1]))\ - .unsqueeze(-1).unsqueeze(-1) - base_grid[:, :, :, :, 3] = 1 - - grid = torch.bmm(base_grid.view(N, D * H * W, 4), theta.transpose(1, 2)) - grid = grid.view(N, D, H, W, 3) - - elif len(size) == 4: - N, C, H, W = size - base_grid = theta.new(N, H, W, 3) - - base_grid[:, :, :, 0] = (torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1])) - base_grid[:, :, :, 1] = (torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1]))\ - .unsqueeze(-1) - base_grid[:, :, :, 2] = 1 - - grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2)) - grid = grid.view(N, H, W, 2) - else: - raise RuntimeError("AffineGridGenerator needs 4d (spatial) or 5d (volumetric) inputs.") - - ctx.base_grid = base_grid - - return grid - - @staticmethod - @once_differentiable - def backward(ctx, grad_grid): - assert ctx.is_cuda == grad_grid.is_cuda - base_grid = ctx.base_grid - - if len(ctx.size) == 5: - N, C, D, H, W = ctx.size - assert grad_grid.size() == torch.Size([N, D, H, W, 3]) - grad_theta = torch.bmm( - base_grid.view(N, D * H * W, 4).transpose(1, 2), - grad_grid.view(N, D * H * W, 3)) - elif len(ctx.size) == 4: - N, C, H, W = ctx.size - assert grad_grid.size() == torch.Size([N, H, W, 2]) - grad_theta = torch.bmm( - base_grid.view(N, H * W, 3).transpose(1, 2), - grad_grid.view(N, H * W, 2)) - else: - assert False - - grad_theta = grad_theta.transpose(1, 2) - - return grad_theta, None + ret = torch.affine_grid_generator(theta, size) + return ret diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 91352fc..6c6b667 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2380,7 +2380,9 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'): return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum) +@torch._jit_internal.weak_script def affine_grid(theta, size): + # type: (Tensor, List[int]) -> Tensor r"""Generates a 2d flow field, given a batch of affine matrices :attr:`theta` Generally used in conjunction with :func:`grid_sample` to implement Spatial Transformer Networks. -- 2.7.4