[RELAY][PYTORCH]isNan, isinf, isfinite, ceil, clamp, round ops (#5316)
authorSamuel <siju.samuel@huawei.com>
Tue, 14 Apr 2020 09:45:02 +0000 (15:15 +0530)
committerGitHub <noreply@github.com>
Tue, 14 Apr 2020 09:45:02 +0000 (18:45 +0900)
* [RELAY][PYTORCH]isNan, isinf, isfinite, ceil, clamp, round ops

* Review comments

docs/frontend/tensorflow.rst
python/tvm/relay/frontend/pytorch.py
python/tvm/relay/op/_tensor.py
python/tvm/relay/op/tensor.py
src/relay/op/tensor/unary.cc
tests/python/frontend/pytorch/test_forward.py

index 45db9e4..a158db9 100644 (file)
@@ -162,6 +162,7 @@ Supported Ops
 - Identity
 - IsFinite
 - IsInf
+- IsNan
 - LeakyRelu
 - LeftShift
 - Less
index 18868cf..38a811d 100644 (file)
@@ -1118,12 +1118,45 @@ def _sqrt():
         return _op.tensor.sqrt(data)
     return _impl
 
+
+def _rsqrt():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        return _op.tensor.rsqrt(data)
+    return _impl
+
+
+def _ceil():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        return _op.ceil(data)
+    return _impl
+
+
+def _clamp():
+    def _impl(inputs, input_types):
+        print(inputs, input_types)
+        data = inputs[0]
+        amin = inputs[1] if inputs[1] else np.finfo(np.float32).min
+        amax = inputs[2] if inputs[2] else np.finfo(np.float32).max
+        return _op.clip(data, amin, amax)
+    return _impl
+
+
 def _floor():
     def _impl(inputs, input_types):
         data = inputs[0]
         return _op.floor(data)
     return _impl
 
+
+def _round():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        return _op.round(data)
+    return _impl
+
+
 def _to():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1232,6 +1265,18 @@ def _mm():
     return _impl
 
 
+def _isfinite():
+    def _impl(inputs, input_types):
+        return _op.isfinite(inputs[0])
+    return _impl
+
+
+def _isnan():
+    def _impl(inputs, input_types):
+        return _op.isnan(inputs[0])
+    return _impl
+
+
 def _list_getitem(prelude):
     def _impl(inputs, input_types):
         return prelude.nth(inputs[0], _wrap_const(inputs[1]))
@@ -1429,7 +1474,11 @@ def _get_convert_map(prelude):
         "aten::std"                             : _std(),
         "aten::var"                             : _variance(),
         "aten::sqrt"                            : _sqrt(),
-        'aten::floor'                           : _floor(),
+        "aten::rsqrt"                           : _rsqrt(),
+        "aten::ceil"                            : _ceil(),
+        "aten::clamp"                           : _clamp(),
+        "aten::floor"                           : _floor(),
+        "aten::round"                           : _round(),
         "aten::detach"                          : _identity(),
         "aten::upsample_bilinear2d"             : _upsample("bilinear"),
         "aten::upsample_nearest2d"              : _upsample("nearest_neighbor"),
@@ -1439,6 +1488,9 @@ def _get_convert_map(prelude):
         "aten::le"                              : _elemwise("less_equal"),
         "aten::ge"                              : _elemwise("greater_equal"),
         "aten::ne"                              : _elemwise("not_equal"),
+        "aten::eq"                              : _elemwise("equal"),
+        "aten::isfinite"                        : _isfinite(),
+        "aten::isnan"                           : _isnan(),
         "aten::Bool"                            : _Bool(),
         "aten::Float"                           : _Float(),
         "aten::neg"                             : _neg(),
index a607a47..79a623d 100644 (file)
@@ -66,6 +66,7 @@ register_broadcast_schedule("less")
 register_broadcast_schedule("less_equal")
 register_broadcast_schedule("greater")
 register_broadcast_schedule("greater_equal")
+register_broadcast_schedule("isnan")
 register_broadcast_schedule("isfinite")
 register_broadcast_schedule("isinf")
 register_injective_schedule("maximum")
index 1f481ee..f602407 100644 (file)
@@ -1010,6 +1010,22 @@ def ndarray_size(data, dtype="int32"):
     return _make.ndarray_size(data, dtype)
 
 
+def isnan(data):
+    """Check nan in input data element-wise.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data
+
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+    """
+    return _make.isnan(data)
+
+
 def isfinite(data):
     """Compute element-wise finiteness of data.
 
index 4cca8b0..10da11d 100644 (file)
@@ -426,6 +426,15 @@ ElemwiseArbitraryLayout)
 .set_support_level(10)
 .set_attr<FTVMCompute>("FTVMCompute", NdarraySizeCompute);
 
+RELAY_REGISTER_UNARY_OP("isnan")
+.describe(R"code(Returns whether the input contains any NaN, computed element-wise.
+.. math::
+   isnan(x)
+)code" TVM_ADD_FILELINE)
+.set_support_level(3)
+.add_type_rel("IdentityCompRel", IdentityCompRel)
+.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isnan));
+
 RELAY_REGISTER_UNARY_OP("isfinite")
 .describe(R"code(Returns the finiteness of input, computed element-wise.
 .. math::
@@ -438,7 +447,7 @@ RELAY_REGISTER_UNARY_OP("isfinite")
 RELAY_REGISTER_UNARY_OP("isinf")
 .describe(R"code(Returns the infiniteness of input, computed element-wise.
 .. math::
-   isfinite(x)
+   isinf(x)
 )code" TVM_ADD_FILELINE)
 .set_support_level(3)
 .add_type_rel("IdentityCompRel", IdentityCompRel)
index 91e14c6..d9d280f 100644 (file)
@@ -1441,6 +1441,110 @@ def test_forward_variance():
     verify_model(Variance5().float().eval(), input_data=input_data)
 
 
+
+def test_forward_isfinite():
+    torch.set_grad_enabled(False)
+
+    class IsFinite1(Module):
+        def forward(self, *args):
+            return torch.isfinite(args[0])
+
+    input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float()
+    verify_model(IsFinite1().float().eval(), input_data=input_data)
+
+
+def test_forward_isnan():
+    torch.set_grad_enabled(False)
+
+    class IsNan1(Module):
+        def forward(self, *args):
+            return torch.isnan(args[0])
+
+    input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float()
+    verify_model(IsNan1().float().eval(), input_data=input_data)
+
+
+def test_forward_isinf():
+    torch.set_grad_enabled(False)
+
+    class IsInf1(Module):
+        def forward(self, *args):
+            return torch.isinf(args[0])
+
+    input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float()
+    verify_model(IsInf1().float().eval(), input_data=input_data)
+
+
+def test_forward_rsqrt():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class Rsqrt1(Module):
+        def forward(self, *args):
+            return torch.rsqrt(args[0])
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Rsqrt1().float().eval(), input_data=input_data)
+
+
+def test_forward_ceil():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class Ceil1(Module):
+        def forward(self, *args):
+            return torch.ceil(args[0])
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Ceil1().float().eval(), input_data=input_data)
+
+
+def test_forward_clamp():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class Clamp1(Module):
+        def forward(self, *args):
+            return torch.clamp(args[0], min=-0.5, max=0.5)
+
+    class Clamp2(Module):
+        def forward(self, *args):
+            return torch.clamp(args[0], min=-0.3)
+
+    class Clamp3(Module):
+        def forward(self, *args):
+            return torch.clamp(args[0], max=1.0)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Clamp1().float().eval(), input_data=input_data)
+    verify_model(Clamp2().float().eval(), input_data=input_data)
+    verify_model(Clamp3().float().eval(), input_data=input_data)
+
+
+def test_forward_floor():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class Floor1(Module):
+        def forward(self, *args):
+            return torch.floor(args[0])
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Floor1().float().eval(), input_data=input_data)
+
+
+def test_forward_round():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3, 10, 10]
+
+    class Round1(Module):
+        def forward(self, *args):
+            return torch.round(args[0])
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Round1().float().eval(), input_data=input_data)
+
+
 if __name__ == "__main__":
     # Single operator tests
     test_forward_add()
@@ -1497,6 +1601,14 @@ if __name__ == "__main__":
     test_forward_expand()
     test_forward_pow()
     test_forward_abs()
+    test_forward_rsqrt()
+    test_forward_ceil()
+    test_forward_clamp()
+    test_forward_floor()
+    test_forward_round()
+    test_forward_isfinite()
+    test_forward_isnan()
+    test_forward_isinf()
     test_forward_arange()
     test_forward_chunk()
     test_forward_split()