[Fix] Fix get_valid_count flaky test for cuda (#4901)
authorLeyuan Wang <laurawly@gmail.com>
Fri, 21 Feb 2020 22:31:04 +0000 (14:31 -0800)
committerGitHub <noreply@github.com>
Fri, 21 Feb 2020 22:31:04 +0000 (14:31 -0800)
* get_valid_count accuracy issue fixed for individual tests but not for all tests running together

* minor fix

* initialize valid_count and PrefixSum buffers

* test updated

* udpate relay test as well

* update document

* fix lint

* address comment

* fix lint

* correct atomicAdd identifier name

tests/python/relay/test_op_level5.py
topi/python/topi/cuda/nms.py
topi/tests/python/test_topi_vision.py

index 03e700b..e622a8a 100644 (file)
@@ -221,8 +221,6 @@ def test_get_valid_counts():
         func = relay.Function([x], z.astuple())
         func = run_infer_type(func)
         for target, ctx in ctx_list():
-            if target == 'cuda':
-                return
             intrp = relay.create_executor("debug", ctx=ctx, target=target)
             out = intrp.evaluate(func)(np_data)
             tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04)
index 38f87a9..5485859 100644 (file)
@@ -21,29 +21,46 @@ import math
 import tvm
 
 from tvm import api
-from tvm.generic import cast
-from tvm.intrin import if_then_else, log, power
+from tvm.intrin import if_then_else
 from topi.vision import non_max_suppression, get_valid_counts
 from .sort import argsort
 from .. import tag
 
 
-def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index):
-    """Low level IR to Prepare get valid count of bounding boxes
-    given a score threshold. Also moves valid boxes to the
+def cuda_atomic_add_rule(op):
+    if op.dtype == "float32":
+        return tvm.call_pure_extern("float32", "atomicAdd", op.args[0], op.args[1])
+    if op.dtype == "float64":
+        return tvm.call_pure_extern("float64", "atomicAdd", op.args[0], op.args[1])
+    if op.dtype == "int32":
+        return tvm.call_pure_extern("int32", "atomicAdd", op.args[0], op.args[1])
+    raise RuntimeError("only support int32, float32 and float64")
+
+
+tvm.target.intrin.register_intrin_rule(
+    "cuda", "atomic_add", cuda_atomic_add_rule, override=True)
+
+
+def atomic_add(x, y):
+    return tvm.call_pure_intrin(y.dtype, "atomic_add", x, y)
+
+
+def get_valid_counts_ir(data, valid_count, flag, score_threshold, id_index, score_index):
+    """Low level IR to get valid count of bounding boxes
+    given a score threshold. Also prepares to move valid boxes to the
     top of input data.
 
     Parameters
     ----------
-    data: Buffer
-        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+    data : Buffer
+        Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length].
+
+    valid_count : Buffer
+        1D buffer for valid number of boxes with shape [batch_size, ].
 
     flag : Buffer
         2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
 
-    idx : Buffer
-        2D Buffer of valid data indices with shape [batch_size, num_anchors].
-
     score_threshold : float32
         Lower limit of score for valid bounding boxes.
 
