From: Yong Tang Date: Sun, 13 May 2018 13:55:53 +0000 (+0000) Subject: Add additional shape validation to `compute_accidental_hits` X-Git-Tag: upstream/v1.9.0_rc1~80^2 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=356f360e8772a2697ec0d30036237342549803f5;p=platform%2Fupstream%2Ftensorflow.git Add additional shape validation to `compute_accidental_hits` In `compute_accidental_hits`, the `sampled_candidates` must be a vector, as is shown in the kernel implementation in `tensorflow/core/kernels/candidate_sampler_ops.cc`. This fix adds shape validation of `sampled_candidates` in the shape function whenever possible. Signed-off-by: Yong Tang --- diff --git a/tensorflow/core/ops/candidate_sampling_ops.cc b/tensorflow/core/ops/candidate_sampling_ops.cc index 6e4d100..6e589c8 100644 --- a/tensorflow/core/ops/candidate_sampling_ops.cc +++ b/tensorflow/core/ops/candidate_sampling_ops.cc @@ -145,12 +145,15 @@ REGISTER_OP("ComputeAccidentalHits") int64 num_true; TF_RETURN_IF_ERROR(c->GetAttr("num_true", &num_true)); - // Validate true_classes. + // Validate true_classes, must be a matrix. ShapeHandle true_classes; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes)); DimensionHandle unused; TF_RETURN_IF_ERROR( c->WithValue(c->Dim(true_classes, 1), num_true, &unused)); + // Validate sampled_candidates, must be a vector. + ShapeHandle sampled_candidates; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sampled_candidates)); // All three outputs are the same shape. ShapeHandle v = c->Vector(InferenceContext::kUnknownDim);