[Relay] add Tuple pattern (#3596)
author雾雨魔理沙 <lolisa@marisa.moe>
Thu, 5 Sep 2019 23:41:44 +0000 (16:41 -0700)
committerJared Roesch <roeschinc@gmail.com>
Thu, 5 Sep 2019 23:41:44 +0000 (16:41 -0700)
* implement tuple pattern

* add tuple pattern

* lint;

* lint

* lint

* fix error

* fix

* add test

18 files changed:
include/tvm/relay/adt.h
include/tvm/relay/pattern_functor.h
python/tvm/relay/__init__.py
python/tvm/relay/adt.py
python/tvm/relay/prelude.py
python/tvm/relay/testing/py_converter.py
src/relay/backend/interpreter.cc
src/relay/backend/vm/compiler.cc
src/relay/ir/adt.cc
src/relay/ir/alpha_equal.cc
src/relay/ir/hash.cc
src/relay/ir/pattern_functor.cc
src/relay/pass/match_exhaustion.cc
src/relay/pass/partial_eval.cc
src/relay/pass/type_infer.cc
tests/python/relay/test_pass_alpha_equal.py
tests/python/relay/test_pass_partial_eval.py
tests/python/relay/test_pass_unmatched_cases.py

index b634361..0a5adfa 100644 (file)
@@ -163,6 +163,29 @@ class PatternConstructorNode : public PatternNode {
 
 RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);
 
+/*! \brief A tuple pattern. Matches a tuple, binds recursively. */
+class PatternTuple;
+/*! \brief PatternVar container node */
+class PatternTupleNode : public PatternNode {
+ public:
+  /*! Sub-patterns to match against each value of the tuple. */
+  tvm::Array<Pattern> patterns;
+
+  PatternTupleNode() {}
+
+  TVM_DLL static PatternTuple make(tvm::Array<Pattern> var);
+
+  void VisitAttrs(tvm::AttrVisitor* v) final {
+    v->Visit("patterns", &patterns);
+    v->Visit("span", &span);
+  }
+
+  static constexpr const char* _type_key = "relay.PatternTuple";
+  TVM_DECLARE_NODE_TYPE_INFO(PatternTupleNode, PatternNode);
+};
+
+RELAY_DEFINE_NODE_REF(PatternTuple, PatternTupleNode, Pattern);
+
 /*!
  * \brief Stores all data for an Algebraic Data Type (ADT).
  *
index 0ced3ea..611b743 100644 (file)
@@ -100,6 +100,8 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
                           Args... args) PATTERN_FUNCTOR_DEFAULT;
   virtual R VisitPattern_(const PatternConstructorNode* op,
                           Args... args) PATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitPattern_(const PatternTupleNode* op,
+                          Args... args) PATTERN_FUNCTOR_DEFAULT;
   virtual R VisitPatternDefault_(const Node* op, Args...) {
     throw Error(std::string("Do not have a default for ") + op->type_key());
   }
@@ -112,6 +114,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
     RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode);
     RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode);
     RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode);
+    RELAY_PATTERN_FUNCTOR_DISPATCH(PatternTupleNode);
     return vtable;
   }
 };
@@ -127,6 +130,7 @@ class PatternVisitor : public ::tvm::relay::PatternFunctor<void(const Pattern& n
   void VisitPattern_(const PatternWildcardNode* op) override;
   void VisitPattern_(const PatternVarNode* op) override;
   void VisitPattern_(const PatternConstructorNode* op) override;
+  void VisitPattern_(const PatternTupleNode* op) override;
   virtual void VisitType(const Type& t);
   virtual void VisitVar(const Var& v);
   virtual void VisitConstructor(const Constructor& c);
@@ -144,6 +148,7 @@ class PatternMutator
   Pattern VisitPattern_(const PatternWildcardNode* op) override;
   Pattern VisitPattern_(const PatternVarNode* op) override;
   Pattern VisitPattern_(const PatternConstructorNode* op) override;
+  Pattern VisitPattern_(const PatternTupleNode* op) override;
   /*! \brief Used to visit the types inside of patterns.
    *
    * Can be overloaded to transform the types in arbitrary
index b56ef65..092cd01 100644 (file)
@@ -105,6 +105,7 @@ RefWrite = expr.RefWrite
 PatternWildcard = adt.PatternWildcard
 PatternVar = adt.PatternVar
 PatternConstructor = adt.PatternConstructor
+PatternTuple = adt.PatternTuple
 Constructor = adt.Constructor
 TypeData = adt.TypeData
 Clause = adt.Clause
index 0b1edc9..30db22c 100644 (file)
@@ -90,6 +90,29 @@ class PatternConstructor(Pattern):
 
 
 @register_relay_node
+class PatternTuple(Pattern):
+    """Constructor pattern in Relay: Matches a tuple, binds recursively."""
+
+    def __init__(self, patterns=None):
+        """Construct a tuple pattern.
+
+        Parameters
+        ----------
+        patterns: Optional[List[Pattern]]
+            Optional subpatterns: for each field of the constructor,
+            match to the given subpattern (treated as a variable pattern by default).
+
+        Returns
+        -------
+        wildcard: PatternWildcard
+            a wildcard pattern.
+        """
+        if patterns is None:
+            patterns = []
+        self.__init_handle_by_constructor__(_make.PatternTuple, patterns)
+
+
+@register_relay_node
 class Constructor(Expr):
     """Relay ADT constructor."""
 
index b5eac75..f9a7d3d 100644 (file)
@@ -21,7 +21,7 @@ from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type
 from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
 from .op.tensor import add, subtract, equal
 from .adt import Constructor, TypeData, Clause, Match
-from .adt import PatternConstructor, PatternVar, PatternWildcard
+from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple
 from .parser import fromtext
 __PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
 from .module import Module
@@ -239,18 +239,19 @@ class Prelude:
         self.zip = GlobalVar("zip")
         a = TypeVar("a")
         b = TypeVar("b")
-        nil_case = Clause(PatternConstructor(self.nil), self.nil())
         l1 = Var("l1")
         l2 = Var("l2")
         h1 = Var("h1")
         h2 = Var("h2")
         t1 = Var("t1")
         t2 = Var("t2")
-        inner_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h2), PatternVar(t2)]),
-                                 self.cons(Tuple([h1, h2]), self.zip(t1, t2)))
-        outer_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h1), PatternVar(t1)]),
-                                 Match(l2, [nil_case, inner_cons_case]))
-        self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]),
+        cons_case = Clause(PatternTuple([PatternConstructor(self.cons,
+                                                            [PatternVar(h1), PatternVar(t1)]),
+                                         PatternConstructor(self.cons,
+                                                            [PatternVar(h2), PatternVar(t2)])]),
+                           self.cons(Tuple([h1, h2]), self.zip(t1, t2)))
+        nil_case = Clause(PatternWildcard(), self.nil())
+        self.mod[self.zip] = Function([l1, l2], Match(Tuple([l1, l2]), [cons_case, nil_case]),
                                       self.l(TupleType([a, b])), [a, b])
 
 
index c003fe7..d661be7 100644 (file)
@@ -311,14 +311,18 @@ class PythonConverter(ExprFunctor):
         if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)):
             return NameConstant(True)
 
-        # constructor patterns check whether the constructors match
-        # and also the matches of any nested patterns
+        conds = []
 
-        # equiv: (arg.tag == patern_constructor.tag)
-        conds = [ast.Compare(ast.Attribute(data, 'tag', Load()),
-                             [ast.Eq()],
-                             [ast.Num(pattern.constructor.tag)])]
+        if isinstance(pattern, relay.PatternConstructor):
+            # constructor patterns check whether the constructors match
+            # and also the matches of any nested patterns
 
+            # equiv: (arg.tag == patern_constructor.tag)
+            conds.append(ast.Compare(ast.Attribute(data, 'tag', Load()),
+                                     [ast.Eq()],
+                                     [ast.Num(pattern.constructor.tag)]))
+
+        assert isinstance(pattern, (relay.PatternConstructor, relay.PatternTuple))
         # now check for any nested patterns
         for i in range(len(pattern.patterns)):
             nested_pat = pattern.patterns[i]
index dedff7a..e77d6a8 100644 (file)
@@ -18,7 +18,7 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
+ *  Copyright (c) 2019 by Contributors
  * \file src/tvm/relay/interpreter.cc
  * \brief An interpreter for the Relay IR.
  */
@@ -708,6 +708,18 @@ class Interpreter :
     return false;
   }
 
