def _std():
def _impl(inputs, input_types):
data = inputs[0]
- axis = list(_infer_shape(inputs[1]))
- keepdims = bool(inputs[3])
- unbiased = bool(inputs[2])
+ if len(inputs) == 2:
+ axis = None
+ keepdims = False
+ unbiased = bool(inputs[1])
+ else:
+ axis = list(_infer_shape(inputs[1]))
+ keepdims = bool(inputs[3])
+ unbiased = bool(inputs[2])
if unbiased:
msg = "Currently only supports standard-deviation calculated via the biased "\
def forward(self, *args):
return args[0].std(dim=(2,3), keepdim=False, unbiased=False)
+ class Std6(Module):
+ def forward(self, *args):
+ return args[0].std(unbiased=False)
+
input_data = torch.rand(input_shape).float()
verify_model(Std1().float().eval(), input_data=input_data)
verify_model(Std2().float().eval(), input_data=input_data)
verify_model(Std3().float().eval(), input_data=input_data)
verify_model(Std4().float().eval(), input_data=input_data)
verify_model(Std5().float().eval(), input_data=input_data)
+ verify_model(Std6().float().eval(), input_data=input_data)
def test_forward_variance():