[TensorExpr] Don't rely on exceptions in Vectorizer. (#64609)
authorMikhail Zolotukhin <mvz@fb.com>
Wed, 8 Sep 2021 07:22:05 +0000 (00:22 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 07:25:34 +0000 (00:25 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64609

We've been using exceptions to indicate whether vectorization succeeded
or not, but that posed some problems with (e.g. we spent too much time
symbolicazing these exceptions). This change converts this mechanism to
a standard error return code.

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D30795342

Pulled By: ZolotukhinM

fbshipit-source-id: 16e38b37bcdd78ceb438ac814cc377f35b058e17

test/cpp/tensorexpr/test_kernel.cpp
torch/csrc/jit/tensorexpr/loopnest.cpp

index f4d3b16..c4bf777 100644 (file)
@@ -1418,5 +1418,76 @@ TEST_F(Kernel, CustomLowering) {
   torch::jit::testing::FileCheck().check("isnan")->run(oss.str());
 }
 
+TEST_F(Kernel, Vectorize) {
+#ifdef TORCH_ENABLE_LLVM
+  const auto graph_string = R"IR(
+      graph(%0 : Float(100, 16, strides=[16, 1], device=cpu),
+            %1 : Float(100, 16, strides=[16, 1], device=cpu)):
+        %2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1)
+        %3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2)
+        return (%3))IR";
+  auto graph = std::make_shared<Graph>();
+  parseIR(graph_string, &*graph);
+
+  auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto ref = a * (a * b);
+  TensorExprKernel k(graph);
+  std::vector<at::Tensor> inputs = {a, b};
+  StmtPtr s = k.getCodeGenStmt();
+
+  std::ostringstream oss;
+  oss << *s;
+
+  // Check the IR we produced
+  const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR";
+  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
+
+  std::vector<IValue> stack = fmap<IValue>(inputs);
+  k.run(stack);
+  o = stack[0].toTensor();
+  for (size_t i = 0; i < 100 * 16; i++) {
+    CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+  }
+#endif
+}
+
+// TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first.
+TEST_F(Kernel, DISABLED_FlattenVectorize) {
+#ifdef TORCH_ENABLE_LLVM
+  const auto graph_string = R"IR(
+      graph(%0 : Float(100, 3, strides=[3, 1], device=cpu),
+            %1 : Float(100, 3, strides=[3, 1], device=cpu)):
+        %2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1)
+        %3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2)
+        return (%3))IR";
+  auto graph = std::make_shared<Graph>();
+  parseIR(graph_string, &*graph);
+
+  auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto ref = a * (a * b);
+  TensorExprKernel k(graph);
+  std::vector<at::Tensor> inputs = {a, b};
+  StmtPtr s = k.getCodeGenStmt();
+
+  std::ostringstream oss;
+  oss << *s;
+
+  // Check the IR we produced
+  const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR";
+  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
+
+  std::vector<IValue> stack = fmap<IValue>(inputs);
+  k.run(stack);
+  o = stack[0].toTensor();
+  for (size_t i = 0; i < 100 * 3; i++) {
+    CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+  }
+#endif
+}
+
 } // namespace jit
 } // namespace torch
index e67d094..4a70700 100644 (file)
@@ -130,13 +130,15 @@ class Vectorizer : public IRMutator {
     auto start_imm = intValue(start);
     auto stop_imm = intValue(stop);
     if (!start_imm) {
-      throw std::runtime_error(
-          "Can't vectorize due to non-constant loop start!");
+      // Can't vectorize due to non-constant loop start!
+      success_ = false;
+      return v;
     }
 
     if (!stop_imm) {
-      throw std::runtime_error(
-          "Can't vectorize due to non-constant loop stop!");
+      // Can't vectorize due to non-constant loop stop!
+      success_ = false;
+      return v;
     }
 
     var_ = var;
@@ -145,12 +147,18 @@ class Vectorizer : public IRMutator {
 
     StmtPtr new_body = body->accept_mutator(this);
     if (new_body == body) {
-      throw std::runtime_error("Vectorization failed!");
+      // Vectorization failed!
+      success_ = false;
+      return v;
     }
 
     return new_body;
   }
 
+  bool success() const {
+    return success_;
+  }
+
   ExprPtr mutate(AddPtr v) override {
     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
     return try_vectorize(v, inputs, [&]() {
@@ -269,7 +277,9 @@ class Vectorizer : public IRMutator {
 
   ExprPtr mutate(VarPtr v) override {
     if (v == var_) {
-      return Ramp::make(ExprHandle(start_), 1, lanes_).node();
+      return Ramp::make(
+                 ExprHandle(start_), ExprHandle(immLike(start_, 1)), lanes_)
+          .node();
     }
 
     return v;
@@ -286,7 +296,9 @@ class Vectorizer : public IRMutator {
       return v;
     }
 
-    throw std::runtime_error("Can't vectorize a Ramp!");
+    // Can't vectorize a Ramp!
+    success_ = false;
+    return v;
   }
 
   ExprPtr mutate(LoadPtr v) override {
@@ -317,14 +329,18 @@ class Vectorizer : public IRMutator {
       return v;
     }
 
-    throw std::runtime_error("Can't vectorize a Broadcast!");
+    // Can't vectorize a Broadcast!
+    success_ = false;
+    return v;
   }
 
   ExprPtr mutate(IfThenElsePtr v) override {
     ExprPtr condition = v->condition();
     ExprPtr new_condition = condition->accept_mutator(this);
     if (new_condition != condition) {
-      throw std::runtime_error("Can't vectorize an IfThenElse condition!");
+      // Can't vectorize an IfThenElse condition!
+      success_ = false;
+      return v;
     }
 
     std::vector<ExprPtr> inputs = {v->true_value(), v->false_value()};
@@ -360,8 +376,9 @@ class Vectorizer : public IRMutator {
     ExprPtr new_stop = stop->accept_mutator(this);
 
     if (new_start != start || new_stop != stop) {
-      throw std::runtime_error(
-          "Can't vectorize nested For with dependent loop bounds!");
+      // Can't vectorize nested For with dependent loop bounds!
+      success_ = false;
+      return v;
     }
 
     StmtPtr body = v->body();
@@ -452,6 +469,7 @@ class Vectorizer : public IRMutator {
   VarPtr var_ = nullptr;
   int lanes_ = 0;
   ExprPtr start_ = nullptr;
+  bool success_ = true;
 };
 
 bool LoopNest::vectorize(ForPtr f) {
@@ -471,12 +489,11 @@ bool LoopNest::vectorize(ForPtr f) {
 
   Vectorizer v;
   StmtPtr new_f = nullptr;
-  try {
-    new_f = Stmt::clone(f);
-    normalize(to<For>(new_f));
-    new_f = FlattenIndexes(new_f);
-    new_f = v.vectorize(to<For>(new_f));
-  } catch (std::runtime_error& e) {
+  new_f = Stmt::clone(f);
+  normalize(to<For>(new_f));
+  new_f = FlattenIndexes(new_f);
+  new_f = v.vectorize(to<For>(new_f));
+  if (!v.success()) {
     // We clone f before vectorizing. So, any partial vectorization will
     // have modified the clone. In case of an exception, we can continue
     // using f.