Improve error messages for memory verifier and gpu memory verifier (#6281)
authorTristan Konolige <tkonolige@octoml.ai>
Sat, 15 Aug 2020 03:06:24 +0000 (20:06 -0700)
committerGitHub <noreply@github.com>
Sat, 15 Aug 2020 03:06:24 +0000 (20:06 -0700)
* [FIX] Print exactly what issues the GPU memory verifier encountered.

* [FIX] Print exactly why memory verifier failed.

src/tir/analysis/verify_gpu_code.cc
src/tir/analysis/verify_memory.cc

index cce0823..5ef755a 100644 (file)
@@ -35,9 +35,10 @@ namespace tir {
 
 class GPUCodeVerifier : public StmtExprVisitor {
  public:
-  bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block,
-              int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y,
-              int64_t max_thread_z, int64_t max_vthread, int64_t max_vector_bytes) {
+  std::vector<String> Verify(Stmt stmt, int64_t max_local_memory_per_block,
+                             int64_t max_shared_memory_per_block, int64_t max_threads_per_block,
+                             int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z,
+                             int64_t max_vthread, int64_t max_vector_bytes) {
     max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
     max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
     max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
@@ -52,7 +53,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
     // TODO(jcf94): Add support of detecting CUDA Misaligned Address error
     this->VisitStmt(stmt);
 
-    return valid_;
+    return errors_;
   }
 
   void VisitStmt_(const AllocateNode* op) final {
@@ -66,7 +67,13 @@ class GPUCodeVerifier : public StmtExprVisitor {
       shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
     }
     if (op->dtype.lanes() > 1) {
-      valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
+      if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
+        std::stringstream s;
+        s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
+          << op->dtype.bytes() << ") for dtype " << op->dtype
+          << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
+        errors_.push_back(s.str());
+      }
     }
   }
 
@@ -98,27 +105,39 @@ class GPUCodeVerifier : public StmtExprVisitor {
           visited_threads_.insert(name);
           thread_per_block_ *= length;
 
+          auto err = [this](std::string id, size_t ext, size_t m) {
+            if (ext > m) {
+              std::stringstream s;
+              s << "Extent of " << id << " (" << ext << ") is greater than maximum allowed (" << m
+                << ");";
+              errors_.push_back(s.str());
+            }
+          };
+
           if (name == "threadIdx.x") {
-            valid_ &= length <= max_thread_x_;
+            err("threadIdx.x", length, max_thread_x_);
             thread_x_extent_ = length;
           } else if (name == "threadIdx.y") {
-            valid_ &= length <= max_thread_y_;
+            err("threadIdx.y", length, max_thread_y_);
             thread_y_extent_ = length;
           } else if (name == "threadIdx.z") {
-            valid_ &= length <= max_thread_z_;
+            err("threadIdx.z", length, max_thread_z_);
             thread_z_extent_ = length;
           } else if (name == "vthread") {
-            valid_ &= length <= max_vthread_;
+            err("vthread", length, max_vthread_);
           }
         } else {
           // the thread should be bound to axes with the same length
-          if (name == "threadIdx.x") {
-            valid_ &= length == thread_x_extent_;
-          } else if (name == "threadIdx.y") {
-            valid_ &= length == thread_y_extent_;
-          } else if (name == "threadIdx.z") {
-            valid_ &= length == thread_z_extent_;
-          }
+          auto err = [this, name](std::string id, size_t ext, size_t m) {
+            if (name == id && ext != m) {
+              std::stringstream s;
+              s << "Extent of " << id << " (" << ext << ") does not match the bound " << m;
+              errors_.push_back(s.str());
+            }
+          };
+          err("threadIdx.x", length, thread_x_extent_);
+          err("threadIdx.y", length, thread_y_extent_);
+          err("threadIdx.z", length, thread_z_extent_);
         }
       }
 
@@ -128,10 +147,17 @@ class GPUCodeVerifier : public StmtExprVisitor {
 
       if (nest_level_ == 0) {
         // exit a kernel, check the validity
-        valid_ &= thread_per_block_ <= max_threads_per_block_;
-
-        valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
-        valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
+        auto err = [this](std::string id, size_t num, size_t m) {
+          if (num > m) {
+            std::stringstream s;
+            s << "Used " << id << " (" << num << ") is greater than the allowed maximum (" << m
+              << ")";
+            errors_.push_back(s.str());
+          }
+        };
+        err("threads per block", thread_per_block_, max_threads_per_block_);
+        err("local memory per block", local_memory_per_block_, max_local_memory_per_block_);
+        err("shared memory per block", shared_memory_per_block_, max_shared_memory_per_block_);
       }
     } else {
       StmtVisitor::VisitStmt_(op);
@@ -143,7 +169,13 @@ class GPUCodeVerifier : public StmtExprVisitor {
       const auto* extent = op->extent.as<IntImmNode>();
       CHECK(extent);
 
-      valid_ &= static_cast<size_t>(extent->value) <= max_vthread_;
+      size_t num_vthread = static_cast<size_t>(extent->value);
+      if (num_vthread > max_vthread_) {
+        std::stringstream s;
+        s << "Number of vthreads (" << num_vthread << ") is greater than the allowed maximum ("
+          << max_vthread_ << ")";
+        errors_.push_back(s.str());
+      }
     }
 
     StmtVisitor::VisitStmt_(op);
@@ -151,15 +183,27 @@ class GPUCodeVerifier : public StmtExprVisitor {
 
   void VisitExpr_(const LoadNode* op) {
     if (op->dtype.lanes() > 1) {
-      valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
+      if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
+        std::stringstream s;
+        s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
+          << op->dtype.bytes() << ") for dtype " << op->dtype
+          << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
+        errors_.push_back(s.str());
+      }
     }
     ExprVisitor::VisitExpr_(op);
   }
 
   void VisitStmt_(const StoreNode* op) {
     if (op->index->dtype.lanes() > 1) {
-      valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) <=
-                max_vector_bytes_;
+      if (static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) >
+          max_vector_bytes_) {
+        std::stringstream s;
+        s << "Number of lanes (" << op->index->dtype.lanes() << ") times number of bytes ("
+          << op->index->dtype.bytes() << ") for dtype " << op->index->dtype
+          << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
+        errors_.push_back(s.str());
+      }
     }
     StmtVisitor::VisitStmt_(op);
   }
