[Relay][Dynamic] OneHot operation (#6209)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Thu, 6 Aug 2020 15:46:58 +0000 (08:46 -0700)
committerGitHub <noreply@github.com>
Thu, 6 Aug 2020 15:46:58 +0000 (08:46 -0700)
* Dynamic OneHot Op

* refactor dynamic_to_static

* add onehot to dynamic_to_static pass

include/tvm/topi/transform.h
python/tvm/relay/op/dyn/_transform.py
python/tvm/relay/op/transform.py
src/relay/op/dyn/tensor/transform.cc
src/relay/op/make_op.h
src/relay/transforms/dynamic_to_static.cc
tests/python/relay/dyn/test_dynamic_op_level10.py
tests/python/relay/test_pass_dynamic_to_static.py

index cd19436..19b2ef4 100644 (file)
@@ -1421,22 +1421,25 @@ inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
  * \param depth depth of the one-hot dimension.
  * \param axis axis to fill.
  * \param dtype data type of the output tensor.
+ * \param oshape shape of the output tensor.
  * \param name output tensor name.
  * \param tag output tensor tag.
  * \return one-hot tensor.
  */
 inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value,
                       int depth, int axis, const DataType& dtype,
+                      Array<PrimExpr> oshape = Array<PrimExpr>(),
                       const std::string name = "T_one_hot", const std::string tag = kInjective) {
-  Array<PrimExpr> oshape;
-  int ndim = indices->shape.size() + 1;
-  int indices_index = 0;
   int true_axis = (axis == -1) ? indices->shape.size() : axis;
-  for (int i = 0; i < ndim; i++) {
-    if (i == true_axis) {
-      oshape.push_back(Integer(depth));
-    } else {
-      oshape.push_back(indices->shape[indices_index++]);
+  if (oshape.size() == 0) {
+    int ndim = indices->shape.size() + 1;
+    int indices_index = 0;
+    for (int i = 0; i < ndim; i++) {
+      if (i == true_axis) {
+        oshape.push_back(Integer(depth));
+      } else {
+        oshape.push_back(indices->shape[indices_index++]);
+      }
     }
   }
 
index e2704bc..3a80f5a 100644 (file)
@@ -25,11 +25,13 @@ from .. import op as _reg
 _reg.register_broadcast_schedule("dyn.broadcast_to")
 _reg.register_injective_schedule("dyn.reshape")
 _reg.register_broadcast_schedule("dyn.tile")
+_reg.register_injective_schedule("dyn.one_hot")
+
 
 @script
 def _reshape_shape_func_input_data(data, newshape, ndim):
-    out = output_tensor((ndim,), "int64")
-    data_shape = allocate((len(data.shape),), "int64")
+    out = output_tensor((ndim, ), "int64")
+    data_shape = allocate((len(data.shape), ), "int64")
     for x in const_range(len(data.shape)):
         data_shape[x] = int64(data.shape[x])
     src_idx = 0
@@ -59,7 +61,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim):
         elif newshape[i] == -3:
             assert data_shape.shape[0] - src_idx > 1, \
                 "Not enough dims in input shape for -3"
-            out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
+            out[dst_idx] = data_shape[src_idx] * data_shape[src_idx + 1]
             src_idx += 2
             dst_idx += 1
         elif newshape[i] == -4:
@@ -82,6 +84,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim):
             out[infer_idx] = old_size // new_size
     return out
 
+
 @_reg.register_shape_func("dyn.reshape", True)
 def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
     return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
@@ -89,7 +92,7 @@ def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
 
 @script
 def _tile_shape_func(data, reps, ndim, tndim, rndim):
-    out = output_tensor((tndim,), "int64")
+    out = output_tensor((tndim, ), "int64")
 
     if ndim == rndim:
         for i in const_range(tndim):
@@ -120,5 +123,25 @@ def tile_shape_func(attrs, inputs, _):
     ndim = len(inputs[0].shape)
     rndim = inputs[1].shape[0].value
     tndim = ndim if ndim > rndim else rndim
-    return [_tile_shape_func(inputs[0], reps, convert(ndim),
-                             convert(tndim), convert(rndim))]
+    return [_tile_shape_func(inputs[0], reps, convert(ndim), convert(tndim), convert(rndim))]
+
+
+@script
+def _onehot_shape_func(dshape, k, axis):
+    ndim = len(dshape) + 1
+    out = output_tensor((ndim, ), "int64")
+    for i in const_range(axis):
+        out[i] = int64(dshape[i])
+    out[axis] = int64(k[0])
+    for j in const_range(axis + 1, ndim):
+        out[j] = int64(dshape[j - 1])
+    return out
+
+
+@_reg.register_shape_func("dyn.one_hot", True)
+def one_hot_shape_func(attrs, inputs, _):
+    """
+    Shape function for dyn.one_hot op.
+    """
+    axis = len(inputs[0].shape) if attrs.axis == -1 else attrs.axis
+    return [_onehot_shape_func(inputs[0].shape, inputs[3], convert(axis))]
index 6f23af2..5e5b867 100644 (file)
@@ -148,6 +148,7 @@ def squeeze(data, axis=None):
     """
     return _make.squeeze(data, axis)
 
+
 def reshape(data, newshape):
     """Reshape the input array.
 
@@ -228,6 +229,7 @@ def reshape(data, newshape):
         newshape = tempshape
     return _make.reshape(data, list(newshape))
 
+
 def argwhere(condition):
     """Find the indices of elements of a tensor that are
     non-zero.
@@ -251,6 +253,7 @@ def argwhere(condition):
     """
     return _make.argwhere(condition)
 
+
 def scatter(data, indices, updates, axis):
     """Update data at positions defined by indices with values in updates
 
@@ -275,6 +278,7 @@ def scatter(data, indices, updates, axis):
     """
     return _make.scatter(data, indices, updates, axis)
 
+
 def scatter_add(data, indices, updates, axis):
     """Update data by adding values in updates at positions defined by indices
 
@@ -299,6 +303,7 @@ def scatter_add(data, indices, updates, axis):
     """
     return _make.scatter_add(data, indices, updates, axis)
 
+
 def reshape_like(data, shape_like):
     """Reshapes the input array by the size of another array.
     For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
@@ -442,6 +447,7 @@ def arange(start, stop=None, step=None, dtype="float32"):
 
     return _make.arange(start, stop, step, dtype)
 
+
 def meshgrid(data, indexing="ij"):
     """Create coordinate matrices from coordinate vectors.
 
@@ -482,6 +488,7 @@ def meshgrid(data, indexing="ij"):
     ret_size = len(data)
     return TupleWrapper(_make.meshgrid(Tuple(data), indexing), ret_size)
 
+
 def repeat(data, repeats, axis):
     """Repeats elements of an array.
     By default, repeat flattens the input array into 1-D and then repeats the elements.
@@ -668,6 +675,7 @@ def where(condition, x, y):
     """
     return _make.where(condition, x, y)
 
+
 def broadcast_to(data, shape):
     """Return a scalar value array with the same type, broadcast to
     the provided shape.
@@ -693,6 +701,7 @@ def broadcast_to(data, shape):
         shape = list(shape)
     return _make.broadcast_to(data, shape)
 
+
 def broadcast_to_like(data, broadcast_type):
     """Return a scalar value array with the same shape and type as the input array.
 
@@ -1053,6 +1062,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0):
     """
     return _make.sequence_mask(data, valid_length, mask_value, axis)
 
+
 def one_hot(indices, on_value, off_value, depth, axis, dtype):
     """
     Returns a one-hot tensor where the locations repsented by indices take value on_value,