@@ -60,18 +77,24 @@ def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index
     """
     batch_size = data.shape[0]
     num_anchors = data.shape[1]
-    box_data_length = data.shape[2]
+    elem_length = data.shape[2]
 
     ib = tvm.ir_builder.create()
 
     data = ib.buffer_ptr(data)
+
+    valid_count = ib.buffer_ptr(valid_count)
     flag = ib.buffer_ptr(flag)
-    idx = ib.buffer_ptr(idx)
-    score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold)
+    atomic_add_return = ib.allocate(
+        valid_count.dtype, (1,), name='atomic_add_return', scope='local')
+    one_count = tvm.const(1, dtype=valid_count.dtype)
+    score_threshold = tvm.make.node(
+        "FloatImm", dtype="float32", value=score_threshold)
     id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
     score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
 
-    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(
+        allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = batch_size * num_anchors // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
@@ -79,163 +102,52 @@ def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index
     ib.scope_attr(tx, "thread_extent", nthread_tx)
     ib.scope_attr(bx, "thread_extent", nthread_bx)
     tid = bx * max_threads + tx
+    idxd = tvm.indexdiv
 
+    # initialize valid_count
+    with ib.if_scope(tid < batch_size):
+        valid_count[tid] = 0
+    # initialize flag
     with ib.if_scope(tid < batch_size * num_anchors):
-        with ib.if_scope(tvm.all(data[tid * box_data_length + score_index] > score_threshold, \
-            tvm.any(id_index < 0, data[tid * box_data_length + id_index] >= 0))):
+        flag[tid] = 0
+    with ib.if_scope(tid < batch_size * num_anchors):
+        i = idxd(tid, num_anchors)
+        with ib.if_scope(tvm.all(data[tid * elem_length + score_index] > score_threshold,
+                                 tvm.any(id_index < 0, data[tid * elem_length + id_index] >= 0))):
             flag[tid] = 1
-            idx[tid] = 1
-        with ib.else_scope():
-            flag[tid] = 0
-            idx[tid] = 0
+            atomic_add_return[0] = atomic_add(tvm.call_pure_intrin("handle", "tvm_address_of",
+                                                                 valid_count[i]), one_count)
 
     return ib.get()
 
-def get_valid_counts_upsweep(data, idx_in, idx, partial):
-    """Low level IR of first step of scan: unsweep.
-
-    Parameters
-    ----------
-    data: Buffer
-        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
-
-    idx_in : Buffer
-        2D Buffer of valid data indices with shape [batch_size, num_anchors].
-
-    idx : Buffer
-        2D Buffer of valid data indices with shape [batch_size, num_anchors].
-
-    partial : Buffer
-        2D Buffer of valid data indices with shape [batch_size, new_range].
-
-    Returns
-    -------
-    stmt : Stmt
-        The result IR statement.
-    """
-    batch_size = data.shape[0]
-    num_anchors = data.shape[1]
-    ib = tvm.ir_builder.create()
-    data = ib.buffer_ptr(data)
-    idx_in = ib.buffer_ptr(idx_in)
-    idx = ib.buffer_ptr(idx)
-    partial = ib.buffer_ptr(partial)
-    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
-    elem_per_thread = num_anchors // max_threads + 1
-    nthread_tx = max_threads
-    nthread_bx = batch_size
-    tx = tvm.thread_axis("threadIdx.x")
-    bx = tvm.thread_axis("blockIdx.x")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "thread_extent", nthread_bx)
-    new_range = num_anchors // elem_per_thread + 1
-    # Scan: Upsweep:
-    with ib.if_scope(tvm.all(bx < batch_size, tx < new_range)):
-        with ib.for_range(0, elem_per_thread) as i:
-            with ib.if_scope(bx * num_anchors + \
-                             tx * elem_per_thread + i < batch_size * num_anchors):
-                with ib.if_scope(i == 0):
-                    partial[bx * new_range + tx] = idx_in[bx * num_anchors + tx * elem_per_thread]
-                    idx[bx * num_anchors + tx * elem_per_thread] = \
-                    idx_in[bx * num_anchors + tx * elem_per_thread]
-                with ib.else_scope():
-                    partial[bx * new_range + tx] += \
-                    idx_in[bx * num_anchors + tx * elem_per_thread + i]
-                    idx[bx * num_anchors + tx * elem_per_thread + i] = \
-                    idx[bx * num_anchors + tx * elem_per_thread + i - 1] + \
-                    idx_in[bx * num_anchors + tx * elem_per_thread + i]
-            ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
-                                  tvm.convert(['shared']),
-                                  tvm.expr.Call.Intrinsic, None, 0))
-    return ib.get()
 
-def get_valid_counts_scan(data, partial_in, partial):
-    """Low level IR to do scan.
+def flag_scan(flag, prefix_sum):
+    """Low level IR to calculate correct positions for valid boxes.
 
     Parameters
     ----------
-    data: Buffer
-        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
-
-    idx_in : Buffer
-        2D Buffer of valid data indices with shape [batch_size, num_anchors].
-
-    idx : Buffer
-        2D Buffer of valid data indices with shape [batch_size, num_anchors].
+    flag : Buffer
+        2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
 
-    partial : Buffer
-        2D Buffer of valid data indices with shape [batch_size, new_range].
+    prefix_sum : Buffer
+        2D Buffer of prefix sum of flags indicating new locations of valid boxes
+        with same shape as flag.
 
     Returns
     -------
     stmt : Stmt
         The result IR statement.
     """
