Remove PrimExpr from String (#5311)
authorZhi <5145158+zhiics@users.noreply.github.com>
Sun, 12 Apr 2020 16:12:23 +0000 (09:12 -0700)
committerGitHub <noreply@github.com>
Sun, 12 Apr 2020 16:12:23 +0000 (09:12 -0700)
include/tvm/ir/expr.h
src/ir/expr.cc
src/target/target.cc
src/tir/ir/stmt.cc
topi/include/topi/contrib/cublas.h
topi/include/topi/contrib/rocblas.h

index 4e0a301..859a134 100644 (file)
@@ -108,12 +108,6 @@ class PrimExpr : public BaseExpr {
    */
   TVM_DLL PrimExpr(float value);  // NOLINT(*)
 
-  /*!
-   * \brief construct from runtime String.
-   * \param value The value to be constructed.
-   */
-  TVM_DLL PrimExpr(runtime::String value);  // NOLINT(*)
-
   /*! \return the data type of this expression. */
   DataType dtype() const {
     return static_cast<const PrimExprNode*>(get())->dtype;
index e08d832..7272213 100644 (file)
@@ -40,9 +40,6 @@ PrimExpr::PrimExpr(int32_t value)
 PrimExpr::PrimExpr(float value)
     : PrimExpr(FloatImm(DataType::Float(32), value)) {}
 
-PrimExpr::PrimExpr(runtime::String value)
-    : PrimExpr(tir::StringImmNode::make(value)) {}
-
 PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
   using runtime::ObjectTypeChecker;
   if (auto* ptr = ref.as<tir::IterVarNode>()) {
index 61d5f6f..50856d6 100644 (file)
@@ -137,7 +137,7 @@ Target CreateTarget(const std::string& target_name,
   } else if (target_name == "hybrid") {
     t->device_type = kDLCPU;
   } else if (target_name == "hexagon") {
-    t->keys_array.push_back(runtime::String("hexagon"));
+    t->keys_array.push_back("hexagon");
     t->device_type = kDLHexagon;
   } else {
     LOG(ERROR) << "Unknown target name " << target_name;
index 64e7ef5..705fe7b 100644 (file)
@@ -58,7 +58,6 @@ Stmt AttrStmtNode::make(ObjectRef node,
 TVM_REGISTER_GLOBAL("tir.AttrStmt")
 .set_body_typed(AttrStmtNode::make);
 
-
 Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
   CHECK(condition.defined());
   CHECK(message.dtype() == DataType::Int(32) ||
@@ -74,8 +73,14 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
 }
 
 TVM_REGISTER_GLOBAL("tir.AssertStmt")
-.set_body_typed(AssertStmtNode::make);
-
+.set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) {
+  if (const auto* str = message.as<StringObj>()) {
+    auto msg = StringImmNode::make(str->data);
+    return AssertStmtNode::make(condition, msg, body);
+  } else {
+    return AssertStmtNode::make(condition, Downcast<PrimExpr>(message), body);
+  }
+});
 
 Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
   CHECK(body.defined());
@@ -92,11 +97,11 @@ TVM_REGISTER_GLOBAL("tir.ProducerConsumer")
 
 
 Stmt ForNode::make(Var loop_var,
-               PrimExpr min,
-               PrimExpr extent,
-               ForType for_type,
-               DeviceAPI device_api,
-               Stmt body) {
+                   PrimExpr min,
+                   PrimExpr extent,
+                   ForType for_type,
+                   DeviceAPI device_api,
+                   Stmt body) {
   CHECK(min.defined());
   CHECK(extent.defined());
   CHECK(min.dtype().is_scalar());
@@ -119,11 +124,11 @@ TVM_REGISTER_GLOBAL("tir.For")
   Var loop_var, PrimExpr min, PrimExpr extent,
   int for_type, int device_api, Stmt body) {
   return ForNode::make(loop_var,
-                   min,
-                   extent,
-                   static_cast<ForType>(for_type),
-                   static_cast<DeviceAPI>(device_api),
-                   body);
+                       min,
+                       extent,
+                       static_cast<ForType>(for_type),
+                       static_cast<DeviceAPI>(device_api),
+                       body);
 });
 
 
@@ -176,12 +181,12 @@ TVM_REGISTER_GLOBAL("tir.Provide")
 
 
 Stmt AllocateNode::make(Var buffer_var,
-                    DataType dtype,
-                    Array<PrimExpr> extents,
-                    PrimExpr condition,
-                    Stmt body,
-                    PrimExpr new_expr,
-                    std::string free_function) {
+                        DataType dtype,
+                        Array<PrimExpr> extents,
+                        PrimExpr condition,
+                        Stmt body,
+                        PrimExpr new_expr,
+                        std::string free_function) {
     for (size_t i = 0; i < extents.size(); ++i) {
       CHECK(extents[i].defined());
       CHECK(extents[i].dtype().is_scalar());
index ee18dea..f2ed029 100644 (file)
@@ -53,7 +53,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
     { { n, m } }, { lhs->dtype }, { lhs, rhs },
     [&](Array<Buffer> ins, Array<Buffer> outs) {
       return call_packed({
-        runtime::String("tvm.contrib.cublas.matmul"),
+        StringImmNode::make("tvm.contrib.cublas.matmul"),
         pack_buffer(ins[0]),
         pack_buffer(ins[1]),
         pack_buffer(outs[0]),
@@ -85,7 +85,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs,
     { { b, n, m } }, { lhs->dtype }, { lhs, rhs },
     [&](Array<Buffer> ins, Array<Buffer> outs) {
       return call_packed({
-        runtime::String("tvm.contrib.cublas.batch_matmul"),
+        StringImmNode::make("tvm.contrib.cublas.batch_matmul"),
         pack_buffer(ins[0]),
         pack_buffer(ins[1]),
         pack_buffer(outs[0]),
index 9fe1825..f0bf926 100644 (file)
@@ -52,7 +52,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
     { { n, m } }, { lhs->dtype }, { lhs, rhs },
     [&](Array<Buffer> ins, Array<Buffer> outs) {
       return call_packed({
-        runtime::String("tvm.contrib.rocblas.matmul"),
+        StringImmNode::make("tvm.contrib.rocblas.matmul"),
         pack_buffer(ins[0]),
         pack_buffer(ins[1]),
         pack_buffer(outs[0]),