From c3deec194f4df209a1ea379f689f24d3f927493b Mon Sep 17 00:00:00 2001 From: optima2005 <56945758+optima2005@users.noreply.github.com> Date: Fri, 27 Dec 2019 22:25:25 +0800 Subject: [PATCH] [TOPI] add 3D upsampling Op. (#4584) * [TOPI] add 3D upsampling Op. * fix lint issues * change align_corners to coordinate_transformation_mode * fix resize3d half_pixel * make a simple function and clean up trilinear_resize3d_python * fix doc --- docs/langref/relay_op.rst | 2 + include/tvm/relay/attrs/nn.h | 33 ++++ python/tvm/relay/op/nn/_nn.py | 19 +++ python/tvm/relay/op/nn/nn.py | 52 +++++++ python/tvm/relay/op/op_attrs.py | 4 + src/relay/op/nn/upsampling.cc | 90 ++++++++++- tests/python/relay/test_op_level2.py | 62 ++++++++ topi/python/topi/image/resize.py | 169 +++++++++++++++++++++ topi/python/topi/nn/upsampling.py | 55 +++++++ topi/python/topi/testing/__init__.py | 3 +- .../topi/testing/trilinear_resize3d_python.py | 105 +++++++++++++ topi/python/topi/testing/upsampling_python.py | 42 +++++ topi/tests/python/test_topi_resize.py | 63 ++++++++ topi/tests/python/test_topi_upsampling.py | 68 +++++++++ 14 files changed, 763 insertions(+), 4 deletions(-) create mode 100644 topi/python/topi/testing/trilinear_resize3d_python.py diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 54163ac..1fabd70 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -77,6 +77,7 @@ This level enables typical convnet models. tvm.relay.nn.global_max_pool2d tvm.relay.nn.global_avg_pool2d tvm.relay.nn.upsampling + tvm.relay.nn.upsampling3d tvm.relay.nn.batch_flatten tvm.relay.nn.pad tvm.relay.nn.lrn @@ -254,6 +255,7 @@ Level 2 Definitions .. autofunction:: tvm.relay.nn.global_max_pool2d .. autofunction:: tvm.relay.nn.global_avg_pool2d .. autofunction:: tvm.relay.nn.upsampling +.. autofunction:: tvm.relay.nn.upsampling3d .. autofunction:: tvm.relay.nn.batch_flatten .. autofunction:: tvm.relay.nn.pad .. autofunction:: tvm.relay.nn.lrn diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 4e061bd..d724f81 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -589,6 +589,39 @@ struct UpSamplingAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for upsampling3d operator */ +struct UpSampling3DAttrs : public tvm::AttrsNode { + double scale_d; + double scale_h; + double scale_w; + std::string layout; + std::string method; + std::string coordinate_transformation_mode; + + TVM_DECLARE_ATTRS(UpSampling3DAttrs, "relay.attrs.UpSampling3DAttrs") { + TVM_ATTR_FIELD(scale_d) + .describe("The upsampling factor for depth"); + TVM_ATTR_FIELD(scale_h) + .describe("The upsampling factor for height"); + TVM_ATTR_FIELD(scale_w) + .describe("The upsampling factor for width"); + TVM_ATTR_FIELD(layout).set_default("NCDHW") + .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Upsampling is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method).set_default("nearest_neighbor") + .describe("Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "trilinear - Trilinear Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel") + .describe("Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); + } +}; + /*! \brief Attributes used for the padding operator */ struct PadAttrs : public tvm::AttrsNode { double pad_value; diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 761abc7..3223258 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -582,6 +582,25 @@ def compute_upsampling(attrs, inputs, out_dtype, target): align_corners = attrs.align_corners return [topi.nn.upsampling(inputs[0], scale_h, scale_w, layout, method, align_corners)] +# upsampling3d +reg.register_schedule("nn.upsampling3d", reg.schedule_injective) + +def schedule_upsampling3d(_, outs, target): + """Schedule definition of upsampling3d""" + with target: + return topi.generic.schedule_injective(outs) + +@reg.register_compute("nn.upsampling3d") +def compute_upsampling3d(attrs, inputs, out_dtype, target): + scale_d = attrs.scale_d + scale_h = attrs.scale_h + scale_w = attrs.scale_w + layout = attrs.layout + method = attrs.method + coordinate_transformation_mode = attrs.coordinate_transformation_mode + return [topi.nn.upsampling3d(inputs[0], scale_d, scale_h, scale_w, layout, method,\ + coordinate_transformation_mode)] + # pad reg.register_schedule("nn.pad", schedule_broadcast) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 5b6e174..ec360af 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -771,6 +771,58 @@ def upsampling(data, return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners) +def upsampling3d(data, + scale_d=1, + scale_h=1, + scale_w=1, + layout="NCDHW", + method="nearest_neighbor", + coordinate_transformation_mode="half_pixel"): + """3D Upsampling. + + This operator takes data as input and does 3D scaling to the given scale factor. + In the default case, where the data_layout is `NCDHW` + with data of shape (n, c, d, h, w) + out will have a shape (n, c, d*scale_d, h*scale_h, w*scale_w) + + method indicates the algorithm to be used while calculating the out value + and method can be one of ("trilinear", "nearest_neighbor") + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + scale_d : tvm.relay.Expr + The scale factor for depth upsampling. + + scale_h : tvm.relay.Expr + The scale factor for height upsampling. + + scale_w : tvm.relay.Expr + The scale factor for width upsampling. + + layout : str, optional + Layout of the input. + + method : str, optional + Scale method to used [nearest_neighbor, trilinear]. + + coordinate_transformation_mode: string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.upsampling3d(data, scale_d, scale_h, scale_w, layout, method, + coordinate_transformation_mode) + + def batch_flatten(data): """BatchFlatten. diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index a0ce40b..e0887e5 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -64,6 +64,10 @@ class UpSamplingAttrs(Attrs): """Attributes for nn.upsampling""" @register_relay_attr_node +class UpSampling3DAttrs(Attrs): + """Attributes for nn.upsampling3d""" + +@register_relay_attr_node class PadAttrs(Attrs): """Attributes for nn.pad""" diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 34cd9f9..61b4058 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -33,6 +33,7 @@ namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(UpSamplingAttrs); +TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs); template Array > UpsamplingInferCorrectLayout( @@ -50,8 +51,11 @@ Array > UpsamplingInferCorrectLayout( Layout input = new_in_layouts[0]; if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) && input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && - !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) { - params->layout = input.name(); // modify self to follow the input layout + !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))&& + (input.IndexOf(LayoutAxis::Get('D')) == -1 || + (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) && + !input.Contains(LayoutAxis::Get('d'))))) { + params->layout = input.name(); // modify self to follow the input layout } } @@ -108,7 +112,6 @@ Expr MakeUpSampling(Expr data, return CallNode::make(op, {data}, Attrs(attrs), {}); } - TVM_REGISTER_API("relay.op.nn._make.upsampling") .set_body_typed(MakeUpSampling); @@ -138,5 +141,86 @@ RELAY_REGISTER_OP("nn.upsampling") .set_attr("TOpPattern", kInjective); +// UpSampling3D +bool UpSampling3DRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + static const Layout kNCDHW("NCDHW"); + + const UpSampling3DAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->layout); + + auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCDHW); + CHECK(layout_converter.defined()) + << "UpSampling3D only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; + + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(2, ir::Cast::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d))); + oshape.Set(3, ir::Cast::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h))); + oshape.Set(4, ir::Cast::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w))); + + // assign output type + reporter->Assign(types[1], + TensorTypeNode::make(layout_converter.BackwardShape(oshape), + data->dtype)); + return true; +} + +// Positional relay function to create upsampling3d operator +// used by frontend FFI. +Expr MakeUpSampling3D(Expr data, + double scale_d, + double scale_h, + double scale_w, + std::string layout, + std::string method, + std::string coordinate_transformation_mode) { + auto attrs = make_node(); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->scale_d = scale_d; + attrs->scale_h = scale_h; + attrs->scale_w = scale_w; + attrs->coordinate_transformation_mode = coordinate_transformation_mode; + static const Op& op = Op::Get("nn.upsampling3d"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.upsampling3d") +.set_body_typed(MakeUpSampling3D); + + +RELAY_REGISTER_OP("nn.upsampling3d") +.describe(R"code(Perform upsampling on input array with nearest neighbour or +bilinear interpolation. + +- **data**: data is 5D array of shape + (batch_size, channels, in_depth, in_height, in_width) for NCDHW + (batch_size, in_depth, in_height, in_width, channels) for NDHWC + +- **out**: Output is 5D array of shape + for layout NCDHW + (batch_size, channels, in_depth*scale, in_height*scale, in_width*scale) + + for layout NDHWC + (batch_size, in_depth*scale, in_height*scale, in_width*scale, channels) + +)code" TVM_ADD_FILELINE) +.set_attrs_type() +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("UpSampling3D", UpSampling3DRel) +.set_attr("FInferCorrectLayout", + UpsamplingInferCorrectLayout) +.set_attr("TOpPattern", kInjective); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 722c31f..2f19f7a 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -456,6 +456,22 @@ def test_upsampling_infer_type(): yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32") +def test_upsampling3d_infer_type(): + n, c, d, h, w = tvm.var("n"), tvm.var("c"), tvm.var("d"), tvm.var("h"), tvm.var("w") + scale = tvm.const(2.0, "float64") + x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32")) + y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear") + + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, tvm.expr.Cast("int32", tvm.round(d*scale)), + tvm.expr.Cast("int32", tvm.round(h*scale)), + tvm.expr.Cast("int32", tvm.round(w*scale))), + "float32") + n, c = tvm.var("n"), tvm.var("c") + x = relay.var("x", relay.TensorType((n, c, 100, 100, 200), "float32")) + y = relay.nn.upsampling3d(x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear") + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, 200, 200, 400), "float32") def _test_pool2d(opfunc, reffunc): n, c, h, w = tvm.var("n"), 10, 224, 224 @@ -782,6 +798,50 @@ def test_upsampling(): _test_upsampling("NHWC", "nearest_neighbor") _test_upsampling("NHWC", "bilinear", True) +def _test_upsampling3d(layout, method, coordinate_transformation_mode="half_pixel"): + n, c, d, h, w = tvm.var("n"), 8, 16, 16, 16 + scale_d = 2.0 + scale_h = 2.0 + scale_w = 2.0 + dtype = "float32" + def get_shape(): + if layout == "NCDHW": + return (c, d, h, w), (c, int(round(d*scale_d)), int(round(h*scale_h)),\ + int(round(w*scale_w))) + else: + return (d, h, w, c), (int(round(d*scale_d)), int(round(h*scale_h)),\ + int(round(w*scale_w)), c) + ishape, oshape = get_shape() + x = relay.var("x", relay.TensorType((n,) + ishape, dtype)) + y = relay.nn.upsampling3d(x, scale_d=scale_d, scale_h=scale_h, scale_w=scale_w,\ + layout=layout, method=method,\ + coordinate_transformation_mode=coordinate_transformation_mode) + + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n,) + oshape, dtype) + dshape = (1,) + ishape + x = relay.var("x", shape=dshape) + y = relay.nn.upsampling3d(x, scale_d=scale_d, scale_h=scale_h, scale_w=scale_w,\ + layout=layout, method=method,\ + coordinate_transformation_mode=coordinate_transformation_mode) + func = relay.Function([x], y) + data = np.random.uniform(size=dshape).astype(dtype) + if method == "nearest_neighbor": + ref = topi.testing.upsampling3d_python(data, (scale_d, scale_h, scale_w), layout) + else: + ref = topi.testing.trilinear_resize3d_python(data, (int(round(d*scale_d)),\ + int(round(h*scale_h)),\ + int(round(w*scale_w))), layout) + for target, ctx in ctx_list(): + executor = relay.create_executor("graph", ctx=ctx, target=target) + out = executor.evaluate(func)(data) + tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=1e-5) + +def test_upsampling3d(): + _test_upsampling3d("NCDHW", "nearest_neighbor") + _test_upsampling3d("NCDHW", "trilinear", "align_corners") + _test_upsampling3d("NDHWC", "nearest_neighbor") + _test_upsampling3d("NDHWC", "trilinear", "align_corners") def test_conv2d_int8_intrinsics(): def _compile(ic, oc, target, data_layout, kernel_layout, dtypes): @@ -935,6 +995,7 @@ if __name__ == "__main__": test_conv2d_infer_type() test_bitpack_infer_type() test_upsampling_infer_type() + test_upsampling3d_infer_type() test_flatten_infer_type() test_pad_infer_type() test_pad_run() @@ -948,4 +1009,5 @@ if __name__ == "__main__": test_bitserial_conv2d_infer_type() test_batch_flatten() test_upsampling() + test_upsampling3d() test_conv2d_int8_intrinsics() diff --git a/topi/python/topi/image/resize.py b/topi/python/topi/image/resize.py index 7a24990..27bea94 100644 --- a/topi/python/topi/image/resize.py +++ b/topi/python/topi/image/resize.py @@ -210,3 +210,172 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out raise ValueError('%s method is not supported.' % method) return tvm.compute(output_shape, compute_func, name='resize', tag=tag.INJECTIVE) + +def resize3d(data, size, layout="NCDHW", method="nearest_neighbor", + coordinate_transformation_mode="align_corners", out_dtype=None): + """Perform resize operation on the data. + + Parameters + ---------- + inputs: tvm.Tensor + inputs is a 5-D tensor with shape + [batch, channel, in_depth, in_height, in_width] + or [batch, in_depth, in_height, in_width, channel] + + size: Tuple + Output resolution scale to + + layout: string, optional + "NCDHW", "NDHWC", or "NCDHWc". + + coordinate_transformation_mode: string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". + + method: {"trilinear", "nearest_neighbor"} + Method to be used for resizing. + + out_dtype: string, optional + Type to return. If left None will be same as input type. + + Returns + ------- + output : tvm.Tensor + 5-D with shape [batch, channel, in_depth*scale, in_height*scale, in_width*scale] + or [batch, in_depth*scale, in_height*scale, in_width*scale, channel] + or 5-D with shape [batch, channel-major, in_depth*scale, in_height*scale, in_width*scale, + channel-minor] + """ + method = method.lower() + + if layout == 'NDHWC': + in_n, in_d, in_h, in_w, in_c = data.shape + output_shape = [in_n, size[0], size[1], size[2], in_c] + elif layout == 'NCDHW': + in_n, in_c, in_d, in_h, in_w = data.shape + output_shape = [in_n, in_c, size[0], size[1], size[2]] + # Otherwise layout must be NCHWxc + else: + in_n, in_c, in_d, in_h, in_w, in_cc = data.shape + output_shape = [in_n, in_c, size[0], size[1], size[2], in_cc] + + if coordinate_transformation_mode == "align_corners": + z_ratio = (in_d - 1).astype('float') / (size[0] - 1) + y_ratio = (in_h - 1).astype('float') / (size[1] - 1) + x_ratio = (in_w - 1).astype('float') / (size[2] - 1) + elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: + z_ratio = (in_d).astype('float') / (size[0]) + y_ratio = (in_h).astype('float') / (size[1]) + x_ratio = (in_w).astype('float') / (size[2]) + else: + raise ValueError("Unsupported coordinate_transformation_mode: {}".format( + coordinate_transformation_mode)) + + def _get_pixel(n, c, z, y, x, cc): + z = tvm.max(tvm.min(z, in_d - 1), 0) + y = tvm.max(tvm.min(y, in_h - 1), 0) + x = tvm.max(tvm.min(x, in_w - 1), 0) + if layout == 'NDHWC': + return data(n, z, y, x, c).astype('float') + if layout == 'NCDHW': + return data(n, c, z, y, x).astype('float') + # else must be NCDHWxc + return data(n, c, z, y, x, cc).astype('float') + + def _get_indices(*indices): + if layout == 'NDHWC': + n, z, y, x, c = indices + cc = None + elif layout == 'NCDHW': + n, c, z, y, x = indices + cc = None + else: + n, c, z, y, x, cc = indices + + return n, c, z, y, x, cc + + def _cast_output(value): + if out_dtype: + dtype = out_dtype + else: + dtype = data.dtype + return value.astype(dtype) + + # Nearest neighbor computation + def _nearest_neighbor(*indices): + n, c, z, y, x, cc = _get_indices(*indices) + + in_z = z_ratio * z + in_y = y_ratio * y + in_x = x_ratio * x + + if coordinate_transformation_mode == "align_corners": + zint = tvm.round(in_z).astype('int32') + yint = tvm.round(in_y).astype('int32') + xint = tvm.round(in_x).astype('int32') + elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: + # Add epsilon to floor to prevent gpu rounding errors. + epsilon = 1e-5 + zint = tvm.floor(in_z + epsilon).astype('int32') + yint = tvm.floor(in_y + epsilon).astype('int32') + xint = tvm.floor(in_x + epsilon).astype('int32') + else: + raise ValueError("Unsupported coordinate_transformation_mode: {}".format( + coordinate_transformation_mode)) + + return _cast_output(_get_pixel(n, c, zint, yint, xint, cc)) + + # Trilinear helper functions and computation. + def _lerp(A, B, t): + return A * (1.0 - t) + B * t + + def _trilinear(*indices): + n, c, z, y, x, cc = _get_indices(*indices) + + if coordinate_transformation_mode == "half_pixel": + in_z = z_ratio * (z + 0.5) - 0.5 + in_y = y_ratio * (y + 0.5) - 0.5 + in_x = x_ratio * (x + 0.5) - 0.5 + else: + in_z = z_ratio * z + in_y = y_ratio * y + in_x = x_ratio * x + + zint = tvm.floor(in_z).astype('int32') + zfract = in_z - tvm.floor(in_z) + + xint = tvm.floor(in_x).astype('int32') + xfract = in_x - tvm.floor(in_x) + + yint = tvm.floor(in_y).astype('int32') + yfract = in_y - tvm.floor(in_y) + + p000 = _get_pixel(n, c, zint, yint, xint, cc) + p001 = _get_pixel(n, c, zint, yint, xint + 1, cc) + p010 = _get_pixel(n, c, zint, yint + 1, xint, cc) + p011 = _get_pixel(n, c, zint, yint + 1, xint + 1, cc) + p100 = _get_pixel(n, c, zint + 1, yint, xint, cc) + p101 = _get_pixel(n, c, zint + 1, yint, xint + 1, cc) + p110 = _get_pixel(n, c, zint + 1, yint + 1, xint, cc) + p111 = _get_pixel(n, c, zint + 1, yint + 1, xint + 1, cc) + + dep00 = _lerp(p000, p100, zfract) + dep01 = _lerp(p001, p101, zfract) + dep10 = _lerp(p010, p110, zfract) + dep11 = _lerp(p011, p111, zfract) + col0 = _lerp(dep00, dep01, xfract) + col1 = _lerp(dep10, dep11, xfract) + value = _lerp(col0, col1, yfract) + return _cast_output(value) + + # Determine which interpolation method to use then run it. + if method == "nearest_neighbor": + compute_func = _nearest_neighbor + elif method == "trilinear": + compute_func = _trilinear + else: + raise ValueError('%s method is not supported.' % method) + + return tvm.compute(output_shape, compute_func, name='resize3d', tag=tag.INJECTIVE) diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index 771c9e2..fe63e47 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -63,3 +63,58 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', raise ValueError("not support this layout {} yet".format(layout)) return topi.image.resize(data, out_shape, layout=layout, method=method, align_corners=align_corners) + + +def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor', + coordinate_transformation_mode="half_pixel"): + """Perform upsampling on the data. + Nearest neighbor and bilinear upsampling are supported. + + Parameters + ---------- + inputs : tvm.Tensor + inputs is a 5-D tensor with shape + [batch, channel, in_depth, in_height, in_width] + or [batch, in_depth, in_height, in_width, channel] + + scale_d : float + Scaling factor for depth + + scale_h : float + Scaling factor for height + + scale_w : float + Scaling factor for width + + layout : string, optional + either "NCDHW" or "NDHWC" + + method : {"trilinear", "nearest_neighbor"} + Method to be used for upsampling. + + coordinate_transformation_mode: string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". + + Returns + ------- + output : tvm.Tensor + 5-D with shape [batch, channel, in_depth*scale, in_height*scale, in_width*scale] + or [batch, in_depth*scale, in_height*scale, in_width*scale, channel] + """ + base_layout = layout[0:5] + if base_layout == "NCDHW": + out_shape = (simplify(topi.cast(tvm.round(data.shape[2] * scale_d), data.shape[2].dtype)), + simplify(topi.cast(tvm.round(data.shape[3] * scale_h), data.shape[3].dtype)), + simplify(topi.cast(tvm.round(data.shape[4] * scale_w), data.shape[4].dtype))) + elif layout == "NDHWC": + out_shape = (simplify(topi.cast(tvm.round(data.shape[1] * scale_d), data.shape[1].dtype)), + simplify(topi.cast(tvm.round(data.shape[2] * scale_h), data.shape[2].dtype)), + simplify(topi.cast(tvm.round(data.shape[3] * scale_w), data.shape[3].dtype))) + + else: + raise ValueError("not support this layout {} yet".format(layout)) + return topi.image.resize3d(data, out_shape, layout=layout, method=method, + coordinate_transformation_mode=coordinate_transformation_mode) diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index a3e2241..2826a2b 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -31,8 +31,9 @@ from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python from .softmax_python import softmax_python, log_softmax_python -from .upsampling_python import upsampling_python +from .upsampling_python import upsampling_python, upsampling3d_python from .bilinear_resize_python import bilinear_resize_python +from .trilinear_resize3d_python import trilinear_resize3d_python from .reorg_python import reorg_python from .roi_align_python import roi_align_nchw_python from .roi_pool_python import roi_pool_nchw_python diff --git a/topi/python/topi/testing/trilinear_resize3d_python.py b/topi/python/topi/testing/trilinear_resize3d_python.py new file mode 100644 index 0000000..cc8fdd6 --- /dev/null +++ b/topi/python/topi/testing/trilinear_resize3d_python.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-nested-blocks +"""Trilinear 3D resize in python""" +import math +import numpy as np + +def trilinear_resize3d_python(data_in, out_size, layout, + coordinate_transformation_mode="align_corners"): + """ Trilinear 3d scaling using python""" + (new_d, new_h, new_w) = out_size + + if layout == 'NDHWC': + (batch, d, h, w, channel) = data_in.shape + data_out = np.ones((batch, new_d, new_h, new_w, channel)) + else: + (batch, channel, d, h, w) = data_in.shape + data_out = np.ones((batch, channel, new_d, new_h, new_w)) + + if coordinate_transformation_mode == "align_corners": + depth_scale = np.float32(d-1) / np.float32(out_size[0]-1) + height_scale = np.float32(h-1) / np.float32(out_size[1]-1) + width_scale = np.float32(w-1) / np.float32(out_size[2]-1) + elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: + depth_scale = np.float32(d) / np.float32(out_size[0]) + height_scale = np.float32(h) / np.float32(out_size[1]) + width_scale = np.float32(w) / np.float32(out_size[2]) + else: + raise ValueError("Unsupported coordinate_transformation_mode: {}".format( + coordinate_transformation_mode)) + + def _lerp(A, B, t): + return A * (1.0 - t) + B * t + + def _in_coord(new_coord, scale, shape, mode): + if mode == "half_pixel": + in_coord = (new_coord + 0.5) * scale - 0.5 + else: + in_coord = new_coord * scale + coord0 = int(math.floor(in_coord)) + coord1 = max(min(coord0 + 1, shape - 1), 0) + coord0 = max(coord0, 0) + coord_lerp = in_coord - math.floor(in_coord) + return coord0, coord1, coord_lerp + + for b in range(batch): + for i in range(channel): + for m in range(new_d): + for j in range(new_h): + for k in range(new_w): + z0, z1, z_lerp = _in_coord(m, depth_scale, d,\ + coordinate_transformation_mode) + y0, y1, y_lerp = _in_coord(j, height_scale, h,\ + coordinate_transformation_mode) + x0, x1, x_lerp = _in_coord(k, width_scale, w,\ + coordinate_transformation_mode) + + if layout == 'NDHWC': + A0 = data_in[b][z0][y0][x0][i] + B0 = data_in[b][z0][y0][x1][i] + C0 = data_in[b][z0][y1][x0][i] + D0 = data_in[b][z0][y1][x1][i] + A1 = data_in[b][z1][y0][x0][i] + B1 = data_in[b][z1][y0][x1][i] + C1 = data_in[b][z1][y1][x0][i] + D1 = data_in[b][z1][y1][x1][i] + else: + A0 = data_in[b][i][z0][y0][x0] + B0 = data_in[b][i][z0][y0][x1] + C0 = data_in[b][i][z0][y1][x0] + D0 = data_in[b][i][z0][y1][x1] + A1 = data_in[b][i][z1][y0][x0] + B1 = data_in[b][i][z1][y0][x1] + C1 = data_in[b][i][z1][y1][x0] + D1 = data_in[b][i][z1][y1][x1] + + A = _lerp(A0, A1, z_lerp) + B = _lerp(B0, B1, z_lerp) + C = _lerp(C0, C1, z_lerp) + D = _lerp(D0, D1, z_lerp) + top = _lerp(A, B, x_lerp) + bottom = _lerp(C, D, x_lerp) + + pixel = np.float32(_lerp(top, bottom, y_lerp)) + + if layout == 'NDHWC': + data_out[b][m][j][k][i] = pixel + else: + data_out[b][i][m][j][k] = pixel + + return data_out diff --git a/topi/python/topi/testing/upsampling_python.py b/topi/python/topi/testing/upsampling_python.py index 6ea7d6a..a34e541 100644 --- a/topi/python/topi/testing/upsampling_python.py +++ b/topi/python/topi/testing/upsampling_python.py @@ -53,3 +53,45 @@ def upsampling_python(data, scale, layout='NCHW'): output_np[b, :, :, c] = upsample_nearest(data[b, :, :, c], scale) return output_np raise ValueError("not support this layout {} yet".format(layout)) + +def upsample3d_nearest(arr, scale): + """ Populate the array by scale factor""" + d, h, w = arr.shape + out_d = int(round(d * scale[0])) + out_h = int(round(h * scale[1])) + out_w = int(round(w * scale[2])) + out = np.empty((out_d, out_h, out_w)) + for z in range(out_d): + for y in range(out_h): + for x in range(out_w): + in_z = math.floor(z / scale[0]) + in_y = math.floor(y / scale[1]) + in_x = math.floor(x / scale[2]) + out[z, y, x] = arr[in_z, in_y, in_x] + return out + +def upsampling3d_python(data, scale, layout='NCDHW'): + """ Python version of 3D scaling using nearest neighbour """ + + ishape = data.shape + if layout == 'NCDHW': + oshape = (ishape[0], ishape[1], + int(round(ishape[2]*scale[0])), + int(round(ishape[3]*scale[1])), + int(round(ishape[4]*scale[2]))) + output_np = np.zeros(oshape, dtype=data.dtype) + for b in range(oshape[0]): + for c in range(oshape[1]): + output_np[b, c, :, :, :] = upsample3d_nearest(data[b, c, :, :, :], scale) + return output_np + if layout == 'NDHWC': + oshape = (ishape[0], + int(round(ishape[1]*scale[0])), + int(round(ishape[2]*scale[1])), + int(round(ishape[3]*scale[2])), ishape[4]) + output_np = np.zeros(oshape, dtype=data.dtype) + for b in range(oshape[0]): + for c in range(oshape[4]): + output_np[b, :, :, :, c] = upsample3d_nearest(data[b, :, :, :, c], scale) + return output_np + raise ValueError("not support this layout {} yet".format(layout)) diff --git a/topi/tests/python/test_topi_resize.py b/topi/tests/python/test_topi_resize.py index 7c33526..10678a0 100644 --- a/topi/tests/python/test_topi_resize.py +++ b/topi/tests/python/test_topi_resize.py @@ -79,5 +79,68 @@ def test_resize(): verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="nearest_neighbor", align_corners=False) verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="nearest_neighbor", align_corners=False) + +def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth, out_height, out_width, + layout='NCDHW', coordinate_transformation_mode="half_pixel", method="trilinear"): + if layout == 'NCDHW': + A = tvm.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A', dtype='float32') + dtype = A.dtype + out_shape = (batch, in_channel, out_depth, out_height, out_width) + a_np = np.random.uniform(size=(batch, in_channel, in_depth, in_height, in_width)).astype(dtype) + elif layout == 'NDHWC': + A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A', dtype='float32') + dtype = A.dtype + out_shape = (batch, out_depth, out_height, out_width, in_channel) + a_np = np.random.uniform(size=(batch, in_depth, in_height, in_width, in_channel)).astype(dtype) + else: + raise NotImplementedError( + 'Layout not supported {} '.format(layout)) + + B = topi.image.resize3d(A, (out_depth, out_height, out_width), layout=layout, + coordinate_transformation_mode=coordinate_transformation_mode, method=method) + + if method == "trilinear": + b_np = topi.testing.trilinear_resize3d_python(a_np, (out_depth, out_height, out_width), layout, + coordinate_transformation_mode) + else: + scale_d = out_depth / in_depth + scale_h = out_height / in_height + scale_w = out_width / in_width + b_np = topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(B) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx) + f = tvm.build(s, [A, B], device) + f(a, b) + + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3) + + for device in get_all_backend(): + check_device(device) + + +def test_resize3d(): + # Trilinear + verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NCDHW') + verify_resize3d(1, 8, 16, 16, 16, 25, 25, 25, "NDHWC") + verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, 'NCDHW', "align_corners") + verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, 'NDHWC', "align_corners") + verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, 'NCDHW', "asymmetric") + verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, 'NDHWC', "asymmetric") + + # Nearest neighbor + verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NCDHW', method="nearest_neighbor") + verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NDHWC', method="nearest_neighbor") + + if __name__ == "__main__": test_resize() + test_resize3d() diff --git a/topi/tests/python/test_topi_upsampling.py b/topi/tests/python/test_topi_upsampling.py index 83909c0..f5b77b1 100644 --- a/topi/tests/python/test_topi_upsampling.py +++ b/topi/tests/python/test_topi_upsampling.py @@ -86,5 +86,73 @@ def test_upsampling(): verify_upsampling(2, 2, 32, 32, 3.0, 3.0, layout="NHWC", method="bilinear") verify_upsampling(1, 64, 22, 32, 3.0, 3.0, layout="NHWC", method="bilinear") +def verify_upsampling3d(batch, in_channel, in_depth, in_height, in_width, scale_d, scale_h, scale_w, + layout='NCDHW', method="nearest_neighbor"): + if layout == 'NCDHW': + A = tvm.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A') + dtype = A.dtype + out_shape = (batch, in_channel, int(round(in_depth*scale_d)), int(round(in_height*scale_h)), + int(round(in_width*scale_w))) + a_np = np.random.uniform(size=(batch, in_channel, in_depth, in_height, in_width)).astype(dtype) + elif layout == 'NDHWC': + A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A') + dtype = A.dtype + out_shape = (batch, int(round(in_depth*scale_d)), int(round(in_height*scale_h)), + int(round(in_width*scale_w)), in_channel) + a_np = np.random.uniform(size=(batch, in_depth, in_height, in_width, in_channel)).astype(dtype) + else: + raise NotImplementedError( + 'Layout not supported {} '.format(layout)) + + B = topi.nn.upsampling3d(A, scale_d, scale_h, scale_w, layout=layout, method=method, + coordinate_transformation_mode="half_pixel") + + if method == "trilinear": + out_size = (int(round(in_depth*scale_d)), int(round(in_height*scale_h)), int(round(in_width*scale_w))) + b_np = topi.testing.trilinear_resize3d_python(a_np, out_size, layout, + coordinate_transformation_mode="half_pixel") + else: + b_np = topi.testing.upsampling3d_python(a_np, (scale_d, scale_h, scale_w), layout) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(B) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx) + f = tvm.build(s, [A, B], device) + f(a, b) + + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) + + for device in get_all_backend(): + check_device(device) + +def test_upsampling3d(): + # nearest_neighbor - NCDHW + verify_upsampling3d(8, 8, 16, 16, 16, 2.0, 2.0, 2.0) + verify_upsampling3d(2, 16, 32, 32, 32, 3.0, 3.0, 3.0) + verify_upsampling3d(1, 8, 11, 16, 6, 1.954545497894287, 2.0, 1.5) + + ## nearest_neighbor - NDHWC + verify_upsampling3d(8, 8, 16, 16, 16, 2.0, 2.0, 2.0, layout="NDHWC") + verify_upsampling3d(2, 16, 32, 32, 32, 3.0, 3.0, 3.0, layout="NDHWC") + verify_upsampling3d(1, 8, 11, 16, 6, 1.954545497894287, 2.0, 1.5, layout="NDHWC") + + # trilinear - NCDHW + verify_upsampling3d(2, 2, 16, 16, 16, 2.0, 2.0, 2.0, method="trilinear") + verify_upsampling3d(2, 2, 32, 32, 32, 3.0, 3.0, 3.0, method="trilinear") + verify_upsampling3d(1, 2, 11, 16, 6, 1.954545497894287, 2.0, 1.5, method="trilinear") + + # trilinear - NDHWC + verify_upsampling3d(2, 2, 16, 16, 16, 2.0, 2.0, 2.0, layout="NDHWC", method="trilinear") + verify_upsampling3d(2, 2, 32, 32, 32, 3.0, 3.0, 3.0, layout="NDHWC", method="trilinear") + verify_upsampling3d(1, 2, 11, 16, 6, 1.954545497894287, 2.0, 1.5, layout="NDHWC", method="trilinear") + if __name__ == "__main__": test_upsampling() + test_upsampling3d() -- 2.7.4