Pre-allocate buffer for x86 roi_align (#3475)
authorYao Wang <kevinthesunwy@gmail.com>
Wed, 3 Jul 2019 17:08:04 +0000 (10:08 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Wed, 3 Jul 2019 17:08:04 +0000 (10:08 -0700)
* Pre-allocate buffer for x86 roi_align

* Fix typo

topi/python/topi/x86/roi_align.py
topi/tests/python/test_topi_vision.py

index a8ad387..26b84be 100644 (file)
 # under the License.
 # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements
 """Non-maximum suppression operator for intel cpu"""
+import math
 import tvm
 
 from tvm import hybrid
 from ..vision.rcnn import roi_align_nchw
+from ..tensor import full
+from ..util import get_const_tuple
 
 
 @hybrid.script
-def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio):
+def roi_align_nchw_ir(data, rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio):
     """Hybrid routing fo ROI align operator in NCHW layout.
 
     Parameters
@@ -35,6 +38,12 @@ def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio):
         2-D with shape [num_roi, 5]. The last dimension should be in format of
         [batch_index, w_start, h_start, w_end, h_end]
 
+    w_pc : tvm.Tensor or numpy NDArray
+        3-D weight pre-calculation buffer
+
+    pos_pc : tvm.Tensor or numpy NDArray
+        3-D position pre-calculation buffer
+
     pooled_size : tvm ConsExpr
         [out_height, out_width]
 
@@ -57,9 +66,6 @@ def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio):
     pooled_size_h = pooled_size[0]
     pooled_size_w = pooled_size[1]
     output = output_tensor((num_rois, channels, pooled_size_h, pooled_size_w), data.dtype)
-    max_num_pc_index = height * width * pooled_size_h * pooled_size_w
-    w_pc = allocate((num_rois, max_num_pc_index, 4), data.dtype)
-    pos_pc = allocate((num_rois, max_num_pc_index, 4), "int32")
 
     for n in parallel(num_rois):
         roi_batch_index = int32(rois[n, 0])
@@ -76,18 +82,16 @@ def roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio):
 
         roi_bin_grid_h = sample_ratio
         roi_bin_grid_w = roi_bin_grid_h
-        div_h = roi_h / pooled_size_h
-        div_w = roi_w / pooled_size_w
-        rounded_div_h = int32(div_h) * 1.0
-        rounded_div_w = int32(div_w) * 1.0
+        rounded_bin_h = int32(bin_h) * 1.0
+        rounded_bin_w = int32(bin_w) * 1.0
         if sample_ratio <= 0:
             # Cannot use ceil function since hybrid script
             # doesn't support Call as indexing
-            roi_bin_grid_h = int32(div_h)
-            roi_bin_grid_w = int32(div_w)
-            if rounded_div_h < div_h:
+            roi_bin_grid_h = int32(bin_h)
+            roi_bin_grid_w = int32(bin_w)
+            if rounded_bin_h < bin_h:
                 roi_bin_grid_h += 1
-            if rounded_div_w < div_w:
+            if rounded_bin_w < bin_w:
                 roi_bin_grid_w += 1
 
         count = roi_bin_grid_h * roi_bin_grid_w
@@ -211,7 +215,21 @@ def roi_align_nchw_cpu(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
     """
     if not isinstance(pooled_size, (tuple, list)):
         pooled_size = (pooled_size, pooled_size)
+
+    # Pre-allocate intermediate buffer
+    if sample_ratio > 0:
+        max_roi_bin_grid_w = max_roi_bin_grid_h = sample_ratio
+    else:
+        _, _, height, width = get_const_tuple(data.shape)
+        max_roi_bin_grid_h = math.ceil(height / pooled_size[0])
+        max_roi_bin_grid_w = math.ceil(width / pooled_size[1])
+    max_pc_shape = (rois.shape[0], max_roi_bin_grid_h * max_roi_bin_grid_w
+                    * pooled_size[0] * pooled_size[1], 4)
+    w_pc_buffer = full(max_pc_shape, data.dtype, 0)
+    pos_pc_buffer = full(max_pc_shape, "int32", 0)
+
     pooled_size = tvm.convert(pooled_size)
     spatial_scale = tvm.const(spatial_scale, "float32")
     sample_ratio = tvm.const(sample_ratio, "int32")
-    return roi_align_nchw_ir(data, rois, pooled_size, spatial_scale, sample_ratio)
+    return roi_align_nchw_ir(data, rois, w_pc_buffer, pos_pc_buffer,
+                             pooled_size, spatial_scale, sample_ratio)
index 3a0b134..08b6d2e 100644 (file)
@@ -306,6 +306,7 @@ def test_roi_align():
     verify_roi_align(1, 16, 32, 64, 7, 1.0, -1)
     verify_roi_align(4, 16, 32, 64, 7, 0.5, 2)
     verify_roi_align(1, 32, 32, 80, 8, 0.0625, 2)
+    verify_roi_align(1, 32, 500, 80, 8, 0.0625, 2)
 
 
 def verify_roi_pool(batch, in_channel, in_size, num_roi, pooled_size, spatial_scale):