Add Scatter to Topi/Relay/ONNX via hybrid script (#5619)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Tue, 9 Jun 2020 22:33:57 +0000 (15:33 -0700)
committerGitHub <noreply@github.com>
Tue, 9 Jun 2020 22:33:57 +0000 (07:33 +0900)
* I can construct scatter but not embed it in a Relay Graph

* working 1-4 dimesion scatter

* add scatter to ONNX

fix lint

* isolate tests to cpu backend

* Fix i386 test

* fix gpu tolerance

* use elemwise_shape_func for scatter

* fix incorrect rebase

12 files changed:
include/tvm/relay/attrs/transform.h
python/tvm/relay/frontend/onnx.py
python/tvm/relay/op/_transform.py
python/tvm/relay/op/strategy/generic.py
python/tvm/relay/op/transform.py
src/relay/op/tensor/transform.cc
tests/python/frontend/onnx/test_forward.py
tests/python/relay/test_op_level3.py
tests/python/relay/test_op_level5.py
topi/python/topi/__init__.py
topi/python/topi/generic/search.py
topi/python/topi/scatter.py [new file with mode: 0644]

index d709ff2..b0d7de5 100644 (file)
@@ -93,6 +93,14 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
   }
 };  // struct ReshapeAttrs
 
+struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
+  Integer axis;
+
+  TVM_DECLARE_ATTRS(ScatterAttrs, "relay.attrs.ScatterAttrs") {
+    TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
+  }
+};
+
 struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
   Integer axis;
   std::string mode;
index 08027a2..42f28d4 100644 (file)
@@ -1058,6 +1058,16 @@ class GatherND(OnnxOpConverter):
         return _op.gather_nd(inputs[0], inputs[1])
 
 
+class Scatter(OnnxOpConverter):
+    """ Operator converter for Scatter.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        axis = attr.get('axis', 0)
+        return _op.scatter(inputs[0], inputs[1], inputs[2], axis)
+
+
 class Greater(OnnxOpConverter):
     """ Operator logical greater.
     """
@@ -1863,6 +1873,8 @@ def _get_convert_map(opset):
         'SpaceToDepth': SpaceToDepth.get_converter(opset),
         'Gather': Gather.get_converter(opset),
         'GatherND': GatherND.get_converter(opset),
+        'Scatter': Scatter.get_converter(opset),
+        'ScatterElements': Scatter.get_converter(opset),
         'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
         'Unsqueeze': Unsqueeze.get_converter(opset),
         'Pad': Pad.get_converter(opset),
index 1d9253f..b1cfe50 100644 (file)
@@ -26,6 +26,7 @@ from topi.util import get_const_int, get_const_tuple
 from . import op as _reg
 from . import strategy
 from .op import OpPattern
+from ._tensor import elemwise_shape_func
 
 _reg.register_broadcast_schedule("broadcast_to")
 _reg.register_broadcast_schedule("broadcast_to_like")
@@ -88,6 +89,14 @@ def compute_argwhere(attrs, inputs, output_type):
 
 _reg.register_schedule("argwhere", strategy.schedule_argwhere)
 
+# scatter
+@_reg.register_compute("scatter")
+def compute_scatter(attrs, inputs, output_type):
+    """Compute definition of scatter"""
+    return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)]
+
+_reg.register_schedule("scatter", strategy.schedule_scatter)
+
 #####################
 #  Shape functions  #
 #####################
@@ -453,6 +462,8 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
         return [_argwhere_shape_func_5d(inputs[0])]
     return ValueError("Does not support rank higher than 5 in argwhere")
 
+_reg.register_shape_func("scatter", False, elemwise_shape_func)
+
 @script
 def _layout_transform_shape_func(data_shape,
                                  out_layout_len,
index de808d1..f523f66 100644 (file)
@@ -774,6 +774,13 @@ def schedule_argwhere(attrs, outs, target):
     with target:
         return topi.generic.schedule_argwhere(outs)
 
+# scatter
+@generic_func
+def schedule_scatter(attrs, outs, target):
+    """schedule scatter"""
+    with target:
+        return topi.generic.schedule_scatter(outs)
+
 # bitserial_conv2d
 def wrap_compute_bitserial_conv2d(topi_compute):
     """wrap bitserial_conv2d topi compute"""
index 1ee2bdb..e1b5627 100644 (file)
@@ -238,6 +238,30 @@ def argwhere(condition):
     """
     return _make.argwhere(condition)
 
