From: 雾雨魔理沙 Date: Thu, 5 Sep 2019 23:41:44 +0000 (-0700) Subject: [Relay] add Tuple pattern (#3596) X-Git-Tag: upstream/0.7.0~1949 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=08d92203f794eebcb159f8cb7309c7d769aaf813;p=platform%2Fupstream%2Ftvm.git [Relay] add Tuple pattern (#3596) * implement tuple pattern * add tuple pattern * lint; * lint * lint * fix error * fix * add test --- diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index b634361..0a5adfa 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -163,6 +163,29 @@ class PatternConstructorNode : public PatternNode { RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern); +/*! \brief A tuple pattern. Matches a tuple, binds recursively. */ +class PatternTuple; +/*! \brief PatternVar container node */ +class PatternTupleNode : public PatternNode { + public: + /*! Sub-patterns to match against each value of the tuple. */ + tvm::Array patterns; + + PatternTupleNode() {} + + TVM_DLL static PatternTuple make(tvm::Array var); + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("patterns", &patterns); + v->Visit("span", &span); + } + + static constexpr const char* _type_key = "relay.PatternTuple"; + TVM_DECLARE_NODE_TYPE_INFO(PatternTupleNode, PatternNode); +}; + +RELAY_DEFINE_NODE_REF(PatternTuple, PatternTupleNode, Pattern); + /*! * \brief Stores all data for an Algebraic Data Type (ADT). * diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index 0ced3ea..611b743 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -100,6 +100,8 @@ class PatternFunctor { Args... args) PATTERN_FUNCTOR_DEFAULT; virtual R VisitPattern_(const PatternConstructorNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternTupleNode* op, + Args... args) PATTERN_FUNCTOR_DEFAULT; virtual R VisitPatternDefault_(const Node* op, Args...) { throw Error(std::string("Do not have a default for ") + op->type_key()); } @@ -112,6 +114,7 @@ class PatternFunctor { RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode); RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode); RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode); + RELAY_PATTERN_FUNCTOR_DISPATCH(PatternTupleNode); return vtable; } }; @@ -127,6 +130,7 @@ class PatternVisitor : public ::tvm::relay::PatternFunctor(); + CHECK(tvn) << "need to be a tuple for match"; + CHECK_EQ(op->patterns.size(), tvn->fields.size()); + for (size_t i = 0; i < op->patterns.size(); ++i) { + if (!VisitPattern(op->patterns[i], tvn->fields[i])) { + return false; + } + } + return true; + } + bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final { return true; } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 5e5bc1a..39508d2 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -152,19 +152,27 @@ TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data, auto pattern = GetRef(pat); auto cond = std::make_shared(pattern->var, data); return TreeBranchNode::Make(cond, then_branch, else_branch); - } else { - auto pat = pattern.as(); - auto pattern = GetRef(pat); - auto tag = pattern->constructor->tag; + } else if (auto pcn = pattern.as()) { + auto tag = pcn->constructor->tag; size_t field_index = 0; - for (auto& p : pattern->patterns) { + for (auto& p : pcn->patterns) { auto d = std::make_shared(data, field_index); then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch); field_index++; } auto cond = std::make_shared(data, tag); return TreeBranchNode::Make(cond, then_branch, else_branch); + } else { + auto pt = pattern.as(); + CHECK(pt) << "unhandled case: " << pattern; + size_t field_index = 0; + for (auto& p : pt->patterns) { + auto d = std::make_shared(data, field_index); + then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch); + field_index++; + } + return then_branch; } } diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index b17bf41..9c670bf 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -81,6 +81,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << ", " << node->patterns << ")"; }); +PatternTuple PatternTupleNode::make(tvm::Array patterns) { + NodePtr n = make_node(); + n->patterns = std::move(patterns); + return PatternTuple(n); +} + +TVM_REGISTER_NODE_TYPE(PatternTupleNode); + +TVM_REGISTER_API("relay._make.PatternTuple") +.set_body_typed(PatternTupleNode::make); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const PatternTupleNode* node, + tvm::IRPrinter* p) { + p->stream << "PatternTupleNode(" << node->patterns << ")"; +}); + Constructor ConstructorNode::make(std::string name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 2c23f0f..515db37 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -493,7 +493,7 @@ class AlphaEqualHandler: } bool PatternEqual(const Pattern& lhs, const Pattern& rhs) { - return VisitPattern(lhs, rhs); + return Compare(VisitPattern(lhs, rhs), lhs, rhs); } bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final { @@ -523,6 +523,21 @@ class AlphaEqualHandler: return true; } + bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final { + const auto* rhs = other.as(); + if (rhs == nullptr + || lhs->patterns.size() != rhs->patterns.size()) { + return false; + } + + for (size_t i = 0; i < lhs->patterns.size(); i++) { + if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) { + return false; + } + } + return true; + } + bool VisitExpr_(const MatchNode* lhs, const Expr& other) final { const MatchNode* rhs = other.as(); diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 9287f38..d392533 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -389,6 +389,14 @@ class RelayHashHandler: return hash; } + size_t VisitPattern_(const PatternTupleNode* ptn) final { + size_t hash = std::hash()(PatternTupleNode::_type_key); + for (const auto& p : ptn->patterns) { + hash = Combine(hash, PatternHash(p)); + } + return hash; + } + size_t VisitPattern_(const PatternVarNode* pvn) final { size_t hash = std::hash()(PatternVarNode::_type_key); hash = Combine(hash, BindVar(pvn->var)); diff --git a/src/relay/ir/pattern_functor.cc b/src/relay/ir/pattern_functor.cc index 29171bf..5095373 100644 --- a/src/relay/ir/pattern_functor.cc +++ b/src/relay/ir/pattern_functor.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,8 +18,8 @@ */ /*! - * Copyright (c) 2018 by Contributors - * \file src/tvm/relay/pattern_functor.cc + * Copyright (c) 2019 by Contributors + * \file src/relay/ir/pattern_functor.cc * \brief Implementations of visitors and mutators for ADT patterns. */ @@ -48,6 +48,14 @@ Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) { return PatternConstructorNode::make(VisitConstructor(op->constructor), pat); } +Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) { + std::vector pat; + for (const auto& p : op->patterns) { + pat.push_back(VisitPattern(p)); + } + return PatternTupleNode::make(pat); +} + Type PatternMutator::VisitType(const Type& t) { return t; } @@ -78,6 +86,12 @@ void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) { } } +void PatternVisitor::VisitPattern_(const PatternTupleNode* op) { + for (const auto& p : op->patterns) { + VisitPattern(p); + } +} + void PatternVisitor::VisitType(const Type& t) { } void PatternVisitor::VisitVar(const Var& v) { diff --git a/src/relay/pass/match_exhaustion.cc b/src/relay/pass/match_exhaustion.cc index cc00a54..07331a1 100644 --- a/src/relay/pass/match_exhaustion.cc +++ b/src/relay/pass/match_exhaustion.cc @@ -68,7 +68,7 @@ class CandidateChecker : public PatternFunctorpatterns.size() == ctor_cand->patterns.size()); + CHECK_EQ(op->patterns.size(), ctor_cand->patterns.size()); bool unspecified = false; for (size_t i = 0; i < op->patterns.size(); i++) { MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]); @@ -87,6 +87,33 @@ class CandidateChecker : public PatternFunctor(); + // attempting to match non-tuple to constructor pattern: need to specify + if (tuple_cand == nullptr) { + return MatchResult::kUnspecified; + } + + // now check that subpatterns match + CHECK_EQ(op->patterns.size(), tuple_cand->patterns.size()); + bool unspecified = false; + for (size_t i = 0; i < op->patterns.size(); i++) { + MatchResult submatch = this->Check(op->patterns[i], tuple_cand->patterns[i]); + // if we have a clash anywhere, then we can return clash + if (submatch == MatchResult::kClash) { + return MatchResult::kClash; + } + if (submatch == MatchResult::kUnspecified) { + unspecified = true; + } + } + // only return unspecified if we have ruled out a clash + if (unspecified) { + return MatchResult::kUnspecified; + } + return MatchResult::kMatch; + } + // wildcard and var patterns always match MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override { return MatchResult::kMatch; @@ -127,18 +154,38 @@ Array> CartesianProduct(Array> fields) { return ret; } -// Expands all wildcards in the candidate pattern once, using the pattern -// to decide which constructors to insert. Returns a list of all possible expansions. -Array ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, +Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, + const Pattern& cand, + const Module& mod); + +Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, + const Pattern& cand, + const Module& mod); + +// Expands all wildcards in the candidate pattern once +// Returns a list of all possible expansions. +Array ExpandWildcards(const Pattern& clause_pat, + const Pattern& cand, const Module& mod) { - auto ctor_cand = cand.as(); - PatternConstructor clause_ctor = Downcast(clause_pat); + if (auto clause_ctor = clause_pat.as()) { + return ExpandWildcardsConstructor(GetRef(clause_ctor), cand, mod); + } else { + return ExpandWildcardsTuple(Downcast(clause_pat), cand, mod); + } +} + +// Expands all wildcards in the candidate pattern once. +// Use the pattern to decide which constructors to insert. +// Returns a list of all possible expansions. +Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, + const Pattern& cand, + const Module& mod) { auto gtv = Downcast(clause_ctor->constructor->belong_to); - // for a wildcard node, create constructor nodes with wildcards for all args - if (!ctor_cand) { + // for a wildcard node, create constructor nodes with wildcards for all args. + if (cand.as()) { TypeData td = mod->LookupDef(gtv); - // for each constructor add a candidate + // for each constructor add a candidate. Array ret; for (auto constructor : td->constructors) { Array args; @@ -150,27 +197,72 @@ Array ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, return ret; } - // for constructors, we will expand the wildcards in any field - // that is an ADT + auto ctor_cand = Downcast(cand); + + // for constructors, we will expand the wildcards in any field that is an ADT. Array> values_by_field; for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) { - auto* subpattern = clause_ctor->patterns[i].as(); - // for non-ADT fields, we can only have a wildcard for the value + bool subpattern = + clause_ctor->patterns[i].as() || + clause_ctor->patterns[i].as(); + // for non-ADT fields, we can only have a wildcard for the value. if (!subpattern) { values_by_field.push_back({PatternWildcardNode::make()}); - continue; + } else { + // otherwise, recursively expand. + values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i], + ctor_cand->patterns[i], + mod)); } + } - // otherwise, recursively expand - values_by_field.push_back(ExpandWildcards(GetRef(subpattern), - ctor_cand->patterns[i], mod)); + // generate new candidates using a cartesian product. + auto all_subfields = CartesianProduct(values_by_field); + Array ret; + for (auto subfields : all_subfields) { + ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields)); + } + return ret; +} + +// Expands all wildcards in the candidate pattern once. +// Returns a list of all possible expansions. +Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, + const Pattern& cand, + const Module& mod) { + // for a wildcard node, create constructor nodes with wildcards for all args. + if (cand.as()) { + Array args; + for (auto inp : clause_tuple->patterns) { + args.push_back(PatternWildcardNode::make()); + } + return {PatternTupleNode::make(args)}; + } + + auto tuple_cand = Downcast(cand); + + // for constructors, we will expand the wildcards in any field that is an ADT. + Array> values_by_field; + for (size_t i = 0; i < tuple_cand->patterns.size(); i++) { + bool subpattern = + clause_tuple->patterns[i].as() || + clause_tuple->patterns[i].as(); + // for non-ADT fields, we can only have a wildcard for the value + if (!subpattern) { + values_by_field.push_back({PatternWildcardNode::make()}); + } else { + // otherwise, recursively expand + values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i], + tuple_cand->patterns[i], + mod)); + } } // generate new candidates using a cartesian product auto all_subfields = CartesianProduct(values_by_field); Array ret; for (auto subfields : all_subfields) { - ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields)); + ret.push_back(PatternTupleNode::make(subfields)); } return ret; } diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 3f92d7a..906d245 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -1051,6 +1051,28 @@ class PartialEvaluator : public ExprFunctor } } + MatchStatus VisitPattern_(const PatternTupleNode* op, const PStatic& ps) final { + if (ps->pstatic.defined()) { + STuple stn = Downcast(ps->pstatic); + CHECK_EQ(op->patterns.size(), stn->fields.size()); + MatchStatus current_match_status = MatchStatus::Match; + for (size_t i = 0; i < op->patterns.size(); ++i) { + MatchStatus ms = VisitPattern(op->patterns[i], stn->fields[i]); + switch (ms) { + case MatchStatus::Match: + continue; + case MatchStatus::NoMatch: + return MatchStatus::NoMatch; + case MatchStatus::Unknown: + current_match_status = MatchStatus::Unknown; + } + } + return current_match_status; + } else { + return MatchStatus::Unknown; + } + } + void InitializeFuncId(const Expr& e) { struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor { PartialEvaluator* pe; diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 1d10651..d0f1b7a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -276,6 +276,27 @@ class TypeInferencer : private ExprFunctor, } } + void VisitPattern_(const PatternTupleNode* tup, const Type& t) { + auto pt = GetRef(tup); + + // we can expect a certain number of arguments + Array unknown_args; + for (size_t i = 0; i < tup->patterns.size(); i++) { + unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); + } + Type expected = TupleTypeNode::make(unknown_args); + Type unified = Unify(t, expected, GetRef(tup)); + + auto* tt = unified.as(); + if (!tt) { + this->ReportFatalError(pt, RELAY_ERROR("Expected a tuple type, got " << unified)); + } + CHECK(tup->patterns.size() == tt->fields.size()) << "not enough pattern"; + for (size_t i = 0; i < tup->patterns.size(); ++i) { + VisitPattern(tup->patterns[i], tt->fields[i]); + } + } + void VisitPattern_(const PatternVarNode* pv, const Type& t) { Type vt = GetType(pv->var); Unify(vt, t, pv->span); diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 9d07fb1..b240daf 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -611,6 +611,21 @@ def test_hash_unequal(): assert not analysis.structural_hash(func1) == analysis.structural_hash(func3) + +def test_tuple_match(): + a = relay.Var("a") + b = relay.Var("b") + clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) + x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) + + a = relay.Var("a") + b = relay.Var("b") + clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) + y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) + assert analysis.alpha_equal(x, y) + assert analysis.structural_hash(x) == analysis.structural_hash(y) + + if __name__ == "__main__": test_tensor_type_alpha_equal() test_incomplete_type_alpha_equal() diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index f914f18..cf4f8f6 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -331,6 +331,14 @@ def test_nat_update(): transform.PartialEvaluate()(m) +def test_tuple_match(): + a = relay.Var("a") + b = relay.Var("b") + clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) + x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) + assert_alpha_equal(dcpe(x), const(2)) + + if __name__ == '__main__': test_nat_update() test_ref() @@ -351,3 +359,4 @@ if __name__ == '__main__': test_match_nat_id() test_concat() test_triangle_number() + test_tuple_match() diff --git a/tests/python/relay/test_pass_unmatched_cases.py b/tests/python/relay/test_pass_unmatched_cases.py index 776f5a0..b06de4c 100644 --- a/tests/python/relay/test_pass_unmatched_cases.py +++ b/tests/python/relay/test_pass_unmatched_cases.py @@ -265,3 +265,11 @@ def test_mixed_adt_constructors(): relay.Clause(relay.PatternConstructor(p.nil, []), v) ]) assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0 + + +def test_tuple_match(): + a = relay.Var("a") + b = relay.Var("b") + clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) + x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) + assert len(unmatched_cases(x)) == 0