The metric learning version of Recall@K.
Inherits From: RetrievalMetric
, ABC
TFSimilarity.retrieval_metrics.RecallAtK(
name: str = recall, k: int = 5, **kwargs
) -> None
A query is counted as a positive when ANY lookup in top K match the query class, 0 otherwise.
name | Name associated with the metric object, e.g., recall@5 |
canonical_name | The canonical name associated with metric, e.g., recall@K |
k | The number of nearest neighbors over which the metric is computed. |
distance_threshold | The max distance below which a nearest neighbor is considered a valid match. |
average |
'micro' Determines the type of averaging performed over the
queries.
|
name |
compute(
*,
query_labels: <a href="../../TFSimilarity/callbacks/IntTensor.md">TFSimilarity.callbacks.IntTensor```
</a>,
match_mask: <a href="../../TFSimilarity/utils/BoolTensor.md">TFSimilarity.utils.BoolTensor```
</a>,
**kwargs
) -> <a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor```
</a>
Compute the metric
Args | |
---|---|
query_labels | A 1D tensor of the labels associated with the embedding queries. |
match_mask | A 2D mask where a 1 indicates a match between the jth query and the kth neighbor and a 0 indicates a mismatch. |
**kwargs | Additional compute args. |
Returns | |
---|---|
A rank 0 tensor containing the metric. |
get_config()