#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
+#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>
#include <string>
TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
/*!
- * \brief Whether the expression have side effect.
+ * \brief Analyze the side effect
* \param expr The expression to be checked.
- * \return whether expression have side effect
+ *
+ * \return CallEffectKind, can be kPure, kReadState or kUpdateState
*/
-TVM_DLL bool HasSideEffect(const PrimExpr& expr);
+TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);
/*!
* \brief Whether e expression used any var in variable set..
// components which have side effects should also be preserved
for (size_t i = 0; i < used.size(); ++i) {
- if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) ||
- HasSideEffect(op->combiner->result[i])) {
+ if (SideEffect(op->source[i]) > CallEffectKind::kReadState ||
+ SideEffect(op->combiner->identity_element[i]) > CallEffectKind::kReadState ||
+ SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState) {
mark_used(i);
}
}
Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) {
PrimExpr value = this->VisitExpr(op->value);
- if (!tir::HasSideEffect(value)) {
+ if (SideEffect(value) <= CallEffectKind::kPure) {
analyzer_->Bind(op->var, value);
}
// We keep the let-binding here
PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode* op) {
PrimExpr value = this->VisitExpr(op->value);
- if (!tir::HasSideEffect(value)) {
+ if (SideEffect(value) <= CallEffectKind::kPure) {
analyzer_->Bind(op->var, value);
}
// We keep the let-binding here
bool has_side_effect = false;
for (size_t i = 0; i < op->indices.size(); ++i) {
- if (HasSideEffect(op->indices[i])) has_side_effect = true;
+ if (SideEffect(op->indices[i]) > CallEffectKind::kReadState) has_side_effect = true;
}
if (has_side_effect) {
for (size_t i = 0; i < args_.size(); ++i) {
class SchedulePostProc : public StmtExprMutator {
public:
Stmt VisitStmt_(const LetStmtNode* op) final {
- if (!HasSideEffect(op->value)) {
+ if (SideEffect(op->value) <= CallEffectKind::kPure) {
var_value_[op->var.get()] = this->VisitExpr(op->value);
return this->VisitStmt(op->body);
} else {
class ExprSideEffect : public ExprVisitor {
public:
void VisitExpr(const PrimExpr& e) final {
- if (has_side_effect_) return;
+ if (kind_ == CallEffectKind::kUpdateState) return;
ExprVisitor::VisitExpr(e);
}
+ void VisitExpr_(const LoadNode* op) final {
+ this->UpdateEffect(CallEffectKind::kReadState);
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ void VisitExpr_(const BufferLoadNode* op) final {
+ this->UpdateEffect(CallEffectKind::kReadState);
+ ExprVisitor::VisitExpr_(op);
+ }
+
void VisitExpr_(const CallNode* op) final {
static auto op_call_effect = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
if (auto* ptr_op = op->op.as<OpNode>()) {
- auto effect_kind = op_call_effect[GetRef<Op>(ptr_op)];
- if (effect_kind != CallEffectKind::kPure && effect_kind != CallEffectKind::kExprAnnotation) {
- has_side_effect_ = true;
- return;
- } else {
- ExprVisitor::VisitExpr_(op);
- }
+ this->UpdateEffect(static_cast<CallEffectKind>(op_call_effect[GetRef<Op>(ptr_op)]->value));
} else {
- has_side_effect_ = true;
- return;
+ this->UpdateEffect(CallEffectKind::kOpaque);
+ }
+ ExprVisitor::VisitExpr_(op);
+ }
+
+ void UpdateEffect(CallEffectKind effect_kind) {
+ if (effect_kind > CallEffectKind::kUpdateState) {
+ effect_kind = CallEffectKind::kUpdateState;
+ }
+ if (effect_kind > kind_) {
+ kind_ = effect_kind;
}
}
- bool has_side_effect_{false};
+ CallEffectKind kind_{CallEffectKind::kPure};
};
-bool HasSideEffect(const PrimExpr& e) {
- ExprSideEffect v;
- v(e);
- return v.has_side_effect_;
+CallEffectKind SideEffect(const PrimExpr& e) {
+ ExprSideEffect visitor;
+ visitor(e);
+ return visitor.kind_;
}
} // namespace tir
return is_no_op(op->body) ? op->body : stmt;
}
Stmt VisitStmt_(const EvaluateNode* op) final {
- if (HasSideEffect(op->value)) return GetRef<Stmt>(op);
+ if (SideEffect(op->value) > CallEffectKind::kReadState) return GetRef<Stmt>(op);
return Evaluate(0);
}
private:
Stmt MakeEvaluate(PrimExpr value) {
- if (HasSideEffect(value)) {
+ if (SideEffect(value) > CallEffectKind::kReadState) {
return Evaluate(value);
} else {
return Evaluate(0);
Stmt MakeEvaluate(const Array<PrimExpr>& values) {
Stmt stmt;
for (PrimExpr e : values) {
- if (HasSideEffect(e)) {
+ if (SideEffect(e) > CallEffectKind::kReadState) {
if (stmt.defined()) {
stmt = SeqStmt({stmt, Evaluate(e)});
} else {
// Won't face the deep expression explosion problem as in Let expression.
// attempt to inline as much as possible if the value integer type(can be index).
if (!op->value.dtype().is_int()) return false;
- return !tir::HasSideEffect(op->value);
+ return SideEffect(op->value) <= CallEffectKind::kPure;
}
Stmt VisitStmt_(const LetStmtNode* op) {
this->HandleDef(op->var.get());
Stmt body = this->VisitStmt(op->body);
// eliminate unreferenced let
- if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && simplify_let_) {
+ if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
+ simplify_let_) {
return body;
} else {
PrimExpr value = this->VisitExpr(op->value);
this->HandleDef(op->var.get());
PrimExpr body = this->VisitExpr(op->body);
// eliminate unreferenced let
- if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && simplify_let_) {
+ if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
+ simplify_let_) {
return body;
} else {
PrimExpr value = this->VisitExpr(op->value);
#include <gtest/gtest.h>
#include <tvm/te/operation.h>
#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
-TEST(SimplePasses, HasSideEffect) {
+TEST(SimplePasses, SideEffect) {
using namespace tvm;
- auto n = te::var("n");
- Array<PrimExpr> shape;
- shape.push_back(n);
-
- auto A = te::placeholder(shape, DataType::Float(32), "A");
-
- CHECK(!tvm::tir::HasSideEffect(A[0]));
+ auto A = tir::Var("A", DataType::Handle());
+ auto i = tir::Var("i", DataType::Int(32));
+ CHECK(tir::SideEffect(tir::Load(DataType::Float(32), A, i, tir::const_true(1))) ==
+ tir::CallEffectKind::kReadState);
+ CHECK(tir::SideEffect(exp(tir::Cast(DataType::Float(32), i + 1))) == tir::CallEffectKind::kPure);
+ CHECK(tir::SideEffect(tir::Call(DataType::Handle(), tir::builtin::tvm_storage_sync(), {})) ==
+ tir::CallEffectKind::kUpdateState);
}
int main(int argc, char** argv) {