import operator
from torch.testing import FileCheck
-from typing import List
if __name__ == '__main__':
self.assertEqual(output_shape[1], sym2)
self.assertEqual(output_shape[2], sym3)
- def test_sharing_of_list_len(self):
- @torch.jit.script
- def foo(x, out: List[int]):
- return torch.nn.functional.adaptive_avg_pool2d(x, out)
-
- self.run_pass("inline", foo.graph)
- torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
- FileCheck().check("Tensor(*, *)").check_same("adaptive_avg_pool2d").run(foo.graph)
-
def test_shared_shape_graph(self):
@torch.jit.script
def foo(x, y):
inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1]))
torch._C._jit_pass_propagate_shapes_on_graph(graph)
self.assertEqual(next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1])
+
+ def test_adaptive_avg_pool2d(self):
+ inps = [
+ [(1, 64, 8, 9), (5, 7)],
+ [(1, 64, 10, 9), (7)],
+ [(1, 64, 10, 9), (5, None)],
+ [(1, 8, 4, 3), (None, None)],
+ [(1, 8, 4, 3), (None, 5)],
+ ]
+
+ for inp in inps:
+ t = torch.randn(*inp[0])
+ out_size = torch.nn.functional.adaptive_avg_pool2d(t, inp[1]).size()
+
+ def foo(x):
+ return torch.nn.functional.adaptive_avg_pool2d(x, inp[1])
+
+ fn = torch.jit.trace(foo, (t,))
+ torch._C._jit_erase_non_input_shape_information(fn.graph)
+ torch._C._jit_pass_peephole(fn.graph)
+ torch._C._jit_pass_constant_propagation(fn.graph)
+ self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True)
return expandedSizes
def adaptive_avg_pool2d(self: List[int], out: List[int]):
- # TODO: return out directly, list len refiner would need to
- # annotate the List Type with len directly in IR
assert len(out) == 2
- return [out[0], out[1]]
+ assert len(self) == 3 or len(self) == 4
+ for i in range (1, len(self)):
+ assert self[i] != 0
+
+ shape: List[int] = []
+ for i in range(0, len(self) -2):
+ shape.append(self[i])
+ for elem in out:
+ shape.append(elem)
+ return shape
# TODO: maybe make it customary that extra arguments are unused ?
# TODO: return self directly