[TOPI]Add op argwhere (#3994)
authorWei Chen <ipondering.weic@gmail.com>
Tue, 1 Oct 2019 20:09:21 +0000 (13:09 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Tue, 1 Oct 2019 20:09:21 +0000 (13:09 -0700)
* Add op argwhere

* Move shape func to _algorithm.py

* Add lint rule

* Raise exception if rank is not supportted

* move argwhere to transform

* Add argwhere example

* Fix lint

* Add 1-d support

* cleanup

* Add more dtype support

* CR comment

* Improve error message

* Docs

* raise exception

include/tvm/relay/attrs/transform.h
python/tvm/relay/op/_transform.py
python/tvm/relay/op/transform.py
src/relay/op/tensor/transform.cc
tests/python/relay/test_any.py
topi/python/topi/__init__.py
topi/python/topi/argwhere.py [new file with mode: 0644]
topi/python/topi/generic/__init__.py
topi/python/topi/generic/search.py [new file with mode: 0644]

index 5265687..ccdc871 100644 (file)
@@ -314,6 +314,12 @@ struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
   }
 };  // struct OneHotAttrs
 
+/*! \brief Attributes for ArgWhere operator */
+struct ArgWhereAttrs : public tvm::AttrsNode<ArgWhereAttrs> {
+  TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") {
+  }
+};  // struct ArgWhereAttrs
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_TRANSFORM_H_
index d1c0f09..687d5b4 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 """Backend compiler related feature registration"""
-# pylint: disable=invalid-name,unused-argument, len-as-condition
+# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks
 from __future__ import absolute_import
+import tvm
+import topi
 from topi.util import get_const_int, get_const_tuple
 from . import op as _reg
 from ._reduce import _schedule_reduce
@@ -204,3 +206,100 @@ def take_shape_func(attrs, inputs, out_ndims):
             axis += data_ndim
         assert 0 <= axis < data_ndim
         return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
+
+@script
+def _argwhere_shape_func_1d(condition):
+    out = output_tensor((2, ), "int64")
+    out[0] = int64(0)
+    out[1] = int64(1)
+    for i1 in range(condition.shape[0]):
+        if condition[i1] != 0:
+            out[0] += int64(1)
+    return out
+
+@script
+def _argwhere_shape_func_2d(condition):
+    out = output_tensor((2, ), "int64")
+    out[0] = int64(0)
+    out[1] = int64(2)
+    for i1 in range(condition.shape[0]):
+        for i2 in range(condition.shape[1]):
+            if condition[i1, i2] != 0:
+                out[0] += int64(1)
+    return out
+
+@script
+def _argwhere_shape_func_3d(condition):
+    out = output_tensor((2, ), "int64")
+    out[0] = int64(0)
+    out[1] = int64(3)
+    for i1 in range(condition.shape[0]):
+        for i2 in range(condition.shape[1]):
+            for i3 in range(condition.shape[2]):
+                if condition[i1, i2, i3] != 0:
+                    out[0] += int64(1)
+    return out
+
+@script
+def _argwhere_shape_func_4d(condition):
+    out = output_tensor((2, ), "int64")
+    out[0] = int64(0)
+    out[1] = int64(4)
+    for i1 in range(condition.shape[0]):
+        for i2 in range(condition.shape[1]):
+            for i3 in range(condition.shape[2]):
+                for i4 in range(condition.shape[3]):
+                    if condition[i1, i2, i3, i4] != 0:
+                        out[0] += int64(1)
+    return out
+
+@script
+def _argwhere_shape_func_5d(condition):
+    out = output_tensor((2, ), "int64")
+    out[0] = int64(0)
+    out[1] = int64(5)
+    for i1 in range(condition.shape[0]):
+        for i2 in range(condition.shape[1]):
+            for i3 in range(condition.shape[2]):
+                for i4 in range(condition.shape[3]):
+                    for i5 in range(condition.shape[4]):
+                        if condition[i1, i2, i3, i4, i5] != 0:
+                            out[0] += int64(1)
+    return out
+
+@_reg.register_shape_func("argwhere", True)
+def argwhere_shape_func(attrs, inputs, out_ndims):
+    """
+    Shape function for argwhere.
+    """
+    if len(inputs[0].shape) == 1:
+        return [_argwhere_shape_func_1d(inputs[0])]
+    elif len(inputs[0].shape) == 2:
+        return [_argwhere_shape_func_2d(inputs[0])]
+    elif len(inputs[0].shape) == 3:
+        return [_argwhere_shape_func_3d(inputs[0])]
+    elif len(inputs[0].shape) == 4:
+        return [_argwhere_shape_func_4d(inputs[0])]
+    elif len(inputs[0].shape) == 5:
+        return [_argwhere_shape_func_5d(inputs[0])]
+    return ValueError("Does not support rank higher than 5 in argwhere")
+
+@_reg.register_schedule("argwhere")
+def schedule_argwhere(_, outs, target):
+    """Schedule definition of argwhere"""
+    with target:
+        return topi.generic.schedule_argwhere(outs)
+
+
+@_reg.register_compute("argwhere")
+def compute_argwhere(attrs, inputs, output_type, _):
+    """Compute definition of argwhere"""
+    output_shape = []
+    for s in output_type.shape:
+        if hasattr(s, "value"):
+            output_shape.append(s)
+        else:
+            # see Any, replace it with a var
+            output_shape.append(tvm.var("any_dim", "int32"))
+    new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
+    return [topi.argwhere(new_output_type, inputs[0])]
index 7f921d0..88d7a44 100644 (file)
@@ -144,7 +144,6 @@ def squeeze(data, axis=None):
     """
     return _make.squeeze(data, axis)
 
-
 def reshape(data, newshape):
     """Reshapes the input array.
 
