if (var_map->find(pattern->name_hint()) == var_map->end()) {
// if we haven't encountered this var yet, make a new free var and associate
// it with the value at 'root'
- auto free_var = Var(pattern->name_hint(), Type());
+ auto free_var = Var(pattern->name_hint(), root->checked_type());
+ free_var->checked_type_ = root->checked_type();
var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
return std::move(free_var);
} else {
new_args.push_back(new_arg);
i++;
}
- return Call(root->op, new_args, root->attrs);
+ Call new_call = Call(root->op, new_args, root->attrs);
+ new_call->checked_type_ = root->checked_type();
+ return std::move(new_call);
}
Expr VisitExpr_(const CallNode* cn) {
auto new_e = this->Mutate(arg);
new_args.push_back(new_e);
}
- return Call(call->op, new_args, call->attrs);
+ Call new_call = Call(call->op, new_args, call->attrs);
+ new_call->checked_type_ = call->checked_type();
+ return std::move(new_call);
}
}
Expr expr = ExprMutator::VisitExpr_(cn);
call = Downcast<Call>(expr);
+ call->checked_type_ = cn->checked_type();
if (!call->op->IsInstance<OpNode>()) return std::move(call);
// only call patterns are supported
args.push_back(args_map[free_var->name_hint()][1]);
}
auto new_call = Call(f, args);
+ new_call->checked_type_ = call->checked_type();
return std::move(new_call);
}
return std::move(call);
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
+def test_type_check():
+ """Test that we can query tensor types in the 'check' function."""
+ def before():
+ x = relay.var('x', shape=(1, 10, 10, 10))
+ w = relay.var('w', shape=(10, 10, 3, 3))
+ b = relay.var('b', shape=(8,))
+ conv = relay.nn.conv2d(x,
+ w,
+ kernel_size=(3, 3),
+ kernel_layout="OIHW",
+ data_layout="NHWC")
+ bias = relay.nn.bias_add(conv, b)
+ relu = relay.nn.relu(bias)
+ return relay.Function([x, w, b], relu)
+
+ def _check_type_true(extract):
+ conv = extract.args[0].args[0]
+ typ = conv.checked_type
+ return bool(typ.shape[0] == 1)
+
+ def _check_type_false(extract):
+ conv = extract.args[0].args[0]
+ typ = conv.checked_type
+ return bool(typ.shape[0] != 1)
+
+ pattern_table_true = [
+ ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true)
+ ]
+ pattern_table_false = [
+ ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false)
+ ]
+
+ result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false))
+ expected = run_opt_pass(before(), relay.transform.InferType())
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
+
+ result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true))
+ assert result.body.op.attrs["Composite"] == "conv_bias_relu"
+
+
if __name__ == "__main__":
test_simple_merge()
test_branch_merge()
test_tuple_get_item_merge()
test_pattern_with_check()
test_diamond_not_merge()
+ test_type_check()