[PYTORCH]aten::norm support added (#5776)
authorSamuel <siju.samuel@huawei.com>
Fri, 12 Jun 2020 17:16:52 +0000 (22:46 +0530)
committerGitHub <noreply@github.com>
Fri, 12 Jun 2020 17:16:52 +0000 (02:16 +0900)
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 2113d7d..a9f4a7b 100644 (file)
@@ -1184,6 +1184,44 @@ def _reduce(name):
 
     return _impl
 
+def _norm():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        axis = None
+        keepdims = False
+        if len(inputs) > 3:
+            axis = list(_infer_shape(inputs[2]))
+            keepdims = bool(inputs[3])
+
+        order = inputs[1]
+        if order == np.inf:
+            return _op.reduce.max(_op.abs(data), axis=axis, keepdims=keepdims)
+        elif order == np.NINF:
+            return _op.reduce.min(_op.abs(data), axis=axis, keepdims=keepdims)
+        else:
+            reci_order = _expr.const(1.0 / order)
+            order = _expr.const(order)
+            return _op.power(_op.reduce.sum(_op.power(_op.abs(data), order),
+                                            axis=axis,
+                                            keepdims=keepdims),
+                             reci_order)
+    return _impl
+
+
+def _frobenius_norm():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        axis = None
+        keepdims = False
+        if len(inputs) > 2:
+            axis = list(_infer_shape(inputs[1]))
+            keepdims = bool(inputs[2])
+
+        return _op.sqrt(_op.reduce.sum((data * data), axis=axis, keepdims=keepdims))
+
+    return _impl
+
+
 def _std():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1853,6 +1891,8 @@ def _get_convert_map(prelude):
         "aten::prod"                            : _reduce("prod"),
         "aten::argmin"                          : _reduce("argmin"),
         "aten::argmax"                          : _reduce("argmax"),
+        "aten::norm"                            : _norm(),
+        "aten::frobenius_norm"                  : _frobenius_norm(),
         "aten::std"                             : _std(),
         "aten::var"                             : _variance(),
         "aten::abs"                             : _unary("abs"),
index c9c76be..86fb409 100644 (file)
@@ -892,6 +892,91 @@ def test_forward_logsoftmax():
     input_data = torch.rand(input_shape).float()
     verify_model(LogSoftmax1().float().eval(), input_data=input_data)
 
+
+def test_forward_norm():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class Norm1(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], p=float('inf'), dim=None, keepdim=False)
+
+    class Norm2(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], p=float('-inf'), dim=None, keepdim=False)
+
+    class Norm3(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], p=float('-inf'), dim=None, keepdim=True)
+
+    class Norm4(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], p=float('inf'), dim=(1, 2), keepdim=False)
+
+    class Norm5(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], p=float('inf'), dim=(1), keepdim=True)
+
+    class Norm6(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], p=float(0.5), dim=(1), keepdim=True)
+
+    class Norm7(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], p=float(1), dim=None, keepdim=False)
+
+    class Norm8(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], p=float(2.0), dim=(1), keepdim=True)
+
+    class Norm9(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], p=float(-0.5), dim=(1, 2), keepdim=True)
+
+    class Norm10(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], p=float(-2), dim=(1), keepdim=False)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Norm1().float().eval(), input_data=input_data)
+    verify_model(Norm2().float().eval(), input_data=input_data)
+    verify_model(Norm3().float().eval(), input_data=input_data)
+    verify_model(Norm4().float().eval(), input_data=input_data)
+    verify_model(Norm5().float().eval(), input_data=input_data)
+    verify_model(Norm6().float().eval(), input_data=input_data)
+    verify_model(Norm7().float().eval(), input_data=input_data)
+    verify_model(Norm8().float().eval(), input_data=input_data)
+    verify_model(Norm9().float().eval(), input_data=input_data)
+    verify_model(Norm10().float().eval(), input_data=input_data)
+
+
+def test_forward_frobenius_norm():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class FroNorm1(Module):
+        def forward(self, *args):
+            return torch.norm(args[0])
+
+    class FroNorm2(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], p='fro', dim=None, keepdim=True)
+
+    class FroNorm3(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], p='fro', dim=(1), keepdim=True)
+
+    class FroNorm4(Module):
+        def forward(self, *args):
+            return torch.norm(args[0], dim=None, keepdim=False)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(FroNorm1().float().eval(), input_data=input_data)
+    verify_model(FroNorm2().float().eval(), input_data=input_data)
+    verify_model(FroNorm3().float().eval(), input_data=input_data)
+    verify_model(FroNorm4().float().eval(), input_data=input_data)
+
+
 def test_forward_sigmoid():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]
@@ -2421,6 +2506,8 @@ if __name__ == "__main__":
     test_forward_reduce_prod()
     test_forward_argmin()
     test_forward_argmax()
+    test_forward_norm()
+    test_forward_frobenius_norm()
     test_forward_std()
     test_forward_variance()
     test_forward_relu()