-    batch_size = data.shape[0]
-    num_anchors = data.shape[1]
-    ib = tvm.ir_builder.create()
-    partial_in = ib.buffer_ptr(partial_in)
-    partial = ib.buffer_ptr(partial)
-    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
-    elem_per_thread = num_anchors // max_threads + 1
-    nthread_tx = max_threads
-    nthread_bx = batch_size
-    tx = tvm.thread_axis("threadIdx.x")
-    bx = tvm.thread_axis("blockIdx.x")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "thread_extent", nthread_bx)
-    var = tvm.make.node("FloatImm", dtype="float32", value=2)
-    new_range = num_anchors // elem_per_thread + 1
-    iteration = cast(log(cast(new_range, "float32")) / math.log(2), "int32")
-    # Scan: Kogge-Stone adder
-    with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))):
-        with ib.for_range(0, iteration) as k:
-            with ib.if_scope(k == 0):
-                with ib.if_scope(tvm.all(tx > 0, tx < tvm.min(new_range, num_anchors))):
-                    partial[bx * new_range + tx] = \
-                    partial_in[bx * new_range + tx] + partial_in[bx * new_range + tx - 1]
-                with ib.else_scope():
-                    partial[bx * new_range] = partial_in[bx * new_range]
-            with ib.else_scope():
-                with ib.if_scope(tvm.all(tx >= cast(power(var, k), "int32"), \
-                                         tx < tvm.min(new_range, num_anchors))):
-                    partial[bx * new_range + tx] += \
-                    partial[bx * new_range + tx - cast(power(var, k), "int32")]
-            ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
-                                  tvm.convert(['shared']),
-                                  tvm.expr.Call.Intrinsic, None, 0))
-    return ib.get()
-
-def get_valid_counts_downsweep(data, idx_in, partial, idx):
-    """Low level IR to do downsweep of scan.
-
-    Parameters
-    ----------
-    data: Buffer
-        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
-
-    idx_in : Buffer
-        2D Buffer of valid data indices with shape [batch_size, num_anchors].
+    batch_size = flag.shape[0]
+    num_anchors = flag.shape[1]
 
-    partial : Buffer
-        2D Buffer of valid data indices with shape [batch_size, new_range].
+    ib = tvm.ir_builder.create()
 
-    idx : Buffer
-        2D Buffer of valid data indices with shape [batch_size, num_anchors].
+    flag = ib.buffer_ptr(flag)
+    prefix_sum = ib.buffer_ptr(prefix_sum)
 
-    Returns
-    -------
-    stmt : Stmt
-        The result IR statement.
-    """
-    batch_size = data.shape[0]
-    num_anchors = data.shape[1]
-    ib = tvm.ir_builder.create()
-    idx_in = ib.buffer_ptr(idx_in)
-    idx = ib.buffer_ptr(idx)
-    partial = ib.buffer_ptr(partial)
-    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
-    elem_per_thread = num_anchors // max_threads + 1
+    max_threads = int(tvm.target.Target.current(
+        allow_none=False).max_num_threads)
     nthread_tx = max_threads
     nthread_bx = batch_size * num_anchors // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
@@ -243,23 +155,23 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
     ib.scope_attr(tx, "thread_extent", nthread_tx)
     ib.scope_attr(bx, "thread_extent", nthread_bx)
     tid = bx * max_threads + tx
-    new_range = num_anchors // elem_per_thread + 1
     idxd = tvm.indexdiv
     idxm = tvm.indexmod
-    # Scan: Downsweep:
-    with ib. if_scope(tid < batch_size * num_anchors):
-        i = idxd(tid, num_anchors) # number of batches
-        j = idxm(tid, num_anchors) # number of anchors
-        with ib.if_scope(j < elem_per_thread):
-            idx[tid] = idx_in[tid]
-        with ib.else_scope():
-            idx[tid] = idx_in[tid] + partial[i * new_range + idxd(j, elem_per_thread) - 1]
+
+    # initialize prefix_sum
+    with ib.if_scope(tid < batch_size * num_anchors):
+        prefix_sum[tid] = 0
+    with ib.if_scope(tid < batch_size * num_anchors):
+        i = idxd(tid, num_anchors)
+        j = idxm(tid, num_anchors)
+        with ib.for_range(0, j) as r:
+            prefix_sum[tid] += flag[i * num_anchors + r]
 
     return ib.get()
 
-def get_valid_counts_ir(data, flag, idx, valid_count, out):
-    """Low level IR to get valid count of bounding boxes
-    given a score threshold. Also moves valid boxes to the
+
+def out_rewrite(data, flag, prefix_sum, valid_count, out):
+    """Low level IR to move valid boxes to the
     top of input data.
 
     Parameters
@@ -270,11 +182,12 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
     flag : Buffer
         2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
 
