out = rewrite(TestRewrite(), x + y)
assert sub_pattern.match(out)
+def test_rewrite_func():
+ x = relay.var('x')
+ w = relay.var('w')
+ y = relay.var('y')
+ add_pattern = is_op('add')(wildcard(), wildcard())
+ sub_pattern = is_op('subtract')(wildcard(), wildcard())
+ class TestRewrite(DFPatternCallback):
+ def __init__(self):
+ self.pattern = add_pattern
+ def callback(self, pre, post, node_map):
+ return post.args[0] - post.args[1]
+ inpf = relay.var("input")
+ weightf = relay.var("weight")
+ func = relay.Function([inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None)
+ out = rewrite(TestRewrite(), func(x,w) + y)
+ assert sub_pattern.match(out)
+
def test_nested_rewrite():
class PatternCallback(DFPatternCallback):
def __init__(self, pattern):