+  bool VisitPattern_(const PatternTupleNode* op, const Value& v) final {
+    const TupleValueNode* tvn = v.as<TupleValueNode>();
+    CHECK(tvn) << "need to be a tuple for match";
+    CHECK_EQ(op->patterns.size(), tvn->fields.size());
+    for (size_t i = 0; i < op->patterns.size(); ++i) {
+      if (!VisitPattern(op->patterns[i], tvn->fields[i])) {
+        return false;
+      }
+    }
+    return true;
+  }
+
   bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final {
     return true;
   }
index 5e5bc1a..39508d2 100644 (file)
@@ -152,19 +152,27 @@ TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data,
     auto pattern = GetRef<PatternVar>(pat);
     auto cond = std::make_shared<VarBinding>(pattern->var, data);
     return TreeBranchNode::Make(cond, then_branch, else_branch);
-  } else {
-    auto pat = pattern.as<PatternConstructorNode>();
-    auto pattern = GetRef<PatternConstructor>(pat);
-    auto tag = pattern->constructor->tag;
+  } else if (auto pcn = pattern.as<PatternConstructorNode>()) {
+    auto tag = pcn->constructor->tag;
 
     size_t field_index = 0;
-    for (auto& p : pattern->patterns) {
+    for (auto& p : pcn->patterns) {
       auto d = std::make_shared<AccessField>(data, field_index);
       then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
       field_index++;
     }
     auto cond = std::make_shared<TagCompare>(data, tag);
     return TreeBranchNode::Make(cond, then_branch, else_branch);
+  } else {
+    auto pt = pattern.as<PatternTupleNode>();
+    CHECK(pt) << "unhandled case: " << pattern;
+    size_t field_index = 0;
+    for (auto& p : pt->patterns) {
+      auto d = std::make_shared<AccessField>(data, field_index);
+      then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
+      field_index++;
+    }
+    return then_branch;
   }
 }
 