@@ -183,7 +227,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
   size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_;
   size_t max_vector_bytes_;
 
-  bool valid_{true};
+  std::vector<String> errors_;
 
   void Reset_() {
     visited_local_buffers_.clear();
@@ -196,7 +240,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
   }
 };
 
-bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
+std::vector<String> VerifyGPUCode_(const PrimFunc& func, Map<String, PrimExpr> constraints) {
   GPUCodeVerifier verifier;
 
   int64_t max_local_memory_per_block = INT64_MAX;
@@ -236,6 +280,11 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
                          max_vthread, max_vector_bytes);
 }
 
+bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
+  auto errs = VerifyGPUCode_(func, constraints);
+  return errs.size() == 0;
+}
+
 TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
 
 namespace transform {
@@ -245,7 +294,16 @@ Pass VerifyGPUCode(Map<String, PrimExpr> constraints) {
     for (auto kv : mod->functions) {
       if (auto* n = kv.second.as<PrimFuncNode>()) {
         auto func = GetRef<PrimFunc>(n);
-        CHECK(VerifyGPUCode(func, constraints)) << "RuntimeError: GPU constraint violated" << func;
+        auto errs = VerifyGPUCode_(func, constraints);
+        if (errs.size() != 0) {
+          std::stringstream s;
+          for (auto& err : errs) {
+            s << "    " << err << std::endl;
+          }
+          LOG(FATAL) << "RuntimeError: GPU constraint(s) violated:\n"
+                     << s.str() << "  In function\n"
+                     << func;
+        }
       }
     }
     return mod;
index dfad549..64097e1 100644 (file)
@@ -62,20 +62,14 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
   }
 
   /// Verification result
-  bool Failed() const { return failure_; }
+  std::vector<String> Errors() const { return errs_; }
 
  protected:
   /// Visitor implementation
   //@{
-  void VisitExpr(const PrimExpr& n) final {
-    if (Failed()) return;
-    StmtExprVisitor::VisitExpr(n);
-  }
+  void VisitExpr(const PrimExpr& n) final { StmtExprVisitor::VisitExpr(n); }
 
-  void VisitStmt(const Stmt& n) final {
-    if (Failed()) return;
-    StmtExprVisitor::VisitStmt(n);
-  }
+  void VisitStmt(const Stmt& n) final { StmtExprVisitor::VisitStmt(n); }
 
   void VisitStmt_(const LetStmtNode* op) final {
     // Book keep definitions
@@ -139,7 +133,11 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
     if (!IsFromFunctionArgs(var.get())) return;
 
     // The verification fails in this case.
-    SetFailure();
+    std::stringstream s;
+    s << "Variable `" << var
+      << "` is directly accessed by host memory (it is not contained in a thread environment or in "
+         "the function arguments.";
+    errs_.push_back(s.str());
   }
 
   /// Status getter/setter
@@ -147,7 +145,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
   bool InThreadEnv() const { return in_thread_env_; }
   void EnterThreadEnv() { in_thread_env_ = true; }
   void ExitThreadEnv() { in_thread_env_ = false; }
-  void SetFailure() { failure_ = true; }
   //@}
 
   /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
@@ -162,7 +159,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
   /// Status of visitor
   //@{
   bool in_thread_env_{false};
-  bool failure_{false};  ///< If the verification fails (i.e. has illegal access)
+  std::vector<String> errs_;
   //@}
   tir::PrimFunc func_{nullptr};                        ///< Function to be verified.
   int dev_type_{kDLCPU};                               ///< Device type
@@ -171,7 +168,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
 }  // namespace
 
 /// Interface of VerifyMemory pass
