[PYTHON] Make IntImm more like an integer (#5232)
authorTianqi Chen <tqchen@users.noreply.github.com>
Fri, 3 Apr 2020 22:24:19 +0000 (15:24 -0700)
committerGitHub <noreply@github.com>
Fri, 3 Apr 2020 22:24:19 +0000 (15:24 -0700)
python/tvm/tir/expr.py
tests/python/unittest/test_tir_nodes.py

index 20a3bca..4cbece3 100644 (file)
@@ -439,6 +439,7 @@ class FloatImm(ConstExpr):
         self.__init_handle_by_constructor__(
             tvm.ir._ffi_api.FloatImm, dtype, value)
 
+
 @tvm._ffi.register_object
 class IntImm(ConstExpr):
     """Int constant.
@@ -455,9 +456,24 @@ class IntImm(ConstExpr):
         self.__init_handle_by_constructor__(
             tvm.ir._ffi_api.IntImm, dtype, value)
 
+    def __hash__(self):
+        return self.value
+
     def __int__(self):
         return self.value
 
+    def __nonzero__(self):
+        return self.value != 0
+
+    def __eq__(self, other):
+        return _ffi_api._OpEQ(self, other)
+
+    def __ne__(self, other):
+        return _ffi_api._OpNE(self, other)
+
+    def __bool__(self):
+        return self.__nonzero__()
+
 
 @tvm._ffi.register_object
 class StringImm(ConstExpr):
index 2e23a61..00ac7e3 100644 (file)
@@ -302,7 +302,21 @@ def test_buffer_load_store():
     assert isinstance(s, tvm.tir.BufferStore)
 
 
+def test_intimm_cond():
+    x = tvm.runtime.convert(1)
+    y = tvm.runtime.convert(1)
+    s = {x}
+    assert y in s
+    assert x == y
+    assert x < 20
+    assert not (x >= 20)
+    assert x < 10 and y < 10
+    assert not tvm.runtime.convert(x != 1)
+    assert x == 1
+
+
 if __name__ == "__main__":
+    test_intimm_cond()
     test_buffer_load_store()
     test_vars()
     test_prim_func()