index b17bf41..9c670bf 100644 (file)
@@ -81,6 +81,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
             << ", " << node->patterns << ")";
 });
 
+PatternTuple PatternTupleNode::make(tvm::Array<Pattern> patterns) {
+  NodePtr<PatternTupleNode> n = make_node<PatternTupleNode>();
+  n->patterns = std::move(patterns);
+  return PatternTuple(n);
+}
+
+TVM_REGISTER_NODE_TYPE(PatternTupleNode);
+
+TVM_REGISTER_API("relay._make.PatternTuple")
+.set_body_typed(PatternTupleNode::make);
+
+TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
+.set_dispatch<PatternTupleNode>([](const PatternTupleNode* node,
+                                   tvm::IRPrinter* p) {
+  p->stream << "PatternTupleNode(" << node->patterns << ")";
+});
+
 Constructor ConstructorNode::make(std::string name_hint,
                                   tvm::Array<Type> inputs,
                                   GlobalTypeVar belong_to) {
index 2c23f0f..515db37 100644 (file)
@@ -493,7 +493,7 @@ class AlphaEqualHandler:
   }
 
   bool PatternEqual(const Pattern& lhs, const Pattern& rhs) {
-    return VisitPattern(lhs, rhs);
+    return Compare(VisitPattern(lhs, rhs), lhs, rhs);
   }
 
   bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final {
@@ -523,6 +523,21 @@ class AlphaEqualHandler:
     return true;
   }
 
+  bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final {
+    const auto* rhs = other.as<PatternTupleNode>();
+    if (rhs == nullptr
+        || lhs->patterns.size() != rhs->patterns.size()) {
+      return false;
+    }
+
+    for (size_t i = 0; i < lhs->patterns.size(); i++) {
+      if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
+        return false;
+      }
+    }
+    return true;
+  }
+
   bool VisitExpr_(const MatchNode* lhs, const Expr& other) final {
     const MatchNode* rhs = other.as<MatchNode>();
 
index 9287f38..d392533 100644 (file)
@@ -389,6 +389,14 @@ class RelayHashHandler:
     return hash;
   }
 
