-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create function for shape checks and add file for detection methods
- Loading branch information
1 parent
23d01cc
commit fc041c5
Showing
5 changed files
with
66 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
"""On device LLM-Generated text detection in Pytorch.""" | ||
|
||
from .detect import detect_ai_text | ||
from .utils import log_likelihood, log_rank | ||
from .methods import log_likelihood, log_rank | ||
|
||
__version__ = "0.1.0" | ||
__all__ = ["detect_ai_text", "log_likelihood", "log_rank"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from .utils import validate_tensor_shapes | ||
|
||
|
||
def log_likelihood(labels: torch.Tensor, logits: torch.Tensor) -> float: | ||
"""Compute the loglikelihood of labels in logits. | ||
Args: | ||
labels (torch.Tensor): Ground truth labels of shape: (1, sequence_length). | ||
logits (torch.Tensor): Logits of shape: (1, sequence_length, vocab_size). | ||
Returns: | ||
float: The mean loglikelihood. | ||
Raises: | ||
ValueError: If the shapes of `labels` and `logits` are incompatible or batch size is > 1. | ||
""" | ||
validate_tensor_shapes(labels, logits) | ||
|
||
logits = logits.view(-1, logits.shape[-1]) | ||
labels = labels.view(-1) | ||
|
||
log_probs = F.log_softmax(logits, dim=-1) | ||
actual_token_probs = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze( | ||
-1 | ||
) | ||
return actual_token_probs.mean().item() | ||
|
||
|
||
def log_rank(labels: torch.Tensor, logits: torch.Tensor) -> float: | ||
"""Compute the negative average log rank of labels in logits. | ||
Args: | ||
labels (torch.Tensor): Ground truth labels of shape: (1, sequence_length). | ||
logits (torch.Tensor): Logits of shape: (1, sequence_length, vocab_size). | ||
Returns: | ||
float: The negative mean logrank. | ||
Raises: | ||
ValueError: If the shapes of `labels` and `logits` are incompatible or batch size is > 1. | ||
""" | ||
validate_tensor_shapes(labels, logits) | ||
|
||
matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero() | ||
ranks = matches[:, -1] | ||
|
||
log_ranks = torch.log(ranks.float() + 1) | ||
|
||
return -log_ranks.mean().item() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters