From 9b274cbb7754c9ca2a0ed3678efd00dd6652570a Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 3 Apr 2020 15:24:19 -0700 Subject: [PATCH] [PYTHON] Make IntImm more like an integer (#5232) --- python/tvm/tir/expr.py | 16 ++++++++++++++++ tests/python/unittest/test_tir_nodes.py | 14 ++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 20a3bca..4cbece3 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -439,6 +439,7 @@ class FloatImm(ConstExpr): self.__init_handle_by_constructor__( tvm.ir._ffi_api.FloatImm, dtype, value) + @tvm._ffi.register_object class IntImm(ConstExpr): """Int constant. @@ -455,9 +456,24 @@ class IntImm(ConstExpr): self.__init_handle_by_constructor__( tvm.ir._ffi_api.IntImm, dtype, value) + def __hash__(self): + return self.value + def __int__(self): return self.value + def __nonzero__(self): + return self.value != 0 + + def __eq__(self, other): + return _ffi_api._OpEQ(self, other) + + def __ne__(self, other): + return _ffi_api._OpNE(self, other) + + def __bool__(self): + return self.__nonzero__() + @tvm._ffi.register_object class StringImm(ConstExpr): diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 2e23a61..00ac7e3 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -302,7 +302,21 @@ def test_buffer_load_store(): assert isinstance(s, tvm.tir.BufferStore) +def test_intimm_cond(): + x = tvm.runtime.convert(1) + y = tvm.runtime.convert(1) + s = {x} + assert y in s + assert x == y + assert x < 20 + assert not (x >= 20) + assert x < 10 and y < 10 + assert not tvm.runtime.convert(x != 1) + assert x == 1 + + if __name__ == "__main__": + test_intimm_cond() test_buffer_load_store() test_vars() test_prim_func() -- 2.7.4