const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr> > &old_in_shapes) {
static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");
+ if (!call->op.as<OpNode>()) {
+ return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
+ }
Op op = Downcast<Op>(call->op);
if (finfer_layout.count(op)) {
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+def test_alter_op_with_global_var():
+ """Test directly replacing an operator with a new one"""
+ def before():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight = relay.var('weight', shape=(64, 64, 3, 3))
+ y = relay.nn.conv2d(x, weight,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1))
+ y = relay.nn.relu(y)
+ mod = relay.Module()
+ foo = relay.GlobalVar('foo')
+ mod[foo] = relay.Function([x, weight], y)
+ mod["main"] = relay.Function([x, weight], foo(x, weight))
+ return mod
+
+ def alter_conv2d(attrs, inputs, tinfos):
+ data, weight = inputs
+ weight = relay.multiply(weight, relay.const(2.0, "float32"))
+ return relay.nn.conv2d(data, weight, **attrs)
+
+ def expected():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight = relay.var('weight', shape=(64, 64, 3, 3))
+ y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1))
+ y = relay.nn.relu(y)
+ mod = relay.Module()
+ foo = relay.GlobalVar('foo')
+ mod[foo] = relay.Function([x, weight], y)
+ mod["main"] = relay.Function([x, weight], foo(x, weight))
+ return mod
+
+ with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+ a = before()
+ a = transform.AlterOpLayout()(a)
+ b = transform.InferType()(expected())
+
+ assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__":
test_alter_op()
test_alter_layout_pool()
test_alter_layout_sum()
test_alter_layout_nhwc_nchw_arm()
+ test_alter_op_with_global_var()