[nnc] Provide helpful error messages about turning off the fuser (#64516)
authorBert Maher <bertrand@fb.com>
Wed, 8 Sep 2021 15:07:19 +0000 (08:07 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 15:10:22 +0000 (08:10 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64516

If fuser compilation fails due to a bug (which should be highly
unlikely at this point) we want to direct the user how to unblock themselves by
disabling fusion, in addition to requesting that they report a bug.
ghstack-source-id: 137398537

Test Plan: existing tests

Reviewed By: ZolotukhinM

Differential Revision: D30758051

fbshipit-source-id: 98be89f1b1d4fb3bc816f5b2634c618b9297930e

torch/csrc/jit/tensorexpr/analysis.h
torch/csrc/jit/tensorexpr/bounds_inference.cpp
torch/csrc/jit/tensorexpr/bounds_overlap.cpp
torch/csrc/jit/tensorexpr/cuda_codegen.cpp
torch/csrc/jit/tensorexpr/exceptions.h
torch/csrc/jit/tensorexpr/external_functions.cpp
torch/csrc/jit/tensorexpr/kernel.cpp

index 351eb87..6f02144 100644 (file)
@@ -266,7 +266,7 @@ class CreateBufferMap : public IRVisitor {
       auto add_node = to<Add>(v->value());
       auto mul_node = to<Mul>(v->value());
       // This means for now, v->value() can be Add or Mul
-      TORCH_INTERNAL_ASSERT((add_node || mul_node));
+      TORCH_INTERNAL_ASSERT(add_node || mul_node, buildErrorMessage());
       map_input_to_tensor_bufs_.emplace(v->buf()->name_hint(), v->buf());
     }
     v->value()->accept(this);
index 649fd0e..c47acd8 100644 (file)
@@ -29,7 +29,7 @@ BoundsInfo mergeTensorAccesses(
     }
 
     auto vtbIt = varToBuf.find(access->var());
-    TORCH_INTERNAL_ASSERT(vtbIt != varToBuf.end());
+    TORCH_INTERNAL_ASSERT(vtbIt != varToBuf.end(), buildErrorMessage());
     BufPtr buf = vtbIt->second;
     std::vector<TensorAccessBoundsInfo>& infos = ret[buf];
 
@@ -38,8 +38,10 @@ BoundsInfo mergeTensorAccesses(
     for (auto& TABI : infos) {
       TensorAccessKind kind = access->isWrite() ? kStore : kLoad;
       if (!distinctAccessKinds || kind == TABI.kind) {
-        TORCH_INTERNAL_ASSERT(TABI.start.size() == access->bounds().size());
-        TORCH_INTERNAL_ASSERT(TABI.stop.size() == access->bounds().size());
+        TORCH_INTERNAL_ASSERT(
+            TABI.start.size() == access->bounds().size(), buildErrorMessage());
+        TORCH_INTERNAL_ASSERT(
+            TABI.stop.size() == access->bounds().size(), buildErrorMessage());
         for (size_t i = 0; i < TABI.start.size(); ++i) {
           TABI.start[i] = IRSimplifier::simplify(
               alloc<Min>(TABI.start[i], access->bounds()[i].start, true));
@@ -275,7 +277,8 @@ HazardKind getPotentialHazards(
 }
 
 IndexBounds getIndexBounds(const TensorAccessBoundsInfo& tabi) {
-  TORCH_INTERNAL_ASSERT(tabi.start.size() == tabi.stop.size());
+  TORCH_INTERNAL_ASSERT(
+      tabi.start.size() == tabi.stop.size(), buildErrorMessage());
   IndexBounds ret(tabi.start.size());
   if (tabi.start.empty()) {
     return ret;
index fdfff12..ae58265 100644 (file)
@@ -194,7 +194,7 @@ std::vector<IndexBounds> subtractIndicesBounds(
     return {};
   }
   // All accesses to a buf must have the same dimensionality.
-  TORCH_INTERNAL_ASSERT(A.size() == B.size());
+  TORCH_INTERNAL_ASSERT(A.size() == B.size(), buildErrorMessage());
 
   // Each dimension can be sliced into multiple bound segments.
   std::vector<IndexBounds> boundSlices;
@@ -208,7 +208,8 @@ std::vector<IndexBounds> subtractIndicesBounds(
     for (auto slice : slices) {
       IndexBounds newRegion;
       newRegion.reserve(A.size());
-      TORCH_INTERNAL_ASSERT(remainingOuterBounds.size() == i);
+      TORCH_INTERNAL_ASSERT(
+          remainingOuterBounds.size() == i, buildErrorMessage());
 
       for (size_t j = 0; j < i; ++j) {
         newRegion.push_back(remainingOuterBounds[j]);
@@ -224,7 +225,7 @@ std::vector<IndexBounds> subtractIndicesBounds(
         remaining = A[i];
       } else {
         auto remainingSlices = subtractBound(remaining, slice);
-        TORCH_INTERNAL_ASSERT(remainingSlices.size() == 1);
+        TORCH_INTERNAL_ASSERT(remainingSlices.size() == 1, buildErrorMessage());
         remaining = remainingSlices[0];
       }
     }
index c23eda3..af0b014 100644 (file)
@@ -821,7 +821,7 @@ StmtPtr GPUMetaVarRewriter::mutate(BlockPtr v) {
     bool need_sync = false;
     // We never mask loops, they'll mask their contents.
     if (!segment.mask()) {
-      TORCH_INTERNAL_ASSERT(segment.stmts().size() == 1);
+      TORCH_INTERNAL_ASSERT(segment.stmts().size() == 1, buildErrorMessage());
       stmts.push_back(segment.stmts()[0]);
       continue;
     }
index 7194dfe..b6e97f7 100644 (file)
@@ -84,7 +84,7 @@ class malformed_ir : public std::runtime_error {
             "MALFORMED IR: " + err + " - " + std::to_string(stmt)) {}
 };
 
-TORCH_API std::string buildErrorMessage(const std::string& s);
+TORCH_API std::string buildErrorMessage(const std::string& s = "");
 
 } // namespace tensorexpr
 } // namespace jit
index a21455a..4809c41 100644 (file)
@@ -5,6 +5,7 @@
 #include <ATen/core/dispatch/Dispatcher.h>
 #include <ATen/native/xnnpack/OpContext.h>
 #include <c10/util/irange.h>
+#include <torch/csrc/jit/tensorexpr/exceptions.h>
 #include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
 
 namespace torch {
@@ -65,7 +66,7 @@ void nnc_aten_conv2d(
   if (args_num > 0) {
     // Check that if the extra arguments are provided, then the bias tensor is
     // also present
-    TORCH_INTERNAL_ASSERT(args_num == 7 && bufs_num == 4);
+    TORCH_INTERNAL_ASSERT(args_num == 7 && bufs_num == 4, buildErrorMessage());
     const at::Tensor& b = tensors[3];
 
     int64_t strideH = extra_args[0];
index a86cb33..8a8aee7 100644 (file)
@@ -69,7 +69,7 @@ namespace tensorexpr {
 std::string buildErrorMessage(const std::string& s) {
   static const std::string generic_error_message =
       "This error occured in the fuser. You can turn off the fuser with "
-      "torch._C._jit_override_can_fuse_on_cpu(False)";
+      "torch.jit.enable_fusion(False).";
   if (s.empty()) {
     return generic_error_message;
   }