# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, invalid-name, unused-argument, too-many-arguments, consider-using-in
-"""Backend compiler related feature registration"""
+"""Backend compiler related feature registration for dynamic relay ops in nn namespace"""
from __future__ import absolute_import
+from tvm import topi
+
+from tvm.runtime import convert
from tvm.te.hybrid import script
-from ...op import register_shape_func
-from ...op import register_broadcast_schedule
+from ...op import register_shape_func, register_compute
+from ...op import register_injective_schedule, register_broadcast_schedule
-# pad
+# upsampling
+@register_compute("dyn.nn.upsampling")
+def compute_upsampling(attrs, inputs, out_dtype):
+ data = inputs[0]
+ scale_h = inputs[1]
+ scale_w = inputs[2]
+ layout = attrs.layout
+ method = attrs.method
+ align_corners = attrs.align_corners
+ return [topi.nn.upsampling(data, scale_h, scale_w, layout,
+ method, align_corners, out_dtype.shape)]
+
+register_injective_schedule("dyn.nn.upsampling")
register_broadcast_schedule("dyn.nn.pad")
#####################
# Shape functions #
#####################
+# upsampling
+@script
+def _upsampling_shape_func(dshape, scale_h, scale_w, height_axis, width_axis, channel_axis):
+ out = output_tensor((4,), "int64")
+ out[0] = int64(dshape[0])
+ out[height_axis] = int64(round(dshape[height_axis] * scale_h[0]))
+ out[width_axis] = int64(round(dshape[width_axis] * scale_w[0]))
+ out[channel_axis] = int64(dshape[channel_axis])
+ return out
+
+@register_shape_func("dyn.nn.upsampling", True)
+def upsampling_shape_func(attrs, inputs, _):
+ """Shape function for upsampling. Supports NCHW and NHWC layouts."""
+ layout = attrs.layout
+ height_axis = width_axis = 1
+ for i, letter in enumerate(layout):
+ if letter == "H":
+ height_axis = i
+ if letter == "W":
+ width_axis = i
+ if letter == "C":
+ channel_axis = i
+ return [_upsampling_shape_func(inputs[0].shape, inputs[1], inputs[2],
+ convert(height_axis), convert(width_axis),
+ convert(channel_axis))]
+# pad
@script
def _dyn_pad_shape_func(data, pad_width):
ndim = len(data.shape)
data : tvm.relay.Expr
The input data to the operator.
- scale_h : tvm.relay.Expr
+ scale_h : tvm.relay.Expr or int or float
The scale factor for height upsampling.
- scale_w : tvm.relay.Expr
+ scale_w : tvm.relay.Expr or int or float
The scale factor for width upsampling.
layout : str, optional
result : tvm.relay.Expr
The computed result.
"""
+ if isinstance(scale_h, Expr) or isinstance(scale_w, Expr):
+ if not isinstance(scale_h, Expr):
+ scale_h = const(scale_h, "float64")
+ if not isinstance(scale_w, Expr):
+ scale_w = const(scale_w, "float64")
+ return _dyn_make.upsampling(data, scale_h, scale_w, layout, method, align_corners)
return _make.upsampling(data, scale_h, scale_w, layout, method, align_corners)
def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
- align_corners=False):
+ align_corners=False, output_shape=None):
"""Perform upsampling on the data.
Nearest neighbor and bilinear upsampling are supported.
"""
base_layout = layout[0:4]
if base_layout == "NCHW":
- out_shape = (simplify(topi.cast(te.round(data.shape[2] * scale_h), data.shape[2].dtype)),
- simplify(topi.cast(te.round(data.shape[3] * scale_w), data.shape[3].dtype)))
+ if not output_shape: #static case
+ scaled_h = data.shape[2] * scale_h
+ scaled_w = data.shape[3] * scale_w
+ reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[2].dtype)),
+ simplify(topi.cast(te.round(scaled_w), data.shape[3].dtype)))
+ else: #dynamic case -- we don't need to scale; already done in shape func
+ reshape_size = (simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)),
+ simplify(topi.cast(te.round(output_shape[3]), output_shape[3].dtype)))
elif layout == "NHWC":
- out_shape = (simplify(topi.cast(te.round(data.shape[1] * scale_h), data.shape[1].dtype)),
- simplify(topi.cast(te.round(data.shape[2] * scale_w), data.shape[2].dtype)))
+ if not output_shape: #static case
+ scaled_h = data.shape[1] * scale_h
+ scaled_w = data.shape[2] * scale_w
+ reshape_size = (simplify(topi.cast(te.round(scaled_h), data.shape[1].dtype)),
+ simplify(topi.cast(te.round(scaled_w), data.shape[2].dtype)))
+ else: #dynamic case
+ reshape_size = (simplify(topi.cast(te.round(output_shape[1]), output_shape[1].dtype)),
+ simplify(topi.cast(te.round(output_shape[2]), output_shape[2].dtype)))
+
else:
raise ValueError("not support this layout {} yet".format(layout))
coord_trans = "align_corners" if align_corners else "asymmetric"
- return topi.image.resize(data, out_shape, layout=layout,
- method=method, coordinate_transformation_mode=coord_trans)
+ return topi.image.resize(data, reshape_size, layout=layout,
+ method=method, coordinate_transformation_mode=coord_trans,
+ output_shape=output_shape)
def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor',
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file upsampling.cc
+ * \brief upsampling operator
+ */
+
+#include "../../nn/upsampling.h"
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/tir/data_layout.h>
+
+#include <vector>
+
+#include "../../op_common.h"
+
+namespace tvm {
+namespace relay {
+namespace dyn {
+
+bool UpSamplingRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ // types = [data_type, scale_h_type, scale_w_type, ret_type]
+ CHECK_EQ(types.size(), 4);
+ const auto* data = types[0].as<TensorTypeNode>();
+ const auto* scale_h = types[1].as<TensorTypeNode>();
+ const auto* scale_w = types[2].as<TensorTypeNode>();
+ if (data == nullptr) return false;
+ if (scale_h == nullptr) return false;
+ if (scale_w == nullptr) return false;
+
+ CHECK_EQ(data->shape.size(), 4);
+ CHECK_EQ(scale_h->shape.size(), 0);
+ CHECK_EQ(scale_w->shape.size(), 0);
+ static const Layout kNCHW("NCHW");
+
+ const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
+ CHECK(param);
+ const Layout in_layout(param->layout);
+
+ auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
+ CHECK(layout_converter.defined())
+ << "UpSampling only supports input layouts that are convertible from NCHW."
+ << " But got " << in_layout;
+
+ auto nchw_oshape = layout_converter.ForwardShape(data->shape);
+
+ nchw_oshape.Set(2, Any());
+ nchw_oshape.Set(3, Any());
+ auto oshape = layout_converter.BackwardShape(nchw_oshape);
+
+ reporter->Assign(types[3], TensorType(oshape, data->dtype));
+ return true;
+}
+
+// Positional relay function to create upsampling operator
+// used by frontend FFI.
+Expr MakeUpSampling(Expr data, Expr scale_h, Expr scale_w, String layout, String method,
+ bool align_corners) {
+ auto attrs = make_object<UpSamplingAttrs>();
+ attrs->layout = std::move(layout);
+ attrs->method = std::move(method);
+ attrs->align_corners = align_corners;
+
+ static const Op& op = Op::Get("dyn.nn.upsampling");
+ return Call(op, {data, scale_h, scale_w}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.upsampling").set_body_typed(MakeUpSampling);
+
+RELAY_REGISTER_OP("dyn.nn.upsampling")
+ .describe(
+ R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation.
+
+- **data**: data is 4D array of shape
+ (batch_size, channels, in_height, in_width) for NCHW
+ (batch_size, in_height, in_width, channels) for NHWC
+
+- **scale_h**: scale_h is an integer of the amount to scale height by
+
+- **scale_w**: scale_w is an integer of the amount to scale width by
+
+- **out**: Output is 4D array of shape
+ for layout NCHW
+ (batch_size, channels, in_height*scale, in_width*scale)
+
+ for layout NHWC
+ (batch_size, in_height*scale, in_width*scale, channels)
+
+)code" TVM_ADD_FILELINE)
+ .set_attrs_type<UpSamplingAttrs>()
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("scale_h", "double", "The scale for the height.")
+ .add_argument("scale_w", "double", "The scale for the width.")
+ .set_support_level(2)
+ .add_type_rel("DynamicUpSampling", UpSamplingRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+ UpsamplingInferCorrectLayout<UpSamplingAttrs>)
+ .set_attr<TOpPattern>("TOpPattern", kInjective);
+
+} // namespace dyn
+} // namespace relay
+} // namespace tvm
Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype);
+Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method,
+ bool align_corners);
+
Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
bool unbiased);
* \file upsampling.cc
* \brief upsampling operator
*/
+
+#include "upsampling.h"
+
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/tir/data_layout.h>
+#include <utility>
#include <vector>
#include "../op_common.h"
TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);
TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs);
-template <typename T>
-Array<Array<Layout> > UpsamplingInferCorrectLayout(const Attrs& attrs,
- const Array<Layout>& new_in_layouts,
- const Array<Layout>& old_in_layouts,
- const Array<tvm::relay::Type>& old_in_types) {
- // NOTE: Discard "const" qualifier here.
- T* params = const_cast<T*>(attrs.as<T>());
-
- if (new_in_layouts.defined()) {
- CHECK_EQ(new_in_layouts.size(), 1);
-
- Layout raw_layout(params->layout);
- Layout input = new_in_layouts[0];
- if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
- input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
- !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) &&
- (input.IndexOf(LayoutAxis::Get('D')) == -1 ||
- (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
- !input.Contains(LayoutAxis::Get('d'))))) {
- params->layout = input.name(); // modify self to follow the input layout
- }
- }
-
- Layout inferred_layout(params->layout);
- return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
-}
-
bool UpSamplingRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file src/relay/op/nn/upsampling.h
+ * \brief implementation of the InferCorrectLayout pass for upsampling
+ */
+
+#ifndef TVM_RELAY_OP_NN_UPSAMPLING_H_
+#define TVM_RELAY_OP_NN_UPSAMPLING_H_
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/tir/data_layout.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relay {
+
+template <typename T>
+Array<Array<Layout> > UpsamplingInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
+ // NOTE: Discard "const" qualifier here.
+ T* params = const_cast<T*>(attrs.as<T>());
+
+ if (new_in_layouts.defined()) {
+ CHECK_EQ(new_in_layouts.size(), 1);
+
+ Layout raw_layout(params->layout);
+ Layout input = new_in_layouts[0];
+ if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
+ input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
+ !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) &&
+ (input.IndexOf(LayoutAxis::Get('D')) == -1 ||
+ (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) &&
+ !input.Contains(LayoutAxis::Get('d'))))) {
+ params->layout = input.name(); // modify self to follow the input layout
+ }
+ }
+
+ Layout inferred_layout(params->layout);
+ return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
+}
+
+} // namespace relay
+} // namespace tvm
+
+#endif // TVM_RELAY_OP_NN_UPSAMPLING_H_
}
return Expr(nullptr);
}},
+ {Op::Get("dyn.nn.upsampling"),
+ [](const CallNode* call_node) {
+ const ConstantNode* scale_h = call_node->args[1].as<ConstantNode>();
+ const ConstantNode* scale_w = call_node->args[2].as<ConstantNode>();
+ if (scale_h && scale_w) {
+ CHECK_EQ(scale_h->data->ndim, 0);
+ CHECK_EQ(scale_w->data->ndim, 0);
+ const UpSamplingAttrs* param = call_node->attrs.as<UpSamplingAttrs>();
+ CHECK(param);
+ return MakeUpSampling(call_node->args[0], ToScalar(scale_h->data),
+ ToScalar(scale_w->data), param->layout, param->method,
+ param->align_corners);
+ }
+ return Expr(nullptr);
+ }},
{Op::Get("dyn.nn.pad"),
[](const CallNode* call_node) {
const ConstantNode* pad_width = call_node->args[1].as<ConstantNode>();
import tvm.topi.testing
from tvm.relay.testing import run_infer_type
+def test_dyn_upsampling_run():
+ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=False):
+
+ if layout == "NCHW":
+ (n, c, h, w) = dshape
+ x_data = np.random.uniform(size=(n, c, h, w)).astype("float32")
+
+ elif layout == "NHWC":
+ (n, h, w, c) = dshape
+ x_data = np.random.uniform(size=(n, h, w, c)).astype("float32")
+
+ if method == "nearest_neighbor":
+ ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h, scale_w), layout)
+ else:
+ ref_res = tvm.topi.testing.bilinear_resize_python(x_data, (int(round(h*scale_h)),
+ int(round(w*scale_w))), layout)
+ x = relay.Var("x", relay.TensorType(dshape, "float32"))
+ scale_h_var = relay.var("scale_h", relay.TensorType((), "float32"))
+ scale_w_var = relay.var("scale_h", relay.TensorType((), "float32"))
+
+ z = relay.nn.upsampling(x, scale_h_var, scale_w_var, method=method, layout=layout, align_corners=align_corners)
+ zz = run_infer_type(z)
+ func = relay.Function([x, scale_h_var, scale_w_var], z)
+
+ for target, ctx in ctx_list():
+ if "llvm" not in target: continue
+ for kind in ["vm", "debug"]:
+ mod = tvm.ir.IRModule.from_expr(func)
+ intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+ op_res = intrp.evaluate()(x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32"))
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6)
+
+ verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NCHW", "nearest_neighbor")
+ verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NCHW", "bilinear", True)
+ verify_upsampling((1, 16, 32, 32), 2.0, 2.0, "NHWC", "nearest_neighbor")
+ verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NHWC", "bilinear", True)
+
+#tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable
+def test_dyn_upsampling_infer_type_const():
+ n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
+
+ data = relay.var("data", relay.TensorType((n, c, h, w), "int8"))
+ scale_h = relay.Var("scale_h", relay.TensorType((), "float32"))
+ scale_w = relay.Var("scale_w", relay.TensorType((), "float32"))
+
+ z = relay.nn.upsampling(data, 2.0, scale_w)
+ zz = run_infer_type(z)
+ assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8")
+
def test_dyn_pad():
def verify_pad(dshape, pad_width, pad_val, dtype):
x = relay.var("x", relay.TensorType(dshape, dtype))
if __name__ == "__main__":
test_dyn_pad()
+ test_dyn_upsampling_infer_type_const()
+ test_dyn_upsampling_run()
verify_full(4, (1, 2, 3, 4), 'int32')
verify_full(4.0, (1, 2, 8, 10), 'float32')
+def test_dynamic_to_static_upsampling():
+ def verify_upsampling(data_shape, scale_h_val, scale_w_val, dtype):
+ x = relay.var("x", relay.TensorType(data_shape, dtype))
+ scale_h = relay.const(scale_h_val)
+ scale_w = relay.const(scale_w_val)
+ z = relay.nn.upsampling(x, scale_h, scale_w)
+
+ func = run_infer_type(relay.Function([x], z))
+ func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+
+ zz = func2.body
+ assert isinstance(zz, relay.Call)
+ assert zz.op == relay.op.get("nn.upsampling")
+
+ x_data = np.random.uniform(size=data_shape).astype(dtype)
+ ref_res = tvm.topi.testing.upsampling_python(x_data, (scale_h_val, scale_w_val), "NCHW")
+ verify_func(func2, [x_data], ref_res)
+
+ verify_upsampling((1, 16, 32, 32), 2, 2, 'int8')
+ verify_upsampling((1, 16, 32, 32), 4, 4, 'int32')
+
def test_dynamic_to_static_pad():
def verify_pad(data_shape, pad_width, pad_val, dtype):
x = relay.var("x", relay.TensorType(data_shape, dtype))
verify_pad((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), 2.0, "int32")
verify_pad((2, 7), ((1, 4), (2, 2)), 4.0, "float64")
-
if __name__ == "__main__":
test_dynamic_to_static_reshape()
test_dynamic_to_static_double_reshape()
test_dynamic_to_static_resize()
test_dynamic_to_static_one_hot()
test_dynamic_to_static_full()
+ test_dynamic_to_static_upsampling()
test_dynamic_to_static_pad()