fix prelu, now can use on 2d input and add one test (#2875)
authorXiaolongMeng <tony8078@126.com>
Mon, 25 Mar 2019 02:27:19 +0000 (10:27 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 25 Mar 2019 02:27:19 +0000 (19:27 -0700)
topi/include/topi/nn.h
topi/python/topi/nn/elemwise.py
topi/tests/python/test_topi_relu.py

index 00c3f999853d23308332c15c96280022818b1d89..653c0a5f70ce98ce2353d080c5031d45f14243d3 100644 (file)
@@ -97,7 +97,6 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
                          const int axis = 1,
                          std::string name = "tensor",
                          std::string tag = kBroadcast) {
-  CHECK_EQ(4, x->shape.size());
   CHECK((size_t)axis < x->shape.size()) <<
         "Wrong axis ("  << axis << ")value. ";
   CHECK(topi::detail::GetConstInt(slope->shape[0]) ==
index 14a747e67610fb97f2cf3c04edc2895776a80c93..6a2697795f4dd410e8b56acbf5dfe5b3f7678aee 100644 (file)
@@ -69,7 +69,7 @@ def prelu(x, slope, axis=1):
         [http://arxiv.org/pdf/1502.01852v1.pdf]
     """
 
-    assert len(x.shape) == 4 and len(slope.shape) == 1
+    assert len(slope.shape) == 1
     assert axis < len(x.shape)
     assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis])
 
index a7ff64f0f759300b34ff8c6babbed4afe51717b9..5aa9c1ee57a019b2984d2ba7d7034d5661e5e477 100644 (file)
@@ -83,6 +83,7 @@ def test_leaky_relu():
 def test_prelu():
     verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
     verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))
+    verify_prelu((1, 3), (3,), 1, (3, ))
 
 if __name__ == "__main__":
     test_schedule_big_array()