import tvm
from tvm import api
-from tvm.generic import cast
-from tvm.intrin import if_then_else, log, power
+from tvm.intrin import if_then_else
from topi.vision import non_max_suppression, get_valid_counts
from .sort import argsort
from .. import tag
-def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index):
- """Low level IR to Prepare get valid count of bounding boxes
- given a score threshold. Also moves valid boxes to the
+def cuda_atomic_add_rule(op):
+ if op.dtype == "float32":
+ return tvm.call_pure_extern("float32", "atomicAdd", op.args[0], op.args[1])
+ if op.dtype == "float64":
+ return tvm.call_pure_extern("float64", "atomicAdd", op.args[0], op.args[1])
+ if op.dtype == "int32":
+ return tvm.call_pure_extern("int32", "atomicAdd", op.args[0], op.args[1])
+ raise RuntimeError("only support int32, float32 and float64")
+
+
+tvm.target.intrin.register_intrin_rule(
+ "cuda", "atomic_add", cuda_atomic_add_rule, override=True)
+
+
+def atomic_add(x, y):
+ return tvm.call_pure_intrin(y.dtype, "atomic_add", x, y)
+
+
+def get_valid_counts_ir(data, valid_count, flag, score_threshold, id_index, score_index):
+ """Low level IR to get valid count of bounding boxes
+ given a score threshold. Also prepares to move valid boxes to the
top of input data.
Parameters
----------
- data: Buffer
- 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
+ data : Buffer
+ Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length].
+
+ valid_count : Buffer
+ 1D buffer for valid number of boxes with shape [batch_size, ].
flag : Buffer
2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
- idx : Buffer
- 2D Buffer of valid data indices with shape [batch_size, num_anchors].
-
score_threshold : float32
Lower limit of score for valid bounding boxes.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
- box_data_length = data.shape[2]
+ elem_length = data.shape[2]
ib = tvm.ir_builder.create()
data = ib.buffer_ptr(data)
+
+ valid_count = ib.buffer_ptr(valid_count)
flag = ib.buffer_ptr(flag)
- idx = ib.buffer_ptr(idx)
- score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold)
+ atomic_add_return = ib.allocate(
+ valid_count.dtype, (1,), name='atomic_add_return', scope='local')
+ one_count = tvm.const(1, dtype=valid_count.dtype)
+ score_threshold = tvm.make.node(
+ "FloatImm", dtype="float32", value=score_threshold)
id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
- max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ max_threads = int(tvm.target.Target.current(
+ allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
+ idxd = tvm.indexdiv
+ # initialize valid_count
+ with ib.if_scope(tid < batch_size):
+ valid_count[tid] = 0
+ # initialize flag
with ib.if_scope(tid < batch_size * num_anchors):
- with ib.if_scope(tvm.all(data[tid * box_data_length + score_index] > score_threshold, \
- tvm.any(id_index < 0, data[tid * box_data_length + id_index] >= 0))):
+ flag[tid] = 0
+ with ib.if_scope(tid < batch_size * num_anchors):
+ i = idxd(tid, num_anchors)
+ with ib.if_scope(tvm.all(data[tid * elem_length + score_index] > score_threshold,
+ tvm.any(id_index < 0, data[tid * elem_length + id_index] >= 0))):
flag[tid] = 1
- idx[tid] = 1
- with ib.else_scope():
- flag[tid] = 0
- idx[tid] = 0
+ atomic_add_return[0] = atomic_add(tvm.call_pure_intrin("handle", "tvm_address_of",
+ valid_count[i]), one_count)
return ib.get()
-def get_valid_counts_upsweep(data, idx_in, idx, partial):
- """Low level IR of first step of scan: unsweep.
-
- Parameters
- ----------
- data: Buffer
- 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
-
- idx_in : Buffer
- 2D Buffer of valid data indices with shape [batch_size, num_anchors].
-
- idx : Buffer
- 2D Buffer of valid data indices with shape [batch_size, num_anchors].
-
- partial : Buffer
- 2D Buffer of valid data indices with shape [batch_size, new_range].
-
- Returns
- -------
- stmt : Stmt
- The result IR statement.
- """
- batch_size = data.shape[0]
- num_anchors = data.shape[1]
- ib = tvm.ir_builder.create()
- data = ib.buffer_ptr(data)
- idx_in = ib.buffer_ptr(idx_in)
- idx = ib.buffer_ptr(idx)
- partial = ib.buffer_ptr(partial)
- max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
- elem_per_thread = num_anchors // max_threads + 1
- nthread_tx = max_threads
- nthread_bx = batch_size
- tx = tvm.thread_axis("threadIdx.x")
- bx = tvm.thread_axis("blockIdx.x")
- ib.scope_attr(tx, "thread_extent", nthread_tx)
- ib.scope_attr(bx, "thread_extent", nthread_bx)
- new_range = num_anchors // elem_per_thread + 1
- # Scan: Upsweep:
- with ib.if_scope(tvm.all(bx < batch_size, tx < new_range)):
- with ib.for_range(0, elem_per_thread) as i:
- with ib.if_scope(bx * num_anchors + \
- tx * elem_per_thread + i < batch_size * num_anchors):
- with ib.if_scope(i == 0):
- partial[bx * new_range + tx] = idx_in[bx * num_anchors + tx * elem_per_thread]
- idx[bx * num_anchors + tx * elem_per_thread] = \
- idx_in[bx * num_anchors + tx * elem_per_thread]
- with ib.else_scope():
- partial[bx * new_range + tx] += \
- idx_in[bx * num_anchors + tx * elem_per_thread + i]
- idx[bx * num_anchors + tx * elem_per_thread + i] = \
- idx[bx * num_anchors + tx * elem_per_thread + i - 1] + \
- idx_in[bx * num_anchors + tx * elem_per_thread + i]
- ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
- tvm.convert(['shared']),
- tvm.expr.Call.Intrinsic, None, 0))
- return ib.get()
-def get_valid_counts_scan(data, partial_in, partial):
- """Low level IR to do scan.
+def flag_scan(flag, prefix_sum):
+ """Low level IR to calculate correct positions for valid boxes.
Parameters
----------
- data: Buffer
- 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
-
- idx_in : Buffer
- 2D Buffer of valid data indices with shape [batch_size, num_anchors].
-
- idx : Buffer
- 2D Buffer of valid data indices with shape [batch_size, num_anchors].
+ flag : Buffer
+ 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
- partial : Buffer
- 2D Buffer of valid data indices with shape [batch_size, new_range].
+ prefix_sum : Buffer
+ 2D Buffer of prefix sum of flags indicating new locations of valid boxes
+ with same shape as flag.
Returns
-------
stmt : Stmt
The result IR statement.
"""
- batch_size = data.shape[0]
- num_anchors = data.shape[1]
- ib = tvm.ir_builder.create()
- partial_in = ib.buffer_ptr(partial_in)
- partial = ib.buffer_ptr(partial)
- max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
- elem_per_thread = num_anchors // max_threads + 1
- nthread_tx = max_threads
- nthread_bx = batch_size
- tx = tvm.thread_axis("threadIdx.x")
- bx = tvm.thread_axis("blockIdx.x")
- ib.scope_attr(tx, "thread_extent", nthread_tx)
- ib.scope_attr(bx, "thread_extent", nthread_bx)
- var = tvm.make.node("FloatImm", dtype="float32", value=2)
- new_range = num_anchors // elem_per_thread + 1
- iteration = cast(log(cast(new_range, "float32")) / math.log(2), "int32")
- # Scan: Kogge-Stone adder
- with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))):
- with ib.for_range(0, iteration) as k:
- with ib.if_scope(k == 0):
- with ib.if_scope(tvm.all(tx > 0, tx < tvm.min(new_range, num_anchors))):
- partial[bx * new_range + tx] = \
- partial_in[bx * new_range + tx] + partial_in[bx * new_range + tx - 1]
- with ib.else_scope():
- partial[bx * new_range] = partial_in[bx * new_range]
- with ib.else_scope():
- with ib.if_scope(tvm.all(tx >= cast(power(var, k), "int32"), \
- tx < tvm.min(new_range, num_anchors))):
- partial[bx * new_range + tx] += \
- partial[bx * new_range + tx - cast(power(var, k), "int32")]
- ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
- tvm.convert(['shared']),
- tvm.expr.Call.Intrinsic, None, 0))
- return ib.get()
-
-def get_valid_counts_downsweep(data, idx_in, partial, idx):
- """Low level IR to do downsweep of scan.
-
- Parameters
- ----------
- data: Buffer
- 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
-
- idx_in : Buffer
- 2D Buffer of valid data indices with shape [batch_size, num_anchors].
+ batch_size = flag.shape[0]
+ num_anchors = flag.shape[1]
- partial : Buffer
- 2D Buffer of valid data indices with shape [batch_size, new_range].
+ ib = tvm.ir_builder.create()
- idx : Buffer
- 2D Buffer of valid data indices with shape [batch_size, num_anchors].
+ flag = ib.buffer_ptr(flag)
+ prefix_sum = ib.buffer_ptr(prefix_sum)
- Returns
- -------
- stmt : Stmt
- The result IR statement.
- """
- batch_size = data.shape[0]
- num_anchors = data.shape[1]
- ib = tvm.ir_builder.create()
- idx_in = ib.buffer_ptr(idx_in)
- idx = ib.buffer_ptr(idx)
- partial = ib.buffer_ptr(partial)
- max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
- elem_per_thread = num_anchors // max_threads + 1
+ max_threads = int(tvm.target.Target.current(
+ allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
- new_range = num_anchors // elem_per_thread + 1
idxd = tvm.indexdiv
idxm = tvm.indexmod
- # Scan: Downsweep:
- with ib. if_scope(tid < batch_size * num_anchors):
- i = idxd(tid, num_anchors) # number of batches
- j = idxm(tid, num_anchors) # number of anchors
- with ib.if_scope(j < elem_per_thread):
- idx[tid] = idx_in[tid]
- with ib.else_scope():
- idx[tid] = idx_in[tid] + partial[i * new_range + idxd(j, elem_per_thread) - 1]
+
+ # initialize prefix_sum
+ with ib.if_scope(tid < batch_size * num_anchors):
+ prefix_sum[tid] = 0
+ with ib.if_scope(tid < batch_size * num_anchors):
+ i = idxd(tid, num_anchors)
+ j = idxm(tid, num_anchors)
+ with ib.for_range(0, j) as r:
+ prefix_sum[tid] += flag[i * num_anchors + r]
return ib.get()
-def get_valid_counts_ir(data, flag, idx, valid_count, out):
- """Low level IR to get valid count of bounding boxes
- given a score threshold. Also moves valid boxes to the
+
+def out_rewrite(data, flag, prefix_sum, valid_count, out):
+ """Low level IR to move valid boxes to the
top of input data.
Parameters
flag : Buffer
2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
- idx : Buffer
- 2D Buffer of valid data indices with shape [batch_size, num_anchors].
+ prefix_sum : Buffer
+ 2D Buffer of prefix sum of flags indicating new locations of valid boxes
+ with same shape as flag.
valid_count : Buffer
- 1-D buffer for valid number of boxes.
+ 1D buffer for valid number of boxes with shape [batch_size, ].
out : Buffer
Rearranged data buffer.
stmt : Stmt
The result IR statement.
"""
- batch_size = data.shape[0]
- num_anchors = data.shape[1]
- elem_length = data.shape[2]
- size = batch_size * num_anchors * elem_length
+ batch_size = out.shape[0]
+ num_anchors = out.shape[1]
+ elem_length = out.shape[2]
ib = tvm.ir_builder.create()
+ one = tvm.const(1, dtype=out.dtype)
data = ib.buffer_ptr(data)
flag = ib.buffer_ptr(flag)
- idx = ib.buffer_ptr(idx)
valid_count = ib.buffer_ptr(valid_count)
+ prefix_sum = ib.buffer_ptr(prefix_sum)
out = ib.buffer_ptr(out)
- max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ max_threads = int(tvm.target.Target.current(
+ allow_none=False).max_num_threads)
nthread_tx = max_threads
- nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1
+ nthread_bx = batch_size * num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
-
idxd = tvm.indexdiv
idxm = tvm.indexmod
i = idxd(tid, num_anchors)
j = idxm(tid, num_anchors)
base_idx = i * num_anchors * elem_length
- with ib.if_scope(flag[tid] > 0):
+ with ib.if_scope(tvm.all(flag[tid] > 0, prefix_sum[tid] >= 0,
+ prefix_sum[tid] < num_anchors)):
+ with ib.for_range(0, elem_length) as k:
+ out[base_idx + prefix_sum[tid] * elem_length +
+ k] = data[tid * elem_length + k]
+ with ib.if_scope(j >= valid_count[i]):
with ib.for_range(0, elem_length) as k:
- with ib.if_scope(base_idx + (idx[tid] - 1) * elem_length + k < size):
- out[base_idx + (idx[tid] - 1) * elem_length + k] =\
- data[base_idx + j * elem_length + k]
- with ib.if_scope(j == 0):
- valid_count[i] = idx[tid + num_anchors - 1]
- with ib.if_scope(j >= idx[i * num_anchors + num_anchors - 1]):
- with ib.for_range(0, elem_length) as l:
- with ib.if_scope(tid * elem_length + l < size):
- out[tid * elem_length + l] = -1.0
+ out[tid * elem_length + k] = -one
+
return ib.get()
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
- max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
- elem_per_thread = num_anchors // max_threads + 1
- new_range = num_anchors // elem_per_thread + 1
+ data_buf = api.decl_buffer(
+ data.shape, data.dtype, "data_buf", data_alignment=8)
+ valid_count_buf = api.decl_buffer(
+ (batch_size,), "int32", "valid_count_buf", data_alignment=8)
temp_flag_buf = api.decl_buffer(
(batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8)
- temp_idx_buf = api.decl_buffer(
- (batch_size, num_anchors,), "int32", "temp_idx", data_alignment=8)
temp_partial_buf = api.decl_buffer(
- (batch_size, new_range), "int32", "temp_partial", data_alignment=8)
- data_buf = api.decl_buffer(
- data.shape, data.dtype, "data_buf", data_alignment=8)
+ (batch_size, num_anchors), "int32", "temp_partial", data_alignment=8)
+ out_buf = api.decl_buffer(
+ data.shape, data.dtype, "out_buf", data_alignment=8)
- temp_flag, temp_idx = \
- tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data],
- lambda ins, outs: get_valid_counts_pre(
- ins[0], outs[0], outs[1], score_threshold, id_index, score_index),
- dtype=["int32", "int32"],
- out_buffers=[temp_flag_buf, temp_idx_buf],
- name="get_valid_counts_phase_one")
- temp_idx_new, temp_partial = \
- tvm.extern([(batch_size, num_anchors,), (batch_size, new_range)], [data, temp_idx],
- lambda ins, outs: get_valid_counts_upsweep(
- ins[0], ins[1], outs[0], outs[1]),
- dtype=["int32", "int32"],
- out_buffers=[temp_idx_buf, temp_partial_buf],
- name="get_valid_counts_phase_two")
- temp_partial_new = \
- tvm.extern([(batch_size, new_range)], [data, temp_partial],
- lambda ins, outs: get_valid_counts_scan(
- ins[0], ins[1], outs[0]),
- dtype=["int32"],
- out_buffers=[temp_partial_buf],
- name="get_valid_counts_phase_three")
- temp_idx_final = \
- tvm.extern([(batch_size, num_anchors)], [data, temp_idx_new, temp_partial_new],
- lambda ins, outs: get_valid_counts_downsweep(
- ins[0], ins[1], ins[2], outs[0]),
- dtype=["int32"],
- out_buffers=[temp_idx_buf],
- name="get_valid_counts_phase_four")
- valid_count, out_tensor = \
- tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final],
- lambda ins, outs: get_valid_counts_ir(
- ins[0], ins[1], ins[2], outs[0], outs[1]),
- dtype=["int32", data.dtype],
- in_buffers=[data_buf, temp_flag_buf, temp_idx_buf],
- name="get_valid_counts_phase_five",
+ valid_count, temp_flag = \
+ tvm.extern([(batch_size,), (batch_size, num_anchors)], [data],
+ lambda ins, outs: get_valid_counts_ir(
+ ins[0], outs[0], outs[1], score_threshold, id_index, score_index),
+ dtype=["int32", "int32"],
+ in_buffers=[data_buf],
+ out_buffers=[valid_count_buf, temp_flag_buf],
+ name="get_valid_counts",
tag="get_valid_counts_gpu")
- return [valid_count, out_tensor]
+ temp_partial = \
+ tvm.extern([(batch_size, num_anchors)], [temp_flag],
+ lambda ins, outs: flag_scan(
+ ins[0], outs[0]),
+ dtype=["int32"],
+ in_buffers=[temp_flag_buf],
+ out_buffers=[temp_partial_buf],
+ name="flag_scan")
+
+ out = \
+ tvm.extern([data.shape], [data, temp_flag, temp_partial, valid_count],
+ lambda ins, outs: out_rewrite(
+ ins[0], ins[1], ins[2], ins[3], outs[0]),
+ dtype=[data.dtype],
+ in_buffers=[data_buf, temp_flag_buf,
+ temp_partial_buf, valid_count_buf],
+ out_buffers=[out_buf],
+ name="out_rewrite")
+
+ return [valid_count, out]
def nms_ir(data, sorted_index, valid_count, out, box_indices,
valid_count = ib.buffer_ptr(valid_count)
out = ib.buffer_ptr(out)
box_indices = ib.buffer_ptr(box_indices)
- num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
+ num_valid_boxes = ib.allocate(
+ "int32", (1,), name="num_valid_boxes", scope="local")
max_threads = int(
tvm.target.Target.current(allow_none=False).max_num_threads)
ib.scope_attr(bx, "thread_extent", nthread_bx)
j = bx * max_threads + tx
- iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold)
+ iou_threshold = tvm.make.node(
+ "FloatImm", dtype="float32", value=iou_threshold)
top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start)
id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
- force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0)
+ force_suppress = tvm.make.node(
+ "IntImm", dtype="int32", value=1 if force_suppress else 0)
with ib.for_range(0, batch_size, for_type="unroll") as i:
base_idx = i * num_anchors * box_data_length
with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)):
# Reorder output
- nkeep = if_then_else( \
- tvm.all(top_k > 0, top_k < valid_count[i]),
- top_k, valid_count[i])
+ nkeep = if_then_else(
+ tvm.all(top_k > 0, top_k < valid_count[i]),
+ top_k, valid_count[i])
with ib.if_scope(j < nkeep):
with ib.for_range(0, box_data_length) as k:
out[(base_idx + j * box_data_length + k)] = \
- data[(base_idx + sorted_index[i * num_anchors + j] \
- * box_data_length + k)]
- box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
+ data[(base_idx + sorted_index[i * num_anchors + j]
+ * box_data_length + k)]
+ box_indices[i * num_anchors +
+ j] = sorted_index[i * num_anchors + j]
with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
with ib.if_scope(j < valid_count[i] - nkeep):
with ib.for_range(0, box_data_length) as k:
# Apply nms
with ib.for_range(0, valid_count[i]) as k:
offset_k = k * box_data_length
- with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0, \
- tvm.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0))):
+ with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0,
+ tvm.any(id_index < 0, out[base_idx +
+ offset_k + id_index] >= 0))):
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
- with ib.if_scope(tvm.all(j > k, \
- out[base_idx + offset_j + score_index] > 0, \
- tvm.any(id_index < 0, \
- out[base_idx + offset_j + id_index] >= 0), \
- tvm.any(force_suppress > 0, id_index < 0, \
- out[base_idx + offset_k + id_index] == \
+ with ib.if_scope(tvm.all(j > k,
+ out[base_idx + offset_j +
+ score_index] > 0,
+ tvm.any(id_index < 0,
+ out[base_idx + offset_j + id_index] >= 0),
+ tvm.any(force_suppress > 0, id_index < 0,
+ out[base_idx + offset_k + id_index] ==
out[base_idx + offset_j + id_index]))):
iou = calculate_overlap(out, base_idx + offset_j + coord_start,
base_idx + offset_k + coord_start)
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
with ib.for_range(0, box_data_length) as k:
- out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
+ out[(base_idx + offset_j + k)
+ ] = data[base_idx + offset_j + k]
box_indices[i * num_anchors + j] = j
# Set invalid entry to be -1
with ib.if_scope(j < num_anchors - valid_count[i]):
with ib.for_range(0, box_data_length) as k:
- out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
+ out[base_idx + (j + valid_count[i]) *
+ box_data_length + k] = -1.0
box_indices[i * num_anchors + j + valid_count[i]] = -1
# Only return max_output_size number of valid boxes
num_valid_boxes[0] = 0
with ib.if_scope(flag[i * num_anchors + j] > 0):
with ib.for_range(0, elem_length) as k:
out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \
- = data[base_idx + j * elem_length + k]
+ = data[base_idx + j * elem_length + k]
return ib.get()
"valid_count_buf", data_alignment=4)
score_axis = score_index
score_shape = (batch_size, num_anchors)
- score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
- sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
+ score_tensor = tvm.compute(
+ score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
+ sort_tensor = argsort(
+ score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sort_tensor_buf", data_alignment=8)
ins[0], outs[0], outs[1]),
dtype=["int32", "int32"],
in_buffers=[out_buf],
- out_buffers=[temp_flag_buf, temp_idx_buf],
+ out_buffers=[
+ temp_flag_buf, temp_idx_buf],
name="invalid_to_bottom_phase_one")
output = tvm.extern([data.shape], [out, temp_flag, temp_idx],