auto matches = node_map[node->ref_];
for (auto match : matches) {
if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() == nullptr &&
- match.as<FunctionNode>() == nullptr) {
+ match.as<FunctionNode>() == nullptr && !EmbedConst(match, node->ref_)) {
inputs[match] = Var(
"FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
NullValue<Type>());
auto extractor = MatchExtractor(inputs);
auto body = extractor.Mutate(expr);
- // Verify the pattern still holds, no longer valid if we're not embedding constants in the
- // graph, keep here for future debug CHECK(DFPatternMatcher(body).Match(pattern_, body));
+ // Verify the pattern still holds
+ CHECK(DFPatternMatcher(body).Match(pattern_, body));
group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
group.name = extractor.GetName();
// Check to make sure we aren't overlapping with another group
CHECK_EQ(groups_[gid_].gid, gid_);
}
+ /* \brief EmbedConst implements rules for embedding constants into partitioned functions or
+ * lifting them into the function arguments.
+ *
+ * The rules depend on what pattern the ConstantNode matched.
+ *
+ * The basic rules are:
+ * If the constant matches ExprPattern(relay.const(*)) or a ConstantPattern(), embed the constant
+ * in the partitioned function. If the constant matched an AltPattern, recursively check the
+ * matched side of the pattern. For any other matching pattern (i.e, wildcard, VarPattern, etc),
+ * lift the constant into the arguments of the partitioned function.
+ */
+ bool EmbedConst(const Expr& expr, const DFPattern pattern) {
+ bool embed = false;
+ if (expr.as<ConstantNode>()) {
+ if (pattern.as<ConstantPatternNode>() != nullptr) {
+ embed = true;
+ } else if (auto expr_pat = pattern.as<ExprPatternNode>()) {
+ if (expr_pat->expr.as<ConstantNode>()) {
+ embed = true;
+ }
+ } else if (auto alt_pat = pattern.as<AltPatternNode>()) {
+ if (matcher_->Match(alt_pat->left, expr)) {
+ embed = EmbedConst(expr, alt_pat->left);
+ } else {
+ embed = EmbedConst(expr, alt_pat->right);
+ }
+ }
+ }
+ return embed;
+ }
// Internal State
DFPattern pattern_;
std::vector<Group> groups_;
self.mean = wildcard()
self.beta = wildcard()
self.gamma = wildcard()
- self.eps = wildcard()
+ self.eps = ConstantPattern()
self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \
self.beta
class ElwiseNullCallback(DFPatternCallback):
def callback(self, pre, post, node_map):
- return node_map[self.x][0] # pylint: disable=no-member
+ return node_map[self.x][0] # pylint: disable=no-member
class AddCallback(ElwiseNullCallback):
def __init__(self):
meanf = relay.var('meanf')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
- epsf = relay.var('epsf')
# Put the arguments in toplogological order for the reference
- f = relay.Function([gammaf, xf, meanf, varf, epsf, betaf],
+ f = relay.Function([gammaf, xf, meanf, varf, betaf],
get_BN(xf, varf, meanf, betaf, gammaf,
- epsf)).with_attr("PartitionedFromPattern",
- "subtract_multiply_add_sqrt_divide_add_")
+ eps)).with_attr("PartitionedFromPattern",
+ "subtract_multiply_add_sqrt_divide_add_")
partitioned = BatchnormCallback().pattern.partition(BN)
- assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, eps, beta))
+ reference = f(gamma, x, mean, var, beta)
+ assert tvm.ir.structural_equal(partitioned, reference)
def test_partition_double_batchnorm():
meanf = relay.var('meanf')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
- epsf = relay.var('epsf')
- f1 = relay.Function([gammaf, xf, meanf, varf, epsf, betaf],
+ f1 = relay.Function([gammaf, xf, meanf, varf, betaf],
get_BN(xf, varf, meanf, betaf, gammaf,
- epsf)).with_attr("PartitionedFromPattern",
- "subtract_multiply_add_sqrt_divide_add_")
+ eps)).with_attr("PartitionedFromPattern",
+ "subtract_multiply_add_sqrt_divide_add_")
# The partitioner doesn't replace duplicates, so we use two copies of the function
xf2 = relay.var('xf2')
varf2 = relay.var('varf2')
meanf2 = relay.var('meanf2')
betaf2 = relay.var('betaf2')
gammaf2 = relay.var('gammaf2')
- epsf2 = relay.var('epsf2')
- f2 = relay.Function([gammaf2, xf2, meanf2, varf2, epsf2, betaf2],
+ f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2],
get_BN(xf2, varf2, meanf2, betaf2, gammaf2,
- epsf2)).with_attr("PartitionedFromPattern",
- "subtract_multiply_add_sqrt_divide_add_")
+ eps)).with_attr("PartitionedFromPattern",
+ "subtract_multiply_add_sqrt_divide_add_")
partitioned = BatchnormCallback().pattern.partition(BN2)
- reference = f2(gamma, f1(gamma, x, mean, var, eps, beta), mean, var, eps, beta)
+ reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
assert tvm.ir.structural_equal(partitioned, reference)
assert relu == pattern.partition(relu, check=check)
+def conv_bias_relu(x, w, b):
+ conv2d = relay.op.nn.conv2d(x, w)
+ bias_add = relay.op.nn.bias_add(conv2d, b)
+ relu = relay.op.nn.relu(bias_add)
+ return relu
+
+
def test_partition_option():
x = relay.var('x')
w = relay.var('w')
bias = is_op('nn.bias_add')(conv2d, wildcard())
pattern2 = bias.optional(lambda x: is_op('nn.relu')(x))
- def conv_bias_relu(x, w, b):
- conv2d = relay.op.nn.conv2d(x, w)
- bias_add = relay.op.nn.bias_add(conv2d, b)
- relu = relay.op.nn.relu(bias_add)
- return relu
-
relu = conv_bias_relu(x, w, b)
xf = relay.var('x')
out = rewrite(TestRewrite(), mod['tensor_concatenate_int64'])
assert tvm.ir.structural_equal(mod['tensor_concatenate_int64'], out)
+def test_partition_constant_embedding():
+ x = relay.var('x')
+ w = relay.var('w')
+ wc = relay.const(1)
+ b = relay.var('b')
+
+ xf = relay.var('x')
+ wf = relay.var('w')
+ bf = relay.var('b')
+ embeded_func = relay.Function([xf, bf],
+ conv_bias_relu(xf, wc,
+ bf)).with_attr("PartitionedFromPattern",
+ "nn.conv2d_nn.bias_add_nn.relu_")
+ xf = relay.var('x')
+ wf = relay.var('w')
+ bf = relay.var('b')
+ lifted_func = relay.Function([xf, wf, bf],
+ conv_bias_relu(xf, wf,
+ bf)).with_attr("PartitionedFromPattern",
+ "nn.conv2d_nn.bias_add_nn.relu_")
+ relu = conv_bias_relu(x, w, b)
+ reluc = conv_bias_relu(x, wc, b)
+
+ # Check lifting of wildcard matches
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), wildcard()),
+ wildcard()))
+ assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
+ assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc))
+
+ # Check lifting of input matches
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()),
+ wildcard()))
+ assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
+ assert tvm.ir.structural_equal(reluc, pattern.partition(reluc)) #Constants are not Inputs
+
+ # Check embedding of constant matches
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
+ ConstantPattern()),
+ wildcard()))
+ assert tvm.ir.structural_equal(relu, pattern.partition(relu))
+ assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
+
+ # Check embedding of constant ExprPatterns
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
+ ExprPattern(wc)),
+ wildcard()))
+ assert tvm.ir.structural_equal(relu, pattern.partition(relu))
+ assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
+
+ # Check lifting/embedding of Alt matches
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()
+ | ConstantPattern()),
+ wildcard()))
+ assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
+ assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
+
+ # Check lifting/embedding of Alt matches with the other ordering
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(
+ wildcard(), ConstantPattern() | is_input()), wildcard()))
+ assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
+ assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
+
+
if __name__ == "__main__":
test_expr_pattern()
test_var_pattern()
test_partition_check_types()
test_partition_option()
test_match_match()
+ test_partition_constant_embedding()