From 287078c33db85d4f312d8d2457a064442d9d18c3 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 3 Jul 2019 10:08:04 -0700 Subject: [PATCH] Pre-allocate buffer for x86 roi_align (#3475) * Pre-allocate buffer for x86 roi_align * Fix typo --- topi/python/topi/x86/roi_align.py | 44 ++++++++++++++++++++++++----------- topi/tests/python/test_topi_vision.py | 1 + 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/topi/python/topi/x86/roi_align.py b/topi/python/topi/x86/roi_align.py index a8ad387..26b84be 100644 --- a/topi/python/topi/x86/roi_align.py +++ b/topi/python/topi/x86/roi_align.py @@ -16,14 +16,17 @@ # under the License. # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements """Non-maximum suppression operator for intel cpu""" +import math import tvm from tvm import hybrid from ..vision.rcnn import roi_align_nchw +from ..tensor import full +from ..util import get_const_tuple @hybrid.script -def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio): +def roi_align_nchw_ir(data, rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio): """Hybrid routing fo ROI align operator in NCHW layout. Parameters @@ -35,6 +38,12 @@ def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio): 2-D with shape [num_roi, 5]. The last dimension should be in format of [batch_index, w_start, h_start, w_end, h_end] + w_pc : tvm.Tensor or numpy NDArray + 3-D weight pre-calculation buffer + + pos_pc : tvm.Tensor or numpy NDArray + 3-D position pre-calculation buffer + pooled_size : tvm ConsExpr [out_height, out_width] @@ -57,9 +66,6 @@ def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio): pooled_size_h = pooled_size[0] pooled_size_w = pooled_size[1] output = output_tensor((num_rois, channels, pooled_size_h, pooled_size_w), data.dtype) - max_num_pc_index = height * width * pooled_size_h * pooled_size_w - w_pc = allocate((num_rois, max_num_pc_index, 4), data.dtype) - pos_pc = allocate((num_rois, max_num_pc_index, 4), "int32") for n in parallel(num_rois): roi_batch_index = int32(rois[n, 0]) @@ -76,18 +82,16 @@ def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio): roi_bin_grid_h = sample_ratio roi_bin_grid_w = roi_bin_grid_h - div_h = roi_h / pooled_size_h - div_w = roi_w / pooled_size_w - rounded_div_h = int32(div_h) * 1.0 - rounded_div_w = int32(div_w) * 1.0 + rounded_bin_h = int32(bin_h) * 1.0 + rounded_bin_w = int32(bin_w) * 1.0 if sample_ratio <= 0: # Cannot use ceil function since hybrid script # doesn't support Call as indexing - roi_bin_grid_h = int32(div_h) - roi_bin_grid_w = int32(div_w) - if rounded_div_h < div_h: + roi_bin_grid_h = int32(bin_h) + roi_bin_grid_w = int32(bin_w) + if rounded_bin_h < bin_h: roi_bin_grid_h += 1 - if rounded_div_w < div_w: + if rounded_bin_w < bin_w: roi_bin_grid_w += 1 count = roi_bin_grid_h * roi_bin_grid_w @@ -211,7 +215,21 @@ def roi_align_nchw_cpu(data, rois, pooled_size, spatial_scale, sample_ratio=-1): """ if not isinstance(pooled_size, (tuple, list)): pooled_size = (pooled_size, pooled_size) + + # Pre-allocate intermediate buffer + if sample_ratio > 0: + max_roi_bin_grid_w = max_roi_bin_grid_h = sample_ratio + else: + _, _, height, width = get_const_tuple(data.shape) + max_roi_bin_grid_h = math.ceil(height / pooled_size[0]) + max_roi_bin_grid_w = math.ceil(width / pooled_size[1]) + max_pc_shape = (rois.shape[0], max_roi_bin_grid_h * max_roi_bin_grid_w + * pooled_size[0] * pooled_size[1], 4) + w_pc_buffer = full(max_pc_shape, data.dtype, 0) + pos_pc_buffer = full(max_pc_shape, "int32", 0) + pooled_size = tvm.convert(pooled_size) spatial_scale = tvm.const(spatial_scale, "float32") sample_ratio = tvm.const(sample_ratio, "int32") - return roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio) + return roi_align_nchw_ir(data, rois, w_pc_buffer, pos_pc_buffer, + pooled_size, spatial_scale, sample_ratio) diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 3a0b134..08b6d2e 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -306,6 +306,7 @@ def test_roi_align(): verify_roi_align(1, 16, 32, 64, 7, 1.0, -1) verify_roi_align(4, 16, 32, 64, 7, 0.5, 2) verify_roi_align(1, 32, 32, 80, 8, 0.0625, 2) + verify_roi_align(1, 32, 500, 80, 8, 0.0625, 2) def verify_roi_pool(batch, in_channel, in_size, num_roi, pooled_size, spatial_scale): -- 2.7.4