Making ids unique in nn.embedding_lookup_sparse. This helps to reduce RPC calls for...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 1 May 2018 22:00:20 +0000 (15:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 1 May 2018 22:03:02 +0000 (15:03 -0700)
PiperOrigin-RevId: 195002785

tensorflow/python/ops/embedding_ops.py

index 6f2a34c..bcc717b 100644 (file)
@@ -385,7 +385,7 @@ def embedding_lookup_sparse(params,
       ```
 
   Raises:
-    TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is 
+    TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
       neither `None` nor `SparseTensor`.
     ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
   """
@@ -421,10 +421,7 @@ def embedding_lookup_sparse(params,
       segment_ids = math_ops.cast(segment_ids, dtypes.int32)
 
     ids = sp_ids.values
-    if ignore_weights:
-      ids, idx = array_ops.unique(ids)
-    else:
-      idx = None
+    ids, idx = array_ops.unique(ids)
 
     embeddings = embedding_lookup(
         params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
@@ -433,6 +430,8 @@ def embedding_lookup_sparse(params,
       if weights.dtype != embeddings.dtype:
         weights = math_ops.cast(weights, embeddings.dtype)
 
+      embeddings = array_ops.gather(embeddings, idx)
+
       # Reshape weights to allow broadcast
       ones = array_ops.fill(
           array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)