From e039c8755fddfdd14da16a7c3c3c2424f68cddd4 Mon Sep 17 00:00:00 2001 From: wjliu Date: Thu, 6 Aug 2020 04:02:36 +0800 Subject: [PATCH] match pytorch 1.6 googlenet pretrained model (#6201) (#6212) --- tests/python/frontend/pytorch/test_forward.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index ab9cca1..e370cd5 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -68,7 +68,11 @@ def load_torchvision(model_name): for channel in range(3): input_data[:, channel] -= mean[channel] input_data[:, channel] /= std[channel] - model = getattr(torchvision.models, model_name)(pretrained=True) + + if model_name.startswith("googlenet"): + model = getattr(torchvision.models, model_name)(pretrained=True, aux_logits=True) + else: + model = getattr(torchvision.models, model_name)(pretrained=True) model = model.float().eval() return model, [input_data] -- 2.7.4