[RELAY][BYOC] Preserve type information in Merge Composite (#5640)
authormbaret <55580676+mbaret@users.noreply.github.com>
Fri, 22 May 2020 02:24:23 +0000 (03:24 +0100)
committerGitHub <noreply@github.com>
Fri, 22 May 2020 02:24:23 +0000 (11:24 +0900)
Keep the type information when extracting patterns
so that it can be used as part of 'check' functions.

Change-Id: I16cc70c3d013a794d2ceefb5bec815129c7b8825

src/relay/transforms/merge_composite.cc
tests/python/relay/test_pass_merge_composite.py

index 596e2a1..027e512 100644 (file)
@@ -46,7 +46,8 @@ class MergeCompositeWrapper : public ExprMutator {
     if (var_map->find(pattern->name_hint()) == var_map->end()) {
       // if we haven't encountered this var yet, make a new free var and associate
       // it with the value at 'root'
-      auto free_var = Var(pattern->name_hint(), Type());
+      auto free_var = Var(pattern->name_hint(), root->checked_type());
+      free_var->checked_type_ = root->checked_type();
       var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
       return std::move(free_var);
     } else {
@@ -147,7 +148,9 @@ class MergeCompositeWrapper : public ExprMutator {
       new_args.push_back(new_arg);
       i++;
     }
-    return Call(root->op, new_args, root->attrs);
+    Call new_call = Call(root->op, new_args, root->attrs);
+    new_call->checked_type_ = root->checked_type();
+    return std::move(new_call);
   }
 
   Expr VisitExpr_(const CallNode* cn) {
@@ -163,12 +166,15 @@ class MergeCompositeWrapper : public ExprMutator {
           auto new_e = this->Mutate(arg);
           new_args.push_back(new_e);
         }
-        return Call(call->op, new_args, call->attrs);
+        Call new_call = Call(call->op, new_args, call->attrs);
+        new_call->checked_type_ = call->checked_type();
+        return std::move(new_call);
       }
     }
 
     Expr expr = ExprMutator::VisitExpr_(cn);
     call = Downcast<Call>(expr);
+    call->checked_type_ = cn->checked_type();
     if (!call->op->IsInstance<OpNode>()) return std::move(call);
 
     // only call patterns are supported
@@ -189,6 +195,7 @@ class MergeCompositeWrapper : public ExprMutator {
         args.push_back(args_map[free_var->name_hint()][1]);
       }
       auto new_call = Call(f, args);
+      new_call->checked_type_ = call->checked_type();
       return std::move(new_call);
     }
     return std::move(call);
index 317bb42..3a79f6a 100644 (file)
@@ -803,6 +803,46 @@ def test_diamond_not_merge():
     assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
 
 
+def test_type_check():
+    """Test that we can query tensor types in the 'check' function."""
+    def before():
+        x = relay.var('x', shape=(1, 10, 10, 10))
+        w = relay.var('w', shape=(10, 10, 3, 3))
+        b = relay.var('b', shape=(8,))
+        conv = relay.nn.conv2d(x,
+                               w,
+                               kernel_size=(3, 3),
+                               kernel_layout="OIHW",
+                               data_layout="NHWC")
+        bias = relay.nn.bias_add(conv, b)
+        relu = relay.nn.relu(bias)
+        return relay.Function([x, w, b], relu)
+
+    def _check_type_true(extract):
+        conv = extract.args[0].args[0]
+        typ = conv.checked_type
+        return bool(typ.shape[0] == 1)
+
+    def _check_type_false(extract):
+        conv = extract.args[0].args[0]
+        typ = conv.checked_type
+        return bool(typ.shape[0] != 1)
+
+    pattern_table_true = [
+        ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true)
+    ]
+    pattern_table_false = [
+        ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false)
+    ]
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false))
+    expected = run_opt_pass(before(), relay.transform.InferType())
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
+
+    result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true))
+    assert result.body.op.attrs["Composite"] == "conv_bias_relu"
+
+
 if __name__ == "__main__":
     test_simple_merge()
     test_branch_merge()
@@ -814,3 +854,4 @@ if __name__ == "__main__":
     test_tuple_get_item_merge()
     test_pattern_with_check()
     test_diamond_not_merge()
+    test_type_check()