[WIP] Fixing an Infinite Loop case in UnmatchedChecker. (#4881)
author雾雨魔理沙 <lolisa@marisa.moe>
Wed, 26 Feb 2020 01:35:49 +0000 (17:35 -0800)
committerGitHub <noreply@github.com>
Wed, 26 Feb 2020 01:35:49 +0000 (17:35 -0800)
* save

* save

* remove

* remove cerr

src/relay/pass/match_exhaustion.cc
tests/python/relay/test_pass_unmatched_cases.py

index 885c47e..14be6b7 100644 (file)
@@ -168,8 +168,10 @@ Array<Pattern> ExpandWildcards(const Pattern& clause_pat,
                                const IRModule& mod) {
   if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
     return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
+  } else if (auto clause_tup = clause_pat.as<PatternTupleNode>()) {
+    return ExpandWildcardsTuple(GetRef<PatternTuple>(clause_tup), cand, mod);
   } else {
-    return ExpandWildcardsTuple(Downcast<PatternTuple>(clause_pat), cand, mod);
+    return {cand};
   }
 }
 
@@ -201,18 +203,9 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
   // for constructors, we will expand the wildcards in any field that is an ADT.
   Array<Array<Pattern>> values_by_field;
   for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
-    bool subpattern =
-      clause_ctor->patterns[i].as<PatternConstructorNode>() ||
-      clause_ctor->patterns[i].as<PatternTupleNode>();
-    // for non-ADT fields, we can only have a wildcard for the value.
-    if (!subpattern) {
-      values_by_field.push_back({PatternWildcardNode::make()});
-    } else {
-      // otherwise, recursively expand.
-      values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
-                                                ctor_cand->patterns[i],
-                                                mod));
-    }
+    values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
+                                              ctor_cand->patterns[i],
+                                              mod));
   }
 
   // generate new candidates using a cartesian product.
@@ -243,18 +236,9 @@ Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
   // for constructors, we will expand the wildcards in any field that is an ADT.
   Array<Array<Pattern>> values_by_field;
   for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
-    bool subpattern =
-      clause_tuple->patterns[i].as<PatternConstructorNode>() ||
-      clause_tuple->patterns[i].as<PatternTupleNode>();
-    // for non-ADT fields, we can only have a wildcard for the value
-    if (!subpattern) {
-      values_by_field.push_back({PatternWildcardNode::make()});
-    } else {
-      // otherwise, recursively expand
-      values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
-                                                tuple_cand->patterns[i],
-                                                mod));
-    }
+    values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
+                                              tuple_cand->patterns[i],
+                                              mod));
   }
 
   // generate new candidates using a cartesian product
index 615d4e0..1ac99a6 100644 (file)
@@ -19,6 +19,7 @@ import tvm
 from tvm import relay
 from tvm.relay.prelude import Prelude
 from tvm.relay.analysis import unmatched_cases
+import pytest
 
 def test_empty_match_block():
     # empty match block will not match anything, so it should return a wildcard pattern
@@ -273,3 +274,27 @@ def test_tuple_match():
     clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
     x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
     assert len(unmatched_cases(x)) == 0
+
+
+def test_inf_loop_case():
+    code = """
+v0.0.4
+type Arith[A] {
+    Zero,
+    Const(A),
+    Plus(Arith[A], Arith[A])
+}
+
+def @shallow_opt[A](%a: Arith[A]) -> Arith[A] {
+    match (%a) {
+        Plus(Zero, %r) => %r,
+        Plus(%l, Zero) => %l,
+        _ => %a
+    }
+}
+"""
+    relay.fromtext(code)
+    # fromtext parse the module, then checked it (which include strictness checking).
+
+if __name__ == "__main__":
+    pytest.main([__file__])