+  size_t VisitPattern_(const PatternTupleNode* ptn) final {
+    size_t hash = std::hash<std::string>()(PatternTupleNode::_type_key);
+    for (const auto& p : ptn->patterns) {
+      hash = Combine(hash, PatternHash(p));
+    }
+    return hash;
+  }
+
   size_t VisitPattern_(const PatternVarNode* pvn) final {
     size_t hash = std::hash<std::string>()(PatternVarNode::_type_key);
     hash = Combine(hash, BindVar(pvn->var));
index 29171bf..5095373 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,8 +18,8 @@
  */
 
 /*!
- *  Copyright (c) 2018 by Contributors
- * \file src/tvm/relay/pattern_functor.cc
+ *  Copyright (c) 2019 by Contributors
+ * \file src/relay/ir/pattern_functor.cc
  * \brief Implementations of visitors and mutators for ADT patterns.
  */
 
@@ -48,6 +48,14 @@ Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) {
   return PatternConstructorNode::make(VisitConstructor(op->constructor), pat);
 }
 
+Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) {
+  std::vector<Pattern> pat;
+  for (const auto& p : op->patterns) {
+    pat.push_back(VisitPattern(p));
+  }
+  return PatternTupleNode::make(pat);
+}
+
 Type PatternMutator::VisitType(const Type& t) {
   return t;
 }
@@ -78,6 +86,12 @@ void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) {
   }
 }
 
+void PatternVisitor::VisitPattern_(const PatternTupleNode* op) {
+  for (const auto& p : op->patterns) {
+    VisitPattern(p);
+  }
+}
+
 void PatternVisitor::VisitType(const Type& t) { }
 
 void PatternVisitor::VisitVar(const Var& v) {
index cc00a54..07331a1 100644 (file)
@@ -68,7 +68,7 @@ class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const
     }
 
     // now check that subpatterns match
-    CHECK(op->patterns.size() == ctor_cand->patterns.size());
+    CHECK_EQ(op->patterns.size(), ctor_cand->patterns.size());
     bool unspecified = false;
     for (size_t i = 0; i < op->patterns.size(); i++) {
       MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]);
@@ -87,6 +87,33 @@ class CandidateChecker : public PatternFunctor<MatchResult(const Pattern&, const
     return MatchResult::kMatch;
   }
 
