[BugFix] Fix bug in cast to bool (#3207)
authorHaichen Shen <shenhaichen@gmail.com>
Mon, 20 May 2019 17:07:01 +0000 (10:07 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 20 May 2019 17:07:01 +0000 (10:07 -0700)
src/codegen/llvm/codegen_llvm.cc
topi/tests/python/test_topi_math.py

index 7946f90..bedcdc7 100644 (file)
@@ -537,6 +537,14 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
   if (value->getType() == target) return value;
   if (to.is_handle()) {
     return builder_->CreateBitCast(value, target);
+  } else if (to.is_uint() && to.bits() == 1) {
+    if (from.is_float()) {
+      llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.);
+      return builder_->CreateFCmpONE(value, zero);
+    } else {
+      llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0);
+      return builder_->CreateICmpNE(value, zero);
+    }
   } else if (!from.is_float() && !to.is_float()) {
     return builder_->CreateIntCast(value, target, from.is_int());
   } else if (from.is_float() && to.is_int()) {
index c180bc7..d6df450 100644 (file)
@@ -19,6 +19,7 @@ import tvm
 import topi
 import topi.testing
 from topi import util
+from common import get_all_backend
 
 
 def test_util():
@@ -59,8 +60,7 @@ def test_ewise():
             foo(a, b)
             tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
 
-        for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm', 'nvptx', 'sdaccel',
-                       'aocl_sw_emu']:
+        for device in get_all_backend():
             check_device(device)
 
 
@@ -77,6 +77,46 @@ def test_ewise():
     test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
     test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True)
 
+
+def test_cast():
+    def verify(from_dtype, to_dtype, low=-100, high=100):
+        shape = (5, 4)
+        A = tvm.placeholder(shape, dtype=from_dtype, name="A")
+        B = topi.cast(A, to_dtype)
+
+        if from_dtype == "bool":
+            a_np = np.random.choice([True, False], size=shape)
+        else:
+            a_np = np.random.uniform(low, high, size=shape).astype(from_dtype)
+        if to_dtype == "bool":
+            a_np = a_np - a_np[2, 3]
+        b_np = a_np.astype(to_dtype)
+
+        for device in get_all_backend():
+            ctx = tvm.context(device,  0)
+            if not ctx.exist:
+                print("Skip because %s is not enabled" % device)
+                continue
+            print("Running on target: %s" % device)
+            with tvm.target.create(device):
+                s = topi.generic.schedule_injective(B)
+            foo = tvm.build(s, [A, B], device)
+            a = tvm.nd.array(a_np, ctx)
+            b = tvm.nd.empty(shape=shape, dtype=to_dtype, ctx=ctx)
+            foo(a, b)
+            tvm.testing.assert_allclose(b.asnumpy(), b_np)
+
+    verify("int32", "float32")
+    verify("int32", "float64")
+    verify("int32", "bool")
+    verify("float32", "int32")
+    verify("float32", "float64")
+    verify("float32", "bool")
+    verify("bool", "float32")
+    verify("bool", "int32")
+
+
 if __name__ == "__main__":
     test_util()
     test_ewise()
+    test_cast()