@@ -1070,7 +1080,7 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
     off_value : relay.Expr
         Value to fill at all other positions besides indices.
 
-    depth : int
+    depth : int or relay.Expr
         Depth of the one-hot dimension.
 
     axis : int
@@ -1095,6 +1105,8 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype):
              [0, 1, 0],
              [0, 0, 1]]
     """
+    if isinstance(depth, Expr):
+        return _dyn_make.one_hot(indices, on_value, off_value, depth, axis, dtype)
     return _make.one_hot(indices, on_value, off_value, depth, axis, dtype)
 
 
@@ -1120,6 +1132,7 @@ def unravel_index(indices, shape):
 
     return _make.unravel_index(indices, shape)
 
+
 def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0):
     """Converts a sparse representation into a dense tensor.
 
index 2bb87ac..d2d6d69 100644 (file)
@@ -304,6 +304,76 @@ RELAY_REGISTER_OP("dyn.ones")
     .set_support_level(3)
     .add_type_rel("DynamicInitOp", InitOpRel);
 
+bool OneHotRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+               const TypeReporter& reporter) {
+  // `types` contains: [indices, on_value, off_value, result]
+  CHECK_EQ(types.size(), 5);
+  const auto* indices = types[0].as<TensorTypeNode>();
+  CHECK(indices);
+
+  const auto param = attrs.as<OneHotAttrs>();
+
+  Array<IndexExpr> oshape;
+  int ndim = indices->shape.size() + 1;
+  int indices_index = 0;
+  int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis;
+  for (int i = 0; i < ndim; i++) {
+    if (i == true_axis) {
+      oshape.push_back(Any());
+    } else {
+      oshape.push_back(indices->shape[indices_index++]);
+    }
+  }
+
+  reporter->Assign(types[4], TensorType(oshape, param->dtype));
+  return true;
+}
+
+Array<te::Tensor> OneHotCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
+                                const Type& out_type) {
+  const auto* param = attrs.as<OneHotAttrs>();
+  CHECK(param != nullptr);
+  const auto* out_ttype = out_type.as<TensorTypeNode>();
+  return Array<te::Tensor>{topi::one_hot(inputs[0], inputs[1](), inputs[2](), -1, param->axis,
+                                         param->dtype, out_ttype->shape)};
+}
+
+Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, Expr depth, int axis, DataType dtype) {
+  auto attrs = make_object<OneHotAttrs>();
+  attrs->axis = axis;
+  attrs->dtype = dtype;
+  static const Op& op = Op::Get("dyn.one_hot");
+  return Call(op, {indices, on_value, off_value, depth}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.dyn._make.one_hot").set_body_typed(MakeOneHot);
+
+RELAY_REGISTER_OP("dyn.one_hot")
+    .describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1,
+    other locations take value 0. Final dimension is <indices dimensions> x depth.
+
+    **indices** Locations to set to 1.
+
+    **on_value** Value to fill at indices.
+
+    **off_value** Value to fill at all other positions besides indices.
+
+    **depth** Depth of the one-hot dimension.
+
+    **axis** Axis to fill.
+
+    **dtype**)code" TVM_ADD_FILELINE)
+    .set_attrs_type<OneHotAttrs>()
+    .set_num_inputs(4)
+    .add_argument("indices", "Tensor", "Locations to set to on_value.")
+    .add_argument("on_value", "Expr", "Value to fill at indices.")
+    .add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.")
+    .add_argument("depth", "Expr", "Value to fill at all other positions besides indices.")
+    .set_support_level(10)
+    .add_type_rel("DynOneHot", OneHotRel)
+    .set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
+    .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
+
 }  // namespace dyn
 }  // namespace relay
 }  // namespace tvm
