From 554df2118931cede5c0783c5a41d4da0be46ffe6 Mon Sep 17 00:00:00 2001 From: Jon Soifer Date: Thu, 22 Aug 2019 13:45:45 -0700 Subject: [PATCH] [TOPI][Relay][TensorFlow] Add OneHot operator (#3781) * Add one-hot to Relay * topi implementation * Working * add topi test * Add TF test * Fix check * fix linting issues * fix documentation * Fix documentation * Add support for on_value, off_value, axis, dtype * Add full support for axis * Fix compute and update test_forward * Move on_value and off_value to inputs * Add topi test * Update tests * Update docs * Fix style * re-enable tests * Add one_hot to mxnet converter --- docs/api/python/topi.rst | 2 + docs/langref/relay_op.rst | 2 + include/tvm/relay/attrs/transform.h | 16 +++++ python/tvm/relay/frontend/mxnet.py | 9 +++ python/tvm/relay/frontend/tensorflow.py | 16 +++++ python/tvm/relay/op/_transform.py | 1 + python/tvm/relay/op/transform.py | 44 ++++++++++++ src/relay/op/tensor/transform.cc | 89 ++++++++++++++++++++++++ tests/python/frontend/mxnet/test_forward.py | 20 ++++++ tests/python/frontend/tensorflow/test_forward.py | 19 +++++ tests/python/relay/test_op_level10.py | 40 +++++++++++ topi/include/topi/transform.h | 50 +++++++++++++ topi/python/topi/testing/__init__.py | 1 + topi/python/topi/testing/one_hot.py | 79 +++++++++++++++++++++ topi/python/topi/transform.py | 44 ++++++++++++ topi/src/topi.cc | 8 +++ topi/tests/python/test_topi_transform.py | 33 +++++++++ 17 files changed, 473 insertions(+) create mode 100644 topi/python/topi/testing/one_hot.py diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 8f59e08..123c1d0 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -104,6 +104,7 @@ List of operators topi.argsort topi.topk topi.sequence_mask + topi.one_hot List of schedules @@ -173,6 +174,7 @@ topi .. autofunction:: topi.argsort .. autofunction:: topi.topk .. autofunction:: topi.sequence_mask +.. autofunction:: topi.one_hot topi.nn ~~~~~~~ diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 6950ecc..4fad352 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -200,6 +200,7 @@ This level support backpropagation of broadcast operators. It is temporary. tvm.relay.nn.batch_matmul tvm.relay.contrib.adaptive_max_pool2d tvm.relay.contrib.adaptive_avg_pool2d + tvm.relay.one_hot **Level 11: Dialect Operators** @@ -350,6 +351,7 @@ Level 10 Definitions .. autofunction:: tvm.relay.nn.batch_matmul .. autofunction:: tvm.relay.contrib.adaptive_max_pool2d .. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d +.. autofunction:: tvm.relay.one_hot Level 11 Definitions diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index e43fd5f..5265687 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -298,6 +298,22 @@ struct NdarraySizeAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in one-hot operator */ +struct OneHotAttrs : public tvm::AttrsNode { + int depth; + int axis; + DataType dtype; + + TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") { + TVM_ATTR_FIELD(depth).set_default(1) + .describe("Depth of the one hot dimension."); + TVM_ATTR_FIELD(axis).set_default(-1) + .describe("Axis to fill."); + TVM_ATTR_FIELD(dtype).set_default(NullValue()) + .describe("Output data type."); + } +}; // struct OneHotAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 9d82671..36c4fb8 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -896,6 +896,14 @@ def _mx_rnn_layer(inputs, attrs): ret.append(_op.stack(inputs, axis=0)) return ret +def _mx_one_hot(inputs, attrs): + indices = inputs[0].astype('int32') + depth = attrs.get_int('depth', 0) + dtype = attrs.get_str('dtype', 'int32') + on_value = tvm.relay.const(attrs.get_float('on_value', 1.0), dtype) + off_value = tvm.relay.const(attrs.get_float('off_value', 0.0), dtype) + return _op.one_hot(indices, on_value, off_value, depth, -1, dtype) + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free @@ -1041,6 +1049,7 @@ _convert_map = { "LinearRegressionOutput" : _mx_linear_regression_output, "smooth_l1" : _mx_smooth_l1, "_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim, + "one_hot" : _mx_one_hot, # vision "_contrib_BilinearResize2D" : _mx_resize, "_contrib_MultiBoxPrior" : _mx_multibox_prior, diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 54724b5..6c2ad0d 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1212,6 +1212,21 @@ def _log1p(): return get_relay_op('log')(add_out) return _impl +def _one_hot(): + def _impl(inputs, attr, params): + depth = int(_get_num_param(params, inputs[1])) + dtype = attr['T'].name + + on_value = _get_num_param(params, inputs[2]) + off_value = _get_num_param(params, inputs[3]) + new_inputs = [inputs[0], \ + tvm.relay.const(on_value, dtype), \ + tvm.relay.const(off_value, dtype)] + return AttrCvt('one_hot', + ignores=['TI'], + extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr) + return _impl + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1284,6 +1299,7 @@ _convert_map = { 'Mul' : _elemwise('multiply'), 'Neg' : AttrCvt('negative'), 'NotEqual' : _broadcast('not_equal'), + 'OneHot' : _one_hot(), 'Pack' : _pack(), 'Pad' : _pad('Pad'), 'PadV2' : _pad('PadV2'), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 51e7615..a4c9375 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -52,6 +52,7 @@ _reg.register_schedule("concatenate", schedule_concatenate) _reg.register_schedule("_contrib_reverse_reshape", schedule_injective) _reg.register_schedule("gather_nd", schedule_injective) _reg.register_schedule("sequence_mask", schedule_injective) +_reg.register_schedule("one_hot", schedule_injective) # layout_transform diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 5d8d280..38ce653 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -748,3 +748,47 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): [[ 0.1, 0.1, 0.1], [ 16., 17., 18.]]] """ return _make.sequence_mask(data, valid_length, mask_value, axis) + +def one_hot(indices, on_value, off_value, depth, axis, dtype): + """ + Returns a one-hot tensor where the locations repsented by indices take value on_value, + other locations take value off_value. + Final dimension is x depth x . + + Parameters + ---------- + indices : relay.Expr + Locations to set to on_value. + + on_value : relay.Expr + Value to fill at indices. + + off_value : relay.Expr + Value to fill at all other positions besides indices. + + depth : int + Depth of the one-hot dimension. + + axis : int + Axis to fill. + + dtype : str + Data type of the output tensor. + + Returns + ------- + ret : relay.Expr + The one-hot tensor. + + Examples + -------- + .. code-block:: python + + indices = [0, 1, 2] + + relay.one_hot(indices, 3) = + [[1, 0, 0], + [0, 1, 0], + [0, 0, 1]] + """ + return _make.one_hot(indices, on_value, off_value, depth, axis, dtype) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 03a92b3..b39c282 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2482,5 +2482,94 @@ Examples:: .set_attr("FTVMCompute", SequenceMaskCompute) .set_attr("TOpPattern", kInjective); +// relay.one_hot +TVM_REGISTER_NODE_TYPE(OneHotAttrs); + +bool OneHotRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [indices, on_value, off_value, result] + CHECK_EQ(types.size(), 4); + const auto* indices = types[0].as(); + CHECK(indices); + + const auto param = attrs.as(); + CHECK_GT(param->depth, 0); + + Array oshape; + int ndim = indices->shape.size() + 1; + int indices_index = 0; + int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis; + for (int i = 0; i < ndim; i++) { + if (i == true_axis) { + oshape.push_back(Integer(param->depth)); + } else { + oshape.push_back(indices->shape[indices_index++]); + } + } + + reporter->Assign(types[3], TensorTypeNode::make(oshape, param->dtype)); + return true; +} + +Array OneHotCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + return Array { + topi::one_hot(inputs[0], + inputs[1](), + inputs[2](), + param->depth, + param->axis, + param->dtype) + }; +} + +Expr MakeOneHot(Expr indices, + Expr on_value, + Expr off_value, + int depth, + int axis, + DataType dtype) { + auto attrs = make_node(); + attrs->depth = std::move(depth); + attrs->axis = axis; + attrs->dtype = dtype; + static const Op& op = Op::Get("one_hot"); + return CallNode::make(op, {indices, on_value, off_value}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.one_hot") +.set_body_typed(MakeOneHot); + +RELAY_REGISTER_OP("one_hot") +.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, + other locations take value 0. Final dimension is x depth. + + **indices** Locations to set to 1. + + **on_value** Value to fill at indices. + + **off_value** Value to fill at all other positions besides indices. + + **depth** Depth of the one-hot dimension. + + **axis** Axis to fill. + + **dtype**)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.OneHotAttrs") +.set_num_inputs(3) +.add_argument("indices", "Tensor", "Locations to set to on_value.") +.add_argument("on_value", "Expr", "Value to fill at indices.") +.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.") +.set_support_level(10) +.add_type_rel("OneHot", OneHotRel) +.set_attr("FTVMCompute", OneHotCompute) +.set_attr("TOpPattern", kOutEWiseFusable); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index a4a514e..90b425f 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -778,6 +778,25 @@ def test_forward_layer_norm(): verify((2, 5), axis=0) verify((2, 5, 6)) +def test_forward_one_hot(): + def verify(indices_shape, depth, on_value, off_value, dtype): + x = np.random.randint(0, 5, size=indices_shape) + ref_res = mx.nd.one_hot(mx.nd.array(x), depth, on_value, off_value, dtype) + mx_sym = mx.sym.one_hot(mx.sym.var("x"), depth, on_value, off_value, dtype) + shape_dict = {"x": x.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x.astype("float32")) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3) + verify((3,), 3, 1, 0, "int32") + verify((3,), 3, 1.0, 0.0, "float32") + verify((2, 2), 5, 2, -2, "int32") + verify((2, 2), 5, 0.5, -0.5, "float32") + verify((3, 2, 4, 5), 6, 1, 0, "int32") + verify((3, 2, 4, 5), 6, 1.0, 0.0, "float32") + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -825,3 +844,4 @@ if __name__ == '__main__': test_forward_contrib_div_sqrt_dim() test_forward_batch_norm() test_forward_layer_norm() + test_forward_one_hot() diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index eb8e27e..16d7ba5 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2158,6 +2158,24 @@ def test_placeholder(): compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', init_global_variables=True) +####################################################################### +# OneHot +# ---------------------- +def _test_forward_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype): + inp_array1 = np.random.randint(0, 5, size=indices_shape) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype) + out = tf.one_hot(in1, depth, on_value, off_value, axis, dtype=out_dtype) + compare_tf_with_tvm(inp_array1, in1.name, out.name) + +def test_forward_one_hot(): + _test_forward_one_hot((3,), 3, 1, 0, -1, "int32") + _test_forward_one_hot((3,), 3, 1.0, 0.0, -1, "float32") + _test_forward_one_hot((2, 2), 5, 2, -2, 0, "int32") + _test_forward_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32") + _test_forward_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32") + _test_forward_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") + ####################################################################### # Main @@ -2193,6 +2211,7 @@ if __name__ == '__main__': test_forward_right_shift() test_forward_left_shift() test_forward_truncatemod() + test_forward_one_hot() # Activations test_forward_sigmoid() diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index f3520f3..e828fa3 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -296,6 +296,45 @@ def test_sequence_mask(): _verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64') _verify((5, 8, 3), 0.1, 1, 'float64', 'float32') +def test_one_hot(): + def _get_oshape(indices_shape, depth, axis): + oshape = [] + true_axis = len(indices_shape) if axis == -1 else axis + ndim = len(indices_shape) + 1 + indices_index = 0 + for i in range(0, ndim): + if i == true_axis: + oshape.append(depth) + else: + oshape.append(indices_shape[indices_index]) + indices_index += 1 + + return oshape + + def _verify(indices_shape, depth, on_value, off_value, axis, dtype): + indices = relay.var("indices", relay.TensorType(indices_shape, "int32")) + on_value_const = relay.const(on_value) + off_value_const = relay.const(off_value) + out = relay.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype) + checked = run_infer_type(out) + assert checked.checked_type == relay.ty.TensorType(_get_oshape(indices_shape, depth, axis), dtype) + func = relay.Function([indices], out) + indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32") + out_np = topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + out_relay = intrp.evaluate(func)(indices_np) + tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) + + _verify((3,), 3, 1, 0, -1, "int32") + _verify((3,), 3, 1.0, 0.0, -1, "float32") + _verify((2, 2), 5, 2, -2, 0, "int32") + _verify((2, 2), 5, 0.5, -0.5, 1, "float32") + _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") + _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") + if __name__ == "__main__": test_adaptive_pool2d() test_collapse_sum_like() @@ -306,4 +345,5 @@ if __name__ == "__main__": test_shape_of() test_sequence_mask() test_ndarray_size() + test_one_hot() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index e8a65b0..1622b20 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1247,5 +1247,55 @@ inline Tensor ndarray_size(const Tensor& src, }, name, tag); } +/*! + * \brief Returns a one-hot tensor where the locations repsented by indices take value on_value, + other locations take value off_value. + * \param indices locations to set to on_value. + * \param on_value value that locations represented by indices take on. + * \param off_value value that other locations take on. + * \param depth depth of the one-hot dimension. + * \param axis axis to fill. + * \param dtype data type of the output tensor. + * \param name output tensor name. + * \param tag output tensor tag. + * \return one-hot tensor. + */ +inline Tensor one_hot(const Tensor& indices, + const Expr on_value, + const Expr off_value, + int depth, + int axis, + const Type& dtype, + const std::string name = "T_one_hot", + const std::string tag = kInjective) { + Array oshape; + int ndim = indices->shape.size() + 1; + int indices_index = 0; + int true_axis = (axis == -1) ? indices->shape.size() : axis; + for (int i = 0; i < ndim; i++) { + if (i == true_axis) { + oshape.push_back(Integer(depth)); + } else { + oshape.push_back(indices->shape[indices_index++]); + } + } + + Expr on_value_cast = cast(dtype, on_value); + Expr off_value_cast = cast(dtype, off_value); + return compute(oshape, [&](const Array& iter_vars) { + Array indices_indices; + for (size_t i = 0; i < iter_vars.size(); i++) { + if (static_cast(i) == true_axis) { + continue; + } + + indices_indices.push_back(iter_vars[i]); + } + + auto idx = iter_vars[true_axis]; + return ir::Select::make(indices(indices_indices) == idx, on_value_cast, off_value_cast); + }, name, tag); +} + } // namespace topi #endif // TOPI_TRANSFORM_H_ diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 57a9c26..d607c28 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -25,3 +25,4 @@ from .batch_matmul import batch_matmul from .slice_axis_python import slice_axis_python from .sequence_mask_python import sequence_mask from .pool_grad_python import pool_grad_nchw +from .one_hot import one_hot diff --git a/topi/python/topi/testing/one_hot.py b/topi/python/topi/testing/one_hot.py new file mode 100644 index 0000000..99c52be --- /dev/null +++ b/topi/python/topi/testing/one_hot.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""OneHot in python""" +import numpy as np + +def one_hot(indices, on_value, off_value, depth, axis, dtype): + """one_hot operator implemented in numpy. + + Returns a one-hot tensor where the locations repsented by indices take value on_value, + other locations take value off_value. + Final dimension is x depth x . + + Parameters + ---------- + indices : numpy.ndarray + Locations to set to on_value. + + on_value : int/float + Value to fill at indices. + + off_value : int/float + Value to fill at all other positions besides indices. + + depth : int + Depth of the one-hot dimension. + + axis : int + Axis to fill. + + dtype : str + Data type of the output tensor. + + Returns + ------- + ret : relay.Expr + The one-hot tensor. + """ + oshape = [] + true_axis = len(indices.shape) if axis == -1 else axis + ndim = len(indices.shape) + 1 + indices_index = 0 + for i in range(0, ndim): + if i == true_axis: + oshape.append(depth) + else: + oshape.append(indices.shape[indices_index]) + indices_index += 1 + + out = np.empty(oshape) + output_indices = [index for index in np.ndindex(out.shape)] + for output_index in output_indices: + indices_indices = [] + for i, out_idx in enumerate(output_index): + if i == true_axis: + continue + indices_indices.append(out_idx) + + index = output_index[true_axis] + if indices[tuple(indices_indices)] == index: + out[output_index] = on_value + else: + out[output_index] = off_value + + return out.astype(dtype) diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 5e87933..3c7fc9c 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -518,3 +518,47 @@ def where(condition, x, y): A Tensor selected from x or y depending on condition. """ return cpp.where(condition, x, y) + +def one_hot(indices, on_value, off_value, depth, axis, dtype): + """ + Returns a one-hot tensor where the locations repsented by indices take value on_value, + other locations take value off_value. + Final dimension is x depth x . + + Parameters + ---------- + indices : tvm.Tensor + Locations to set to on_value. + + on_value : tvm.Tensor + Value to fill at indices. + + off_value : tvm.Tensor + Value to fill at all other positions besides indices. + + depth : int + Depth of the one-hot dimension. + + axis : int + Axis to fill. + + dtype : relay.DataType + Data type of the output tensor. + + Returns + ------- + ret : relay.Expr + The one-hot tensor. + + Examples + -------- + .. code-block:: python + + indices = [0, 1, 2] + + relay.one_hot(indices, 3) = + [[1, 0, 0], + [0, 1, 0], + [0, 0, 1]] + """ + return cpp.one_hot(indices, on_value, off_value, depth, axis, dtype) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 799b660..7e47b62 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -417,6 +417,14 @@ TVM_REGISTER_GLOBAL("topi.strided_slice") *rv = strided_slice(args[0], args[1], args[2], args[3]); }); +TVM_REGISTER_GLOBAL("topi.one_hot") +.set_body([](TVMArgs args, TVMRetValue *rv) { + int depth = args[3]; + int axis = args[4]; + DataType dtype = args[5]; + *rv = one_hot(args[0], args[1], args[2], depth, axis, dtype); + }); + /* Ops from nn/upsampling.h */ TVM_REGISTER_GLOBAL("topi.nn.upsampling") .set_body([](TVMArgs args, TVMRetValue *rv) { diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 64305b4..b1aa20e 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -473,6 +473,31 @@ def verify_where(in_shape): for device in get_all_backend(): check_device(device) +def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype): + indices = tvm.placeholder(shape=indices_shape, name="indices", dtype="int32") + on_value_const = tvm.const(on_value, dtype) + off_value_const = tvm.const(off_value, dtype) + one_hot_result = topi.transform.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype) + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(one_hot_result) + fn = tvm.build(s, [indices, one_hot_result], device, name="one_hot") + indices_npy = np.random.randint(0, depth, size=indices_shape).astype(indices.dtype) + out_npy = topi.testing.one_hot(indices_npy, on_value, off_value, depth, axis, dtype) + indices_nd = tvm.nd.array(indices_npy, ctx) + out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(one_hot_result.dtype), ctx) + fn(indices_nd, out_nd) + out_topi = out_nd.asnumpy() + tvm.testing.assert_allclose(out_topi, out_npy) + + for device in get_all_backend(): + check_device(device) + def test_strided_slice(): verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) @@ -770,6 +795,13 @@ def test_where_fusion(): for backend in get_all_backend(): check_device(backend) +def test_one_hot(): + verify_one_hot((3,), 3, 1, 0, -1, "int32") + verify_one_hot((3,), 3, 1.0, 0.0, -1, "float32") + verify_one_hot((2, 2), 5, 2, -2, 0, "int32") + verify_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32") + verify_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32") + verify_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") if __name__ == "__main__": test_strided_slice() @@ -793,3 +825,4 @@ if __name__ == "__main__": test_sequence_mask() test_ndarray_size() test_where_fusion() + test_one_hot() -- 2.7.4