From de0869de10bd353fc0daf6872a68677c6154483e Mon Sep 17 00:00:00 2001 From: pyjhzwh Date: Fri, 6 Mar 2020 19:39:33 -0500 Subject: [PATCH] Fix stride default value None in torch.nn.functional.avg_pool (#4984) * fix unordered dictionary problem for python version 3.5 * modify style * default value of stride in torch.nn.functional.avg_pool is None * delete prev modifications * add testcase for nn.functional.avg_pool2d --- python/tvm/relay/frontend/pytorch.py | 5 ++++- tests/python/frontend/pytorch/test_forward.py | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1bdcf0a..5716837 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -470,7 +470,10 @@ def _avg_pool2d(): data = inputs[0] pool_size = _infer_shape(inputs[1]) - strides = _infer_shape(inputs[2]) + if inputs[2]: + strides = _infer_shape(inputs[2]) + else: + strides = pool_size padding = _infer_shape(inputs[3]) ceil_mode = int(inputs[4]) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 641f5c9..eed47ea 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -375,8 +375,13 @@ def test_forward_avgpool(): def forward(self, *args): return torch.nn.AvgPool2d(kernel_size=[10, 10])(args[0]) + class AvgPool2D2(Module): + def forward(self, *args): + return torch.nn.functional.avg_pool2d(args[0], kernel_size=[10, 10]) + input_data = torch.rand(input_shape).float() verify_model(AvgPool2D1().float().eval(), input_data=input_data) + verify_model(AvgPool2D2().float().eval(), input_data=input_data) def test_forward_hardtanh(): torch.set_grad_enabled(False) -- 2.7.4