Add additional shape validation to `compute_accidental_hits`
authorYong Tang <yong.tang.github@outlook.com>
Sun, 13 May 2018 13:55:53 +0000 (13:55 +0000)
committerYong Tang <yong.tang.github@outlook.com>
Mon, 14 May 2018 00:32:50 +0000 (00:32 +0000)
In `compute_accidental_hits`, the `sampled_candidates` must
be a vector, as is shown in the kernel implementation in
`tensorflow/core/kernels/candidate_sampler_ops.cc`.

This fix adds shape validation of `sampled_candidates`
in the shape function whenever possible.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
tensorflow/core/ops/candidate_sampling_ops.cc

index 6e4d100..6e589c8 100644 (file)
@@ -145,12 +145,15 @@ REGISTER_OP("ComputeAccidentalHits")
       int64 num_true;
       TF_RETURN_IF_ERROR(c->GetAttr("num_true", &num_true));
 
-      // Validate true_classes.
+      // Validate true_classes, must be a matrix.
       ShapeHandle true_classes;
       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes));
       DimensionHandle unused;
       TF_RETURN_IF_ERROR(
           c->WithValue(c->Dim(true_classes, 1), num_true, &unused));
+      // Validate sampled_candidates, must be a vector.
+      ShapeHandle sampled_candidates;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sampled_candidates));
 
       // All three outputs are the same shape.
       ShapeHandle v = c->Vector(InferenceContext::kUnknownDim);