bool needTrimGrad(Node* n) {
static OperatorSet need_trim_grad_ops = {
"aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
- "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
- "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
"aten::topk(Tensor self, int k, int dim, bool largest, bool sorted) -> (Tensor, Tensor)",
};
if (need_trim_grad_ops.find(n)) {
"aten::ne(Tensor self, Scalar other) -> Tensor",
"aten::fmod(Tensor self, Scalar other) -> Tensor",
"aten::remainder(Tensor self, Scalar other) -> Tensor",
+ "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
+ "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
};
} else if (
node->matches(
+ "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)")) {
+ AT_ASSERT(grads.size() == 2);
+ auto graph = node->owningGraph();
+ auto backward_value = graph->insert(
+ aten::max_pool2d_with_indices_backward,
+ {grads.at(0).value(),
+ node->namedInput(attr::self),
+ node->namedInput(attr::kernel_size),
+ node->namedInput(attr::stride),
+ node->namedInput(attr::padding),
+ node->namedInput(attr::dilation),
+ node->namedInput(attr::ceil_mode),
+ outputs.at(1).value()});
+ return {backward_value->node()->output(0),
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr};
+
+ } else if (
+ node->matches(
+ "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
+ auto graph = node->owningGraph();
+ auto backward_value = graph->insert(
+ aten::thnn_conv2d_backward,
+ {grads.at(0).value(),
+ inputs.at(0).value(),
+ inputs.at(1).value(),
+ node->namedInput(attr::kernel_size),
+ node->namedInput(attr::stride),
+ node->namedInput(attr::padding),
+ outputs.at(1).value(),
+ outputs.at(2).value(),
+ graph->insertConstant(std::vector<bool>{true, true, true})});
+ // graph->insert returns a tuple automatically if multiple outputs are
+ // returned. So unpack them again.
+ Node* tuple_unpack_node =
+ graph->insertNode(graph->createTupleUnpack(backward_value));
+ auto tuple_outputs = tuple_unpack_node->outputs();
+ AT_ASSERT(tuple_outputs.size() == size_t(3));
+ return {tuple_outputs[0],
+ tuple_outputs[1],
+ nullptr,
+ tuple_outputs[2],
+ nullptr,
+ nullptr};
+
+ } else if (
+ node->matches(
"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
auto graph = node->owningGraph();
auto backward_value = graph->insert(
return result, backward
- def max_pool2d_with_indices(self,
- kernel_size: List[int],
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- ceil_mode: bool):
- result0, result1 = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
- def backward(grad_output):
- grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, result1)
- return grad_self, None, None, None, None, None
-
- return result0, result1, backward
-
def nll_loss(self, target, weight: Optional[Tensor], reduction: int, ignore_index: int):
result, total_weight = torch.nll_loss_forward(self, target, weight, reduction, ignore_index)
def backward(grad):
return torch.softmax(self, dim, dtype), backward
- def thnn_conv2d_forward(self, weight,
- kernel_size: List[int],
- bias: Optional[Tensor],
- stride: List[int],
- padding: List[int]):
- result0, result1, result2 = torch.thnn_conv2d_forward(self, weight, kernel_size, bias, stride, padding)
- def backward(grad_output):
- grad_self, grad_weight, grad_bias = torch.thnn_conv2d_backward(grad_output, self, weight, kernel_size, stride, padding, result1, result2, [True, True, True])
- return grad_self, grad_weight, None, grad_bias, None, None
-
- return result0, result1, result2, backward
-
def AD_interpolate_backward(grad,
input,
mode: str,