* Improved uncommon case of floormod and floordiv. Removed dependence on np floor_div and fmod.
* Fixed clang-format complaints
* Streamlined floormod and floordiv lowering logic
* Improved build times by expressing int64 case of tir FloorMod and FloorDiv using let nodes
* Updated use-def analysis and llvm codegen to support duplicated letnodes.
* Corrected misuse of var_map_ in llvm codegen
* Updated backends that support LetNode
* Changed floormod and div lowering logic to avoid using FP on systems that don't support it.
* Fixed formatting
Co-authored-by: pankratz <pankratz@ualberta.ca>
}
llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
- CHECK(!var_map_.count(op->var.get()));
+ auto it = let_binding_.find(op->var);
+ if (it != let_binding_.end()) {
+ CHECK(deep_equal_(it->second->value, op->value))
+ << "Let cannot bind the same var to two different values";
+ } else {
+ let_binding_[op->var] = op;
+ }
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
#include <tvm/ir/module.h>
#include <tvm/runtime/container.h>
#include <tvm/target/codegen.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
std::unordered_set<const VarNode*> alias_var_set_;
// set of volatile buffer.
std::unordered_set<const VarNode*> volatile_buf_;
+ // deep comparison of PrimExpr
+ ExprDeepEqual deep_equal_;
+ // binding of let variables. Enables duplicate var defs that map to same value
+ std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// Cache potential common path ops to slightly improve lookup time.
// global symbol table.
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
}
void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
+ auto it = let_binding_.find(op->var);
+ if (it != let_binding_.end()) {
+ CHECK(deep_equal_(it->second->value, op->value))
+ << "Let cannot bind the same var to two different values";
+ } else {
+ let_binding_[op->var] = op;
+ }
std::string value = PrintExpr(op->value);
- CHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value;
os << PrintExpr(op->body);
}
#include <tvm/ir/op.h>
#include <tvm/runtime/container.h>
#include <tvm/target/codegen.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
bool print_ssa_form_{false};
/*! \brief set of volatile buf access */
std::unordered_set<const VarNode*> volatile_buf_;
+ // deep comparison of PrimExpr
+ ExprDeepEqual deep_equal_;
+ // binding of let variables. Enables duplicate var defs that map to same value
+ std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};
} // namespace codegen
}
spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
- CHECK(!var_map_.count(op->var.get()));
+ auto it = let_binding_.find(op->var);
+ if (it != let_binding_.end()) {
+ CHECK(deep_equal_(it->second->value, op->value))
+ << "Let cannot bind the same var to two different values";
+ } else {
+ let_binding_[op->var] = op;
+ }
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
#define TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_
#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
std::unordered_map<const VarNode*, spirv::Value> var_map_;
// The analyzer.
std::unique_ptr<arith::Analyzer> analyzer_;
+ // deep comparison of PrimExpr
+ ExprDeepEqual deep_equal_;
+ // binding of let variables. Enables duplicate var defs that map to same value
+ std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};
} // namespace codegen
}
}
} else {
- // uncommon case
- DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
- // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
- // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
- PrimExpr rdiv = truncdiv(op->a, op->b);
- PrimExpr rmod = truncmod(op->a, op->b);
- return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
- rdiv - make_const(dtype, 1));
+ if (dtype.is_float()) {
+ // floor(a / b)
+ return VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>());
+ } else {
+ // uncommon case
+ DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor";
+ auto rmod = tir::Var("rmod", dtype);
+ auto rdiv = tir::Var("rdiv", dtype);
+ // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1)
+ // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1)
+ PrimExpr let_rdiv =
+ tir::Let(rdiv, truncdiv(op->a, op->b),
+ tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv,
+ rdiv - make_const(dtype, 1)));
+ return Let(rmod, truncmod(op->a, op->b), let_rdiv);
+ }
}
}
}
}
} else {
- // uncommon case
- DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
- PrimExpr rmod = truncmod(op->a, op->b);
- // b > 0 && rmod >= 0 -> rmod
- // b > 0 && rmod < 0 -> rmod + b
- // b < 0 && rmod < 0 -> rmod
- // b < 0 && rmod > 0 -> rmod + b
- return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b);
+ if (dtype.is_float()) {
+ // a - floor(a / b) * b
+ return op->a - (VisitExpr_(tvm::floor(op->a / op->b).as<CallNode>()) * op->b);
+ } else {
+ // uncommon case
+ DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident";
+ auto rmod = tir::Var("rmod", dtype);
+ // b > 0 && rmod >= 0 -> rmod
+ // b > 0 && rmod < 0 -> rmod + b
+ // b < 0 && rmod < 0 -> rmod
+ // b < 0 && rmod > 0 -> rmod + b
+ return Let(
+ rmod, truncmod(op->a, op->b),
+ Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b));
+ }
}
}
}
PrimExpr VisitExpr_(const LetNode* op) final {
- this->HandleDef(op->var.get());
+ // Weaker SSA condition
+ // A single var can be binded in multiple lets
+ // but they have to bind to the same value.
+ // This is used to allow cases when we reuse a single let
+ // expression to construct a nested expr.
+ // (let x = 1 in x + 1) * (let x = 1 in x + 1)
+ auto it = let_binding_.find(op->var);
+ PrimExpr value = this->VisitExpr(op->value);
+ if (it != let_binding_.end()) {
+ CHECK(deep_equal_(it->second->value, value))
+ << "Let cannot bind the same var to two different values";
+ return GetRef<PrimExpr>(it->second);
+ } else {
+ this->HandleDef(op->var.get());
+ let_binding_[op->var] = op;
+ }
PrimExpr body = this->VisitExpr(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState &&
simplify_let_) {
return body;
} else {
- PrimExpr value = this->VisitExpr(op->value);
if (body.same_as(op->body) && value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
Array<PrimExpr> thread_extent_;
std::unordered_map<const VarNode*, int> use_count_;
std::unordered_map<const VarNode*, int> def_count_;
+
+ private:
+ ExprDeepEqual deep_equal_;
+ std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
rhs_npy, rhs_nd = gen_operand(rhs_shape, rhs_min, rhs_max, ctx)
out_npy = fnumpy(lhs_npy, rhs_npy)
- if fnumpy == np.floor_divide:
- # avoid check too close to X.5 and X.0
- # FIXME: floor_divide(94.90735, 0.6731018) behaves as floor(div(94.90735, 0.6731018))
- # However the result is somehow incorrect - need to further investigate.
- # And looks like numpy's floor_div(a,b) is implemented different from floor(div(a,b))
- mask = np.logical_or(np.abs(np.abs(np.fmod(lhs_npy / rhs_npy, 1)) - 0.5) < 1e-6,
- np.abs(np.fmod(lhs_npy / rhs_npy, 1)) < 1e-6)
- if mask.any():
- lhs_npy = lhs_npy + mask * 1e-3 * rhs_npy
- lhs_npy = lhs_npy.astype(dtype)
- lhs_nd = tvm.nd.array(lhs_npy, ctx) if lhs_shape is not None else lhs_npy.item()
- out_npy = fnumpy(lhs_npy, rhs_npy)
-
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
foo(lhs_nd, rhs_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
(2, 3, 1, 32), (64, 32), topi.divide, np.divide, rhs_min=0.0001)
def test_floor_divide():
+ def _canonical_floor_div(a,b):
+ return np.floor(a / b)
verify_broadcast_binary_ele(
- None, (10,), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
+ None, (10,), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
verify_broadcast_binary_ele(
- (), None, topi.floor_divide, np.floor_divide, rhs_min=0.0001)
+ (), None, topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
verify_broadcast_binary_ele(
- (2, 3, 64, 32), (64, 32), topi.floor_divide, np.floor_divide, rhs_min=0.0001)
+ (2, 3, 64, 32), (64, 32), topi.floor_divide, _canonical_floor_div, rhs_min=0.0001)
def test_maximum_minmum():
verify_broadcast_binary_ele(
(1, 2, 2), (2,), topi.mod, np.mod, lhs_min=0.001, rhs_min=1, dtype="int32")
def test_floor_mod():
+ def _canonical_floor_mod(a,b):
+ return a - np.floor(a / b) * b
verify_broadcast_binary_ele(
- (1, 2, 2), (2,), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="int32")
+ (1, 2, 2), (2,), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="int32")
verify_broadcast_binary_ele(
- (3, 4, 5), (3, 4, 5), topi.floor_mod, np.fmod, lhs_min=0.001, rhs_min=1, dtype="float32")
+ (3, 4, 5), (3, 4, 5), topi.floor_mod, _canonical_floor_mod, lhs_min=0.001, rhs_min=1, dtype="float32")
def test_cmp():
# explicit specify the output type