[Relay][Fix] Fix alter op layout when calling a global var (#4454)
authorHaichen Shen <shenhaichen@gmail.com>
Tue, 10 Dec 2019 19:09:23 +0000 (11:09 -0800)
committerYao Wang <kevinthesunwy@gmail.com>
Tue, 10 Dec 2019 19:09:23 +0000 (11:09 -0800)
* [Relay][Fix] Fix alter op layout when calling a global var

* add test case

src/relay/pass/alter_op_layout.cc
tests/python/relay/test_pass_alter_op_layout.py

index bbfb97c..d893d94 100644 (file)
@@ -161,6 +161,9 @@ std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer(
     const Array<Layout>& old_in_layouts,
     const Array<Array<IndexExpr> > &old_in_shapes) {
   static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");
+  if (!call->op.as<OpNode>()) {
+    return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
+  }
 
   Op op = Downcast<Op>(call->op);
   if (finfer_layout.count(op)) {
index c1941c9..9ab582d 100644 (file)
@@ -931,6 +931,47 @@ def test_alter_layout_nhwc_nchw_arm():
 
     assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
 
+def test_alter_op_with_global_var():
+    """Test directly replacing an operator with a new one"""
+    def before():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var('weight', shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(x, weight,
+                            channels=64,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        y = relay.nn.relu(y)
+        mod = relay.Module()
+        foo = relay.GlobalVar('foo')
+        mod[foo] = relay.Function([x, weight], y)
+        mod["main"] = relay.Function([x, weight], foo(x, weight))
+        return mod
+
+    def alter_conv2d(attrs, inputs, tinfos):
+        data, weight = inputs
+        weight = relay.multiply(weight, relay.const(2.0, "float32"))
+        return relay.nn.conv2d(data, weight, **attrs)
+
+    def expected():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var('weight', shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
+                            channels=64,
+                            kernel_size=(3, 3),
+                            padding=(1, 1))
+        y = relay.nn.relu(y)
+        mod = relay.Module()
+        foo = relay.GlobalVar('foo')
+        mod[foo] = relay.Function([x, weight], y)
+        mod["main"] = relay.Function([x, weight], foo(x, weight))
+        return mod
+
+    with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+        a = before()
+        a = transform.AlterOpLayout()(a)
+        b = transform.InferType()(expected())
+
+    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
 
 if __name__ == "__main__":
     test_alter_op()
@@ -949,3 +990,4 @@ if __name__ == "__main__":
     test_alter_layout_pool()
     test_alter_layout_sum()
     test_alter_layout_nhwc_nchw_arm()
+    test_alter_op_with_global_var()