Replaces most of the ComptuteReduce using foldl.
* \brief sum of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
+ * \return The result.
*/
TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis);
* \brief logical Or of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
+ * \return The result.
*/
TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis);
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
+ * \return The result.
*/
TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis);
* \brief max of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
+ * \return The result.
*/
TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis);
* \brief product of of source expression over axis
* \param source The source expression.
* \param axis List of iteration variables that will be used for reduction.
+ * \return The result.
*/
TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis);
inline bool is_const(const PrimExpr& x);
/*!
+ * \brief Left fold.
+ * \param freduce The reduction function.
+ * \param init_value The initial value.
+ * \param values The values to be folded.
+ * \return The result.
+ * \tparam FReduce The type of the reduction.
+ */
+template <typename FReduce>
+inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values);
+
+/*!
* \brief Check whether x is a constant power of two
* If x is power of two, write the power to the shift.
*
}
return make_const(t, 0);
}
+
+template <typename FReduce>
+inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values) {
+ for (PrimExpr val : values) {
+ init_value = freduce(init_value, val);
+ }
+ return init_value;
+}
+
} // namespace tir
// additional const expression overloading
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file compute_expr.h
- * \brief Utility to invoke certan compute operations.
- */
-#ifndef TVM_ARITH_COMPUTE_EXPR_H_
-#define TVM_ARITH_COMPUTE_EXPR_H_
-
-#include <tvm/tir/expr.h>
-#include <tvm/tir/op.h>
-
-#include <algorithm>
-#include <limits>
-
-namespace tvm {
-namespace arith {
-
-/*!
- * \brief Compute the expression with the given binary op.
- * \param lhs The left operand
- * \param rhs The right operand
- * \tparam Op the computation operator
- * \return The result.
- */
-template <typename OP>
-inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) {
- return OP::make(lhs, rhs);
-}
-
-/*!
- * \brief Compute an reduction with Op
- * \param values The input values.
- * \param empty_value The value when return if it is empty, can be Expr()
- * which will cause an error to be rasied.
- * \tparam Op The computation operator
- * \return The result.
- */
-template <typename Op>
-inline PrimExpr ComputeReduce(const Array<PrimExpr>& values, PrimExpr empty_value);
-
-template <>
-inline PrimExpr Compute<tir::AddNode>(PrimExpr a, PrimExpr b) {
- return a + b;
-}
-
-template <>
-inline PrimExpr Compute<tir::SubNode>(PrimExpr a, PrimExpr b) {
- return a - b;
-}
-
-template <>
-inline PrimExpr Compute<tir::MulNode>(PrimExpr a, PrimExpr b) {
- return a * b;
-}
-
-template <>
-inline PrimExpr Compute<tir::DivNode>(PrimExpr a, PrimExpr b) {
- return truncdiv(a, b);
-}
-
-template <>
-inline PrimExpr Compute<tir::ModNode>(PrimExpr a, PrimExpr b) {
- return truncmod(a, b);
-}
-
-template <>
-inline PrimExpr Compute<tir::MaxNode>(PrimExpr a, PrimExpr b) {
- return max(a, b);
-}
-
-template <>
-inline PrimExpr Compute<tir::MinNode>(PrimExpr a, PrimExpr b) {
- return min(a, b);
-}
-
-template <typename Op>
-inline PrimExpr ComputeReduce(const Array<PrimExpr>& values, PrimExpr empty_value) {
- if (values.size() == 0U) {
- CHECK(empty_value.defined());
- return empty_value;
- }
- PrimExpr res = values[0];
- for (size_t i = 1; i < values.size(); ++i) {
- res = Compute<Op>(res, values[i]);
- }
- return res;
-}
-
-} // namespace arith
-} // namespace tvm
-#endif // TVM_ARITH_COMPUTE_EXPR_H_
#include <vector>
-#include "../../../arith/compute_expr.h"
#include "../../transforms/infer_layout_util.h"
#include "../../transforms/pattern_util.h"
#include "../op_common.h"
#include <utility>
#include <vector>
-#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
#include "../../tir/transforms/ir_util.h"
#include "llvm_common.h"
#include <cctype>
#include <iomanip>
-#include "../../arith/compute_expr.h"
#include "../../arith/pattern_match.h"
namespace tvm {
#include <tvm/runtime/container.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <string>
-#include "../../arith/compute_expr.h"
-
namespace tvm {
namespace codegen {
#include <unordered_set>
#include <utility>
-#include "../../arith/compute_expr.h"
#include "../../arith/interval_set.h"
#include "../schedule/message_passing.h"
#include "op_util.h"
}
}
- return IfThenElseNode::make(arith::ComputeReduce<tir::OrNode>(conds, const_true(1)), update,
- body);
+ auto cond = foldl([](PrimExpr a, PrimExpr b) { return a || b; }, const_false(1), conds);
+ return IfThenElseNode::make(cond, update, body);
}
} // namespace te
#include <string>
-#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
#include "../schedule/message_passing.h"
#include <unordered_set>
-#include "../../arith/compute_expr.h"
#include "./compute_op.h"
#include "./op_util.h"
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
-#include "../../arith/compute_expr.h"
-
namespace tvm {
namespace te {
*/
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_set>
-#include "../../arith/compute_expr.h"
#include "../../tir/transforms/ir_util.h"
#include "message_passing.h"
#include "operation_inline.h"
using tir::SelectNode;
if (predicates.size() == 0) return body;
const ReduceNode* reduce = body.as<ReduceNode>();
+ auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
+
if (reduce) {
auto n = make_object<ReduceNode>(*reduce);
- n->condition = n->condition && arith::ComputeReduce<tir::AndNode>(predicates, PrimExpr());
+ n->condition = foldl(fand, n->condition, predicates);
return PrimExpr(n);
}
- return SelectNode::make(arith::ComputeReduce<tir::AndNode>(predicates, PrimExpr()), body,
- make_zero(body.dtype()));
+ return SelectNode::make(foldl(fand, const_true(1), predicates), body, make_zero(body.dtype()));
}
// Replace data flow appears in all stages given the tensor change.
const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
CHECK(reduce) << "Can only rfactor non-inline reductions";
predicates.push_back(reduce->condition);
- PrimExpr predicate = likely(arith::ComputeReduce<tir::AndNode>(predicates, PrimExpr()));
+ auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
+
+ PrimExpr predicate = likely(foldl(fand, const_true(1), predicates));
std::unordered_map<const VarNode*, PrimExpr> vsub;
#include <unordered_map>
-#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <iterator>
#include <stack>
-#include "../../arith/compute_expr.h"
-
namespace tvm {
namespace tir {
int highest_dim = 0;
extent = self->strides[highest_dim] * self->shape[highest_dim] - offset;
} else {
- extent = arith::ComputeReduce<tir::MulNode>(self->shape, PrimExpr()) - offset;
+ auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; };
+ extent = foldl(fmul, make_const(DataType::Int(32), 1), self->shape) - offset;
}
PrimExpr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
#include <tvm/runtime/device_api.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
-#include "../../arith/compute_expr.h"
#include "ir_util.h"
namespace tvm {
<< " expected to be compact array";
if (conds.size() != 0) {
auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str());
- Stmt check = AssertStmtNode::make(arith::ComputeReduce<tir::AndNode>(conds, PrimExpr()),
- stride_msg, EvaluateNode::make(0));
+ auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
+ Stmt check = AssertStmtNode::make(foldl(fand, const_true(1), conds), stride_msg,
+ EvaluateNode::make(0));
check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
}
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
-#include "../../arith/compute_expr.h"
#include "ir_util.h"
namespace tvm {
Stmt VisitStmt_(const AllocateNode* op) final {
auto it = dbuffer_info_.find(op->buffer_var.get());
if (it != dbuffer_info_.end()) {
+ auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; };
it->second.stride =
- arith::ComputeReduce<MulNode>(op->extents, PrimExpr()) * op->dtype.lanes();
+ foldl(fmul, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes();
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
Array<PrimExpr> new_extents{make_const(op->extents[0].dtype(), 2)};
#include <unordered_set>
-#include "../../arith/compute_expr.h"
#include "ir_util.h"
namespace tvm {
// always rewrite if not allow sharing.
if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
// place v on highest dimension.
- PrimExpr stride = arith::ComputeReduce<MulNode>(op->extents, PrimExpr()) * op->dtype.lanes();
+ auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; };
+ PrimExpr stride =
+ foldl(fmul, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes();
Array<PrimExpr> other;
other.push_back(make_const(op->extents[0].dtype(), num_threads_));
for (PrimExpr e : extents) {
#include <unordered_set>
-#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
#include "ir_util.h"
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
-#include "../../arith/compute_expr.h"
#include "../../arith/pattern_match.h"
#include "../../runtime/thread_storage_scope.h"
#include "storage_access.h"
#include <tvm/target/target_info.h>
+#include <tvm/tir/op.h>
#include <string>
#include <utility>
-#include "../../arith/compute_expr.h"
#include "ir_util.h"
namespace tvm {
#include <unordered_map>
-#include "../../arith/compute_expr.h"
#include "../../arith/ir_visitor_with_analyzer.h"
#include "../../runtime/thread_storage_scope.h"
#include "arg_binder.h"
#include <unordered_map>
#include <unordered_set>
-#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
#include "ir_util.h"
alloc_type = op->dtype;
}
}
+
+ auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; };
+
if (e->allocs.size() == 1) {
// simply use the original allocation.
- PrimExpr sz = arith::ComputeReduce<MulNode>(e->allocs[0]->extents,
- make_const(DataType::Int(32), 1));
+ PrimExpr sz = foldl(fmul, make_const(DataType::Int(32), 1), e->allocs[0]->extents);
e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition,
EvaluateNode::make(0));
if (e->scope.tag.length() != 0) {
// Build a merged allocation
PrimExpr combo_size;
for (const AllocateNode* op : e->allocs) {
- PrimExpr sz =
- arith::ComputeReduce<MulNode>(op->extents, make_const(DataType::Int(32), 1));
+ PrimExpr sz = foldl(fmul, make_const(DataType::Int(32), 1), op->extents);
auto nbits = op->dtype.bits() * op->dtype.lanes();
if (const auto* imm = sz.as<IntImmNode>()) {
if (imm->value > std::numeric_limits<int>::max() / nbits) {
#include <unordered_set>
#include <vector>
-#include "../../arith/compute_expr.h"
#include "ir_util.h"
namespace tvm {
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include <vector>
-#include "../../arith/compute_expr.h"
-
namespace tvm {
namespace tir {
}
}
- PrimExpr VisitExpr_(const AddNode* op) final { return AddSubVec(op); }
- PrimExpr VisitExpr_(const SubNode* op) final { return AddSubVec(op); }
+ PrimExpr VisitExpr_(const AddNode* op) final {
+ return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; });
+ }
+
+ PrimExpr VisitExpr_(const SubNode* op) final {
+ return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; });
+ }
+
PrimExpr VisitExpr_(const MulNode* op) final {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
}
- template <typename T>
- PrimExpr AddSubVec(const T* op) {
+ template <typename T, typename FCompute>
+ PrimExpr AddSubVec(const T* op, FCompute fcompute) {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
const RampNode* b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>();
if (a.dtype().lanes() == 1 && b_ramp) {
- return RampNode::make(
- arith::Compute<T>(a, b_ramp->base),
- arith::Compute<T>(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes);
+ return RampNode::make(fcompute(a, b_ramp->base),
+ fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride),
+ b_ramp->lanes);
}
if (b.dtype().lanes() == 1 && a_ramp) {
- return RampNode::make(arith::Compute<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
+ return RampNode::make(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
}
}
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));