Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
RewriteAnnotation rewrote = RewriteAnnotation();
- return rewrote.Rewrite(expr, fallback_device);
+ Expr new_expr = rewrote.Rewrite(expr, fallback_device);
+
+ // Remove OnDevice operators. Note that these operators are only present at the
+ // leaves after annotation. Therefore, we can simply reconstruct the
+ // Function/Expr by removing them directly.
+ if (const FunctionNode* fn = new_expr.as<FunctionNode>()) {
+ auto params = fn->params;
+ auto body = fn->body;
+ std::vector<Expr> new_body;
+ if (const TupleNode* tuple = body.as<TupleNode>()) {
+ for (const auto& field : tuple->fields) {
+ if (!IsOnDeviceNode(field.operator->())) {
+ new_body.push_back(field);
+ }
+ }
+ CHECK_GT(new_body.size(), 0U);
+ if (new_body.size() == 1) {
+ return FunctionNode::make(params, new_body[0], Type(nullptr),
+ fn->type_params, fn->attrs);
+ } else if (tuple->fields.size() == new_body.size()) {
+ return new_expr;
+ } else {
+ Tuple tuple_body = TupleNode::make(new_body);
+ return FunctionNode::make(params, tuple_body, Type(nullptr),
+ fn->type_params, fn->attrs);
+ }
+ } else {
+ return new_expr;
+ }
+ } else if (const TupleNode* tuple = new_expr.as<TupleNode>()) {
+ std::vector<Expr> new_fields;
+ for (const auto& field : tuple->fields) {
+ if (!IsOnDeviceNode(field.operator->())) {
+ new_fields.push_back(field);
+ }
+ }
+ CHECK_GT(new_fields.size(), 0U);
+ if (tuple->fields.size() == new_fields.size()) {
+ return new_fields.size() == 1 ? new_fields[0] : new_expr;
+ } else {
+ return new_fields.size() == 1 ? new_fields[0]
+ : TupleNode::make(new_fields);
+ }
+ } else {
+ return new_expr;
+ }
}
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) {
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
- func = relay.ir_pass.infer_type(func)
- return relay.Function(relay.ir_pass.free_vars(func.body[2]),
- func.body[2])
+ return func
def expected():
add = relay.add(x, y)
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
+def test_annotate_expr():
+ ctx1 = tvm.context(1)
+ ctx2 = tvm.context(2)
+ x = relay.var("x", shape=(3,))
+ y = relay.var("y", shape=(3,))
+ z = relay.var("z", shape=(3,))
+
+ def annotated():
+ add = relay.add(x, y)
+ _add = relay.annotation.on_device(add, ctx1)
+ sub = relay.subtract(add, z)
+ _sub = relay.annotation.on_device(sub, ctx2)
+ expr = relay.Tuple([sub, _add, _sub])
+ expr = relay.ir_pass.infer_type(expr)
+ expr = relay.ir_pass.rewrite_annotated_ops(expr,
+ ctx1.device_type)
+ return expr
+
+ def expected():
+ add = relay.add(x, y)
+ copy_add_sub = relay.device_copy(add, ctx1, ctx2)
+ sub = relay.subtract(copy_add_sub, z)
+ return sub
+
+ annotated_expr = relay.ir_pass.infer_type(annotated())
+ expected_expr = relay.ir_pass.infer_type(expected())
+ assert relay.ir_pass.graph_equal(annotated_expr, expected_expr)
+
+
def test_annotate_all():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
ctx1.device_type)
- func = relay.ir_pass.infer_type(func)
- return relay.Function(relay.ir_pass.free_vars(func.body[2]),
- func.body[2])
+ return func
def expected():
add = relay.add(x, y)
expected_func = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
+
def test_annotate_none():
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
tvm.context(3).device_type)
- func = relay.ir_pass.infer_type(func)
- return relay.Function(relay.ir_pass.free_vars(func.body[4]),
- func.body[4])
+ return func
def expected():
conv2d_1 = relay.nn.conv2d(
kernel_size=(3, 3),
padding=(1, 1))
- func = relay.Function([data1, weight, data2], conv2d_3)
+ func = relay.Function([data1, data2, weight], conv2d_3)
return func
def check_storage_and_device_types():
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
- func = relay.ir_pass.infer_type(func)
- return relay.Function(relay.ir_pass.free_vars(func.body[2]),
- func.body[2])
+ return func
def expected():
add = relay.add(x, y)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
cpu_ctx.device_type)
- func = relay.ir_pass.infer_type(func)
- return relay.Function(relay.ir_pass.free_vars(func.body[5]),
- func.body[5])
+ return func
annotated_func = annotated()
expected_func = get_func()
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
- func = relay.ir_pass.infer_type(func)
- return relay.Function(relay.ir_pass.free_vars(func.body[1]),
- func.body[1])
+ return func
def expected():
add = relay.add(x, y)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
- func = relay.ir_pass.infer_type(func)
- return relay.Function(relay.ir_pass.free_vars(func.body[3]),
- func.body[3])
+ return func
def expected():
add = relay.add(a, b)
if __name__ == "__main__":
test_redundant_annotation()
+ test_annotate_expr()
test_annotate_all()
test_annotate_none()
test_conv_network()