}
}
+TEST(Simplify, SimplifyDivWithLoopContext0) {
+ // Stmt to simplify:
+ // for (int i = 0; i < 100; i++) {
+ // A[i] = i / 100;
+ //}
+ VarHandle i("i", kInt);
+ BufHandle a_buf("A", {100}, kInt);
+ auto for_stmt = For::make(i, 0, 100, Store::make(a_buf, {i}, (i / 100)));
+
+ const StmtPtr simplified = IRSimplifier::simplify(for_stmt);
+
+ std::ostringstream oss;
+ oss << *(simplified);
+ const std::string& verification_pattern =
+ R"IR(
+# CHECK: for (int i
+# CHECK-NEXT: A[i] = 0;
+ )IR";
+ torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
+}
+
TEST(Simplify, SimplifyDivWithLoopContext1) {
// Stmt to simplify:
// for (int i = 0; i < 6; i++) {
std::ostringstream oss;
if (auto ret = distributeDiv(lhs, rhs, var_bound_info_)) {
+ GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret);
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
- oss << "SimplifierUnderContext: " << *v << " => " << *ret << "\n";
- GRAPH_DEBUG(oss.str());
return ret->accept_mutator(this);
}
+ // i / N -> 0 if the range of i's values is a subset of [0, N)
+ // where N is an integer constant
+ auto lhsVar = to<Var>(lhs);
+ ExprPtr rhsScalar = rhs->isConstant() ? rhs : nullptr;
+ if (lhsVar && rhsScalar && !rhsScalar->dtype().is_floating_point()) {
+ auto got = var_bound_info_.find(lhsVar);
+ if (got != var_bound_info_.end()) {
+ auto start = got->second.first;
+ auto end = got->second.second;
+ ExprPtr check_start = IRSimplifier::simplify(
+ alloc<CompareSelect>(start, immLike(start, 0), kGE));
+ ExprPtr check_end =
+ IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
+ if (check_start->isConstant() && check_end->isConstant() &&
+ immediateEquals(check_start, 1) && immediateEquals(check_end, 1)) {
+ GRAPH_DEBUG(
+ "SimplifierUnderContext: ", *v, " => ", *immLike(lhsVar, 0));
+ return immLike(lhsVar, 0);
+ }
+ }
+ }
+
ExprPtr lhs_new = lhs->accept_mutator(this);
ExprPtr rhs_new = rhs->accept_mutator(this);
if (lhs == lhs_new && rhs == rhs_new) {
std::ostringstream oss;
if (auto ret = distributeMod(lhs, rhs, var_bound_info_)) {
- oss << "SimplifierUnderContext: " << *v << " => " << *ret << "\n";
- GRAPH_DEBUG(oss.str());
+ GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret);
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
return ret->accept_mutator(this);
}
IRSimplifier::simplify(alloc<CompareSelect>(end, rhsScalar, kLE));
if (check_start->isConstant() && check_end->isConstant() &&
immediateEquals(check_start, 1) && immediateEquals(check_end, 1)) {
- oss << "SimplifierUnderContext: " << *v << " => " << *lhsVar << "\n";
- GRAPH_DEBUG(oss.str());
+ GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *lhsVar);
return lhsVar;
}
}