[PYTORCH]Std op without specified dimensions support (#6226)
authorshiwenloong <52487098+shiwenloong@users.noreply.github.com>
Fri, 7 Aug 2020 00:55:46 +0000 (08:55 +0800)
committerGitHub <noreply@github.com>
Fri, 7 Aug 2020 00:55:46 +0000 (09:55 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 3dfdb2f..bbc684e 100644 (file)
@@ -1253,9 +1253,14 @@ def _frobenius_norm():
 def _std():
     def _impl(inputs, input_types):
         data = inputs[0]
-        axis = list(_infer_shape(inputs[1]))
-        keepdims = bool(inputs[3])
-        unbiased = bool(inputs[2])
+        if len(inputs) == 2:
+            axis = None
+            keepdims = False
+            unbiased = bool(inputs[1])
+        else:
+            axis = list(_infer_shape(inputs[1]))
+            keepdims = bool(inputs[3])
+            unbiased = bool(inputs[2])
 
         if unbiased:
             msg = "Currently only supports standard-deviation calculated via the biased "\
index e370cd5..3c9dfb1 100644 (file)
@@ -1869,12 +1869,17 @@ def test_forward_std():
         def forward(self, *args):
             return args[0].std(dim=(2,3), keepdim=False, unbiased=False)
 
+    class Std6(Module):
+        def forward(self, *args):
+            return args[0].std(unbiased=False)
+
     input_data = torch.rand(input_shape).float()
     verify_model(Std1().float().eval(), input_data=input_data)
     verify_model(Std2().float().eval(), input_data=input_data)
     verify_model(Std3().float().eval(), input_data=input_data)
     verify_model(Std4().float().eval(), input_data=input_data)
     verify_model(Std5().float().eval(), input_data=input_data)
+    verify_model(Std6().float().eval(), input_data=input_data)
 
 
 def test_forward_variance():