Migrate simplifier to new infra. (#3368)
authorTianqi Chen <tqchen@users.noreply.github.com>
Mon, 1 Jul 2019 21:07:45 +0000 (14:07 -0700)
committerGitHub <noreply@github.com>
Mon, 1 Jul 2019 21:07:45 +0000 (14:07 -0700)
19 files changed:
CMakeLists.txt
include/tvm/arithmetic.h
include/tvm/ir_pass.h
src/arithmetic/analyzer.cc
src/arithmetic/bound_deducer.cc
src/arithmetic/const_fold.h
src/arithmetic/rewrite_simplify.cc
src/arithmetic/stmt_simplify.cc
src/lang/buffer.cc
src/op/scan_op.cc
src/pass/loop_partition.cc
src/pass/narrow_channel_access.cc
src/pass/storage_rewrite.cc
src/pass/vectorize_loop.cc
src/schedule/message_passing.cc
src/schedule/schedule_dataflow_rewrite.cc
tests/cpp/ir_simplify_test.cc
tests/python/unittest/test_arith_deduce_bound.py
tests/python/unittest/test_pass_basic.py

index 6500ba0..c23d403 100644 (file)
@@ -154,7 +154,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS
 file(GLOB TOPI_SRCS
     topi/src/*.cc
 )
-file(GLOB_RECURSE HALIDEIR_SRCS 3rdparty/HalideIR/src/*.cpp)
+file(GLOB_RECURSE HALIDEIR_SRCS
+  3rdparty/HalideIR/src/base/*.cpp
+  3rdparty/HalideIR/src/ir/*.cpp
+  3rdparty/HalideIR/src/tvm/*.cpp
+)
 list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
 file(GLOB RUNTIME_SRCS
   src/runtime/*.cc
index 92f7399..446c4c0 100644 (file)
@@ -623,12 +623,15 @@ IntSet Intersect(const Array<IntSet>& sets);
  *  give the domain of each variables. Return undefined IntSet to
  *  represent failure.
  *
+ * \note The returned set may be smaller than set that
+ *       contains all possible values of v that satisfies the bound.
+ *
  * \param v The target variable to be deduced.
  * \param cond The conditional expression.
  * \param hint_map The domain of variable, used to help deduce.
  * \param relax_map The domain of each variable, used to relax the domain,
- *        The deduce bound mush implies e for all value in relax_map
- * \return An integer set that can cover all the possible values.
+ *        The deduce bound must implies e for all value in relax_map
+ * \return An integer set that always satisfies the condition.
  */
 IntSet DeduceBound(Expr v, Expr cond,
                    const Map<Var, IntSet>& hint_map,
@@ -641,7 +644,7 @@ IntSet DeduceBound(Expr v, Expr cond,
  * \param hint_map The domain of variable, used to help deduce.
  * \param relax_map The domain of each variable, used to relax the domain,
  *        The deduce bound mush implies e for all value in relax_map
- * \return An integer set that can cover all the possible values.
+ * \return An integer set that always satisfies the condition.
  */
 IntSet DeduceBound(Expr v, Expr cond,
                    const std::unordered_map<const Variable*, IntSet>& hint_map,
index e1c92e5..98dbf6b 100644 (file)
@@ -27,7 +27,6 @@
 #ifndef TVM_IR_PASS_H_
 #define TVM_IR_PASS_H_
 
-#include <arithmetic/Simplify.h>
 #include <unordered_map>
 #include <unordered_set>
 #include <vector>
index 2198aee..626fc18 100644 (file)
@@ -106,6 +106,7 @@ bool Analyzer::CanProve(const Expr& expr) {
 Expr Analyzer::Simplify(const Expr& expr) {
   if (is_const(expr)) return expr;
   auto res = this->rewrite_simplify(expr);
+  if (is_const(res)) return res;
   res = this->canonical_simplify(res);
   return res;
 }
index 395a371..003ba8d 100644 (file)
@@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor {
   void Deduce();
 
   void Visit(const NodeRef& e) final {
-    if (!success) return;
+    if (!success_) return;
     if (e.get() == path_[iter_++]) {
       IRVisitor::Visit(e);
     } else {
-      success = false;
+      success_ = false;
       return;
     }
   }
@@ -111,18 +111,18 @@ class BoundDeducer: public IRVisitor {
 
   void Visit_(const Add* op) final {
     bool left = op->a.get() == path_[iter_];
-    result -= left ? op->b : op->a;
+    result_ -= left ? op->b : op->a;
     Visit(left ? op->a : op->b);
   }
 
   void Visit_(const Sub* op) final {
     bool left = op->a.get() == path_[iter_];
     if (left) {
-      result += op->b;
+      result_ += op->b;
     } else {
-      result -= op->a;
-      result = - result;
-      is_greater = !is_greater;
+      result_ -= op->a;
+      result_ = - result_;
+      is_greater_ = !is_greater_;
     }
     Visit(left ? op->a : op->b);
   }
@@ -130,43 +130,65 @@ class BoundDeducer: public IRVisitor {
   void Visit_(const Mul* op) final {
     bool left = op->a.get() == path_[iter_];
     Expr operand = left ? op->b : op->a;
+    Expr target_var = left ? op->a : op->b;
 
-    SignType sign;
+    SignType sign_operand;
     if (operand.type().is_uint()) {
-      sign = kPositive;
+      sign_operand = kPositive;
     } else {
-      sign = expr_map_[operand].sign_type();
+      sign_operand = expr_map_[operand].sign_type();
     }
 
-    if (sign == SignType::kNegative) {
-      is_greater = !is_greater;
-    } else if (sign == SignType::kUnknown) {
+    if (sign_operand == SignType::kNegative) {
+      is_greater_ = !is_greater_;
+    } else if (sign_operand == SignType::kUnknown) {
       // unable to get the sign of operand
-      success = false;
+      success_ = false;
       return;
     }
-
     // always use relax bound
-    bool divided = can_prove(result % operand == 0);
-    result = result / operand;
-    // since system will round down when not divided
-    // eg. 2/4 -> 0; -2/4 -> -1
-    // no need fix for !is_greater:
-    // eg. a <= 2/4 -> a <= 0
-    // eg. a <= 0/4 -> a <= 0
-    // so just fix for not divided and is_greater
-    // eg. a >= 2/4 -> a >= 0 + 1
-    // eg. a >= 0/4 -> a >= 0
-    if (is_greater && !divided) {
-       result += 1;
+    bool divided = analyzer_.CanProve(result_ % operand == 0);
+
+    result_ = result_ / operand;
+
+    if (!divided) {
+      // Handle non-divisible case
+      // NOTE: this accounts for truc div behavior.
+      bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative();
+
+      if (is_greater_) {
+        result_ += 1;
+      } else {
+        // NOTE: this is a bit sutble hack.
+        //
+        // condition:
+        // - x * operand <= result
+        // - operand > 0
+        // - x >= 0
+        //
+        // Then it is fine to deduce that x <= result / operand.
+        // - if result > 0,  this division round down
+        // - if result < 0, (result / operand) rounds up and may violate the constraint
+        //   however, given that x is always non-negative,
+        //   it is fine to have this relaxed bound, given that the user of deduce bound
+        //   will respect the bound of x
+        //
+        // TODO(tvm-team): think about a better API to incorporate constraint of x.
+        //                 e.g. specify an interval of x and return a bound
+        //                 that is in the interval and satisfies the condition.
+        if (target_is_non_neg && sign_operand == kPositive) {
+          // do nothing
+        } else {
+          result_ -= 1;
+        }
+      }
     }
-
     Visit(left ? op->a : op->b);
   }
 
-  Expr result;
-  bool is_greater{true};
-  bool success{true};
+  Expr result_;
+  bool is_greater_{true};
+  bool success_{true};
 
  private:
   void Init();
@@ -180,6 +202,8 @@ class BoundDeducer: public IRVisitor {
   ExprIntSetMap expr_map_;
   std::vector<const Node*> path_;
   size_t iter_{0};
+  // internal analzyer
+  Analyzer analyzer_;
 };
 
 class BoundDeduceInputChecker: public IRVisitor {
@@ -202,7 +226,7 @@ class BoundDeduceInputChecker: public IRVisitor {
 
 void BoundDeducer::Init() {
   BoundDeduceInputChecker checker;
-  if (!checker.Check(this)) success = false;
+  if (!checker.Check(this)) success_ = false;
   Transform();
 }
 
@@ -211,66 +235,65 @@ void BoundDeducer::Transform() {
   if (const LT* op = expr_.as<LT>()) {
     if (GetPath(target_, op->a).empty()) {
       // a < b -> b >= a + 1
-      is_greater = true;
+      is_greater_ = true;
       expr_ = op->b;
-      result = op->a + 1;
+      result_ = op->a + 1;
     } else {
       // a < b -> a <= b - 1
-      is_greater = false;
+      is_greater_ = false;
       expr_ = op->a;
-      result = op->b - 1;
+      result_ = op->b - 1;
     }
   } else if (const LE* op = expr_.as<LE>()) {
     if (GetPath(target_, op->a).empty()) {
       // a <= b -> b >= a
-      is_greater = true;
+      is_greater_ = true;
       expr_ = op->b;
-      result = op->a;
+      result_ = op->a;
     } else {
-      is_greater = false;
+      is_greater_ = false;
       expr_ = op->a;
-      result = op->b;
+      result_ = op->b;
     }
   } else if (const GT* op = expr_.as<GT>()) {
     if (GetPath(target_, op->a).empty()) {
       // a > b -> b <= a - 1
-      is_greater = false;
+      is_greater_ = false;
       expr_ = op->b;
-      result = op->a - 1;
+      result_ = op->a - 1;
     } else {
       // a > b -> a >= b + 1
-      is_greater = true;
+      is_greater_ = true;
       expr_ = op->a;
-      result = op->b + 1;
+      result_ = op->b + 1;
     }
   } else if (const GE* op = expr_.as<GE>()) {
     if (GetPath(target_, op->a).empty()) {
       // a >= b -> b <= a
-      is_greater = false;
+      is_greater_ = false;
       expr_ = op->b;
-      result = op->a;
+      result_ = op->a;
     } else {
-      is_greater = true;
+      is_greater_ = true;
       expr_ = op->a;
-      result = op->b;
+      result_ = op->b;
     }
   } else {
-    success = false;
+    success_ = false;
   }
 }
 
 void BoundDeducer::Deduce() {
   Init();
-  if (!success) return;
+  if (!success_) return;
   Relax();
-  if (!success) return;
+  if (!success_) return;
   // get the path
   path_ = GetPath(target_, expr_);
   if (!path_.size()) {
-    success = false;
+    success_ = false;
     return;
   }
-
   expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
 
   Visit(expr_);
@@ -278,13 +301,13 @@ void BoundDeducer::Deduce() {
 
 void BoundDeducer::Relax() {
   IntSet a = EvalSet(expr_, relax_map_);
-  IntSet b = EvalSet(result, relax_map_);
+  IntSet b = EvalSet(result_, relax_map_);
   if (a.is_everything() || b.is_everything()) {
-    success = false;
+    success_ = false;
     return;
   }
-  expr_  = is_greater ? a.min() : a.max();
-  result = is_greater ? b.max() : b.min();
+  expr_  = is_greater_ ? a.min() : a.max();
+  result_ = is_greater_ ? b.max() : b.min();
 }
 
 IntSet DeduceBound(Expr v, Expr e,
@@ -292,12 +315,12 @@ IntSet DeduceBound(Expr v, Expr e,
   const std::unordered_map<const Variable*, IntSet>& relax_map) {
   BoundDeducer d(v, e, hint_map, relax_map);
   d.Deduce();
-  if (!d.success) return IntSet::nothing();
+  if (!d.success_) return IntSet::nothing();
   Expr min = neg_inf(), max = pos_inf();
-  if (d.is_greater) {
-    min = d.result;
+  if (d.is_greater_) {
+    min = d.result_;
   } else {
-    max = d.result;
+    max = d.result_;
   }
   return IntSet::interval(min, max);
 }
index ec50aef..dc6b80a 100644 (file)
@@ -155,9 +155,10 @@ template<>
 inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
   TVM_ARITH_CONST_PROPAGATION({
       const Type& rtype = a.type();
-      // due to division and mod can have different modes
-      // only constant fold positive number where rule is fixed.
-      if (pa && pb && pa->value >= 0 && pb->value > 0) {
+      if (pa && pb) {
+        // due to division and mod can have different modes
+        // NOTE: this will assumes truc div.
+        CHECK_NE(pb->value, 0) << "Divide by zero";
         return IntImm::make(rtype, pa->value / pb->value);
       }
       if (pa) {
index bc8666e..6cc829d 100644 (file)
@@ -155,7 +155,6 @@ Mutate_(const Add* op, const Expr& self) {
     TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y));
     TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z));
 
-
     TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y),
                        c1.Eval()->value == -c2.Eval()->value);
     TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y),
index 01cb96e..162cb1e 100644 (file)
@@ -28,7 +28,6 @@
 #include <tvm/ir_mutator.h>
 #include <tvm/expr_operator.h>
 #include <tvm/arithmetic.h>
-#include "arithmetic/Simplify.h"
 
 namespace tvm {
 namespace arith {
@@ -158,42 +157,18 @@ Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
   return analyzer.canonical_simplify(expr);
 }
 
-template<typename T>
-T Simplify_(T a, Map<Var, Range> vrange) {
-  using namespace HalideIR::Internal;
-  Scope<Interval> rscope;
+Expr Simplify(Expr expr, Map<Var, Range> vrange) {
+  arith::Analyzer analyzer;
   for (auto kv : vrange) {
-    Range r = kv.second;
-    rscope.push(
-        kv.first.get(),
-        Interval(r->min,
-                 simplify(r->min + r->extent - make_const(r->min.type(), 1))));
-  }
-  return HalideIR::Internal::simplify(a, true, rscope);
-}
-
-
-Expr Simplify(Expr a, Map<Var, Range> vrange) {
-  // Simplify top level reduce.
-  if (const Reduce* r = a.as<Reduce>()) {
-    Array<Expr> new_source;
-    for (auto& e : r->source) {
-      new_source.push_back(Simplify_(e, vrange));
-    }
-    Expr new_condition = Simplify_(r->condition, vrange);
-    if (r->source.same_as(new_source) &&
-        r->condition.same_as(new_condition)) {
-      return a;
-    } else {
-      return Reduce::make(
-              r->combiner, new_source, r->axis, new_condition, r->value_index);
-    }
+    analyzer.Bind(kv.first, kv.second);
   }
-  return Simplify_(a, vrange);
+  expr = analyzer.Simplify(expr);
+  return expr;
 }
 
-Stmt Simplify(Stmt a, Map<Var, Range> vrange) {
-  return Simplify_(a, vrange);
+Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
+  return arith::CanonicalStmtSimplifier().CanonicalSimplify(
+      stmt, vrange);
 }
 }  // namespace ir
 }  // namespace tvm
index 8c584c5..3e06151 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -26,6 +26,7 @@
 #include <tvm/ir.h>
 #include <tvm/ir_pass.h>
 #include <iterator>
+#include <stack>
 #include "../arithmetic/compute_expr.h"
 
 namespace tvm {
index 42b1331..78f8c82 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -80,7 +80,7 @@ Operation ScanOpNode::make(std::string name,
   for (size_t i = 0; i < init.size(); ++i) {
     CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype);
     CHECK_EQ(init[i]->dtype, update[i]->dtype);
-    CHECK(can_prove(init[i]->shape[0] == axis->dom->min))
+    CHECK(prove_equal(init[i]->shape[0], axis->dom->min))
         << "init.shape[0] need to match scan_axis.dom.min";
     CHECK(prove_equal(
         state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent))
index 0a5b741..33dbaed 100644 (file)
@@ -466,8 +466,13 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
                                    Stmt body,
                                    bool partition_thread_scope) {
   using namespace arith;
+  // include hint of var.
+  hint_map_.insert({var.get(), IntSet::interval(min, max)});
+
   PartitionFinder finder(var, hint_map_, relax_map_);
   finder.Visit(body);
+
+  hint_map_.erase(var.get());
   if (finder.partitions.empty()) return Stmt();
 
   arith::IntervalSet for_interval(min, max);
@@ -504,9 +509,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
   bool pre_stmt_recurse = true;
   if (middle_interval_i->HasLowerBound()) {
     body_begin = ir::Simplify(middle_interval.min());
-    if (!can_prove(body_begin == min)) {
+    if (!analyzer_.CanProve(body_begin == min)) {
       Expr cond = (body_begin - min >= 0);
-      if (!can_prove(cond)) {
+      if (!analyzer_.CanProve(cond)) {
         LOG(WARNING) << "Cannot prove: " << cond
                      << ", when generating the pre doubt loop";
         body_begin = Max::make(body_begin, min);
@@ -529,10 +534,10 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
   bool post_stmt_recurse = true;
   if (middle_interval_i->HasUpperBound()) {
     post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
-    if (!can_prove(middle_interval.max() == max)) {
+    if (!analyzer_.CanProve(middle_interval.max() == max)) {
       // require the extent to be non-negative
       Expr cond = (max - post_doubt_begin + 1 >= 0);
-      if (!can_prove(cond)) {
+      if (!analyzer_.CanProve(cond)) {
         LOG(WARNING) << "Cannot prove: " << cond
                      << ", when generating the post doubt loop";
         post_doubt_begin = Min::make(post_doubt_begin, max);
@@ -554,7 +559,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
   // Generating code for middle subrange
   if (!partition_thread_scope) {
     Stmt mid_stmt;
-    if (!can_prove(body_begin >= post_doubt_begin)) {
+    if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) {
       // [body_begin, post_doubt_begin)
       Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body);
       Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
@@ -576,8 +581,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
     s = AppendStmts(s, post_stmt);
   } else {
     Expr cond = const_true();
-    if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin);
-    if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
+    if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
+    if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
     s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt);
   }
   s = ConvertSSA(s);
@@ -587,7 +592,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
 inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) {
   const For *for_node = static_cast<const For*>(node);
   CHECK(for_node);
-  if (can_prove(extent == make_const(Int(32), 1))) {
+  if (analyzer_.CanProve(extent == make_const(Int(32), 1))) {
     // If the loop extent is 1, do not create the loop anymore
     return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}});
   } else {
index 731064e..57f3baf 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -200,7 +200,7 @@ class ChannelAccessRewriter : public IRMutator {
     Expr base = linear_eq[1];
     if (!is_zero(base)) return body;
     Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent);
-    if (!can_prove(left >= 0)) return body;
+    if (!analyzer_.CanProve(left >= 0)) return body;
     // rewrite access index.
     ChannelAccessIndexRewriter rw(
         ch->handle_var.get(), var * coeff, read_access);
@@ -233,6 +233,7 @@ class ChannelAccessRewriter : public IRMutator {
     return body;
   }
 
+  arith::Analyzer analyzer_;
   std::vector<RewriteEntry> tasks_;
 };
 
index 806a80a..eba1cee 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -606,7 +606,7 @@ class StoragePlanRewriter : public IRMutator {
           }
           // transform to alloc bytes
           auto type_bits = alloc_type.bits() * alloc_type.lanes();
-          bool divided = can_prove(combo_size % type_bits == 0);
+          bool divided = analyzer_.CanProve(combo_size % type_bits == 0);
           combo_size = combo_size / type_bits;
           // round up for can not divided
           if (!divided) {
@@ -920,6 +920,8 @@ class StoragePlanRewriter : public IRMutator {
   std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
   // The allocations
   std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
+  // analyzer
+  arith::Analyzer analyzer_;
 };
 
 // Turn alloc into vector alloc
index 8c3d383..a48e8b4 100644 (file)
@@ -25,6 +25,7 @@
 #include <tvm/ir.h>
 #include <tvm/ir_pass.h>
 #include <tvm/ir_mutator.h>
+#include <tvm/arithmetic.h>
 #include <unordered_set>
 #include <unordered_map>
 #include <vector>
@@ -132,11 +133,11 @@ class Vectorizer : public IRMutator {
       if (lanes != 1) {
         const Ramp* b_ramp = b.as<Ramp>();
         const Ramp* a_ramp = a.as<Ramp>();
-        if (a_ramp && b.type().lanes() == 1 && can_prove(b > 0)) {
+        if (a_ramp && b.type().lanes() == 1 && analyzer_.CanProve(b > 0)) {
           return Ramp::make(
               a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
         }
-        if (b_ramp && a.type().lanes() == 1 && can_prove(a > 0)) {
+        if (b_ramp && a.type().lanes() == 1 && analyzer_.CanProve(a > 0)) {
           return Ramp::make(
               b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
         }
@@ -186,7 +187,7 @@ class Vectorizer : public IRMutator {
     Expr stride = this->Mutate(op->stride);
     if (base.type().lanes() > 1 && stride.type().lanes() == 1) {
       const Ramp* base_ramp = base.as<Ramp>();
-      if (can_prove(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) {
+      if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) {
         return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes);
       }
     }
@@ -423,6 +424,8 @@ class Vectorizer : public IRMutator {
   }
 
  private:
+  // analyzer
+  arith::Analyzer analyzer_;
   // variable to be replaced
   Var var_;
   // the lanes.
index a7f9746..0dc82ab 100644 (file)
@@ -432,9 +432,9 @@ void PassDownBitMaskOr(const Stage& stage,
  */
 void PassUpBoundCheck(const Stage& s,
                       const Map<IterVar, Range>& dom_map,
-                      std::unordered_map<IterVar, bool>* p_state) {
+                      std::unordered_map<IterVar, bool>* p_state,
+                      arith::Analyzer* analyzer) {
   auto& state = *p_state;
-  using HalideIR::Internal::can_prove;
   for (size_t i = s->relations.size(); i != 0; --i) {
     IterVarRelation rel = s->relations[i - 1];
     if (const SplitNode* s = rel.as<SplitNode>()) {
@@ -447,7 +447,7 @@ void PassUpBoundCheck(const Stage& s,
         if (outer || inner) {
           state[s->parent] = true;
         } else {
-          if (can_prove(dom_map.at(s->parent)->extent == factor * step)) {
+          if (analyzer->CanProve(dom_map.at(s->parent)->extent == factor * step)) {
             state[s->parent] = false;
           } else {
             state[s->parent] = true;
@@ -476,11 +476,13 @@ std::vector<Expr> MakeBoundCheck(
     const std::unordered_map<IterVar, Expr>& value_map,
     bool skip_ivar_domain,
     const std::unordered_set<IterVar>& skip_iter) {
+  Analyzer analyzer;
+
   std::unordered_map<IterVar, bool> bound_state;
   for (IterVar iv : stage->leaf_iter_vars) {
     bound_state[iv] = false;
   }
-  PassUpBoundCheck(stage, dom_map, &bound_state);
+  PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
 
   std::vector<Expr> preds;
   std::unordered_map<const Variable*, IntSet> iset_dmap;
@@ -496,7 +498,7 @@ std::vector<Expr> MakeBoundCheck(
       Range dom = dom_map.at(iv);
       Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min);
       Expr vmax = EvalSet(value, iset_dmap).max();
-      if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) {
+      if (vmax.type() != value.type() || !analyzer.CanProve(vmax < dom->extent)) {
         preds.emplace_back(value < dom->extent);
       }
     }
@@ -511,10 +513,10 @@ std::vector<Expr> MakeBoundCheck(
       Expr vmin = s.min();
       Expr vmax = s.max();
       // The range of `value` resides in [vmin, vmax]
-      if (vmin.type() != value.type() || !can_prove(vmin >= 0)) {
+      if (vmin.type() != value.type() || !analyzer.CanProve(vmin >= 0)) {
         preds.emplace_back(value >= 0);
       }
-      if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) {
+      if (vmax.type() != value.type() || !analyzer.CanProve(vmax < iv->dom->extent)) {
         preds.emplace_back(value < iv->dom->extent);
       }
     }
index c5f1b16..760ed0f 100644 (file)
@@ -740,7 +740,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
   const Reduce* reduce = compute_op->body[idx].as<Reduce>();
   CHECK(reduce) << "Can only rfactor non-inline reductions";
   predicates.push_back(reduce->condition);
-  Expr predicate = likely(simplify(arith::ComputeReduce<ir::And>(predicates, Expr())));
+  Expr predicate = likely(arith::ComputeReduce<ir::And>(predicates, Expr()));
 
   std::unordered_map<const Variable*, Expr> vsub;
 
index 35968f8..5a5dc03 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 #include <gtest/gtest.h>
 #include <tvm/ir_pass.h>
 #include <tvm/tvm.h>
-#include <arithmetic/Simplify.h>
-
-TEST(IRSIMPLIFY, Basic) {
-  using namespace HalideIR::Internal;
-  simplify_test();
-}
 
 TEST(IRSIMPLIFY, MinMax) {
   auto x = tvm::var("x");
index 7fe6f56..d26b508 100644 (file)
 # under the License.
 import tvm
 
+
+def assert_expr_equal(a, b):
+    res =  tvm.ir_pass.Simplify(a - b)
+    equal = isinstance(res, tvm.expr.IntImm) and res.value == 0
+    if not equal:
+        raise ValueError("{} and {} are not equal".format(a, b))
+
+
 def test_deduce():
     a = tvm.var('a')
     b = tvm.var('b')
@@ -29,31 +37,34 @@ def test_deduce():
 
     e0 = (-b)*a+c-d
     res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
-    ans0 = ((d - c) /(b*-1))
-    assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+    ans0 = ((d - c) /(b*-1) + (-1))
+    assert_expr_equal(res0.max_value, ans0)
 
     # expression containing variable a is on rhs
     res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
-    assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+    assert_expr_equal(res0.max_value, ans0)
 
     e0 = d*a+c-d
     res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
-    ans0 = ((0-c)/d + 1)
-    assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+    ans0 = ((d-c)/d - 1)
+    assert_expr_equal(res0.max_value, ans0)
 
     # expression containing variable a is on rhs
     res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
-    assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+    assert_expr_equal(res0.max_value, ans0)
+
 
     e1 = (a*4+b < c)
     res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
-    ans1 = (((c - b) + -1)/4)
-    assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1)
+    ans1 = (((c - b) + -1)/4 -1)
+    assert_expr_equal(res1.max_value, ans1)
+
 
     # expression containing variable a is on rhs
     e1 = (c > a*4+b)
     res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
-    assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1)
+    assert_expr_equal(res1.max_value, ans1)
+
 
     e2 = (tvm.max(5, a * 4) < 0)
     res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
@@ -66,7 +77,6 @@ def test_deduce():
     assert str(res2.max_value) == "neg_inf"
     assert str(res2.min_value) == "pos_inf"
 
-
     e3 = (-b)+a*c-d
     res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
     ans3 = 2/c+1
@@ -75,6 +85,7 @@ def test_deduce():
     res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
     assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
 
+
 def test_check():
     a = tvm.var('a')
     b = tvm.var('b')
index fc76c30..b05d75a 100644 (file)
@@ -24,9 +24,6 @@ def test_simplify():
   assert(tvm.ir_pass.Equal(e2, x * 8))
   e3 = tvm.ir_pass.Simplify(x - x / 3 * 3)
   assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3)))
-  let = tvm.make.Let(x, 1, x + 3)
-  e4 = tvm.ir_pass.Simplify(let)
-  assert(tvm.ir_pass.Equal(e4, 4))
 
 
 def test_verify_ssa():