+  MatchResult VisitPattern_(const PatternTupleNode* op, const Pattern& cand) override {
+    auto* tuple_cand = cand.as<PatternTupleNode>();
+    // attempting to match non-tuple to constructor pattern: need to specify
+    if (tuple_cand == nullptr) {
+      return MatchResult::kUnspecified;
+    }
+
+    // now check that subpatterns match
+    CHECK_EQ(op->patterns.size(), tuple_cand->patterns.size());
+    bool unspecified = false;
+    for (size_t i = 0; i < op->patterns.size(); i++) {
+      MatchResult submatch = this->Check(op->patterns[i], tuple_cand->patterns[i]);
+      // if we have a clash anywhere, then we can return clash
+      if (submatch == MatchResult::kClash) {
+        return MatchResult::kClash;
+      }
+      if (submatch == MatchResult::kUnspecified) {
+        unspecified = true;
+      }
+    }
+    // only return unspecified if we have ruled out a clash
+    if (unspecified) {
+      return MatchResult::kUnspecified;
+    }
+    return MatchResult::kMatch;
+  }
+
   // wildcard and var patterns always match
   MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override {
     return MatchResult::kMatch;
@@ -127,18 +154,38 @@ Array<Array<Pattern>> CartesianProduct(Array<Array<Pattern>> fields) {
   return ret;
 }
 
-// Expands all wildcards in the candidate pattern once, using the pattern
-// to decide which constructors to insert. Returns a list of all possible expansions.
-Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand,
+Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
+                                          const Pattern& cand,
+                                          const Module& mod);
+
+Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
+                                    const Pattern& cand,
+                                    const Module& mod);
+
+// Expands all wildcards in the candidate pattern once
+// Returns a list of all possible expansions.
+Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
+                               const Pattern& cand,
                                const Module& mod) {
-  auto ctor_cand = cand.as<PatternConstructorNode>();
-  PatternConstructor clause_ctor = Downcast<PatternConstructor>(clause_pat);
+  if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
+    return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
+  } else {
+    return ExpandWildcardsTuple(Downcast<PatternTuple>(clause_pat), cand, mod);
+  }
+}
+
+// Expands all wildcards in the candidate pattern once.
+// Use the pattern to decide which constructors to insert.
+// Returns a list of all possible expansions.
+Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
+                                          const Pattern& cand,
+                                          const Module& mod) {
   auto gtv = Downcast<GlobalTypeVar>(clause_ctor->constructor->belong_to);
 
-  // for a wildcard node, create constructor nodes with wildcards for all args
-  if (!ctor_cand) {
+  // for a wildcard node, create constructor nodes with wildcards for all args.
+  if (cand.as<PatternWildcardNode>()) {
     TypeData td = mod->LookupDef(gtv);
-    // for each constructor add a candidate
+    // for each constructor add a candidate.
     Array<Pattern> ret;
     for (auto constructor : td->constructors) {
       Array<Pattern> args;
@@ -150,27 +197,72 @@ Array<Pattern> ExpandWildcards(const Pattern& clause_pat, const Pattern& cand,
     return ret;
   }
 
-  // for constructors, we will expand the wildcards in any field
-  // that is an ADT
+  auto ctor_cand = Downcast<PatternConstructor>(cand);
+
+  // for constructors, we will expand the wildcards in any field that is an ADT.
   Array<Array<Pattern>> values_by_field;
   for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
-    auto* subpattern = clause_ctor->patterns[i].as<PatternConstructorNode>();
-    // for non-ADT fields, we can only have a wildcard for the value
+    bool subpattern =
+      clause_ctor->patterns[i].as<PatternConstructorNode>() ||
+      clause_ctor->patterns[i].as<PatternTupleNode>();
+    // for non-ADT fields, we can only have a wildcard for the value.
     if (!subpattern) {
       values_by_field.push_back({PatternWildcardNode::make()});
-      continue;
+    } else {
+      // otherwise, recursively expand.
+      values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
+                                                ctor_cand->patterns[i],
+                                                mod));
     }
+  }
 
-    // otherwise, recursively expand
-    values_by_field.push_back(ExpandWildcards(GetRef<Pattern>(subpattern),
-                                              ctor_cand->patterns[i], mod));
+  // generate new candidates using a cartesian product.
+  auto all_subfields = CartesianProduct(values_by_field);
+  Array<Pattern> ret;
+  for (auto subfields : all_subfields) {
+    ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields));
+  }
+  return ret;
+}
+
+// Expands all wildcards in the candidate pattern once.
+// Returns a list of all possible expansions.
+Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
+                                    const Pattern& cand,
+                                    const Module& mod) {
+  // for a wildcard node, create constructor nodes with wildcards for all args.
+  if (cand.as<PatternWildcardNode>()) {
+    Array<Pattern> args;
+    for (auto inp : clause_tuple->patterns) {
+      args.push_back(PatternWildcardNode::make());
+    }
+    return {PatternTupleNode::make(args)};
+  }
+
+  auto tuple_cand = Downcast<PatternTuple>(cand);
+
+  // for constructors, we will expand the wildcards in any field that is an ADT.
+  Array<Array<Pattern>> values_by_field;
+  for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
+    bool subpattern =
+      clause_tuple->patterns[i].as<PatternConstructorNode>() ||
+      clause_tuple->patterns[i].as<PatternTupleNode>();
+    // for non-ADT fields, we can only have a wildcard for the value
+    if (!subpattern) {
+      values_by_field.push_back({PatternWildcardNode::make()});
+    } else {
+      // otherwise, recursively expand
+      values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
+                                                tuple_cand->patterns[i],
+                                                mod));
+    }
   }
 
   // generate new candidates using a cartesian product
   auto all_subfields = CartesianProduct(values_by_field);
   Array<Pattern> ret;
   for (auto subfields : all_subfields) {
-    ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields));
+    ret.push_back(PatternTupleNode::make(subfields));
   }
   return ret;
 }
