}
}; // struct OneHotAttrs
+/*! \brief Attributes for ArgWhere operator */
+struct ArgWhereAttrs : public tvm::AttrsNode<ArgWhereAttrs> {
+ TVM_DECLARE_ATTRS(ArgWhereAttrs, "relay.attrs.ArgWhereAttrs") {
+ }
+}; // struct ArgWhereAttrs
+
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
-# pylint: disable=invalid-name,unused-argument, len-as-condition
+# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks
from __future__ import absolute_import
+import tvm
+import topi
from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from ._reduce import _schedule_reduce
axis += data_ndim
assert 0 <= axis < data_ndim
return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
+
+@script
+def _argwhere_shape_func_1d(condition):
+ out = output_tensor((2, ), "int64")
+ out[0] = int64(0)
+ out[1] = int64(1)
+ for i1 in range(condition.shape[0]):
+ if condition[i1] != 0:
+ out[0] += int64(1)
+ return out
+
+@script
+def _argwhere_shape_func_2d(condition):
+ out = output_tensor((2, ), "int64")
+ out[0] = int64(0)
+ out[1] = int64(2)
+ for i1 in range(condition.shape[0]):
+ for i2 in range(condition.shape[1]):
+ if condition[i1, i2] != 0:
+ out[0] += int64(1)
+ return out
+
+@script
+def _argwhere_shape_func_3d(condition):
+ out = output_tensor((2, ), "int64")
+ out[0] = int64(0)
+ out[1] = int64(3)
+ for i1 in range(condition.shape[0]):
+ for i2 in range(condition.shape[1]):
+ for i3 in range(condition.shape[2]):
+ if condition[i1, i2, i3] != 0:
+ out[0] += int64(1)
+ return out
+
+@script
+def _argwhere_shape_func_4d(condition):
+ out = output_tensor((2, ), "int64")
+ out[0] = int64(0)
+ out[1] = int64(4)
+ for i1 in range(condition.shape[0]):
+ for i2 in range(condition.shape[1]):
+ for i3 in range(condition.shape[2]):
+ for i4 in range(condition.shape[3]):
+ if condition[i1, i2, i3, i4] != 0:
+ out[0] += int64(1)
+ return out
+
+@script
+def _argwhere_shape_func_5d(condition):
+ out = output_tensor((2, ), "int64")
+ out[0] = int64(0)
+ out[1] = int64(5)
+ for i1 in range(condition.shape[0]):
+ for i2 in range(condition.shape[1]):
+ for i3 in range(condition.shape[2]):
+ for i4 in range(condition.shape[3]):
+ for i5 in range(condition.shape[4]):
+ if condition[i1, i2, i3, i4, i5] != 0:
+ out[0] += int64(1)
+ return out
+
+@_reg.register_shape_func("argwhere", True)
+def argwhere_shape_func(attrs, inputs, out_ndims):
+ """
+ Shape function for argwhere.
+ """
+ if len(inputs[0].shape) == 1:
+ return [_argwhere_shape_func_1d(inputs[0])]
+ elif len(inputs[0].shape) == 2:
+ return [_argwhere_shape_func_2d(inputs[0])]
+ elif len(inputs[0].shape) == 3:
+ return [_argwhere_shape_func_3d(inputs[0])]
+ elif len(inputs[0].shape) == 4:
+ return [_argwhere_shape_func_4d(inputs[0])]
+ elif len(inputs[0].shape) == 5:
+ return [_argwhere_shape_func_5d(inputs[0])]
+ return ValueError("Does not support rank higher than 5 in argwhere")
+
+@_reg.register_schedule("argwhere")
+def schedule_argwhere(_, outs, target):
+ """Schedule definition of argwhere"""
+ with target:
+ return topi.generic.schedule_argwhere(outs)
+
+
+@_reg.register_compute("argwhere")
+def compute_argwhere(attrs, inputs, output_type, _):
+ """Compute definition of argwhere"""
+ output_shape = []
+ for s in output_type.shape:
+ if hasattr(s, "value"):
+ output_shape.append(s)
+ else:
+ # see Any, replace it with a var
+ output_shape.append(tvm.var("any_dim", "int32"))
+ new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
+ return [topi.argwhere(new_output_type, inputs[0])]
"""
return _make.squeeze(data, axis)
-
def reshape(data, newshape):
"""Reshapes the input array.
newshape = [newshape]
return _make.reshape(data, list(newshape))
+def argwhere(condition):
+ """Find the indices of elements of a tensor that are
+ non-zero.
+
+ Parameters
+ ----------
+ condition : relay.Expr
+ The input condition tensor.
+
+ Returns
+ -------
+ out : relay.Expr
+ Tensor with the indices of elements that are non-zero.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ condition = [[True, False], [False, True]]
+ relay.argwhere(condition) = [[0, 0], [1, 1]]
+ """
+ return _make.argwhere(condition)
def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
+// ArgWhere
+bool ArgWhereRel(const Array<Type>& types,
+ int num_inputs,
+ const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(num_inputs, 1);
+ auto tt = types[0].as<TensorTypeNode>();
+ CHECK(tt != nullptr);
+ const auto& input_shape = tt->shape;
+ const auto& input_rank = input_shape.size();
+ std::vector<IndexExpr> result_shape;
+ result_shape.push_back(Any::make());
+ result_shape.push_back(IntImm::make(Int(32), input_rank));
+ reporter->Assign(types[1], TensorTypeNode::make(result_shape, Int(32)));
+ return true;
+}
+
+TVM_REGISTER_API("relay.op._make.argwhere")
+.set_body_typed<Expr(Expr)>([](Expr data) {
+ static const Op& op = Op::Get("argwhere");
+ auto attrs = make_node<ArgWhereAttrs>();
+ return CallNode::make(op, {data}, Attrs(attrs), {});
+});
+
+RELAY_REGISTER_OP("argwhere")
+.describe(R"doc(Find the indices of elements of a tensor that are
+non-zero)doc" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_attrs_type_key("relay.attrs.ArgWhereAttrs")
+.add_argument("condition", "Tensor", "The input condition tensor.")
+.add_type_rel("ArgWhere", ArgWhereRel)
+.set_attr<TOpIsStateful>("TOpIsStateful", false)
+.set_attr<TOpPattern>("TOpPattern", kOpaque)
+.set_support_level(10);
// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);
verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))
verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))
+def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
+ x = relay.var('x', shape=x_shape, dtype=dtype)
+ y = relay.argwhere(x)
+ mod = relay.module.Module()
+ mod["main"] = relay.Function([x], y)
+ data = np.random.choice([0, 1, 2, 3], size=x_np_shape).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data).asnumpy()
+ expected = np.argwhere(data)
+ assert result.shape == expected.shape
+ tvm.testing.assert_allclose(result.flatten(), expected.flatten())
+
+def test_any_argwhere():
+ verify_any_argwhere(any_dims(1), (5,))
+ verify_any_argwhere(any_dims(2), (5, 5))
+ verify_any_argwhere(any_dims(3), (5, 5, 5))
+ verify_any_argwhere(any_dims(4), (5, 5, 5, 5))
+ verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5))
+ verify_any_argwhere(any_dims(1), (5,), "int32")
+ verify_any_argwhere(any_dims(2), (5, 5), "int32")
+ verify_any_argwhere(any_dims(3), (5, 5, 5), "int32")
+ verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int32")
+ verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int32")
+ verify_any_argwhere(any_dims(1), (5,), "int8")
+ verify_any_argwhere(any_dims(2), (5, 5), "int8")
+ verify_any_argwhere(any_dims(3), (5, 5, 5), "int8")
+ verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8")
+ verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8")
+
def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape):
mod = relay.Module()
data = relay.var('data', shape=data_shape, dtype='float32')
from .transform import *
from .broadcast import *
from .sort import *
+from .argwhere import *
from . import nn
from . import x86
from . import cuda
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
+"""Argwhere operator"""
+import tvm
+from tvm import hybrid
+
+@hybrid.script
+def hybrid_argwhere_1d(output_shape, condition):
+ """Find the indices of elements of a 1-D tensor that are non-zero.
+
+ Parameters
+ ----------
+ condition : tvm.Tensor
+ 1-D tensor with boolean values.
+
+ Returns
+ -------
+ out : tvm.Tensor
+ Indices of non-zero elements.
+ """
+ a = output_tensor(output_shape, "int32")
+ a1 = condition.shape[0]
+ valid_index = 0
+ for i1 in range(a1):
+ if condition[i1] != 0:
+ a[valid_index, 0] = i1
+ valid_index += 1
+ return a
+
+@hybrid.script
+def hybrid_argwhere_2d(output_shape, condition):
+ """Find the indices of elements of a 2-D tensor that are non-zero.
+
+ Parameters
+ ----------
+ condition : tvm.Tensor
+ 2-D tensor with boolean values.
+
+ Returns
+ -------
+ out : tvm.Tensor
+ Indices of non-zero elements.
+ """
+ a = output_tensor(output_shape, "int32")
+ a1 = condition.shape[0]
+ a2 = condition.shape[1]
+ valid_index = 0
+ for i1 in range(a1):
+ for i2 in range(a2):
+ if condition[i1, i2] != 0:
+ a[valid_index, 0] = i1
+ a[valid_index, 1] = i2
+ valid_index += 1
+ return a
+
+@hybrid.script
+def hybrid_argwhere_3d(output_shape, condition):
+ """Find the indices of elements of a 3-D tensor that are non-zero.
+
+ Parameters
+ ----------
+ condition : tvm.Tensor
+ 3-D tensor with boolean values.
+
+ Returns
+ -------
+ out : tvm.Tensor
+ Indices of non-zero elements.
+ """
+ a = output_tensor(output_shape, "int32")
+ a1 = condition.shape[0]
+ a2 = condition.shape[1]
+ a3 = condition.shape[2]
+ valid_index = 0
+ for i1 in range(a1):
+ for i2 in range(a2):
+ for i3 in range(a3):
+ if condition[i1, i2, i3] != 0:
+ a[valid_index, 0] = i1
+ a[valid_index, 1] = i2
+ a[valid_index, 2] = i3
+ valid_index += 1
+ return a
+
+@hybrid.script
+def hybrid_argwhere_4d(output_shape, condition):
+ """Find the indices of elements of a 4-D tensor that are non-zero.
+
+ Parameters
+ ----------
+ condition : tvm.Tensor
+ 4-D tensor with boolean values.
+
+ Returns
+ -------
+ out : tvm.Tensor
+ Indices of non-zero elements.
+ """
+ a = output_tensor(output_shape, "int32")
+ a1 = condition.shape[0]
+ a2 = condition.shape[1]
+ a3 = condition.shape[2]
+ a4 = condition.shape[3]
+ valid_index = 0
+ for i1 in range(a1):
+ for i2 in range(a2):
+ for i3 in range(a3):
+ for i4 in range(a4):
+ if condition[i1, i2, i3, i4] != 0:
+ a[valid_index, 0] = i1
+ a[valid_index, 1] = i2
+ a[valid_index, 2] = i3
+ a[valid_index, 3] = i4
+ valid_index += 1
+ return a
+
+@hybrid.script
+def hybrid_argwhere_5d(output_shape, condition):
+ """Find the indices of elements of a 5-D tensor that are non-zero.
+
+ Parameters
+ ----------
+ condition : tvm.Tensor
+ 5-D tensor with boolean values.
+
+ Returns
+ -------
+ out : tvm.Tensor
+ Indices of non-zero elements.
+ """
+ a = output_tensor(output_shape, "int32")
+ a1 = condition.shape[0]
+ a2 = condition.shape[1]
+ a3 = condition.shape[2]
+ a4 = condition.shape[3]
+ a5 = condition.shape[4]
+ valid_index = 0
+ for i1 in range(a1):
+ for i2 in range(a2):
+ for i3 in range(a3):
+ for i4 in range(a4):
+ for i5 in range(a5):
+ if condition[i1, i2, i3, i4, i5] != 0:
+ a[valid_index, 0] = i1
+ a[valid_index, 1] = i2
+ a[valid_index, 2] = i3
+ a[valid_index, 3] = i4
+ a[valid_index, 4] = i5
+ valid_index += 1
+ return a
+
+@tvm.target.generic_func
+def argwhere(output_shape, condition):
+ """Find the indices of elements of a tensor that are non-zero.
+
+ Parameters
+ ----------
+ condition : tvm.Tensor
+ Tensor with boolean values.
+
+ Returns
+ -------
+ out : tvm.Tensor
+ Indices of non-zero elements.
+ """
+ if len(condition.shape) == 1:
+ return hybrid_argwhere_1d(output_shape.shape, condition)
+ if len(condition.shape) == 2:
+ return hybrid_argwhere_2d(output_shape.shape, condition)
+ if len(condition.shape) == 3:
+ return hybrid_argwhere_3d(output_shape.shape, condition)
+ if len(condition.shape) == 4:
+ return hybrid_argwhere_4d(output_shape.shape, condition)
+ if len(condition.shape) == 5:
+ return hybrid_argwhere_5d(output_shape.shape, condition)
+ raise ValueError("Does not support rank higher than 5 in argwhere")
from .extern import *
from .vision import *
from .sort import *
+from .search import *
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member
+"""Generic search operators"""
+from __future__ import absolute_import as _abs
+import tvm
+from .vision import _default_schedule
+
+@tvm.target.generic_func
+def schedule_argwhere(outs):
+ """Schedule for argwhere operator.
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of argwhere.
+
+ Returns
+ -------
+ s: Schedule
+ The computation schedule for the op.
+ """
+ return _default_schedule(outs, False)