Skip to content

Latest commit

 

History

History
178 lines (129 loc) · 3.14 KB

File metadata and controls

178 lines (129 loc) · 3.14 KB

TFSimilarity.retrieval_metrics.RecallAtK

The metric learning version of Recall@K.

Inherits From: RetrievalMetric, ABC

TFSimilarity.retrieval_metrics.RecallAtK(
    name: str = recall, k: int = 5, **kwargs
) -> None

A query is counted as a positive when ANY lookup in top K match the query class, 0 otherwise.

Args

name Name associated with the metric object, e.g., recall@5
canonical_name The canonical name associated with metric, e.g., recall@K
k The number of nearest neighbors over which the metric is computed.
distance_threshold The max distance below which a nearest neighbor is considered a valid match.
average 'micro' Determines the type of averaging performed over the queries.
  • 'micro': Calculates metrics globally over all queries.

  • 'macro': Calculates metrics for each label and takes the unweighted mean.

Attributes

name

Methods

compute

View source

compute(
    *,
    query_labels: <a href="../../TFSimilarity/callbacks/IntTensor.md">TFSimilarity.callbacks.IntTensor```
</a>,
    match_mask: <a href="../../TFSimilarity/utils/BoolTensor.md">TFSimilarity.utils.BoolTensor```
</a>,
    **kwargs
) -> <a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor```
</a>

Compute the metric

Args
query_labels A 1D tensor of the labels associated with the embedding queries.
match_mask A 2D mask where a 1 indicates a match between the jth query and the kth neighbor and a 0 indicates a mismatch.
**kwargs Additional compute args.
Returns
A rank 0 tensor containing the metric.

get_config

View source

get_config()