@@ -214,6 +213,28 @@ def reshape(data, newshape):
         newshape = [newshape]
     return _make.reshape(data, list(newshape))
 
+def argwhere(condition):
+    """Find the indices of elements of a tensor that are
+    non-zero.
+
+    Parameters
+    ----------
+    condition : relay.Expr
+        The input condition tensor.
+
+    Returns
+    -------
+    out : relay.Expr
+        Tensor with the indices of elements that are non-zero.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        condition = [[True, False], [False, True]]
+        relay.argwhere(condition) = [[0, 0], [1, 1]]
+    """
+    return _make.argwhere(condition)
 
 def reshape_like(data, shape_like):
     """Reshapes the input array by the size of another array.
index b920057..0002390 100644 (file)
@@ -817,6 +817,40 @@ the input array into an output array with the same shape as the second input arr
 .set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
 .set_attr<TOpPattern>("TOpPattern", kInjective);
 
+// ArgWhere
+bool ArgWhereRel(const Array<Type>& types,
+                 int num_inputs,
+                 const Attrs& attrs,
+                 const TypeReporter& reporter) {
+  CHECK_EQ(num_inputs, 1);
+  auto tt = types[0].as<TensorTypeNode>();
+  CHECK(tt != nullptr);
+  const auto& input_shape = tt->shape;
+  const auto& input_rank = input_shape.size();
+  std::vector<IndexExpr> result_shape;
+  result_shape.push_back(Any::make());
+  result_shape.push_back(IntImm::make(Int(32), input_rank));
+  reporter->Assign(types[1], TensorTypeNode::make(result_shape, Int(32)));
+  return true;
+}
+
+TVM_REGISTER_API("relay.op._make.argwhere")
+.set_body_typed<Expr(Expr)>([](Expr data) {
+  static const Op& op = Op::Get("argwhere");
+  auto attrs = make_node<ArgWhereAttrs>();
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+});
+
+RELAY_REGISTER_OP("argwhere")
+.describe(R"doc(Find the indices of elements of a tensor that are
+non-zero)doc" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_attrs_type_key("relay.attrs.ArgWhereAttrs")
+.add_argument("condition", "Tensor", "The input condition tensor.")
+.add_type_rel("ArgWhere", ArgWhereRel)
+.set_attr<TOpIsStateful>("TOpIsStateful", false)
+.set_attr<TOpPattern>("TOpPattern", kOpaque)
+.set_support_level(10);
 
 // Take
 TVM_REGISTER_NODE_TYPE(TakeAttrs);
index 214b88f..d02dcd0 100644 (file)
@@ -92,6 +92,36 @@ def test_any_reshape():
     verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))
     verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))
 
+def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
+    x = relay.var('x', shape=x_shape, dtype=dtype)
+    y = relay.argwhere(x)
+    mod = relay.module.Module()
+    mod["main"] = relay.Function([x], y)
+    data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype)
+    for kind in ["debug", "vm"]:
+        ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+        result = ex.evaluate()(data).asnumpy()
+        expected = np.argwhere(data)
+        assert result.shape == expected.shape
+        tvm.testing.assert_allclose(result.flatten(), expected.flatten())
+
+def test_any_argwhere():
+    verify_any_argwhere(any_dims(1), (5,))
+    verify_any_argwhere(any_dims(2), (5, 5))
+    verify_any_argwhere(any_dims(3), (5, 5, 5))
+    verify_any_argwhere(any_dims(4), (5, 5, 5, 5))
+    verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5))
+    verify_any_argwhere(any_dims(1), (5,), "int32")
+    verify_any_argwhere(any_dims(2), (5, 5), "int32")
+    verify_any_argwhere(any_dims(3), (5, 5, 5), "int32")
+    verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int32")
+    verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int32")
+    verify_any_argwhere(any_dims(1), (5,), "int8")
+    verify_any_argwhere(any_dims(2), (5, 5), "int8")
+    verify_any_argwhere(any_dims(3), (5, 5, 5), "int8")
+    verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8")
+    verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8")
+
 def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape):
     mod = relay.Module()
     data = relay.var('data', shape=data_shape, dtype='float32')
index ac855d1..fd293a0 100644 (file)
@@ -22,6 +22,7 @@ from .reduction import *
 from .transform import *
 from .broadcast import *
 from .sort import *
+from .argwhere import *
 from . import nn
 from . import x86
 from . import cuda
diff --git a/topi/python/topi/argwhere.py b/topi/python/topi/argwhere.py
new file mode 100644 (file)
index 0000000..32f4e87
--- /dev/null
@@ -0,0 +1,191 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
+"""Argwhere operator"""
+import tvm
+from tvm import hybrid
+
+@hybrid.script
+def hybrid_argwhere_1d(output_shape, condition):
+    """Find the indices of elements of a 1-D tensor that are non-zero.
+
+    Parameters
+    ----------
+    condition : tvm.Tensor
+        1-D tensor with boolean values.
+
+    Returns
+    -------
+    out : tvm.Tensor
+        Indices of non-zero elements.
+    """
+    a = output_tensor(output_shape, "int32")
+    a1 = condition.shape[0]
+    valid_index = 0
+    for i1 in range(a1):
+        if condition[i1] != 0:
+            a[valid_index, 0] = i1
+            valid_index += 1
+    return a
+
+@hybrid.script
+def hybrid_argwhere_2d(output_shape, condition):
+    """Find the indices of elements of a 2-D tensor that are non-zero.
+
+    Parameters
+    ----------
+    condition : tvm.Tensor
+        2-D tensor with boolean values.
+
+    Returns
+    -------
+    out : tvm.Tensor
+        Indices of non-zero elements.
+    """
+    a = output_tensor(output_shape, "int32")
+    a1 = condition.shape[0]
+    a2 = condition.shape[1]
+    valid_index = 0
+    for i1 in range(a1):
+        for i2 in range(a2):
+            if condition[i1, i2] != 0:
+                a[valid_index, 0] = i1
+                a[valid_index, 1] = i2
+                valid_index += 1
+    return a
+
+@hybrid.script
+def hybrid_argwhere_3d(output_shape, condition):
+    """Find the indices of elements of a 3-D tensor that are non-zero.
+
+    Parameters
+    ----------
+    condition : tvm.Tensor
+        3-D tensor with boolean values.
+
+    Returns
+    -------
+    out : tvm.Tensor
+        Indices of non-zero elements.
+    """
+    a = output_tensor(output_shape, "int32")
+    a1 = condition.shape[0]
+    a2 = condition.shape[1]
+    a3 = condition.shape[2]
+    valid_index = 0
+    for i1 in range(a1):
+        for i2 in range(a2):
+            for i3 in range(a3):
+                if condition[i1, i2, i3] != 0:
+                    a[valid_index, 0] = i1
+                    a[valid_index, 1] = i2
+                    a[valid_index, 2] = i3
+                    valid_index += 1
+    return a
+
+@hybrid.script
+def hybrid_argwhere_4d(output_shape, condition):
+    """Find the indices of elements of a 4-D tensor that are non-zero.
+
+    Parameters
+    ----------
+    condition : tvm.Tensor
+        4-D tensor with boolean values.
+
+    Returns
+    -------
+    out : tvm.Tensor
+        Indices of non-zero elements.
+    """
+    a = output_tensor(output_shape, "int32")
+    a1 = condition.shape[0]
+    a2 = condition.shape[1]
+    a3 = condition.shape[2]
+    a4 = condition.shape[3]
+    valid_index = 0
+    for i1 in range(a1):
+        for i2 in range(a2):
+            for i3 in range(a3):
+                for i4 in range(a4):
+                    if condition[i1, i2, i3, i4] != 0:
+                        a[valid_index, 0] = i1
+                        a[valid_index, 1] = i2
+                        a[valid_index, 2] = i3
+                        a[valid_index, 3] = i4
+                        valid_index += 1
+    return a
+
+@hybrid.script
+def hybrid_argwhere_5d(output_shape, condition):
+    """Find the indices of elements of a 5-D tensor that are non-zero.
+
+    Parameters
+    ----------
+    condition : tvm.Tensor
+        5-D tensor with boolean values.
+
+    Returns
+    -------
+    out : tvm.Tensor
+        Indices of non-zero elements.
+    """
+    a = output_tensor(output_shape, "int32")
+    a1 = condition.shape[0]
+    a2 = condition.shape[1]
+    a3 = condition.shape[2]
+    a4 = condition.shape[3]
+    a5 = condition.shape[4]
+    valid_index = 0
+    for i1 in range(a1):
+        for i2 in range(a2):
+            for i3 in range(a3):
+                for i4 in range(a4):
+                    for i5 in range(a5):
+                        if condition[i1, i2, i3, i4, i5] != 0:
+                            a[valid_index, 0] = i1
+                            a[valid_index, 1] = i2
+                            a[valid_index, 2] = i3
+                            a[valid_index, 3] = i4
+                            a[valid_index, 4] = i5
+                            valid_index += 1
+    return a
+
+@tvm.target.generic_func
+def argwhere(output_shape, condition):
+    """Find the indices of elements of a tensor that are non-zero.
+
+    Parameters
+    ----------
+    condition : tvm.Tensor
+        Tensor with boolean values.
+
+    Returns
+    -------
+    out : tvm.Tensor
+        Indices of non-zero elements.
+    """
+    if len(condition.shape) == 1:
+        return hybrid_argwhere_1d(output_shape.shape, condition)
+    if len(condition.shape) == 2:
+        return hybrid_argwhere_2d(output_shape.shape, condition)
+    if len(condition.shape) == 3:
+        return hybrid_argwhere_3d(output_shape.shape, condition)
+    if len(condition.shape) == 4:
+        return hybrid_argwhere_4d(output_shape.shape, condition)
+    if len(condition.shape) == 5:
+        return hybrid_argwhere_5d(output_shape.shape, condition)
+    raise ValueError("Does not support rank higher than 5 in argwhere")
index 6bf5f3a..18af0e3 100644 (file)
@@ -20,3 +20,4 @@ from .injective import *
 from .extern import *
 from .vision import *
 from .sort import *
+from .search import *
diff --git a/topi/python/topi/generic/search.py b/topi/python/topi/generic/search.py
new file mode 100644 (file)
index 0000000..41045e4
--- /dev/null
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member
+"""Generic search operators"""
+from __future__ import absolute_import as _abs
+import tvm
+from .vision import _default_schedule
+
+@tvm.target.generic_func
+def schedule_argwhere(outs):
+    """Schedule for argwhere operator.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+      The computation graph description of argwhere.
+
+    Returns
+    -------
+    s: Schedule
+      The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)