# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# pylint: disable=unused-wildcard-import
+import numpy as np
+
import tvm
from tvm import relay
+from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import *
from tvm.relay.testing import run_opt_pass
-import numpy as np
# NB: 1 corresponds to the C++ enum that specicfies this
# we loose the type safety due to the Python/C++ calling
K_ELEMWISE = 0
K_BROADCAST = 1
+
## NODE TESTS
def test_expr_pattern():
ep = ExprPattern(relay.var('x', shape=(4, 1)))
assert isinstance(ep, ExprPattern)
assert isinstance(ep.expr, relay.Var)
+
def test_var_pattern():
v = is_input("x")
assert isinstance(v, VarPattern)
assert v.name == "x"
+
+def test_constant_pattern():
+ c = ConstantPattern()
+ assert isinstance(c, ConstantPattern)
+
+
def test_wildcard_pattern():
wc = wildcard()
assert isinstance(wc, WildcardPattern)
+
def test_CallPattern():
wc1 = wildcard()
wc2 = wildcard()
assert isinstance(c.args[0], WildcardPattern)
assert isinstance(c.args[1], WildcardPattern)
+
def test_TuplePattern():
wc1 = wildcard()
wc2 = wildcard()
assert isinstance(t.fields[0], WildcardPattern)
assert isinstance(t.fields[1], WildcardPattern)
+
def test_TupleGetItemPattern():
wc1 = wildcard()
wc2 = wildcard()
assert isinstance(tgi.tuple.fields[0], WildcardPattern)
assert isinstance(tgi.tuple.fields[1], WildcardPattern)
+
def test_AltPattern():
is_add_or_sub = is_op('add') | is_op('subtract')
assert isinstance(is_add_or_sub, AltPattern)
+
def test_TypePattern():
ttype = relay.TensorType((10, 10), "float32")
ty_pat = has_type(ttype)
assert isinstance(ty_pat, TypePattern)
assert ty_pat.type == ttype
+
def test_AttrPattern():
op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE})
assert isinstance(op, AttrPattern)
assert op.attrs["TOpPattern"] == K_ELEMWISE
+
## MATCHER TESTS
+
def test_match_op():
assert is_op('add').match(relay.op.op.get("add"))
+
def test_no_match_op():
assert not is_op('add').match(relay.op.op.get("subtract"))
+
def test_match_op_or():
is_add_or_sub = is_op('add') | is_op('subtract')
assert is_add_or_sub.match(relay.op.op.get("add"))
assert is_add_or_sub.match(relay.op.op.get("subtract"))
+
def test_match_call_commutive():
x = relay.var('x')
y = relay.var('y')
assert mul_pattern.match(x * y)
assert mul_pattern.match(y * x)
+
def test_no_match_call_commutive():
x = relay.var('x')
y = relay.var('y')
assert add_pattern.match(x / y)
assert not add_pattern.match(y / x)
+
def test_match_call():
x = relay.var('x')
y = relay.var('y')
add_pattern = is_op('add')(wildcard(), wildcard())
assert add_pattern.match(x + y)
+
def test_no_match_call():
x = relay.var('x')
y = relay.var('y')
add_pattern = is_op('add')(wildcard(), wildcard())
assert not add_pattern.match(x - y)
+
def test_match_option():
x = relay.var('x')
w = relay.var('w')
b = relay.var('b')
- pattern = is_op("nn.relu")(
- is_op("nn.conv2d")(wildcard(), wildcard()
- ).optional(lambda x: is_op("nn.bias_add")(x, wildcard()))
- )
+ pattern = is_op("nn.relu")(is_op("nn.conv2d")(
+ wildcard(), wildcard()).optional(lambda x: is_op("nn.bias_add")(x, wildcard())))
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
assert pattern.match(tanh2)
assert not pattern.match(relu2)
+
def test_no_match_option():
x = relay.var('x')
w = relay.var('w')
b = relay.var('b')
- pattern = is_op("nn.relu")(
- is_op("nn.conv2d")(wildcard(), wildcard()
- ).optional(lambda x: is_op("nn.bias_add")(x, wildcard()))
- )
+ pattern = is_op("nn.relu")(is_op("nn.conv2d")(
+ wildcard(), wildcard()).optional(lambda x: is_op("nn.bias_add")(x, wildcard())))
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.tanh(conv2d)
relu = relay.op.nn.relu(bias_add)
assert not pattern.match(relu)
-def test_match_tuple():
- x = relay.var('x')
- y = relay.var('y')
- z = relay.op.op.get("add")
- tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
- assert tuple_pattern.match(relay.expr.Tuple((x,y,z)))
-def test_no_match_tuple():
- x = relay.var('x')
- y = relay.var('y')
- z = relay.op.op.get("add")
- tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard()))
- assert not tuple_pattern.match(relay.expr.Tuple((x,y,z)))
+def test_match_const():
+ conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
+ pattern = is_op('nn.bias_add')(conv2d, wildcard())
+
+ x = relay.var('x', shape=(1, 3, 224, 224))
+ w = relay.var('w', shape=(3, 3, 3, 3))
+ b = relay.var('b', shape=(3, ))
+ conv2d = relay.op.nn.conv2d(x, w)
+ out = relay.op.nn.bias_add(conv2d, b)
+ func = relay.Function([x, w, b], out)
+ mod = tvm.IRModule.from_expr(func)
+
+ assert not pattern.match(mod['main'].body)
+ mod["main"] = bind_params_by_name(mod["main"],
+ {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))})
+ assert pattern.match(mod['main'].body)
+
def test_match_tuple():
x = relay.var('x')
y = relay.var('y')
z = relay.op.op.get("add")
tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
+ assert tuple_pattern.match(relay.expr.Tuple((x, y, z)))
+
+ tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1)
- assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 1))
+ assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))
+
def test_no_match_tuple():
x = relay.var('x')
y = relay.var('y')
z = relay.op.op.get("add")
+ tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard()))
+ assert not tuple_pattern.match(relay.expr.Tuple((x, y, z)))
+
tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add")))
tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1)
- assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 2))
+ assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple(
+ (x, y, z)), 2))
+
def test_match_type():
x = relay.var('x', shape=(10, 10), dtype="float32")
ty_pat = has_type(relay.TensorType((10, 10), "float32"))
assert ty_pat.match(x)
+
def test_no_match_type():
x = relay.var('x', shape=(10, 10), dtype="int32")
ty_pat = has_type(relay.TensorType((10, 10), "float32"))
assert not ty_pat.match(x)
+
def test_match_op_attr():
op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
op_pat = op(wildcard(), wildcard())
y = relay.var('y')
assert op_pat.match(x + y)
+
def test_no_match_op_attr():
op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE})
op_pat = op(wildcard(), wildcard())
y = relay.var('y')
assert not op_pat.match(x - y)
+
def test_match_func_attr():
pattern = wildcard().has_attr({"Composite": "add"})
x = relay.var('x')
f = relay.Function([x, y], x + y).with_attr("Composite", "add")
assert pattern.match(f)
+
def test_no_match_func_attr():
pattern = wildcard().has_attr({"Composite": "add"})
x = relay.var('x')
f = relay.Function([x, y], x + y).with_attr("Composite", "conv_bias")
assert not pattern.match(f)
+
def test_match_call_attr():
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"})
x = relay.var('x')
y = relay.var('y')
assert is_conv2d.match(relay.op.nn.conv2d(x, y))
+
def test_no_match_call_attr():
x = relay.var('x')
y = relay.var('y')
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"})
assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
+
def test_match_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
# Check
assert diamond.match(out)
+
def test_no_match_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
- out = relu + leaky_relu
# Check
assert not diamond.match(leaky_relu)
assert not diamond.match(relu)
+
def test_match_fake_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
# Check
assert diamond.match(out)
-
+
# Fuzzy path/nested Diamond
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
- is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
+ is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(
+ wildcard()) | is_op('add')(wildcard(), wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
assert diamond.match(out)
+
def test_not_match_dominator():
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
# Check
assert not diamond.match(out)
+
def test_rewrite():
x = relay.var('x')
y = relay.var('y')
add_pattern = is_op('add')(wildcard(), wildcard())
sub_pattern = is_op('subtract')(wildcard(), wildcard())
+
class TestRewrite(DFPatternCallback):
def __init__(self):
self.pattern = add_pattern
+
def callback(self, pre, post, node_map):
return post.args[0] - post.args[1]
+
out = rewrite(TestRewrite(), x + y)
assert sub_pattern.match(out)
+
def test_rewrite_func():
x = relay.var('x')
w = relay.var('w')
y = relay.var('y')
add_pattern = is_op('add')(wildcard(), wildcard())
sub_pattern = is_op('subtract')(wildcard(), wildcard())
+
class TestRewrite(DFPatternCallback):
def __init__(self):
self.pattern = add_pattern
+
def callback(self, pre, post, node_map):
return post.args[0] - post.args[1]
+
inpf = relay.var("input")
weightf = relay.var("weight")
- func = relay.Function([inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None)
- out = rewrite(TestRewrite(), func(x,w) + y)
+ func = relay.Function([inpf, weightf],
+ relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)),
+ attrs=None)
+ out = rewrite(TestRewrite(), func(x, w) + y)
assert sub_pattern.match(out)
+
def test_nested_rewrite():
class PatternCallback(DFPatternCallback):
def __init__(self, pattern):
assert tvm.ir.structural_equal(out, new_out)
+
def test_not_fuse_multi_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
# Check
assert not diamond.match(out)
+
class BatchnormCallback(DFPatternCallback):
def __init__(self):
self.x = wildcard()
self.beta = wildcard()
self.gamma = wildcard()
self.eps = wildcard()
-
- self.pattern = self.gamma * (self.x - self.mean)/is_op("sqrt")(self.var + self.eps) + self.beta
+
+ self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \
+ self.beta
def callback(self, pre, post, node_map):
x = node_map[self.x][0]
beta = node_map[self.beta][0]
gamma = node_map[self.gamma][0]
eps = node_map[self.eps][0]
- return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.asnumpy().item())[0]
+ return relay.op.nn.batch_norm(x, gamma, beta, mean, var,
+ epsilon=eps.data.asnumpy().item())[0]
+
def test_fuse_batchnorm():
x = relay.var('x')
mean = relay.var('mean')
beta = relay.var('beta')
gamma = relay.var('gamma')
-
- BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+ BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
out = rewrite(BatchnormCallback(), BN)
- assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+ assert tvm.ir.structural_equal(
+ out,
+ relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0])
+
def test_no_fuse_batchnorm():
x = relay.var('x')
mean = relay.var('mean')
beta = relay.var('beta')
gamma = relay.var('gamma')
-
- fake_BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta
+
+ fake_BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta
out = rewrite(BatchnormCallback(), fake_BN)
assert tvm.ir.structural_equal(out, fake_BN)
+
def test_fuse_double_batchnorm():
x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
beta = relay.var('beta')
gamma = relay.var('gamma')
-
- BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
- BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+ BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
+ BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
out = rewrite(BatchnormCallback(), BN2)
- bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]
- bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon = 1e-5)[0]
+ bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]
+ bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon=1e-5)[0]
assert tvm.ir.structural_equal(out, bn2)
+
def test_partial_fuse_double_batchnorm():
x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
beta = relay.var('beta')
gamma = relay.var('gamma')
-
- BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta
- BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+ BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta
+ BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta
out = rewrite(BatchnormCallback(), BN2)
- bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon = 1e-5)[0]
+ bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon=1e-5)[0]
assert tvm.ir.structural_equal(out, bn2)
+
def test_fuse_batchnorm_commutation():
x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
beta = relay.var('beta')
gamma = relay.var('gamma')
-
+
#commute add
- BN = beta + gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5))
+ BN = beta + gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5))
out = rewrite(BatchnormCallback(), BN)
- assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+ assert tvm.ir.structural_equal(
+ out,
+ relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0])
# associate divide/multiply
- BN = (gamma * (x - mean)) /relay.op.sqrt(var + relay.const(1e-5)) + beta
+ BN = (gamma * (x - mean)) / relay.op.sqrt(var + relay.const(1e-5)) + beta
out = rewrite(BatchnormCallback(), BN)
- assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+ assert tvm.ir.structural_equal(
+ out,
+ relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0])
# associate multiply/divide
- BN = gamma * ((x - mean)/relay.op.sqrt(var + relay.const(1e-5))) + beta
+ BN = gamma * ((x - mean) / relay.op.sqrt(var + relay.const(1e-5))) + beta
out = rewrite(BatchnormCallback(), BN)
- assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])
+ assert tvm.ir.structural_equal(
+ out,
+ relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0])
+
def test_quadruple_rewrite_dominator():
class DominatorRemovalCallback(DFPatternCallback):
def __init__(self):
self.inp = wildcard()
self.weight = wildcard()
-
is_conv2d = is_op('nn.conv2d')(self.inp, self.weight)
- is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
+ is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(
+ wildcard()) | is_op('add')(wildcard(), wildcard())
reduction = is_op('add')(wildcard(), wildcard())
self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction)
tanh = relay.op.tanh(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = tanh + leaky_relu
-
one = relay.op.nn.conv2d(inp, weight)
two = relay.op.nn.conv2d(one, weight)
three = relay.op.nn.conv2d(two, weight)
assert tvm.ir.structural_equal(DominatorRemovalCallback().rewrite(out), four)
+
def algebraic_simplify(expr):
zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0)))
one = (ExprPattern(relay.const(1)) | ExprPattern(relay.const(1.0)))
+
class ElwiseNullCallback(DFPatternCallback):
def callback(self, pre, post, node_map):
- return node_map[self.x][0]
+ return node_map[self.x][0] # pylint: disable=no-member
class AddCallback(ElwiseNullCallback):
def __init__(self):
self.x = wildcard()
self.pattern = self.x + zero
-
+
class SubCallback(ElwiseNullCallback):
def __init__(self):
self.x = wildcard()
self.x = zero
self.pattern = self.x / wildcard()
- return rewrite([AddCallback(),
- SubCallback(),
- MulCallback(),
- DivCallback(),
- MulZeroCallback(),
- ZeroDivCallback()
- ], expr);
+ return rewrite([
+ AddCallback(),
+ SubCallback(),
+ MulCallback(),
+ DivCallback(),
+ MulZeroCallback(),
+ ZeroDivCallback()
+ ], expr)
+
def test_algebraic_simplify():
x = relay.Var('x')
- y = relay.Var('y')
+ y = relay.Var('y')
one = relay.const(1)
zero = relay.const(0)
assert algebraic_simplify(x + zerof) == x
assert algebraic_simplify(zero + x) == x
assert algebraic_simplify(zerof + x) == x
-
+
assert algebraic_simplify(x - zero) == x
assert algebraic_simplify(x - zerof) == x
-
+
assert algebraic_simplify(x * one) == x
assert algebraic_simplify(x * onef) == x
assert algebraic_simplify(one * x) == x
assert algebraic_simplify(onef * x) == x
assert algebraic_simplify(x * zero) == zero
assert algebraic_simplify(x * zerof) == zerof
-
+
assert algebraic_simplify(x / one) == x
assert algebraic_simplify(x / onef) == x
assert algebraic_simplify(zero / x) == zero
assert algebraic_simplify(zerof / x) == zerof
- assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y)
+ assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x),
+ x + y)
+
def test_double_partition():
# Pattern 1
for pat, label in [(relu_p, "conv_bias_relu"), (bias_add_p, "conv_bias")]:
partitioned = pat.partition(partitioned, {"Composite": label})
-
inpf = relay.var("input")
weightf = relay.var("weight")
biasf = relay.var("bias")
- func0 = relay.Function([inpf, weightf, biasf], relay.op.nn.relu(relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf))).with_attr("Composite", "conv_bias_relu").with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_nn.relu_")
+ func0 = relay.Function(
+ [inpf, weightf, biasf],
+ relay.op.nn.relu(relay.op.nn.bias_add(
+ relay.op.nn.conv2d(inpf, weightf),
+ biasf))).with_attr("Composite",
+ "conv_bias_relu").with_attr("PartitionedFromPattern",
+ "nn.conv2d_nn.bias_add_nn.relu_")
inpf = relay.var("input")
weightf = relay.var("weight")
biasf = relay.var("bias")
- func1 = relay.Function([inpf, weightf, biasf], relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf)).with_attr("Composite", "conv_bias").with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_")
+ func1 = relay.Function([inpf, weightf, biasf],
+ relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf),
+ biasf)).with_attr("Composite",
+ "conv_bias").with_attr(
+ "PartitionedFromPattern",
+ "nn.conv2d_nn.bias_add_")
expected = func1(func0(x, w, b), w2, b2)
assert tvm.ir.structural_equal(partitioned, expected)
+
def test_partition_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
# Classic Diamond
inp = relay.var('input')
weight = relay.var('weight')
+
def generate_diamond(inp, weight):
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
return relu + leaky_relu
- out = generate_diamond(inp*inp, weight*weight)
+
+ out = generate_diamond(inp * inp, weight * weight)
# Check
partitioned = diamond.partition(out)
-
+
i = relay.Var("input")
w = relay.Var("weight")
- f = relay.Function([i, w], generate_diamond(i, w)).with_attr("PartitionedFromPattern","nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_")
- assert tvm.ir.structural_equal(partitioned, f(inp*inp, weight*weight))
+ f = relay.Function([i, w], generate_diamond(i, w)).with_attr(
+ "PartitionedFromPattern", "nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_")
+ assert tvm.ir.structural_equal(partitioned, f(inp * inp, weight * weight))
+
def test_quadruple_partition_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
- is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
+ is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(
+ wildcard()) | is_op('add')(wildcard(), wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
-
inp = relay.var('input')
weight = relay.var('weight')
+
# Classic Diamond
def classic_diamond(inp, weight):
conv2d = relay.op.nn.conv2d(inp, weight)
return tanh + leaky_relu
partitioned = diamond.partition(
- nested_diamond(
- single_branch(
- deeper_diamond(
- classic_diamond(inp, weight),
- weight),
- weight),
- weight
- )
- )
+ nested_diamond(single_branch(deeper_diamond(classic_diamond(inp, weight), weight), weight),
+ weight))
functions = []
partition_names = [
"nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_",
- "nn.conv2d_nn.relu_nn.relu_tanh_nn.leaky_relu_add_",
- "nn.conv2d_nn.relu_nn.relu_tanh_add_",
+ "nn.conv2d_nn.relu_nn.relu_tanh_nn.leaky_relu_add_", "nn.conv2d_nn.relu_nn.relu_tanh_add_",
"nn.conv2d_nn.relu_add_tanh_nn.leaky_relu_add_"
]
for i, f in enumerate([classic_diamond, deeper_diamond, single_branch, nested_diamond]):
inpf = relay.var("input")
weightf = relay.var("weight")
- functions.append(relay.Function([inpf, weightf], f(inpf, weightf)).with_attr("PartitionedFromPattern", partition_names[i]))
-
- reference = functions[3](
- functions[2](
- functions[1](
- functions[0](inp, weight),
- weight),
- weight),
- weight
- )
+ functions.append(
+ relay.Function([inpf, weightf], f(inpf,
+ weightf)).with_attr("PartitionedFromPattern",
+ partition_names[i]))
+
+ reference = functions[3](functions[2](functions[1](functions[0](inp, weight), weight), weight),
+ weight)
assert tvm.ir.structural_equal(partitioned, reference)
+
def get_BN(x, var, mean, beta, gamma, eps):
- return gamma * (x - mean)/relay.op.sqrt(var + eps) + beta
+ return gamma * (x - mean) / relay.op.sqrt(var + eps) + beta
+
def test_partition_batchnorm():
x = relay.var('x')
eps = relay.const(1e-5)
BN = get_BN(x, var, mean, beta, gamma, eps)
-
xf = relay.var('xf')
varf = relay.var('varf')
meanf = relay.var('meanf')
gammaf = relay.var('gammaf')
epsf = relay.var('epsf')
# Put the arguments in toplogological order for the reference
- f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
+ f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf],
+ get_BN(xf, varf, meanf, betaf, gammaf,
+ epsf)).with_attr("PartitionedFromPattern",
+ "subtract_multiply_add_sqrt_divide_add_")
partitioned = BatchnormCallback().pattern.partition(BN)
assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, eps, beta))
+
def test_partition_double_batchnorm():
x = relay.var('x')
var = relay.var('var')
beta = relay.var('beta')
gamma = relay.var('gamma')
eps = relay.const(1e-5)
-
- BN = gamma * (x - mean)/relay.op.sqrt(var + eps) + beta
- BN2 = gamma * (BN - mean)/relay.op.sqrt(var + eps) + beta
+
+ BN = gamma * (x - mean) / relay.op.sqrt(var + eps) + beta
+ BN2 = gamma * (BN - mean) / relay.op.sqrt(var + eps) + beta
xf = relay.var('xf')
varf = relay.var('varf')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
epsf = relay.var('epsf')
- f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf], get_BN(xf, varf, meanf, betaf, gammaf, epsf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
+ f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf],
+ get_BN(xf, varf, meanf, betaf, gammaf,
+ epsf)).with_attr("PartitionedFromPattern",
+ "subtract_multiply_add_sqrt_divide_add_")
# The partitioner doesn't replace duplicates, so we use two copies of the function
xf2 = relay.var('xf2')
varf2 = relay.var('varf2')
betaf2 = relay.var('betaf2')
gammaf2 = relay.var('gammaf2')
epsf2 = relay.var('epsf2')
- f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2, epsf2)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
+ f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2],
+ get_BN(xf2, varf2, meanf2, betaf2, gammaf2,
+ epsf2)).with_attr("PartitionedFromPattern",
+ "subtract_multiply_add_sqrt_divide_add_")
partitioned = BatchnormCallback().pattern.partition(BN2)
reference = f2(gamma, f1(gamma, x, mean, var, eps, beta), mean, var, eps, beta)
assert tvm.ir.structural_equal(partitioned, reference)
+
def test_partition_check():
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
+
def check(pre):
return pre.args[0].attrs.data_layout == "NCHW"
wf = relay.var('weight')
conv2df = relay.op.nn.conv2d(xf, wf)
reluf = relay.op.nn.relu(conv2df)
- func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern", "nn.conv2d_nn.relu_")
+ func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern",
+ "nn.conv2d_nn.relu_")
reference = func(x, w)
partitioned = pattern.partition(relu, check=check)
relu = relay.op.nn.relu(conv2d)
assert relu == pattern.partition(relu, check=check)
+
def test_partition_check_types():
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
+
def check(pre):
conv = pre.args[0]
return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1)
relu = run_opt_pass(relu, relay.transform.InferType())
assert relu == pattern.partition(relu, check=check)
+
def test_partition_option():
x = relay.var('x')
w = relay.var('w')
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
return relu
+
relu = conv_bias_relu(x, w, b)
xf = relay.var('x')
wf = relay.var('w')
bf = relay.var('b')
- func = relay.Function([xf, wf, bf], conv_bias_relu(xf, wf, bf)).with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_nn.relu_")
+ func = relay.Function([xf, wf, bf],
+ conv_bias_relu(xf, wf, bf)).with_attr("PartitionedFromPattern",
+ "nn.conv2d_nn.bias_add_nn.relu_")
assert pattern1.match(relu)
assert tvm.ir.structural_equal(func(x, w, b), pattern1.partition(relu))
assert pattern2.match(relu)
assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu))
+
if __name__ == "__main__":
+ test_expr_pattern()
+ test_var_pattern()
+ test_constant_pattern()
+ test_wildcard_pattern()
+ test_CallPattern()
+ test_TuplePattern()
+ test_TupleGetItemPattern()
+ test_AltPattern()
+ test_TypePattern()
+ test_AttrPattern()
test_match_op()
test_no_match_op()
test_match_op_or()
- test_match_call()
- test_no_match_call()
test_match_call_commutive()
test_no_match_call_commutive()
+ test_match_call()
+ test_no_match_call()
+ test_match_option()
+ test_no_match_option()
+ test_match_const()
test_match_tuple()
test_no_match_tuple()
test_match_type()
test_no_match_type()
- test_match_attr()
- test_no_match_attr()
+ test_match_op_attr()
+ test_no_match_op_attr()
+ test_match_func_attr()
+ test_no_match_func_attr()
+ test_match_call_attr()
+ test_no_match_call_attr()
test_match_diamond()
test_no_match_diamond()
test_match_fake_diamond()
+ test_match_dominator()
+ test_not_match_dominator()
test_rewrite()
+ test_rewrite_func()
test_nested_rewrite()
+ test_not_fuse_multi_diamond()
test_fuse_batchnorm()
test_no_fuse_batchnorm()
test_fuse_double_batchnorm()
test_partial_fuse_double_batchnorm()
test_fuse_batchnorm_commutation()
- test_match_dominator()
- test_not_match_dominator()
+ test_quadruple_rewrite_dominator()
test_algebraic_simplify()
+ test_double_partition()
test_partition_dominator()
test_quadruple_partition_dominator()
test_partition_batchnorm()