[TOPI][Relay][TensorFlow] Add OneHot operator (#3781)
authorJon Soifer <soiferj@gmail.com>
Thu, 22 Aug 2019 20:45:45 +0000 (13:45 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Thu, 22 Aug 2019 20:45:45 +0000 (13:45 -0700)
* Add one-hot to Relay

* topi implementation

* Working

* add topi test

* Add TF test

* Fix check

* fix linting issues

* fix documentation

* Fix documentation

* Add support for on_value, off_value, axis, dtype

* Add full support for axis

* Fix compute and update test_forward

* Move on_value and off_value to inputs

* Add topi test

* Update tests

* Update docs

* Fix style

* re-enable tests

* Add one_hot to mxnet converter

17 files changed:
docs/api/python/topi.rst
docs/langref/relay_op.rst
include/tvm/relay/attrs/transform.h
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/frontend/tensorflow.py
python/tvm/relay/op/_transform.py
python/tvm/relay/op/transform.py
src/relay/op/tensor/transform.cc
tests/python/frontend/mxnet/test_forward.py
tests/python/frontend/tensorflow/test_forward.py
tests/python/relay/test_op_level10.py
topi/include/topi/transform.h
topi/python/topi/testing/__init__.py
topi/python/topi/testing/one_hot.py [new file with mode: 0644]
topi/python/topi/transform.py
topi/src/topi.cc
topi/tests/python/test_topi_transform.py

index 8f59e08..123c1d0 100644 (file)
@@ -104,6 +104,7 @@ List of operators
    topi.argsort
    topi.topk
    topi.sequence_mask
+   topi.one_hot
 
 
 List of schedules
@@ -173,6 +174,7 @@ topi
 .. autofunction:: topi.argsort
 .. autofunction:: topi.topk
 .. autofunction:: topi.sequence_mask
+.. autofunction:: topi.one_hot
 
 topi.nn
 ~~~~~~~
index 6950ecc..4fad352 100644 (file)
@@ -200,6 +200,7 @@ This level support backpropagation of broadcast operators. It is temporary.
    tvm.relay.nn.batch_matmul
    tvm.relay.contrib.adaptive_max_pool2d
    tvm.relay.contrib.adaptive_avg_pool2d
+   tvm.relay.one_hot
 
 
 **Level 11: Dialect Operators**
@@ -350,6 +351,7 @@ Level 10 Definitions
 .. autofunction:: tvm.relay.nn.batch_matmul
 .. autofunction:: tvm.relay.contrib.adaptive_max_pool2d
 .. autofunction:: tvm.relay.contrib.adaptive_avg_pool2d
+.. autofunction:: tvm.relay.one_hot
 
 
 Level 11 Definitions
index e43fd5f..5265687 100644 (file)
@@ -298,6 +298,22 @@ struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
   }
 };
 
+/*! \brief Attributes used in one-hot operator */
+struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
+  int depth;
+  int axis;
+  DataType dtype;
+
+  TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") {
+    TVM_ATTR_FIELD(depth).set_default(1)
+        .describe("Depth of the one hot dimension.");
+    TVM_ATTR_FIELD(axis).set_default(-1)
+        .describe("Axis to fill.");
+    TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
+        .describe("Output data type.");
+  }
+};  // struct OneHotAttrs
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_TRANSFORM_H_
index 9d82671..36c4fb8 100644 (file)
@@ -896,6 +896,14 @@ def _mx_rnn_layer(inputs, attrs):
             ret.append(_op.stack(inputs, axis=0))
     return ret
 
+def _mx_one_hot(inputs, attrs):
+    indices = inputs[0].astype('int32')
+    depth = attrs.get_int('depth', 0)
+    dtype = attrs.get_str('dtype', 'int32')
+    on_value = tvm.relay.const(attrs.get_float('on_value', 1.0), dtype)
+    off_value = tvm.relay.const(attrs.get_float('off_value', 0.0), dtype)
+    return _op.one_hot(indices, on_value, off_value, depth, -1, dtype)
+
 
 # Note: due to attribute conversion constraint
 # ops in the identity set must be attribute free