-    idx : Buffer
-        2D Buffer of valid data indices with shape [batch_size, num_anchors].
+    prefix_sum : Buffer
+        2D Buffer of prefix sum of flags indicating new locations of valid boxes
+        with same shape as flag.
 
     valid_count : Buffer
-        1-D buffer for valid number of boxes.
+        1D buffer for valid number of boxes with shape [batch_size, ].
 
     out : Buffer
         Rearranged data buffer.
@@ -284,28 +197,28 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
     stmt : Stmt
         The result IR statement.
     """
-    batch_size = data.shape[0]
-    num_anchors = data.shape[1]
-    elem_length = data.shape[2]
-    size = batch_size * num_anchors * elem_length
+    batch_size = out.shape[0]
+    num_anchors = out.shape[1]
+    elem_length = out.shape[2]
 
     ib = tvm.ir_builder.create()
 
+    one = tvm.const(1, dtype=out.dtype)
     data = ib.buffer_ptr(data)
     flag = ib.buffer_ptr(flag)
-    idx = ib.buffer_ptr(idx)
     valid_count = ib.buffer_ptr(valid_count)
+    prefix_sum = ib.buffer_ptr(prefix_sum)
     out = ib.buffer_ptr(out)
 
-    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    max_threads = int(tvm.target.Target.current(
+        allow_none=False).max_num_threads)
     nthread_tx = max_threads
-    nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1
+    nthread_bx = batch_size * num_anchors // max_threads + 1
     tx = tvm.thread_axis("threadIdx.x")
     bx = tvm.thread_axis("blockIdx.x")
     ib.scope_attr(tx, "thread_extent", nthread_tx)
     ib.scope_attr(bx, "thread_extent", nthread_bx)
     tid = bx * max_threads + tx
-
     idxd = tvm.indexdiv
     idxm = tvm.indexmod
 
@@ -313,17 +226,15 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
         i = idxd(tid, num_anchors)
         j = idxm(tid, num_anchors)
         base_idx = i * num_anchors * elem_length
-        with ib.if_scope(flag[tid] > 0):
+        with ib.if_scope(tvm.all(flag[tid] > 0, prefix_sum[tid] >= 0,
+                                 prefix_sum[tid] < num_anchors)):
+            with ib.for_range(0, elem_length) as k:
+                out[base_idx + prefix_sum[tid] * elem_length +
+                    k] = data[tid * elem_length + k]
+        with ib.if_scope(j >= valid_count[i]):
             with ib.for_range(0, elem_length) as k:
-                with ib.if_scope(base_idx + (idx[tid] - 1) * elem_length + k < size):
-                    out[base_idx + (idx[tid] - 1) * elem_length + k] =\
-                    data[base_idx + j * elem_length + k]
-        with ib.if_scope(j == 0):
-            valid_count[i] = idx[tid + num_anchors - 1]
-        with ib.if_scope(j >= idx[i * num_anchors + num_anchors - 1]):
-            with ib.for_range(0, elem_length) as l:
-                with ib.if_scope(tid * elem_length + l < size):
-                    out[tid * elem_length + l] = -1.0
+                out[tid * elem_length + k] = -one
+
     return ib.get()
 
 
@@ -356,56 +267,47 @@ def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1):
     """
     batch_size = data.shape[0]
     num_anchors = data.shape[1]
-    max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
-    elem_per_thread = num_anchors // max_threads + 1
-    new_range = num_anchors // elem_per_thread + 1
+    data_buf = api.decl_buffer(
+        data.shape, data.dtype, "data_buf", data_alignment=8)
+    valid_count_buf = api.decl_buffer(
+        (batch_size,), "int32", "valid_count_buf", data_alignment=8)
     temp_flag_buf = api.decl_buffer(
         (batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8)
-    temp_idx_buf = api.decl_buffer(
-        (batch_size, num_anchors,), "int32", "temp_idx", data_alignment=8)
     temp_partial_buf = api.decl_buffer(
-        (batch_size, new_range), "int32", "temp_partial", data_alignment=8)
-    data_buf = api.decl_buffer(
-        data.shape, data.dtype, "data_buf", data_alignment=8)
+        (batch_size, num_anchors), "int32", "temp_partial", data_alignment=8)
+    out_buf = api.decl_buffer(
+        data.shape, data.dtype, "out_buf", data_alignment=8)
 
