if (value->getType() == target) return value;
if (to.is_handle()) {
return builder_->CreateBitCast(value, target);
+ } else if (to.is_uint() && to.bits() == 1) {
+ if (from.is_float()) {
+ llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.);
+ return builder_->CreateFCmpONE(value, zero);
+ } else {
+ llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0);
+ return builder_->CreateICmpNE(value, zero);
+ }
} else if (!from.is_float() && !to.is_float()) {
return builder_->CreateIntCast(value, target, from.is_int());
} else if (from.is_float() && to.is_int()) {
import topi
import topi.testing
from topi import util
+from common import get_all_backend
def test_util():
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
- for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm', 'nvptx', 'sdaccel',
- 'aocl_sw_emu']:
+ for device in get_all_backend():
check_device(device)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True)
+
+def test_cast():
+ def verify(from_dtype, to_dtype, low=-100, high=100):
+ shape = (5, 4)
+ A = tvm.placeholder(shape, dtype=from_dtype, name="A")
+ B = topi.cast(A, to_dtype)
+
+ if from_dtype == "bool":
+ a_np = np.random.choice([True, False], size=shape)
+ else:
+ a_np = np.random.uniform(low, high, size=shape).astype(from_dtype)
+ if to_dtype == "bool":
+ a_np = a_np - a_np[2, 3]
+ b_np = a_np.astype(to_dtype)
+
+ for device in get_all_backend():
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ continue
+ print("Running on target: %s" % device)
+ with tvm.target.create(device):
+ s = topi.generic.schedule_injective(B)
+ foo = tvm.build(s, [A, B], device)
+ a = tvm.nd.array(a_np, ctx)
+ b = tvm.nd.empty(shape=shape, dtype=to_dtype, ctx=ctx)
+ foo(a, b)
+ tvm.testing.assert_allclose(b.asnumpy(), b_np)
+
+ verify("int32", "float32")
+ verify("int32", "float64")
+ verify("int32", "bool")
+ verify("float32", "int32")
+ verify("float32", "float64")
+ verify("float32", "bool")
+ verify("bool", "float32")
+ verify("bool", "int32")
+
+
if __name__ == "__main__":
test_util()
test_ewise()
+ test_cast()