std::function<void()> exit_;
};
-/*!
- * \brief Analyzer that contains bunch of sub-analyzers.
- *
- * Each sub-analyzer can make use of another sub-analyzer
- * by weak reference of this.
- *
- * NOTE for sub-analyzer developers:
- * If the analyzer uses memoization, we need to clear the internal
- * cache when information about a Var has been overrideen.
- */
-class Analyzer {
- public:
- /*! \brief sub-analyzer: const integer bound */
- ConstIntBoundAnalyzer const_int_bound;
- /*! \brief sub-analyzer: modular set */
- ModularSetAnalyzer modular_set;
- /*! \brief sub-analyzer rewrite simplify */
- RewriteSimplifier rewrite_simplify;
- /*! \brief sub-analyzer canonical simplify */
- CanonicalSimplifier canonical_simplify;
- /*! \brief constructor */
- Analyzer();
- /*!
- * \brief Notify all the sub-analyzers that var
- * is created and binded to expr.
- *
- * Each var can only be binded once.
- *
- * \param var The variable.
- * \param expr The expression we bind to.
- */
- void Bind(const VarExpr& var, const Expr& expr);
- /*!
- * \brief Notify all the sub-analyzers that var
- * is created and binded to a range.
- *
- * Each var can only be binded once.
- *
- * \param var The variable.
- * \param range The range we bind to.
- */
- void Bind(const VarExpr& var, const Range& range);
- /*!
- * \brief Whether can we proof expr >= val.
-
- * Non-negative proof is very useful in integer analysis
- * to lower divisions and mods given difference in trunc and ceil mode.
- *
- * \param expr The expression.
- * \param lower_bound The lower bound.
- * \return Whether we can proof it.
- *
- * \note Analyzer will call into sub-analyzers to get the result.
- */
- bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
-};
-
//-----------------------------------------------
-// Integer set abstraction API.
+// Integer set data structure.
//
// This is a API build on top of the base
// integer analysis API to provide set analysis.
//------------------------------------------------
/*!
- * \brief Sign of an expression or set.
+ * \brief Sign type of an integer expression.
*/
enum SignType {
kPositive,
kUnknown
};
-// internal node container of int set.
-struct IntSetNode;
+/*!
+ * \brief Base class of all IntSet containers.
+ */
+struct IntSetNode : public Node {
+ static constexpr const char* _type_key = "IntSet";
+ TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
+};
/*!
* \brief Integer set class, represent a set of integers in one dimension.
* \return The covering range.
*/
Range cover_range(Range max_range) const;
- /*!
- * \brief find an interval that covers the set.
- * \return The covering interval set.
- */
- IntSet cover_interval() const;
/*! \return Lower bound of the set */
Expr min() const;
/*! \return upper bound of the set */
};
/*!
- * \brief Base class of all IntSet containers.
+ * \brief Integer set analyzer.
*/
-struct IntSetNode : public Node {
- static constexpr const char* _type_key = "IntSet";
- TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
+class IntSetAnalyzer {
+ public:
+ /*!
+ * \brief Find a symbolic integer set that contains all possible values of
+ * expr given the domain of each variables.
+ *
+ * \param expr The expression of interest.
+ * \param dom_map The domain map to indicate which variable to relax.
+ * \return the result of the analysis.
+ */
+ IntSet operator()(const Expr& expr, const Map<Var, IntSet>& dom_map);
+
+ private:
+ friend class Analyzer;
+ explicit IntSetAnalyzer(Analyzer* parent);
+ ~IntSetAnalyzer();
+ class Impl;
+ /*! \brief Internal impl */
+ Impl* impl_;
};
/*!
- * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
- * Where coeff[i] and base are invariant of var[j] for all i and j.
+ * \brief Analyzer that contains bunch of sub-analyzers.
*
- * \param e The expression to be detected.
- * \param vars List of variables to be used in detection.
- * \return [coeff[i]] if it is possible, empty array if it is not.
- */
-Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars);
-
-/*!
- * \brief Detect if expression corresponds to clip bound of the vars
+ * Each sub-analyzer can make use of another sub-analyzer
+ * by weak reference of this.
*
- * \param e The expression to be detected.
- * \param vars List of variables to be used in detection.
- * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
- * return empty if the e does not match the pattern.
+ * NOTE for sub-analyzer developers:
+ * If the analyzer uses memoization, we need to clear the internal
+ * cache when information about a Var has been overridden.
*/
-Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars);
+class Analyzer {
+ public:
+ /*! \brief sub-analyzer: const integer bound */
+ ConstIntBoundAnalyzer const_int_bound;
+ /*! \brief sub-analyzer: modular set */
+ ModularSetAnalyzer modular_set;
+ /*! \brief sub-analyzer rewrite simplify */
+ RewriteSimplifier rewrite_simplify;
+ /*! \brief sub-analyzer canonical simplify */
+ CanonicalSimplifier canonical_simplify;
+ /*! \brief sub-analyzer: int set */
+ IntSetAnalyzer int_set;
+ /*! \brief constructor */
+ Analyzer();
+ /*!
+ * \brief Notify all the sub-analyzers that var
+ * is created and binded to expr.
+ *
+ * Each var can only be binded once.
+ *
+ * \param var The variable.
+ * \param expr The expression we bind to.
+ */
+ void Bind(const VarExpr& var, const Expr& expr);
+ /*!
+ * \brief Notify all the sub-analyzers that var
+ * is created and binded to a range.
+ *
+ * Each var can only be binded once.
+ *
+ * \param var The variable.
+ * \param range The range we bind to.
+ */
+ void Bind(const VarExpr& var, const Range& range);
+ /*!
+ * \brief Whether can we prove expr >= val.
+
+ * Non-negative proof is very useful in integer analysis
+ * to lower divisions and mods given difference in trunc and ceil mode.
+ *
+ * \param expr The expression.
+ * \param lower_bound The lower bound.
+ * \return Whether we can prove it.
+ *
+ * \note Analyzer will call into sub-analyzers to get the result.
+ */
+ bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
+};
+//-----------------------------------------------
+// Integer set legacy API.
+//------------------------------------------------
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
*/
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);
+// Expression pattern detector.
+/*!
+ * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
+ * Where coeff[i] and base are invariant of var[j] for all i and j.
+ *
+ * \param e The expression to be detected.
+ * \param vars List of variables to be used in detection.
+ * \return [coeff[i]] if it is possible, empty array if it is not.
+ */
+Array<Expr> DetectLinearEquation(const Expr& e,
+ const Array<Var>& vars);
+
+/*!
+ * \brief Detect if expression corresponds to clip bound of the vars
+ *
+ * \param e The expression to be detected.
+ * \param vars List of variables to be used in detection.
+ * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
+ * return empty if the e does not match the pattern.
+ */
+Array<Expr> DetectClipBound(const Expr& e,
+ const Array<Var>& vars);
+
// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
return _api_internal._IntSetIsEverything(self)
-@register_node
+@register_node("arith.IntervalSet")
class IntervalSet(IntSet):
- """Represent set of continuous interval"""
- def min(self):
- """get the minimum value"""
- return _api_internal._IntervalSetGetMin(self)
-
- def max(self):
- """get the maximum value"""
- return _api_internal._IntervalSetGetMax(self)
+ """Represent set of continuous interval [min_value, max_value]
+ Parameters
+ ----------
+ min_value : Expr
+ The minimum value in the interval.
-@register_node
-class StrideSet(IntSet):
- """Represent set of strided integers"""
+ max_value : Expr
+ The maximum value in the interval.
+ """
+ def __init__(self, min_value, max_value):
+ self.__init_handle_by_constructor__(
+ _make_IntervalSet, min_value, max_value)
@register_node("arith.ModularSet")
self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify")
self._canonical_simplify = _mod("canonical_simplify")
+ self._int_set = _mod("int_set")
self._enter_constraint_context = _mod("enter_constraint_context")
def const_int_bound(self, expr):
"""
return self._canonical_simplify(expr)
+ def int_set(self, expr, dom_map):
+ """Compute a symbolic IntSet that covers expr for all values in dom_map.
+
+ Parameters
+ ----------
+ expr : tvm.Expr
+ The expression.
+
+ dom_map : Dict[Var, tvm.arith.IntSet]
+ The domain for variables to be relaxed.
+
+ Returns
+ -------
+ result : IntSet
+ The result.
+ """
+ return self._int_set(expr, dom_map)
+
def bind(self, var, expr):
"""Bind a variable to the expression.
TVM_REGISTER_API("arith.intset_interval")
.set_body_typed(IntSet::interval);
+
TVM_REGISTER_API("arith.DetectLinearEquation")
.set_body_typed(DetectLinearEquation);
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->canonical_simplify(args[0]);
});
+ } else if (name == "int_set") {
+ return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+ *ret = self->int_set(args[0], args[1]);
+ });
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr();
: const_int_bound(this),
modular_set(this),
rewrite_simplify(this),
- canonical_simplify(this) {
+ canonical_simplify(this),
+ int_set(this) {
}
void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<ir::IntImm>()) {
- return ptr->value > lower_bound;
+ return ptr->value >= lower_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->min_value >= lower_bound) return true;
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <unordered_set>
#include <unordered_map>
+#include "int_set.h"
namespace tvm {
namespace arith {
using namespace ir;
-using HalideIR::Internal::Interval;
// a visitor to find the path to the target variable
// from a expression.
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success) return IntSet::nothing();
- Expr min = Interval::neg_inf, max = Interval::pos_inf;
+ Expr min = neg_inf(), max = pos_inf();
if (d.is_greater) {
min = d.result;
} else {
*/
/*!
- * Copyright (c) 2019 by Contributors
* \file canonical_simplify.cc
* \brief Canonical form based simplification.
*/
if (TryCompare(temp, cval) == kLT) {
return temp;
} else {
- return SplitModConst(ToSplitExpr(temp), cval);
+ // contonue to use logic below.
+ a = extra;
+ psum = a.as<SumExprNode>();
+ CHECK(psum != nullptr);
}
}
}
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#include <tvm/ir.h>
-#include <arithmetic/Interval.h>
#include <limits>
+#include <algorithm>
namespace tvm {
namespace arith {
template<>
inline Expr ComputeExpr<ir::Max>(Expr a, Expr b) {
- return HalideIR::Internal::Interval::make_max(a, b);
+ return max(a, b);
}
template<>
inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
- return HalideIR::Internal::Interval::make_min(a, b);
+ return min(a, b);
}
template<typename Op>
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
});
+ if (a.same_as(b)) return a;
return Expr();
}
if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
});
+ if (a.same_as(b)) return a;
return Expr();
}
return Expr();
}
+/*! \brief Helper namespace for symbolic value limits */
+struct SymbolicLimits {
+ /*! \brief positive infinity */
+ static Expr pos_inf_;
+ /*! \brief negative infinity */
+ static Expr neg_inf_;
+};
+
+/*!
+ * \brief Opaque expression representing positive infinity.
+ *
+ * It can can only be used as parameter of by min/max
+ * for integer analysis and cannot be used in normal expressions.
+ *
+ * \return positive infinity.
+ */
+inline Expr pos_inf() {
+ return SymbolicLimits::pos_inf_;
+}
+
+/*!
+ * \brief Check if value is positive infinity.
+ * \param value The value to be checked.
+ *
+ * \return The check result.
+ */
+inline bool is_pos_inf(const Expr& value) {
+ return value.same_as(SymbolicLimits::pos_inf_);
+}
+
+/*!
+ * \brief Opaque expression representing negative infinity.
+ *
+ * It can can only be used as parameter of by min/max
+ * for integer analysis and cannot be used in normal expressions.
+ *
+ * \return negative infinity.
+ */
+inline Expr neg_inf() {
+ return SymbolicLimits::neg_inf_;
+}
+
+/*!
+ * \brief Check if value is negative infinity.
+ * \param value The value to be checked.
+ *
+ * \return The check result.
+ */
+inline bool is_neg_inf(const Expr& value) {
+ return value.same_as(SymbolicLimits::neg_inf_);
+}
+
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_CONST_FOLD_H_
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
/*!
* Copyright (c) 2017 by Contributors
- * \file bound_deducer.cc
- * \brief Utility to deduce bound of expression
+ * \file detect_linear_equation.cc
+ * \brief Utility to detect patterns in the expression.
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2017 by Contributors
* \file int_set.cc
* \brief The integer set functions
*/
#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include <tvm/arithmetic.h>
#include <tvm/ir_functor_ext.h>
-#include <arithmetic/Interval.h>
+#include <tvm/api_registry.h>
+#include <utility>
+#include <algorithm>
#include <unordered_map>
-#include "compute_expr.h"
-#include "int_set_internal.h"
+#include "int_set.h"
+#include "pattern_match.h"
namespace tvm {
namespace arith {
-using HalideIR::Internal::Interval;
-using namespace ir;
-
-inline IntSet IntSet::cover_interval() const {
- if ((*this).as<IntervalSet>()) return *this;
- const StrideSet* s = (*this).as<StrideSet>();
- if (s) {
- CHECK_NE(s->extents.size(), 0U);
- Expr max = s->base.max;
- for (size_t i = 0; i < s->extents.size(); ++i) {
- max = max + s->extents[i] * s->strides[i] - s->strides[i];
- }
- return IntervalSet::make(s->base.min, Simplify(max));
- }
- LOG(FATAL) << "cannot convert set " << (*this)->type_key() << " to interval";
- return IntSet::everything();
-}
-
-Range IntSet::cover_range(Range max_range) const {
- IntSet temp;
- const IntervalSet* s_int = (*this).as<IntervalSet>();
- if (s_int == nullptr) {
- temp = this->cover_interval();
- s_int = temp.as<IntervalSet>();
- }
- if (s_int->i.is_bounded()) {
- return Range::make_by_min_extent(
- s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min));
- }
- return max_range;
-}
-
-Expr IntSet::min() const {
- const IntervalSet* s_int = (*this).as<IntervalSet>();
- CHECK(s_int);
- return s_int->i.min;
-}
-
-Expr IntSet::max() const {
- const IntervalSet* s_int = (*this).as<IntervalSet>();
- CHECK(s_int);
- return s_int->i.max;
-}
-
-bool IntSet::is_nothing() const {
- const IntervalSet* s_int = (*this).as<IntervalSet>();
- return (s_int && s_int->i.is_empty());
-}
-
-bool IntSet::is_everything() const {
- const IntervalSet* s_int = (*this).as<IntervalSet>();
- return (s_int && s_int->i.is_everything());
-}
+Expr SymbolicLimits::pos_inf_ = Var("pos_inf", Handle());
+Expr SymbolicLimits::neg_inf_ = Var("neg_inf", Handle());
-bool IntSet::is_single_point() const {
- const IntervalSet* s_int = (*this).as<IntervalSet>();
- return (s_int && s_int->i.is_single_point());
+IntervalSet::IntervalSet(Expr min_value, Expr max_value) {
+ auto node = make_node<IntervalSetNode>();
+ node->min_value = std::move(min_value);
+ node->max_value = std::move(max_value);
+ node_ = std::move(node);
}
-bool IntSet::can_prove_positive() const {
- const IntervalSet* s_int = (*this).as<IntervalSet>();
- return (s_int && is_positive_const(ir::Simplify(s_int->i.min)));
+IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) {
+ return IntervalSet(min_value, max_value);
}
-bool IntSet::can_prove_negative() const {
- const IntervalSet* s_int = (*this).as<IntervalSet>();
- return (s_int && is_negative_const(ir::Simplify(s_int->i.max)));
-}
+TVM_REGISTER_API("arith._make_IntervalSet")
+.set_body_typed(MakeIntervalSet);
-bool IntSet::can_prove_non_positive() const {
- if (const IntervalSet* s_int = (*this).as<IntervalSet>()) {
- auto max = ir::Simplify(s_int->i.max);
- return is_zero(max) || is_negative_const(max);
- }
- return false;
-}
-bool IntSet::can_prove_non_negative() const {
- if (const IntervalSet* s_int = (*this).as<IntervalSet>()) {
- // Any reason why we should or should not use can_prove() to implement
- // these functions?
- auto min = ir::Simplify(s_int->i.min);
- return is_zero(min) || is_positive_const(min);
- }
- return false;
-}
-
-
-SignType IntSet::sign_type() const {
- if (can_prove_positive()) {
- return kPositive;
- } else if (can_prove_negative()) {
- return kNegative;
- } else if (is_single_point() && is_zero(point_value())) {
- return kZero;
+IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
+ Expr max_value = min(a->max_value, b->max_value);
+ Expr min_value = max(a->min_value, b->min_value);
+ if ((max_value.type().is_int() || max_value.type().is_uint()) &&
+ (min_value.type().is_int() || min_value.type().is_uint()) &&
+ analyzer->CanProveGreaterEqual(min_value - max_value, 1)) {
+ return IntervalSet::Empty();
} else {
- return kUnknown;
- }
-}
-Expr IntSet::point_value() const {
- const IntervalSet* s_int = (*this).as<IntervalSet>();
- CHECK(s_int && s_int->i.is_single_point());
- return s_int->i.min;
-}
-
-IntSet IntSet::nothing() {
- return IntervalSet::make(Interval::nothing());
-}
-
-IntSet IntSet::everything() {
- return IntervalSet::make(Interval::everything());
-}
-
-IntSet IntSet::single_point(Expr x) {
- return IntervalSet::make(Interval::single_point(x));
-}
-
-IntSet IntSet::range(Range r) {
- // must make sure it can be matched back by MatchRange.
- if (is_one(r->extent)) {
- return IntSet::single_point(r->min);
- }
- if (is_positive_const(r->extent) && is_const(r->min)) {
- return IntervalSet::make(
- r->min, ComputeExpr<Sub>(ComputeExpr<Add>(r->extent, r->min), 1));
- }
- return IntervalSet::make(r->min, (r->extent + r->min) - 1);
-}
-
-IntSet IntSet::interval(Expr min, Expr max) {
- if (min.same_as(max)) {
- return IntSet::single_point(min);
- }
- return IntervalSet::make(min, max);
-}
-
-inline bool prove_equal(Expr lhs, Expr rhs) {
- return is_zero(ir::Simplify(lhs - rhs));
-}
-
-// Check if a is created from b.
-bool IntSet::match_range(const Range& b) const {
- const IntSet& a = *this;
- const IntervalSet* a_int = a.as<IntervalSet>();
- if (!a_int) return false;
- const Interval& i = a_int->i;
- return prove_equal(i.min, b->min) &&
- prove_equal(i.max, ComputeExpr<Sub>(ComputeExpr<Add>(b->extent, b->min), 1));
-}
-
-inline bool MatchPoint(const IntSet& a,
- const Expr& b) {
- const IntervalSet* a_int = a.as<IntervalSet>();
- if (!a_int) return false;
- const Interval& i = a_int->i;
- return i.is_single_point() && i.min.same_as(b);
-}
-
-IntSet Union(const Array<IntSet>& sets) {
- if (sets.size() == 0) return IntSet::nothing();
- if (sets.size() == 1) return sets[0];
- Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
- for (size_t i = 1; i < sets.size(); ++i) {
- IntSet s = sets[i].cover_interval();
- const Interval& y = s.as<IntervalSet>()->i;
- x.include(y);
+ return IntervalSet(min_value, max_value);
}
- x.max = ir::Simplify(x.max);
- x.min = ir::Simplify(x.min);
- return IntervalSet::make(x);
}
-IntSet Intersect(const Array<IntSet>& sets) {
- Interval x = sets[0].cover_interval().as<IntervalSet>()->i;
- for (size_t i = 1; i < sets.size(); ++i) {
- Interval y = sets[i].cover_interval().as<IntervalSet>()->i;
- x = Interval::make_intersection(x, y);
- }
- return IntervalSet::make(x);
+IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
+ Expr max_value = max(a->max_value, b->max_value);
+ Expr min_value = min(a->min_value, b->min_value);
+ return IntervalSet(min_value, max_value);
}
// type traits
static const bool value = true; \
};
-// interval related.
-template<typename OP>
-inline IntSet CombineInterval(Interval a, Interval b) {
- if (a.is_single_point() && b.is_single_point()) {
- return IntSet::single_point(ComputeExpr<OP>(a.min, b.min));
- }
- LOG(WARNING) << "Return Everything in CombineInterval " << OP::_type_key;
- return IntSet::everything();
+TVM_DECLARE_LOGICAL_OP(And);
+TVM_DECLARE_LOGICAL_OP(Or);
+TVM_DECLARE_LOGICAL_OP(EQ);
+TVM_DECLARE_LOGICAL_OP(NE);
+TVM_DECLARE_LOGICAL_OP(GE);
+TVM_DECLARE_LOGICAL_OP(GT);
+TVM_DECLARE_LOGICAL_OP(LE);
+TVM_DECLARE_LOGICAL_OP(LT);
+TVM_DECLARE_LOGICAL_OP(Not);
+
+/*!
+ * \brief Combine two interval set under arithmetic operations.
+ * \note this can possibly relax the set.
+ */
+template<typename Op>
+inline IntervalSet Combine(Analyzer* analyzer,
+ IntervalSet a,
+ IntervalSet b) {
+ if (a->IsSinglePoint() && b->IsSinglePoint()) {
+ Expr res = TryConstFold<Op>(a->min_value, b->min_value);
+ if (!res.defined()) res = Op::make(a->min_value, b->min_value);
+ return IntervalSet::SinglePoint(res);
+ }
+ if (is_logical_op<Op>::value) {
+ return IntervalSet(make_const(a->min_value.type(), 0),
+ make_const(a->min_value.type(), 1));
+ }
+ if (a->IsEmpty()) return a;
+ if (b->IsEmpty()) return b;
+ if (a->IsEverything()) return a;
+ if (b->IsEverything()) return b;
+ return IntervalSet::Everything();
}
template<>
-inline IntSet CombineInterval<Add>(Interval a, Interval b) {
- if (a.is_single_point() && b.is_single_point()) {
- return IntSet::single_point(ComputeExpr<Add>(a.min, b.min));
- }
- Interval r = Interval::everything();
- if (a.has_lower_bound() && b.has_lower_bound()) {
- r.min = ComputeExpr<Add>(a.min, b.min);
- }
- if (a.has_upper_bound() && b.has_upper_bound()) {
- r.max = ComputeExpr<Add>(a.max, b.max);
- }
- return IntervalSet::make(r);
+inline IntervalSet Combine<ir::Add>(Analyzer* analyer,
+ IntervalSet a,
+ IntervalSet b) {
+ if (a->IsSinglePoint() && b->IsSinglePoint()) {
+ return IntervalSet::SinglePoint(a->min_value + b->min_value);
+ }
+ if (a->IsEmpty()) return a;
+ if (b->IsEmpty()) return b;
+ Expr min_value =
+ a->HasLowerBound() && b->HasLowerBound() ?
+ a->min_value + b->min_value : neg_inf();
+ Expr max_value =
+ a->HasUpperBound() && b->HasUpperBound() ?
+ a->max_value + b->max_value : pos_inf();
+ return IntervalSet(min_value, max_value);
}
template<>
-inline IntSet CombineInterval<Sub>(Interval a, Interval b) {
- if (a.is_single_point() && b.is_single_point()) {
- return IntSet::single_point(ComputeExpr<Sub>(a.min, b.min));
+inline IntervalSet Combine<ir::Sub>(Analyzer* analyer,
+ IntervalSet a,
+ IntervalSet b) {
+ if (a->IsSinglePoint() && b->IsSinglePoint()) {
+ return IntervalSet::SinglePoint(a->min_value - b->min_value);
}
- Interval r = Interval::everything();
- if (a.has_lower_bound() && b.has_upper_bound()) {
- r.min = ComputeExpr<Sub>(a.min, b.max);
- }
- if (a.has_upper_bound() && b.has_lower_bound()) {
- r.max = ComputeExpr<Sub>(a.max, b.min);
- }
- return IntervalSet::make(r);
+ if (a->IsEmpty()) return a;
+ if (b->IsEmpty()) return b;
+ Expr min_value =
+ a->HasLowerBound() && b->HasUpperBound() ?
+ a->min_value - b->max_value : neg_inf();
+ Expr max_value =
+ a->HasUpperBound() && b->HasLowerBound() ?
+ a->max_value - b->min_value : pos_inf();
+ return IntervalSet(min_value, max_value);
}
+
template<>
-inline IntSet CombineInterval<Mul>(Interval a, Interval b) {
- if (a.is_single_point() && b.is_single_point()) {
- return IntSet::single_point(ComputeExpr<Mul>(a.min, b.min));
- }
- if (a.is_single_point() && !b.is_single_point()) {
+inline IntervalSet Combine<ir::Mul>(Analyzer* analyzer,
+ IntervalSet a,
+ IntervalSet b) {
+ if (a->IsSinglePoint() && b->IsSinglePoint()) {
+ return IntervalSet::SinglePoint(a->min_value * b->min_value);
+ }
+ if (a->IsEmpty()) return a;
+ if (b->IsEmpty()) return b;
+ if (a->IsSinglePoint()) {
std::swap(a, b);
}
- if (b.is_single_point()) {
- if (is_zero(b.min)) return IntSet::single_point(0);
- if (is_one(b.min)) return IntervalSet::make(a);
- Expr e1 = a.has_lower_bound() ? ComputeExpr<Mul>(a.min, b.min) : a.min;
- Expr e2 = a.has_upper_bound() ? ComputeExpr<Mul>(a.max, b.min) : a.max;
- // no relaxation is needed in here due to set is inclusive
- // TODO(tqchen): consider convert to StrideSet.
- if (is_positive_const(b.min)) {
- return IntervalSet::make(e1, e2);
- } else if (is_negative_const(b.min)) {
- return IntervalSet::make(e2, e1);
- } else if (a.is_bounded()) {
+ if (b->IsSinglePoint()) {
+ if (is_zero(b->min_value)) return b;
+ if (is_one(b->min_value)) return a;
+ if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
+ Expr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf();
+ Expr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf();
+ return IntervalSet(min_value, max_value);
+ } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
+ Expr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf();
+ Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf();
+ return IntervalSet(min_value, max_value);
+ } else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::Select;
- Expr cmp = b.min >= make_zero(b.min.type().element_of());
- return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1));
+ Expr sign = b->min_value >= make_zero(b->min_value.type().element_of());
+ Expr e1 = a->min_value * b->min_value;
+ Expr e2 = a->max_value * b->min_value;
+ return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
}
}
- LOG(WARNING) << "Return Everything in CombineInterval Mul";
- return IntSet::everything();
+ DLOG(WARNING) << "Return Everything in CombineInterval Mul";
+ return IntervalSet::Everything();
}
template<>
-inline IntSet CombineInterval<Div>(Interval a, Interval b) {
- if (a.is_single_point() && b.is_single_point()) {
- return IntSet::single_point(ComputeExpr<Div>(a.min, b.min));
- }
- if (b.is_single_point()) {
- if (is_zero(b.min)) {
+inline IntervalSet Combine<ir::Div>(Analyzer* analyzer,
+ IntervalSet a,
+ IntervalSet b) {
+ if (a->IsSinglePoint() && b->IsSinglePoint()) {
+ return IntervalSet::SinglePoint(a->min_value / b->min_value);
+ }
+ if (a->IsEmpty()) return a;
+ if (b->IsEmpty()) return b;
+ if (b->IsSinglePoint()) {
+ if (is_zero(b->min_value)) {
LOG(FATAL) << "Divide by zero in CombineInterval Div";
}
- if (is_one(b.min)) return IntervalSet::make(a);
- Expr e1 = a.has_lower_bound() ? ComputeExpr<Div>(a.min, b.min) : a.min;
- Expr e2 = a.has_upper_bound() ? ComputeExpr<Div>(a.max, b.min) : a.max;
+ if (is_one(b->min_value)) return a;
// no relaxation is needed in here due to set is inclusive
- if (is_positive_const(b.min)) {
- return IntervalSet::make(e1, e2);
- } else if (is_negative_const(b.min)) {
- return IntervalSet::make(e2, e1);
- } else if (a.is_bounded()) {
+ if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
+ Expr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf();
+ Expr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf();
+ return IntervalSet(min_value, max_value);
+ } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
+ Expr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf();
+ Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
+ return IntervalSet(min_value, max_value);
+ } else if (a->HasUpperBound() && a->HasLowerBound()) {
using ir::Select;
- Expr cmp = b.min >= make_zero(b.min.type().element_of());
- return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1));
+ Expr sign = b->min_value >= make_zero(b->min_value.type().element_of());
+ Expr e1 = a->min_value / b->min_value;
+ Expr e2 = a->max_value / b->min_value;
+ return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1));
}
}
- LOG(WARNING) << "Return Everything in CombineInterval Div";
- return IntSet::everything();
+ DLOG(WARNING) << "Return Everything in CombineInterval Div";
+ return IntervalSet::Everything();
}
template<>
-inline IntSet CombineInterval<Mod>(Interval a, Interval b) {
- if (a.is_single_point() && b.is_single_point()) {
- return IntSet::single_point(ComputeExpr<Mod>(a.min, b.min));
+inline IntervalSet Combine<ir::Mod>(Analyzer* analyzer,
+ IntervalSet a,
+ IntervalSet b) {
+ if (a->IsSinglePoint() && b->IsSinglePoint()) {
+ return IntervalSet::SinglePoint(a->min_value % b->min_value);
}
- if (b.is_single_point()) {
- Expr divisor = b.min;
+ if (a->IsEmpty()) return a;
+ if (b->IsEmpty()) return b;
+
+ if (b->IsSinglePoint()) {
+ const Expr& divisor = b->min_value;
if (is_zero(divisor)) {
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
- return IntervalSet::make(make_zero(divisor.type()), divisor - 1);
+ // We need to add more bound constraints throughout the code.
+ // The logic below assumes a is non-negative, which usually
+ // is the case of our application.
+ // TODO(tqchen): add bound constraints for a.
+ if (analyzer->CanProveGreaterEqual(divisor, 0)) {
+ return IntervalSet(make_zero(divisor.type()), divisor - 1);
+ } else {
+ Expr bound = abs(divisor) - 1;
+ return IntervalSet(-bound, bound);
+ }
}
-
- LOG(WARNING) << "Return Everything in CombineInterval Mod";
- return IntSet::everything();
+ DLOG(WARNING) << "Return Everything in CombineInterval Mod";
+ return IntervalSet::Everything();
}
template<>
-inline IntSet CombineInterval<Max>(Interval a, Interval b) {
- if (a.is_single_point() && b.is_single_point()) {
- return IntSet::single_point(ComputeExpr<Max>(a.min, b.min));
+inline IntervalSet Combine<ir::Max>(Analyzer* analzyer,
+ IntervalSet a,
+ IntervalSet b) {
+ if (a->IsSinglePoint() && b->IsSinglePoint()) {
+ return IntervalSet::SinglePoint(max(a->min_value, b->min_value));
}
- return IntervalSet::make(Interval::make_max(a.min, b.min),
- Interval::make_max(a.max, b.max));
+ if (a->IsEmpty()) return a;
+ if (b->IsEmpty()) return b;
+ return IntervalSet(max(a->min_value, b->min_value),
+ max(a->max_value, b->max_value));
}
template<>
-inline IntSet CombineInterval<Min>(Interval a, Interval b) {
- if (a.is_single_point() && b.is_single_point()) {
- return IntSet::single_point(ComputeExpr<Min>(a.min, b.min));
+inline IntervalSet Combine<ir::Min>(Analyzer* analzyer,
+ IntervalSet a,
+ IntervalSet b) {
+ if (a->IsSinglePoint() && b->IsSinglePoint()) {
+ return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
}
- return IntervalSet::make(Interval::make_min(a.min, b.min),
- Interval::make_min(a.max, b.max));
-}
-
-template<typename OP>
-inline IntSet CombineInterval_(IntSet a, IntSet b) {
- return CombineInterval<OP>(
- a.as<IntervalSet>()->i, b.as<IntervalSet>()->i);
-}
-
-// stride related
-inline IntSet AsStrideSet(IntSet a) {
- if (a.as<StrideSet>()) return a;
- const IntervalSet* s = a.as<IntervalSet>();
- CHECK(s->i.is_bounded());
- NodePtr<StrideSet> n = make_node<StrideSet>();
- n->base = s->i;
- return IntSet(n);
-}
-template<typename OP>
-inline IntSet CombineSets(IntSet a, IntSet b) {
- return CombineInterval_<OP>(a.cover_interval(), b.cover_interval());
+ if (a->IsEmpty()) return a;
+ if (b->IsEmpty()) return b;
+ return IntervalSet(min(a->min_value, b->min_value),
+ min(a->max_value, b->max_value));
}
-template<>
-inline IntSet CombineSets<Add>(IntSet a, IntSet b) {
- const IntervalSet* a_int = a.as<IntervalSet>();
- const IntervalSet* b_int = b.as<IntervalSet>();
- if (a_int && is_zero(a_int->i.min)) return b;
- if (b_int && is_zero(b_int->i.min)) return a;
- a = AsStrideSet(a);
- b = AsStrideSet(b);
- const StrideSet* a_stride = a.as<StrideSet>();
- const StrideSet* b_stride = b.as<StrideSet>();
- auto n = make_node<StrideSet>(*a_stride);
- for (size_t i = 0; i < b_stride->extents.size(); ++i) {
- n->extents.push_back(b_stride->extents[i]);
- n->strides.push_back(b_stride->strides[i]);
- }
- n->base = CombineInterval<Add>(
- a_stride->base, b_stride->base).as<IntervalSet>()->i;
- return IntSet(n);
-}
-
-inline IntSet NegateSet(IntSet a) {
- const IntervalSet* a_int = a.as<IntervalSet>();
- if (a_int) {
- if (a_int->i.is_single_point()) {
- return IntSet::single_point(-a_int->i.min);
- } else {
- Interval r = Interval::everything();
- if (a_int->i.has_upper_bound()) {
- r.min = -(a_int->i.max);
- }
- if (a_int->i.has_lower_bound()) {
- r.max = -(a_int->i.min);
- }
- return IntervalSet::make(r);
- }
- } else {
- return NegateSet(a.cover_interval());
+// internal helper function to get an interval set
+IntervalSet ToIntervalSet(IntSet set) {
+ if (auto* node = set.as<IntervalSetNode>()) {
+ return GetRef<IntervalSet>(node);
}
+ DLOG(INFO) << "cannot resolve int set " << set;
+ return IntervalSet::Everything();
}
-template<>
-inline IntSet CombineSets<Sub>(IntSet a, IntSet b) {
- return CombineSets<Add>(a, NegateSet(b));
-}
-
-TVM_DECLARE_LOGICAL_OP(And);
-TVM_DECLARE_LOGICAL_OP(Or);
-TVM_DECLARE_LOGICAL_OP(EQ);
-TVM_DECLARE_LOGICAL_OP(NE);
-TVM_DECLARE_LOGICAL_OP(GE);
-TVM_DECLARE_LOGICAL_OP(GT);
-TVM_DECLARE_LOGICAL_OP(LE);
-TVM_DECLARE_LOGICAL_OP(LT);
-TVM_DECLARE_LOGICAL_OP(Not);
+using namespace ir;
-// generic combine operations of two sets
-template<typename OP>
-inline IntSet Combine(const IntSet& a, const IntSet &b) {
- if (is_logical_op<OP>::value) {
- return IntervalSet::make(0, 1);
+// Simplified version of int set evaluator that operates on IntervalSet
+// We might use better set analysis in the future to replace the intervalset.
+class IntervalSetEvaluator :
+ public ExprFunctor<IntervalSet(const Expr&)> {
+ public:
+ IntervalSetEvaluator(Analyzer* analyzer,
+ const Map<Var, IntSet>& dom_map,
+ bool eval_vec = false)
+ : analyzer_(analyzer),
+ dom_map_(dom_map),
+ eval_vec_(eval_vec) {
}
- const IntervalSet* a_int = a.as<IntervalSet>();
- const IntervalSet* b_int = b.as<IntervalSet>();
- if (a_int && a_int->i.is_everything()) return a;
- if (b_int && b_int->i.is_everything()) return b;
- if (a_int && b_int) {
- return CombineInterval<OP>(a_int->i, b_int->i);
+
+ IntervalSet Eval(const Expr& val) {
+ return this->VisitExpr(val);
}
- if (a_int && !(a_int->i.is_bounded())) {
- return CombineInterval_<OP>(a, b.cover_interval());
+
+ IntervalSet VisitExpr_(const IntImm* op) final {
+ return IntervalSet::SinglePoint(GetRef<Expr>(op));
}
- if (b_int && !(b_int->i.is_bounded())) {
- return CombineInterval_<OP>(a.cover_interval(), b);
+
+ IntervalSet VisitExpr_(const UIntImm* op) final {
+ return IntervalSet::SinglePoint(GetRef<Expr>(op));
}
- return CombineSets<OP>(a, b);
-}
-class IntSetEvaluator :
- public ExprFunctor<IntSet(const Expr&, const Expr&)> {
- public:
- explicit IntSetEvaluator(
- const std::unordered_map<const Variable*, IntSet>& dom_map,
- bool eval_vec = false)
- : dom_map_(dom_map), eval_vec_(eval_vec) {}
- // Evaluate.
- IntSet Eval(const Expr& e) {
- return this->VisitExpr(e, e);
- }
- IntSet VisitExpr_(const IntImm* op, const Expr& e) final {
- return IntSet::single_point(e);
- }
- IntSet VisitExpr_(const UIntImm* op, const Expr& e) final {
- return IntSet::single_point(e);
- }
- IntSet VisitExpr_(const Variable* op, const Expr& e) final {
- auto it = dom_map_.find(op);
+ IntervalSet VisitExpr_(const Variable* op) final {
+ Var var = GetRef<Var>(op);
+ auto it = dom_map_.find(var);
if (it != dom_map_.end()) {
- return it->second;
+ return ToIntervalSet((*it).second);
} else {
- return IntSet::single_point(e);
+ return IntervalSet::SinglePoint(var);
}
}
- IntSet VisitExpr_(const Add* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const Add* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const Sub* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const Sub* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const Mul* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const Mul* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const Div* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const Div* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const Mod* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const Mod* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const Min* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const Min* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const Max* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const Max* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const EQ* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const EQ* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const NE* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const NE* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const LT* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const LT* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const LE* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const LE* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const GT* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const GT* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const GE* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const GE* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const And* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const And* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const Or* op, const Expr& e) final {
- return Binary(op, e);
+
+ IntervalSet VisitExpr_(const Or* op) final {
+ return VisitBinaryExpr_(op);
}
- IntSet VisitExpr_(const Ramp* op, const Expr& e) final {
+
+ IntervalSet VisitExpr_(const Ramp* op) final {
CHECK(eval_vec_);
- IntSet base = Eval(op->base);
- int vstride;
- if (GetConstInt(op->stride, &vstride)) {
+ IntervalSet base = Eval(op->base);
+ PVar<Integer> stride;
+ if (stride.Match(op->stride)) {
Type t = op->base.type();
- if (vstride > 0) {
+ int64_t vstride = stride.Eval()->value;
+ if (vstride> 0) {
return Combine<Add>(
+ analyzer_,
base,
- IntSet::interval(make_zero(t),
- make_const(t, vstride * op->lanes -1)));
+ IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1)));
} else {
return Combine<Add>(
+ analyzer_,
base,
- IntSet::interval(make_const(t, vstride * op->lanes + 1),
- make_zero(t)));
+ IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)));
}
}
- LOG(WARNING) << "cannot evaluate set on expression " << e;
- return IntSet::everything();
+ DLOG(WARNING) << "cannot evaluate set on expression " << GetRef<Expr>(op);
+ return IntervalSet::Everything();
}
- IntSet VisitExpr_(const Broadcast* op, const Expr& e) final {
+
+ IntervalSet VisitExpr_(const Broadcast* op) final {
CHECK(eval_vec_);
- return Eval(op->value);
+ return VisitExpr(op->value);
}
- IntSet VisitExpr_(const Select* op, const Expr& e) final {
- IntSet true_set = this->Eval(op->true_value);
- IntSet false_set = this->Eval(op->false_value);
- return Union({false_set, true_set});
+
+ IntervalSet VisitExpr_(const Select* op) final {
+ IntervalSet true_set = this->Eval(op->true_value);
+ IntervalSet false_set = this->Eval(op->false_value);
+ return Union(analyzer_, false_set, true_set);
}
- IntSet VisitExprDefault_(const Node* op, const Expr& e) final {
- LOG(WARNING) << "cannot evaluate set type " << e->type_key();
- return IntSet::everything();
+
+ IntervalSet VisitExprDefault_(const Node* op) final {
+ DLOG(WARNING) << "cannot evaluate set type " << op->type_key();
+ return IntervalSet::Everything();
}
private:
+ // whether set is exactly single point that equals value.
+ bool MatchPoint(const IntervalSet& set,
+ const Expr& value) const {
+ return set->min_value.same_as(value) && set->max_value.same_as(value);
+ }
+
template<typename T>
- inline IntSet Binary(const T* op, const Expr& e) {
- IntSet a = this->Eval(op->a);
- IntSet b = this->Eval(op->b);
+ inline IntervalSet VisitBinaryExpr_(const T* op) {
+ IntervalSet a = this->Eval(op->a);
+ IntervalSet b = this->Eval(op->b);
if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
- return IntSet::single_point(e);
+ return IntervalSet::SinglePoint(GetRef<Expr>(op));
}
- return Combine<T>(a, b);
+ return Combine<T>(analyzer_, a, b);
}
- const std::unordered_map<const Variable*, IntSet>& dom_map_;
+ Analyzer* analyzer_;
+ const Map<Var, IntSet>& dom_map_;
bool eval_vec_{false};
};
+class IntSetAnalyzer::Impl {
+ public:
+ explicit Impl(Analyzer* analyzer)
+ : analyzer_(analyzer) {
+ }
+
+ IntSet Eval(const Expr& expr, const Map<Var, IntSet>& dom_map) const {
+ return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr);
+ }
+
+ private:
+ Analyzer* analyzer_;
+};
+
+IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent)
+ : impl_(new Impl(parent)) {
+}
+
+IntSetAnalyzer::~IntSetAnalyzer() {
+ delete impl_;
+}
+
+IntSet IntSetAnalyzer::operator()(const Expr& expr,
+ const Map<Var, IntSet>& dom_map) {
+ return impl_->Eval(expr, dom_map);
+}
+
+// Quickly adapt to IntSet interface
+// TODO(tqchen): revisit IntSet interface as well.
+Range IntSet::cover_range(Range max_range) const {
+ IntSet temp;
+ const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+ CHECK(s_int != nullptr);
+ if (s_int->HasUpperBound() && s_int->HasLowerBound()) {
+ return Range::make_by_min_extent(
+ s_int->min_value, Simplify(s_int->max_value + 1 - s_int->min_value));
+ }
+ return max_range;
+}
+
+Expr IntSet::min() const {
+ const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+ CHECK(s_int);
+ return s_int->min_value;
+}
+
+Expr IntSet::max() const {
+ const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+ CHECK(s_int);
+ return s_int->max_value;
+}
+
+bool IntSet::is_nothing() const {
+ const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+ return (s_int && s_int->IsEmpty());
+}
+
+bool IntSet::is_everything() const {
+ const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+ return (s_int && s_int->IsEverything());
+}
+
+bool IntSet::is_single_point() const {
+ const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+ return (s_int && s_int->IsSinglePoint());
+}
+
+bool IntSet::can_prove_positive() const {
+ const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+ return (s_int && is_positive_const(ir::Simplify(s_int->min_value)));
+}
+
+bool IntSet::can_prove_negative() const {
+ const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+ return (s_int && is_negative_const(ir::Simplify(s_int->max_value)));
+}
+
+bool IntSet::can_prove_non_positive() const {
+ if (const auto* s_int = (*this).as<IntervalSetNode>()) {
+ auto max = ir::Simplify(s_int->max_value);
+ return is_zero(max) || is_negative_const(max);
+ }
+ return false;
+}
+
+bool IntSet::can_prove_non_negative() const {
+ if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) {
+ auto min = ir::Simplify(s_int->min_value);
+ return is_zero(min) || is_positive_const(min);
+ }
+ return false;
+}
+
+SignType IntSet::sign_type() const {
+ if (can_prove_positive()) {
+ return kPositive;
+ } else if (can_prove_negative()) {
+ return kNegative;
+ } else if (is_single_point() && is_zero(point_value())) {
+ return kZero;
+ } else {
+ return kUnknown;
+ }
+}
+Expr IntSet::point_value() const {
+ const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
+ CHECK(s_int && s_int->IsSinglePoint());
+ return s_int->min_value;
+}
+
+IntSet IntSet::nothing() {
+ return IntervalSet::Empty();
+}
+
+IntSet IntSet::everything() {
+ return IntervalSet::Everything();
+}
+
+IntSet IntSet::single_point(Expr x) {
+ return IntervalSet::SinglePoint(x);
+}
+
+IntSet IntSet::interval(Expr min, Expr max) {
+ if (min.same_as(max)) {
+ return IntSet::single_point(min);
+ }
+ return IntervalSet(min, max);
+}
+
+// Range related code
+inline bool ProveEqual(Expr lhs, Expr rhs) {
+ return is_zero(ir::Simplify(lhs - rhs));
+}
+
+IntSet IntSet::range(Range r) {
+ // must make sure it can be matched back by MatchRange.
+ if (is_one(r->extent)) {
+ return IntSet::single_point(r->min);
+ }
+ return IntervalSet(r->min, r->extent + r->min - 1);
+}
+
+bool IntSet::match_range(const Range& b) const {
+ const IntSet& a = *this;
+ const IntervalSetNode* a_int = a.as<IntervalSetNode>();
+ if (!a_int) return false;
+ return ProveEqual(a_int->min_value, b->min) &&
+ ProveEqual(a_int->max_value, b->extent + b->min - 1);
+}
+
+IntSet Union(const Array<IntSet>& sets) {
+ if (sets.size() == 0) return IntSet::nothing();
+ if (sets.size() == 1) return sets[0];
+ Analyzer ana;
+ IntervalSet x = ToIntervalSet(sets[0]);
+ for (size_t i = 1; i < sets.size(); ++i) {
+ x = Union(&ana, x, ToIntervalSet(sets[i]));
+ }
+ return IntervalSet(ir::Simplify(x->min_value),
+ ir::Simplify(x->max_value));
+}
+
+IntSet Intersect(const Array<IntSet>& sets) {
+ if (sets.size() == 0) return IntSet::nothing();
+ if (sets.size() == 1) return sets[0];
+ Analyzer ana;
+ IntervalSet x = ToIntervalSet(sets[0]);
+ for (size_t i = 1; i < sets.size(); ++i) {
+ x = Intersect(&ana, x, ToIntervalSet(sets[i]));
+ }
+ return IntervalSet(ir::Simplify(x->min_value),
+ ir::Simplify(x->max_value));
+}
+
+Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) {
+ Map<Var, IntSet> dmap;
+ for (auto kv : dom_map) {
+ dmap.Set(kv.first->var, kv.second);
+ }
+ return dmap;
+}
+
+Map<Var, IntSet> ConvertDomMap(
+ const std::unordered_map<const Variable*, IntSet>& dom_map) {
+ Map<Var, IntSet> dmap;
+ for (auto kv : dom_map) {
+ dmap.Set(GetRef<Var>(kv.first), kv.second);
+ }
+ return dmap;
+}
+
IntSet EvalSet(Expr e,
- const std::unordered_map<const Variable*, IntSet>& dom_map) {
- return IntSetEvaluator(dom_map, false).Eval(e);
+ const Map<Var, IntSet>& dom_map) {
+ Analyzer ana;
+ return IntervalSetEvaluator(&ana, dom_map, false).Eval(e);
}
IntSet IntSet::vector(Expr x) {
- std::unordered_map<const Variable*, IntSet> dmap;
- return IntSetEvaluator(dmap, true).Eval(x);
+ Analyzer ana;
+ Map<Var, IntSet> dmap;
+ return IntervalSetEvaluator(&ana, dmap, true).Eval(x);
}
IntSet EvalSet(Expr e,
const Map<IterVar, IntSet>& dom_map) {
- std::unordered_map<const Variable*, IntSet> dmap;
- for (auto kv : dom_map) {
- dmap[kv.first->var.as<Variable>()] = kv.second;
- }
- return EvalSet(e, dmap);
+ return EvalSet(e, ConvertDomMap(dom_map));
}
-IntSet EvalSet(Range r,
+IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
- IntSetEvaluator m(dom_map);
- IntSet min_set = m.Eval(r->min).cover_interval();
+ return EvalSet(e, ConvertDomMap(dom_map));
+}
+
+IntSet EvalSet(Range r,
+ const Map<Var, IntSet>& dom_map) {
+ Analyzer ana;
+ IntervalSetEvaluator m(&ana, dom_map);
+ IntervalSet min_set = m.Eval(r->min);
// Simplifying first can give tighter bounds if r->min and r->extent share variables
- Expr sum = ComputeExpr<Sub>(ComputeExpr<Add>(r->min, r->extent), 1);
- IntSet max_set = m.Eval(Simplify(sum)).cover_interval();
- const Interval& ni = min_set.as<IntervalSet>()->i;
- const Interval& xi = max_set.as<IntervalSet>()->i;
- if (!ni.has_lower_bound()) return IntSet::everything();
- if (!xi.has_upper_bound()) return IntSet::everything();
- return IntervalSet::make(ni.min, xi.max);
+ Expr sum = r->min + r->extent - 1;
+ IntervalSet max_set = m.Eval(Simplify(sum));
+ if (!min_set->HasLowerBound()) return IntSet::everything();
+ if (!max_set->HasUpperBound()) return IntSet::everything();
+ return IntervalSet(min_set->min_value, max_set->max_value);
}
-IntSet EvalSet(IntSet s,
+IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
- IntSetEvaluator m(dom_map);
- s = s.cover_interval();
- const IntervalSet* s_int = s.as<IntervalSet>();
- Expr vmax = s_int->i.has_upper_bound() ?
- m.Eval(s_int->i.max).cover_interval().max() : s_int->i.max;
- Expr vmin = s_int->i.has_lower_bound() ?
- m.Eval(s_int->i.min).cover_interval().min() : s_int->i.min;
- return IntervalSet::make(vmin, vmax);
+ return EvalSet(r, ConvertDomMap(dom_map));
}
-class SubExprIntSetEvaluator : public IntSetEvaluator {
+IntSet EvalSet(IntSet s,
+ const std::unordered_map<const Variable*, IntSet>& dom_map) {
+ Analyzer ana;
+ auto dmap = ConvertDomMap(dom_map);
+ IntervalSetEvaluator m(&ana, dmap);
+ const IntervalSetNode* s_int = s.as<IntervalSetNode>();
+ Expr vmax = s_int->HasUpperBound() ?
+ m.Eval(s_int->max_value).max() : s_int->max_value;
+ Expr vmin = s_int->HasLowerBound() ?
+ m.Eval(s_int->min_value).min() : s_int->min_value;
+ return IntervalSet(vmin, vmax);
+}
+
+class SubExprIntervalSetEvaluator : public IntervalSetEvaluator {
public:
- explicit SubExprIntSetEvaluator(
- const std::unordered_map<const Variable*, IntSet>& dom_map)
- : IntSetEvaluator(dom_map) {}
+ explicit SubExprIntervalSetEvaluator(
+ Analyzer* analyzer,
+ const Map<Var, IntSet>& dom_map)
+ : IntervalSetEvaluator(analyzer, dom_map) {}
- IntSet VisitExpr(const Expr& n, const Expr& e) final {
- IntSet ret = IntSetEvaluator::VisitExpr(n, e);
+ IntervalSet VisitExpr(const Expr& n) final {
+ IntervalSet ret = IntervalSetEvaluator::VisitExpr(n);
expr_map[n] = ret;
return ret;
}
ExprIntSetMap expr_map;
};
-ExprIntSetMap EvalSetForEachSubExpr(Expr e,
+ExprIntSetMap EvalSetForEachSubExpr(
+ Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
- SubExprIntSetEvaluator m(dom_map);
+ Analyzer ana;
+ auto dmap = ConvertDomMap(dom_map);
+ SubExprIntervalSetEvaluator m(&ana, dmap);
m.Eval(e);
return m.expr_map;
}
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map) {
- std::unordered_map<const Variable*, IntSet> dmap;
- for (auto kv : dom_map) {
- dmap[kv.first->var.as<Variable>()] = kv.second;
- }
- return EvalSet(r, dmap);
+ return EvalSet(r, ConvertDomMap(dom_map));
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
-.set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) {
- p->stream << "interval-set"
- << "[" << op->i.min << ", "
- << op->i.max << ']';
+.set_dispatch<IntervalSetNode>([](const IntervalSetNode *op, IRPrinter *p) {
+ p->stream << "IntervalSet"
+ << "[" << op->min_value << ", "
+ << op->max_value << ']';
});
-
} // namespace arith
} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file int_set.h
+ * \brief Internal data structure for integer set.
+ */
+#ifndef TVM_ARITHMETIC_INT_SET_H_
+#define TVM_ARITHMETIC_INT_SET_H_
+
+#include <tvm/arithmetic.h>
+#include <tvm/expr_operator.h>
+#include <limits>
+#include "const_fold.h"
+
+namespace tvm {
+namespace arith {
+
+/*!
+ * \brief Symbolic interval set.
+ *
+ * \note We intentionally keep the internal of IntSet private,
+ as we might change it later.
+ */
+class IntervalSetNode : public IntSetNode {
+ public:
+ /*! \brief Minimum value in the interval. */
+ Expr min_value;
+ /*! \brief Maximum value in the interval. */
+ Expr max_value;
+
+ // visitor overload.
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("min_value", &min_value);
+ v->Visit("max_value", &max_value);
+ }
+
+ /*! \return Whether the interval has upper bound. */
+ bool HasUpperBound() const {
+ return !is_pos_inf(max_value) && !IsEmpty();
+ }
+ /*! \return Whether the interval has lower bound. */
+ bool HasLowerBound() const {
+ return !is_neg_inf(min_value) && !IsEmpty();
+ }
+ /*! \return Whether the interval is a single point. */
+ bool IsSinglePoint() const {
+ return min_value.same_as(max_value);
+ }
+ /*! \return whether interval represent nothing */
+ bool IsEmpty() const {
+ // during computations, either extreme could occur.
+ return is_pos_inf(min_value) || is_neg_inf(max_value);
+ }
+ /*! \return whether interval represent everything */
+ bool IsEverything() const {
+ return is_neg_inf(min_value) && is_pos_inf(max_value);
+ }
+
+ static constexpr const char* _type_key = "arith.IntervalSet";
+ TVM_DECLARE_NODE_TYPE_INFO(IntervalSetNode, IntSetNode);
+};
+
+/*!
+ * \brief Interval set used for symbolic integer analysis.
+ * \sa IntervalSetNode
+ */
+class IntervalSet : public IntSet {
+ public:
+ /*!
+ * \brief Make a new instance of interval set.
+ * \param min_value The minimum value in the interval.
+ * \param max_value The maximum value in the interval.
+ * \return The created set.
+ */
+ TVM_DLL IntervalSet(Expr min_value, Expr max_value);
+
+ /*!
+ * \brief Create an IntervalSet that represents a single point.
+ * \param value The value to be represented.
+ * \return The result set.
+ */
+ static IntervalSet SinglePoint(Expr value) {
+ return IntervalSet(value, value);
+ }
+ /*!
+ * \brief Create an IntervalSet that represents everything.
+ * \param value The value to be represented.
+ * \return The result set.
+ */
+ static IntervalSet Everything() {
+ return IntervalSet(neg_inf(), pos_inf());
+ }
+ /*!
+ * \brief Create an empty eet.
+ * \return The result set.
+ */
+ static IntervalSet Empty() {
+ return IntervalSet(pos_inf(), neg_inf());
+ }
+
+ TVM_DEFINE_NODE_REF_COW(IntervalSetNode);
+ TVM_DEFINE_NODE_REF_METHODS(IntervalSet, IntSet, IntervalSetNode);
+};
+
+/*!
+ * \brief Create union of two IntervalSets.
+ * \param analyzer The analyzer for simplification analysis.
+ * \param a The first set.
+ * \param b The second set.
+ * \return The result set.
+ */
+TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b);
+
+/*!
+ * \brief Create insersection of two IntervalSets.
+ * \param analzyer The analyzer for simplification analysis.
+ * \param a The first set.
+ * \param b The second set.
+ * \return The result set.
+ */
+TVM_DLL IntervalSet Intersect(Analyzer *analzyer, IntervalSet a, IntervalSet b);
+
+} // namespace arith
+} // namespace tvm
+
+#endif // TVM_ARITHMETIC_INT_SET_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2017 by Contributors
- * \file int_set_internal.h
- * \brief Implementations of integer set
- */
-#ifndef TVM_ARITHMETIC_INT_SET_INTERNAL_H_
-#define TVM_ARITHMETIC_INT_SET_INTERNAL_H_
-
-#include <tvm/ir.h>
-#include <tvm/ir_pass.h>
-#include <tvm/arithmetic.h>
-
-namespace tvm {
-namespace arith {
-
-using HalideIR::Internal::Interval;
-
-/*! \brief Set of continuous interval */
-struct IntervalSet : public IntSetNode {
- /*! \brief the internal interval*/
- Interval i;
-
- static IntSet make(Interval i) {
- NodePtr<IntervalSet> n =
- make_node<IntervalSet>();
- n->i = i;
- return IntSet(n);
- }
- static IntSet make(Expr min, Expr max) {
- NodePtr<IntervalSet> n =
- make_node<IntervalSet>();
- n->i.min = min;
- n->i.max = max;
- return IntSet(n);
- }
-
- static constexpr const char* _type_key = "IntervalSet";
- TVM_DECLARE_NODE_TYPE_INFO(IntervalSet, IntSetNode);
-};
-
-/*!
- * \brief set represented by strided integers
- * Reserved for cases where strided access is supported.
- */
-struct StrideSet : public IntSetNode {
- /*! \brief the base inetrval */
- Interval base;
- /*! \brief additional extents in positive number */
- Array<Expr> extents;
- /*! \brief additional strides in positive number */
- Array<Expr> strides;
-
- static constexpr const char* _type_key = "StrideSet";
- TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
-};
-
-} // namespace arith
-} // namespace tvm
-
-#endif // TVM_ARITHMETIC_INT_SET_INTERNAL_H_
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
return ir::Mod::make(a, b);
}
+
Expr min(Expr a, Expr b) {
+ // inf-aware simplificaiton
+ using arith::is_pos_inf;
+ using arith::is_neg_inf;
+ if (is_pos_inf(a)) return b;
+ if (is_neg_inf(a)) return a;
+ if (is_pos_inf(b)) return a;
+ if (is_neg_inf(b)) return b;
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Min>(a, b);
if (ret.defined()) return ret;
}
Expr max(Expr a, Expr b) {
+ // inf-aware simplificaiton
+ using arith::is_pos_inf;
+ using arith::is_neg_inf;
+ if (is_pos_inf(a)) return a;
+ if (is_neg_inf(a)) return b;
+ if (is_pos_inf(b)) return b;
+ if (is_neg_inf(b)) return a;
BinaryOpMatchTypes(a, b);
Expr ret = arith::TryConstFold<ir::Max>(a, b);
if (ret.defined()) return ret;
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#include <tvm/arithmetic.h>
#include <unordered_map>
#include <unordered_set>
-#include "../arithmetic/int_set_internal.h"
+#include "../arithmetic/int_set.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
std::pair<IntSet, std::unordered_set<const Node*>>
GetIntervalAndCondset(const Partition &partitions,
- const arith::Interval &for_interval,
+ const arith::IntervalSet &for_interval,
bool cond_value);
inline Stmt MakeFor(const Node* op, Expr extent, Stmt body);
/* Candidate IRs that may be partitioned potentially */
std::unordered_map<const Variable*, IntSet> hint_map_;
std::unordered_map<const Variable*, IntSet> relax_map_;
+ arith::Analyzer analyzer_;
CandidateSelector selector;
};
// given in the second component provably have value given by cond_value
std::pair<IntSet, std::unordered_set<const Node*>>
LoopPartitioner::GetIntervalAndCondset(const Partition &partitions,
- const arith::Interval &for_interval,
+ const arith::IntervalSet &for_interval,
bool cond_value) {
Array<IntSet> sets;
std::unordered_set<const Node*> cond_set;
for (const auto &kv : partitions) {
if (kv.first.second == cond_value) {
- arith::Interval interval = kv.second.as<arith::IntervalSet>()->i;
- arith::Interval intersection = arith::Interval::make_intersection(interval, for_interval);
- if (!intersection.is_empty()) {
+ arith::IntervalSet interval = Downcast<arith::IntervalSet>(kv.second);
+ arith::IntervalSet intersection = arith::Intersect(
+ &analyzer_, interval, for_interval);
+ if (!intersection->IsEmpty()) {
sets.push_back(kv.second);
cond_set.insert(kv.first.first);
}
Expr max,
Stmt body,
bool partition_thread_scope) {
+ using namespace arith;
PartitionFinder finder(var, hint_map_, relax_map_);
finder.Visit(body);
if (finder.partitions.empty()) return Stmt();
- arith::Interval for_interval(min, max);
+ arith::IntervalSet for_interval(min, max);
bool cond_value;
IntSet middle_interval;
std::unordered_set<const Node*> cond_set;
// if such interval doesn't exist, find an interval in which all
// conditions on var are false
std::tie(middle_interval, cond_set) =
- GetIntervalAndCondset(finder.partitions, for_interval, false);
+ GetIntervalAndCondset(finder.partitions, for_interval, false);
if (middle_interval.is_nothing())
// we couldn't find an interval in which the condintions are provably true or false
// Therefore, we can't partition the loop based on those conds
cond_value = true;
}
- arith::Interval middle_interval_i = middle_interval.as<arith::IntervalSet>()->i;
+ IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval);
// middle_interval is the subrange of the loop variable range for which a
// set of conditions are true (or false resp.)
// The part of the loop variable range that is before (after resp.) that
Expr body_begin;
Stmt pre_stmt;
bool pre_stmt_recurse = true;
- if (middle_interval_i.has_lower_bound()) {
+ if (middle_interval_i->HasLowerBound()) {
body_begin = ir::Simplify(middle_interval.min());
if (!can_prove(body_begin == min)) {
Expr cond = (body_begin - min >= 0);
Expr post_doubt_begin;
Stmt post_stmt;
bool post_stmt_recurse = true;
- if (middle_interval_i.has_upper_bound()) {
+ if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = ir::Simplify(middle_interval.max() + 1);
if (!can_prove(middle_interval.max() == max)) {
// require the extent to be non-negative
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+
+def test_deduce():
+ a = tvm.var('a')
+ b = tvm.var('b')
+ c = tvm.var('c')
+ d = tvm.var('d')
+
+ b_s = tvm.arith.IntervalSet(2, 3)
+ c_s = tvm.arith.IntervalSet(10, 15)
+ d_s = tvm.arith.IntervalSet(-3, -1)
+ zero = tvm.const(0, "int32")
+
+ e0 = (-b)*a+c-d
+ res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
+ ans0 = ((d - c) /(b*-1))
+ assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+
+ # expression containing variable a is on rhs
+ res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
+ assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+
+ e0 = d*a+c-d
+ res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
+ ans0 = ((0-c)/d + 1)
+ assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+
+ # expression containing variable a is on rhs
+ res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
+ assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0)
+
+ e1 = (a*4+b < c)
+ res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
+ ans1 = (((c - b) + -1)/4)
+ assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1)
+
+ # expression containing variable a is on rhs
+ e1 = (c > a*4+b)
+ res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
+ assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1)
+
+ e2 = (tvm.max(5, a * 4) < 0)
+ res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
+ assert str(res2.max_value) == "neg_inf"
+ assert str(res2.min_value) == "pos_inf"
+
+ # expression containing variable a is on rhs
+ e2 = (zero < tvm.max(5, a * 4))
+ res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
+ assert str(res2.max_value) == "neg_inf"
+ assert str(res2.min_value) == "pos_inf"
+
+
+ e3 = (-b)+a*c-d
+ res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
+ ans3 = 2/c+1
+ assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
+
+ res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
+ assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
+
+def test_check():
+ a = tvm.var('a')
+ b = tvm.var('b')
+ c = tvm.var('c')
+ d = tvm.var('d')
+
+ b_s = tvm.arith.IntervalSet(2, 3)
+ c_s = tvm.arith.IntervalSet(5, 7)
+ d_s = tvm.arith.IntervalSet(-3, -1)
+
+ # no compare operator
+ res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
+ assert res1.is_nothing()
+
+ # multiple compare operators
+ res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
+ assert res2.is_nothing()
+
+ # multiple target variable
+ res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
+ assert res2.is_nothing()
+
+def test_deduce_basic():
+ def test_basic(a1, a2, coff):
+ a = tvm.var('a')
+ b = tvm.var('b')
+ b_s = tvm.arith.IntervalSet(a1, a2)
+ e0 = b + a*coff + 3
+
+ res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s})
+ [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
+ assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
+
+ # expression containing variable a is on rhs
+ res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s})
+ [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
+ assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
+
+ # expression containing variable a is on rhs
+ res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
+ [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
+ assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
+
+ res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
+ [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
+ assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
+
+ test_basic(0, 4, 4)
+ test_basic(1, 5, 4)
+ test_basic(2, 6, 4)
+ test_basic(0, 4, -4)
+ test_basic(1, 5, -4)
+ test_basic(2, 6, -4)
+
+def test_deduce_complex():
+ def test_complex(a1, a2, coff):
+ a = tvm.var('a')
+ b = tvm.var('b')
+ b_s = tvm.arith.IntervalSet(a1, a2)
+ e0 = (b*3 + a* coff) * 4
+
+ res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s})
+ [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
+ assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
+
+ # expression containing variable a is on rhs
+ res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
+ [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
+ assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
+
+ res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
+ [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
+ assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
+
+ # expression containing variable a is on rhs
+ res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
+ [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
+ assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
+
+ test_complex(0, 4, 4)
+ test_complex(0, 4, -4)
+ test_complex(2, 6, 4)
+ test_complex(0, 4, -4)
+ test_complex(1, 5, -4)
+ test_complex(2, 6, -4)
+
+
+if __name__ == "__main__":
+ test_check()
+ test_deduce_basic()
+ test_deduce_complex()
# under the License.
import tvm
+
+class IntSetChecker:
+ def __init__(self):
+ self.analyzer = tvm.arith.Analyzer()
+
+ def verify(self, data, dmap, expected):
+ res = self.analyzer.int_set(data, dmap)
+ def err_msg():
+ return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected)
+ def equal(x, y):
+ res = self.analyzer.canonical_simplify(x - y)
+ return tvm.ir_pass.Equal(res, 0)
+ assert equal(res.min_value, expected[0]), err_msg()
+ assert equal(res.max_value, expected[1]), err_msg()
+
def test_basic():
- s = tvm.arith.intset_interval(2, 3)
- assert s.min().value == 2
- assert s.max().value == 3
+ s = tvm.arith.IntervalSet(2, 3)
+ assert s.min_value.value == 2
+ assert s.max_value.value == 3
+
def test_vector():
base = 10
stride = 3
lanes = 2
s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes))
- assert s.min().value == base
- assert s.max().value == base + stride * lanes - 1
-
-def test_deduce():
- a = tvm.var('a')
- b = tvm.var('b')
- c = tvm.var('c')
- d = tvm.var('d')
-
- b_s = tvm.arith.intset_interval(2, 3)
- c_s = tvm.arith.intset_interval(10, 15)
- d_s = tvm.arith.intset_interval(-3, -1)
- zero = tvm.const(0, "int32")
-
- e0 = (-b)*a+c-d
- res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
- ans0 = ((d - c) /(b*-1))
- assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
-
- # expression containing variable a is on rhs
- res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
- assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
-
- e0 = d*a+c-d
- res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
- ans0 = ((0-c)/d + 1)
- assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
-
- # expression containing variable a is on rhs
- res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
- assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0)
-
- e1 = (a*4+b < c)
- res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
- ans1 = (((c - b) + -1)/4)
- assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
-
- # expression containing variable a is on rhs
- e1 = (c > a*4+b)
- res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
- assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1)
-
- e2 = (tvm.max(5, a * 4) < 0)
- res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
- assert str(res2.max()) == "neg_inf"
- assert str(res2.min()) == "pos_inf"
-
- # expression containing variable a is on rhs
- e2 = (zero < tvm.max(5, a * 4))
- res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
- assert str(res2.max()) == "neg_inf"
- assert str(res2.min()) == "pos_inf"
-
-
- e3 = (-b)+a*c-d
- res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
- ans3 = 2/c+1
- assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
-
- res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
- assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3)
-
-def test_check():
- a = tvm.var('a')
- b = tvm.var('b')
- c = tvm.var('c')
- d = tvm.var('d')
-
- b_s = tvm.arith.intset_interval(2, 3)
- c_s = tvm.arith.intset_interval(5, 7)
- d_s = tvm.arith.intset_interval(-3, -1)
-
- # no compare operator
- res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
- assert res1.is_nothing()
-
- # multiple compare operators
- res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
- assert res2.is_nothing()
-
- # multiple target variable
- res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
- assert res2.is_nothing()
-
-def test_deduce_basic():
- def test_basic(a1, a2, coff):
- a = tvm.var('a')
- b = tvm.var('b')
- b_s = tvm.arith.intset_interval(a1, a2)
- e0 = b + a*coff + 3
-
- res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s})
- [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
- assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
-
- # expression containing variable a is on rhs
- res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s})
- [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
- assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
-
- # expression containing variable a is on rhs
- res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
- [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
- assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
-
- res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
- [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
- assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
-
- test_basic(0, 4, 4)
- test_basic(1, 5, 4)
- test_basic(2, 6, 4)
- test_basic(0, 4, -4)
- test_basic(1, 5, -4)
- test_basic(2, 6, -4)
-
-def test_deduce_complex():
- def test_complex(a1, a2, coff):
- a = tvm.var('a')
- b = tvm.var('b')
- b_s = tvm.arith.intset_interval(a1, a2)
- e0 = (b*3 + a* coff) * 4
-
- res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s})
- [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
- assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
-
- # expression containing variable a is on rhs
- res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
- [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()]
- assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
-
- res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
- [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
- assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
-
- # expression containing variable a is on rhs
- res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
- [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()]
- assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
-
- test_complex(0, 4, 4)
- test_complex(0, 4, -4)
- test_complex(2, 6, 4)
- test_complex(0, 4, -4)
- test_complex(1, 5, -4)
- test_complex(2, 6, -4)
+ assert s.min_value.value == base
+ assert s.max_value.value == base + stride * lanes - 1
+
+
+def test_add_sub():
+ ck = IntSetChecker()
+ x, y = tvm.var("x"), tvm.var("y")
+ ck.verify(x + y, {x : tvm.arith.IntervalSet(0, 10)}, (y, 10 + y))
+ ck.verify(x + y,
+ {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)},
+ (1, 21))
+ ck.verify(x - y,
+ {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)},
+ (-11, 9))
+
+def test_mul_div():
+ ck = IntSetChecker()
+ x, y = tvm.var("x"), tvm.var("y")
+ ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
+ ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
+ ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20))
+ ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2))
+ ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y))
+ ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5))
+
+
+def test_mod():
+ ck = IntSetChecker()
+ x, y = tvm.var("x"), tvm.var("y")
+ ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
+ ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
+ ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9))
+
+def test_max_min():
+ ck = IntSetChecker()
+ x, y = tvm.var("x"), tvm.var("y")
+ ck.verify(tvm.max(x, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (1, 11))
+ ck.verify(tvm.min(x - 1, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (-1, 9))
+ ck.verify(tvm.min(x, y), {}, (tvm.min(x, y), tvm.min(x, y)))
+ ck.verify(tvm.max(x, y), {}, (tvm.max(x, y), tvm.max(x, y)))
+
+
+def test_select():
+ ck = IntSetChecker()
+ x, y = tvm.var("x"), tvm.var("y")
+ ck.verify(tvm.expr.Select(x > 0, x - 1, x + 1),
+ {x : tvm.arith.IntervalSet(0, 10)}, (-1, 11))
+
if __name__ == "__main__":
test_basic()
test_vector()
- test_deduce()
- test_check()
- test_deduce_basic()
- test_deduce_complex()
+ test_add_sub()
+ test_mul_div()
+ test_max_min()
+ test_select()
+ test_mod()
+