const IRModule& mod) {
if (auto clause_ctor = clause_pat.as<PatternConstructorNode>()) {
return ExpandWildcardsConstructor(GetRef<PatternConstructor>(clause_ctor), cand, mod);
+ } else if (auto clause_tup = clause_pat.as<PatternTupleNode>()) {
+ return ExpandWildcardsTuple(GetRef<PatternTuple>(clause_tup), cand, mod);
} else {
- return ExpandWildcardsTuple(Downcast<PatternTuple>(clause_pat), cand, mod);
+ return {cand};
}
}
// for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) {
- bool subpattern =
- clause_ctor->patterns[i].as<PatternConstructorNode>() ||
- clause_ctor->patterns[i].as<PatternTupleNode>();
- // for non-ADT fields, we can only have a wildcard for the value.
- if (!subpattern) {
- values_by_field.push_back({PatternWildcardNode::make()});
- } else {
- // otherwise, recursively expand.
- values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
- ctor_cand->patterns[i],
- mod));
- }
+ values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i],
+ ctor_cand->patterns[i],
+ mod));
}
// generate new candidates using a cartesian product.
// for constructors, we will expand the wildcards in any field that is an ADT.
Array<Array<Pattern>> values_by_field;
for (size_t i = 0; i < tuple_cand->patterns.size(); i++) {
- bool subpattern =
- clause_tuple->patterns[i].as<PatternConstructorNode>() ||
- clause_tuple->patterns[i].as<PatternTupleNode>();
- // for non-ADT fields, we can only have a wildcard for the value
- if (!subpattern) {
- values_by_field.push_back({PatternWildcardNode::make()});
- } else {
- // otherwise, recursively expand
- values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
- tuple_cand->patterns[i],
- mod));
- }
+ values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i],
+ tuple_cand->patterns[i],
+ mod));
}
// generate new candidates using a cartesian product
from tvm import relay
from tvm.relay.prelude import Prelude
from tvm.relay.analysis import unmatched_cases
+import pytest
def test_empty_match_block():
# empty match block will not match anything, so it should return a wildcard pattern
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b)
x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
assert len(unmatched_cases(x)) == 0
+
+
+def test_inf_loop_case():
+ code = """
+v0.0.4
+type Arith[A] {
+ Zero,
+ Const(A),
+ Plus(Arith[A], Arith[A])
+}
+
+def @shallow_opt[A](%a: Arith[A]) -> Arith[A] {
+ match (%a) {
+ Plus(Zero, %r) => %r,
+ Plus(%l, Zero) => %l,
+ _ => %a
+ }
+}
+"""
+ relay.fromtext(code)
+ # fromtext parse the module, then checked it (which include strictness checking).
+
+if __name__ == "__main__":
+ pytest.main([__file__])