* Relay operators map to a single external operator.
*/
-#include <tvm/te/operation.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
+#include <tvm/te/operation.h>
namespace tvm {
namespace relay {
class MergeCompositeWrapper : public ExprMutator {
public:
- explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern)
- : pattern_name_(pattern_name), pattern_(pattern) {}
+ explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern,
+ const PackedFunc& check)
+ : pattern_name_(pattern_name), pattern_(pattern), check_(check) {}
Expr ExtractPattern(const Var& pattern, const Expr& root,
- Map<std::string, Array<Expr>>* var_map) {
+ Map<std::string, Array<Expr>>* var_map) {
if (var_map->find(pattern->name_hint()) == var_map->end()) {
// if we haven't encountered this var yet, make a new free var and associate
// it with the value at 'root'
}
Expr ExtractPattern(const Constant& pattern, const Expr& root,
- Map<std::string, Array<Expr>>* var_map) {
+ Map<std::string, Array<Expr>>* var_map) {
return root;
}
Expr ExtractPattern(const TupleGetItem& pattern, const Expr& root,
- Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
+ Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
if (!root->IsInstance<TupleGetItemNode>()) {
return Expr();
}
if (pattern->index != root_node->index) {
return Expr();
}
- if (pattern->tuple->IsInstance<CallNode>() &&
- root_node->tuple->IsInstance<CallNode>()) {
+ if (pattern->tuple->IsInstance<CallNode>() && root_node->tuple->IsInstance<CallNode>()) {
Expr new_arg;
if (call_map->find(pattern->tuple) != call_map->end()) {
new_arg = (*call_map)[pattern->tuple];
} else {
- new_arg = ExtractPattern(Downcast<Call>(pattern->tuple),
- Downcast<Call>(root_node->tuple),
+ new_arg = ExtractPattern(Downcast<Call>(pattern->tuple), Downcast<Call>(root_node->tuple),
var_map, call_map);
call_map->Set(pattern->tuple, new_arg);
}
* and free variables. The free variables indicate where the pattern can 'attach' in your
* graph. This function takes the final call node of the pattern and the call node currently
* being traversed in the Relay graph. It traverses through the pattern in lockstep with call node
- * from the graph (referred to as the 'root' node here) to check they're identical. If at any point
- * they differ, an empty expression is returned to signify the extract failed. If a free var is
- * reached in the pattern, the corresponding value in the root is associated with the name of the
- * free var (via the var_map) so that when we construct the composite function, the inputs match
- * up correctly with the rest of the graph. The return value of this function when successful is
- * a new Relay expression ready to be wrapped into a composite function.
+ * from the graph (referred to as the 'root' node here) to check they're identical. If at any
+ * point they differ, an empty expression is returned to signify the extract failed. If a free var
+ * is reached in the pattern, the corresponding value in the root is associated with the name of
+ * the free var (via the var_map) so that when we construct the composite function, the inputs
+ * match up correctly with the rest of the graph. The return value of this function when
+ * successful is a new Relay expression ready to be wrapped into a composite function.
*/
- Expr ExtractPattern(const Call& pattern, const Call& root,
- Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
+ Expr ExtractPattern(const Call& pattern, const Call& root, Map<std::string, Array<Expr>>* var_map,
+ Map<Expr, Expr>* call_map) {
// check to make sure both calls are to operators (not functions)
- if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
- return Expr();
- if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
- return Expr();
+ if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>()) return Expr();
+ if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name) return Expr();
unsigned int i = 0;
Array<Expr> new_args;
return Expr();
}
// if it's a call node, recursively call this function
- new_arg = ExtractPattern(Downcast<Call>(arg),
- Downcast<Call>(root->args[i]),
- var_map, call_map);
+ new_arg =
+ ExtractPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i]), var_map, call_map);
call_map->Set(arg, new_arg);
}
} else if (arg->IsInstance<VarNode>()) {
// if there's a var in the pattern, it must be a free var
// so call the function to update the var_map
- new_arg = ExtractPattern(Downcast<Var>(arg),
- root->args[i],
- var_map);
+ new_arg = ExtractPattern(Downcast<Var>(arg), root->args[i], var_map);
} else if (arg->IsInstance<ConstantNode>()) {
// if there's a constant, simply get the corresponding
// value of the constant from the root
- new_arg = ExtractPattern(Downcast<Constant>(arg),
- root->args[i],
- var_map);
+ new_arg = ExtractPattern(Downcast<Constant>(arg), root->args[i], var_map);
} else if (arg->IsInstance<TupleGetItemNode>()) {
- new_arg = ExtractPattern(Downcast<TupleGetItem>(arg),
- root->args[i],
- var_map, call_map);
+ new_arg = ExtractPattern(Downcast<TupleGetItem>(arg), root->args[i], var_map, call_map);
}
if (!new_arg.defined()) {
return Expr();
if (call->op->IsInstance<FunctionNode>()) {
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
- const auto name_node =
- func->GetAttr<tir::StringImm>(attr::kComposite);
+ const auto name_node = func->GetAttr<tir::StringImm>(attr::kComposite);
// don't step into existing composite functions
if (name_node.defined() && name_node->value != "") {
tvm::Array<tvm::relay::Expr> new_args;
Expr expr = ExprMutator::VisitExpr_(cn);
call = Downcast<Call>(expr);
- if (!call->op->IsInstance<OpNode>())
- return std::move(call);
+ if (!call->op->IsInstance<OpNode>()) return std::move(call);
// only call patterns are supported
Call pattern = Downcast<Call>(pattern_);
Map<std::string, Array<Expr>> args_map;
Map<Expr, Expr> call_map;
auto extract = ExtractPattern(pattern, call, &args_map, &call_map);
- if (extract.defined()) {
+ if (extract.defined() && static_cast<bool>(check_(extract))) {
auto free_vars = FreeVars(extract);
// make the composite function
auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs());
std::string pattern_name_;
/*! \brief The pattern to match */
Expr pattern_;
+ /*! \brief The function to check whether an extract is supported */
+ PackedFunc check_;
};
-Expr MergeComposite(const Expr& expr,
- const Array<tir::StringImm>& pattern_names, const Array<Expr>& patterns) {
+Expr MergeComposite(const Expr& expr, const Array<tir::StringImm>& pattern_names,
+ const Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
CHECK_EQ(pattern_names.size(), patterns.size());
Expr merged_expr = expr;
// merge the patterns one-by-one in order
for (size_t i = 0; i < patterns.size(); i++) {
std::string pattern_name = pattern_names[i]->value;
Expr pattern = patterns[i];
- merged_expr = MergeCompositeWrapper(pattern_name, pattern).Mutate(merged_expr);
+ PackedFunc check = checks[i];
+ merged_expr = MergeCompositeWrapper(pattern_name, pattern, check).Mutate(merged_expr);
}
return merged_expr;
}
namespace transform {
Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
- const tvm::Array<Expr>& patterns) {
+ const tvm::Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(
- relay::merge_composite::MergeComposite(f, pattern_names, patterns));
+ relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {});
return func_pass;
}
-TVM_REGISTER_GLOBAL("relay._transform.MergeComposite")
-.set_body_typed(MergeComposite);
+TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) {
+ tvm::Array<tir::StringImm> pattern_names = args[0];
+ tvm::Array<Expr> patterns = args[1];
+ std::vector<PackedFunc> checks;
+ for (int i = 2; i < args.size(); i++) {
+ checks.push_back(args[i]);
+ }
+ *rv = MergeComposite(pattern_names, patterns, checks);
+});
} // namespace transform
assert tvm.ir.structural_equal(expected, result)
+def test_composite_function():
+ def before():
+ a = relay.var('a', shape=(10, 10))
+ b = relay.var('b', shape=(10, 10))
+
+ # add_relu function
+ in_1 = relay.var('in_1', shape=(10, 10))
+ in_2 = relay.var('in_2', shape=(10, 10))
+ add_node = relay.add(in_1, in_2)
+ relu_node = relay.nn.relu(add_node)
+ add_relu = relay.Function([in_1, in_2], relu_node)
+ add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
+
+ # merged function
+ r = relay.Call(add_relu, [a, b])
+ f = relay.Function([a, b], r)
+ mod = tvm.IRModule.from_expr(f)
+ return mod
+
+ def after():
+ a = relay.var('a', shape=(10, 10))
+ b = relay.var('b', shape=(10, 10))
+
+ # add_relu function
+ in_1 = relay.var('in_1', shape=(10, 10))
+ in_2 = relay.var('in_2', shape=(10, 10))
+ add_node = relay.add(in_1, in_2)
+ relu_node = relay.nn.relu(add_node)
+ add_relu = relay.Function([in_1, in_2], relu_node)
+ add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
+
+ # merged function
+ cb_1 = relay.annotation.compiler_begin(a, "test")
+ cb_2 = relay.annotation.compiler_begin(b, "test")
+ r = relay.Call(add_relu, [cb_1, cb_2])
+ ce_1 = relay.annotation.compiler_end(r, "test")
+ f = relay.Function([a, b], ce_1)
+ mod = tvm.IRModule.from_expr(f)
+ return mod
+
+ result = transform.AnnotateTarget("test")(before())
+ expected = transform.InferType()(after())
+ assert tvm.ir.structural_equal(expected, result)
+
+
if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl()
test_extern_dnnl_mobilenet()
+ test_composite_function()
assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
+def test_pattern_with_check():
+ def before():
+ x = relay.var('x', shape=(1, 10, 10, 10))
+ w = relay.var('w', shape=(10, 10, 3, 3))
+ b = relay.var('b', shape=(8,))
+ conv = relay.nn.conv2d(x,
+ w,
+ kernel_size=(3, 3),
+ kernel_layout="OIHW",
+ data_layout="NHWC")
+ bias = relay.nn.bias_add(conv, b)
+ relu = relay.nn.relu(bias)
+ return relay.Function([x, w, b], relu)
+
+ def _check_true(extract):
+ conv = extract.args[0].args[0]
+ return conv.attrs.data_layout == "NHWC"
+
+ def _check_false(extract):
+ conv = extract.args[0].args[0]
+ return conv.attrs.data_layout == "NCHW"
+
+ pattern_table_true = [
+ ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true)
+ ]
+ pattern_table_false = [
+ ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false)
+ ]
+
+ result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false))
+ expected = run_opt_pass(before(), relay.transform.InferType())
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
+
+ result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true))
+ assert result.body.op.attrs["Composite"] == "conv_bias_relu"
+
+
if __name__ == "__main__":
test_simple_merge()
test_branch_merge()
test_multiple_input_subgraphs()
test_reuse_call_merge()
test_tuple_get_item_merge()
+ test_pattern_with_check()