[TOPI][OP] Support Faster-RCNN Proposal OP on CPU (#4297)
authorZhao Wu <wuzhaozju@gmail.com>
Wed, 13 Nov 2019 06:11:38 +0000 (14:11 +0800)
committerWuwei Lin <wuwei@apache.org>
Wed, 13 Nov 2019 06:11:38 +0000 (01:11 -0500)
* Support Proposal operator on CPU.

* PyLint space issue

* PyLint space issue

* Pylint singleton-comparison issue

tests/python/relay/test_op_level5.py
topi/python/topi/vision/rcnn/proposal.py
topi/tests/python/test_topi_vision.py

index fb5dbcc..f744746 100644 (file)
@@ -424,7 +424,7 @@ def test_proposal():
 
         func = relay.Function([cls_prob, bbox_pred, im_info], z)
         func = run_infer_type(func)
-        for target in ['cuda']:
+        for target in ['llvm', 'cuda']:
             if not tvm.module.enabled(target):
                 print("Skip test because %s is not enabled." % target)
                 continue
index 1df25a0..507d464 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name
+# pylint: disable=invalid-name, singleton-comparison
 """Proposal operator"""
 import math
 import tvm
-
+from ...util import get_const_tuple, get_const_int
+from ...sort import argsort
 
 def generate_anchor(ratio, scale, base_size):
     """Generate anchor"""
@@ -60,6 +61,261 @@ def reg_iou(x1, y1, x2, y2, dx1, dy1, dx2, dy2):
     pred_y2 = y2 + dy2
     return pred_x1, pred_y1, pred_x2, pred_y2
 
+def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, ratios,
+                    feature_stride, rpn_min_size, iou_loss):
+    """Predict bounding boxes based on anchors, scores and deltas.
+
+    Parameters
+    ----------
+    cls_prob_buf : tvm.schedule.Buffer
+        4-D with shape [batch, 2 * num_anchors, height, width]
+
+    bbox_pred_buf : tvm.schedule.Buffer
+        4-D with shape [batch, 4 * num_anchors, height, width]
+
+    im_info_buf : tvm.schedule.Buffer
+        2-D with shape [batch, 3]
+
+    out_buf : tvm.schedule.Buffer
+        3-D with shape [batch, num_bbox, 5]
+        The last dimension is in format of [w_start, h_start, w_end, h_end, score]
+
+    scales : list/tuple of float
+        Scales of anchor windoes.
+
+    ratios : list/tuple of float
+        Ratios of anchor windoes.
+
+    feature_stride : int
+        The size of the receptive field each unit in the convolution layer of the rpn, for example
+        the product of all stride's prior to this layer.
+
+    rpn_min_size : int
+        Minimum height or width in proposal.
+
+    iou_loss : bool
+        Usage of IoU loss.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape)
+    num_anchors //= 2
+    ib = tvm.ir_builder.create()
+
+    p_score = ib.buffer_ptr(cls_prob_buf)
+    p_delta = ib.buffer_ptr(bbox_pred_buf)
+    p_im_info = ib.buffer_ptr(im_info_buf)
+    p_out = ib.buffer_ptr(out_buf)
+
+    idxm = tvm.indexmod
+    idxd = tvm.indexdiv
+
+    with ib.for_range(0, batch * height * width) as tid:
+        w = idxm(tid, width)
+        h = idxm(idxd(tid, width), height)
+        b = idxd(idxd(tid, width), height)
+
+        for k in range(num_anchors):
+            out_index = tid * num_anchors + k
+            ratio = ratios[k // len(scales)]
+            scale = scales[k % len(scales)]
+            anchor = generate_anchor(ratio, scale, feature_stride)
+            im_height = p_im_info[b * 3]
+            im_width = p_im_info[b * 3 + 1]
+            x1 = anchor[0] + w * feature_stride
+            y1 = anchor[1] + h * feature_stride
+            x2 = anchor[2] + w * feature_stride
+            y2 = anchor[3] + h * feature_stride
+
+            delta = [p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)]
+                     for i in range(4)]
+            regression_func = reg_iou if iou_loss else reg_bbox
+            pred_x1, pred_y1, pred_x2, pred_y2 = regression_func(x1, y1, x2, y2, *delta)
+
+            pred_x1 = tvm.max(tvm.min(pred_x1, im_width - 1.0), 0.0)
+            pred_y1 = tvm.max(tvm.min(pred_y1, im_height - 1.0), 0.0)
+            pred_x2 = tvm.max(tvm.min(pred_x2, im_width - 1.0), 0.0)
+            pred_y2 = tvm.max(tvm.min(pred_y2, im_height - 1.0), 0.0)
+
+            real_height = (im_height / feature_stride).astype('int32')
+            real_width = (im_width / feature_stride).astype('int32')
+
+            bbox_w = pred_x2 - pred_x1 + 1.0
+            bbox_h = pred_y2 - pred_y1 + 1.0
+            min_size = p_im_info[b * 3 + 2] * rpn_min_size
+
+            pred_score = p_score[((b * num_anchors * 2 + num_anchors + k) * height + h) * width + w]
+            pred_score = tvm.expr.Select(tvm.any(h >= real_height, w >= real_width),
+                                         -1.0, pred_score)
+            p_out[out_index * 5 + 0] = pred_x1
+            p_out[out_index * 5 + 1] = pred_y1
+            p_out[out_index * 5 + 2] = pred_x2
+            p_out[out_index * 5 + 3] = pred_y2
+            p_out[out_index * 5 + 4] = pred_score
+
+            with ib.if_scope(tvm.any(bbox_w < min_size, bbox_h < min_size)):
+                p_out[out_index * 5 + 0] -= min_size / 2.0
+                p_out[out_index * 5 + 1] -= min_size / 2.0
+                p_out[out_index * 5 + 2] += min_size / 2.0
+                p_out[out_index * 5 + 3] += min_size / 2.0
+                p_out[out_index * 5 + 4] = -1.0
+
+    return ib.get()
+
+
+def argsort_ir(data_buf, out_index_buf):
+    """Batched odd-even transposition sort.
+
+    Parameters
+    ----------
+    data_buf : tvm.schedule.Buffer
+        2-D with shape [batch, num_bbox]
+
+    out_index_buf : tvm.schedule.Buffer
+        2-D with shape [batch, num_bbox]. Indices of data in sorted order.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch, num_bbox = get_const_tuple(data_buf.shape)
+    ib = tvm.ir_builder.create()
+    p_data = ib.buffer_ptr(data_buf)
+    index_out = ib.buffer_ptr(out_index_buf)
+    temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
+    temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
+    idxm = tvm.indexmod
+    with ib.for_range(0, batch, for_type="unroll") as b:
+        start = b * num_bbox
+        for i in range(2):
+            with ib.for_range(0, (num_bbox + 1) // 2) as tid:
+                bbox_id = tid * 2 + i
+                with ib.if_scope(bbox_id < num_bbox):
+                    index_out[start + bbox_id] = bbox_id
+        with ib.for_range(0, num_bbox) as k:
+            with ib.for_range(0, (num_bbox + 1) // 2) as tid:
+                offset = start + 2 * tid + idxm(k, 2)
+                with ib.if_scope(tvm.all(offset + 1 < num_bbox,
+                                         p_data[offset] < p_data[offset + 1])):
+                    temp_data[0] = p_data[offset]
+                    p_data[offset] = p_data[offset + 1]
+                    p_data[offset + 1] = temp_data[0]
+                    temp_index[0] = index_out[offset]
+                    index_out[offset] = index_out[offset + 1]
+                    index_out[offset + 1] = temp_index[0]
+    return ib.get()
+
+
+def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
+    """Non-maximum supression.
+
+    Parameters
+    ----------
+    sorted_bbox_buf : tvm.schedule.Buffer
+        3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
+        [w_start, h_start, w_end, h_end, score].
+
+    out_buf : tvm.schedule.Buffer
+        2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.
+
+    nms_threshold : float
+        Non-maximum suppression threshold.
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
+        """Calculate overlap of two boxes.
+        """
+        w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
+                    - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]) + 1.0)
+        h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
+                    - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]) + 1.0)
+        i = w * h
+        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0) * \
+            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0) + \
+            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0) * \
+            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0) - i
+        return i / u
+
+    batch, num_bbox = get_const_tuple(out_buf.shape)
+    ib = tvm.ir_builder.create()
+    p_data = ib.buffer_ptr(sorted_bbox_buf)
+    p_out = ib.buffer_ptr(out_buf)
+    with ib.for_range(0, batch, for_type="unroll", name="n") as b:
+        base_idx = b * num_bbox
+        for i in range(num_bbox):
+            p_out[base_idx + i] = False
+        with ib.for_range(0, num_bbox - 1) as l:
+            with ib.for_range(0, num_bbox) as i:
+                with ib.if_scope(tvm.all(i < num_bbox, i > l, p_out[base_idx + l] == False)):
+                    iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5)
+                    with ib.if_scope(iou > nms_threshold):
+                        p_out[base_idx + i] = True
+    return ib.get()
+
+
+def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
+    """Copy output after applying nms to continuous memory.
+
+    Parameters
+    ----------
+    sorted_bbox_buf : tvm.schedule.Buffer
+        3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
+        [w_start, h_start, w_end, h_end, score].
+
+    remove_mask_buf : tvm.schedule.Buffer
+        2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.
+
+    out_buf : tvm.schedule.Buffer
+        2-D with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
+        [batch_index, w_start, h_start, w_end, h_end].
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch, num_bbox, _ = get_const_tuple(sorted_bbox_buf.shape)
+    rpn_post_nms_top_n = get_const_int(out_buf.shape[0]) // batch
+    ib = tvm.ir_builder.create()
+    i = ib.allocate('int32', (batch,), 'i', scope='local')
+    p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf)
+    p_remove = ib.buffer_ptr(remove_mask_buf)
+    p_out = ib.buffer_ptr(out_buf)
+
+    nkeep = ib.allocate('int32', (batch,), 'nkeep', scope='local')
+
+    with ib.for_range(0, batch) as b:
+        nkeep[b] = 0
+        i[b] = 0
+
+    with ib.for_range(0, num_bbox) as j:
+        with ib.for_range(0, batch) as b:
+            with ib.if_scope(p_remove[b * num_bbox + j] == False):
+                nkeep[b] += 1
+    with ib.for_range(0, batch) as b:
+        with ib.if_scope(nkeep[b] > 0):
+            with ib.for_range(0, tvm.ceil(
+                tvm.const(rpn_post_nms_top_n, 'float32') / nkeep[b]).astype('int32')):
+                with ib.for_range(0, num_bbox) as j:
+                    offset_j = (b * num_bbox + j) * 5
+                    offset_i = (b * rpn_post_nms_top_n + i[b]) * 5
+                    with ib.if_scope(tvm.all(i[b] < rpn_post_nms_top_n,
+                                             p_remove[(b*num_bbox+j)] == False)):
+                        p_out[offset_i] = tvm.expr.Cast('float32', b)
+                        with ib.for_range(0, 4, for_type='unroll') as k:
+                            p_out[offset_i + k + 1] = p_sorted_bbox[offset_j + k]
+                        i[b] = i[b] + 1
+
+    body = ib.get()
+    return body
 
 @tvm.target.generic_func
 def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
@@ -109,4 +365,25 @@ def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, thres
         [batch_index, w_start, h_start, w_end, h_end].
     """
     # pylint: disable=unused-argument
-    raise ValueError("missing register for topi.vision.rcnn.proposal")
+    batch, _, height, width = get_const_tuple(cls_prob.shape)
+    num_anchors = len(scales) * len(ratios)
+    num_bbox = height * width * num_anchors
+    rpn_pre_nms_top_n = min(rpn_pre_nms_top_n, num_bbox) if rpn_pre_nms_top_n > 0 else num_bbox
+
+    bbox = tvm.extern((batch, num_bbox, 5), [cls_prob, bbox_pred, im_info], lambda ins, outs:
+                      predict_bbox_ir(ins[0], ins[1], ins[2], outs[0], scales, ratios,
+                                      feature_stride, rpn_min_size, iou_loss),
+                      dtype=bbox_pred.dtype)
+    score = tvm.compute((batch, num_bbox), lambda b, i: bbox[b, i, 4], tag='bbox_score')
+    valid_count_shape = (1,)
+    valid_count = tvm.compute(valid_count_shape, lambda i: num_bbox)
+    sorted_index = argsort(score, valid_count=valid_count, axis=1, is_ascend=False)
+    sorted_bbox = tvm.compute((batch, rpn_pre_nms_top_n, 5),
+                              lambda b, i, j: bbox[b, sorted_index[b, i], j], tag='sorted_bbox')
+    nms_remove_mask = tvm.extern((batch, rpn_pre_nms_top_n), [sorted_bbox],
+                                 lambda ins, outs: nms_ir(ins[0], outs[0], threshold),
+                                 dtype='bool')
+    nms_out = tvm.extern((batch * rpn_post_nms_top_n, 5), [sorted_bbox, nms_remove_mask],
+                         lambda ins, outs: prepare_output_ir(ins[0], ins[1], outs[0]),
+                         dtype=sorted_bbox.dtype)
+    return nms_out
index 08b6d2e..a081f07 100644 (file)
@@ -378,7 +378,7 @@ def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
             f(tvm_cls_prob, tvm_bbox_pred, tvm_im_info, tvm_out)
             tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-4)
 
-    for device in ['cuda']:
+    for device in ['llvm', 'cuda']:
         check_device(device)