From 72274e2a2fd55019ec860e1743dbdc5b0c5a5624 Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Wed, 8 Sep 2021 00:22:05 -0700 Subject: [PATCH] [TensorExpr] Don't rely on exceptions in Vectorizer. (#64609) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64609 We've been using exceptions to indicate whether vectorization succeeded or not, but that posed some problems with (e.g. we spent too much time symbolicazing these exceptions). This change converts this mechanism to a standard error return code. Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D30795342 Pulled By: ZolotukhinM fbshipit-source-id: 16e38b37bcdd78ceb438ac814cc377f35b058e17 --- test/cpp/tensorexpr/test_kernel.cpp | 71 ++++++++++++++++++++++++++++++++++ torch/csrc/jit/tensorexpr/loopnest.cpp | 51 ++++++++++++++++-------- 2 files changed, 105 insertions(+), 17 deletions(-) diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index f4d3b16..c4bf777 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -1418,5 +1418,76 @@ TEST_F(Kernel, CustomLowering) { torch::jit::testing::FileCheck().check("isnan")->run(oss.str()); } +TEST_F(Kernel, Vectorize) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%0 : Float(100, 16, strides=[16, 1], device=cpu), + %1 : Float(100, 16, strides=[16, 1], device=cpu)): + %2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1) + %3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 100 * 16; i++) { + CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +#endif +} + +// TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first. +TEST_F(Kernel, DISABLED_FlattenVectorize) { +#ifdef TORCH_ENABLE_LLVM + const auto graph_string = R"IR( + graph(%0 : Float(100, 3, strides=[3, 1], device=cpu), + %1 : Float(100, 3, strides=[3, 1], device=cpu)): + %2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1) + %3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2) + return (%3))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat)); + auto ref = a * (a * b); + TensorExprKernel k(graph); + std::vector inputs = {a, b}; + StmtPtr s = k.getCodeGenStmt(); + + std::ostringstream oss; + oss << *s; + + // Check the IR we produced + const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); + + std::vector stack = fmap(inputs); + k.run(stack); + o = stack[0].toTensor(); + for (size_t i = 0; i < 100 * 3; i++) { + CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]); + } +#endif +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index e67d094..4a70700 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -130,13 +130,15 @@ class Vectorizer : public IRMutator { auto start_imm = intValue(start); auto stop_imm = intValue(stop); if (!start_imm) { - throw std::runtime_error( - "Can't vectorize due to non-constant loop start!"); + // Can't vectorize due to non-constant loop start! + success_ = false; + return v; } if (!stop_imm) { - throw std::runtime_error( - "Can't vectorize due to non-constant loop stop!"); + // Can't vectorize due to non-constant loop stop! + success_ = false; + return v; } var_ = var; @@ -145,12 +147,18 @@ class Vectorizer : public IRMutator { StmtPtr new_body = body->accept_mutator(this); if (new_body == body) { - throw std::runtime_error("Vectorization failed!"); + // Vectorization failed! + success_ = false; + return v; } return new_body; } + bool success() const { + return success_; + } + ExprPtr mutate(AddPtr v) override { std::vector inputs = {v->lhs(), v->rhs()}; return try_vectorize(v, inputs, [&]() { @@ -269,7 +277,9 @@ class Vectorizer : public IRMutator { ExprPtr mutate(VarPtr v) override { if (v == var_) { - return Ramp::make(ExprHandle(start_), 1, lanes_).node(); + return Ramp::make( + ExprHandle(start_), ExprHandle(immLike(start_, 1)), lanes_) + .node(); } return v; @@ -286,7 +296,9 @@ class Vectorizer : public IRMutator { return v; } - throw std::runtime_error("Can't vectorize a Ramp!"); + // Can't vectorize a Ramp! + success_ = false; + return v; } ExprPtr mutate(LoadPtr v) override { @@ -317,14 +329,18 @@ class Vectorizer : public IRMutator { return v; } - throw std::runtime_error("Can't vectorize a Broadcast!"); + // Can't vectorize a Broadcast! + success_ = false; + return v; } ExprPtr mutate(IfThenElsePtr v) override { ExprPtr condition = v->condition(); ExprPtr new_condition = condition->accept_mutator(this); if (new_condition != condition) { - throw std::runtime_error("Can't vectorize an IfThenElse condition!"); + // Can't vectorize an IfThenElse condition! + success_ = false; + return v; } std::vector inputs = {v->true_value(), v->false_value()}; @@ -360,8 +376,9 @@ class Vectorizer : public IRMutator { ExprPtr new_stop = stop->accept_mutator(this); if (new_start != start || new_stop != stop) { - throw std::runtime_error( - "Can't vectorize nested For with dependent loop bounds!"); + // Can't vectorize nested For with dependent loop bounds! + success_ = false; + return v; } StmtPtr body = v->body(); @@ -452,6 +469,7 @@ class Vectorizer : public IRMutator { VarPtr var_ = nullptr; int lanes_ = 0; ExprPtr start_ = nullptr; + bool success_ = true; }; bool LoopNest::vectorize(ForPtr f) { @@ -471,12 +489,11 @@ bool LoopNest::vectorize(ForPtr f) { Vectorizer v; StmtPtr new_f = nullptr; - try { - new_f = Stmt::clone(f); - normalize(to(new_f)); - new_f = FlattenIndexes(new_f); - new_f = v.vectorize(to(new_f)); - } catch (std::runtime_error& e) { + new_f = Stmt::clone(f); + normalize(to(new_f)); + new_f = FlattenIndexes(new_f); + new_f = v.vectorize(to(new_f)); + if (!v.success()) { // We clone f before vectorizing. So, any partial vectorization will // have modified the clone. In case of an exception, we can continue // using f. -- 2.7.4