[TOPI] upsample operator 'NCHWinic' format support. (#4791)
authorHua Jiang <huaj@xilinx.com>
Mon, 3 Feb 2020 19:36:53 +0000 (11:36 -0800)
committerGitHub <noreply@github.com>
Mon, 3 Feb 2020 19:36:53 +0000 (11:36 -0800)
* [TOPI] upsample operator 'NCHWinic' format support.

some hardware accelerator ask packed format data like NCHWinic to fit the
hardware resource, here add upsample NCHWinic format support to help
such requirement.

* address review comments, add assert for 'else must be NCHWxc' logic.

topi/python/topi/image/resize.py
topi/python/topi/testing/bilinear_resize_python.py
topi/python/topi/testing/upsampling_python.py
topi/python/topi/util.py
topi/tests/python/test_topi_upsampling.py

index 00ae5d6..0c02867 100644 (file)
 """TVM operator input resize compute."""
 from __future__ import absolute_import
 import tvm
+from topi.util import nchw_pack_layout, nchw_xc_layout
 from .. import tag
 
+def get_2d_indices(indices, layout='NCHW'):
+    """ Get 2d indices """
+    (cc, inum, ic) = (0, 0, 0)
+    if layout == 'NHWC':
+        n, y, x, c = indices
+        cc = None
+    elif layout == 'NCHW':
+        n, c, y, x = indices
+        cc = None
+    elif nchw_pack_layout(layout):
+        n, c, y, x, inum, ic = indices
+    else:
+        # else must be NCHWxc
+        assert nchw_xc_layout(layout)
+        n, c, y, x, cc = indices
+
+    return n, c, y, x, cc, inum, ic
+
+def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, ib, ic):
+    """ Get 2d pixel """
+    if boxes is None:
+        y = tvm.max(tvm.min(y, image_height - 1), 0)
+        x = tvm.max(tvm.min(x, image_width - 1), 0)
+    if layout == 'NHWC':
+        return data(n, y, x, c).astype('float')
+    if layout == 'NCHW':
+        return data(n, c, y, x).astype('float')
+    if nchw_pack_layout(layout):
+        return data(n, c, y, x, ib, ic).astype('float')
+
+    # else must be NCHWxc
+    assert nchw_xc_layout(layout)
+    return data(n, c, y, x, cc).astype('float')
 
 def resize_nearest_neighbor(indices, data, image_height, image_width,
                             target_height, target_width, boxes=None,
@@ -89,29 +123,7 @@ def resize_nearest_neighbor(indices, data, image_height, image_width,
             dtype = data_dtype
         return value.astype(dtype)
 
-    def _get_indices(indices, layout='NCHW'):
-        if layout == 'NHWC':
-            n, y, x, c = indices
-            cc = None
-        elif layout == 'NCHW':
-            n, c, y, x = indices
-            cc = None
-        else:
-            n, c, y, x, cc = indices
-        return n, c, y, x, cc
-
-    def _get_pixel(data, layout, n, c, y, x, cc):
-        if boxes is None:
-            y = tvm.max(tvm.min(y, image_height - 1), 0)
-            x = tvm.max(tvm.min(x, image_width - 1), 0)
-        if layout == 'NHWC':
-            return data(n, y, x, c).astype('float')
-        if layout == 'NCHW':
-            return data(n, c, y, x).astype('float')
-        # else must be NCHWxc
-        return data(n, c, y, x, cc).astype('float')
-
-    n, c, y, x, cc = _get_indices(indices, layout)
+    n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout)
     box_idx = box_indices(n) if box_indices is not None else n
     if boxes is not None:
         y1, x1 = boxes(n, 0), boxes(n, 1)
@@ -146,7 +158,8 @@ def resize_nearest_neighbor(indices, data, image_height, image_width,
         closest_y_index = tvm.floor(in_y + epsilon).astype('int32')
         closest_x_index = tvm.floor(in_x + epsilon).astype('int32')
 
-    value = _get_pixel(data, layout, box_idx, c, closest_y_index, closest_x_index, cc)
+    value = get_2d_pixel(data, layout, boxes, image_height, image_width,
+                         box_idx, c, closest_y_index, closest_x_index, cc, inum, ic)
 
     if extrapolation_value is not None:
         out = tvm.if_then_else(in_y < 0,
@@ -234,29 +247,7 @@ def resize_bilinear(indices, data, image_height, image_width,
     def _lerp(A, B, t):
         return A * (1.0 - t) + B * t
 
-    def _get_indices(indices, layout='NCHW'):
-        if layout == 'NHWC':
-            n, y, x, c = indices
-            cc = None
-        elif layout == 'NCHW':
-            n, c, y, x = indices
-            cc = None
-        else:
-            n, c, y, x, cc = indices
-        return n, c, y, x, cc
-
-    def _get_pixel(data, layout, n, c, y, x, cc):
-        if boxes is None:
-            y = tvm.max(tvm.min(y, image_height - 1), 0)
-            x = tvm.max(tvm.min(x, image_width - 1), 0)
-        if layout == 'NHWC':
-            return data(n, y, x, c).astype('float')
-        if layout == 'NCHW':
-            return data(n, c, y, x).astype('float')
-        # else must be NCHWxc
-        return data(n, c, y, x, cc).astype('float')
-
-    n, c, y, x, cc = _get_indices(indices, layout=layout)
+    n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout=layout)
     box_idx = box_indices(n) if box_indices is not None else n
 
     if boxes is not None:
@@ -296,10 +287,14 @@ def resize_bilinear(indices, data, image_height, image_width,
     right_x_index = tvm.ceil(in_x).astype('int32')
     x_lerp = in_x - left_x_index
 
-    top_left = _get_pixel(data, layout, box_idx, c, top_y_index, left_x_index, cc)
-    top_right = _get_pixel(data, layout, box_idx, c, top_y_index, right_x_index, cc)
-    bottom_left = _get_pixel(data, layout, box_idx, c, bottom_y_index, left_x_index, cc)
-    bottom_right = _get_pixel(data, layout, box_idx, c, bottom_y_index, right_x_index, cc)
+    top_left = get_2d_pixel(data, layout, boxes, image_height, image_width,
+                            box_idx, c, top_y_index, left_x_index, cc, inum, ic)
+    top_right = get_2d_pixel(data, layout, boxes, image_height, image_width,
+                             box_idx, c, top_y_index, right_x_index, cc, inum, ic)
+    bottom_left = get_2d_pixel(data, layout, boxes, image_height, image_width,
+                               box_idx, c, bottom_y_index, left_x_index, cc, inum, ic)
+    bottom_right = get_2d_pixel(data, layout, boxes, image_height, image_width,
+                                box_idx, c, bottom_y_index, right_x_index, cc, inum, ic)
 
     top = _lerp(top_left, top_right, x_lerp)
     bottom = _lerp(bottom_left, bottom_right, x_lerp)
@@ -394,29 +389,7 @@ def resize_bicubic(indices, data, image_height, image_width,
             dtype = data_dtype
         return value.astype(dtype)
 
-    def _get_indices(indices, layout='NCHW'):
-        if layout == 'NHWC':
-            n, y, x, c = indices
-            cc = None
-        elif layout == 'NCHW':
-            n, c, y, x = indices
-            cc = None
-        else:
-            n, c, y, x, cc = indices
-        return n, c, y, x, cc
-
-    def _get_pixel(data, layout, n, c, y, x, cc):
-        if boxes is None:
-            y = tvm.max(tvm.min(y, image_height - 1), 0)
-            x = tvm.max(tvm.min(x, image_width - 1), 0)
-        if layout == 'NHWC':
-            return data(n, y, x, c).astype('float')
-        if layout == 'NCHW':
-            return data(n, c, y, x).astype('float')
-        # else must be NCHWxc
-        return data(n, c, y, x, cc).astype('float')
-
-    n, c, y, x, cc = _get_indices(indices, layout)
+    n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout)
     box_idx = box_indices(n) if box_indices is not None else n
 
     if boxes is not None:
@@ -455,28 +428,44 @@ def resize_bicubic(indices, data, image_height, image_width,
     yfract = in_y - tvm.floor(in_y)
 
     # 1st row
-    p00 = _get_pixel(data, layout, box_idx, c, yint - 1, xint - 1, cc)
-    p10 = _get_pixel(data, layout, box_idx, c, yint - 1, xint + 0, cc)
-    p20 = _get_pixel(data, layout, box_idx, c, yint - 1, xint + 1, cc)
-    p30 = _get_pixel(data, layout, box_idx, c, yint - 1, xint + 2, cc)
+    p00 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint - 1, xint - 1, cc, inum, ic)
+    p10 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint - 1, xint + 0, cc, inum, ic)
+    p20 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint - 1, xint + 1, cc, inum, ic)
+    p30 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint - 1, xint + 2, cc, inum, ic)
 
     # 2nd row
-    p01 = _get_pixel(data, layout, box_idx, c, yint + 0, xint - 1, cc)
-    p11 = _get_pixel(data, layout, box_idx, c, yint + 0, xint + 0, cc)
-    p21 = _get_pixel(data, layout, box_idx, c, yint + 0, xint + 1, cc)
-    p31 = _get_pixel(data, layout, box_idx, c, yint + 0, xint + 2, cc)
+    p01 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint + 0, xint - 1, cc, inum, ic)
+    p11 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint + 0, xint + 0, cc, inum, ic)
+    p21 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint + 0, xint + 1, cc, inum, ic)
+    p31 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint + 0, xint + 2, cc, inum, ic)
 
     # 3rd row