-    temp_flag, temp_idx = \
-        tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data],
-                   lambda ins, outs: get_valid_counts_pre(
-                       ins[0], outs[0], outs[1], score_threshold, id_index, score_index),
-                   dtype=["int32", "int32"],
-                   out_buffers=[temp_flag_buf, temp_idx_buf],
-                   name="get_valid_counts_phase_one")
-    temp_idx_new, temp_partial = \
-        tvm.extern([(batch_size, num_anchors,), (batch_size, new_range)], [data, temp_idx],
-                   lambda ins, outs: get_valid_counts_upsweep(
-                       ins[0], ins[1], outs[0], outs[1]),
-                   dtype=["int32", "int32"],
-                   out_buffers=[temp_idx_buf, temp_partial_buf],
-                   name="get_valid_counts_phase_two")
-    temp_partial_new = \
-        tvm.extern([(batch_size, new_range)], [data, temp_partial],
-                   lambda ins, outs: get_valid_counts_scan(
-                       ins[0], ins[1], outs[0]),
-                   dtype=["int32"],
-                   out_buffers=[temp_partial_buf],
-                   name="get_valid_counts_phase_three")
-    temp_idx_final = \
-        tvm.extern([(batch_size, num_anchors)], [data, temp_idx_new, temp_partial_new],
-                   lambda ins, outs: get_valid_counts_downsweep(
-                       ins[0], ins[1], ins[2], outs[0]),
-                   dtype=["int32"],
-                   out_buffers=[temp_idx_buf],
-                   name="get_valid_counts_phase_four")
-    valid_count, out_tensor = \
-       tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final],
-               lambda ins, outs: get_valid_counts_ir(
-                ins[0], ins[1], ins[2], outs[0], outs[1]),
-            dtype=["int32", data.dtype],
-            in_buffers=[data_buf, temp_flag_buf, temp_idx_buf],
-            name="get_valid_counts_phase_five",
+    valid_count, temp_flag = \
+        tvm.extern([(batch_size,), (batch_size, num_anchors)], [data],
+                   lambda ins, outs: get_valid_counts_ir(
+            ins[0], outs[0], outs[1], score_threshold, id_index, score_index),
+            dtype=["int32", "int32"],
+            in_buffers=[data_buf],
+            out_buffers=[valid_count_buf, temp_flag_buf],
+            name="get_valid_counts",
             tag="get_valid_counts_gpu")
 
-    return [valid_count, out_tensor]
+    temp_partial = \
+        tvm.extern([(batch_size, num_anchors)], [temp_flag],
+                   lambda ins, outs: flag_scan(
+            ins[0], outs[0]),
+            dtype=["int32"],
+            in_buffers=[temp_flag_buf],
+            out_buffers=[temp_partial_buf],
+            name="flag_scan")
+
+    out = \
+        tvm.extern([data.shape], [data, temp_flag, temp_partial, valid_count],
+                   lambda ins, outs: out_rewrite(
+            ins[0], ins[1], ins[2], ins[3], outs[0]),
+            dtype=[data.dtype],
+            in_buffers=[data_buf, temp_flag_buf,
+                        temp_partial_buf, valid_count_buf],
+            out_buffers=[out_buf],
+            name="out_rewrite")
+
+    return [valid_count, out]
 
 
 def nms_ir(data, sorted_index, valid_count, out, box_indices,
@@ -479,7 +381,8 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
     valid_count = ib.buffer_ptr(valid_count)
     out = ib.buffer_ptr(out)
     box_indices = ib.buffer_ptr(box_indices)
-    num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
+    num_valid_boxes = ib.allocate(
+        "int32", (1,), name="num_valid_boxes", scope="local")
 
     max_threads = int(
         tvm.target.Target.current(allow_none=False).max_num_threads)
@@ -491,26 +394,29 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
     ib.scope_attr(bx, "thread_extent", nthread_bx)
     j = bx * max_threads + tx
 
-    iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold)
+    iou_threshold = tvm.make.node(
+        "FloatImm", dtype="float32", value=iou_threshold)
     top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
     coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start)
     id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
     score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
-    force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0)
+    force_suppress = tvm.make.node(
+        "IntImm", dtype="int32", value=1 if force_suppress else 0)
 
     with ib.for_range(0, batch_size, for_type="unroll") as i:
         base_idx = i * num_anchors * box_data_length
         with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)):
             # Reorder output
