From f397fead73bae20ca9abd2fc69df5723e065acde Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 24 Jan 2019 12:50:21 +0800 Subject: [PATCH] [RELAY] Fix ops in packed layout (#2472) * [RELAY] Fix ops in packed layout * Fix style --- src/relay/op/nn/pooling.cc | 6 +++++- src/relay/op/tensor/transform.cc | 3 ++- tests/python/relay/test_pass_alter_op_layout.py | 4 ++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 6cf37668c..8fd33e1f3 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -83,7 +83,11 @@ bool Pool2DRel(const Array& types, return false; } - std::vector oshape({dshape[0], dshape[1], dshape[2], dshape[3]}); + std::vector oshape; + for (const auto& e : dshape) { + oshape.push_back(e); + } + if (param->ceil_mode) { oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] + param->strides[0] - 1) / param->strides[0]) + 1; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 704324533..6d583bfd6 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -76,7 +76,8 @@ RELAY_REGISTER_OP("cast") .set_support_level(3) .add_type_rel("Cast", CastRel) .set_attr("FTVMCompute", CastCompute) -.set_attr("TOpPattern", kElemWise); +.set_attr("TOpPattern", kElemWise) +.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); // relay.expand_dims TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 48ab2ba27..975973d2b 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -82,6 +82,8 @@ def test_alter_layout(): # a useless tuple, which will be eliminated y = relay.Tuple([y])[0] y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2)) + y = relay.cast(y, 'int32') y = relay.nn.batch_flatten(y) y = relay.Function(free_vars(y), y) return y @@ -112,6 +114,8 @@ def test_alter_layout(): y = relay.add(y, b) y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NCHW16c") + y = relay.cast(y, 'int32') y = relay.layout_transform(y, "NCHW16c", "NCHW") y = relay.nn.batch_flatten(y) y = relay.Function(free_vars(y), y) -- 2.34.1