From b9af4308064dc560c4501523a5508de553000fb0 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 27 Jan 2018 19:48:35 +0000 Subject: [PATCH] Enable multi-dimensional and axis support for tf.unique_with_counts This fix tries to address the issue raised in 16499 to bring multi-dimensional and axis support for `unique_with_counts`. When `UniqueV2` kernel was added in 12952, it actually supports multi-dimensional and axis support for `unique_with_counts` as well, just not registered. This fix: 1. Register `UniqueWithCountsV2` kernel to have axis support. 2. Hide both `UniqueWithCounts` and `UniqueWithCountsV2` 3. Add python unique_with_counts wrapper to call `gen_array_ops._unique_with_counts` 4. If APi review passes and the PR merges, `unique_with_counts` will switch to `gen_array_ops._unique_with_counts_v2` (in 3 weeks). This fix fixes 16499. Signed-off-by: Yong Tang --- tensorflow/core/kernels/unique_op.cc | 10 ++++++++++ tensorflow/core/ops/array_ops.cc | 17 +++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc index 0ef8724..31388e4 100644 --- a/tensorflow/core/kernels/unique_op.cc +++ b/tensorflow/core/kernels/unique_op.cc @@ -223,6 +223,16 @@ class UniqueOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ + UniqueOp); \ + REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueOp) \ + REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ UniqueOp) TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE); REGISTER_UNIQUE(string) diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 267ce88..2fab62e 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1201,6 +1201,23 @@ REGISTER_OP("UniqueWithCounts") return Status::OK(); }); +REGISTER_OP("UniqueWithCountsV2") + .Input("x: T") + .Input("axis: Taxis") + .Output("y: T") + .Output("idx: out_idx") + .Output("count: out_idx") + .Attr("T: type") + .Attr("Taxis: {int32,int64} = DT_INT64") + .Attr("out_idx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + auto uniq = c->Vector(InferenceContext::kUnknownDim); + c->set_output(0, uniq); + c->set_output(1, c->input(0)); + c->set_output(2, uniq); + return Status::OK(); + }); + namespace { Status ShapeShapeFn(InferenceContext* c) { -- 2.7.4