index 3b5e9a1..d2c170d 100644 (file)
@@ -78,6 +78,8 @@ Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool
 
 Expr MakeZeros(Array<Integer> shape, DataType dtype);
 
+Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype);
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_OP_MAKE_OP_H_
index d4de15c..8501ee5 100644 (file)
@@ -33,44 +33,82 @@ namespace relay {
 
 class DynamicToStaticMutator : public MixedModeMutator {
  public:
-  DynamicToStaticMutator() {}
+  DynamicToStaticMutator() {
+    op_map_ = {
+        {Op::Get("dyn.reshape"),
+         [](const CallNode* call_node) {
+           if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
+             CHECK_EQ(shape->data->ndim, 1);
+             return MakeReshape(call_node->args[0], ToVector(shape->data));
+           }
+           return Expr(nullptr);
+         }},
+        {Op::Get("dyn.tile"),
+         [](const CallNode* call_node) {
+           if (const ConstantNode* reps = call_node->args[1].as<ConstantNode>()) {
+             CHECK_EQ(reps->data->ndim, 1);
+             return MakeTile(call_node->args[0], ToVector(reps->data));
+           }
+           return Expr(nullptr);
+         }},
+        {Op::Get("dyn.topk"),
+         [](const CallNode* call_node) {
+           if (const ConstantNode* k = call_node->args[1].as<ConstantNode>()) {
+             const TopKAttrs* param = call_node->attrs.as<TopKAttrs>();
+             CHECK(param);
+             return MakeTopK(call_node->args[0], static_cast<int>(ToScalar(k->data, 0)),
+                             param->axis, param->ret_type, param->is_ascend, param->dtype);
+           }
+           return Expr(nullptr);
+         }},
+        {Op::Get("dyn.broadcast_to"),
+         [](const CallNode* call_node) {
+           if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
+             CHECK_EQ(shape->data->ndim, 1);
+             return MakeBroadCastTo(call_node->args[0], ToVector(shape->data));
+           }
+           return Expr(nullptr);
+         }},
+        {Op::Get("dyn.zeros"),
+         [](const CallNode* call_node) {
+           if (const ConstantNode* shape = call_node->args[0].as<ConstantNode>()) {
+             const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
+             CHECK(param);
+             return MakeZeros(ToVector(shape->data), param->dtype);
+           }
+           return Expr(nullptr);
+         }},
+        {Op::Get("dyn.ones"),
+         [](const CallNode* call_node) {
+           if (const ConstantNode* shape = call_node->args[0].as<ConstantNode>()) {
+             const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
+             CHECK(param);
+             return MakeOnes(ToVector(shape->data), param->dtype);
+           }
+           return Expr(nullptr);
+         }},
+        {Op::Get("dyn.one_hot"),
+         [](const CallNode* call_node) {
+           if (const ConstantNode* depth = call_node->args[3].as<ConstantNode>()) {
+             const OneHotAttrs* param = call_node->attrs.as<OneHotAttrs>();
+             CHECK(param);
+             return MakeOneHot(call_node->args[0], call_node->args[1], call_node->args[2],
+                               static_cast<int>(ToScalar(depth->data, 0)), param->axis,
+                               param->dtype);
+           }
+           return Expr(nullptr);
+         }},
+    };
+  }
 
  private:
   Expr Rewrite_(const CallNode* pre, const Expr& post) override {
-    const CallNode* call_node = post.as<CallNode>();
-    if (call_node->op == Op::Get("dyn.reshape")) {
-      if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
-        CHECK_EQ(shape->data->ndim, 1);
-        return MakeReshape(call_node->args[0], ToVector(shape->data));
-      }
-    } else if (call_node->op == Op::Get("dyn.tile")) {
-      if (const ConstantNode* reps = call_node->args[1].as<ConstantNode>()) {
-        CHECK_EQ(reps->data->ndim, 1);
-        return MakeTile(call_node->args[0], ToVector(reps->data));
-      }
-    } else if (call_node->op == Op::Get("dyn.topk")) {
-      if (const ConstantNode* k = call_node->args[1].as<ConstantNode>()) {
-        const TopKAttrs* param = call_node->attrs.as<TopKAttrs>();
-        CHECK(param);
-        return MakeTopK(call_node->args[0], static_cast<int>(ToScalar(k->data, 0)), param->axis,
-                        param->ret_type, param->is_ascend, param->dtype);
-      }
-    } else if (call_node->op == Op::Get("dyn.broadcast_to")) {
-      if (const ConstantNode* shape = call_node->args[1].as<ConstantNode>()) {
-        CHECK_EQ(shape->data->ndim, 1);
-        return MakeBroadCastTo(call_node->args[0], ToVector(shape->data));
-      }
-    } else if (call_node->op == Op::Get("dyn.zeros")) {
-      if (const ConstantNode* shape = call_node->args[0].as<ConstantNode>()) {
-        const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
-        CHECK(param);
-        return MakeZeros(ToVector(shape->data), param->dtype);
-      }
-    } else if (call_node->op == Op::Get("dyn.ones")) {
-      if (const ConstantNode* shape = call_node->args[0].as<ConstantNode>()) {
-        const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
-        CHECK(param);
-        return MakeOnes(ToVector(shape->data), param->dtype);
+    if (const CallNode* call_node = post.as<CallNode>()) {
+      if (op_map_.count(call_node->op)) {
+        auto out = op_map_[call_node->op](call_node);
+        if (out.defined()) {
+          return out;
+        }
       }
     }
     return post;
@@ -83,6 +121,8 @@ class DynamicToStaticMutator : public MixedModeMutator {
     }
     return post;
   }
+  std::unordered_map<Expr, std::function<Expr(const CallNode*)>, ObjectPtrHash, ObjectPtrEqual>
+      op_map_;
 };
 
 Expr DynamicToStatic(Function f, IRModule m) {
@@ -90,6 +130,7 @@ Expr DynamicToStatic(Function f, IRModule m) {
   Expr expr = f;
   auto fold_const = transform::FoldConstant();
   auto infer_type = transform::InferType();
+  DynamicToStaticMutator mutator;
   Map<BaseFunc, GlobalVar> vars;
   for (auto kv : m->functions) {
     vars.Set(kv.second, kv.first);
@@ -101,7 +142,7 @@ Expr DynamicToStatic(Function f, IRModule m) {
     // TODO(mbrookhart): Is it possible to run these passes JUST on the current function?
     m = infer_type(m);
     m = fold_const(m);
-    expr = DynamicToStaticMutator().Mutate(m->functions[gv]);
+    expr = mutator.Mutate(m->functions[gv]);
     m->Update(gv, Downcast<BaseFunc>(expr));
     i += 1;
   } while (pre != expr && i < 1000);
index d9b23a7..95a030f 100644 (file)
@@ -19,36 +19,80 @@ Support level10 operator test cases.
 
 """
 
-
 import numpy as np
 import tvm
 from tvm import relay
 from tvm.relay.testing import ctx_list, run_infer_type
+import tvm.topi.testing
 import random
 
+
 def test_dyn_broadcast_to():
     dtype = 'uint8'
     rank = 3
     shape_type = 'int64'
-    dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type))
-    x_shape = (1,)
+    dyn_shape = relay.Var("shape", relay.ty.TensorType((rank, ), shape_type))
+    x_shape = (1, )
     x = relay.Var("x", relay.ty.TensorType(x_shape, dtype))
     z = relay.broadcast_to(x, dyn_shape)
     zz = run_infer_type(z)
-    
-    assert zz.checked_type == relay.ty.TensorType((relay.Any(),) * rank, dtype)
+
+    assert zz.checked_type == relay.ty.TensorType((relay.Any(), ) * rank, dtype)
 
     func = relay.Function([x, dyn_shape], z)
-    
+
     x = np.random.uniform(size=x_shape).astype(dtype)
-    dyn_shape = (1,)*rank
+    dyn_shape = (1, ) * rank
     ref_res = np.broadcast_to(x, dyn_shape)
     for target, ctx in ctx_list():
-        if (target != 'cuda'): #skip cuda because we don't have dynamic support for GPU
+        if (target != 'cuda'):  #skip cuda because we don't have dynamic support for GPU
             for kind in ["vm", "debug"]:
                 mod = tvm.ir.IRModule.from_expr(func)
                 intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
-                op_res = intrp.evaluate(func)(x,np.array(dyn_shape).astype(shape_type))
+                op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type))
                 tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
 
-test_dyn_broadcast_to()
+
+def test_dyn_one_hot():
+    def _get_oshape(indices_shape, depth, axis):
+        oshape = []
+        true_axis = len(indices_shape) if axis == -1 else axis
+        ndim = len(indices_shape) + 1
+        indices_index = 0
+        for i in range(0, ndim):
+            if i == true_axis:
+                oshape.append(depth)
+            else:
+                oshape.append(indices_shape[indices_index])
+                indices_index += 1
+
+        return oshape
+
+    def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
+        indices = relay.var("indices", relay.TensorType(indices_shape, "int32"))
+        depth_var = relay.var("depth", relay.TensorType((), "int32"))
+        on_value_const = relay.const(on_value)
+        off_value_const = relay.const(off_value)
+        out = relay.one_hot(indices, on_value_const, off_value_const, depth_var, axis, dtype)
+        func = relay.Function([indices, depth_var], out)
+        indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32")
+        out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype)
+        for target, ctx in ctx_list():
+            if (target != 'cuda'):  #skip cuda because we don't have dynamic support for GPU
+                for kind in ["vm", "debug"]:
+                    mod = tvm.ir.IRModule.from_expr(func)
+                    intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+                    out_relay = intrp.evaluate()(indices_np, np.array(depth).astype("int32"))
+                    tvm.testing.assert_allclose(out_relay.asnumpy(), out_np)
+
+    _verify((3, ), 3, 1, 0, -1, "int32")
+    _verify((3, ), 3, 1.0, 0.0, -1, "float32")
+    _verify((2, 2), 5, 2, -2, 0, "int32")
+    _verify((2, 2), 5, 0.5, -0.5, 1, "float32")
+    _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
+    _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
+
+
+if __name__ == "__main__":
+    test_dyn_broadcast_to()
+    test_dyn_one_hot()
index 8ca7882..a50c9df 100644 (file)
@@ -22,6 +22,8 @@ from tvm.relay import transform
 from tvm.relay.build_module import bind_params_by_name
 from tvm.relay.testing import run_infer_type, create_workload, ctx_list
 
+import tvm.topi.testing
+
 
 def run_opt_pass(expr, opt_pass):
     assert isinstance(opt_pass, tvm.transform.Pass)
@@ -222,6 +224,32 @@ def test_dynamic_to_static_zeros_ones():
     verify_ones_zeros((1, 2, 3), 'int64')
     verify_ones_zeros((9, 8, 3, 4), 'float32')
 
+def test_dynamic_to_static_one_hot():
+    def _verify(indices_shape, depth, on_value, off_value, axis, dtype):
+        indices = relay.var("indices", relay.TensorType(indices_shape, "int32"))
+        depth_var = relay.const(depth)
+        on_value_const = relay.const(on_value)
+        off_value_const = relay.const(off_value)
+        out = relay.one_hot(indices, on_value_const, off_value_const, depth_var, axis, dtype)
+        func = relay.Function([indices], out)
+
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())
+
+        zz = func2.body
+        assert isinstance(zz, relay.Call)
+        assert zz.op == relay.op.get("one_hot")
+
+        indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32")
+        out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype)
+        verify_func(func2, [indices_np], out_np)
+
+    _verify((3, ), 3, 1, 0, -1, "int32")
+    _verify((3, ), 3, 1.0, 0.0, -1, "float32")
+    _verify((2, 2), 5, 2, -2, 0, "int32")
+    _verify((2, 2), 5, 0.5, -0.5, 1, "float32")
+    _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32")
+    _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
+
 if __name__=="__main__":
     test_dynamic_to_static_reshape()
     test_dynamic_to_static_double_reshape()