From: Matthew Brookhart Date: Thu, 6 Aug 2020 15:46:58 +0000 (-0700) Subject: [Relay][Dynamic] OneHot operation (#6209) X-Git-Tag: upstream/0.7.0~298 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=da75d85cdce6fa189f3662793e0a68e0f84309f1;p=platform%2Fupstream%2Ftvm.git [Relay][Dynamic] OneHot operation (#6209) * Dynamic OneHot Op * refactor dynamic_to_static * add onehot to dynamic_to_static pass --- diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index cd19436..19b2ef4 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1421,22 +1421,25 @@ inline Tensor ndarray_size(const Tensor& src, const DataType& dtype, * \param depth depth of the one-hot dimension. * \param axis axis to fill. * \param dtype data type of the output tensor. + * \param oshape shape of the output tensor. * \param name output tensor name. * \param tag output tensor tag. * \return one-hot tensor. */ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value, int depth, int axis, const DataType& dtype, + Array oshape = Array(), const std::string name = "T_one_hot", const std::string tag = kInjective) { - Array oshape; - int ndim = indices->shape.size() + 1; - int indices_index = 0; int true_axis = (axis == -1) ? indices->shape.size() : axis; - for (int i = 0; i < ndim; i++) { - if (i == true_axis) { - oshape.push_back(Integer(depth)); - } else { - oshape.push_back(indices->shape[indices_index++]); + if (oshape.size() == 0) { + int ndim = indices->shape.size() + 1; + int indices_index = 0; + for (int i = 0; i < ndim; i++) { + if (i == true_axis) { + oshape.push_back(Integer(depth)); + } else { + oshape.push_back(indices->shape[indices_index++]); + } } } diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index e2704bc..3a80f5a 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -25,11 +25,13 @@ from .. import op as _reg _reg.register_broadcast_schedule("dyn.broadcast_to") _reg.register_injective_schedule("dyn.reshape") _reg.register_broadcast_schedule("dyn.tile") +_reg.register_injective_schedule("dyn.one_hot") + @script def _reshape_shape_func_input_data(data, newshape, ndim): - out = output_tensor((ndim,), "int64") - data_shape = allocate((len(data.shape),), "int64") + out = output_tensor((ndim, ), "int64") + data_shape = allocate((len(data.shape), ), "int64") for x in const_range(len(data.shape)): data_shape[x] = int64(data.shape[x]) src_idx = 0 @@ -59,7 +61,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim): elif newshape[i] == -3: assert data_shape.shape[0] - src_idx > 1, \ "Not enough dims in input shape for -3" - out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1] + out[dst_idx] = data_shape[src_idx] * data_shape[src_idx + 1] src_idx += 2 dst_idx += 1 elif newshape[i] == -4: @@ -82,6 +84,7 @@ def _reshape_shape_func_input_data(data, newshape, ndim): out[infer_idx] = old_size // new_size return out + @_reg.register_shape_func("dyn.reshape", True) def dynamic_reshape_shape_func(attrs, inputs, out_ndims): return [_reshape_shape_func_input_data(*inputs, out_ndims[0])] @@ -89,7 +92,7 @@ def dynamic_reshape_shape_func(attrs, inputs, out_ndims): @script def _tile_shape_func(data, reps, ndim, tndim, rndim): - out = output_tensor((tndim,), "int64") + out = output_tensor((tndim, ), "int64") if ndim == rndim: for i in const_range(tndim): @@ -120,5 +123,25 @@ def tile_shape_func(attrs, inputs, _): ndim = len(inputs[0].shape) rndim = inputs[1].shape[0].value tndim = ndim if ndim > rndim else rndim - return [_tile_shape_func(inputs[0], reps, convert(ndim), - convert(tndim), convert(rndim))] + return [_tile_shape_func(inputs[0], reps, convert(ndim), convert(tndim), convert(rndim))] + + +@script +def _onehot_shape_func(dshape, k, axis): + ndim = len(dshape) + 1 + out = output_tensor((ndim, ), "int64") + for i in const_range(axis): + out[i] = int64(dshape[i]) + out[axis] = int64(k[0]) + for j in const_range(axis + 1, ndim): + out[j] = int64(dshape[j - 1]) + return out + + +@_reg.register_shape_func("dyn.one_hot", True) +def one_hot_shape_func(attrs, inputs, _): + """ + Shape function for dyn.one_hot op. + """ + axis = len(inputs[0].shape) if attrs.axis == -1 else attrs.axis + return [_onehot_shape_func(inputs[0].shape, inputs[3], convert(axis))] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 6f23af2..5e5b867 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -148,6 +148,7 @@ def squeeze(data, axis=None): """ return _make.squeeze(data, axis) + def reshape(data, newshape): """Reshape the input array. @@ -228,6 +229,7 @@ def reshape(data, newshape): newshape = tempshape return _make.reshape(data, list(newshape)) + def argwhere(condition): """Find the indices of elements of a tensor that are non-zero. @@ -251,6 +253,7 @@ def argwhere(condition): """ return _make.argwhere(condition) + def scatter(data, indices, updates, axis): """Update data at positions defined by indices with values in updates @@ -275,6 +278,7 @@ def scatter(data, indices, updates, axis): """ return _make.scatter(data, indices, updates, axis) + def scatter_add(data, indices, updates, axis): """Update data by adding values in updates at positions defined by indices @@ -299,6 +303,7 @@ def scatter_add(data, indices, updates, axis): """ return _make.scatter_add(data, indices, updates, axis) + def reshape_like(data, shape_like): """Reshapes the input array by the size of another array. For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes @@ -442,6 +447,7 @@ def arange(start, stop=None, step=None, dtype="float32"): return _make.arange(start, stop, step, dtype) + def meshgrid(data, indexing="ij"): """Create coordinate matrices from coordinate vectors. @@ -482,6 +488,7 @@ def meshgrid(data, indexing="ij"): ret_size = len(data) return TupleWrapper(_make.meshgrid(Tuple(data), indexing), ret_size) + def repeat(data, repeats, axis): """Repeats elements of an array. By default, repeat flattens the input array into 1-D and then repeats the elements. @@ -668,6 +675,7 @@ def where(condition, x, y): """ return _make.where(condition, x, y) + def broadcast_to(data, shape): """Return a scalar value array with the same type, broadcast to the provided shape. @@ -693,6 +701,7 @@ def broadcast_to(data, shape): shape = list(shape) return _make.broadcast_to(data, shape) + def broadcast_to_like(data, broadcast_type): """Return a scalar value array with the same shape and type as the input array. @@ -1053,6 +1062,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): """ return _make.sequence_mask(data, valid_length, mask_value, axis) + def one_hot(indices, on_value, off_value, depth, axis, dtype): """ Returns a one-hot tensor where the locations repsented by indices take value on_value, @@ -1070,7 +1080,7 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): off_value : relay.Expr Value to fill at all other positions besides indices. - depth : int + depth : int or relay.Expr Depth of the one-hot dimension. axis : int @@ -1095,6 +1105,8 @@ def one_hot(indices, on_value, off_value, depth, axis, dtype): [0, 1, 0], [0, 0, 1]] """ + if isinstance(depth, Expr): + return _dyn_make.one_hot(indices, on_value, off_value, depth, axis, dtype) return _make.one_hot(indices, on_value, off_value, depth, axis, dtype) @@ -1120,6 +1132,7 @@ def unravel_index(indices, shape): return _make.unravel_index(indices, shape) + def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0): """Converts a sparse representation into a dense tensor. diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index 2bb87ac..d2d6d69 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -304,6 +304,76 @@ RELAY_REGISTER_OP("dyn.ones") .set_support_level(3) .add_type_rel("DynamicInitOp", InitOpRel); +bool OneHotRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [indices, on_value, off_value, result] + CHECK_EQ(types.size(), 5); + const auto* indices = types[0].as(); + CHECK(indices); + + const auto param = attrs.as(); + + Array oshape; + int ndim = indices->shape.size() + 1; + int indices_index = 0; + int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis; + for (int i = 0; i < ndim; i++) { + if (i == true_axis) { + oshape.push_back(Any()); + } else { + oshape.push_back(indices->shape[indices_index++]); + } + } + + reporter->Assign(types[4], TensorType(oshape, param->dtype)); + return true; +} + +Array OneHotCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + const auto* out_ttype = out_type.as(); + return Array{topi::one_hot(inputs[0], inputs[1](), inputs[2](), -1, param->axis, + param->dtype, out_ttype->shape)}; +} + +Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, Expr depth, int axis, DataType dtype) { + auto attrs = make_object(); + attrs->axis = axis; + attrs->dtype = dtype; + static const Op& op = Op::Get("dyn.one_hot"); + return Call(op, {indices, on_value, off_value, depth}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn._make.one_hot").set_body_typed(MakeOneHot); + +RELAY_REGISTER_OP("dyn.one_hot") + .describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, + other locations take value 0. Final dimension is x depth. + + **indices** Locations to set to 1. + + **on_value** Value to fill at indices. + + **off_value** Value to fill at all other positions besides indices. + + **depth** Depth of the one-hot dimension. + + **axis** Axis to fill. + + **dtype**)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(4) + .add_argument("indices", "Tensor", "Locations to set to on_value.") + .add_argument("on_value", "Expr", "Value to fill at indices.") + .add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.") + .add_argument("depth", "Expr", "Value to fill at all other positions besides indices.") + .set_support_level(10) + .add_type_rel("DynOneHot", OneHotRel) + .set_attr("FTVMCompute", OneHotCompute) + .set_attr("TOpPattern", kOutEWiseFusable); + } // namespace dyn } // namespace relay } // namespace tvm diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 3b5e9a1..d2c170d 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -78,6 +78,8 @@ Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool Expr MakeZeros(Array shape, DataType dtype); +Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_MAKE_OP_H_ diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index d4de15c..8501ee5 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -33,44 +33,82 @@ namespace relay { class DynamicToStaticMutator : public MixedModeMutator { public: - DynamicToStaticMutator() {} + DynamicToStaticMutator() { + op_map_ = { + {Op::Get("dyn.reshape"), + [](const CallNode* call_node) { + if (const ConstantNode* shape = call_node->args[1].as()) { + CHECK_EQ(shape->data->ndim, 1); + return MakeReshape(call_node->args[0], ToVector(shape->data)); + } + return Expr(nullptr); + }}, + {Op::Get("dyn.tile"), + [](const CallNode* call_node) { + if (const ConstantNode* reps = call_node->args[1].as()) { + CHECK_EQ(reps->data->ndim, 1); + return MakeTile(call_node->args[0], ToVector(reps->data)); + } + return Expr(nullptr); + }}, + {Op::Get("dyn.topk"), + [](const CallNode* call_node) { + if (const ConstantNode* k = call_node->args[1].as()) { + const TopKAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeTopK(call_node->args[0], static_cast(ToScalar(k->data, 0)), + param->axis, param->ret_type, param->is_ascend, param->dtype); + } + return Expr(nullptr); + }}, + {Op::Get("dyn.broadcast_to"), + [](const CallNode* call_node) { + if (const ConstantNode* shape = call_node->args[1].as()) { + CHECK_EQ(shape->data->ndim, 1); + return MakeBroadCastTo(call_node->args[0], ToVector(shape->data)); + } + return Expr(nullptr); + }}, + {Op::Get("dyn.zeros"), + [](const CallNode* call_node) { + if (const ConstantNode* shape = call_node->args[0].as()) { + const InitOpAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeZeros(ToVector(shape->data), param->dtype); + } + return Expr(nullptr); + }}, + {Op::Get("dyn.ones"), + [](const CallNode* call_node) { + if (const ConstantNode* shape = call_node->args[0].as()) { + const InitOpAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeOnes(ToVector(shape->data), param->dtype); + } + return Expr(nullptr); + }}, + {Op::Get("dyn.one_hot"), + [](const CallNode* call_node) { + if (const ConstantNode* depth = call_node->args[3].as()) { + const OneHotAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeOneHot(call_node->args[0], call_node->args[1], call_node->args[2], + static_cast(ToScalar(depth->data, 0)), param->axis, + param->dtype); + } + return Expr(nullptr); + }}, + }; + } private: Expr Rewrite_(const CallNode* pre, const Expr& post) override { - const CallNode* call_node = post.as(); - if (call_node->op == Op::Get("dyn.reshape")) { - if (const ConstantNode* shape = call_node->args[1].as()) { - CHECK_EQ(shape->data->ndim, 1); - return MakeReshape(call_node->args[0], ToVector(shape->data)); - } - } else if (call_node->op == Op::Get("dyn.tile")) { - if (const ConstantNode* reps = call_node->args[1].as()) { - CHECK_EQ(reps->data->ndim, 1); - return MakeTile(call_node->args[0], ToVector(reps->data)); - } - } else if (call_node->op == Op::Get("dyn.topk")) { - if (const ConstantNode* k = call_node->args[1].as()) { - const TopKAttrs* param = call_node->attrs.as(); - CHECK(param); - return MakeTopK(call_node->args[0], static_cast(ToScalar(k->data, 0)), param->axis, - param->ret_type, param->is_ascend, param->dtype); - } - } else if (call_node->op == Op::Get("dyn.broadcast_to")) { - if (const ConstantNode* shape = call_node->args[1].as()) { - CHECK_EQ(shape->data->ndim, 1); - return MakeBroadCastTo(call_node->args[0], ToVector(shape->data)); - } - } else if (call_node->op == Op::Get("dyn.zeros")) { - if (const ConstantNode* shape = call_node->args[0].as()) { - const InitOpAttrs* param = call_node->attrs.as(); - CHECK(param); - return MakeZeros(ToVector(shape->data), param->dtype); - } - } else if (call_node->op == Op::Get("dyn.ones")) { - if (const ConstantNode* shape = call_node->args[0].as()) { - const InitOpAttrs* param = call_node->attrs.as(); - CHECK(param); - return MakeOnes(ToVector(shape->data), param->dtype); + if (const CallNode* call_node = post.as()) { + if (op_map_.count(call_node->op)) { + auto out = op_map_[call_node->op](call_node); + if (out.defined()) { + return out; + } } } return post; @@ -83,6 +121,8 @@ class DynamicToStaticMutator : public MixedModeMutator { } return post; } + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + op_map_; }; Expr DynamicToStatic(Function f, IRModule m) { @@ -90,6 +130,7 @@ Expr DynamicToStatic(Function f, IRModule m) { Expr expr = f; auto fold_const = transform::FoldConstant(); auto infer_type = transform::InferType(); + DynamicToStaticMutator mutator; Map vars; for (auto kv : m->functions) { vars.Set(kv.second, kv.first); @@ -101,7 +142,7 @@ Expr DynamicToStatic(Function f, IRModule m) { // TODO(mbrookhart): Is it possible to run these passes JUST on the current function? m = infer_type(m); m = fold_const(m); - expr = DynamicToStaticMutator().Mutate(m->functions[gv]); + expr = mutator.Mutate(m->functions[gv]); m->Update(gv, Downcast(expr)); i += 1; } while (pre != expr && i < 1000); diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py b/tests/python/relay/dyn/test_dynamic_op_level10.py index d9b23a7..95a030f 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level10.py +++ b/tests/python/relay/dyn/test_dynamic_op_level10.py @@ -19,36 +19,80 @@ Support level10 operator test cases. """ - import numpy as np import tvm from tvm import relay from tvm.relay.testing import ctx_list, run_infer_type +import tvm.topi.testing import random + def test_dyn_broadcast_to(): dtype = 'uint8' rank = 3 shape_type = 'int64' - dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type)) - x_shape = (1,) + dyn_shape = relay.Var("shape", relay.ty.TensorType((rank, ), shape_type)) + x_shape = (1, ) x = relay.Var("x", relay.ty.TensorType(x_shape, dtype)) z = relay.broadcast_to(x, dyn_shape) zz = run_infer_type(z) - - assert zz.checked_type == relay.ty.TensorType((relay.Any(),) * rank, dtype) + + assert zz.checked_type == relay.ty.TensorType((relay.Any(), ) * rank, dtype) func = relay.Function([x, dyn_shape], z) - + x = np.random.uniform(size=x_shape).astype(dtype) - dyn_shape = (1,)*rank + dyn_shape = (1, ) * rank ref_res = np.broadcast_to(x, dyn_shape) for target, ctx in ctx_list(): - if (target != 'cuda'): #skip cuda because we don't have dynamic support for GPU + if (target != 'cuda'): #skip cuda because we don't have dynamic support for GPU for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x,np.array(dyn_shape).astype(shape_type)) + op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type)) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) -test_dyn_broadcast_to() + +def test_dyn_one_hot(): + def _get_oshape(indices_shape, depth, axis): + oshape = [] + true_axis = len(indices_shape) if axis == -1 else axis + ndim = len(indices_shape) + 1 + indices_index = 0 + for i in range(0, ndim): + if i == true_axis: + oshape.append(depth) + else: + oshape.append(indices_shape[indices_index]) + indices_index += 1 + + return oshape + + def _verify(indices_shape, depth, on_value, off_value, axis, dtype): + indices = relay.var("indices", relay.TensorType(indices_shape, "int32")) + depth_var = relay.var("depth", relay.TensorType((), "int32")) + on_value_const = relay.const(on_value) + off_value_const = relay.const(off_value) + out = relay.one_hot(indices, on_value_const, off_value_const, depth_var, axis, dtype) + func = relay.Function([indices, depth_var], out) + indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32") + out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype) + for target, ctx in ctx_list(): + if (target != 'cuda'): #skip cuda because we don't have dynamic support for GPU + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + out_relay = intrp.evaluate()(indices_np, np.array(depth).astype("int32")) + tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) + + _verify((3, ), 3, 1, 0, -1, "int32") + _verify((3, ), 3, 1.0, 0.0, -1, "float32") + _verify((2, 2), 5, 2, -2, 0, "int32") + _verify((2, 2), 5, 0.5, -0.5, 1, "float32") + _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") + _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") + + +if __name__ == "__main__": + test_dyn_broadcast_to() + test_dyn_one_hot() diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index 8ca7882..a50c9df 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -22,6 +22,8 @@ from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.testing import run_infer_type, create_workload, ctx_list +import tvm.topi.testing + def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) @@ -222,6 +224,32 @@ def test_dynamic_to_static_zeros_ones(): verify_ones_zeros((1, 2, 3), 'int64') verify_ones_zeros((9, 8, 3, 4), 'float32') +def test_dynamic_to_static_one_hot(): + def _verify(indices_shape, depth, on_value, off_value, axis, dtype): + indices = relay.var("indices", relay.TensorType(indices_shape, "int32")) + depth_var = relay.const(depth) + on_value_const = relay.const(on_value) + off_value_const = relay.const(off_value) + out = relay.one_hot(indices, on_value_const, off_value_const, depth_var, axis, dtype) + func = relay.Function([indices], out) + + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("one_hot") + + indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32") + out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype) + verify_func(func2, [indices_np], out_np) + + _verify((3, ), 3, 1, 0, -1, "int32") + _verify((3, ), 3, 1.0, 0.0, -1, "float32") + _verify((2, 2), 5, 2, -2, 0, "int32") + _verify((2, 2), 5, 0.5, -0.5, 1, "float32") + _verify((3, 2, 4, 5), 6, 1, 0, 1, "int32") + _verify((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32") + if __name__=="__main__": test_dynamic_to_static_reshape() test_dynamic_to_static_double_reshape()