Skip to content

Commit

Permalink
pre-commit: replace linters + formatters with Ruff; fix some issues (#…
Browse files Browse the repository at this point in the history
…1300)

* pre-commit: replace linters + formatters with Ruff

* Don't use bare except

* Clean up `noqa`s

* Enable Ruff UP; apply auto-fixes

* Enable Ruff B; apply fixes

* Enable Ruff T with exceptions

* Enable Ruff C (complexity); autofix

* Upgrade Ruff to 0.2.0
  • Loading branch information
akx authored Feb 15, 2024
1 parent 29f162b commit 9bc478e
Show file tree
Hide file tree
Showing 30 changed files with 113 additions and 149 deletions.
37 changes: 5 additions & 32 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,37 +1,10 @@
repos:
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
hooks:
- id: isort
args:
- --profile=black
- --skip-glob=wandb/**/*
- --thirdparty=wandb
- repo: https://github.com/myint/autoflake
rev: v1.4
hooks:
- id: autoflake
args:
- -r
- --exclude=wandb,__init__.py
- --in-place
- --remove-unused-variables
- --remove-all-unused-imports
- repo: https://github.com/python/black
rev: 22.3.0
hooks:
- id: black
args:
- --line-length=119
- --target-version=py38
- --exclude=wandb
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
args:
- --ignore=E203,E501,W503,E128
- --max-line-length=119
- id: ruff
args: [ --fix ]
- id: ruff-format

# - repo: https://github.com/codespell-project/codespell
# rev: v2.1.0
Expand Down
2 changes: 1 addition & 1 deletion examples/hello_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": 20,
}
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
response_tensor = ppo_trainer.generate(list(query_tensor), return_prompt=False, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])

# 5. define a reward for response
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -163,7 +162,7 @@ def preprocess_function(examples):


def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}


# set seed before initializing value head for deterministic eval
Expand Down
3 changes: 1 addition & 2 deletions examples/research_projects/tools/calculator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -107,7 +106,7 @@ def exact_match_reward(responses, answers=None):
)

# main training loop
for step in range(100):
for _step in range(100):
tasks, answers = generate_data(ppo_config.batch_size)
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
Expand Down
5 changes: 2 additions & 3 deletions examples/research_projects/tools/python_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -61,9 +60,9 @@ def exact_match_reward(responses, answers=None):
if match_pattern:
predicted_number = float(match_pattern[0])
if predicted_number is not None:
if np.abs((predicted_number - float(answer))) < 0.1:
if np.abs(predicted_number - float(answer)) < 0.1:
reward += 1.0
except: # noqa
except Exception:
pass
rewards.append(torch.tensor(reward))
return rewards
Expand Down
17 changes: 9 additions & 8 deletions examples/research_projects/tools/triviaqa.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -114,7 +113,7 @@ class ScriptArguments:

def data_generator():
for i in range(len(dataset)):
yield dataset[i]["question"], [item for item in dataset[i]["answer"]["normalized_aliases"]]
yield dataset[i]["question"], list(dataset[i]["answer"]["normalized_aliases"])


gen = data_generator()
Expand All @@ -123,7 +122,7 @@ def data_generator():

def generate_data(n):
tasks, answers = [], []
for i in range(n):
for _i in range(n):
q, a = next(gen)
tasks.append(q)
answers.append(a)
Expand All @@ -143,10 +142,14 @@ def exact_match_reward(responses, answers=None):
return rewards


def tool_fn(x):
# limit the amount of tokens
return tool(x).split("\n")[1][:600]


# text env
tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
# limit the amount if tokens
tool_fn = lambda x: tool(x).split("\n")[1][:600] # noqa

text_env = TextEnvironment(
model,
tokenizer,
Expand Down Expand Up @@ -184,8 +187,6 @@ def print_trainable_parameters(model):
"answer": [", ".join(item) for item in answers],
}
all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device))
ppo_trainer.log_stats(
train_stats, texts, [item for item in all_rewards], columns_to_log=["query", "response", "answer"]
)
ppo_trainer.log_stats(train_stats, texts, list(all_rewards), columns_to_log=["query", "response", "answer"])
if i % 100 == 0:
ppo_trainer.save_pretrained(f"models/{args.model_name}_{args.seed}_{i}_triviaqa")
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -146,7 +145,7 @@ def tokenize(sample):


