From 1abbee0f8e088d2c99481a05672f07947916fd75 Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Wed, 10 Apr 2019 18:12:38 -0700 Subject: [PATCH] Allow Tensor lists to show up in symbolic differentiable graphs. (#16784) Summary: It is done by flattening all tensor lists that are inputs/outputs to the graph into the inputs/outputs list in the autograd graph. This is less desirable than simply allowing IValues to exist in the inputs/outputs of autograd::Function but it is substantially less intrusive. CaptureList describes the variables captured for backward in a single class. UnpackInstructs describes how the flattened inputs to backwards are re-packed into lists. ailzhang This PR is also part 2 of covering maskrcnn & bert AD formulas, following #16689. Ops added in this PR: ``` cat index meshgrid reshape split split_with_sizes stack unbind ``` I will also add a few perf numbers here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16784 Differential Revision: D14104063 Pulled By: ailzhang fbshipit-source-id: 5ceadadfd67ccaac60c5fd6740786c5354e252b9 --- test/common_methods_invocations.py | 23 +- test/test_jit.py | 95 ++++++- torch/csrc/jit/graph_executor.cpp | 299 ++++++++++++++++------ torch/csrc/jit/passes/specialize_autogradzero.cpp | 3 +- torch/csrc/jit/symbolic_script.cpp | 86 ++++++- 5 files changed, 418 insertions(+), 88 deletions(-) diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py index 738f680..6202d5b 100644 --- a/test/common_methods_invocations.py +++ b/test/common_methods_invocations.py @@ -181,11 +181,11 @@ def method_tests(): ('view', (S,), (S,), '1d', (True,)), ('view', (), (dont_convert(()),), 'scalar_to_scalar', (True,)), ('view', (), (1,), 'scalar_to_1d', (True,)), - ('reshape', (S, S, S), (S * S, S),), - ('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size'), - ('reshape', (S,), (S,), '1d'), - ('reshape', (), (dont_convert(()),), 'scalar_to_scalar'), - ('reshape', (), (1,), 'scalar_to_1d'), + ('reshape', (S, S, S), (S * S, S), '', (True,)), + ('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size', (True,)), + ('reshape', (S,), (S,), '1d', (True,)), + ('reshape', (), (dont_convert(()),), 'scalar_to_scalar', (True,)), + ('reshape', (), (1,), 'scalar_to_1d', (True,)), ('reshape_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)), ('reshape_as', (), (non_differentiable(torch.tensor(42.)),), 'scalar'), ('reshape_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'), @@ -726,10 +726,15 @@ def method_tests(): ('unsqueeze', (), (0,), 'scalar', (True,), [0]), ('chunk', (S, S, S), (2,), '', (True, 'prim::ConstantChunk')), ('chunk', (S, S, S), (S, 1), 'dim', (True, 'prim::ConstantChunk'), [1]), - ('split', (S, S, S), (2,)), - ('split', (S, S, S), (S, 1), 'dim', (), [1]), - ('split', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'size_list'), - ('split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2), 'size_list_dim', (), [1]), + ('split', (S, S, S), (2,), '', (True,)), + ('split', (S, S, S), (S, 1), 'dim', (True,), [1]), + ('split', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'size_list', + (True, 'aten::split_with_sizes')), + ('split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2), 'size_list_dim', + (True, 'aten::split_with_sizes'), [1]), + ('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), '', (True,)), + ('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3), 0],), 'size_0', (True, )), + ('split_with_sizes', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'dim', (True, ), [1]), ('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', (), [0]), ('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', (), [0]), ('gather', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_input', (), [0]), diff --git a/test/test_jit.py b/test/test_jit.py index b7e4499..26160ad 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -253,6 +253,14 @@ def enable_cpu_fuser(fn): return wrapper +# helper function to get sum of List[Tensor] +def _sum_of_list(tensorlist): + s = 0 + for t in tensorlist: + s += t.sum() + return s + + class JitTestCase(TestCase): _do_cuda_memory_leak_check = True _restored_warnings = False @@ -3736,10 +3744,20 @@ a") @torch.jit.script def func2(x, y): return torch.cat((x, x), y) + func2.debug_disable_autodiff_subgraph_inlining() - x = torch.rand([2, 2]) + x = torch.rand([2, 2]).requires_grad_() y = torch.tensor(1) - self.assertEqual(func2(x, y), torch.cat((x, x), y)) + + output = func2(x, y) + output_ref = torch.cat((x, x), y) + self.assertEqual(output, output_ref) + + self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], []) + + grad = torch.autograd.grad(output.sum(), x) + grad_ref = torch.autograd.grad(output_ref.sum(), x) + self.assertEqual(grad, grad_ref) def test_cat_lifts(self): @torch.jit.script @@ -3757,6 +3775,73 @@ a") for g in [foo.graph, foo2.graph, foo3.graph]: FileCheck().check("int =").check("ListConstruct").check("aten::cat").run(str(g)) + @unittest.skipIf(PY2, "Requires python 3") + def test_stack(self): + @torch.jit.script + def func(x): + return torch.stack((x, x), dim=1) + x = torch.rand(10, 10) + self.assertEqual(func(x), torch.stack((x, x), dim=1)) + + @torch.jit.script + def func2(x, y): + return torch.stack((x, y), dim=0) + + func2.debug_disable_autodiff_subgraph_inlining() + + x = torch.randn([2, 2]).requires_grad_() + y = torch.randn([2, 2]).requires_grad_() + + output = func2(x, y) + output_ref = torch.stack((x, y), 0) + self.assertEqual(output, output_ref) + + self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], []) + + grads = torch.autograd.grad(output.sum(), (x, y)) + grads_ref = torch.autograd.grad(output_ref.sum(), (x, y)) + self.assertEqual(grads, grads_ref) + + def test_unbind(self): + @torch.jit.script + def func(x, y): + # type: (Tensor, int) -> List[Tensor] + return torch.unbind(x, y) + func.debug_disable_autodiff_subgraph_inlining() + + x = torch.rand([2, 2]).requires_grad_() + y = 0 + outputs = func(x, y) + outputs_ref = torch.unbind(x, dim=y) + self.assertEqual(outputs, outputs_ref) + + self.assertAutodiffNode(func.graph_for(x, y), True, ['aten::unbind'], []) + + grad = torch.autograd.grad(_sum_of_list(outputs), x) + grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x) + self.assertEqual(grad, grad_ref) + + def test_meshgrid(self): + @torch.jit.script + def func(a): + # type: (List[Tensor]) -> List[Tensor] + return torch.meshgrid(a) + func.debug_disable_autodiff_subgraph_inlining() + + a = torch.tensor([1.0, 2, 3]).requires_grad_() + b = torch.tensor([1.0, 2, 3, 4]).requires_grad_() + inputs = [a, b] + + outputs_ref = torch.meshgrid(inputs) + outputs = func(inputs) + self.assertEqual(outputs, outputs_ref) + + self.assertAutodiffNode(func.graph_for(inputs), True, ['aten::meshgrid'], []) + + grads = torch.autograd.grad(_sum_of_list(outputs), inputs) + grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs) + self.assertEqual(grads, grads_ref) + def test_list_literal(self): def reassign(): x = [1] @@ -11933,10 +12018,16 @@ EXCLUDE_SCRIPT = { # chunk returns a list in scripting and we don't unpack the list, # Thus it won't be replaced by ConstantChunk and run AD. # It's explicitly checked in test_chunk_constant_script_ad +# Similary for split, it's replaced by split_with_sizes in tracing, +# but we don't have AD formula for aten::split(Tensor, int[], int), +# an op registered in JIT so AD is not triggered in scripting. EXCLUDE_SCRIPT_AD_CHECK = { 'test_chunk', 'test_chunk_dim', 'test_chunk_dim_neg0', + 'test_split_size_list', + 'test_split_size_list_dim', + 'test_split_size_list_dim_neg0', } EXCLUDE_PYTHON_PRINT = { diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index da088a8..31ef09b 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -75,53 +75,157 @@ struct ExecutionPlan { std::shared_ptr graph; }; -struct DifferentiableGraphBackward : public autograd::Function { - DifferentiableGraphBackward(GraphExecutor executor, size_t capture_size) - : executor(std::move(executor)) { - is_var_capture.reserve(capture_size); - var_captures.reserve(capture_size); - ivalue_captures.reserve(capture_size); +struct CaptureList { + CaptureList(size_t capture_size) { + capture_types_.reserve(capture_size); + var_captures_.reserve(capture_size); // var_captures_.size() might be greater than capture_size + ivalue_captures_.reserve(capture_size); } - variable_list apply(variable_list&& inputs) override { - Stack stack; - stack.reserve(is_var_capture.size() + inputs.size()); - stack.insert( - stack.end(), - std::make_move_iterator(inputs.begin()), - std::make_move_iterator(inputs.end())); - auto var_capture_it = var_captures.begin(); - auto ivalue_capture_it = ivalue_captures.begin(); - for (bool is_var : is_var_capture) { - if (is_var) { - stack.emplace_back(var_capture_it->unpack(this->shared_from_this())); - ++var_capture_it; - } else { - stack.push_back(*ivalue_capture_it); - ++ivalue_capture_it; + void captureTensor(const at::Tensor& tensor, bool is_output) { + var_captures_.emplace_back(Variable(tensor), is_output); + } + + void capture(const IValue& val, bool is_output) { + if (val.isTensor()) { + capture_types_.emplace_back(CAPTURE_TENSOR); + captureTensor(val.toTensor(), is_output); + } else if (val.isTensorList()) { + // For TensorList, we have to flatten it to Tensors during saving and + // unflatten it back to TensorList when using it in backward apply(). + // This is to avoid any implicit mutation to TensorList happened + // between forward & backward. + capture_types_.emplace_back(CAPTURE_LIST); + const std::vector& tensors = val.toTensorListRef(); + sizes_.push_back(tensors.size()); + + for (const at::Tensor& tensor: tensors) { + captureTensor(tensor, is_output); + } + } else { + capture_types_.emplace_back(CAPTURE_IVALUE); + ivalue_captures_.push_back(val); + } + } + + size_t size() const { + return capture_types_.size(); + } + + void unpack(Stack & stack, const std::shared_ptr& saved_for) { + auto var_capture_it = var_captures_.begin(); + auto ivalue_capture_it = ivalue_captures_.begin(); + auto size_it = sizes_.begin(); + for (Capture capture_type : capture_types_) { + switch(capture_type) { + case CAPTURE_TENSOR: { + stack.emplace_back(var_capture_it->unpack(saved_for)); + ++var_capture_it; + } break; + case CAPTURE_LIST: { + std::vector lst; + auto size = *size_it++; + for (size_t i = 0; i < size; i++) { + lst.emplace_back(var_capture_it->unpack(saved_for)); + var_capture_it++; + } + stack.emplace_back(TensorList::create(std::move(lst))); + } break; + case CAPTURE_IVALUE: { + stack.push_back(*ivalue_capture_it++); + } break; + } + } + } +private: + enum Capture: uint8_t { + CAPTURE_TENSOR, + CAPTURE_LIST, + CAPTURE_IVALUE, + }; + + std::vector capture_types_; + std::vector var_captures_; + std::vector ivalue_captures_; + std::vector sizes_; +}; + +// how do we turn a flattened list of tensors back into the ivalues that +// the DifferentiableGraphBackward expects +struct UnpackInstructions { + UnpackInstructions(size_t num_inputs) { + insts_.reserve(num_inputs); + } + void pushTensor() { + insts_.emplace_back(PUSH_TENSOR); + } + void pushTensorList(size_t size) { + insts_.emplace_back(PUSH_LIST); + sizes_.push_back(size); + } + void unpack(variable_list&& inputs, Stack& stack) { + auto input_it = std::make_move_iterator(inputs.begin()); + auto sizes_it = sizes_.begin(); + for(Inst inst : insts_) { + switch(inst) { + case PUSH_TENSOR: { + at::Tensor t = *input_it++; + stack.emplace_back(std::move(t)); + } break; + case PUSH_LIST: { + std::vector lst(input_it, input_it + *sizes_it++); + stack.emplace_back(TensorList::create(std::move(lst))); + } break; } } + } +private: + enum Inst : uint8_t { + PUSH_TENSOR, + PUSH_LIST, // consumes one size + }; + std::vector insts_; + std::vector sizes_; +}; +struct DifferentiableGraphBackward : public autograd::Function { + DifferentiableGraphBackward(GraphExecutor executor, size_t input_size, size_t capture_size) + : executor(std::move(executor)) + , captures_(capture_size) + , input_instructions_(input_size) {} + + variable_list apply(variable_list&& inputs) override { + Stack stack; + stack.reserve(captures_.size() + inputs.size()); + + input_instructions_.unpack(std::move(inputs), stack); + captures_.unpack(stack, shared_from_this()); executor.run(stack); - AT_ASSERT(stack.size() == num_outputs()); + // NB: stack.size() == num_outputs() is not always true + // after we added TensorList support. + // Example: aten::stack(Tensor[] tensors, int) where + // tensors = [x, x] + // Here stack.size()[=1] with a TensorList IValue of + // backward graph output. + // num_outputs()[=2], however, is the number of outputs of + // grad_fn (an autograd::Function). grad_fn's outputs are + // grads with regard to Tensor/Variables `x`, but not + // graph input TensorList [x, x]. These two grads will + // be accumulated to x.grad later using autograd::InputBuffer. variable_list outputs; outputs.reserve(num_outputs()); - for (size_t i = 0; i < num_outputs(); ++i) { - // Input grad can also be None even if it requires grad - // Example: `other` in expand_as(self, other) - if (should_compute_output(i) && !stack[i].isNone()) { - auto output = std::move(stack[i]).toTensor(); - const auto& edge = next_edge(i); - if (output.defined()) { - outputs.emplace_back(std::move(output)); - } else if (edge.is_valid()) { - outputs.emplace_back( - edge.function->input_metadata(edge.input_nr).zeros_like()); - } else { - outputs.emplace_back(); + size_t output_index = 0; + for (IValue& v : stack) { + if (v.isTensorList()) { + for(at::Tensor tensor : v.toTensorListRef()) { + produceOutput(output_index++, std::move(tensor), outputs); } + } else if (v.isTensor()) { + produceOutput(output_index++, std::move(v).toTensor(), outputs); } else { + // Input grad can also be None even if it requires grad + // Example: `other` in expand_as(self, other) outputs.emplace_back(); } } @@ -129,24 +233,72 @@ struct DifferentiableGraphBackward : public autograd::Function { } void capture(const IValue& val, bool is_output) { - const bool is_tensor = val.isTensor(); - is_var_capture.push_back(is_tensor); - if (is_tensor) { - var_captures.emplace_back(Variable(val.toTensor()), is_output); + captures_.capture(val, is_output); + } + + + void addOutputForTensor(const at::Tensor& tensor) { + auto v = Variable(tensor); + add_next_edge( + v.defined() ? v.gradient_edge() : autograd::Edge{}); + } + void addOutputForIValue(const IValue& value) { + if (value.isTensorList()){ + for(const at::Tensor& tensor : value.toTensorListRef()) { + addOutputForTensor(tensor); + } } else { - ivalue_captures.push_back(val); + addOutputForTensor(value.toTensor()); + } + } + + void addInputVariable(Variable output) { + // NB: since our requires_grad setting is only a heuristic we might end + // up wanting to differentiate through integral tensors, which is + // generally a hard error in autograd. + if (at::isFloatingType(output.type().scalarType())) { + autograd::create_gradient_edge(output, shared_from_this()); + output.set_requires_grad(true); + } else { + add_input_metadata(autograd::Function::undefined_input{}); + } + } + + void addInputIValue(const IValue& v) { + if (v.isTensorList()) { + const std::vector& tensors = v.toTensorListRef(); + input_instructions_.pushTensorList(tensors.size()); + for (const at::Tensor& tensor : tensors) { + addInputVariable(tensor); + } + } else if (v.isTensor()) { + input_instructions_.pushTensor(); + addInputVariable(v.toTensor()); + } + } + +private: + + void produceOutput(size_t i, at::Tensor output, variable_list& outputs) { + if (should_compute_output(i)) { + const auto& edge = next_edge(i); + if (output.defined()) { + outputs.emplace_back(std::move(output)); + } else if (edge.is_valid()) { + outputs.emplace_back( + edge.function->input_metadata(edge.input_nr).zeros_like()); + } else { + outputs.emplace_back(); + } + } else { + outputs.emplace_back(); } } - private: friend struct ExecutionPlan; GraphExecutor executor; - - // INVARIANT: is_var_capture.size() == var_captures.size() + - // ivalue_captures.size() - std::vector is_var_capture; - std::vector var_captures; - std::vector ivalue_captures; + CaptureList captures_; + UnpackInstructions input_instructions_; }; // an optimized way of executing the subgraph computed directly on @@ -166,6 +318,7 @@ struct DifferentiableGraphOp { int operator()(Stack& stack) const { auto grad_fn = std::make_shared( grad_executor, + grad.df_input_vjps.size(), grad.df_input_captured_inputs.size() + grad.df_input_captured_outputs.size()); @@ -174,9 +327,7 @@ struct DifferentiableGraphOp { // hook up the outputs of df to the gradient functions of the inputs that // require gradients for (auto idx : grad.df_output_vjps) { - auto v = Variable(inputs[idx].toTensor()); - grad_fn->add_next_edge( - v.defined() ? v.gradient_edge() : autograd::Edge{}); + grad_fn->addOutputForIValue(inputs[idx]); } captureInputs(*grad_fn, inputs); } @@ -194,24 +345,7 @@ struct DifferentiableGraphOp { // this is currently intentionally not done here so we can get an idea of // our perf before introducing overhead for correctness for (auto idx : grad.df_input_vjps) { - // Note: we have to set this up in place, or we have to throw away and - // reallocate variables that were already created in wrapTensors. We - // should add an API for this. - - // XXX: undefined tensor syntax in autograd - Variable output; - if (!outputs[idx].isNone()) { - output = outputs[idx].toTensor(); - } - // NB: since our requires_grad setting is only a heuristic we might end - // up wanting to differentiate through integral tensors, which is - // generally a hard error in autograd. - if (at::isFloatingType(output.scalar_type())) { - autograd::create_gradient_edge(output, grad_fn); - output.set_requires_grad(true); - } else { - grad_fn->add_input_metadata(autograd::Function::undefined_input{}); - } + grad_fn->addInputIValue(outputs[idx]); } captureOutputs(*grad_fn, outputs); // drop the temporary outputs so that we return the same number of @@ -225,6 +359,26 @@ struct DifferentiableGraphOp { private: friend GraphExecutor* detail::getGradExecutor(Operation& op); + void detach(at::Tensor& t) const { + if (t.defined()) { + t = autograd::as_variable_ref(t).detach(); + } + } + + void detach(IValue& v) const { + if(v.isTensor()) { + auto t = std::move(v).toTensor(); + detach(t); + v = IValue{t}; + } else if(v.isTensorList()) { + std::vector lst = v.toTensorListRef(); + for(at::Tensor& t : lst) { + detach(t); + } + v = TensorList::create(std::move(lst)); + } + } + void detachVariables(Stack& stack) const { // It would be nice to use an ArrayRef here, but unfortunately those can // only return const references, so we need to do a bunch of indexing @@ -232,12 +386,7 @@ struct DifferentiableGraphOp { const int64_t stack_size = stack.size(); const int64_t stack_offset = stack_size - num_inputs; for (int64_t i = stack_offset; i < stack_size; ++i) { - auto& v = stack[i]; - if (!v.isTensor()) - continue; - auto t = std::move(v).toTensor(); - v = IValue{t.defined() ? autograd::as_variable_ref(t).detach() - : std::move(t)}; + detach(stack[i]); } } // Capture (save) inputs that would be required to subsequently run backwards diff --git a/torch/csrc/jit/passes/specialize_autogradzero.cpp b/torch/csrc/jit/passes/specialize_autogradzero.cpp index 06f252d..245668c 100644 --- a/torch/csrc/jit/passes/specialize_autogradzero.cpp +++ b/torch/csrc/jit/passes/specialize_autogradzero.cpp @@ -18,7 +18,8 @@ void specializeAutogradZero(Graph& g) { const auto& tp = input->type(); if (tp->isSubtypeOf(AutogradZeroTensorType::get())) { state[input] = State::Zero; - } else if (tp->isSubtypeOf(TensorType::get())) { + } else if (tp->isSubtypeOf(TensorType::get()) + || tp->isSubtypeOf(ListType::ofTensors())) { state[input] = State::Nonzero; } else { state[input] = State::Unknown; diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index 55a795f..2a99f72 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -6,7 +6,6 @@ namespace { std::mutex lock; const std::vector functions = { R"( - #### HELPER FUNCTIONS ### #### PREFIX: AD_ ### #### SCHEMA NOT SAVED IN CACHE ### @@ -475,6 +474,91 @@ const std::vector functions = { return self * other, backward + def reshape(self, + shape: List[int]): + self_size = self.size() + def backward(grad_output): + grad_self = grad_output.reshape(self_size) + return grad_self, None + + return torch.reshape(self, shape), backward + + def split(self, + split_size: int, + dim: int): + def backward(grad_outputs: List[Tensor]): + grad_self = torch.cat(grad_outputs, dim) + return grad_self, None, None + + return torch.split(self, split_size, dim), backward + + def split_with_sizes(self, + split_sizes: List[int], + dim: int=0): + def backward(grad_outputs: List[Tensor]): + size = len(grad_outputs) + grad_self = torch.cat(grad_outputs, dim) + return grad_self, None, None + + return torch.split_with_sizes(self, split_sizes, dim), backward + + def stack(tensors: List[Tensor], + dim: int=0): + def backward(grad_output): + grad_tensors = torch.unbind(grad_output, dim) + return grad_tensors, None + + return torch.stack(tensors, dim), backward + + def unbind(self, + dim: int=0): + def backward(grad_outputs: List[Tensor]): + grad_self = torch.stack(grad_outputs, dim) + return grad_self, None + + return torch.unbind(self, dim), backward + + def cat(tensors: List[Tensor], + dim: int=0): + size = len(tensors) + split_sizes = [0] * size + for i in range(size): + if tensors[i].numel() > 0: + split_sizes[i] = tensors[i].size()[dim] + + def backward(grad_output): + grad_tensors = torch.split_with_sizes(grad_output, split_sizes, dim) + return grad_tensors, None + + return torch.cat(tensors, dim), backward + + def index(self, + indices: List[Tensor]): + def backward(grad_output): + grad_self = torch.zeros_like(self).index_put_(indices, grad_output, True) + return grad_self, None + + return torch.index(self, indices), backward + + def meshgrid(tensors: List[Tensor]): + size = len(tensors) + sizes = [0] * size + for i in range(size): + if tensors[i].dim() != 0: + sizes[i] = tensors[i].size()[0] + def backward(grad_outputs: List[Tensor]): + grads_tensors = [] + for i in range(size): + view_shape = [1] * size + if sizes[i] == 0: + view_shape[i] = 1 + grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape(())) + else: + view_shape[i] = sizes[i] + grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape([sizes[i]])) + return grads_tensors + return torch.meshgrid(tensors), backward + def mv(self, vec): def backward(grad_output): return grad_output.ger(vec), self.t().mv(grad_output) -- 2.7.4