[TIR][ANALYSIS] Refine side effect analysis. (#5954)
authorTianqi Chen <tqchen@users.noreply.github.com>
Mon, 29 Jun 2020 15:25:46 +0000 (08:25 -0700)
committerGitHub <noreply@github.com>
Mon, 29 Jun 2020 15:25:46 +0000 (08:25 -0700)
include/tvm/tir/analysis.h
src/arith/canonical_simplify.cc
src/arith/ir_mutator_with_analyzer.cc
src/te/schedule/operation_inline.cc
src/te/schedule/schedule_ops.cc
src/tir/analysis/side_effect.cc
src/tir/transforms/remove_no_op.cc
src/tir/transforms/simplify.cc
src/tir/transforms/split_host_device.cc
tests/cpp/tir_analysis_side_effect.cc [moved from tests/cpp/simple_passes_test.cc with 68% similarity]

index 6e7ed41..cbc7a51 100644 (file)
@@ -28,6 +28,7 @@
 #include <tvm/ir/transform.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/function.h>
+#include <tvm/tir/op_attr_types.h>
 #include <tvm/tir/stmt.h>
 
 #include <string>
@@ -64,11 +65,12 @@ struct ExprDeepEqual {
 TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
 
 /*!
- * \brief Whether the expression have side effect.
+ * \brief Analyze the side effect
  * \param expr The expression to be checked.
- * \return whether expression have side effect
+ *
+ * \return CallEffectKind, can be kPure, kReadState or kUpdateState
  */
-TVM_DLL bool HasSideEffect(const PrimExpr& expr);
+TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
 
 /*!
  * \brief Whether e expression used any var in variable set..
index 6139f73..726289c 100644 (file)
@@ -1018,8 +1018,9 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
 
   // components which have side effects should also be preserved
   for (size_t i = 0; i < used.size(); ++i) {
-    if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) ||
-        HasSideEffect(op->combiner->result[i])) {
+    if (SideEffect(op->source[i]) > CallEffectKind::kReadState ||
+        SideEffect(op->combiner->identity_element[i]) > CallEffectKind::kReadState ||
+        SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState) {
       mark_used(i);
     }
   }
index 2a02661..8fb69b3 100644 (file)
@@ -37,7 +37,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
 
 Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
   PrimExpr value = this->VisitExpr(op->value);
-  if (!tir::HasSideEffect(value)) {
+  if (SideEffect(value) <= CallEffectKind::kPure) {
     analyzer_->Bind(op->var, value);
   }
   // We keep the let-binding here
@@ -154,7 +154,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {
 
 PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode* op) {
   PrimExpr value = this->VisitExpr(op->value);
-  if (!tir::HasSideEffect(value)) {
+  if (SideEffect(value) <= CallEffectKind::kPure) {
     analyzer_->Bind(op->var, value);
   }
   // We keep the let-binding here
index fd613f4..aab30ed 100644 (file)
@@ -54,7 +54,7 @@ class OperationInliner final : public StmtExprMutator {
 
       bool has_side_effect = false;
       for (size_t i = 0; i < op->indices.size(); ++i) {
-        if (HasSideEffect(op->indices[i])) has_side_effect = true;
+        if (SideEffect(op->indices[i]) > CallEffectKind::kReadState) has_side_effect = true;
       }
       if (has_side_effect) {
         for (size_t i = 0; i < args_.size(); ++i) {
index f2955f3..e5124df 100644 (file)
@@ -147,7 +147,7 @@ class InjectScanStep : public StmtMutator {
 class SchedulePostProc : public StmtExprMutator {
  public:
   Stmt VisitStmt_(const LetStmtNode* op) final {
-    if (!HasSideEffect(op->value)) {
+    if (SideEffect(op->value) <= CallEffectKind::kPure) {
       var_value_[op->var.get()] = this->VisitExpr(op->value);
       return this->VisitStmt(op->body);
     } else {
index 923cda3..5613961 100644 (file)
@@ -33,34 +33,47 @@ namespace tir {
 class ExprSideEffect : public ExprVisitor {
  public:
   void VisitExpr(const PrimExpr& e) final {
-    if (has_side_effect_) return;
+    if (kind_ == CallEffectKind::kUpdateState) return;
     ExprVisitor::VisitExpr(e);
   }
 
+  void VisitExpr_(const LoadNode* op) final {
+    this->UpdateEffect(CallEffectKind::kReadState);
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    this->UpdateEffect(CallEffectKind::kReadState);
+    ExprVisitor::VisitExpr_(op);
+  }
+
   void VisitExpr_(const CallNode* op) final {
     static auto op_call_effect = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
 
     if (auto* ptr_op = op->op.as<OpNode>()) {
-      auto effect_kind = op_call_effect[GetRef<Op>(ptr_op)];
-      if (effect_kind != CallEffectKind::kPure && effect_kind != CallEffectKind::kExprAnnotation) {
-        has_side_effect_ = true;
-        return;
-      } else {
-        ExprVisitor::VisitExpr_(op);
-      }
+      this->UpdateEffect(static_cast<CallEffectKind>(op_call_effect[GetRef<Op>(ptr_op)]->value));
     } else {
-      has_side_effect_ = true;
-      return;
+      this->UpdateEffect(CallEffectKind::kOpaque);
+    }
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  void UpdateEffect(CallEffectKind effect_kind) {
+    if (effect_kind > CallEffectKind::kUpdateState) {
+      effect_kind = CallEffectKind::kUpdateState;
+    }
+    if (effect_kind > kind_) {
+      kind_ = effect_kind;
     }
   }
 
-  bool has_side_effect_{false};
+  CallEffectKind kind_{CallEffectKind::kPure};
 };
 
-bool HasSideEffect(const PrimExpr& e) {
-  ExprSideEffect v;
-  v(e);
-  return v.has_side_effect_;
+CallEffectKind SideEffect(const PrimExpr& e) {
+  ExprSideEffect visitor;
+  visitor(e);
+  return visitor.kind_;
 }
 
 }  // namespace tir
index cd3a4b7..baa1c3c 100644 (file)
@@ -90,7 +90,7 @@ class NoOpRemover : public StmtMutator {
     return is_no_op(op->body) ? op->body : stmt;
   }
   Stmt VisitStmt_(const EvaluateNode* op) final {
-    if (HasSideEffect(op->value)) return GetRef<Stmt>(op);
+    if (SideEffect(op->value) > CallEffectKind::kReadState) return GetRef<Stmt>(op);
     return Evaluate(0);
   }
 
@@ -127,7 +127,7 @@ class NoOpRemover : public StmtMutator {
 
  private:
   Stmt MakeEvaluate(PrimExpr value) {
-    if (HasSideEffect(value)) {
+    if (SideEffect(value) > CallEffectKind::kReadState) {
       return Evaluate(value);
     } else {
       return Evaluate(0);
@@ -136,7 +136,7 @@ class NoOpRemover : public StmtMutator {
   Stmt MakeEvaluate(const Array<PrimExpr>& values) {
     Stmt stmt;
     for (PrimExpr e : values) {
-      if (HasSideEffect(e)) {
+      if (SideEffect(e) > CallEffectKind::kReadState) {
         if (stmt.defined()) {
           stmt = SeqStmt({stmt, Evaluate(e)});
         } else {
index 3088b6b..df8816c 100644 (file)
@@ -60,7 +60,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
     // Won't face the deep expression explosion problem as in Let expression.
     // attempt to inline as much as possible if the value integer type(can be index).
     if (!op->value.dtype().is_int()) return false;
-    return !tir::HasSideEffect(op->value);
+    return SideEffect(op->value) <= CallEffectKind::kPure;
   }
 
   Stmt VisitStmt_(const LetStmtNode* op) {
index 75ae743..169ac14 100644 (file)
@@ -70,7 +70,8 @@ class VarUseDefAnalysis : public StmtExprMutator {
     this->HandleDef(op->var.get());
     Stmt body = this->VisitStmt(op->body);
     // eliminate unreferenced let
-    if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && simplify_let_) {
+    if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
+        simplify_let_) {
       return body;
     } else {
       PrimExpr value = this->VisitExpr(op->value);
@@ -101,7 +102,8 @@ class VarUseDefAnalysis : public StmtExprMutator {
     this->HandleDef(op->var.get());
     PrimExpr body = this->VisitExpr(op->body);
     // eliminate unreferenced let
-    if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && simplify_let_) {
+    if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
+        simplify_let_) {
       return body;
     } else {
       PrimExpr value = this->VisitExpr(op->value);
similarity index 68%
rename from tests/cpp/simple_passes_test.cc
rename to tests/cpp/tir_analysis_side_effect.cc
index 36b3645..26dedab 100644 (file)
 #include <gtest/gtest.h>
 #include <tvm/te/operation.h>
 #include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
 
-TEST(SimplePasses, HasSideEffect) {
+TEST(SimplePasses, SideEffect) {
   using namespace tvm;
-  auto n = te::var("n");
-  Array<PrimExpr> shape;
-  shape.push_back(n);
-
-  auto A = te::placeholder(shape, DataType::Float(32), "A");
-
-  CHECK(!tvm::tir::HasSideEffect(A[0]));
+  auto A = tir::Var("A", DataType::Handle());
+  auto i = tir::Var("i", DataType::Int(32));
+  CHECK(tir::SideEffect(tir::Load(DataType::Float(32), A, i, tir::const_true(1))) ==
+        tir::CallEffectKind::kReadState);
+  CHECK(tir::SideEffect(exp(tir::Cast(DataType::Float(32), i + 1))) == tir::CallEffectKind::kPure);
+  CHECK(tir::SideEffect(tir::Call(DataType::Handle(), tir::builtin::tvm_storage_sync(), {})) ==
+        tir::CallEffectKind::kUpdateState);
 }
 
 int main(int argc, char** argv) {