add a testcase for #5674 (#5677)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Wed, 27 May 2020 01:14:58 +0000 (18:14 -0700)
committerGitHub <noreply@github.com>
Wed, 27 May 2020 01:14:58 +0000 (18:14 -0700)
tests/python/relay/test_dataflow_pattern.py

index ed90873..6a66f60 100644 (file)
@@ -464,6 +464,23 @@ def test_rewrite():
     out = rewrite(TestRewrite(), x + y)
     assert sub_pattern.match(out)
 
+def test_rewrite_func():
+    x = relay.var('x')
+    w = relay.var('w')
+    y = relay.var('y')
+    add_pattern = is_op('add')(wildcard(), wildcard())
+    sub_pattern = is_op('subtract')(wildcard(), wildcard())
+    class TestRewrite(DFPatternCallback):
+        def __init__(self):
+            self.pattern = add_pattern
+        def callback(self, pre, post, node_map):
+            return post.args[0] - post.args[1]
+    inpf = relay.var("input")
+    weightf = relay.var("weight")
+    func = relay.Function([inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None)
+    out = rewrite(TestRewrite(), func(x,w) + y)
+    assert sub_pattern.match(out)
+
 def test_nested_rewrite():
     class PatternCallback(DFPatternCallback):
         def __init__(self, pattern):