Shape Propagation Pass: Fix AdaptiveAveragePooling2d (#63629)
authorPriya Ramani <priyaramani@fb.com>
Wed, 25 Aug 2021 20:08:12 +0000 (13:08 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 25 Aug 2021 20:13:41 +0000 (13:13 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63629

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D30461727

Pulled By: priyaramani

fbshipit-source-id: 3873d1d636f79185680b82de06174d8de288c941

test/jit/test_symbolic_shape_analysis.py
torch/csrc/jit/runtime/symbolic_shape_registry.cpp

index 33dc515..6d4e33c 100644 (file)
@@ -3,7 +3,6 @@ from torch.testing._internal.jit_utils import JitTestCase
 import operator
 
 from torch.testing import FileCheck
-from typing import List
 
 
 if __name__ == '__main__':
@@ -60,15 +59,6 @@ class TestSymbolicShapeAnalysis(JitTestCase):
         self.assertEqual(output_shape[1], sym2)
         self.assertEqual(output_shape[2], sym3)
 
-    def test_sharing_of_list_len(self):
-        @torch.jit.script
-        def foo(x, out: List[int]):
-            return torch.nn.functional.adaptive_avg_pool2d(x, out)
-
-        self.run_pass("inline", foo.graph)
-        torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
-        FileCheck().check("Tensor(*, *)").check_same("adaptive_avg_pool2d").run(foo.graph)
-
     def test_shared_shape_graph(self):
         @torch.jit.script
         def foo(x, y):
@@ -165,3 +155,25 @@ class TestSymbolicShapeAnalysis(JitTestCase):
             inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1]))
             torch._C._jit_pass_propagate_shapes_on_graph(graph)
             self.assertEqual(next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1])
+
+    def test_adaptive_avg_pool2d(self):
+        inps = [
+            [(1, 64, 8, 9), (5, 7)],
+            [(1, 64, 10, 9), (7)],
+            [(1, 64, 10, 9), (5, None)],
+            [(1, 8, 4, 3), (None, None)],
+            [(1, 8, 4, 3), (None, 5)],
+        ]
+
+        for inp in inps:
+            t = torch.randn(*inp[0])
+            out_size = torch.nn.functional.adaptive_avg_pool2d(t, inp[1]).size()
+
+            def foo(x):
+                return torch.nn.functional.adaptive_avg_pool2d(x, inp[1])
+
+            fn = torch.jit.trace(foo, (t,))
+            torch._C._jit_erase_non_input_shape_information(fn.graph)
+            torch._C._jit_pass_peephole(fn.graph)
+            torch._C._jit_pass_constant_propagation(fn.graph)
+            self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True)
index ffc2f44..d447199 100644 (file)
@@ -36,10 +36,17 @@ const std::string shape_compute_functions =
           return expandedSizes
 
         def adaptive_avg_pool2d(self: List[int], out: List[int]):
-          # TODO: return out directly, list len refiner would need to
-          # annotate the List Type with len directly in IR
           assert len(out) == 2
-          return [out[0], out[1]]
+          assert len(self) == 3 or len(self) == 4
+          for i in range (1, len(self)):
+            assert self[i] != 0
+
+          shape: List[int] = []
+          for i in range(0, len(self) -2):
+            shape.append(self[i])
+          for elem in out:
+            shape.append(elem)
+          return shape
 
         # TODO: maybe make it customary that extra arguments are unused ?
         # TODO: return self directly