[RELAY][PYTORCH]Resize3d, Upsample3d op support (#5633)
authorSamuel <siju.samuel@huawei.com>
Thu, 21 May 2020 00:35:28 +0000 (06:05 +0530)
committerGitHub <noreply@github.com>
Thu, 21 May 2020 00:35:28 +0000 (09:35 +0900)
include/tvm/relay/attrs/image.h
python/tvm/relay/frontend/pytorch.py
python/tvm/relay/op/image/_image.py
python/tvm/relay/op/image/image.py
src/relay/op/image/resize.cc
tests/python/frontend/pytorch/test_forward.py
tests/python/relay/test_op_level5.py

index b927c98..58fd44b 100644 (file)
@@ -65,6 +65,37 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
   }
 };
 
+/*! \brief Attributes used in image resize3d operator */
+struct Resize3dAttrs : public tvm::AttrsNode<Resize3dAttrs> {
+  Array<IndexExpr> size;
+  String layout;
+  String method;
+  String coordinate_transformation_mode;
+  DataType out_dtype;
+
+  TVM_DECLARE_ATTRS(Resize3dAttrs, "relay.attrs.Resize3dAttrs") {
+    TVM_ATTR_FIELD(size).set_default(NullValue<Array<IndexExpr> >()).describe("Output Size.");
+    TVM_ATTR_FIELD(layout).set_default("NCDHW").describe(
+        "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
+        "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
+        "dimensions respectively. Resize3d is applied on the 'D', 'H' and"
+        "'W' dimensions.");
+    TVM_ATTR_FIELD(method)
+        .set_default("trilinear")
+        .describe(
+            "Specify the mode to use for scaling."
+            "nearest_neighbor -  Nearest Neighbor"
+            "trilinear - Trilinear Interpolation");
+    TVM_ATTR_FIELD(coordinate_transformation_mode)
+        .set_default("half_pixel")
+        .describe(
+            "Describes how to transform the coordinate in the resized tensor"
+            "to the coordinate in the original tensor."
+            "Available options are half_pixel, align_corners and asymmetric");
+    TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("Output data type.");
+  }
+};
+
 /*! \brief Attributes used in image crop_and_resize operator */
 struct CropAndResizeAttrs : public tvm::AttrsNode<CropAndResizeAttrs> {
   Array<IndexExpr> crop_size;
index 46b5cec..28703da 100644 (file)
@@ -1426,6 +1426,32 @@ def _upsample(method):
 
     return _impl
 
+
+def _upsample3d(method):
+    def _impl(inputs, input_types):
+        if isinstance(inputs[1], _expr.Var):
+            out_size = _infer_shape(inputs[1])
+        elif isinstance(inputs[1], list):
+            infer_res = [_infer_value(size, {}) for size in inputs[1]]
+            out_size = [np.asscalar(res.asnumpy().astype(np.int))
+                        for res in infer_res]
+
+        data = inputs[0]
+
+        if len(inputs) > 2:
+            align_corners = inputs[2]
+        else:
+            align_corners = False
+
+        if align_corners:
+            coord_trans = "align_corners"
+        else:
+            coord_trans = "half_pixel"
+
+        return _op.image.resize3d(data, out_size, "NCDHW", method, coord_trans)
+    return _impl
+
+
 def _expand_as():
     def _impl(inputs, input_types):
         # TODO: maybe fix this
@@ -1796,6 +1822,8 @@ def _get_convert_map(prelude):
         "aten::detach"                          : _identity(),
         "aten::upsample_bilinear2d"             : _upsample("bilinear"),
         "aten::upsample_nearest2d"              : _upsample("nearest_neighbor"),
+        "aten::upsample_trilinear3d"            : _upsample3d("trilinear"),
+        "aten::upsample_nearest3d"              : _upsample3d("nearest_neighbor"),
         "aten::expand_as"                       : _expand_as(),
         "aten::lt"                              : _elemwise("less"),
         "aten::gt"                              : _elemwise("greater"),
index ba9d62a..290c0a2 100644 (file)
@@ -37,6 +37,18 @@ def compute_resize(attrs, inputs, out_type):
 reg.register_injective_schedule("image.resize")
 
 
+@reg.register_compute("image.resize3d")
+def compute_resize3d(attrs, inputs, out_type):
+    size = attrs.size
+    layout = attrs.layout
+    method = attrs.method
+    coord_trans = attrs.coordinate_transformation_mode
+    out_dtype = attrs.out_dtype
+    return [topi.image.resize3d(inputs[0], size, layout, method, coord_trans, out_dtype)]
+
+reg.register_injective_schedule("image.resize3d")
+
+
 # crop and resize
 @reg.register_compute("image.crop_and_resize")
 def compute_crop_and_resize(attrs, inputs, out_type):
index 097322c..49b35d8 100644 (file)
@@ -64,6 +64,52 @@ def resize(data,
     return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype)
 
 
+def resize3d(data,
+             size,
+             layout="NCDHW",
+             method="trilinear",
+             coordinate_transformation_mode="half_pixel",
+             out_dtype=None):
+    """Image resize 3D operator.
+
+    This operator takes data as input and does 3D scaling to the given scale factor.
+    In the default case, where the data_layout is `NCDHW`
+    with data of shape (n, c, d, h, w)
+    out will have a shape (n, c, size[0], size[1], size[2])
+
+    method indicates the algorithm to be used while calculating the out value
+    and method can be one of ("trilinear", "nearest_neighbor")
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    size: Tuple of Expr
+        The out size to which the image will be resized.
+
+    layout : str, optional
+        Layout of the input.
+
+    method : str, optional
+        Scale method to used [nearest_neighbor, trilinear].
+
+    coordinate_transformation_mode : string, optional
+        Describes how to transform the coordinate in the resized tensor
+        to the coordinate in the original tensor.
+        [half_pixel, align_corners, asymmetric]
+
+    out_dtype : str, optional
+        Type to return. If left None returns the same type as input.
+
+    Returns
+    -------
+    result: relay.Expr
+        The resized result.
+    """
+    return _make.resize3d(data, size, layout, method, coordinate_transformation_mode, out_dtype)
+
+
 def crop_and_resize(data,
                     boxes,
                     box_indices,
index 7ad96b4..7bddb29 100644 (file)
@@ -99,6 +99,77 @@ RELAY_REGISTER_OP("image.resize")
     .add_type_rel("Resize", ResizeRel)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
+TVM_REGISTER_NODE_TYPE(Resize3dAttrs);
+
+bool Resize3dRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                 const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  static const Layout kNCDHW("NCDHW");
+
+  const Resize3dAttrs* param = attrs.as<Resize3dAttrs>();
+  CHECK(param != nullptr);
+  const Layout in_layout(param->layout);
+  auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW);
+  CHECK(layout_converter.defined())
+      << "Resize3d only support input layouts that are convertible from NCDHW."
+      << " But got " << in_layout;
+
+  auto oshape = layout_converter.ForwardShape(data->shape);
+  oshape.Set(2, param->size[0]);
+  oshape.Set(3, param->size[1]);
+  oshape.Set(4, param->size[2]);
+
+  DataType out_dtype = param->out_dtype;
+  if (out_dtype.bits() == 0) {
+    out_dtype = data->dtype;
+  }
+
+  // assign output type
+  reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), out_dtype));
+  return true;
+}
+
+// Positional relay function to create image operator
+// used by frontend FFI.
+Expr MakeResize3d(Expr data, Array<IndexExpr> size, String layout, String method,
+                  String coordinate_transformation_mode, DataType out_dtype) {
+  auto attrs = make_object<Resize3dAttrs>();
+  attrs->size = std::move(size);
+  attrs->layout = std::move(layout);
+  attrs->method = std::move(method);
+  attrs->coordinate_transformation_mode = coordinate_transformation_mode;
+  attrs->out_dtype = out_dtype;
+  static const Op& op = Op::Get("image.resize3d");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.image._make.resize3d").set_body_typed(MakeResize3d);
+
+RELAY_REGISTER_OP("image.resize3d")
+    .describe(R"code(
+Perform resize3d to input array with nearest neighbour or bilinear interpolation.
+
+- **data**: data is 5D array of shape
+            (batch_size, channels, in_depth, in_height, in_width) for NCDHW
+            (batch_size, in_depth, in_height, in_width, channels) for NDHWC
+
+- **out**: Output is 5D array of shape
+           for layout NCDHW
+           (batch_size, channels, size[0], size[1], size[2])
+
+           for layout NDHWC
+           (batch_size, size[0], size[1], size[2], channels)
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<Resize3dAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_support_level(5)
+    .add_type_rel("Resize3d", Resize3dRel)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);
+
 TVM_REGISTER_NODE_TYPE(CropAndResizeAttrs);
 
 bool CropAndResizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
index 50c3ede..f1543f0 100644 (file)
@@ -1029,6 +1029,13 @@ def test_forward_reflection_pad2d():
     verify_model(torch.nn.ReflectionPad2d((1, 3, 2, 4)).eval(), inp)
 
 
+def test_forward_upsample3d():
+    inp = torch.arange(1, 9, dtype=torch.float32).view(1, 1, 2, 2, 2)
+    verify_model(torch.nn.Upsample(scale_factor=2, mode='nearest').eval(), inp)
+    verify_model(torch.nn.Upsample(scale_factor=2, mode='trilinear').eval(), inp)
+    verify_model(torch.nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True).eval(), inp)
+
+
 def test_conv3d():
     for ishape in [(1, 32, 16, 16, 16),
                    (1, 32, 9, 15, 15),
@@ -2191,6 +2198,7 @@ if __name__ == "__main__":
     test_forward_chunk()
     test_forward_split()
     test_upsample()
+    test_forward_upsample3d()
     test_to()
     test_forward_reflection_pad2d()
     test_adaptive_pool3d()
index b29b696..c9d7d42 100644 (file)
@@ -68,6 +68,48 @@ def test_resize():
         for layout in ["NHWC", "NCHW"]:
             verify_resize((1, 4, 4, 4), 2, method, layout)
 
+def test_resize3d_infer_type():
+    n, c, d, h, w = te.size_var("n"), te.size_var("c"), te.size_var("d"), te.size_var("h"), te.size_var("w")
+    x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8"))
+    td, th, tw = te.var("td"), te.var("th"), te.var("tw")
+    z = relay.image.resize3d(x, (td, th, tw))
+    zz = run_infer_type(z)
+    assert zz.checked_type == relay.TensorType((n, c, td, th, tw), "int8")
+
+    x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8"))
+    z= relay.image.resize3d(x, (10, 10, 20), "NCDHW", "trilinear", "align_corners")
+    assert "size=" in z.astext()
+    zz = run_infer_type(z)
+    assert zz.checked_type == relay.TensorType((n, c, 10, 10, 20), "int8")
+
+def test_resize3d():
+    def verify_resize(dshape, scale, method, layout):
+        if layout == "NDHWC":
+            size = (dshape[1] * scale, dshape[2] * scale, dshape[3] * scale)
+        else:
+            size = (dshape[2] * scale, dshape[3] * scale, dshape[4] * scale)
+
+        x_data = np.random.uniform(size=dshape).astype("float32")
+        if method == "trilinear":
+            ref_res = topi.testing.trilinear_resize3d_python(x_data, size, layout)
+        else:
+            ref_res = topi.testing.upsampling3d_python(x_data, (scale, scale, scale), layout)
+        x = relay.var("x", relay.TensorType(dshape, "float32"))
+        z = relay.image.resize3d(x, size, layout, method, "align_corners")
+        assert "size=" in z.astext()
+        zz = run_infer_type(z)
+        assert zz.checked_type == relay.TensorType(ref_res.shape, "float32")
+        func = relay.Function([x], z)
+
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(func)(x_data)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4)
+    for method in ["trilinear", "nearest_neighbor"]:
+        for layout in ["NDHWC", "NCDHW"]:
+            verify_resize((1, 4, 4, 4, 4), 2, method, layout)
+
 def test_crop_and_resize():
     def verify_crop_and_resize(img_shape, boxes, box_indices, crop_size,
                                layout, method, extrapolation_value=0.0):
@@ -784,6 +826,8 @@ def test_dilation2d_run():
 if __name__ == "__main__":
     test_resize_infer_type()
     test_resize()
+    test_resize3d_infer_type()
+    test_resize3d()
     test_crop_and_resize()
     test_multibox_prior()
     test_multibox_transform_loc()