From 4481c87ac408560cb03aa3c81d0dcc1a77acdd4c Mon Sep 17 00:00:00 2001 From: Hui Guo Date: Fri, 10 Sep 2021 20:30:06 -0700 Subject: [PATCH] [tensorexpr] Simplify x/100 -> 0 if x is a non-negative integer less than 100. (#64763) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64763 Simplification pattern: x/N -> 0; N is a constant positive integer and x is a for-loop index whose range is a subset of [0, N). Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D30845854 Pulled By: huiguoo fbshipit-source-id: 814d69ed4be05e57405c222183cc1c6c526721cd --- test/cpp/tensorexpr/test_simplify.cpp | 21 +++++++++++++++++++ torch/csrc/jit/tensorexpr/ir_simplifier.cpp | 31 +++++++++++++++++++++++------ 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp index 48983c8..9de5713 100644 --- a/test/cpp/tensorexpr/test_simplify.cpp +++ b/test/cpp/tensorexpr/test_simplify.cpp @@ -1109,6 +1109,27 @@ TEST(Simplify, SimplifyDiv) { } } +TEST(Simplify, SimplifyDivWithLoopContext0) { + // Stmt to simplify: + // for (int i = 0; i < 100; i++) { + // A[i] = i / 100; + //} + VarHandle i("i", kInt); + BufHandle a_buf("A", {100}, kInt); + auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100))); + + const StmtPtr simplified = IRSimplifier::simplify(for_stmt); + + std::ostringstream oss; + oss << *(simplified); + const std::string& verification_pattern = + R"IR( +# CHECK: for (int i +# CHECK-NEXT: A[i] = 0; + )IR"; + torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); +} + TEST(Simplify, SimplifyDivWithLoopContext1) { // Stmt to simplify: // for (int i = 0; i < 6; i++) { diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index 3ce1943..3ed51b9 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -2875,12 +2875,33 @@ ExprPtr SimplifierUnderContext::mutate(DivPtr v) { std::ostringstream oss; if (auto ret = distributeDiv(lhs, rhs, var_bound_info_)) { + GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - oss << "SimplifierUnderContext: " << *v << " => " << *ret << "\n"; - GRAPH_DEBUG(oss.str()); return ret->accept_mutator(this); } + // i / N -> 0 if the range of i's values is a subset of [0, N) + // where N is an integer constant + auto lhsVar = to(lhs); + ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr; + if (lhsVar && rhsScalar && !rhsScalar->dtype().is_floating_point()) { + auto got = var_bound_info_.find(lhsVar); + if (got != var_bound_info_.end()) { + auto start = got->second.first; + auto end = got->second.second; + ExprPtr check_start = IRSimplifier::simplify( + alloc(start, immLike(start, 0), kGE)); + ExprPtr check_end = + IRSimplifier::simplify(alloc(end, rhsScalar, kLE)); + if (check_start->isConstant() && check_end->isConstant() && + immediateEquals(check_start, 1) && immediateEquals(check_end, 1)) { + GRAPH_DEBUG( + "SimplifierUnderContext: ", *v, " => ", *immLike(lhsVar, 0)); + return immLike(lhsVar, 0); + } + } + } + ExprPtr lhs_new = lhs->accept_mutator(this); ExprPtr rhs_new = rhs->accept_mutator(this); if (lhs == lhs_new && rhs == rhs_new) { @@ -2895,8 +2916,7 @@ ExprPtr SimplifierUnderContext::mutate(ModPtr v) { std::ostringstream oss; if (auto ret = distributeMod(lhs, rhs, var_bound_info_)) { - oss << "SimplifierUnderContext: " << *v << " => " << *ret << "\n"; - GRAPH_DEBUG(oss.str()); + GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret); // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return ret->accept_mutator(this); } @@ -2916,8 +2936,7 @@ ExprPtr SimplifierUnderContext::mutate(ModPtr v) { IRSimplifier::simplify(alloc(end, rhsScalar, kLE)); if (check_start->isConstant() && check_end->isConstant() && immediateEquals(check_start, 1) && immediateEquals(check_end, 1)) { - oss << "SimplifierUnderContext: " << *v << " => " << *lhsVar << "\n"; - GRAPH_DEBUG(oss.str()); + GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *lhsVar); return lhsVar; } } -- 2.7.4