From: XiaolongMeng Date: Mon, 25 Mar 2019 02:27:19 +0000 (+0800) Subject: fix prelu, now can use on 2d input and add one test (#2875) X-Git-Tag: upstream/0.7.0~2600 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=69758ed1eb9105a42996a077598becc5252c367b;p=platform%2Fupstream%2Ftvm.git fix prelu, now can use on 2d input and add one test (#2875) --- diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index 00c3f9998..653c0a5f7 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -97,7 +97,6 @@ inline tvm::Tensor prelu(const tvm::Tensor &x, const int axis = 1, std::string name = "tensor", std::string tag = kBroadcast) { - CHECK_EQ(4, x->shape.size()); CHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. "; CHECK(topi::detail::GetConstInt(slope->shape[0]) == diff --git a/topi/python/topi/nn/elemwise.py b/topi/python/topi/nn/elemwise.py index 14a747e67..6a2697795 100644 --- a/topi/python/topi/nn/elemwise.py +++ b/topi/python/topi/nn/elemwise.py @@ -69,7 +69,7 @@ def prelu(x, slope, axis=1): [http://arxiv.org/pdf/1502.01852v1.pdf] """ - assert len(x.shape) == 4 and len(slope.shape) == 1 + assert len(slope.shape) == 1 assert axis < len(x.shape) assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis]) diff --git a/topi/tests/python/test_topi_relu.py b/topi/tests/python/test_topi_relu.py index a7ff64f0f..5aa9c1ee5 100644 --- a/topi/tests/python/test_topi_relu.py +++ b/topi/tests/python/test_topi_relu.py @@ -83,6 +83,7 @@ def test_leaky_relu(): def test_prelu(): verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1)) verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1)) + verify_prelu((1, 3), (3,), 1, (3, )) if __name__ == "__main__": test_schedule_big_array()