CommReducer combiner;
/*! \brief The source operand */
Array<PrimExpr> source;
+ /*! \brief The init operand */
+ Array<PrimExpr> init;
/*! \brief The reduction axis */
Array<IterVar> axis;
/*!
v->Visit("dtype", &dtype);
v->Visit("combiner", &combiner);
v->Visit("source", &source);
+ v->Visit("init", &init);
v->Visit("axis", &axis);
v->Visit("condition", &condition);
v->Visit("value_index", &value_index);
// check axis first so IterVars can define the necessary variables.
return equal(dtype, other->dtype) && equal(axis, other->axis) &&
equal(combiner, other->combiner) && equal(source, other->source) &&
- equal(condition, other->condition) && equal(value_index, other->value_index);
+ equal(init, other->init) && equal(condition, other->condition) &&
+ equal(value_index, other->value_index);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(axis);
hash_reduce(combiner);
hash_reduce(source);
+ hash_reduce(init);
hash_reduce(condition);
hash_reduce(value_index);
}
class Reduce : public PrimExpr {
public:
TVM_DLL Reduce(CommReducer combiner, Array<PrimExpr> src, Array<IterVar> rdom, PrimExpr condition,
- int value_index);
+ int value_index, Array<PrimExpr> init);
TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode);
};
* \brief sum of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
+ * \param init The value with which to initialize the output.
* \return The result.
*/
-TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis);
+TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
/*!
* \brief logical And of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
+ * \param init The value with which to initialize the output.
*/
-TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis);
+TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
/*!
* \brief logical Or of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
+ * \param init The value with which to initialize the output.
* \return The result.
*/
-TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis);
+TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
/*!
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
+ * \param init The value with which to initialize the output.
* \return The result.
*/
-TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis);
+TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
/*!
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
+ * \param init The value with which to initialize the output.
* \return The result.
*/
-TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis);
+TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
/*!
* \brief product of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
+ * \param init The value with which to initialize the output.
* \return The result.
*/
-TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis);
+TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
/*!
* \brief Calculate floor(x)
using namespace tvm::te;
/*! \brief The operation to use for CommReduce */
-using FReduce = std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis)>;
+using FReduce =
+ std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis, Array<PrimExpr> init)>;
/*! \brief The operation to use for CommReduceIdx */
using FCommReduce = std::function<Array<PrimExpr>(Array<PrimExpr> exprs, const Array<IterVar>& axis,
arg_counter++;
}
- return func(data(eval_range), r_axes);
+ return func(data(eval_range), r_axes, {});
};
return tvm::te::compute(target_shape, compute, data->op->name + "_red", kCommReduce);
auto combiner = tvm::tir::CommReducer(lhs, rhs, result, id_elem);
Array<PrimExpr> outputs;
for (size_t i = 0; i < exprs.size(); ++i) {
- outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast<int>(i)));
+ outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast<int>(i), {}));
}
return outputs;
};
}
/*! \brief Wrap tvm::min to ensure we get the correct overload */
-inline PrimExpr MinOp(PrimExpr source, Array<IterVar> axis) { return tvm::min(source, axis); }
+inline PrimExpr MinOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {}) {
+ return tvm::min(source, axis, init);
+}
/*! \brief Wrap tvm::max to ensure we get the correct overload */
-inline PrimExpr MaxOp(PrimExpr source, Array<IterVar> axis) {
- return tvm::max(source, axis); // NOLINT(*)
+inline PrimExpr MaxOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {}) {
+ return tvm::max(source, axis, init); // NOLINT(*)
}
/*! \brief Wrap tvm::prod to ensure we get the correct overload */
-inline PrimExpr ProdOp(PrimExpr source, Array<IterVar> axis) {
- return tvm::prod(source, axis); // NOLINT(*)
+inline PrimExpr ProdOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {}) {
+ return tvm::prod(source, axis, init); // NOLINT(*)
}
/*!
value_index : int
The value index.
+
+ init : list of Expr
+ The initial value for output. This can be an int, float or ProducerLoad
"""
- def __init__(self, combiner, src, rdom, condition, value_index):
+ def __init__(self, combiner, src, rdom, condition, value_index, init=None):
self.__init_handle_by_constructor__(
_ffi_api.Reduce, combiner, src, rdom,
- condition, value_index)
+ condition, value_index, init)
@tvm._ffi.register_object
res = fcombine(res, args[i+1])
return res
- def _make_reduce(expr, axis, where=None):
+ def _make_reduce(expr, axis, where=None, init=None):
code = fcombine.__code__
assert fcombine.__code__.co_argcount == 2
expr = convert(expr)
+ if init is not None:
+ init = convert(init)
if isinstance(expr, Array):
size = len(expr)
larr = []
larr.append(Var(lname, dtype))
rname = code.co_varnames[1] + "_" + str(i)
rarr.append(Var(rname, dtype))
+ if init is not None:
+ init = convert(init)
+ assert isinstance(init, Array)
+ assert len(init) == size
+ for init_i in range(size):
+ init_i = convert(init_i)
+ assert isinstance(init_i,
+ (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm))
+ else:
+ init = convert([])
lhs = convert(larr)
rhs = convert(rarr)
result = fcombine(lhs, rhs)
lhs = convert([lvar])
rhs = convert([rvar])
expr = convert([expr])
+ if init is not None:
+ assert isinstance(init, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm))
+ init = convert([init])
result = convert(result)
id_elem = convert(id_elem)
combiner = CommReducer(lhs, rhs, result, id_elem)
axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
if where is None:
where = convert(True)
- outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i)
- for i in range(size))
+ if init is None:
+ outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, convert([]))
+ for i in range(size))
+ else:
+ outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, init)
+ for i in range(size))
return outputs[0] if size == 1 else outputs
# pylint: disable=keyword-arg-before-vararg
- def reducer(expr, axis, where=None, *args):
+ def reducer(expr, axis, where=None, init=None, *args):
if isinstance(axis, (tvm.tir.IterVar, list, tuple)):
assert not args
- return _make_reduce(expr, axis, where)
+ return _make_reduce(expr, axis, where, init)
if where is None:
assert not args
return _reduce_directly(expr, axis)
for (size_t i = 0; i < used.size(); ++i) {
if (SideEffect(op->source[i]) > CallEffectKind::kReadState ||
SideEffect(op->combiner->identity_element[i]) > CallEffectKind::kReadState ||
- SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState) {
+ SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState ||
+ (!op->init.empty() && SideEffect(op->init[i]) > CallEffectKind::kReadState)) {
mark_used(i);
}
}
Array<Var> new_lhs;
Array<Var> new_rhs;
Array<PrimExpr> new_source;
+ Array<PrimExpr> new_init;
// new stuff is old stuff which is used
for (size_t i = 0; i < used.size(); ++i) {
new_lhs.push_back(op->combiner->lhs[i]);
new_rhs.push_back(op->combiner->rhs[i]);
new_source.push_back(op->source[i]);
+ if (!op->init.empty()) new_init.push_back(op->init[i]);
} else if (static_cast<int>(i) < op->value_index) {
// value_index should also be adjusted
new_value_index--;
}
CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity);
- return Reduce(new_combiner, new_source, op->axis, op->condition, new_value_index);
+ return Reduce(new_combiner, new_source, op->axis, op->condition, new_value_index, new_init);
}
PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
// already been simplified by const reduction axis removal
if (op == nullptr) return ret;
if (op->axis.empty()) {
+ if (!op->init.empty()) {
+ return this->VisitExpr(Select(op->condition,
+ (*op->combiner.get())(op->init, op->source)[op->value_index],
+ op->init[op->value_index]));
+ }
// Note that here we assume that the identity element is indeed identity. Without this
// assumption we would have to perform a single iteration of the loop, i.e. use
// `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) {
Doc doc;
doc << "reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis)
- << ", " << op->value_index << ")";
+ << ", " << op->value_index << ", " << Print(op->init) << ")";
return doc;
}
// Perform simplification mainly to remove a possibly empty reduction.
arith::Analyzer analyzer;
- return analyzer.Simplify(
- Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
- kSimplifyRewriteCanonicalRewrite);
+ return analyzer.Simplify(Reduce(red->combiner, new_source, new_axis, All(res->dst->relations),
+ red->value_index, red->init),
+ kSimplifyRewriteCanonicalRewrite);
} else {
return expr;
}
virtual PrimExpr VisitExpr_(const ReduceNode* op) {
Array<PrimExpr> known_with_axes = known_;
+ CHECK(op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
for (const PrimExpr& axis_cond : IterVarsToInequalities(op->axis)) {
known_with_axes.push_back(axis_cond);
}
new_source.push_back(new_mutator(src));
}
- return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index);
+ return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index, op->init);
}
virtual PrimExpr VisitExpr_(const EQNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
ReductionAsTensorAccessMutator new_mutator(Concat(IterVarsToVars(op->axis), outer_axis_),
Merge(vranges_, IterVarsToMap(op->axis)), name_);
+ CHECK(op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
Array<PrimExpr> new_source;
for (const PrimExpr& src : op->source) {
new_source.push_back(new_mutator(src));
}
PrimExpr new_reduce =
- Reduce(op->combiner, new_source, op->axis, op->condition, op->value_index);
+ Reduce(op->combiner, new_source, op->axis, op->condition, op->value_index, op->init);
Array<Var> undefined_vars = UndefinedVars(new_reduce);
std::unordered_set<const VarNode*> undefined_var_set;
}
PrimExpr new_condition = ReductionAsTensorAccess(red->condition, new_outer_axis, new_vranges);
- return Reduce(red->combiner, new_source, red->axis, new_condition, red->value_index);
+ return Reduce(red->combiner, new_source, red->axis, new_condition, red->value_index, red->init);
} else {
return ReductionAsTensorAccess(expr, outer_axis, vranges);
}
PrimExpr expr = analyzer.Simplify(expr_orig, kSimplifyRewriteCanonicalRewrite);
if (const ReduceNode* red = expr.as<ReduceNode>()) {
+ CHECK(red->init.empty()) << "Derivative of Reduction with initialization is not implemented";
// TODO(sgrechanik-h): There are some other operations which behave like sum
bool is_sum = IsSumCombiner(red->combiner, vranges);
if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index, vranges)) {
source.Set(0, nz.value);
}
- new_red = Reduce(red->combiner, source, red->axis, cond, red->value_index);
+ new_red = Reduce(red->combiner, source, red->axis, cond, red->value_index, red->init);
new_red = SimplifyReductionDomain(new_red, combined_vranges);
// If the reduction disappears completely then transform the result as a non-reduction
if (!new_red.as<ReduceNode>()) {
new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype())));
}
- PrimExpr new_reduce =
- Reduce(red->combiner, new_source, red->axis, new_reduce_cond, red->value_index);
+ PrimExpr new_reduce = Reduce(red->combiner, new_source, red->axis, new_reduce_cond,
+ red->value_index, red->init);
new_reduce =
TrySimplifyCompute(new_reduce, new_outer_cond, IterVarsToVars(axis), combined_vranges);
result = Select(new_outer_cond, new_reduce, make_zero(new_reduce.dtype()));
for (const auto& src : red->source) {
src_with_newaxis.push_back(tir::Substitute(src, vmap));
}
+ Array<PrimExpr> init_with_newaxis;
+ for (const auto& init : red->init) {
+ init_with_newaxis.push_back(tir::Substitute(init, vmap));
+ }
return Reduce(red->combiner, src_with_newaxis, new_axis, tir::Substitute(red->condition, vmap),
- red->value_index);
+ red->value_index, init_with_newaxis);
} else {
return expr;
}
// If this is a reduction then we have to replicate it
if (const ReduceNode* red = exprs[0].as<ReduceNode>()) {
for (size_t i = 0; i < red->source.size(); ++i) {
- PrimExpr ith_red = Reduce(red->combiner, red->source, red->axis, red->condition, i);
+ PrimExpr ith_red =
+ Reduce(red->combiner, red->source, red->axis, red->condition, i, red->init);
new_exprs.push_back(ith_red);
}
} else {
PrimExpr expr_with_new_axes = te::CloneReduction(GetRef<PrimExpr>(op));
const ReduceNode* new_op = expr_with_new_axes.as<ReduceNode>();
+ CHECK(new_op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
+
// New lhs and rhs variables of the new combiner consist of
// variables representing derivatives (which are later derived from new_op->source)
// followed by the original variables.
CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity);
// Also simplify the resulting combiner
// (mostly to get rid of unused components, e.g., the original expressions)
- return analyzer_.Simplify(
- Reduce(new_combiner, new_source, new_op->axis, new_op->condition, new_op->value_index));
+ return analyzer_.Simplify(Reduce(new_combiner, new_source, new_op->axis, new_op->condition,
+ new_op->value_index, new_op->init));
}
PrimExpr VisitExpr_(const CastNode* op) {
if (const ReduceNode* red = new_body.as<ReduceNode>()) {
value_index = red->value_index;
for (size_t idx = 0; idx < red->source.size(); ++idx) {
- new_bodies.push_back(Reduce(red->combiner, red->source, red->axis, red->condition, idx));
+ new_bodies.push_back(
+ Reduce(red->combiner, red->source, red->axis, red->condition, idx, red->init));
}
} else {
new_bodies.push_back(new_body);
inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
- (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition));
+ (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) &&
+ ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init)));
}
int ComputeOpNode::num_outputs() const { return body.size(); }
}
Array<PrimExpr> init_value = combiner->identity_element;
Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
+
+ // If an init was passed to ReduceNode, use that for initialization
+ // instead of combiner->identity_element
+ Array<PrimExpr> reduce_init = reduce->init;
+ if (!reduce_init.empty()) {
+ init_value = reduce_init;
+ }
for (size_t i = 0; i < size; ++i) {
Tensor t = tensors[i];
inits.emplace_back(ProducerStore(t, init_value[i], args));
for (size_t i = 0; i < size; ++i) {
const ReduceNode* reduce = self->body[i].as<ReduceNode>();
CHECK(reduce);
+ CHECK(reduce->init.empty()) << "Cannot perform cross_thread_reduction for reductions with init";
reduces[i] = reduce;
}
axis.push_back(it->second);
}
}
- return Reduce(op->combiner, op->source, axis, op->condition, op->value_index);
+ return Reduce(op->combiner, op->source, axis, op->condition, op->value_index, op->init);
}
void Init(const ComputeOpNode* self, const Stage& stage,
return new_e;
} else {
return tir::Reduce(new_combiner, new_reduce->source, new_reduce->axis, new_reduce->condition,
- new_reduce->value_index);
+ new_reduce->value_index, new_reduce->init);
}
}
inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
- (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition));
+ (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) &&
+ ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init)));
}
Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope,
if (first_reduce != nullptr) {
CHECK(ReduceEqual(reduce_body, first_reduce));
body = tir::Reduce(first_reduce->combiner, first_reduce->source, first_reduce->axis,
- first_reduce->condition, reduce_body->value_index);
+ first_reduce->condition, reduce_body->value_index, reduce_body->init);
} else {
first_reduce = reduce_body;
}
std::vector<PrimExpr> body;
for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
- body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx));
+ body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx, {}));
}
n->body = Array<PrimExpr>(body);
// refresh relations, keep the un-touched relations.
}
indices.push_back(i[idx]);
}
+ Array<PrimExpr> new_init = reduce->init;
+ if (!reduce->init.empty()) {
+ std::unordered_map<const VarNode*, PrimExpr> init_vsub;
+ for (const auto& init : reduce->init) {
+ if (init->IsInstance<ProducerLoadNode>()) {
+ CHECK_EQ(compute_op->axis.size(), idx_size)
+ << "'init' should have the number of dimensions as output when using with "
+ "rfactor";
+ for (int idx = 0; idx < idx_size; idx++) {
+ init_vsub[compute_op->axis[idx]->var.get()] = i[idx];
+ }
+ }
+ }
+ VarReplacer init_replacer(init_vsub);
+ new_init = tir::UpdateArray(
+ reduce->init, [&init_replacer](const PrimExpr& e) { return init_replacer(e); });
+ }
if (factor_axis_pos == idx_size) {
indices.push_back(repl_red_axis->var);
}
Array<IterVar> axis = {repl_red_axis};
PrimExpr cond = const_true();
for (int idx = 0; idx < size; ++idx) {
- reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx));
+ reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx, new_init));
}
return reductions;
},
// Reduce
Reduce::Reduce(CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis,
- PrimExpr condition, int value_index) {
+ PrimExpr condition, int value_index, Array<PrimExpr> init) {
for (size_t i = 0; i < axis.size(); ++i) {
CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis";
}
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
}
+ if (!init.empty()) {
+ CHECK_EQ(init.size(), source.size()) << "Number of inits should match number of exprs";
+ for (size_t i = 0; i < init.size(); i++) {
+ CHECK(init[i]->IsInstance<ProducerLoadNode>() || init[i]->IsInstance<IntImmNode>() ||
+ init[i]->IsInstance<FloatImmNode>())
+ << "init can only be a IntImm, FloatImm or ProducerLoad";
+ }
+ }
n->dtype = source[value_index].dtype();
n->combiner = std::move(combiner);
n->source = std::move(source);
+ n->init = std::move(init);
n->axis = std::move(axis);
n->condition = condition;
n->value_index = value_index;
TVM_REGISTER_GLOBAL("tir.Reduce")
.set_body_typed([](CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis,
- PrimExpr condition, int value_index) {
- return Reduce(combiner, source, axis, condition, value_index);
+ PrimExpr condition, int value_index, Array<PrimExpr> init) {
+ return Reduce(combiner, source, axis, condition, value_index, init);
});
TVM_REGISTER_NODE_TYPE(ReduceNode);
auto* op = static_cast<const ReduceNode*>(node.get());
p->stream << "reduce(combiner=" << op->combiner;
p->stream << ", source=" << op->source;
+ p->stream << ", init=" << op->init;
p->stream << ", axis=" << op->axis;
p->stream << ", where=" << op->condition;
p->stream << ", value_index=" << op->value_index;
this->VisitExpr(r->dom->extent);
});
VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); });
+ if (!op->init.empty()) {
+ VisitArray(op->init, [this](const PrimExpr& e) { this->VisitExpr(e); });
+ }
this->VisitExpr(op->condition);
}
auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> source = MutateArray(op->source, fexpr);
+ Array<PrimExpr> init = MutateArray(op->init, fexpr);
PrimExpr condition = this->VisitExpr(op->condition);
- if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition)) {
+ if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition) &&
+ init.same_as(op->init)) {
return GetRef<PrimExpr>(op);
} else {
- return Reduce(op->combiner, source, axis, condition, op->value_index);
+ return Reduce(op->combiner, source, axis, condition, op->value_index, init);
}
}
// isfinite
PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); }
-PrimExpr sum(PrimExpr source, Array<IterVar> rdom) {
+PrimExpr sum(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::Add(x, y);
PrimExpr identity_element = make_zero(source.dtype());
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
- return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+ return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
}
-PrimExpr all(PrimExpr source, Array<IterVar> rdom) {
+PrimExpr all(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
CHECK(source.dtype().is_bool());
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::And(x, y);
PrimExpr identity_element = make_const(source.dtype(), true);
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
- return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+ return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
}
-PrimExpr any(PrimExpr source, Array<IterVar> rdom) {
+PrimExpr any(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
CHECK(source.dtype().is_bool());
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::Or(x, y);
PrimExpr identity_element = make_const(source.dtype(), false);
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
- return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+ return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
}
-PrimExpr max(PrimExpr source, Array<IterVar> rdom) {
+PrimExpr max(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::Max(x, y);
PrimExpr identity_element = min_value(source.dtype());
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
- return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+ return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
}
-PrimExpr min(PrimExpr source, Array<IterVar> rdom) {
+PrimExpr min(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::Min(x, y);
PrimExpr identity_element = max_value(source.dtype());
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
- return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+ return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
}
-PrimExpr prod(PrimExpr source, Array<IterVar> rdom) {
+PrimExpr prod(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::Mul(x, y);
PrimExpr identity_element = make_const(source.dtype(), 1);
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
- return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+ return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
}
// fmod
test_prim(tvm.te.min, np.amin)
test_prim(tvm.te.max, np.amax)
+def test_init_imm():
+ n = tvm.runtime.convert(1027)
+ A = te.placeholder((n,), name='A')
+ k = te.reduce_axis((0, n))
+ B = te.compute((1,), lambda i: te.sum(A[k], axis=k, init=10.0), name='B')
+ # schedule
+ s = te.create_schedule(B.op)
+ # one line to build the function.
+ def check_target(target="llvm"):
+ if not tvm.runtime.enabled(target):
+ return
+ ctx = tvm.cpu(0)
+ fapi = tvm.lower(s, args=[A, B])
+ fsum = tvm.build(fapi,
+ target=target,
+ name="mysum")
+ # launch the kernel.
+ n = 1027
+ a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
+ b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
+ fsum(a, b)
+ res = 10.0 + np.sum(a.asnumpy(), axis=0)
+ tvm.testing.assert_allclose(
+ b.asnumpy(), res, rtol=1e-4)
+
+ check_target()
+
+def test_init():
+ n = tvm.runtime.convert(1027)
+ A = te.placeholder((n,n), name='A')
+ C = te.placeholder((n,n), name='C')
+ I = te.placeholder((n,n), name='I')
+ k = te.reduce_axis((0, n))
+ B = te.compute((n,n), lambda i,j: te.sum(A[i,k]*C[k,j], axis=k, init=I[i,j]), name='B')
+
+ # schedule
+ s = te.create_schedule(B.op)
+ # one line to build the function.
+ def check_target(target="llvm"):
+ if not tvm.runtime.enabled(target):
+ return
+ ctx = tvm.cpu(0)
+ fapi = tvm.lower(s, args=[A, C, I, B])
+ print(fapi)
+ mmult = tvm.build(fapi, target=target, name="mmult")
+ # launch the kernel.
+ n = 1027
+ a = tvm.nd.array(np.random.uniform(size=(n,n)).astype(A.dtype), ctx)
+ c = tvm.nd.array(np.random.uniform(size=(n,n)).astype(C.dtype), ctx)
+ ii = tvm.nd.array(np.random.uniform(size=(n,n)).astype(B.dtype), ctx)
+ b = tvm.nd.array(np.zeros((n,n), dtype=B.dtype), ctx)
+ mmult(a, c, ii, b)
+ res = ii.asnumpy() + np.matmul(a.asnumpy(),c.asnumpy())
+ tvm.testing.assert_allclose(
+ b.asnumpy(), res, rtol=1e-4)
+
+ check_target()
def test_rfactor():
n = tvm.runtime.convert(1027)
check_target()
+def test_rfactor_init():
+ n = tvm.runtime.convert(1027)
+ A = te.placeholder((n,n), name='A')
+ C = te.placeholder((n,n), name='C')
+ I = te.placeholder((n,n), name='I')
+ k = te.reduce_axis((0, n))
+ B = te.compute((n,n), lambda i,j: te.sum(A[i,k]*C[k,j], axis=k, init=I[i,j]), name='B')
+
+ # schedule
+ s = te.create_schedule(B.op)
+ kf, ki = s[B].split(k, nparts=4)
+ BF = s.rfactor(B, kf, 1)
+ s[BF].parallel(BF.op.axis[0])
+ # one line to build the function.
+ def check_target(target="llvm"):
+ if not tvm.runtime.enabled(target):
+ return
+ ctx = tvm.cpu(0)
+ fapi = tvm.lower(s, args=[A, C, I, B])
+ print(fapi)
+ mmult = tvm.build(fapi, target=target, name="mmult")
+ # launch the kernel.
+ n = 1027
+ a = tvm.nd.array(np.random.uniform(size=(n,n)).astype(A.dtype), ctx)
+ c = tvm.nd.array(np.random.uniform(size=(n,n)).astype(C.dtype), ctx)
+ ii = tvm.nd.array(np.random.uniform(size=(n,n)).astype(B.dtype), ctx)
+ b = tvm.nd.array(np.zeros((n,n), dtype=B.dtype), ctx)
+ mmult(a, c, ii, b)
+ res = ii.asnumpy() + np.matmul(a.asnumpy(),c.asnumpy())
+ tvm.testing.assert_allclose(
+ b.asnumpy(), res, rtol=1e-4)
+
+ check_target()
+
def test_rfactor_factor_axis():
n = tvm.runtime.convert(1027)
A = te.placeholder((n,), name='A')
test_rfactor_argmax()
test_warp_reduction1()
test_warp_reduction2()
+ test_init()
+ test_init_imm()
+ test_rfactor_init()
# Test that SimplifyCombiner makes use of vranges
ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, -4))
ck.verify(sum_or_prod(A[k], k), te.sum(A[k], k))
+ ck.verify(sum_or_prod(A[k], k, init=1), te.sum(A[k], k, init=1))
ck.analyzer.update(dummy, tvm.arith.ConstIntBound(5, 9), True)
ck.verify(sum_or_prod(A[k], k), prod(A[k], k))
+ ck.verify(sum_or_prod(A[k], k, init=1), prod(A[k], k, init=1))
ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, 100), True)
ck.verify(sum_and_prod((A[k], A[10-k]), k)[0], te.sum(A[k], k))
ck.verify(sum_and_prod((A[k], A[10-k]), k)[1], prod(A[10-k], k))
ck.verify(te.sum(tvm.tir.Select(k + j < 12, k + j, 0), [k, j]),
te.sum(k + j, [k, j]))
ck.verify(te.sum(A[3], []), A[3])
+ ck.verify(te.sum(A[3], [], where=k > 12, init=1.0), tvm.tir.const(1.0, dtype='float32'))
# The rule below is not typical, removed for now
ck.verify(te.sum(te.div(k, 10), k), te.sum(tvm.tir.const(0, "int32"), k))
from tvm.testing import check_numerical_grads, assert_allclose
from tvm import topi
from tvm.topi.util import get_const_tuple
+import pytest
import numpy as np
Y = topi.nn.pool(X, [3, 3], [3, 3], [0, 0, 0, 0], 'max')
check_grad(Y, [X])
+@pytest.mark.xfail
+def test_reduction_init():
+ np.random.seed(0)
+ shape = (10, 10)
+ k = te.reduce_axis((0, 10), name="k")
+ A0 = te.placeholder(shape, name='A0')
+
+ B = te.compute((10,), lambda i: te.sum(A0[i, k]*A0[k, i], axis=k, init=0.0), name='B')
+ check_grad(B, A0)
if __name__ == "__main__":
test_basic_operation()