[Relay][heterogeneous pass] remove on_device op after annotation (#3204)
authorZhi <5145158+zhiics@users.noreply.github.com>
Tue, 21 May 2019 19:53:58 +0000 (12:53 -0700)
committerJared Roesch <roeschinc@gmail.com>
Tue, 21 May 2019 19:53:58 +0000 (12:53 -0700)
* remove on_device op after annotation

* Update src/relay/pass/device_annotation.cc

Co-Authored-By: MORINAGA <34588258+imorinaga@users.noreply.github.com>
src/relay/pass/device_annotation.cc
tests/python/relay/test_pass_annotation.py

index 0139cc9..8807f6d 100644 (file)
@@ -485,7 +485,52 @@ class DeviceInfo {
 
 Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
   RewriteAnnotation rewrote = RewriteAnnotation();
-  return rewrote.Rewrite(expr, fallback_device);
+  Expr new_expr = rewrote.Rewrite(expr, fallback_device);
+
+  // Remove OnDevice operators. Note that these operators are only present at the
+  // leaves after annotation. Therefore, we can simply reconstruct the
+  // Function/Expr by removing them directly.
+  if (const FunctionNode* fn = new_expr.as<FunctionNode>()) {
+    auto params = fn->params;
+    auto body = fn->body;
+    std::vector<Expr> new_body;
+    if (const TupleNode* tuple = body.as<TupleNode>()) {
+      for (const auto& field : tuple->fields) {
+        if (!IsOnDeviceNode(field.operator->())) {
+          new_body.push_back(field);
+        }
+      }
+      CHECK_GT(new_body.size(), 0U);
+      if (new_body.size() == 1) {
+        return FunctionNode::make(params, new_body[0], Type(nullptr),
+                                  fn->type_params, fn->attrs);
+      } else if (tuple->fields.size() == new_body.size()) {
+          return new_expr;
+      } else {
+        Tuple tuple_body = TupleNode::make(new_body);
+        return FunctionNode::make(params, tuple_body, Type(nullptr),
+                                  fn->type_params, fn->attrs);
+      }
+    } else {
+      return new_expr;
+    }
+  } else if (const TupleNode* tuple = new_expr.as<TupleNode>()) {
+    std::vector<Expr> new_fields;
+    for (const auto& field : tuple->fields) {
+      if (!IsOnDeviceNode(field.operator->())) {
+        new_fields.push_back(field);
+      }
+    }
+    CHECK_GT(new_fields.size(), 0U);
+    if (tuple->fields.size() == new_fields.size()) {
+      return new_fields.size() == 1 ? new_fields[0] : new_expr;
+    } else {
+      return new_fields.size() == 1 ? new_fields[0]
+                                    : TupleNode::make(new_fields);
+    }
+  } else {
+    return new_expr;
+  }
 }
 
 Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) {
index 9a77d2f..98cf0f1 100644 (file)
@@ -42,9 +42,7 @@ def test_redundant_annotation():
         func = relay.ir_pass.infer_type(func)
         func = relay.ir_pass.rewrite_annotated_ops(func,
                                                    ctx1.device_type)
-        func = relay.ir_pass.infer_type(func)
-        return relay.Function(relay.ir_pass.free_vars(func.body[2]),
-                              func.body[2])
+        return func
 
     def expected():
         add = relay.add(x, y)
@@ -58,6 +56,35 @@ def test_redundant_annotation():
     assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
 
 
+def test_annotate_expr():
+    ctx1 = tvm.context(1)
+    ctx2 = tvm.context(2)
+    x = relay.var("x", shape=(3,))
+    y = relay.var("y", shape=(3,))
+    z = relay.var("z", shape=(3,))
+
+    def annotated():
+        add = relay.add(x, y)
+        _add = relay.annotation.on_device(add, ctx1)
+        sub = relay.subtract(add, z)
+        _sub = relay.annotation.on_device(sub, ctx2)
+        expr = relay.Tuple([sub, _add, _sub])
+        expr = relay.ir_pass.infer_type(expr)
+        expr = relay.ir_pass.rewrite_annotated_ops(expr,
+                                                   ctx1.device_type)
+        return expr
+
+    def expected():
+        add = relay.add(x, y)
+        copy_add_sub = relay.device_copy(add, ctx1, ctx2)
+        sub = relay.subtract(copy_add_sub, z)
+        return sub
+
+    annotated_expr = relay.ir_pass.infer_type(annotated())
+    expected_expr = relay.ir_pass.infer_type(expected())
+    assert relay.ir_pass.graph_equal(annotated_expr, expected_expr)
+
+
 def test_annotate_all():
     ctx1 = tvm.context(1)
     ctx2 = tvm.context(2)
@@ -77,9 +104,7 @@ def test_annotate_all():
         func = relay.ir_pass.infer_type(func)
         func = relay.ir_pass.rewrite_annotated_ops(func,
                                                    ctx1.device_type)
-        func = relay.ir_pass.infer_type(func)
-        return relay.Function(relay.ir_pass.free_vars(func.body[2]),
-                              func.body[2])
+        return func
 
     def expected():
         add = relay.add(x, y)
@@ -91,6 +116,7 @@ def test_annotate_all():
     expected_func = relay.ir_pass.infer_type(expected())
     assert relay.ir_pass.alpha_equal(annotated_func, expected_func)
 
+
 def test_annotate_none():
     ctx1 = tvm.context(1)
     ctx2 = tvm.context(2)
@@ -174,9 +200,7 @@ def test_conv_network():
         func = relay.ir_pass.infer_type(func)
         func = relay.ir_pass.rewrite_annotated_ops(func,
                                                    tvm.context(3).device_type)
-        func = relay.ir_pass.infer_type(func)
-        return relay.Function(relay.ir_pass.free_vars(func.body[4]),
-                              func.body[4])
+        return func
 
     def expected():
         conv2d_1 = relay.nn.conv2d(
@@ -202,7 +226,7 @@ def test_conv_network():
             kernel_size=(3, 3),
             padding=(1, 1))
 
-        func = relay.Function([data1, weight, data2], conv2d_3)
+        func = relay.Function([data1, data2, weight], conv2d_3)
         return func
 
     def check_storage_and_device_types():
@@ -306,9 +330,7 @@ def run_fusible_network(dev, tgt):
             func = relay.ir_pass.infer_type(func)
             func = relay.ir_pass.rewrite_annotated_ops(func,
                                                        cpu_ctx.device_type)
-            func = relay.ir_pass.infer_type(func)
-            return relay.Function(relay.ir_pass.free_vars(func.body[2]),
-                                  func.body[2])
+            return func
 
         def expected():
             add = relay.add(x, y)
@@ -358,9 +380,7 @@ def run_fusible_network(dev, tgt):
             func = relay.ir_pass.infer_type(func)
             func = relay.ir_pass.rewrite_annotated_ops(func,
                                                        cpu_ctx.device_type)
-            func = relay.ir_pass.infer_type(func)
-            return relay.Function(relay.ir_pass.free_vars(func.body[5]),
-                                  func.body[5])
+            return func
 
         annotated_func = annotated()
         expected_func = get_func()
@@ -386,9 +406,7 @@ def run_fusible_network(dev, tgt):
             func = relay.ir_pass.infer_type(func)
             func = relay.ir_pass.rewrite_annotated_ops(func,
                                                        dev_ctx.device_type)
-            func = relay.ir_pass.infer_type(func)
-            return relay.Function(relay.ir_pass.free_vars(func.body[1]),
-                                  func.body[1])
+            return func
 
         def expected():
             add = relay.add(x, y)
@@ -462,9 +480,7 @@ def run_unpropagatable_graph(dev, tgt):
         func = relay.ir_pass.infer_type(func)
         func = relay.ir_pass.rewrite_annotated_ops(func,
                                                    dev_ctx.device_type)
-        func = relay.ir_pass.infer_type(func)
-        return relay.Function(relay.ir_pass.free_vars(func.body[3]),
-                              func.body[3])
+        return func
         
     def expected():    
         add = relay.add(a, b)
@@ -506,6 +522,7 @@ def test_check_run():
  
 if __name__ == "__main__":
     test_redundant_annotation()
+    test_annotate_expr()
     test_annotate_all()
     test_annotate_none()
     test_conv_network()