move 2ops back to autodiff (#18969)
authorAiling Zhang <ailzhang@fb.com>
Sun, 7 Apr 2019 04:36:22 +0000 (21:36 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 7 Apr 2019 04:41:25 +0000 (21:41 -0700)
Summary:
Move these 2 ops back to autodiff to unblock xla CI.
I will leave them for my next PR to cleanup symbolic_variable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18969

Differential Revision: D14816811

Pulled By: ailzhang

fbshipit-source-id: dd8a7e133dcad29560d3d1d25691883960117299

torch/csrc/jit/autodiff.cpp
torch/csrc/jit/symbolic_script.cpp

index 094732c..922681e 100644 (file)
@@ -36,8 +36,6 @@ void wrapDim(int64_t& dim, const std::vector<int64_t>& sizes) {
 bool needTrimGrad(Node* n) {
   static OperatorSet need_trim_grad_ops = {
       "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
-      "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
-      "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
       "aten::topk(Tensor self, int k, int dim, bool largest, bool sorted) -> (Tensor, Tensor)",
   };
   if (need_trim_grad_ops.find(n)) {
@@ -66,6 +64,8 @@ bool isDifferentiable(Node* n) {
       "aten::ne(Tensor self, Scalar other) -> Tensor",
       "aten::fmod(Tensor self, Scalar other) -> Tensor",
       "aten::remainder(Tensor self, Scalar other) -> Tensor",
+      "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
+      "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
       "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
   };
 
@@ -347,6 +347,56 @@ class GradientHelper {
 
     } else if (
         node->matches(
+            "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)")) {
+      AT_ASSERT(grads.size() == 2);
+      auto graph = node->owningGraph();
+      auto backward_value = graph->insert(
+          aten::max_pool2d_with_indices_backward,
+          {grads.at(0).value(),
+           node->namedInput(attr::self),
+           node->namedInput(attr::kernel_size),
+           node->namedInput(attr::stride),
+           node->namedInput(attr::padding),
+           node->namedInput(attr::dilation),
+           node->namedInput(attr::ceil_mode),
+           outputs.at(1).value()});
+      return {backward_value->node()->output(0),
+              nullptr,
+              nullptr,
+              nullptr,
+              nullptr,
+              nullptr};
+
+    } else if (
+        node->matches(
+            "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
+      auto graph = node->owningGraph();
+      auto backward_value = graph->insert(
+          aten::thnn_conv2d_backward,
+          {grads.at(0).value(),
+           inputs.at(0).value(),
+           inputs.at(1).value(),
+           node->namedInput(attr::kernel_size),
+           node->namedInput(attr::stride),
+           node->namedInput(attr::padding),
+           outputs.at(1).value(),
+           outputs.at(2).value(),
+           graph->insertConstant(std::vector<bool>{true, true, true})});
+      // graph->insert returns a tuple automatically if multiple outputs are
+      // returned. So unpack them again.
+      Node* tuple_unpack_node =
+          graph->insertNode(graph->createTupleUnpack(backward_value));
+      auto tuple_outputs = tuple_unpack_node->outputs();
+      AT_ASSERT(tuple_outputs.size() == size_t(3));
+      return {tuple_outputs[0],
+              tuple_outputs[1],
+              nullptr,
+              tuple_outputs[2],
+              nullptr,
+              nullptr};
+
+    } else if (
+        node->matches(
             "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
       auto graph = node->owningGraph();
       auto backward_value = graph->insert(
index 170471c..55a795f 100644 (file)
@@ -1055,19 +1055,6 @@ const std::vector<std::string> functions = {
 
             return result, backward
 
-        def max_pool2d_with_indices(self,
-                                    kernel_size: List[int],
-                                    stride: List[int],
-                                    padding: List[int],
-                                    dilation: List[int],
-                                    ceil_mode: bool):
-            result0, result1 = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
-            def backward(grad_output):
-                grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, result1)
-                return grad_self, None, None, None, None, None
-
-            return result0, result1, backward
-
         def nll_loss(self, target, weight: Optional[Tensor], reduction: int, ignore_index: int):
             result, total_weight = torch.nll_loss_forward(self, target, weight, reduction, ignore_index)
             def backward(grad):
@@ -1090,18 +1077,6 @@ const std::vector<std::string> functions = {
 
             return torch.softmax(self, dim, dtype), backward
 
-        def thnn_conv2d_forward(self, weight,
-                                kernel_size: List[int],
-                                bias: Optional[Tensor],
-                                stride: List[int],
-                                padding: List[int]):
-            result0, result1, result2 = torch.thnn_conv2d_forward(self, weight, kernel_size, bias, stride, padding)
-            def backward(grad_output):
-                grad_self, grad_weight, grad_bias = torch.thnn_conv2d_backward(grad_output, self, weight, kernel_size, stride, padding, result1, result2, [True, True, True])
-                return grad_self, grad_weight, None, grad_bias, None, None
-
-            return result0, result1, result2, backward
-
         def AD_interpolate_backward(grad,
                                     input,
                                     mode: str,