Skip to content

Commit

Permalink
Merge pull request #11 from Dylan-Harden3/refactor
Browse files Browse the repository at this point in the history
add typehints everywhere
  • Loading branch information
Dylan-Harden3 authored Dec 10, 2024
2 parents ce17ae4 + fef0912 commit fd3d75f
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 62 deletions.
33 changes: 19 additions & 14 deletions pydetectgpt/detect.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,39 @@
"""Implementations of detection algorithms."""

from typing import Literal
from typing import Literal, Dict, Callable
from .utils import load_model
from .methods import log_likelihood, log_rank, likelihood_logrank_ratio, fast_detect_gpt
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

DETECTION_FUNCS = {
DETECTION_FUNCS: Dict[str, Callable[[torch.Tensor, torch.Tensor], float]] = {
"loglikelihood": log_likelihood,
"logrank": log_rank,
"detectllm": likelihood_logrank_ratio,
"fastdetectgpt": fast_detect_gpt,
}
THRESHOLDS = {
THRESHOLDS: Dict[str, float] = {
"loglikelihood": -1.8,
"logrank": -0.8,
"detectllm": 2.14,
"fastdetectgpt": 1.9,
}

DetectionMethod = Literal["loglikelihood", "logrank", "detectllm", "fastdetectgpt"]


def detect_ai_text(
text: str,
method: Literal[
"loglikelihood", "logrank", "detectllm", "fastdetectgpt"
] = "fastdetectgpt",
method: DetectionMethod = "fastdetectgpt",
threshold: float = None,
detection_model: str = "Qwen/Qwen2.5-1.5B",
) -> int:
"""Detect if `text` is written by human or ai.
Args:
text (str): The text to check.
method (str, optional), default='fastdetectgpt': Detection method to use, must be one of ['loglikelihood', 'logrank', 'detectllm', 'fastdetectgpt'].
threshold (float, optional), default=None: Decision threshold for `method` to use. If not provided, a default value will be used based on `method`.
method (DetectionMethod, optional), default='fastdetectgpt': Detection method to use, must be one of ['loglikelihood', 'logrank', 'detectllm', 'fastdetectgpt']
threshold (float | None, optional), default=None: Decision threshold for `method` to use. If not provided, a default value will be used based on `method`.
detection_model (str, optional), default=Qwen/Qwen2.5-1.5B: Huggingface Repo name for the model that `method` will use to generate logits.
Returns:
Expand All @@ -44,10 +45,12 @@ def detect_ai_text(
if not text:
return 0

device = "cuda" if torch.cuda.is_available() else "cpu"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
model: AutoModelForCausalLM
tokenizer: AutoTokenizer
model, tokenizer = load_model(detection_model)

tokens = tokenizer(
tokens: torch.Tensor = tokenizer(
text,
return_tensors="pt",
padding=True,
Expand All @@ -59,13 +62,15 @@ def detect_ai_text(
f"In detect_ai_text `method` must be one of ['loglikelihood', 'logrank', 'detectllm', 'fastdetectgpt'], but got {method}"
)

method_func = DETECTION_FUNCS[method]
method_func: Callable[[torch.Tensor, torch.Tensor], float] = DETECTION_FUNCS[method]
if threshold is None:
threshold = THRESHOLDS[method]

labels = tokens.input_ids[:, 1:] # remove bos token
labels: torch.Tensor = tokens.input_ids[:, 1:] # remove bos token
with torch.no_grad():
logits = model(**tokens).logits[:, :-1] # remove next token logits
pred = method_func(labels, logits)
logits: torch.Tensor = model(**tokens).logits[
:, :-1
] # remove next token logits
pred: float = method_func(labels, logits)

return 0 if pred < threshold else 1
46 changes: 27 additions & 19 deletions pydetectgpt/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def log_likelihood(labels: torch.Tensor, logits: torch.Tensor) -> float:
"""
validate_tensor_shapes(labels, logits)

logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
logits: torch.Tensor = logits.view(-1, logits.shape[-1])
labels: torch.Tensor = labels.view(-1)

log_probs = F.log_softmax(logits, dim=-1)
actual_token_probs = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(
-1
)
log_probs: torch.Tensor = F.log_softmax(logits, dim=-1)
actual_token_probs: torch.Tensor = log_probs.gather(
dim=-1, index=labels.unsqueeze(-1)
).squeeze(-1)
return actual_token_probs.mean().item()


Expand All @@ -45,10 +45,12 @@ def log_rank(labels: torch.Tensor, logits: torch.Tensor) -> float:
"""
validate_tensor_shapes(labels, logits)

matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero()
ranks = matches[:, -1]
matches: torch.Tensor = (
logits.argsort(-1, descending=True) == labels.unsqueeze(-1)
).nonzero()
ranks: torch.Tensor = matches[:, -1]

log_ranks = torch.log(ranks.float() + 1)
log_ranks: torch.Tensor = torch.log(ranks.float() + 1)

return -log_ranks.mean().item()

Expand All @@ -68,8 +70,8 @@ def likelihood_logrank_ratio(labels: torch.Tensor, logits: torch.Tensor) -> floa
"""
validate_tensor_shapes(labels, logits)

_log_likelihood = log_likelihood(labels, logits)
_log_rank = log_rank(labels, logits)
_log_likelihood: float = log_likelihood(labels, logits)
_log_rank: float = log_rank(labels, logits)

return _log_likelihood / _log_rank

Expand All @@ -90,18 +92,24 @@ def fast_detect_gpt(labels: torch.Tensor, logits: torch.Tensor) -> float:
validate_tensor_shapes(labels, logits)

# conditional sampling
log_probs = F.log_softmax(logits, dim=-1)
distribution = torch.distributions.categorical.Categorical(logits=log_probs)
x_tilde = distribution.sample([10000]).permute([1, 2, 0])
log_probs: torch.Tensor = F.log_softmax(logits, dim=-1)
distribution: torch.distributions.categorical.Categorical = (
torch.distributions.categorical.Categorical(logits=log_probs)
)
x_tilde: torch.Tensor = distribution.sample([10000]).permute([1, 2, 0])

log_likelihood_x = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).mean(dim=1)
log_likelihood_x_tilde = log_probs.gather(dim=-1, index=x_tilde).mean(dim=1)
log_likelihood_x: torch.Tensor = log_probs.gather(
dim=-1, index=labels.unsqueeze(-1)
).mean(dim=1)
log_likelihood_x_tilde: torch.Tensor = log_probs.gather(dim=-1, index=x_tilde).mean(
dim=1
)

# estimate the mean/variance
mu_tilde = log_likelihood_x_tilde.mean(dim=-1)
sigma_tilde = log_likelihood_x_tilde.std(dim=-1)
mu_tilde: torch.Tensor = log_likelihood_x_tilde.mean(dim=-1)
sigma_tilde: torch.Tensor = log_likelihood_x_tilde.std(dim=-1)

# estimate conditional probability curvature
dhat = (log_likelihood_x - mu_tilde) / sigma_tilde
dhat: torch.Tensor = (log_likelihood_x - mu_tilde) / sigma_tilde

return dhat.item()
8 changes: 5 additions & 3 deletions pydetectgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ def load_model(hf_repo: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
Raises:
ValueError: If there is an issue loading the model or tokenizer from HuggingFace.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(hf_repo)
model = AutoModelForCausalLM.from_pretrained(hf_repo).to(device)
device: str = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(hf_repo)
model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(hf_repo).to(
device
)

if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand Down
45 changes: 27 additions & 18 deletions tests/test_detector.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from pydetectgpt import detect_ai_text
import pytest
from pydetectgpt.detect import DetectionMethod

# I asked chatgpt "Where is Texas A&M?"
AI_TEXT = "Texas A&M University is located in College Station, Texas, in the southeastern part of the state. It's about 90 miles northwest of Houston and around 150 miles south of Dallas. The university's full name is Texas Agricultural and Mechanical University, and it is one of the largest public universities in the United States."
AI_TEXT: str = (
"Texas A&M University is located in College Station, Texas, in the southeastern part of the state. It's about 90 miles northwest of Houston and around 150 miles south of Dallas. The university's full name is Texas Agricultural and Mechanical University, and it is one of the largest public universities in the United States."
)
# random paragraph from one of my assignments (written by human)
HUMAN_TEXT = "The main problem the authors are trying to address is that Large Language Models require large computational resources to use. This means that as a common setup we see companies deploying GPU clusters which act as a cloud server to generate responses when a user presents a query. Aside from the vast resources needed to set up a GPU cluster this approach has 2 main downsides: sending queries over the internet via an API exposes users’ private data and results in additional latency when generating responses"
HUMAN_TEXT: str = (
"The main problem the authors are trying to address is that Large Language Models require large computational resources to use. This means that as a common setup we see companies deploying GPU clusters which act as a cloud server to generate responses when a user presents a query. Aside from the vast resources needed to set up a GPU cluster this approach has 2 main downsides: sending queries over the internet via an API exposes users’ private data and results in additional latency when generating responses"
)


def test_detect_ai_text():
Expand All @@ -23,44 +28,48 @@ def test_detect_ai_text():


def test_detect_ai_text_loglikelihood():
method: DetectionMethod = "loglikelihood"

assert detect_ai_text(AI_TEXT, method="loglikelihood") == 1
assert detect_ai_text(AI_TEXT, method=method) == 1

assert detect_ai_text(HUMAN_TEXT, method="loglikelihood") == 0
assert detect_ai_text(HUMAN_TEXT, method=method) == 0

assert detect_ai_text(AI_TEXT, method="loglikelihood", threshold=99999.9) == 0
assert detect_ai_text(AI_TEXT, method=method, threshold=99999.9) == 0

assert detect_ai_text(HUMAN_TEXT, method="loglikelihood", threshold=-99999.9) == 1
assert detect_ai_text(HUMAN_TEXT, method=method, threshold=-99999.9) == 1


def test_detect_ai_text_logrank():
method: DetectionMethod = "logrank"

assert detect_ai_text(AI_TEXT, method="logrank") == 1
assert detect_ai_text(AI_TEXT, method=method) == 1

assert detect_ai_text(HUMAN_TEXT, method="logrank") == 0
assert detect_ai_text(HUMAN_TEXT, method=method) == 0

assert detect_ai_text(AI_TEXT, method="logrank", threshold=99999.9) == 0
assert detect_ai_text(AI_TEXT, method=method, threshold=99999.9) == 0

assert detect_ai_text(HUMAN_TEXT, method="logrank", threshold=-99999.9) == 1
assert detect_ai_text(HUMAN_TEXT, method=method, threshold=-99999.9) == 1


def test_detect_ai_text_detectllm():
method: DetectionMethod = "detectllm"

assert detect_ai_text(AI_TEXT, method="detectllm") == 1
assert detect_ai_text(AI_TEXT, method=method) == 1

assert detect_ai_text(HUMAN_TEXT, method="detectllm") == 0
assert detect_ai_text(HUMAN_TEXT, method=method) == 0

assert detect_ai_text(AI_TEXT, method="detectllm", threshold=99999.9) == 0
assert detect_ai_text(AI_TEXT, method=method, threshold=99999.9) == 0

assert detect_ai_text(HUMAN_TEXT, method="detectllm", threshold=-99999.9) == 1
assert detect_ai_text(HUMAN_TEXT, method=method, threshold=-99999.9) == 1


def test_detect_ai_text_fastdetectgpt():
method: DetectionMethod = "fastdetectgpt"

assert detect_ai_text(AI_TEXT, method="fastdetectgpt") == 1
assert detect_ai_text(AI_TEXT, method=method) == 1

assert detect_ai_text(HUMAN_TEXT, method="fastdetectgpt") == 0
assert detect_ai_text(HUMAN_TEXT, method=method) == 0

assert detect_ai_text(AI_TEXT, method="fastdetectgpt", threshold=99999.9) == 0
assert detect_ai_text(AI_TEXT, method=method, threshold=99999.9) == 0

assert detect_ai_text(HUMAN_TEXT, method="fastdetectgpt", threshold=-99999.9) == 1
assert detect_ai_text(HUMAN_TEXT, method=method, threshold=-99999.9) == 1
18 changes: 10 additions & 8 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import pytest
from torch import Tensor

from pydetectgpt import (
log_likelihood,
log_rank,
Expand All @@ -10,8 +12,8 @@

def test_log_likelihood():
# shape mismatch
logits = torch.randn(1, 5, 10)
labels = torch.randint(0, 9, (1, 6))
logits: Tensor = torch.randn(1, 5, 10)
labels: Tensor = torch.randint(0, 9, (1, 6))

with pytest.raises(
ValueError, match="Labels and logits must have compatible shapes"
Expand Down Expand Up @@ -46,8 +48,8 @@ def test_log_likelihood():

def test_log_rank():
# shape mismatch
logits = torch.randn(1, 5, 10)
labels = torch.randint(0, 9, (1, 6))
logits: Tensor = torch.randn(1, 5, 10)
labels: Tensor = torch.randint(0, 9, (1, 6))

with pytest.raises(
ValueError, match="Labels and logits must have compatible shapes"
Expand Down Expand Up @@ -82,8 +84,8 @@ def test_log_rank():

def test_likelihood_logrank_ratio():
# shape mismatch
logits = torch.randn(1, 5, 10)
labels = torch.randint(0, 9, (1, 6))
logits: Tensor = torch.randn(1, 5, 10)
labels: Tensor = torch.randint(0, 9, (1, 6))

with pytest.raises(
ValueError, match="Labels and logits must have compatible shapes"
Expand Down Expand Up @@ -118,8 +120,8 @@ def test_likelihood_logrank_ratio():

def test_fast_detect_gpt():
# shape mismatch
logits = torch.randn(1, 5, 10)
labels = torch.randint(0, 9, (1, 6))
logits: Tensor = torch.randn(1, 5, 10)
labels: Tensor = torch.randint(0, 9, (1, 6))

with pytest.raises(
ValueError, match="Labels and logits must have compatible shapes"
Expand Down

0 comments on commit fd3d75f

Please sign in to comment.