-bool VerifyMemory(const PrimFunc& func) {
+std::vector<String> VerifyMemory_(const PrimFunc& func) {
   auto target = func->GetAttr<Target>(tvm::attr::kTarget);
   CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute";
 
@@ -179,30 +176,37 @@ bool VerifyMemory(const PrimFunc& func) {
       CallingConv::kDefault) {
     MemoryAccessVerifier v(func, target.value()->kind->device_type);
     v.Run();
-    return !v.Failed();
+    return v.Errors();
   } else {
-    return true;
+    return {};
   }
 }
 
+bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0; }
+
 TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory);
 
 namespace transform {
 
 Pass VerifyMemory() {
-  auto pass_func =
-      [=](IRModule mod, PassContext ctx) {
-        for (auto kv : mod->functions) {
-          if (auto* n = kv.second.as<PrimFuncNode>()) {
-            auto func = GetRef<PrimFunc>(n);
-            CHECK(VerifyMemory(func))
-                << "RuntimeError: Direct host side access to device memory is detected."
-                << " Did you forget to bind?\n"
-                << func;
+  auto pass_func = [=](IRModule mod, PassContext ctx) {
+    for (auto kv : mod->functions) {
+      if (auto* n = kv.second.as<PrimFuncNode>()) {
+        auto func = GetRef<PrimFunc>(n);
+        auto errs = VerifyMemory_(func);
+        if (errs.size() > 0) {
+          std::stringstream s;
+          for (auto& err : errs) {
+            s << "    " << err << "\n";
           }
+          LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n"
+                     << s.str() << "  Did you forget to bind?\n"
+                     << func;
         }
-        return mod;
-      };
+      }
+    }
+    return mod;
+  };
   return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {});
 }