-    p02 = _get_pixel(data, layout, box_idx, c, yint + 1, xint - 1, cc)
-    p12 = _get_pixel(data, layout, box_idx, c, yint + 1, xint + 0, cc)
-    p22 = _get_pixel(data, layout, box_idx, c, yint + 1, xint + 1, cc)
-    p32 = _get_pixel(data, layout, box_idx, c, yint + 1, xint + 2, cc)
+    p02 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint + 1, xint - 1, cc, inum, ic)
+    p12 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint + 1, xint + 0, cc, inum, ic)
+    p22 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint + 1, xint + 1, cc, inum, ic)
+    p32 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint + 1, xint + 2, cc, inum, ic)
 
     # 4th row
-    p03 = _get_pixel(data, layout, box_idx, c, yint + 2, xint - 1, cc)
-    p13 = _get_pixel(data, layout, box_idx, c, yint + 2, xint + 0, cc)
-    p23 = _get_pixel(data, layout, box_idx, c, yint + 2, xint + 1, cc)
-    p33 = _get_pixel(data, layout, box_idx, c, yint + 2, xint + 2, cc)
+    p03 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint + 2, xint - 1, cc, inum, ic)
+    p13 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint + 2, xint + 0, cc, inum, ic)
+    p23 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint + 2, xint + 1, cc, inum, ic)
+    p33 = _get_pixel(data, layout, boxes, image_height, image_width,
+                     box_idx, c, yint + 2, xint + 2, cc, inum, ic)
 
     # Interpolate bicubically
     col0 = _cubic_kernel(p00, p10, p20, p30, xfract)
