[nnc] Updated internal asserts to include more detailed error messages (#64118)
authorRaghavan Raman <raghavanr@fb.com>
Mon, 30 Aug 2021 11:38:00 +0000 (04:38 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 11:40:51 +0000 (04:40 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64118

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D30616944

Pulled By: navahgar

fbshipit-source-id: 35289696cc0e7faa01599304243b86f0febc6daf

torch/csrc/jit/tensorexpr/kernel.cpp
torch/csrc/jit/tensorexpr/kernel.h

index 0d0d19e..e4136d8 100644 (file)
@@ -34,7 +34,10 @@ static bool checkTypes(const ScalarType highType, const int typeConstraints) {
   }
 
   // assume JIT not supporting complex and qint yet
-  TORCH_INTERNAL_ASSERT((typeConstraints & (kQintTypes | kComplexTypes)) == 0);
+  TORCH_INTERNAL_ASSERT(
+      (typeConstraints & (kQintTypes | kComplexTypes)) == 0,
+      buildErrorMessage(
+          "Qint and Complex types are not supported in the fuser."));
   return false;
 }
 
@@ -63,6 +66,13 @@ namespace torch {
 namespace jit {
 namespace tensorexpr {
 
+std::string buildErrorMessage(const std::string& s) {
+  // TODO: Update this generic error message to include details regarding
+  // turning off the fuser.
+  static const std::string generic_error_message = "";
+  return s + " " + generic_error_message;
+}
+
 static int te_cuda_pointwise_loop_levels = -1;
 static int te_cuda_pointwise_block_count = -1;
 static int te_cuda_pointwise_block_size = -1;
@@ -164,13 +174,18 @@ c10::optional<at::Device> pickDeviceType(const std::shared_ptr<Graph>& graph) {
     for (auto const& input : node->inputs()) {
       if (auto tt = input->type()->cast<TensorType>()) {
         if (auto inputDevice = tt->device()) {
-          TORCH_INTERNAL_ASSERT(!device || *device == *inputDevice);
+          TORCH_INTERNAL_ASSERT(
+              !device || *device == *inputDevice,
+              buildErrorMessage(
+                  "Different devices specified for inputs to the fuser."));
           device = inputDevice;
         }
       }
     }
   }
-  TORCH_INTERNAL_ASSERT(device);
+  TORCH_INTERNAL_ASSERT(
+      device,
+      buildErrorMessage("Could not find device in fuser graph inputs."));
   return device;
 }
 
@@ -356,7 +371,9 @@ bool matmulIsSupported(const torch::jit::Node* node) {
 void annotateInputShapes(
     const std::shared_ptr<Graph>& graph,
     const std::vector<c10::optional<at::Tensor>>& example_inputs) {
-  TORCH_INTERNAL_ASSERT(graph->inputs().size() == example_inputs.size());
+  TORCH_INTERNAL_ASSERT(
+      graph->inputs().size() == example_inputs.size(),
+      buildErrorMessage("Given inputs do not match the fuser graph inputs."));
   for (size_t idx = 0; idx < example_inputs.size(); idx++) {
     if (auto t = example_inputs[idx]) {
       auto concrete_tensor_type = tensorTypeInCurrentExecutionContext(*t);
@@ -820,7 +837,10 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
         throw std::runtime_error("Empty input list is passed to aten::cat");
       }
 
-      TORCH_INTERNAL_ASSERT(n->input(1)->node()->kind() == prim::Constant);
+      TORCH_INTERNAL_ASSERT(
+          n->input(1)->node()->kind() == prim::Constant,
+          buildErrorMessage(
+              "aten::cat op's dim input is not constant in fuser."));
       int64_t dim = n->input(1)->node()->i(attr::value);
       auto shape = sizesForValue(inputs[0]);
       auto norm_dim = normalizeAndCheckIndex(dim, shape.size());
@@ -2689,7 +2709,11 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
         blockSize = default_uint8_blocksize;
       }
       std::vector<ForPtr> loops = l.getLoopStmtsFor(buf);
-      TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty");
+      TORCH_INTERNAL_ASSERT(
+          !loops.empty(),
+          buildErrorMessage(
+              "No loops found for the buffer " + buf->name_hint() +
+              " in the fuser."));
       ForPtr flattened = nullptr;
       LoopNest::flatten(loops, &flattened);
       assert(flattened);
index 4b92b02..bdb9802 100644 (file)
@@ -300,6 +300,8 @@ TORCH_API void annotateInputShapes(
 TORCH_API std::shared_ptr<Graph> removeUnusedSelfArgument(
     const std::shared_ptr<Graph>& graph);
 
+TORCH_API std::string buildErrorMessage(const std::string& s);
+
 } // namespace tensorexpr
 } // namespace jit
 } // namespace torch