Add `init` member to ReduceNode (#6138)
authorquic-sanirudh <63797228+quic-sanirudh@users.noreply.github.com>
Thu, 27 Aug 2020 02:11:24 +0000 (07:41 +0530)
committerGitHub <noreply@github.com>
Thu, 27 Aug 2020 02:11:24 +0000 (19:11 -0700)
- This patch adds a new member to ReduceNode called init which allows
  initialization with a custom ProducerLoad or a Float/Int immediate.
- This allows initialization of the output Tensor of a reduction with
  another Tensor instead of the `identity_element` defined in the
  CommReducer
- One example use case for this node is to initialize the Output of a
  convolution reduction with the Bias values thereby saving the
  Bias-add computation.

20 files changed:
include/tvm/tir/expr.h
include/tvm/tir/op.h
include/tvm/topi/reduction.h
python/tvm/tir/expr.py
python/tvm/tir/op.py
src/arith/canonical_simplify.cc
src/printer/tir_text_printer.cc
src/te/autodiff/ad_simplify.cc
src/te/autodiff/ad_util.cc
src/te/autodiff/jacobian.cc
src/te/operation/compute_op.cc
src/te/operation/cross_thread_reduction.cc
src/te/operation/tensorize.cc
src/te/schedule/schedule_dataflow_rewrite.cc
src/tir/ir/expr.cc
src/tir/ir/expr_functor.cc
src/tir/op/op.cc
tests/python/integration/test_reduce.py
tests/python/unittest/test_arith_canonical_simplify.py
tests/python/unittest/test_te_autodiff.py

index 100d163..9e6f440 100644 (file)
@@ -1026,6 +1026,8 @@ class ReduceNode : public PrimExprNode {
   CommReducer combiner;
   /*! \brief The source operand */
   Array<PrimExpr> source;
+  /*! \brief The init operand */
+  Array<PrimExpr> init;
   /*! \brief The reduction axis */
   Array<IterVar> axis;
   /*!
@@ -1040,6 +1042,7 @@ class ReduceNode : public PrimExprNode {
     v->Visit("dtype", &dtype);
     v->Visit("combiner", &combiner);
     v->Visit("source", &source);
+    v->Visit("init", &init);
     v->Visit("axis", &axis);
     v->Visit("condition", &condition);
     v->Visit("value_index", &value_index);
@@ -1049,7 +1052,8 @@ class ReduceNode : public PrimExprNode {
     // check axis first so IterVars can define the necessary variables.
     return equal(dtype, other->dtype) && equal(axis, other->axis) &&
            equal(combiner, other->combiner) && equal(source, other->source) &&
-           equal(condition, other->condition) && equal(value_index, other->value_index);
+           equal(init, other->init) && equal(condition, other->condition) &&
+           equal(value_index, other->value_index);
   }
 
   void SHashReduce(SHashReducer hash_reduce) const {
@@ -1057,6 +1061,7 @@ class ReduceNode : public PrimExprNode {
     hash_reduce(axis);
     hash_reduce(combiner);
     hash_reduce(source);
+    hash_reduce(init);
     hash_reduce(condition);
     hash_reduce(value_index);
   }
@@ -1072,7 +1077,7 @@ class ReduceNode : public PrimExprNode {
 class Reduce : public PrimExpr {
  public:
   TVM_DLL Reduce(CommReducer combiner, Array<PrimExpr> src, Array<IterVar> rdom, PrimExpr condition,
-                 int value_index);
+                 int value_index, Array<PrimExpr> init);
 
   TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode);
 };
index 93a54b0..9e53e97 100644 (file)
@@ -464,48 +464,54 @@ TVM_DLL PrimExpr isinf(PrimExpr x);
  * \brief sum of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
+ * \param init The value with which to initialize the output.
  * \return The result.
  */
-TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis);
+TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
 
 /*!
  * \brief logical And of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
+ * \param init The value with which to initialize the output.
  */
-TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis);
+TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
 
 /*!
  * \brief logical Or of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
+ * \param init The value with which to initialize the output.
  * \return The result.
  */
-TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis);
+TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
 
 /*!
  * \brief max of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
+ * \param init The value with which to initialize the output.
  * \return The result.
  */
-TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis);
+TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
 
 /*!
  * \brief max of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
+ * \param init The value with which to initialize the output.
  * \return The result.
  */
-TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis);
+TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
 
 /*!
  * \brief product of of source expression over axis
  * \param source The source expression.
  * \param axis List of iteration variables that will be used for reduction.
+ * \param init The value with which to initialize the output.
  * \return The result.
  */
-TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis);
+TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
 
 /*!
  * \brief Calculate floor(x)
index 8a8a947..75c8265 100644 (file)
@@ -43,7 +43,8 @@ namespace topi {
 using namespace tvm::te;
 
 /*! \brief The operation to use for CommReduce */
-using FReduce = std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis)>;
+using FReduce =
+    std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis, Array<PrimExpr> init)>;
 
 /*! \brief The operation to use for CommReduceIdx */
 using FCommReduce = std::function<Array<PrimExpr>(Array<PrimExpr> exprs, const Array<IterVar>& axis,
@@ -158,7 +159,7 @@ inline Tensor DoCommReduce(const Tensor& data, FReduce func, const Array<PrimExp
       arg_counter++;
     }
 
-    return func(data(eval_range), r_axes);
+    return func(data(eval_range), r_axes, {});
   };
 
   return tvm::te::compute(target_shape, compute, data->op->name + "_red", kCommReduce);
@@ -284,23 +285,25 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity,
     auto combiner = tvm::tir::CommReducer(lhs, rhs, result, id_elem);
     Array<PrimExpr> outputs;
     for (size_t i = 0; i < exprs.size(); ++i) {
-      outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast<int>(i)));
+      outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast<int>(i), {}));
     }
     return outputs;
   };
 }
 
 /*! \brief Wrap tvm::min to ensure we get the correct overload */
-inline PrimExpr MinOp(PrimExpr source, Array<IterVar> axis) { return tvm::min(source, axis); }
+inline PrimExpr MinOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {}) {
+  return tvm::min(source, axis, init);
+}
 
 /*! \brief Wrap tvm::max to ensure we get the correct overload */
-inline PrimExpr MaxOp(PrimExpr source, Array<IterVar> axis) {
-  return tvm::max(source, axis);  // NOLINT(*)
+inline PrimExpr MaxOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {}) {
+  return tvm::max(source, axis, init);  // NOLINT(*)
 }
 
 /*! \brief Wrap tvm::prod to ensure we get the correct overload */
