[RELAY] Fix ops in packed layout (#2472)
authorWuwei Lin <vincentl13x@gmail.com>
Thu, 24 Jan 2019 04:50:21 +0000 (12:50 +0800)
committerYizhi Liu <liuyizhi@apache.org>
Thu, 24 Jan 2019 04:50:21 +0000 (20:50 -0800)
* [RELAY] Fix ops in packed layout

* Fix style

src/relay/op/nn/pooling.cc
src/relay/op/tensor/transform.cc
tests/python/relay/test_pass_alter_op_layout.py

index 6cf37668cab59ed2f3d3f7f399e2e069b1e29e62..8fd33e1f3cdca6e8f99e5374a55ec49bd157dec5 100644 (file)
@@ -83,7 +83,11 @@ bool Pool2DRel(const Array<Type>& types,
     return false;
   }
 
-  std::vector<IndexExpr> oshape({dshape[0], dshape[1], dshape[2], dshape[3]});
+  std::vector<IndexExpr> oshape;
+  for (const auto& e : dshape) {
+    oshape.push_back(e);
+  }
+
   if (param->ceil_mode) {
     oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] +
                     param->strides[0] - 1) / param->strides[0]) + 1;
index 7043245331856759ee442acee8f11015b812d112..6d583bfd6636f460c54e6646115d418990a282bb 100644 (file)
@@ -76,7 +76,8 @@ RELAY_REGISTER_OP("cast")
 .set_support_level(3)
 .add_type_rel("Cast", CastRel)
 .set_attr<FTVMCompute>("FTVMCompute", CastCompute)
-.set_attr<TOpPattern>("TOpPattern", kElemWise);
+.set_attr<TOpPattern>("TOpPattern", kElemWise)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
 
 // relay.expand_dims
 TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
index 48ab2ba271f7319441db7f20c931077d7f70bd0e..975973d2b952259f832ca385608b0fed45fca874 100644 (file)
@@ -82,6 +82,8 @@ def test_alter_layout():
         # a useless tuple, which will be eliminated
         y = relay.Tuple([y])[0]
         y = relay.nn.relu(y)
+        y = relay.nn.max_pool2d(y, pool_size=(2, 2))
+        y = relay.cast(y, 'int32')
         y = relay.nn.batch_flatten(y)
         y = relay.Function(free_vars(y), y)
         return y
@@ -112,6 +114,8 @@ def test_alter_layout():
         y = relay.add(y, b)
 
         y = relay.nn.relu(y)
+        y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NCHW16c")
+        y = relay.cast(y, 'int32')
         y = relay.layout_transform(y, "NCHW16c", "NCHW")
         y = relay.nn.batch_flatten(y)
         y = relay.Function(free_vars(y), y)