const int axis = 1,
std::string name = "tensor",
std::string tag = kBroadcast) {
- CHECK_EQ(4, x->shape.size());
CHECK((size_t)axis < x->shape.size()) <<
"Wrong axis (" << axis << ")value. ";
CHECK(topi::detail::GetConstInt(slope->shape[0]) ==
[http://arxiv.org/pdf/1502.01852v1.pdf]
"""
- assert len(x.shape) == 4 and len(slope.shape) == 1
+ assert len(slope.shape) == 1
assert axis < len(x.shape)
assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis])
def test_prelu():
verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))
+ verify_prelu((1, 3), (3,), 1, (3, ))
if __name__ == "__main__":
test_schedule_big_array()