def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}


# set seed before initializing value head for deterministic eval
Expand Down Expand Up @@ -218,7 +217,7 @@ def collator(data):
response_tensors.append(response.squeeze()[-gen_len:])
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

# Compute sentiment score # noqa
# Compute sentiment score
texts = batch["response"]
toxicity_inputs = toxicity_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(
ppo_trainer.accelerator.device
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/dpo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
5 changes: 2 additions & 3 deletions examples/scripts/ppo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -95,7 +94,7 @@ def tokenize(sample):


def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}


# set seed before initializing value head for deterministic eval
Expand Down Expand Up @@ -171,7 +170,7 @@ def collator(data):
"max_new_tokens": 32,
}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
for _epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
query_tensors = batch["input_ids"]

# Get response from gpt2
Expand Down
5 changes: 2 additions & 3 deletions examples/scripts/ppo_multi_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -97,7 +96,7 @@ def tokenize(example):


def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
return {key: [d[key] for d in data] for key in data[0]}


config = PPOConfig(
Expand Down Expand Up @@ -131,7 +130,7 @@ def collator(data):
"max_new_tokens": 32,
}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
for _epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
question_tensors = batch["input_ids"]

response_tensors = ppo_trainer.generate(
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/sft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
26 changes: 16 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
[tool.black]
line-length = 119
target-version = ['py38']

[tool.ruff]
ignore = ["E501", "E741", "W605"]
select = ["E", "F", "I", "W"]
target-version = "py37"
line-length = 119

# Ignore import violations in all `__init__.py` files.
[tool.ruff.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]
[tool.ruff.lint]
ignore = [
"B028", # warning without explicit stacklevel
"C408", # dict() calls (stylistic)
"C901", # function complexity
"E501",
]
extend-select = ["E", "F", "I", "W", "UP", "B", "T", "C"]

[tool.ruff.lint.per-file-ignores]
# Allow prints in auxiliary scripts
"benchmark/**.py" = ["T201"]
"examples/**.py" = ["T201"]
"scripts/**.py" = ["T201"]

[tool.ruff.isort]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["trl"]
2 changes: 1 addition & 1 deletion scripts/log_example_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main(text_file_name, slack_channel_name=None):
if os.path.isfile(text_file_name):
final_results = {}

file = open(text_file_name, "r")
file = open(text_file_name)
lines = file.readlines()
for line in lines:
result, config_name = line.split(",")
Expand Down
2 changes: 1 addition & 1 deletion scripts/log_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main(slack_channel_name=None):
for log in Path().glob("*.log"):
section_num_failed = 0
i = 0
with open(log, "r") as f:
with open(log) as f:
for line in f:
line = json.loads(line)
i += 1
Expand Down
2 changes: 1 addition & 1 deletion scripts/stale.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main():
open_issues = repo.get_issues(state="open")

for issue in open_issues:
comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True)
comments = sorted(issue.get_comments(), key=lambda i: i.created_at, reverse=True)
last_comment = comments[0] if len(comments) > 0 else None
if (
last_comment is not None
Expand Down
9 changes: 0 additions & 9 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,11 +1,2 @@
[metadata]
license_file = LICENSE

[isort]
ensure_newline_before_comments = True
force_grid_wrap = 0
include_trailing_comma = True
line_length = 119
lines_after_imports = 2
multi_line_output = 3
use_parentheses = True
10 changes: 5 additions & 5 deletions tests/test_no_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def test_no_peft(self):

# Check that loading a model with `peft` will raise an error
with pytest.raises(ModuleNotFoundError):
import peft # noqa
import peft # noqa: F401

trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) # noqa
trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id) # noqa
_trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id)
_trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id)

def test_imports_no_peft(self):
with patch.dict(sys.modules, {"peft": None}):
from trl import ( # noqa
from trl import ( # noqa: F401
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
PPOConfig,
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_ppo_trainer_no_peft(self):
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
train_stats = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
break

# check gradients are not None
Expand Down
Loading

0 comments on commit 9bc478e

Please sign in to comment.