Skip to content

Commit

Permalink
Create function for shape checks and add file for detection methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Dylan-Harden3 committed Dec 9, 2024
1 parent 23d01cc commit fc041c5
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 71 deletions.
2 changes: 1 addition & 1 deletion pydetectgpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""On device LLM-Generated text detection in Pytorch."""

from .detect import detect_ai_text
from .utils import log_likelihood, log_rank
from .methods import log_likelihood, log_rank

__version__ = "0.1.0"
__all__ = ["detect_ai_text", "log_likelihood", "log_rank"]
3 changes: 2 additions & 1 deletion pydetectgpt/detect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Implementations of detection algorithms."""

from typing import Literal
from .utils import load_model, log_likelihood, log_rank
from .utils import load_model
from .methods import log_likelihood, log_rank
import torch

DETECTION_FUNCS = {"loglikelihood": log_likelihood, "logrank": log_rank}
Expand Down
51 changes: 51 additions & 0 deletions pydetectgpt/methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torch.nn.functional as F
from .utils import validate_tensor_shapes


def log_likelihood(labels: torch.Tensor, logits: torch.Tensor) -> float:
"""Compute the loglikelihood of labels in logits.
Args:
labels (torch.Tensor): Ground truth labels of shape: (1, sequence_length).
logits (torch.Tensor): Logits of shape: (1, sequence_length, vocab_size).
Returns:
float: The mean loglikelihood.
Raises:
ValueError: If the shapes of `labels` and `logits` are incompatible or batch size is > 1.
"""
validate_tensor_shapes(labels, logits)

logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)

log_probs = F.log_softmax(logits, dim=-1)
actual_token_probs = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(
-1
)
return actual_token_probs.mean().item()


def log_rank(labels: torch.Tensor, logits: torch.Tensor) -> float:
"""Compute the negative average log rank of labels in logits.
Args:
labels (torch.Tensor): Ground truth labels of shape: (1, sequence_length).
logits (torch.Tensor): Logits of shape: (1, sequence_length, vocab_size).
Returns:
float: The negative mean logrank.
Raises:
ValueError: If the shapes of `labels` and `logits` are incompatible or batch size is > 1.
"""
validate_tensor_shapes(labels, logits)

matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero()
ranks = matches[:, -1]

log_ranks = torch.log(ranks.float() + 1)

return -log_ranks.mean().item()
69 changes: 6 additions & 63 deletions pydetectgpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Utils used throughout source code."""

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Tuple

Expand All @@ -28,91 +27,35 @@ def load_model(hf_repo: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
return model, tokenizer


def log_likelihood(labels: torch.Tensor, logits: torch.Tensor) -> float:
"""Compute the loglikelihood of labels in logits.
def validate_tensor_shapes(labels: torch.Tensor, logits: torch.Tensor) -> None:
"""Validate the compatibility of labels and logits.
Args:
labels (torch.Tensor): Ground truth labels of shape: (1, sequence_length).
logits (torch.Tensor): Logits of shape: (1, sequence_length, vocab_size).
Returns:
float: The mean loglikelihood.
Raises:
ValueError: If the shapes of `labels` and `logits` are incompatible or batch size is > 1.
"""
if logits.shape[0] != 1 or labels.shape[0] != 1:
raise ValueError(
f"In log_likelihood, batch size must be 1, but got logits batch size {logits.shape[0]} "
f"Batch size must be 1, but got logits batch size {logits.shape[0]} "
f"and labels batch size {labels.shape[0]}"
)

if logits.dim() < 2:
raise ValueError(
f"In log_likelihood, logits must have at least 2 dimensions, but got shape {logits.shape}"
f"Logits must have at least 2 dimensions, but got shape {logits.shape}"
)

if labels.shape != logits.shape[:-1]:
raise ValueError(
f"In log_likelihood, labels and logits must have compatible shapes. "
f"Labels and logits must have compatible shapes. "
f"Got labels shape {labels.shape} and logits shape {logits.shape[:-1]}"
)

if labels.max().item() >= logits.shape[-1]:
raise ValueError(
f"In log_likelihood, labels must be in vocab size ({logits.shape[-1]}), "
f"Labels must be in vocab size ({logits.shape[-1]}), "
f"but got label {labels.max().item()}"
)

logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)

log_probs = F.log_softmax(logits, dim=-1)
actual_token_probs = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(
-1
)
return actual_token_probs.mean().item()


def log_rank(labels: torch.Tensor, logits: torch.Tensor) -> float:
"""Compute the negative average log rank of labels in logits.
Args:
labels (torch.Tensor): Ground truth labels of shape: (1, sequence_length).
logits (torch.Tensor): Logits of shape: (1, sequence_length, vocab_size).
Returns:
float: The negative mean logrank.
Raises:
ValueError: If the shapes of `labels` and `logits` are incompatible or batch size is > 1.
"""
if logits.shape[0] != 1 or labels.shape[0] != 1:
raise ValueError(
f"In log_likelihood, batch size must be 1, but got logits batch size {logits.shape[0]} "
f"and labels batch size {labels.shape[0]}"
)

if logits.dim() < 2:
raise ValueError(
f"In log_likelihood, logits must have at least 2 dimensions, but got shape {logits.shape}"
)

if labels.shape != logits.shape[:-1]:
raise ValueError(
f"In log_likelihood, labels and logits must have compatible shapes. "
f"Got labels shape {labels.shape} and logits shape {logits.shape[:-1]}"
)

if labels.max().item() >= logits.shape[-1]:
raise ValueError(
f"In log_likelihood, labels must be in vocab size ({logits.shape[-1]}), "
f"but got label {labels.max().item()}"
)

matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero()
ranks = matches[:, -1]

log_ranks = torch.log(ranks.float() + 1)

return -log_ranks.mean().item()
12 changes: 6 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@ def test_log_likelihood():
labels = torch.randint(0, 9, (1, 6))

with pytest.raises(
ValueError, match="labels and logits must have compatible shapes"
ValueError, match="Labels and logits must have compatible shapes"
):
log_likelihood(labels, logits)

# batch size > 1
logits = torch.randn(2, 5, 10)
labels = torch.randint(0, 9, (2, 5))

with pytest.raises(ValueError, match="batch size must be 1"):
with pytest.raises(ValueError, match="Batch size must be 1"):
log_likelihood(labels, logits)

# label > vocab size
logits = torch.randn(1, 3, 10)
labels = torch.tensor([[2, 5, 10]])

with pytest.raises(ValueError, match="labels must be in vocab size"):
with pytest.raises(ValueError, match="Labels must be in vocab size"):
log_likelihood(labels, logits)

# some simple tests I calculated manually
Expand All @@ -45,22 +45,22 @@ def test_log_rank():
labels = torch.randint(0, 9, (1, 6))

with pytest.raises(
ValueError, match="labels and logits must have compatible shapes"
ValueError, match="Labels and logits must have compatible shapes"
):
log_rank(labels, logits)

# batch size > 1
logits = torch.randn(2, 5, 10)
labels = torch.randint(0, 9, (2, 5))

with pytest.raises(ValueError, match="batch size must be 1"):
with pytest.raises(ValueError, match="Batch size must be 1"):
log_rank(labels, logits)

# label > vocab size
logits = torch.randn(1, 3, 10)
labels = torch.tensor([[2, 5, 10]])

with pytest.raises(ValueError, match="labels must be in vocab size"):
with pytest.raises(ValueError, match="Labels must be in vocab size"):
log_rank(labels, logits)

# some simple tests I calculated manually
Expand Down

0 comments on commit fc041c5

Please sign in to comment.