Fix serialization of inf float value (#5912)
authorlixiaoquan <radioheads@163.com>
Wed, 24 Jun 2020 15:49:13 +0000 (23:49 +0800)
committerGitHub <noreply@github.com>
Wed, 24 Jun 2020 15:49:13 +0000 (08:49 -0700)
src/node/serialization.cc
tests/python/unittest/test_node_reflection.py

index 4382579..42767c2 100644 (file)
@@ -352,9 +352,15 @@ class JSONAttrSetter : public AttrVisitor {
   template <typename T>
   void ParseValue(const char* key, T* value) const {
     std::istringstream is(GetValue(key));
-    is >> *value;
-    if (is.fail()) {
-      LOG(FATAL) << "Wrong value format for field " << key;
+    if (is.str() == "inf") {
+      *value = std::numeric_limits<T>::infinity();
+    } else if (is.str() == "-inf") {
+      *value = -std::numeric_limits<T>::infinity();
+    } else {
+      is >> *value;
+      if (is.fail()) {
+        LOG(FATAL) << "Wrong value format for field " << key;
+      }
     }
   }
   void Visit(const char* key, double* value) final { ParseValue(key, value); }
index 3a7318c..d375fa0 100644 (file)
@@ -28,6 +28,16 @@ def test_const_saveload_json():
     zz = tvm.ir.load_json(json_str)
     tvm.ir.assert_structural_equal(zz, z, map_free_vars=True)
 
+def _test_infinity_value(value, dtype):
+    x = tvm.tir.const(value, dtype)
+    json_str = tvm.ir.save_json(x)
+    tvm.ir.assert_structural_equal(x, tvm.ir.load_json(json_str))
+
+def test_infinity_value():
+    _test_infinity_value(float("inf"), 'float64')
+    _test_infinity_value(float("-inf"), 'float64')
+    _test_infinity_value(float("inf"), 'float32')
+    _test_infinity_value(float("-inf"), 'float32')
 
 def test_make_smap():
     # save load json
@@ -145,3 +155,4 @@ if __name__ == "__main__":
     test_make_sum()
     test_pass_config()
     test_dict()
+    test_infinity_value()