@@ -536,6 +525,7 @@ def resize(data, size, layout="NCHW", method="bilinear",
         or [batch, in_height*scale, in_width*scale, channel]
         or 5-D with shape [batch, channel-major, in_height*scale, in_width*scale, channel-minor]
     """
+
     method = method.lower()
 
     if layout == 'NHWC':
@@ -544,7 +534,10 @@ def resize(data, size, layout="NCHW", method="bilinear",
     elif layout == 'NCHW':
         in_n, in_c, in_h, in_w = data.shape
         output_shape = [in_n, in_c, size[0], size[1]]
-    elif layout.startswith("NCHW"):# for NCHWxc
+    elif nchw_pack_layout(layout):# for NCHWinic
+        in_n, in_c, in_h, in_w, in_inum, in_ic = data.shape
+        output_shape = [in_n, in_c, size[0], size[1], in_inum, in_ic]
+    elif nchw_xc_layout(layout):# for NCHWxc
         in_n, in_c, in_h, in_w, in_cc = data.shape
         output_shape = [in_n, in_c, size[0], size[1], in_cc]
     else:
index d324e29..4d12d39 100644 (file)
 """Bilinear Scale in python"""
 import math
 import numpy as np
+from topi.util import nchw_pack_layout
 
 def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mode="align_corners"):
     """ Bilinear scaling using python"""
     (new_h, new_w) = out_size
+    (ib, ic) = (1, 1)
 
     if layout == 'NHWC':
         (batch, h, w, channel) = image.shape
         scaled_image = np.ones((batch, new_h, new_w, channel))
+    # NCHWinic
+    elif nchw_pack_layout(layout):
+        (batch, channel, h, w, ib, ic) = image.shape
+        scaled_image = np.ones((batch, channel, new_h, new_w, ib, ic))
     else:
         (batch, channel, h, w) = image.shape
         scaled_image = np.ones((batch, channel, new_h, new_w))
@@ -40,47 +46,59 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo
     def _lerp(A, B, t):
         return A * (1.0 - t) + B * t
 
-    for b in range(batch):
-        for i in range(channel):
-            for j in range(new_h):
-                for k in range(new_w):
-                    if coordinate_transformation_mode == "half_pixel":
-                        in_y = (j + 0.5) * height_scale - 0.5
-                    else:
-                        in_y = j * height_scale
-                    y0 = int(math.floor(in_y))
-                    y1 = max(min(y0 + 1, h - 1), 0)
-                    y0 = max(y0, 0)
-                    y_lerp = in_y - math.floor(in_y)
+    def _img_scale(b, m, i, n):
+        for j in range(new_h):
+            for k in range(new_w):
+                if coordinate_transformation_mode == "half_pixel":
+                    in_y = (j + 0.5) * height_scale - 0.5
+                else:
+                    in_y = j * height_scale
+                y0 = int(math.floor(in_y))
+                y1 = max(min(y0 + 1, h - 1), 0)
+                y0 = max(y0, 0)
+                y_lerp = in_y - math.floor(in_y)
+
+                if coordinate_transformation_mode == "half_pixel":
+                    in_x = (k + 0.5) * width_scale - 0.5
+                else:
+                    in_x = k * width_scale
+                x0 = int(math.floor(in_x))
+                x1 = max(min(x0 + 1, w - 1), 0)
+                x0 = max(x0, 0)
+                x_lerp = in_x - math.floor(in_x)
 
-                    if coordinate_transformation_mode == "half_pixel":
-                        in_x = (k + 0.5) * width_scale - 0.5
-                    else:
-                        in_x = k * width_scale
-                    x0 = int(math.floor(in_x))
-                    x1 = max(min(x0 + 1, w - 1), 0)
-                    x0 = max(x0, 0)
-                    x_lerp = in_x - math.floor(in_x)
+                if layout == 'NHWC':
+                    A = image[b][y0][x0][i]
+                    B = image[b][y0][x1][i]
+                    C = image[b][y1][x0][i]
+                    D = image[b][y1][x1][i]
+                elif nchw_pack_layout(layout):
+                    A = image[b][i][y0][x0][m][n]
+                    B = image[b][i][y0][x1][m][n]
+                    C = image[b][i][y1][x0][m][n]
+                    D = image[b][i][y1][x1][m][n]
+                else:
+                    A = image[b][i][y0][x0]
+                    B = image[b][i][y0][x1]
+                    C = image[b][i][y1][x0]
+                    D = image[b][i][y1][x1]
 
-                    if layout == 'NHWC':
-                        A = image[b][y0][x0][i]
-                        B = image[b][y0][x1][i]
-                        C = image[b][y1][x0][i]
-                        D = image[b][y1][x1][i]
-                    else:
-                        A = image[b][i][y0][x0]
-                        B = image[b][i][y0][x1]
-                        C = image[b][i][y1][x0]
-                        D = image[b][i][y1][x1]
+                top = _lerp(A, B, x_lerp)
+                bottom = _lerp(C, D, x_lerp)
 
-                    top = _lerp(A, B, x_lerp)
-                    bottom = _lerp(C, D, x_lerp)
+                pixel = np.float32(_lerp(top, bottom, y_lerp))
 
-                    pixel = np.float32(_lerp(top, bottom, y_lerp))
+                if layout == 'NHWC':
+                    scaled_image[b][j][k][i] = pixel
+                elif nchw_pack_layout(layout):
+                    scaled_image[b][i][j][k][m][n] = pixel
+                else:
+                    scaled_image[b][i][j][k] = pixel
 
-                    if layout == 'NHWC':
-                        scaled_image[b][j][k][i] = pixel
-                    else:
-                        scaled_image[b][i][j][k] = pixel
+    for b in range(batch):
+        for m in range(ib):
+            for i in range(channel):
+                for n in range(ic):
+                    _img_scale(b, m, i, n)
 
     return scaled_image
index a34e541..f2fa80f 100644 (file)
@@ -18,6 +18,8 @@
 """Upsampling in python"""
 import math
 import numpy as np
+from topi.util import nchw_pack_layout
+
 
 def upsample_nearest(arr, scale):
     """ Populate the array by scale factor"""
@@ -44,6 +46,18 @@ def upsampling_python(data, scale, layout='NCHW'):
             for c in range(oshape[1]):
                 output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale)
         return output_np
+    # NCHWinic
+    if nchw_pack_layout(layout):
+        oshape = (ishape[0], ishape[1], int(round(ishape[2]*scale[0])),
+                  int(round(ishape[3]*scale[1])), ishape[4], ishape[5])
+        output_np = np.zeros(oshape, dtype=data.dtype)
+        for b in range(oshape[0]):
+            for ib in range(oshape[4]):
+                for c in range(oshape[1]):
+                    for ic in range(oshape[5]):
+                        output_np[b, c, :, :, ib, ic] = upsample_nearest(data[b, c, :, :, ib, ic], scale)
+        return output_np
+
     if layout == 'NHWC':
         oshape = (ishape[0], int(round(ishape[1]*scale[0])),
                   int(round(ishape[2]*scale[1])), ishape[3])
index 02d082b..4c4aabf 100644 (file)
@@ -27,6 +27,14 @@ class InvalidShapeError(ValueError):
     """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
     pass
 
+def nchw_pack_layout(layout_info):
+    """Check whether the layout type is NCHWinic"""
+    return layout_info[:4] == 'NCHW' and 'c' in layout_info and 'n' in layout_info
+
+def nchw_xc_layout(layout_info):
+    """Check whether the layout type is NCHWxc"""
+    return layout_info[:4] == 'NCHW' and 'c' in layout_info and layout_info[4:-1].isnumeric()
+
 def traverse_inline(s, final_op, callback):
     """Traverse computation graph and do auto inline
 
index 3aa67a5..875b2f7 100644 (file)
@@ -20,16 +20,26 @@ import tvm
 import topi
 import topi.testing
 import math
+from topi.util import nchw_pack_layout
 
 from common import get_all_backend
 
 def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w,
-                      layout='NCHW', method="nearest_neighbor"):
+                      layout='NCHW', method="nearest_neighbor",
+                      in_batch_block = 0, in_channel_block = 0):
     if layout == 'NCHW':
         A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
         dtype = A.dtype
         out_shape = (batch, in_channel, int(round(in_height*scale_h)), int(round(in_width*scale_w)))
         a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
+    elif nchw_pack_layout(layout):
+        A = tvm.placeholder((batch, in_channel, in_height, in_width, in_batch_block, in_channel_block),
+                             name='A')
+        dtype = A.dtype
+        out_shape = (batch, in_channel, int(round(in_height*scale_h)), int(round(in_width*scale_w)),
+                     in_batch_block, in_channel_block)
+        a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width,
+                                 in_batch_block, in_channel_block)).astype(dtype)
     elif layout == 'NHWC':
         A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A')
         dtype = A.dtype
@@ -81,6 +91,22 @@ def test_upsampling():
     verify_upsampling(2, 2, 32, 32, 3.0, 3.0, method="bilinear")
     verify_upsampling(1, 64, 22, 32, 1.954545497894287, 2.0, method="bilinear")
 
+    # nearest_neighbor - NCHWinic
+    verify_upsampling(2, 2, 32, 32, in_batch_block=4, in_channel_block=8,
+                      scale_h=2.0, scale_w=2.0)
+    verify_upsampling(2, 2, 64, 64, in_batch_block=1, in_channel_block=16,
+                      scale_h=3.0, scale_w=3.0)
+    verify_upsampling(1, 4, 22, 32, in_batch_block=1, in_channel_block=16,
+                      scale_h=1.954545497894287, scale_w=2.0)
+
+    # bilinear - NCHWinic
+    verify_upsampling(2, 2, 32, 32, in_batch_block=1, in_channel_block=1,
+                      scale_h=2.0, scale_w=2.0, method="bilinear")
+    verify_upsampling(2, 2, 32, 32, in_batch_block=1, in_channel_block=1,
+                      scale_h=3.0, scale_w=3.0, method="bilinear")
+    verify_upsampling(2, 4, 22, 32, in_batch_block=1, in_channel_block=16,
+                      scale_h=1.954545497894287, scale_w=2.0, layout="NCHW1n16c", method="bilinear")
+
     # bilinear - NHWC
     verify_upsampling(2, 2, 32, 32, 2.0, 2.0, layout="NHWC", method="bilinear")
     verify_upsampling(2, 2, 32, 32, 3.0, 3.0, layout="NHWC", method="bilinear")