@@ -1041,6 +1049,7 @@ _convert_map = {
     "LinearRegressionOutput" : _mx_linear_regression_output,
     "smooth_l1"     : _mx_smooth_l1,
     "_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim,
+    "one_hot"           : _mx_one_hot,
     # vision
     "_contrib_BilinearResize2D" : _mx_resize,
     "_contrib_MultiBoxPrior" : _mx_multibox_prior,
index 54724b5..6c2ad0d 100644 (file)
@@ -1212,6 +1212,21 @@ def _log1p():
         return get_relay_op('log')(add_out)
     return _impl
 
+def _one_hot():
+    def _impl(inputs, attr, params):
+        depth = int(_get_num_param(params, inputs[1]))
+        dtype = attr['T'].name
+
+        on_value = _get_num_param(params, inputs[2])
+        off_value = _get_num_param(params, inputs[3])
+        new_inputs = [inputs[0], \
+                      tvm.relay.const(on_value, dtype), \
+                      tvm.relay.const(off_value, dtype)]
+        return AttrCvt('one_hot',
+                       ignores=['TI'],
+                       extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr)
+    return _impl
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -1284,6 +1299,7 @@ _convert_map = {
     'Mul'                               : _elemwise('multiply'),
     'Neg'                               : AttrCvt('negative'),
     'NotEqual'                          : _broadcast('not_equal'),
+    'OneHot'                            : _one_hot(),
     'Pack'                              : _pack(),
     'Pad'                               : _pad('Pad'),
     'PadV2'                             : _pad('PadV2'),
index 51e7615..a4c9375 100644 (file)
@@ -52,6 +52,7 @@ _reg.register_schedule("concatenate", schedule_concatenate)
 _reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
 _reg.register_schedule("gather_nd", schedule_injective)
 _reg.register_schedule("sequence_mask", schedule_injective)
+_reg.register_schedule("one_hot", schedule_injective)
 
 
 # layout_transform
index 5d8d280..38ce653 100644 (file)
@@ -748,3 +748,47 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0):
              [[  0.1,  0.1,  0.1], [  16.,  17.,  18.]]]
     """
     return _make.sequence_mask(data, valid_length, mask_value, axis)
+
+def one_hot(indices, on_value, off_value, depth, axis, dtype):
+    """
+    Returns a one-hot tensor where the locations repsented by indices take value on_value,
+    other locations take value off_value.
+    Final dimension is <indices outer dimensions> x depth x <indices inner dimensions>.
+
+    Parameters
+    ----------
+    indices : relay.Expr
+        Locations to set to on_value.
+
+    on_value : relay.Expr
+        Value to fill at indices.
+
+    off_value : relay.Expr
+        Value to fill at all other positions besides indices.
+
+    depth : int
+        Depth of the one-hot dimension.
+
+    axis : int
+        Axis to fill.
+
+    dtype : str
+        Data type of the output tensor.
+
+    Returns
+    -------
+    ret : relay.Expr
+        The one-hot tensor.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        indices = [0, 1, 2]
+
+        relay.one_hot(indices, 3) =
+            [[1, 0, 0],
+             [0, 1, 0],
+             [0, 0, 1]]
+    """
+    return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
index 03a92b3..b39c282 100644 (file)
@@ -2482,5 +2482,94 @@ Examples::
 .set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute)
 .set_attr<TOpPattern>("TOpPattern", kInjective);
 
+// relay.one_hot
+TVM_REGISTER_NODE_TYPE(OneHotAttrs);
+
+bool OneHotRel(const Array<Type>& types,
+               int num_inputs,
+               const Attrs& attrs,
+               const TypeReporter& reporter) {
+  // `types` contains: [indices, on_value, off_value, result]
+  CHECK_EQ(types.size(), 4);
+  const auto* indices = types[0].as<TensorTypeNode>();
+  CHECK(indices);
+
+  const auto param = attrs.as<OneHotAttrs>();
+  CHECK_GT(param->depth, 0);
+
+  Array<IndexExpr> oshape;
+  int ndim = indices->shape.size() + 1;
+  int indices_index = 0;
+  int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis;
+  for (int i = 0; i < ndim; i++) {
+    if (i == true_axis) {
+      oshape.push_back(Integer(param->depth));
+    } else {
+      oshape.push_back(indices->shape[indices_index++]);
+    }
+  }
+
+  reporter->Assign(types[3], TensorTypeNode::make(oshape, param->dtype));
+  return true;
+}
+
+Array<Tensor> OneHotCompute(const Attrs& attrs,
+                            const Array<Tensor>& inputs,
+                            const Type& out_type,
+                            const Target& target) {
+  const auto* param = attrs.as<OneHotAttrs>();
+  CHECK(param != nullptr);
+  return Array<Tensor> {
+    topi::one_hot(inputs[0],
+                  inputs[1](),
+                  inputs[2](),
+                  param->depth,
+                  param->axis,
+                  param->dtype)
+  };
+}
+
+Expr MakeOneHot(Expr indices,
+                Expr on_value,
+                Expr off_value,
+                int depth,
+                int axis,
+                DataType dtype) {
+  auto attrs = make_node<OneHotAttrs>();
+  attrs->depth = std::move(depth);
+  attrs->axis = axis;
+  attrs->dtype = dtype;
+  static const Op& op = Op::Get("one_hot");
+  return CallNode::make(op, {indices, on_value, off_value}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op._make.one_hot")
+.set_body_typed(MakeOneHot);
+
+RELAY_REGISTER_OP("one_hot")
+.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, 
+    other locations take value 0. Final dimension is <indices dimensions> x depth.
+
+    **indices** Locations to set to 1.
+
+    **on_value** Value to fill at indices.
+
+    **off_value** Value to fill at all other positions besides indices.
+
+    **depth** Depth of the one-hot dimension.
+    
+    **axis** Axis to fill.
+    
+    **dtype**)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.OneHotAttrs")
+.set_num_inputs(3)
+.add_argument("indices", "Tensor", "Locations to set to on_value.")
+.add_argument("on_value", "Expr", "Value to fill at indices.")
+.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.")
+.set_support_level(10)
+.add_type_rel("OneHot", OneHotRel)
+.set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
+.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
+
 }  // namespace relay
 }  // namespace tvm
index a4a514e..90b425f 100644 (file)
@@ -778,6 +778,25 @@ def test_forward_layer_norm():
     verify((2, 5), axis=0)
     verify((2, 5, 6))
 
+def test_forward_one_hot():
+    def verify(indices_shape, depth, on_value, off_value, dtype):
+        x = np.random.randint(0, 5, size=indices_shape)
+        ref_res = mx.nd.one_hot(mx.nd.array(x), depth, on_value, off_value, dtype)
+        mx_sym = mx.sym.one_hot(mx.sym.var("x"), depth, on_value, off_value, dtype)
+        shape_dict = {"x": x.shape}
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(x.astype("float32"))
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
+    verify((3,), 3, 1, 0, "int32")
+    verify((3,), 3, 1.0, 0.0, "float32")
+    verify((2, 2), 5, 2, -2, "int32")
+    verify((2, 2), 5, 0.5, -0.5, "float32")
+    verify((3, 2, 4, 5), 6, 1, 0, "int32")
+    verify((3, 2, 4, 5), 6, 1.0, 0.0, "float32")
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -825,3 +844,4 @@ if __name__ == '__main__':
     test_forward_contrib_div_sqrt_dim()
     test_forward_batch_norm()
     test_forward_layer_norm()
+    test_forward_one_hot()
index eb8e27e..16d7ba5 100644 (file)
@@ -2158,6 +2158,24 @@ def test_placeholder():
         compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0',
                             init_global_variables=True)
 
+#######################################################################
+# OneHot
+# ----------------------
+def _test_forward_one_hot(indices_shape, depth, on_value, off_value, axis, out_dtype):
+    inp_array1 = np.random.randint(0, 5, size=indices_shape)
+    with tf.Graph().as_default():
+        in1 = tf.placeholder(shape=inp_array1.shape, dtype=inp_array1.dtype)
+        out = tf.one_hot(in1, depth, on_value, off_value, axis, dtype=out_dtype)
+        compare_tf_with_tvm(inp_array1, in1.name, out.name)
+
+def test_forward_one_hot():
+    _test_forward_one_hot((3,), 3, 1, 0, -1, "int32")
+    _test_forward_one_hot((3,), 3, 1.0, 0.0, -1, "float32")
+    _test_forward_one_hot((2, 2), 5, 2, -2, 0, "int32")
+    _test_forward_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32")
+    _test_forward_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
+    _test_forward_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
+
 
 #######################################################################
 # Main
@@ -2193,6 +2211,7 @@ if __name__ == '__main__':
     test_forward_right_shift()
     test_forward_left_shift()
     test_forward_truncatemod()
+    test_forward_one_hot()
 
     # Activations
     test_forward_sigmoid()
index f3520f3..e828fa3 100644 (file)
@@ -296,6 +296,45 @@ def test_sequence_mask():
     _verify((2, 3, 5, 3), 0.0, 0, 'float32', 'int64')
     _verify((5, 8, 3), 0.1, 1, 'float64', 'float32')
 
+def test_one_hot():
+    def _get_oshape(indices_shape, depth, axis):
+        oshape = []
+        true_axis = len(indices_shape) if axis == -1 else axis
+        ndim = len(indices_shape) + 1
+        indices_index = 0
+        for i in range(0, ndim):
+            if i == true_axis:
+                oshape.append(depth)
+            else:
+                oshape.append(indices_shape[indices_index])
+                indices_index += 1
+        
+        return oshape
+
+    def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
+        indices = relay.var("indices", relay.TensorType(indices_shape, "int32"))
+        on_value_const = relay.const(on_value)
+        off_value_const = relay.const(off_value)
+        out = relay.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype)
+        checked = run_infer_type(out)
+        assert checked.checked_type == relay.ty.TensorType(_get_oshape(indices_shape, depth, axis), dtype)
+        func = relay.Function([indices], out)
+        indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32")
+        out_np = topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype)
+
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, ctx=ctx, target=target)
+                out_relay = intrp.evaluate(func)(indices_np)
+                tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)
+    
+    _verify((3,), 3, 1, 0, -1, "int32")
+    _verify((3,), 3, 1.0, 0.0, -1, "float32")
+    _verify((2, 2), 5, 2, -2, 0, "int32")
+    _verify((2, 2), 5, 0.5, -0.5, 1, "float32")
+    _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
+    _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
+
 if __name__ == "__main__":
     test_adaptive_pool2d()
     test_collapse_sum_like()
@@ -306,4 +345,5 @@ if __name__ == "__main__":
     test_shape_of()
     test_sequence_mask()
     test_ndarray_size()
+    test_one_hot()
 
index e8a65b0..1622b20 100644 (file)
@@ -1247,5 +1247,55 @@ inline Tensor ndarray_size(const Tensor& src,
   }, name, tag);
 }
 
+/*!
+ * \brief Returns a one-hot tensor where the locations repsented by indices take value on_value, 
+    other locations take value off_value.
+ * \param indices locations to set to on_value.
+ * \param on_value value that locations represented by indices take on.
+ * \param off_value value that other locations take on.
+ * \param depth depth of the one-hot dimension.
+ * \param axis axis to fill.
+ * \param dtype data type of the output tensor.
+ * \param name output tensor name.
+ * \param tag output tensor tag.
+ * \return one-hot tensor.
+ */
+inline Tensor one_hot(const Tensor& indices,
+                      const Expr on_value,
+                      const Expr off_value,
+                      int depth,
+                      int axis,
+                      const Type& dtype,
+                      const std::string name = "T_one_hot",
+                      const std::string tag = kInjective) {
+  Array<Expr> oshape;
+  int ndim = indices->shape.size() + 1;
+  int indices_index = 0;
+  int true_axis = (axis == -1) ? indices->shape.size() : axis;
+  for (int i = 0; i < ndim; i++) {
+    if (i == true_axis) {
+      oshape.push_back(Integer(depth));
+    } else {
+      oshape.push_back(indices->shape[indices_index++]);
+    }
+  }
+
+  Expr on_value_cast = cast(dtype, on_value);
+  Expr off_value_cast = cast(dtype, off_value);
+  return compute(oshape, [&](const Array<Var>& iter_vars) {
+    Array<Var> indices_indices;
+    for (size_t i = 0; i < iter_vars.size(); i++) {
+      if (static_cast<int>(i) == true_axis) {
+        continue;
+      }
+
+      indices_indices.push_back(iter_vars[i]);
+    }
+
+    auto idx = iter_vars[true_axis];
+    return ir::Select::make(indices(indices_indices) == idx, on_value_cast, off_value_cast);
+  }, name, tag);
+}
+
 }  // namespace topi
 #endif  // TOPI_TRANSFORM_H_
index 57a9c26..d607c28 100644 (file)
@@ -25,3 +25,4 @@ from .batch_matmul import batch_matmul
 from .slice_axis_python import slice_axis_python
 from .sequence_mask_python import sequence_mask
 from .pool_grad_python import pool_grad_nchw
+from .one_hot import one_hot
diff --git a/topi/python/topi/testing/one_hot.py b/topi/python/topi/testing/one_hot.py
new file mode 100644 (file)
index 0000000..99c52be
--- /dev/null
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""OneHot in python"""
+import numpy as np
+
+def one_hot(indices, on_value, off_value, depth, axis, dtype):
+    """one_hot operator implemented in numpy.
+
+    Returns a one-hot tensor where the locations repsented by indices take value on_value,
+    other locations take value off_value.
+    Final dimension is <indices outer dimensions> x depth x <indices inner dimensions>.
+
+    Parameters
+    ----------
+    indices : numpy.ndarray
+        Locations to set to on_value.
+
+    on_value : int/float
+        Value to fill at indices.
+
+    off_value : int/float
+        Value to fill at all other positions besides indices.
+
+    depth : int
+        Depth of the one-hot dimension.
+
+    axis : int
+        Axis to fill.
+
+    dtype : str
+        Data type of the output tensor.
+
+    Returns
+    -------
+    ret : relay.Expr
+        The one-hot tensor.
+    """
+    oshape = []
+    true_axis = len(indices.shape) if axis == -1 else axis
+    ndim = len(indices.shape) + 1
+    indices_index = 0
+    for i in range(0, ndim):
+        if i == true_axis:
+            oshape.append(depth)
+        else:
+            oshape.append(indices.shape[indices_index])
+            indices_index += 1
+
+    out = np.empty(oshape)
+    output_indices = [index for index in np.ndindex(out.shape)]
+    for output_index in output_indices:
+        indices_indices = []
+        for i, out_idx in enumerate(output_index):
+            if i == true_axis:
+                continue
+            indices_indices.append(out_idx)
+
+        index = output_index[true_axis]
+        if indices[tuple(indices_indices)] == index:
+            out[output_index] = on_value
+        else:
+            out[output_index] = off_value
+
+    return out.astype(dtype)
index 5e87933..3c7fc9c 100644 (file)
@@ -518,3 +518,47 @@ def where(condition, x, y):
         A Tensor selected from x or y depending on condition.
     """
     return cpp.where(condition, x, y)
+
+def one_hot(indices, on_value, off_value, depth, axis, dtype):
+    """
+    Returns a one-hot tensor where the locations repsented by indices take value on_value,
+    other locations take value off_value.
+    Final dimension is <indices outer dimensions> x depth x <indices inner dimensions>.
+
+    Parameters
+    ----------
+    indices : tvm.Tensor
+        Locations to set to on_value.
+
+    on_value : tvm.Tensor
+        Value to fill at indices.
+
+    off_value : tvm.Tensor
+        Value to fill at all other positions besides indices.
+
+    depth : int
+        Depth of the one-hot dimension.
+
+    axis : int
+        Axis to fill.
+
+    dtype : relay.DataType
+        Data type of the output tensor.
+
+    Returns
+    -------
+    ret : relay.Expr
+        The one-hot tensor.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        indices = [0, 1, 2]
+
+        relay.one_hot(indices, 3) =
+            [[1, 0, 0],
+             [0, 1, 0],
+             [0, 0, 1]]
+    """
+    return cpp.one_hot(indices, on_value, off_value, depth, axis, dtype)
index 799b660..7e47b62 100644 (file)
@@ -417,6 +417,14 @@ TVM_REGISTER_GLOBAL("topi.strided_slice")
   *rv = strided_slice(args[0], args[1], args[2], args[3]);
   });
 
+TVM_REGISTER_GLOBAL("topi.one_hot")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  int depth = args[3];
+  int axis = args[4];
+  DataType dtype = args[5];
+  *rv = one_hot(args[0], args[1], args[2], depth, axis, dtype);
+  });
+
 /* Ops from nn/upsampling.h */
 TVM_REGISTER_GLOBAL("topi.nn.upsampling")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
index 64305b4..b1aa20e 100644 (file)
@@ -473,6 +473,31 @@ def verify_where(in_shape):
     for device in get_all_backend():
         check_device(device)
 
+def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype):
+    indices = tvm.placeholder(shape=indices_shape, name="indices", dtype="int32")
+    on_value_const = tvm.const(on_value, dtype)
+    off_value_const = tvm.const(off_value, dtype)
+    one_hot_result = topi.transform.one_hot(indices, on_value_const, off_value_const, depth, axis, dtype)
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            s = topi.generic.schedule_injective(one_hot_result)
+        fn = tvm.build(s, [indices, one_hot_result], device, name="one_hot")
+        indices_npy = np.random.randint(0, depth, size=indices_shape).astype(indices.dtype)
+        out_npy = topi.testing.one_hot(indices_npy, on_value, off_value, depth, axis, dtype)
+        indices_nd = tvm.nd.array(indices_npy, ctx)
+        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(one_hot_result.dtype), ctx)
+        fn(indices_nd, out_nd)
+        out_topi = out_nd.asnumpy()
+        tvm.testing.assert_allclose(out_topi, out_npy)
+
+    for device in get_all_backend():
+        check_device(device)
+
 def test_strided_slice():
     verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
     verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
@@ -770,6 +795,13 @@ def test_where_fusion():
     for backend in get_all_backend():
         check_device(backend)
 
+def test_one_hot():
+    verify_one_hot((3,), 3, 1, 0, -1, "int32")
+    verify_one_hot((3,), 3, 1.0, 0.0, -1, "float32")
+    verify_one_hot((2, 2), 5, 2, -2, 0, "int32")
+    verify_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32")
+    verify_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
+    verify_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
 
 if __name__ == "__main__":
     test_strided_slice()
@@ -793,3 +825,4 @@ if __name__ == "__main__":
     test_sequence_mask()
     test_ndarray_size()
     test_where_fusion()
+    test_one_hot()