PrimExpr::PrimExpr(float value)
: PrimExpr(FloatImm(DataType::Float(32), value)) {}
-PrimExpr::PrimExpr(runtime::String value)
- : PrimExpr(tir::StringImmNode::make(value)) {}
-
PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
using runtime::ObjectTypeChecker;
if (auto* ptr = ref.as<tir::IterVarNode>()) {
TVM_REGISTER_GLOBAL("tir.AttrStmt")
.set_body_typed(AttrStmtNode::make);
-
Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
CHECK(condition.defined());
CHECK(message.dtype() == DataType::Int(32) ||
}
TVM_REGISTER_GLOBAL("tir.AssertStmt")
-.set_body_typed(AssertStmtNode::make);
-
+.set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) {
+ if (const auto* str = message.as<StringObj>()) {
+ auto msg = StringImmNode::make(str->data);
+ return AssertStmtNode::make(condition, msg, body);
+ } else {
+ return AssertStmtNode::make(condition, Downcast<PrimExpr>(message), body);
+ }
+});
Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
CHECK(body.defined());
Stmt ForNode::make(Var loop_var,
- PrimExpr min,
- PrimExpr extent,
- ForType for_type,
- DeviceAPI device_api,
- Stmt body) {
+ PrimExpr min,
+ PrimExpr extent,
+ ForType for_type,
+ DeviceAPI device_api,
+ Stmt body) {
CHECK(min.defined());
CHECK(extent.defined());
CHECK(min.dtype().is_scalar());
Var loop_var, PrimExpr min, PrimExpr extent,
int for_type, int device_api, Stmt body) {
return ForNode::make(loop_var,
- min,
- extent,
- static_cast<ForType>(for_type),
- static_cast<DeviceAPI>(device_api),
- body);
+ min,
+ extent,
+ static_cast<ForType>(for_type),
+ static_cast<DeviceAPI>(device_api),
+ body);
});
Stmt AllocateNode::make(Var buffer_var,
- DataType dtype,
- Array<PrimExpr> extents,
- PrimExpr condition,
- Stmt body,
- PrimExpr new_expr,
- std::string free_function) {
+ DataType dtype,
+ Array<PrimExpr> extents,
+ PrimExpr condition,
+ Stmt body,
+ PrimExpr new_expr,
+ std::string free_function) {
for (size_t i = 0; i < extents.size(); ++i) {
CHECK(extents[i].defined());
CHECK(extents[i].dtype().is_scalar());
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- runtime::String("tvm.contrib.cublas.matmul"),
+ StringImmNode::make("tvm.contrib.cublas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
{ { b, n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- runtime::String("tvm.contrib.cublas.batch_matmul"),
+ StringImmNode::make("tvm.contrib.cublas.batch_matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),