From 76c239269935288e51fbce14f135d75ad9742b2a Mon Sep 17 00:00:00 2001 From: Logan Weber <36520469+weberlo@users.noreply.github.com> Date: Mon, 7 Oct 2019 08:20:59 -0700 Subject: [PATCH] Fix match case in Python-side expr functor (#4037) --- python/tvm/relay/expr_functor.py | 5 ++++- tests/python/relay/test_expr_functor.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 77970d1..f492c74 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -249,7 +249,10 @@ class ExprMutator(ExprFunctor): return con def visit_match(self, m): - return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses]) + return Match( + self.visit(m.data), + [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses], + complete=m.complete) def visit_ref_create(self, r): return RefCreate(self.visit(r.value)) diff --git a/tests/python/relay/test_expr_functor.py b/tests/python/relay/test_expr_functor.py index 4b0adff..5c92365 100644 --- a/tests/python/relay/test_expr_functor.py +++ b/tests/python/relay/test_expr_functor.py @@ -125,6 +125,16 @@ def test_match(): p = relay.prelude.Prelude() check_visit(p.mod[p.map]) + +def test_match_completeness(): + p = relay.prelude.Prelude() + for completeness in [True, False]: + match_expr = relay.adt.Match(p.nil, [], complete=completeness) + result_expr = ExprMutator().visit(match_expr) + # ensure the mutator doesn't mangle the completeness flag + assert result_expr.complete == completeness + + if __name__ == "__main__": test_constant() test_tuple() @@ -139,3 +149,4 @@ if __name__ == "__main__": test_ref_write() test_memo() test_match() + test_match_completeness() -- 2.7.4