RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);
+/*! \brief A tuple pattern. Matches a tuple, binds recursively. */
+class PatternTuple;
+/*! \brief PatternVar container node */
+class PatternTupleNode : public PatternNode {
+ public:
+ /*! Sub-patterns to match against each value of the tuple. */
+ tvm::Array<Pattern> patterns;
+
+ PatternTupleNode() {}
+
+ TVM_DLL static PatternTuple make(tvm::Array<Pattern> var);
+
+ void VisitAttrs(tvm::AttrVisitor* v) final {
+ v->Visit("patterns", &patterns);
+ v->Visit("span", &span);
+ }
+
+ static constexpr const char* _type_key = "relay.PatternTuple";
+ TVM_DECLARE_NODE_TYPE_INFO(PatternTupleNode, PatternNode);
+};
+
+RELAY_DEFINE_NODE_REF(PatternTuple, PatternTupleNode, Pattern);
+
/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternConstructorNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
+ virtual R VisitPattern_(const PatternTupleNode* op,
+ Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPatternDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode);
+ RELAY_PATTERN_FUNCTOR_DISPATCH(PatternTupleNode);
return vtable;
}
};
void VisitPattern_(const PatternWildcardNode* op) override;
void VisitPattern_(const PatternVarNode* op) override;
void VisitPattern_(const PatternConstructorNode* op) override;
+ void VisitPattern_(const PatternTupleNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitVar(const Var& v);
virtual void VisitConstructor(const Constructor& c);
Pattern VisitPattern_(const PatternWildcardNode* op) override;
Pattern VisitPattern_(const PatternVarNode* op) override;
Pattern VisitPattern_(const PatternConstructorNode* op) override;
+ Pattern VisitPattern_(const PatternTupleNode* op) override;
/*! \brief Used to visit the types inside of patterns.
*
* Can be overloaded to transform the types in arbitrary
PatternWildcard = adt.PatternWildcard
PatternVar = adt.PatternVar
PatternConstructor = adt.PatternConstructor
+PatternTuple = adt.PatternTuple
Constructor = adt.Constructor
TypeData = adt.TypeData
Clause = adt.Clause
@register_relay_node
+class PatternTuple(Pattern):
+ """Constructor pattern in Relay: Matches a tuple, binds recursively."""
+
+ def __init__(self, patterns=None):
+ """Construct a tuple pattern.
+
+ Parameters
+ ----------
+ patterns: Optional[List[Pattern]]
+ Optional subpatterns: for each field of the constructor,
+ match to the given subpattern (treated as a variable pattern by default).
+
+ Returns
+ -------
+ wildcard: PatternWildcard
+ a wildcard pattern.
+ """
+ if patterns is None:
+ patterns = []
+ self.__init_handle_by_constructor__(_make.PatternTuple, patterns)
+
+
+@register_relay_node
class Constructor(Expr):
"""Relay ADT constructor."""
from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
-from .adt import PatternConstructor, PatternVar, PatternWildcard
+from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple
from .parser import fromtext
__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
from .module import Module
self.zip = GlobalVar("zip")
a = TypeVar("a")
b = TypeVar("b")
- nil_case = Clause(PatternConstructor(self.nil), self.nil())
l1 = Var("l1")
l2 = Var("l2")
h1 = Var("h1")
h2 = Var("h2")
t1 = Var("t1")
t2 = Var("t2")
- inner_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h2), PatternVar(t2)]),
- self.cons(Tuple([h1, h2]), self.zip(t1, t2)))
- outer_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h1), PatternVar(t1)]),
- Match(l2, [nil_case, inner_cons_case]))
- self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]),
+ cons_case = Clause(PatternTuple([PatternConstructor(self.cons,
+ [PatternVar(h1), PatternVar(t1)]),
+ PatternConstructor(self.cons,
+ [PatternVar(h2), PatternVar(t2)])]),
+ self.cons(Tuple([h1, h2]), self.zip(t1, t2)))
+ nil_case = Clause(PatternWildcard(), self.nil())
+ self.mod[self.zip] = Function([l1, l2], Match(Tuple([l1, l2]), [cons_case, nil_case]),
self.l(TupleType([a, b])), [a, b])
if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)):
return NameConstant(True)
- # constructor patterns check whether the constructors match
- # and also the matches of any nested patterns
+ conds = []
- # equiv: (arg.tag == patern_constructor.tag)
- conds = [ast.Compare(ast.Attribute(data, 'tag', Load()),
- [ast.Eq()],
- [ast.Num(pattern.constructor.tag)])]
+ if isinstance(pattern, relay.PatternConstructor):
+ # constructor patterns check whether the constructors match
+ # and also the matches of any nested patterns
+ # equiv: (arg.tag == patern_constructor.tag)
+ conds.append(ast.Compare(ast.Attribute(data, 'tag', Load()),
+ [ast.Eq()],
+ [ast.Num(pattern.constructor.tag)]))
+
+ assert isinstance(pattern, (relay.PatternConstructor, relay.PatternTuple))
# now check for any nested patterns
for i in range(len(pattern.patterns)):
nested_pat = pattern.patterns[i]
*/
/*!
- * Copyright (c) 2018 by Contributors
+ * Copyright (c) 2019 by Contributors
* \file src/tvm/relay/interpreter.cc
* \brief An interpreter for the Relay IR.
*/
return false;
}
+ bool VisitPattern_(const PatternTupleNode* op, const Value& v) final {
+ const TupleValueNode* tvn = v.as<TupleValueNode>();
+ CHECK(tvn) << "need to be a tuple for match";
+ CHECK_EQ(op->patterns.size(), tvn->fields.size());
+ for (size_t i = 0; i < op->patterns.size(); ++i) {
+ if (!VisitPattern(op->patterns[i], tvn->fields[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final {
return true;
}
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
- } else {
- auto pat = pattern.as<PatternConstructorNode>();
- auto pattern = GetRef<PatternConstructor>(pat);
- auto tag = pattern->constructor->tag;
+ } else if (auto pcn = pattern.as<PatternConstructorNode>()) {
+ auto tag = pcn->constructor->tag;
size_t field_index = 0;
- for (auto& p : pattern->patterns) {
+ for (auto& p : pcn->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
+ } else {
+ auto pt = pattern.as<PatternTupleNode>();
+ CHECK(pt) << "unhandled case: " << pattern;
+ size_t field_index = 0;
+ for (auto& p : pt->patterns) {
+ auto d = std::make_shared<AccessField>(data, field_index);
+ then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
+ field_index++;
+ }
+ return then_branch;
}
}
<< ", " << node->patterns << ")";
});
+PatternTuple PatternTupleNode::make(tvm::Array<Pattern> patterns) {
+ NodePtr<PatternTupleNode> n = make_node<PatternTupleNode>();
+ n->patterns = std::move(patterns);
+ return PatternTuple(n);
+}
+
+TVM_REGISTER_NODE_TYPE(PatternTupleNode);
+
+TVM_REGISTER_API("relay._make.PatternTuple")
+.set_body_typed(PatternTupleNode::make);
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<PatternTupleNode>([](const PatternTupleNode* node,
+ tvm::IRPrinter* p) {
+ p->stream << "PatternTupleNode(" << node->patterns << ")";
+});
+
Constructor ConstructorNode::make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
}
bool PatternEqual(const Pattern& lhs, const Pattern& rhs) {
- return VisitPattern(lhs, rhs);
+ return Compare(VisitPattern(lhs, rhs), lhs, rhs);
}
bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final {
return true;
}
+ bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final {
+ const auto* rhs = other.as<PatternTupleNode>();
+ if (rhs == nullptr
+ || lhs->patterns.size() != rhs->patterns.size()) {
+ return false;
+ }
+
+ for (size_t i = 0; i < lhs->patterns.size(); i++) {
+ if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
bool VisitExpr_(const MatchNode* lhs, const Expr& other) final {
const MatchNode* rhs = other.as<MatchNode>();
return hash;
}
+ size_t VisitPattern_(const PatternTupleNode* ptn) final {
+ size_t hash = std::hash<std::string>()(PatternTupleNode::_type_key);
+ for (const auto& p : ptn->patterns) {
+ hash = Combine(hash, PatternHash(p));
+ }
+ return hash;
+ }
+
size_t VisitPattern_(const PatternVarNode* pvn) final {
size_t hash = std::hash<std::string>()(PatternVarNode::_type_key);
hash = Combine(hash, BindVar(pvn->var));
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
/*!
- * Copyright (c) 2018 by Contributors
- * \file src/tvm/relay/pattern_functor.cc
+ * Copyright (c) 2019 by Contributors
+ * \file src/relay/ir/pattern_functor.cc
* \brief Implementations of visitors and mutators for ADT patterns.
*/
return PatternConstructorNode::make(VisitConstructor(op->constructor), pat);
}
+Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) {
+ std::vector<Pattern> pat;
+ for (const auto& p : op->patterns) {
+ pat.push_back(VisitPattern(p));
+ }
+ return PatternTupleNode::make(pat);
+}
+
Type PatternMutator::VisitType(const Type& t) {
return t;
}
}
}
+void PatternVisitor::VisitPattern_(const PatternTupleNode* op) {
+ for (const auto& p : op->patterns) {
+ VisitPattern(p);
+ }
+}
+
void PatternVisitor::VisitType(const Type& t) { }
void PatternVisitor::VisitVar(const Var& v) {
}
// now check that subpatterns match
- CHECK(op->patterns.size() == ctor_cand->patterns.size());
+ CHECK_EQ(op->patterns.size(), ctor_cand->patterns.size());
bool unspecified = false;
for (size_t i = 0; i < op->patterns.size(); i++) {
MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]);
return MatchResult::kMatch;
}
+ MatchResult VisitPattern_(const PatternTupleNode* op, const Pattern& cand) override {
+ auto* tuple_cand = cand.as<PatternTupleNode>();
+ // attempting to match non-tuple to constructor pattern: need to specify
+ if (tuple_cand == nullptr) {
+ return MatchResult::kUnspecified;
+ }
+
+ // now check that subpatterns match
+ CHECK_EQ(op->patterns.size(), tuple_cand->patterns.size());
+ bool unspecified = false;
+ for (size_t i = 0; i < op->patterns.size(); i++) {
+ MatchResult submatch = this->Check(op->patterns[i], tuple_cand->patterns[i]);
+ // if we have a clash anywhere, then we can return clash
+ if (submatch == MatchResult::kClash) {
+ return MatchResult::kClash;
+ }
+ if (submatch == MatchResult::kUnspecified) {
+ unspecified = true;
+ }
+ }
+ // only return unspecified if we have ruled out a clash
+ if (unspecified) {
+ return MatchResult::kUnspecified;
+ }
+ return MatchResult::kMatch;
+ }
+
// wildcard and var patterns always match
MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override {
return MatchResult::kMatch;
return ret;
}
-// Expands all wildcards in the candidate pattern once, using the pattern
-// to decide which constructors to insert. Returns a list of all possible expansions.
-Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand,
+Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
+ const Pattern& cand,
+ const Module& mod);
+
+Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
+ const Pattern& cand,
+ const Module& mod);
+
+// Expands all wildcards in the candidate pattern once
+// Returns a list of all possible expansions.
+Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
+ const Pattern& cand,
const Module& mod) {
- auto ctor_cand = cand.as<PatternConstructorNode>();
- PatternConstructor clause_ctor = Downcast<PatternConstructor>(clause_pat);
+ if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
+ return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
+ } else {
+ return ExpandWildcardsTuple(Downcast<PatternTuple>(clause_pat), cand, mod);
+ }
+}
+
+// Expands all wildcards in the candidate pattern once.
+// Use the pattern to decide which constructors to insert.
+// Returns a list of all possible expansions.
+Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
+ const Pattern& cand,
+ const Module& mod) {
auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);
- // for a wildcard node, create constructor nodes with wildcards for all args
- if (!ctor_cand) {
+ // for a wildcard node, create constructor nodes with wildcards for all args.
+ if (cand.as<PatternWildcardNode>()) {
TypeData td = mod->LookupDef(gtv);
- // for each constructor add a candidate
+ // for each constructor add a candidate.
Array<Pattern> ret;
for (auto constructor : td->constructors) {
Array<Pattern> args;
return ret;
}
- // for constructors, we will expand the wildcards in any field
- // that is an ADT
+ auto ctor_cand = Downcast<PatternConstructor>(cand);
+
+ // for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
- auto* subpattern = clause_ctor->patterns[i].as<PatternConstructorNode>();
- // for non-ADT fields, we can only have a wildcard for the value
+ bool subpattern =
+ clause_ctor->patterns[i].as<PatternConstructorNode>() ||
+ clause_ctor->patterns[i].as<PatternTupleNode>();
+ // for non-ADT fields, we can only have a wildcard for the value.
if (!subpattern) {
values_by_field.push_back({PatternWildcardNode::make()});
- continue;
+ } else {
+ // otherwise, recursively expand.
+ values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
+ ctor_cand->patterns[i],
+ mod));
}
+ }
- // otherwise, recursively expand
- values_by_field.push_back(ExpandWildcards(GetRef<Pattern>(subpattern),
- ctor_cand->patterns[i], mod));
+ // generate new candidates using a cartesian product.
+ auto all_subfields = CartesianProduct(values_by_field);
+ Array<Pattern> ret;
+ for (auto subfields : all_subfields) {
+ ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields));
+ }
+ return ret;
+}
+
+// Expands all wildcards in the candidate pattern once.
+// Returns a list of all possible expansions.
+Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
+ const Pattern& cand,
+ const Module& mod) {
+ // for a wildcard node, create constructor nodes with wildcards for all args.
+ if (cand.as<PatternWildcardNode>()) {
+ Array<Pattern> args;
+ for (auto inp : clause_tuple->patterns) {
+ args.push_back(PatternWildcardNode::make());
+ }
+ return {PatternTupleNode::make(args)};
+ }
+
+ auto tuple_cand = Downcast<PatternTuple>(cand);
+
+ // for constructors, we will expand the wildcards in any field that is an ADT.
+ Array<Array<Pattern>> values_by_field;
+ for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
+ bool subpattern =
+ clause_tuple->patterns[i].as<PatternConstructorNode>() ||
+ clause_tuple->patterns[i].as<PatternTupleNode>();
+ // for non-ADT fields, we can only have a wildcard for the value
+ if (!subpattern) {
+ values_by_field.push_back({PatternWildcardNode::make()});
+ } else {
+ // otherwise, recursively expand
+ values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
+ tuple_cand->patterns[i],
+ mod));
+ }
}
// generate new candidates using a cartesian product
auto all_subfields = CartesianProduct(values_by_field);
Array<Pattern> ret;
for (auto subfields : all_subfields) {
- ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields));
+ ret.push_back(PatternTupleNode::make(subfields));
}
return ret;
}
}
}
+ MatchStatus VisitPattern_(const PatternTupleNode* op, const PStatic& ps) final {
+ if (ps->pstatic.defined()) {
+ STuple stn = Downcast<STuple>(ps->pstatic);
+ CHECK_EQ(op->patterns.size(), stn->fields.size());
+ MatchStatus current_match_status = MatchStatus::Match;
+ for (size_t i = 0; i < op->patterns.size(); ++i) {
+ MatchStatus ms = VisitPattern(op->patterns[i], stn->fields[i]);
+ switch (ms) {
+ case MatchStatus::Match:
+ continue;
+ case MatchStatus::NoMatch:
+ return MatchStatus::NoMatch;
+ case MatchStatus::Unknown:
+ current_match_status = MatchStatus::Unknown;
+ }
+ }
+ return current_match_status;
+ } else {
+ return MatchStatus::Unknown;
+ }
+ }
+
void InitializeFuncId(const Expr& e) {
struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor {
PartialEvaluator* pe;
}
}
+ void VisitPattern_(const PatternTupleNode* tup, const Type& t) {
+ auto pt = GetRef<PatternTuple>(tup);
+
+ // we can expect a certain number of arguments
+ Array<Type> unknown_args;
+ for (size_t i = 0; i < tup->patterns.size(); i++) {
+ unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
+ }
+ Type expected = TupleTypeNode::make(unknown_args);
+ Type unified = Unify(t, expected, GetRef<NodeRef>(tup));
+
+ auto* tt = unified.as<TupleTypeNode>();
+ if (!tt) {
+ this->ReportFatalError(pt, RELAY_ERROR("Expected a tuple type, got " << unified));
+ }
+ CHECK(tup->patterns.size() == tt->fields.size()) << "not enough pattern";
+ for (size_t i = 0; i < tup->patterns.size(); ++i) {
+ VisitPattern(tup->patterns[i], tt->fields[i]);
+ }
+ }
+
void VisitPattern_(const PatternVarNode* pv, const Type& t) {
Type vt = GetType(pv->var);
Unify(vt, t, pv->span);
assert not analysis.structural_hash(func1) == analysis.structural_hash(func3)
+
+def test_tuple_match():
+ a = relay.Var("a")
+ b = relay.Var("b")
+ clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
+ x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
+
+ a = relay.Var("a")
+ b = relay.Var("b")
+ clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
+ y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
+ assert analysis.alpha_equal(x, y)
+ assert analysis.structural_hash(x) == analysis.structural_hash(y)
+
+
if __name__ == "__main__":
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
transform.PartialEvaluate()(m)
+def test_tuple_match():
+ a = relay.Var("a")
+ b = relay.Var("b")
+ clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
+ x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
+ assert_alpha_equal(dcpe(x), const(2))
+
+
if __name__ == '__main__':
test_nat_update()
test_ref()
test_match_nat_id()
test_concat()
test_triangle_number()
+ test_tuple_match()
relay.Clause(relay.PatternConstructor(p.nil, []), v)
])
assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0
+
+
+def test_tuple_match():
+ a = relay.Var("a")
+ b = relay.Var("b")
+ clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
+ x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
+ assert len(unmatched_cases(x)) == 0