+def scatter(data, indices, updates, axis):
+    """Update data at positions defined by indices with values in updates
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    indices : relay.Expr
+        The index locations to update.
+
+    updates : relay.Expr
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    Returns
+    -------
+    ret : relay.Expr
+        The computed result.
+    """
+    return _make.scatter(data, indices, updates, axis)
+
 def reshape_like(data, shape_like):
     """Reshapes the input array by the size of another array.
     For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
index 136ae00..6544468 100644 (file)
@@ -780,6 +780,53 @@ non-zero)doc" TVM_ADD_FILELINE)
     .set_attr<TOpPattern>("TOpPattern", kOpaque)
     .set_support_level(10);
 
+// Scatter
+TVM_REGISTER_NODE_TYPE(ScatterAttrs);
+
+// Scatter
+bool ScatterRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  CHECK_EQ(num_inputs, 3);
+  CHECK_EQ(types.size(), 4);
+  auto data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) {
+    return false;
+  }
+  auto indices = types[1].as<TensorTypeNode>();
+  if (indices == nullptr) {
+    return false;
+  }
+  auto updates = types[2].as<TensorTypeNode>();
+  if (updates == nullptr) {
+    return false;
+  }
+  CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
+  const auto param = attrs.as<ScatterAttrs>();
+  CHECK(param != nullptr);
+  reporter->Assign(types[3], TensorType(data->shape, data->dtype));
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.scatter")
+    .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) {
+      auto attrs = make_object<ScatterAttrs>();
+      attrs->axis = std::move(axis);
+      static const Op& op = Op::Get("scatter");
+      return Call(op, {data, indices, updates}, Attrs(attrs), {});
+    });
+
+RELAY_REGISTER_OP("scatter")
+    .describe(
+        R"doc(Update data at positions defined by indices with values in updates)doc" TVM_ADD_FILELINE)
+    .set_num_inputs(3)
+    .add_argument("data", "Tensor", "The input data tensor.")
+    .add_argument("indicies", "Tensor", "The indicies location tensor.")
+    .add_argument("updates", "Tensor", "The values to update the input with.")
+    .add_type_rel("Scatter", ScatterRel)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_support_level(10);
+
 // Take
 TVM_REGISTER_NODE_TYPE(TakeAttrs);
 
index 80c7253..178f059 100644 (file)
@@ -408,6 +408,41 @@ def test_gather():
     verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32')
 
 
+def verify_scatter(in_shape, indices, axis):
+    x = np.random.uniform(size=in_shape).astype("float32")
+    indices = np.array(indices, dtype="int32")
+    updates = np.random.uniform(size=indices.shape).astype("float32")
+
+    y = helper.make_node("ScatterElements", ['data', 'indices', 'updates'], ['output'], axis=axis)
+
+    graph = helper.make_graph([y],
+                              'scatter_test',
+                              inputs=[helper.make_tensor_value_info("data",
+                                                                    TensorProto.FLOAT, list(in_shape)),
+                                      helper.make_tensor_value_info("indices",
+                                                                    TensorProto.INT32, list(indices.shape)),
+                                      helper.make_tensor_value_info("updates",
+                                                                    TensorProto.FLOAT, list(indices.shape))],
+                              outputs=[helper.make_tensor_value_info("output",
+                                                                     TensorProto.FLOAT, list(in_shape))])
+    model = helper.make_model(graph, producer_name='scatter_test')
+    onnx_out = get_onnxruntime_output(model, [x, indices, updates])
+
+    for target, ctx in ctx_list():
+        tvm_out = get_tvm_output(
+            model, [x, indices, updates], target, ctx, onnx_out[0].shape)
+        tvm.testing.assert_allclose(onnx_out[0], tvm_out)
+
+
+def test_scatter():
+    verify_scatter((4,), [1], 0)
+    verify_scatter((1, 4), [[0]], 0)
+    verify_scatter((4,), [2, 3], 0)
+    verify_scatter((2, 2), [[1, 0], [0, 1]], 1)
+    verify_scatter((3, 3, 3), [[[-1, -3]]], -1)
+    verify_scatter((4, 3, 5, 6), [[[[2, 1, 0, 0]]]], 0)
+
+
 def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None):
     if axes:
         y = helper.make_node(
@@ -2823,6 +2858,7 @@ if __name__ == '__main__':
     test_batch_matmul()
     test_gather()
     test_gather_nd()
+    test_scatter()
     test_lrn()
     test_instance_norm()
     test_upsample()
index 52ff45b..d778312 100644 (file)
@@ -663,6 +663,54 @@ def test_reverse():
     verify_reverse((2, 3, 4), -1)
 
 
+def test_scatter():
+
+    def ref_scatter(data, indices, updates, axis=0):
+        idx = np.indices(indices.shape).reshape(indices.ndim, -1)
+
+        updated_idx = np.copy(idx)
+        indices = indices.reshape(-1)
+        for i in range(len(indices)):
+            updated_idx[axis, i] = indices[i]
+        scattered = np.copy(data)
+        scattered[tuple(updated_idx)] = updates[tuple(idx)]
+        return scattered
+
+    def verify_scatter(dshape, ishape, axis=0):
+        d = relay.var("d", relay.TensorType(dshape, "float32"))
+        i = relay.var("i", relay.TensorType(ishape, "int64"))
+        u = relay.var("u", relay.TensorType(ishape, "float32"))
+        z = relay.op.scatter(d, i, u, axis)
+
+        func = relay.Function([d, i, u], z)
+
+        data_np = np.random.uniform(size=dshape).astype("float32")
+        updates_np = np.random.uniform(size=ishape).astype("float32")
+        indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64")
+
+        ref_res = ref_scatter(data_np, indices_np, updates_np, axis)
+        # TODO(mbrookhart): expand testing when adding more backend schedules
+        for target, ctx in [("llvm", tvm.cpu())]:
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res, rtol=1e-5)
+
+    verify_scatter((10, ), (10, ), 0)
+    verify_scatter((10, 5), (10, 5), -2)
+    verify_scatter((10, 5), (10, 5), -1)
+    verify_scatter((10, 5), (3, 5), 0)
+    verify_scatter((12, 4), (7, 2), 1)
+    verify_scatter((2, 3, 4), (1, 3, 4), 0)
+    verify_scatter((2, 3, 4), (2, 1, 4), 1)
+    verify_scatter((2, 3, 4), (2, 3, 1), 2)
+    verify_scatter((2, 3, 4, 5), (1, 3, 4, 5), 0)
+    verify_scatter((6, 3, 4, 5), (2, 3, 4, 5), 1)
+    verify_scatter((2, 3, 8, 5), (2, 3, 1, 1), 2)
+    verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)
+
+
 def test_gather_nd():
     def verify_gather_nd(xshape, yshape, y_data):
         x = relay.var("x", relay.TensorType(xshape, "float32"))
index 40842eb..14d43c0 100644 (file)
@@ -63,7 +63,7 @@ def test_resize():
             for kind in ["graph", "debug"]:
                 intrp = relay.create_executor(kind, ctx=ctx, target=target)
                 op_res = intrp.evaluate(func)(x_data)
-                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4)
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6)
     for method in ["bilinear", "nearest_neighbor"]:
         for layout in ["NHWC", "NCHW"]:
             verify_resize((1, 4, 4, 4), 2, method, layout)
index 2f06f4e..56c3a74 100644 (file)
@@ -39,6 +39,7 @@ from .reduction import *
 from .transform import *
 from .broadcast import *
 from .sort import *
+from .scatter import *
 from .argwhere import *
 from . import generic
 from . import nn
index 91b7635..895dadb 100644 (file)
@@ -34,3 +34,19 @@ def schedule_argwhere(outs):
       The computation schedule for the op.
     """
     return _default_schedule(outs, False)
