[tensorexpr] Simplify x/100 -> 0 if x is a non-negative integer less than 100. (...
authorHui Guo <huiguo@fb.com>
Sat, 11 Sep 2021 03:30:06 +0000 (20:30 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sat, 11 Sep 2021 03:33:02 +0000 (20:33 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64763

Simplification pattern:
  x/N -> 0; N is a constant positive integer and x is a for-loop index whose range is a subset of [0, N).

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D30845854

Pulled By: huiguoo

fbshipit-source-id: 814d69ed4be05e57405c222183cc1c6c526721cd

test/cpp/tensorexpr/test_simplify.cpp
torch/csrc/jit/tensorexpr/ir_simplifier.cpp

index 48983c8..9de5713 100644 (file)
@@ -1109,6 +1109,27 @@ TEST(Simplify, SimplifyDiv) {
   }
 }
 
+TEST(Simplify, SimplifyDivWithLoopContext0) {
+  // Stmt to simplify:
+  // for (int i = 0; i < 100; i++) {
+  //  A[i] = i / 100;
+  //}
+  VarHandle i("i", kInt);
+  BufHandle a_buf("A", {100}, kInt);
+  auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100)));
+
+  const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
+
+  std::ostringstream oss;
+  oss << *(simplified);
+  const std::string& verification_pattern =
+      R"IR(
+# CHECK: for (int i
+# CHECK-NEXT:   A[i] = 0;
+      )IR";
+  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
+}
+
 TEST(Simplify, SimplifyDivWithLoopContext1) {
   // Stmt to simplify:
   // for (int i = 0; i < 6; i++) {
index 3ce1943..3ed51b9 100644 (file)
@@ -2875,12 +2875,33 @@ ExprPtr SimplifierUnderContext::mutate(DivPtr v) {
 
   std::ostringstream oss;
   if (auto ret = distributeDiv(lhs, rhs, var_bound_info_)) {
+    GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret);
     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
-    oss << "SimplifierUnderContext: " << *v << " => " << *ret << "\n";
-    GRAPH_DEBUG(oss.str());
     return ret->accept_mutator(this);
   }
 
+  // i / N -> 0 if the range of i's values is a subset of [0, N)
+  // where N is an integer constant
+  auto lhsVar = to<Var>(lhs);
+  ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
+  if (lhsVar && rhsScalar && !rhsScalar->dtype().is_floating_point()) {
+    auto got = var_bound_info_.find(lhsVar);
+    if (got != var_bound_info_.end()) {
+      auto start = got->second.first;
+      auto end = got->second.second;
+      ExprPtr check_start = IRSimplifier::simplify(
+          alloc<CompareSelect>(start, immLike(start, 0), kGE));
+      ExprPtr check_end =
+          IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
+      if (check_start->isConstant() && check_end->isConstant() &&
+          immediateEquals(check_start, 1) && immediateEquals(check_end, 1)) {
+        GRAPH_DEBUG(
+            "SimplifierUnderContext: ", *v, " => ", *immLike(lhsVar, 0));
+        return immLike(lhsVar, 0);
+      }
+    }
+  }
+
   ExprPtr lhs_new = lhs->accept_mutator(this);
   ExprPtr rhs_new = rhs->accept_mutator(this);
   if (lhs == lhs_new && rhs == rhs_new) {
@@ -2895,8 +2916,7 @@ ExprPtr SimplifierUnderContext::mutate(ModPtr v) {
 
   std::ostringstream oss;
   if (auto ret = distributeMod(lhs, rhs, var_bound_info_)) {
-    oss << "SimplifierUnderContext: " << *v << " => " << *ret << "\n";
-    GRAPH_DEBUG(oss.str());
+    GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret);
     // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
     return ret->accept_mutator(this);
   }
@@ -2916,8 +2936,7 @@ ExprPtr SimplifierUnderContext::mutate(ModPtr v) {
           IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
       if (check_start->isConstant() && check_end->isConstant() &&
           immediateEquals(check_start, 1) && immediateEquals(check_end, 1)) {
-        oss << "SimplifierUnderContext: " << *v << " => " << *lhsVar << "\n";
-        GRAPH_DEBUG(oss.str());
+        GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *lhsVar);
         return lhsVar;
       }
     }