From c1eb31566ac7321809f4b9734df97edf378573f6 Mon Sep 17 00:00:00 2001 From: shiwenloong <52487098+shiwenloong@users.noreply.github.com> Date: Fri, 7 Aug 2020 08:55:46 +0800 Subject: [PATCH] [PYTORCH]Std op without specified dimensions support (#6226) --- python/tvm/relay/frontend/pytorch.py | 11 ++++++++--- tests/python/frontend/pytorch/test_forward.py | 5 +++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3dfdb2f..bbc684e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1253,9 +1253,14 @@ def _frobenius_norm(): def _std(): def _impl(inputs, input_types): data = inputs[0] - axis = list(_infer_shape(inputs[1])) - keepdims = bool(inputs[3]) - unbiased = bool(inputs[2]) + if len(inputs) == 2: + axis = None + keepdims = False + unbiased = bool(inputs[1]) + else: + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[3]) + unbiased = bool(inputs[2]) if unbiased: msg = "Currently only supports standard-deviation calculated via the biased "\ diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e370cd5..3c9dfb1 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1869,12 +1869,17 @@ def test_forward_std(): def forward(self, *args): return args[0].std(dim=(2,3), keepdim=False, unbiased=False) + class Std6(Module): + def forward(self, *args): + return args[0].std(unbiased=False) + input_data = torch.rand(input_shape).float() verify_model(Std1().float().eval(), input_data=input_data) verify_model(Std2().float().eval(), input_data=input_data) verify_model(Std3().float().eval(), input_data=input_data) verify_model(Std4().float().eval(), input_data=input_data) verify_model(Std5().float().eval(), input_data=input_data) + verify_model(Std6().float().eval(), input_data=input_data) def test_forward_variance(): -- 2.7.4