-            nkeep = if_then_else( \
-                    tvm.all(top_k > 0, top_k < valid_count[i]),
-                    top_k, valid_count[i])
+            nkeep = if_then_else(
+                tvm.all(top_k > 0, top_k < valid_count[i]),
+                top_k, valid_count[i])
             with ib.if_scope(j < nkeep):
                 with ib.for_range(0, box_data_length) as k:
                     out[(base_idx + j * box_data_length + k)] = \
-                    data[(base_idx + sorted_index[i * num_anchors + j] \
-                    * box_data_length + k)]
-                box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
+                        data[(base_idx + sorted_index[i * num_anchors + j]
+                              * box_data_length + k)]
+                box_indices[i * num_anchors +
+                            j] = sorted_index[i * num_anchors + j]
             with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
                 with ib.if_scope(j < valid_count[i] - nkeep):
                     with ib.for_range(0, box_data_length) as k:
@@ -519,16 +425,18 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
             # Apply nms
             with ib.for_range(0, valid_count[i]) as k:
                 offset_k = k * box_data_length
-                with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0, \
-                    tvm.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0))):
+                with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0,
+                                         tvm.any(id_index < 0, out[base_idx +
+                                                                   offset_k + id_index] >= 0))):
                     with ib.if_scope(j < valid_count[i]):
                         offset_j = j * box_data_length
-                        with ib.if_scope(tvm.all(j > k, \
-                            out[base_idx + offset_j + score_index] > 0, \
-                                                 tvm.any(id_index < 0, \
-                                                    out[base_idx + offset_j + id_index] >= 0), \
-                                                tvm.any(force_suppress > 0, id_index < 0, \
-                                                         out[base_idx + offset_k + id_index] == \
+                        with ib.if_scope(tvm.all(j > k,
+                                                 out[base_idx + offset_j +
+                                                     score_index] > 0,
+                                                 tvm.any(id_index < 0,
+                                                         out[base_idx + offset_j + id_index] >= 0),
+                                                 tvm.any(force_suppress > 0, id_index < 0,
+                                                         out[base_idx + offset_k + id_index] ==
                                                          out[base_idx + offset_j + id_index]))):
                             iou = calculate_overlap(out, base_idx + offset_j + coord_start,
                                                     base_idx + offset_k + coord_start)
@@ -541,12 +449,14 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
             with ib.if_scope(j < valid_count[i]):
                 offset_j = j * box_data_length
                 with ib.for_range(0, box_data_length) as k:
-                    out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
+                    out[(base_idx + offset_j + k)
+                        ] = data[base_idx + offset_j + k]
                 box_indices[i * num_anchors + j] = j
         # Set invalid entry to be -1
         with ib.if_scope(j < num_anchors - valid_count[i]):
             with ib.for_range(0, box_data_length) as k:
-                out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
+                out[base_idx + (j + valid_count[i]) *
+                    box_data_length + k] = -1.0
             box_indices[i * num_anchors + j + valid_count[i]] = -1
         # Only return max_output_size number of valid boxes
         num_valid_boxes[0] = 0
@@ -671,7 +581,7 @@ def invalid_to_bottom_ir(data, flag, idx, out):
             with ib.if_scope(flag[i * num_anchors + j] > 0):
                 with ib.for_range(0, elem_length) as k:
                     out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \
-                    = data[base_idx + j * elem_length + k]
+                        = data[base_idx + j * elem_length + k]
     return ib.get()
 
 
@@ -756,8 +666,10 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
                                       "valid_count_buf", data_alignment=4)
     score_axis = score_index
     score_shape = (batch_size, num_anchors)
-    score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
-    sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
+    score_tensor = tvm.compute(
+        score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
+    sort_tensor = argsort(
+        score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
 
     sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
                                       "sort_tensor_buf", data_alignment=8)
@@ -795,7 +707,8 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
                                              ins[0], outs[0], outs[1]),
                                          dtype=["int32", "int32"],
                                          in_buffers=[out_buf],
-                                         out_buffers=[temp_flag_buf, temp_idx_buf],
+                                         out_buffers=[
+                                             temp_flag_buf, temp_idx_buf],
                                          name="invalid_to_bottom_phase_one")
 
         output = tvm.extern([data.shape], [out, temp_flag, temp_idx],
index a081f07..85e4180 100644 (file)
@@ -67,8 +67,8 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
         tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
 
     for device in ['llvm', 'cuda', 'opencl']:
-        # Disable gpu test for now
-        if device != "llvm":
+        # Disable opencl test for now
+        if device != "llvm" and device != "cuda":
             continue
         check_device(device)