index 3f92d7a..906d245 100644 (file)
@@ -1051,6 +1051,28 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
     }
   }
 
+  MatchStatus VisitPattern_(const PatternTupleNode* op, const PStatic& ps) final {
+    if (ps->pstatic.defined()) {
+      STuple stn = Downcast<STuple>(ps->pstatic);
+      CHECK_EQ(op->patterns.size(), stn->fields.size());
+      MatchStatus current_match_status = MatchStatus::Match;
+      for (size_t i = 0; i < op->patterns.size(); ++i) {
+        MatchStatus ms = VisitPattern(op->patterns[i], stn->fields[i]);
+        switch (ms) {
+        case MatchStatus::Match:
+          continue;
+        case MatchStatus::NoMatch:
+          return MatchStatus::NoMatch;
+        case MatchStatus::Unknown:
+          current_match_status = MatchStatus::Unknown;
+        }
+      }
+      return current_match_status;
+    } else {
+      return MatchStatus::Unknown;
+    }
+  }
+
   void InitializeFuncId(const Expr& e) {
     struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor {
       PartialEvaluator* pe;
index 1d10651..d0f1b7a 100644 (file)
@@ -276,6 +276,27 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
     }
   }
 
+  void VisitPattern_(const PatternTupleNode* tup, const Type& t) {
+    auto pt = GetRef<PatternTuple>(tup);
+
+    // we can expect a certain number of arguments
+    Array<Type> unknown_args;
+    for (size_t i = 0; i < tup->patterns.size(); i++) {
+      unknown_args.push_back(IncompleteTypeNode::make(Kind::kType));
+    }
+    Type expected = TupleTypeNode::make(unknown_args);
+    Type unified = Unify(t, expected, GetRef<NodeRef>(tup));
+
+    auto* tt = unified.as<TupleTypeNode>();
+    if (!tt) {
+      this->ReportFatalError(pt, RELAY_ERROR("Expected a tuple type, got " << unified));
+    }
+    CHECK(tup->patterns.size() == tt->fields.size()) << "not enough pattern";
+    for (size_t i = 0; i < tup->patterns.size(); ++i) {
+      VisitPattern(tup->patterns[i], tt->fields[i]);
+    }
+  }
+
   void VisitPattern_(const PatternVarNode* pv, const Type& t) {
     Type vt = GetType(pv->var);
     Unify(vt, t, pv->span);
index 9d07fb1..b240daf 100644 (file)
@@ -611,6 +611,21 @@ def test_hash_unequal():
 
     assert not analysis.structural_hash(func1) == analysis.structural_hash(func3)
 
+
+def test_tuple_match():
+    a = relay.Var("a")
+    b = relay.Var("b")
+    clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
+    x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
+
+    a = relay.Var("a")
+    b = relay.Var("b")
+    clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
+    y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
+    assert analysis.alpha_equal(x, y)
+    assert analysis.structural_hash(x) == analysis.structural_hash(y)
+
+
 if __name__ == "__main__":
     test_tensor_type_alpha_equal()
     test_incomplete_type_alpha_equal()
index f914f18..cf4f8f6 100644 (file)
@@ -331,6 +331,14 @@ def test_nat_update():
     transform.PartialEvaluate()(m)
 
 
+def test_tuple_match():
+    a = relay.Var("a")
+    b = relay.Var("b")
+    clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
+    x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
+    assert_alpha_equal(dcpe(x), const(2))
+
+
 if __name__ == '__main__':
     test_nat_update()
     test_ref()
@@ -351,3 +359,4 @@ if __name__ == '__main__':
     test_match_nat_id()
     test_concat()
     test_triangle_number()
+    test_tuple_match()
index 776f5a0..b06de4c 100644 (file)
@@ -265,3 +265,11 @@ def test_mixed_adt_constructors():
         relay.Clause(relay.PatternConstructor(p.nil, []), v)
     ])
     assert len(unmatched_cases(list_of_boxes_comp, mod)) == 0
+
+
+def test_tuple_match():
+    a = relay.Var("a")
+    b = relay.Var("b")
+    clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
+    x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
+    assert len(unmatched_cases(x)) == 0