Hotfix for issue #3641. (#3644)
authorBalint Cristian <cristian.balint@gmail.com>
Sun, 28 Jul 2019 08:05:37 +0000 (11:05 +0300)
committerThierry Moreau <moreau@uw.edu>
Sun, 28 Jul 2019 08:05:37 +0000 (01:05 -0700)
topi/tests/python/test_topi_conv2d_winograd.py

index cf176a8..a42d61d 100644 (file)
@@ -81,7 +81,12 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
         else:
             func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
             func(a, w, c)
-        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+        rtol = 1e-5
+        if (kernel > 3):
+          rtol = 2e-5
+
+        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)
 
 
     for device in devices: