```
Raises:
- TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
+ TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
neither `None` nor `SparseTensor`.
ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
"""
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
ids = sp_ids.values
- if ignore_weights:
- ids, idx = array_ops.unique(ids)
- else:
- idx = None
+ ids, idx = array_ops.unique(ids)
embeddings = embedding_lookup(
params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
if weights.dtype != embeddings.dtype:
weights = math_ops.cast(weights, embeddings.dtype)
+ embeddings = array_ops.gather(embeddings, idx)
+
# Reshape weights to allow broadcast
ones = array_ops.fill(
array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)