+
+
+def schedule_scatter(outs):
+    """Schedule for scatter operator.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+      The computation graph description of scatter.
+
+    Returns
+    -------
+    s: Schedule
+      The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
diff --git a/topi/python/topi/scatter.py b/topi/python/topi/scatter.py
new file mode 100644 (file)
index 0000000..e4e9886
--- /dev/null
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
+"""Scatter operator"""
+from tvm.te import hybrid
+
+
+@hybrid.script
+def _scatter_1d(data, indices, updates):
+    out = output_tensor(data.shape, data.dtype)
+    for i in range(data.shape[0]):
+        out[i] = data[i]
+    for i in range(indices.shape[0]):
+        out[indices[i] if indices[i] >= 0 else indices[i] +
+            data.shape[0]] = updates[i]
+    return out
+
+
+@hybrid.script
+def _scatter_2d(data, indices, updates, axis):
+    out = output_tensor(data.shape, data.dtype)
+    for i in const_range(data.shape[0]):
+        for j in const_range(data.shape[1]):
+            out[i, j] = data[i, j]
+    if axis == 0:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                out[indices[i, j] if indices[i, j] >=
+                    0 else indices[i, j] + data.shape[axis], j] = updates[i, j]
+    else:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                out[i, indices[i, j] if indices[i, j] >=
+                    0 else indices[i, j] + data.shape[axis]] = updates[i, j]
+
+    return out
+
+
+@hybrid.script
+def _scatter_3d(data, indices, updates, axis):
+    out = output_tensor(data.shape, data.dtype)
+    for i in const_range(data.shape[0]):
+        for j in const_range(data.shape[1]):
+            for k in const_range(data.shape[2]):
+                out[i, j, k] = data[i, j, k]
+    if axis == 0:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    out[indices[i, j, k] if indices[i, j, k] >=
+                        0 else indices[i, j, k] + data.shape[axis], j, k] = updates[i, j, k]
+    elif axis == 1:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    out[i, indices[i, j, k] if indices[i, j, k] >=
+                        0 else indices[i, j, k] + data.shape[axis], k] = updates[i, j, k]
+    else:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    out[i, j, indices[i, j, k] if indices[i, j, k] >=
+                        0 else indices[i, j, k] + data.shape[axis]] = updates[i, j, k]
+
+    return out
+
+
+@hybrid.script
+def _scatter_4d(data, indices, updates, axis):
+    out = output_tensor(data.shape, data.dtype)
+    for i in const_range(data.shape[0]):
+        for j in const_range(data.shape[1]):
+            for k in const_range(data.shape[2]):
+                for l in const_range(data.shape[3]):
+                    out[i, j, k, l] = data[i, j, k, l]
+
+    if axis == 0:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    for l in const_range(indices.shape[3]):
+                        out[indices[i, j, k, l] if indices[i, j, k, l] >=
+                            0 else indices[i, j, k, l] + data.shape[axis],
+                            j, k, l] = updates[i, j, k, l]
+    elif axis == 1:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    for l in const_range(indices.shape[3]):
+                        out[i,
+                            indices[i, j, k, l] if indices[i, j, k, l] >=
+                            0 else indices[i, j, k, l] + data.shape[axis],
+                            k, l] = updates[i, j, k, l]
+    elif axis == 2:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    for l in const_range(indices.shape[3]):
+                        out[i, j,
+                            indices[i, j, k, l] if indices[i, j, k, l] >=
+                            0 else indices[i, j, k, l] + data.shape[axis],
+                            l] = updates[i, j, k, l]
+    else:
+        for i in range(indices.shape[0]):
+            for j in range(indices.shape[1]):
+                for k in const_range(indices.shape[2]):
+                    for l in const_range(indices.shape[3]):
+                        out[i, j, k,
+                            indices[i, j, k, l] if indices[i, j, k, l] >=
+                            0 else indices[i, j, k, l] + data.shape[axis]
+                            ] = updates[i, j, k, l]
+
+    return out
+
+
+def scatter(data, indices, updates, axis=0):
+    """Update data at positions defined by indices with values in updates
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data to the operator.
+
+    indices : relay.Expr
+        The index locations to update.
+
+    updates : relay.Expr
+        The values to update.
+
+    axis : int
+        The axis to scatter on
+
+    Returns
+    -------
+    ret : relay.Expr
+        The computed result.
+    """
+    if axis < 0:
+        axis += len(data.shape)
+    assert axis >= 0
+    assert axis < len(data.shape)
+
+    if len(data.shape) == 1:
+        return _scatter_1d(data, indices, updates)
+    if len(data.shape) == 2:
+        return _scatter_2d(data, indices, updates, axis)
+    if len(data.shape) == 3:
+        return _scatter_3d(data, indices, updates, axis)
+    if len(data.shape) == 4:
+        return _scatter_4d(data, indices, updates, axis)
+    raise ValueError("scatter only support for 1-4 dimensions")