-inline PrimExpr ProdOp(PrimExpr source, Array<IterVar> axis) {
-  return tvm::prod(source, axis);  // NOLINT(*)
+inline PrimExpr ProdOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {}) {
+  return tvm::prod(source, axis, init);  // NOLINT(*)
 }
 
 /*!
index c8f151e..6f3d550 100644 (file)
@@ -433,11 +433,14 @@ class Reduce(PrimExprWithOp):
 
     value_index : int
         The value index.
+
+    init : list of Expr
+        The initial value for output. This can be an int, float or ProducerLoad
     """
-    def __init__(self, combiner, src, rdom, condition, value_index):
+    def __init__(self, combiner, src, rdom, condition, value_index, init=None):
         self.__init_handle_by_constructor__(
             _ffi_api.Reduce, combiner, src, rdom,
-            condition, value_index)
+            condition, value_index, init)
 
 
 @tvm._ffi.register_object
index b62d6a3..9592e6e 100644 (file)
@@ -1239,10 +1239,12 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
             res = fcombine(res, args[i+1])
         return res
 
-    def _make_reduce(expr, axis, where=None):
+    def _make_reduce(expr, axis, where=None, init=None):
         code = fcombine.__code__
         assert fcombine.__code__.co_argcount == 2
         expr = convert(expr)
+        if init is not None:
+            init = convert(init)
         if isinstance(expr, Array):
             size = len(expr)
             larr = []
@@ -1255,6 +1257,16 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
                 larr.append(Var(lname, dtype))
                 rname = code.co_varnames[1] + "_" + str(i)
                 rarr.append(Var(rname, dtype))
+            if init is not None:
+                init = convert(init)
+                assert isinstance(init, Array)
+                assert len(init) == size
+                for init_i in range(size):
+                    init_i = convert(init_i)
+                    assert isinstance(init_i,
+                                      (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm))
+            else:
+                init = convert([])
             lhs = convert(larr)
             rhs = convert(rarr)
             result = fcombine(lhs, rhs)
@@ -1270,21 +1282,28 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
             lhs = convert([lvar])
             rhs = convert([rvar])
             expr = convert([expr])
+            if init is not None:
+                assert isinstance(init, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm))
+                init = convert([init])
         result = convert(result)
         id_elem = convert(id_elem)
         combiner = CommReducer(lhs, rhs, result, id_elem)
         axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
         if where is None:
             where = convert(True)
-        outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i)
-                        for i in range(size))
+        if init is None:
+            outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, convert([]))
+                            for i in range(size))
+        else:
+            outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, init)
+                            for i in range(size))
         return outputs[0] if size == 1 else outputs
 
     # pylint: disable=keyword-arg-before-vararg
-    def reducer(expr, axis, where=None, *args):
+    def reducer(expr, axis, where=None, init=None, *args):
         if isinstance(axis, (tvm.tir.IterVar, list, tuple)):
             assert not args
-            return _make_reduce(expr, axis, where)
+            return _make_reduce(expr, axis, where, init)
         if where is None:
             assert not args
             return _reduce_directly(expr, axis)
index a8ef6a1..a88849b 100644 (file)
@@ -1013,7 +1013,8 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
   for (size_t i = 0; i < used.size(); ++i) {
     if (SideEffect(op->source[i]) > CallEffectKind::kReadState ||
         SideEffect(op->combiner->identity_element[i]) > CallEffectKind::kReadState ||
-        SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState) {
+        SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState ||
+        (!op->init.empty() && SideEffect(op->init[i]) > CallEffectKind::kReadState)) {
       mark_used(i);
     }
   }
@@ -1024,6 +1025,7 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
   Array<Var> new_lhs;
   Array<Var> new_rhs;
   Array<PrimExpr> new_source;
+  Array<PrimExpr> new_init;
 
   // new stuff is old stuff which is used
   for (size_t i = 0; i < used.size(); ++i) {
@@ -1034,6 +1036,7 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
       new_lhs.push_back(op->combiner->lhs[i]);
       new_rhs.push_back(op->combiner->rhs[i]);
       new_source.push_back(op->source[i]);
+      if (!op->init.empty()) new_init.push_back(op->init[i]);
     } else if (static_cast<int>(i) < op->value_index) {
       // value_index should also be adjusted
       new_value_index--;
@@ -1041,7 +1044,7 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
   }
 
   CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity);
-  return Reduce(new_combiner, new_source, op->axis, op->condition, new_value_index);
+  return Reduce(new_combiner, new_source, op->axis, op->condition, new_value_index, new_init);
 }
 
 PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
@@ -1051,6 +1054,11 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
   // already been simplified by const reduction axis removal
   if (op == nullptr) return ret;
   if (op->axis.empty()) {
+    if (!op->init.empty()) {
+      return this->VisitExpr(Select(op->condition,
+                                    (*op->combiner.get())(op->init, op->source)[op->value_index],
+                                    op->init[op->value_index]));
+    }
     // Note that here we assume that the identity element is indeed identity. Without this
     // assumption we would have to perform a single iteration of the loop, i.e. use
     // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
index 8e86fb5..132b12c 100644 (file)
@@ -373,7 +373,7 @@ Doc TIRTextPrinter::VisitExpr_(const ShuffleNode* op) {
 Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) {
   Doc doc;
   doc << "reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis)
-      << ", " << op->value_index << ")";
+      << ", " << op->value_index << ", " << Print(op->init) << ")";
   return doc;
 }
 
index 874e8ca..9ce9597 100644 (file)
@@ -806,9 +806,9 @@ PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& ou
 
     // Perform simplification mainly to remove a possibly empty reduction.
     arith::Analyzer analyzer;
-    return analyzer.Simplify(
-        Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
-        kSimplifyRewriteCanonicalRewrite);
+    return analyzer.Simplify(Reduce(red->combiner, new_source, new_axis, All(res->dst->relations),
+                                    red->value_index, red->init),
+                             kSimplifyRewriteCanonicalRewrite);
   } else {
     return expr;
   }
@@ -938,6 +938,7 @@ class RemoveRedundantInequalitiesMutator : public ExprMutator {
 
   virtual PrimExpr VisitExpr_(const ReduceNode* op) {
     Array<PrimExpr> known_with_axes = known_;
+    CHECK(op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
     for (const PrimExpr& axis_cond : IterVarsToInequalities(op->axis)) {
       known_with_axes.push_back(axis_cond);
     }
@@ -956,7 +957,7 @@ class RemoveRedundantInequalitiesMutator : public ExprMutator {
       new_source.push_back(new_mutator(src));
     }
 
-    return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index);
+    return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index, op->init);
   }
 
   virtual PrimExpr VisitExpr_(const EQNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
@@ -1068,13 +1069,14 @@ class ReductionAsTensorAccessMutator : public ExprMutator {
     ReductionAsTensorAccessMutator new_mutator(Concat(IterVarsToVars(op->axis), outer_axis_),
                                                Merge(vranges_, IterVarsToMap(op->axis)), name_);
 
+    CHECK(op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
     Array<PrimExpr> new_source;
     for (const PrimExpr& src : op->source) {
       new_source.push_back(new_mutator(src));
     }
 
     PrimExpr new_reduce =
-        Reduce(op->combiner, new_source, op->axis, op->condition, op->value_index);
+        Reduce(op->combiner, new_source, op->axis, op->condition, op->value_index, op->init);
 
     Array<Var> undefined_vars = UndefinedVars(new_reduce);
     std::unordered_set<const VarNode*> undefined_var_set;
@@ -1133,7 +1135,7 @@ PrimExpr LiftReductions(const PrimExpr& expr, const Array<Var>& outer_axis,
     }
     PrimExpr new_condition = ReductionAsTensorAccess(red->condition, new_outer_axis, new_vranges);
 
-    return Reduce(red->combiner, new_source, red->axis, new_condition, red->value_index);
+    return Reduce(red->combiner, new_source, red->axis, new_condition, red->value_index, red->init);
   } else {
     return ReductionAsTensorAccess(expr, outer_axis, vranges);
   }
@@ -1150,6 +1152,7 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A
   PrimExpr expr = analyzer.Simplify(expr_orig, kSimplifyRewriteCanonicalRewrite);
 
   if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    CHECK(red->init.empty()) << "Derivative of Reduction with initialization is not implemented";
     // TODO(sgrechanik-h): There are some other operations which behave like sum
     bool is_sum = IsSumCombiner(red->combiner, vranges);
     if (is_sum || CanFactorZeroFromCombiner(red->combiner, red->value_index, vranges)) {
@@ -1167,7 +1170,7 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A
         source.Set(0, nz.value);
       }
 
-      new_red = Reduce(red->combiner, source, red->axis, cond, red->value_index);
+      new_red = Reduce(red->combiner, source, red->axis, cond, red->value_index, red->init);
       new_red = SimplifyReductionDomain(new_red, combined_vranges);
       // If the reduction disappears completely then transform the result as a non-reduction
       if (!new_red.as<ReduceNode>()) {
@@ -1193,8 +1196,8 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A
         new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype())));
       }
 
-      PrimExpr new_reduce =
-          Reduce(red->combiner, new_source, red->axis, new_reduce_cond, red->value_index);
+      PrimExpr new_reduce = Reduce(red->combiner, new_source, red->axis, new_reduce_cond,
+                                   red->value_index, red->init);
       new_reduce =
           TrySimplifyCompute(new_reduce, new_outer_cond, IterVarsToVars(axis), combined_vranges);
       result = Select(new_outer_cond, new_reduce, make_zero(new_reduce.dtype()));
index 995c8e0..024015a 100644 (file)
@@ -55,9 +55,13 @@ PrimExpr CloneReduction(const PrimExpr& expr) {
     for (const auto& src : red->source) {
       src_with_newaxis.push_back(tir::Substitute(src, vmap));
     }
+    Array<PrimExpr> init_with_newaxis;
+    for (const auto& init : red->init) {
+      init_with_newaxis.push_back(tir::Substitute(init, vmap));
+    }
 
     return Reduce(red->combiner, src_with_newaxis, new_axis, tir::Substitute(red->condition, vmap),
-                  red->value_index);
+                  red->value_index, init_with_newaxis);
   } else {
     return expr;
   }
@@ -82,7 +86,8 @@ Operation ComputeOpFromExprs(const Array<PrimExpr>& exprs, const Array<IterVar>&
   // If this is a reduction then we have to replicate it
   if (const ReduceNode* red = exprs[0].as<ReduceNode>()) {
     for (size_t i = 0; i < red->source.size(); ++i) {
-      PrimExpr ith_red = Reduce(red->combiner, red->source, red->axis, red->condition, i);
+      PrimExpr ith_red =
+          Reduce(red->combiner, red->source, red->axis, red->condition, i, red->init);
       new_exprs.push_back(ith_red);
     }
   } else {
index e769e54..3724af5 100644 (file)
@@ -181,6 +181,8 @@ class JacobianMutator : public ExprMutator {
     PrimExpr expr_with_new_axes = te::CloneReduction(GetRef<PrimExpr>(op));
     const ReduceNode* new_op = expr_with_new_axes.as<ReduceNode>();
 
+    CHECK(new_op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
+
     // New lhs and rhs variables of the new combiner consist of
     // variables representing derivatives (which are later derived from new_op->source)
     // followed by the original variables.
@@ -245,8 +247,8 @@ class JacobianMutator : public ExprMutator {
     CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity);
     // Also simplify the resulting combiner
     // (mostly to get rid of unused components, e.g., the original expressions)
-    return analyzer_.Simplify(
-        Reduce(new_combiner, new_source, new_op->axis, new_op->condition, new_op->value_index));
+    return analyzer_.Simplify(Reduce(new_combiner, new_source, new_op->axis, new_op->condition,
+                                     new_op->value_index, new_op->init));
   }
 
   PrimExpr VisitExpr_(const CastNode* op) {
@@ -342,7 +344,8 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) {
   if (const ReduceNode* red = new_body.as<ReduceNode>()) {
     value_index = red->value_index;
     for (size_t idx = 0; idx < red->source.size(); ++idx) {
-      new_bodies.push_back(Reduce(red->combiner, red->source, red->axis, red->condition, idx));
+      new_bodies.push_back(
+          Reduce(red->combiner, red->source, red->axis, red->condition, idx, red->init));
     }
   } else {
     new_bodies.push_back(new_body);
index 62369f9..c3b2a0b 100644 (file)
@@ -56,7 +56,8 @@ static void VerifyComputeOp(const ComputeOpNode* op);
 
 inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
   return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
-         (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition));
+         (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) &&
+         ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init)));
 }
 
 int ComputeOpNode::num_outputs() const { return body.size(); }
@@ -307,6 +308,13 @@ void MakeReduction(const ComputeOpNode* op, const Array<Tensor>& tensors, Stmt*
   }
   Array<PrimExpr> init_value = combiner->identity_element;
   Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
+
+  // If an init was passed to ReduceNode, use that for initialization
+  // instead of combiner->identity_element
+  Array<PrimExpr> reduce_init = reduce->init;
+  if (!reduce_init.empty()) {
+    init_value = reduce_init;
+  }
   for (size_t i = 0; i < size; ++i) {
     Tensor t = tensors[i];
     inits.emplace_back(ProducerStore(t, init_value[i], args));
index 427be32..6369ecb 100644 (file)
@@ -97,6 +97,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
   for (size_t i = 0; i < size; ++i) {
     const ReduceNode* reduce = self->body[i].as<ReduceNode>();
     CHECK(reduce);
+    CHECK(reduce->init.empty()) << "Cannot perform cross_thread_reduction for reductions with init";
     reduces[i] = reduce;
   }
 
index 1d72345..ab96ae8 100644 (file)
@@ -192,7 +192,7 @@ class TensorIntrinMatcher final : public StmtExprMutator {
         axis.push_back(it->second);
       }
     }
-    return Reduce(op->combiner, op->source, axis, op->condition, op->value_index);
+    return Reduce(op->combiner, op->source, axis, op->condition, op->value_index, op->init);
   }
 
   void Init(const ComputeOpNode* self, const Stage& stage,
index 52c6757..78f0608 100644 (file)
@@ -76,7 +76,7 @@ class VarReplacer : public tir::StmtExprMutator {
       return new_e;
     } else {
       return tir::Reduce(new_combiner, new_reduce->source, new_reduce->axis, new_reduce->condition,
-                         new_reduce->value_index);
+                         new_reduce->value_index, new_reduce->init);
     }
   }
 
@@ -123,7 +123,8 @@ void ReplaceDataFlow(const Array<Stage>& stages, std::unordered_map<Tensor, Tens
 
 inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
   return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
-         (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition));
+         (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) &&
+         ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init)));
 }
 
 Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope,
@@ -301,7 +302,7 @@ Array<Tensor> CacheWriteWithReLayout(Schedule sch, const Array<Tensor>& tensor_a
       if (first_reduce != nullptr) {
         CHECK(ReduceEqual(reduce_body, first_reduce));
         body = tir::Reduce(first_reduce->combiner, first_reduce->source, first_reduce->axis,
-                           first_reduce->condition, reduce_body->value_index);
+                           first_reduce->condition, reduce_body->value_index, reduce_body->init);
       } else {
         first_reduce = reduce_body;
       }
@@ -812,7 +813,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f
 
   std::vector<PrimExpr> body;
   for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
-    body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx));
+    body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx, {}));
   }
   n->body = Array<PrimExpr>(body);
   // refresh relations, keep the un-touched relations.
@@ -865,6 +866,23 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f
           }
           indices.push_back(i[idx]);
         }
+        Array<PrimExpr> new_init = reduce->init;
+        if (!reduce->init.empty()) {
+          std::unordered_map<const VarNode*, PrimExpr> init_vsub;
+          for (const auto& init : reduce->init) {
+            if (init->IsInstance<ProducerLoadNode>()) {
+              CHECK_EQ(compute_op->axis.size(), idx_size)
+                  << "'init' should have the number of dimensions as output when using with "
+                     "rfactor";
+              for (int idx = 0; idx < idx_size; idx++) {
+                init_vsub[compute_op->axis[idx]->var.get()] = i[idx];
+              }
+            }
+          }
+          VarReplacer init_replacer(init_vsub);
+          new_init = tir::UpdateArray(
+              reduce->init, [&init_replacer](const PrimExpr& e) { return init_replacer(e); });
+        }
         if (factor_axis_pos == idx_size) {
           indices.push_back(repl_red_axis->var);
         }
@@ -876,7 +894,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f
         Array<IterVar> axis = {repl_red_axis};
         PrimExpr cond = const_true();
         for (int idx = 0; idx < size; ++idx) {
-          reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx));
+          reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx, new_init));
         }
         return reductions;
       },
index b4bb984..687dfd6 100644 (file)
@@ -857,7 +857,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 // Reduce
 Reduce::Reduce(CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis,
-               PrimExpr condition, int value_index) {
+               PrimExpr condition, int value_index, Array<PrimExpr> init) {
   for (size_t i = 0; i < axis.size(); ++i) {
     CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis";
   }
@@ -869,9 +869,18 @@ Reduce::Reduce(CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis
   for (size_t i = 0; i < axis.size(); ++i) {
     CHECK(axis[i].defined());
   }
+  if (!init.empty()) {
+    CHECK_EQ(init.size(), source.size()) << "Number of inits should match number of exprs";
+    for (size_t i = 0; i < init.size(); i++) {
+      CHECK(init[i]->IsInstance<ProducerLoadNode>() || init[i]->IsInstance<IntImmNode>() ||
+            init[i]->IsInstance<FloatImmNode>())
+          << "init can only be a IntImm, FloatImm or ProducerLoad";
+    }
+  }
   n->dtype = source[value_index].dtype();
   n->combiner = std::move(combiner);
   n->source = std::move(source);
+  n->init = std::move(init);
   n->axis = std::move(axis);
   n->condition = condition;
   n->value_index = value_index;
@@ -880,8 +889,8 @@ Reduce::Reduce(CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis
 
 TVM_REGISTER_GLOBAL("tir.Reduce")
     .set_body_typed([](CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis,
-                       PrimExpr condition, int value_index) {
-      return Reduce(combiner, source, axis, condition, value_index);
+                       PrimExpr condition, int value_index, Array<PrimExpr> init) {
+      return Reduce(combiner, source, axis, condition, value_index, init);
     });
 
 TVM_REGISTER_NODE_TYPE(ReduceNode);
@@ -891,6 +900,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       auto* op = static_cast<const ReduceNode*>(node.get());
       p->stream << "reduce(combiner=" << op->combiner;
       p->stream << ", source=" << op->source;
+      p->stream << ", init=" << op->init;
       p->stream << ", axis=" << op->axis;
       p->stream << ", where=" << op->condition;
       p->stream << ", value_index=" << op->value_index;
index 166f950..4c5ea5b 100644 (file)
@@ -90,6 +90,9 @@ void ExprVisitor::VisitExpr_(const ReduceNode* op) {
     this->VisitExpr(r->dom->extent);
   });
   VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); });
+  if (!op->init.empty()) {
+    VisitArray(op->init, [this](const PrimExpr& e) { this->VisitExpr(e); });
+  }
   this->VisitExpr(op->condition);
 }
 
@@ -225,13 +228,15 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) {
 
   auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
   Array<PrimExpr> source = MutateArray(op->source, fexpr);
+  Array<PrimExpr> init = MutateArray(op->init, fexpr);
 
   PrimExpr condition = this->VisitExpr(op->condition);
 
-  if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition)) {
+  if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition) &&
+      init.same_as(op->init)) {
     return GetRef<PrimExpr>(op);
   } else {
-    return Reduce(op->combiner, source, axis, condition, op->value_index);
+    return Reduce(op->combiner, source, axis, condition, op->value_index, init);
   }
 }
 
index 75a483c..6dc485f 100644 (file)
@@ -632,54 +632,54 @@ PrimExpr isinf(PrimExpr x) {
 // isfinite
 PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); }
 
-PrimExpr sum(PrimExpr source, Array<IterVar> rdom) {
+PrimExpr sum(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
   Var x("x", source.dtype()), y("y", source.dtype());
   PrimExpr result = tir::Add(x, y);
   PrimExpr identity_element = make_zero(source.dtype());
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
 }
 
-PrimExpr all(PrimExpr source, Array<IterVar> rdom) {
+PrimExpr all(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
   CHECK(source.dtype().is_bool());
   Var x("x", source.dtype()), y("y", source.dtype());
   PrimExpr result = tir::And(x, y);
   PrimExpr identity_element = make_const(source.dtype(), true);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
 }
 
-PrimExpr any(PrimExpr source, Array<IterVar> rdom) {
+PrimExpr any(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
   CHECK(source.dtype().is_bool());
   Var x("x", source.dtype()), y("y", source.dtype());
   PrimExpr result = tir::Or(x, y);
   PrimExpr identity_element = make_const(source.dtype(), false);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
 }
 
-PrimExpr max(PrimExpr source, Array<IterVar> rdom) {
+PrimExpr max(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
   Var x("x", source.dtype()), y("y", source.dtype());
   PrimExpr result = tir::Max(x, y);
   PrimExpr identity_element = min_value(source.dtype());
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
 }
 
-PrimExpr min(PrimExpr source, Array<IterVar> rdom) {
+PrimExpr min(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
   Var x("x", source.dtype()), y("y", source.dtype());
   PrimExpr result = tir::Min(x, y);
   PrimExpr identity_element = max_value(source.dtype());
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
 }
 
-PrimExpr prod(PrimExpr source, Array<IterVar> rdom) {
+PrimExpr prod(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
   Var x("x", source.dtype()), y("y", source.dtype());
   PrimExpr result = tir::Mul(x, y);
   PrimExpr identity_element = make_const(source.dtype(), 1);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
 }
 
 // fmod
index c5d9d08..67f6dcf 100644 (file)
@@ -70,6 +70,63 @@ def test_reduce_prims():
     test_prim(tvm.te.min, np.amin)
     test_prim(tvm.te.max, np.amax)
 
+def test_init_imm():
+    n = tvm.runtime.convert(1027)
+    A = te.placeholder((n,), name='A')
+    k = te.reduce_axis((0, n))
+    B = te.compute((1,), lambda i: te.sum(A[k], axis=k, init=10.0), name='B')
+    # schedule
+    s = te.create_schedule(B.op)
+    # one line to build the function.
+    def check_target(target="llvm"):
+        if not tvm.runtime.enabled(target):
+            return
+        ctx = tvm.cpu(0)
+        fapi = tvm.lower(s, args=[A, B])
+        fsum = tvm.build(fapi,
+                         target=target,
+                         name="mysum")
+        # launch the kernel.
+        n = 1027
+        a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
+        b  = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx)
+        fsum(a, b)
+        res = 10.0 + np.sum(a.asnumpy(), axis=0)
+        tvm.testing.assert_allclose(
+            b.asnumpy(), res, rtol=1e-4)
+
+    check_target()
+
+def test_init():
+    n = tvm.runtime.convert(1027)
+    A = te.placeholder((n,n), name='A')
+    C = te.placeholder((n,n), name='C')
+    I = te.placeholder((n,n), name='I')
+    k = te.reduce_axis((0, n))
+    B = te.compute((n,n), lambda i,j: te.sum(A[i,k]*C[k,j], axis=k, init=I[i,j]), name='B')
+
+    # schedule
+    s = te.create_schedule(B.op)
+    # one line to build the function.
+    def check_target(target="llvm"):
+        if not tvm.runtime.enabled(target):
+            return
+        ctx = tvm.cpu(0)
+        fapi = tvm.lower(s, args=[A, C, I, B])
+        print(fapi)
+        mmult = tvm.build(fapi, target=target, name="mmult")
+        # launch the kernel.
+        n = 1027
+        a = tvm.nd.array(np.random.uniform(size=(n,n)).astype(A.dtype), ctx)
+        c = tvm.nd.array(np.random.uniform(size=(n,n)).astype(C.dtype), ctx)
+        ii = tvm.nd.array(np.random.uniform(size=(n,n)).astype(B.dtype), ctx)
+        b  = tvm.nd.array(np.zeros((n,n), dtype=B.dtype), ctx)
+        mmult(a, c, ii, b)
+        res = ii.asnumpy() + np.matmul(a.asnumpy(),c.asnumpy())
+        tvm.testing.assert_allclose(
+            b.asnumpy(), res, rtol=1e-4)
+
+    check_target()
 
 def test_rfactor():
     n = tvm.runtime.convert(1027)
@@ -101,6 +158,40 @@ def test_rfactor():
 
     check_target()
 
+def test_rfactor_init():
+    n = tvm.runtime.convert(1027)
+    A = te.placeholder((n,n), name='A')
+    C = te.placeholder((n,n), name='C')
+    I = te.placeholder((n,n), name='I')
+    k = te.reduce_axis((0, n))
+    B = te.compute((n,n), lambda i,j: te.sum(A[i,k]*C[k,j], axis=k, init=I[i,j]), name='B')
+
+    # schedule
+    s = te.create_schedule(B.op)
+    kf, ki = s[B].split(k, nparts=4)
+    BF = s.rfactor(B, kf, 1)
+    s[BF].parallel(BF.op.axis[0])
+    # one line to build the function.
+    def check_target(target="llvm"):
+        if not tvm.runtime.enabled(target):
+            return
+        ctx = tvm.cpu(0)
+        fapi = tvm.lower(s, args=[A, C, I, B])
+        print(fapi)
+        mmult = tvm.build(fapi, target=target, name="mmult")
+        # launch the kernel.
+        n = 1027
+        a = tvm.nd.array(np.random.uniform(size=(n,n)).astype(A.dtype), ctx)
+        c = tvm.nd.array(np.random.uniform(size=(n,n)).astype(C.dtype), ctx)
+        ii = tvm.nd.array(np.random.uniform(size=(n,n)).astype(B.dtype), ctx)
+        b  = tvm.nd.array(np.zeros((n,n), dtype=B.dtype), ctx)
+        mmult(a, c, ii, b)
+        res = ii.asnumpy() + np.matmul(a.asnumpy(),c.asnumpy())
+        tvm.testing.assert_allclose(
+            b.asnumpy(), res, rtol=1e-4)
+
+    check_target()
+
 def test_rfactor_factor_axis():
     n = tvm.runtime.convert(1027)
     A = te.placeholder((n,), name='A')
@@ -454,3 +545,6 @@ if __name__ == "__main__":
     test_rfactor_argmax()
     test_warp_reduction1()
     test_warp_reduction2()
+    test_init()
+    test_init_imm()
+    test_rfactor_init()
index e12f970..39d5d61 100644 (file)
@@ -181,8 +181,10 @@ def test_reduce_combiner_simplify():
     # Test that SimplifyCombiner makes use of vranges
     ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, -4))
     ck.verify(sum_or_prod(A[k], k), te.sum(A[k], k))
+    ck.verify(sum_or_prod(A[k], k, init=1), te.sum(A[k], k, init=1))
     ck.analyzer.update(dummy, tvm.arith.ConstIntBound(5, 9), True)
     ck.verify(sum_or_prod(A[k], k), prod(A[k], k))
+    ck.verify(sum_or_prod(A[k], k, init=1), prod(A[k], k, init=1))
     ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, 100), True)
     ck.verify(sum_and_prod((A[k], A[10-k]), k)[0], te.sum(A[k], k))
     ck.verify(sum_and_prod((A[k], A[10-k]), k)[1], prod(A[10-k], k))
@@ -219,6 +221,7 @@ def test_reduce_simplify():
     ck.verify(te.sum(tvm.tir.Select(k + j < 12, k + j, 0), [k, j]),
               te.sum(k + j, [k, j]))
     ck.verify(te.sum(A[3], []), A[3])
+    ck.verify(te.sum(A[3], [], where=k > 12, init=1.0), tvm.tir.const(1.0, dtype='float32'))
     # The rule below is not typical, removed for now
     ck.verify(te.sum(te.div(k, 10), k), te.sum(tvm.tir.const(0, "int32"), k))
 
index 25accde..5b4a309 100644 (file)
@@ -20,6 +20,7 @@ from tvm import te
 from tvm.testing import check_numerical_grads, assert_allclose
 from tvm import topi
 from tvm.topi.util import get_const_tuple
+import pytest
 
 import numpy as np
 
@@ -324,6 +325,15 @@ def test_stride_dilation():
     Y = topi.nn.pool(X, [3, 3], [3, 3], [0, 0, 0, 0], 'max')
     check_grad(Y, [X])
 
+@pytest.mark.xfail
+def test_reduction_init():
+    np.random.seed(0)
+    shape = (10, 10)
+    k = te.reduce_axis((0, 10), name="k")
+    A0 = te.placeholder(shape, name='A0')
+
+    B = te.compute((10,), lambda i: te.sum(A0[i, k]*A0[k, i], axis=k, init=0.0), name='B')
+    check_grad(B, A0)
 
 if __name__ == "__main__":
     test_basic_operation()