}
}; // struct ReshapeAttrs
+struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
+ Integer axis;
+
+ TVM_DECLARE_ATTRS(ScatterAttrs, "relay.attrs.ScatterAttrs") {
+ TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
+ }
+};
+
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer axis;
std::string mode;
return _op.gather_nd(inputs[0], inputs[1])
+class Scatter(OnnxOpConverter):
+ """ Operator converter for Scatter.
+ """
+
+ @classmethod
+ def _impl_v1(cls, inputs, attr, params):
+ axis = attr.get('axis', 0)
+ return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
+
+
class Greater(OnnxOpConverter):
""" Operator logical greater.
"""
'SpaceToDepth': SpaceToDepth.get_converter(opset),
'Gather': Gather.get_converter(opset),
'GatherND': GatherND.get_converter(opset),
+ 'Scatter': Scatter.get_converter(opset),
+ 'ScatterElements': Scatter.get_converter(opset),
'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
'Unsqueeze': Unsqueeze.get_converter(opset),
'Pad': Pad.get_converter(opset),
from . import op as _reg
from . import strategy
from .op import OpPattern
+from ._tensor import elemwise_shape_func
_reg.register_broadcast_schedule("broadcast_to")
_reg.register_broadcast_schedule("broadcast_to_like")
_reg.register_schedule("argwhere", strategy.schedule_argwhere)
+# scatter
+@_reg.register_compute("scatter")
+def compute_scatter(attrs, inputs, output_type):
+ """Compute definition of scatter"""
+ return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)]
+
+_reg.register_schedule("scatter", strategy.schedule_scatter)
+
#####################
# Shape functions #
#####################
return [_argwhere_shape_func_5d(inputs[0])]
return ValueError("Does not support rank higher than 5 in argwhere")
+_reg.register_shape_func("scatter", False, elemwise_shape_func)
+
@script
def _layout_transform_shape_func(data_shape,
out_layout_len,
with target:
return topi.generic.schedule_argwhere(outs)
+# scatter
+@generic_func
+def schedule_scatter(attrs, outs, target):
+ """schedule scatter"""
+ with target:
+ return topi.generic.schedule_scatter(outs)
+
# bitserial_conv2d
def wrap_compute_bitserial_conv2d(topi_compute):
"""wrap bitserial_conv2d topi compute"""
"""
return _make.argwhere(condition)
+def scatter(data, indices, updates, axis):
+ """Update data at positions defined by indices with values in updates
+
+ Parameters
+ ----------
+ data : relay.Expr
+ The input data to the operator.
+
+ indices : relay.Expr
+ The index locations to update.
+
+ updates : relay.Expr
+ The values to update.
+
+ axis : int
+ The axis to scatter on
+
+ Returns
+ -------
+ ret : relay.Expr
+ The computed result.
+ """
+ return _make.scatter(data, indices, updates, axis)
+
def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_support_level(10);
+// Scatter
+TVM_REGISTER_NODE_TYPE(ScatterAttrs);
+
+// Scatter
+bool ScatterRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(num_inputs, 3);
+ CHECK_EQ(types.size(), 4);
+ auto data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) {
+ return false;
+ }
+ auto indices = types[1].as<TensorTypeNode>();
+ if (indices == nullptr) {
+ return false;
+ }
+ auto updates = types[2].as<TensorTypeNode>();
+ if (updates == nullptr) {
+ return false;
+ }
+ CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
+ const auto param = attrs.as<ScatterAttrs>();
+ CHECK(param != nullptr);
+ reporter->Assign(types[3], TensorType(data->shape, data->dtype));
+ return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.scatter")
+ .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) {
+ auto attrs = make_object<ScatterAttrs>();
+ attrs->axis = std::move(axis);
+ static const Op& op = Op::Get("scatter");
+ return Call(op, {data, indices, updates}, Attrs(attrs), {});
+ });
+
+RELAY_REGISTER_OP("scatter")
+ .describe(
+ R"doc(Update data at positions defined by indices with values in updates)doc" TVM_ADD_FILELINE)
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "The input data tensor.")
+ .add_argument("indicies", "Tensor", "The indicies location tensor.")
+ .add_argument("updates", "Tensor", "The values to update the input with.")
+ .add_type_rel("Scatter", ScatterRel)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_support_level(10);
+
// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);
verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
+def verify_scatter(in_shape, indices, axis):
+ x = np.random.uniform(size=in_shape).astype("float32")
+ indices = np.array(indices, dtype="int32")
+ updates = np.random.uniform(size=indices.shape).astype("float32")
+
+ y = helper.make_node("ScatterElements", ['data', 'indices', 'updates'], ['output'], axis=axis)
+
+ graph = helper.make_graph([y],
+ 'scatter_test',
+ inputs=[helper.make_tensor_value_info("data",
+ TensorProto.FLOAT, list(in_shape)),
+ helper.make_tensor_value_info("indices",
+ TensorProto.INT32, list(indices.shape)),
+ helper.make_tensor_value_info("updates",
+ TensorProto.FLOAT, list(indices.shape))],
+ outputs=[helper.make_tensor_value_info("output",
+ TensorProto.FLOAT, list(in_shape))])
+ model = helper.make_model(graph, producer_name='scatter_test')
+ onnx_out = get_onnxruntime_output(model, [x, indices, updates])
+
+ for target, ctx in ctx_list():
+ tvm_out = get_tvm_output(
+ model, [x, indices, updates], target, ctx, onnx_out[0].shape)
+ tvm.testing.assert_allclose(onnx_out[0], tvm_out)
+
+
+def test_scatter():
+ verify_scatter((4,), [1], 0)
+ verify_scatter((1, 4), [[0]], 0)
+ verify_scatter((4,), [2, 3], 0)
+ verify_scatter((2, 2), [[1, 0], [0, 1]], 1)
+ verify_scatter((3, 3, 3), [[[-1, -3]]], -1)
+ verify_scatter((4, 3, 5, 6), [[[[2, 1, 0, 0]]]], 0)
+
+
def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None):
if axes:
y = helper.make_node(
test_batch_matmul()
test_gather()
test_gather_nd()
+ test_scatter()
test_lrn()
test_instance_norm()
test_upsample()
verify_reverse((2, 3, 4), -1)
+def test_scatter():
+
+ def ref_scatter(data, indices, updates, axis=0):
+ idx = np.indices(indices.shape).reshape(indices.ndim, -1)
+
+ updated_idx = np.copy(idx)
+ indices = indices.reshape(-1)
+ for i in range(len(indices)):
+ updated_idx[axis, i] = indices[i]
+ scattered = np.copy(data)
+ scattered[tuple(updated_idx)] = updates[tuple(idx)]
+ return scattered
+
+ def verify_scatter(dshape, ishape, axis=0):
+ d = relay.var("d", relay.TensorType(dshape, "float32"))
+ i = relay.var("i", relay.TensorType(ishape, "int64"))
+ u = relay.var("u", relay.TensorType(ishape, "float32"))
+ z = relay.op.scatter(d, i, u, axis)
+
+ func = relay.Function([d, i, u], z)
+
+ data_np = np.random.uniform(size=dshape).astype("float32")
+ updates_np = np.random.uniform(size=ishape).astype("float32")
+ indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")
+
+ ref_res = ref_scatter(data_np, indices_np, updates_np, axis)
+ # TODO(mbrookhart): expand testing when adding more backend schedules
+ for target, ctx in [("llvm", tvm.cpu())]:
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
+ tvm.testing.assert_allclose(
+ op_res.asnumpy(), ref_res, rtol=1e-5)
+
+ verify_scatter((10, ), (10, ), 0)
+ verify_scatter((10, 5), (10, 5), -2)
+ verify_scatter((10, 5), (10, 5), -1)
+ verify_scatter((10, 5), (3, 5), 0)
+ verify_scatter((12, 4), (7, 2), 1)
+ verify_scatter((2, 3, 4), (1, 3, 4), 0)
+ verify_scatter((2, 3, 4), (2, 1, 4), 1)
+ verify_scatter((2, 3, 4), (2, 3, 1), 2)
+ verify_scatter((2, 3, 4, 5), (1, 3, 4, 5), 0)
+ verify_scatter((6, 3, 4, 5), (2, 3, 4, 5), 1)
+ verify_scatter((2, 3, 8, 5), (2, 3, 1, 1), 2)
+ verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)
+
+
def test_gather_nd():
def verify_gather_nd(xshape, yshape, y_data):
x = relay.var("x", relay.TensorType(xshape, "float32"))
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
- tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6)
for method in ["bilinear", "nearest_neighbor"]:
for layout in ["NHWC", "NCHW"]:
verify_resize((1, 4, 4, 4), 2, method, layout)
from .transform import *
from .broadcast import *
from .sort import *
+from .scatter import *
from .argwhere import *
from . import generic
from . import nn
The computation schedule for the op.
"""
return _default_schedule(outs, False)
+
+
+def schedule_scatter(outs):
+ """Schedule for scatter operator.
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of scatter.
+
+ Returns
+ -------
+ s: Schedule
+ The computation schedule for the op.
+ """
+ return _default_schedule(outs, False)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
+"""Scatter operator"""
+from tvm.te import hybrid
+
+
+@hybrid.script
+def _scatter_1d(data, indices, updates):
+ out = output_tensor(data.shape, data.dtype)
+ for i in range(data.shape[0]):
+ out[i] = data[i]
+ for i in range(indices.shape[0]):
+ out[indices[i] if indices[i] >= 0 else indices[i] +
+ data.shape[0]] = updates[i]
+ return out
+
+
+@hybrid.script
+def _scatter_2d(data, indices, updates, axis):
+ out = output_tensor(data.shape, data.dtype)
+ for i in const_range(data.shape[0]):
+ for j in const_range(data.shape[1]):
+ out[i, j] = data[i, j]
+ if axis == 0:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ out[indices[i, j] if indices[i, j] >=
+ 0 else indices[i, j] + data.shape[axis], j] = updates[i, j]
+ else:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ out[i, indices[i, j] if indices[i, j] >=
+ 0 else indices[i, j] + data.shape[axis]] = updates[i, j]
+
+ return out
+
+
+@hybrid.script
+def _scatter_3d(data, indices, updates, axis):
+ out = output_tensor(data.shape, data.dtype)
+ for i in const_range(data.shape[0]):
+ for j in const_range(data.shape[1]):
+ for k in const_range(data.shape[2]):
+ out[i, j, k] = data[i, j, k]
+ if axis == 0:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ out[indices[i, j, k] if indices[i, j, k] >=
+ 0 else indices[i, j, k] + data.shape[axis], j, k] = updates[i, j, k]
+ elif axis == 1:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ out[i, indices[i, j, k] if indices[i, j, k] >=
+ 0 else indices[i, j, k] + data.shape[axis], k] = updates[i, j, k]
+ else:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ out[i, j, indices[i, j, k] if indices[i, j, k] >=
+ 0 else indices[i, j, k] + data.shape[axis]] = updates[i, j, k]
+
+ return out
+
+
+@hybrid.script
+def _scatter_4d(data, indices, updates, axis):
+ out = output_tensor(data.shape, data.dtype)
+ for i in const_range(data.shape[0]):
+ for j in const_range(data.shape[1]):
+ for k in const_range(data.shape[2]):
+ for l in const_range(data.shape[3]):
+ out[i, j, k, l] = data[i, j, k, l]
+
+ if axis == 0:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ for l in const_range(indices.shape[3]):
+ out[indices[i, j, k, l] if indices[i, j, k, l] >=
+ 0 else indices[i, j, k, l] + data.shape[axis],
+ j, k, l] = updates[i, j, k, l]
+ elif axis == 1:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ for l in const_range(indices.shape[3]):
+ out[i,
+ indices[i, j, k, l] if indices[i, j, k, l] >=
+ 0 else indices[i, j, k, l] + data.shape[axis],
+ k, l] = updates[i, j, k, l]
+ elif axis == 2:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ for l in const_range(indices.shape[3]):
+ out[i, j,
+ indices[i, j, k, l] if indices[i, j, k, l] >=
+ 0 else indices[i, j, k, l] + data.shape[axis],
+ l] = updates[i, j, k, l]
+ else:
+ for i in range(indices.shape[0]):
+ for j in range(indices.shape[1]):
+ for k in const_range(indices.shape[2]):
+ for l in const_range(indices.shape[3]):
+ out[i, j, k,
+ indices[i, j, k, l] if indices[i, j, k, l] >=
+ 0 else indices[i, j, k, l] + data.shape[axis]
+ ] = updates[i, j, k, l]
+
+ return out
+
+
+def scatter(data, indices, updates, axis=0):
+ """Update data at positions defined by indices with values in updates
+
+ Parameters
+ ----------
+ data : relay.Expr
+ The input data to the operator.
+
+ indices : relay.Expr
+ The index locations to update.
+
+ updates : relay.Expr
+ The values to update.
+
+ axis : int
+ The axis to scatter on
+
+ Returns
+ -------
+ ret : relay.Expr
+ The computed result.
+ """
+ if axis < 0:
+ axis += len(data.shape)
+ assert axis >= 0
+ assert axis < len(data.shape)
+
+ if len(data.shape) == 1:
+ return _scatter_1d(data, indices, updates)
+ if len(data.shape) == 2:
+ return _scatter_2d(data, indices, updates, axis)
+ if len(data.shape) == 3:
+ return _scatter_3d(data, indices, updates, axis)
+ if len(data.shape) == 4:
+ return _scatter_4d(data, indices, updates, axis)
+ raise ValueError("scatter only support for 1-4 dimensions")