Skip to content

Commit

Permalink
Merge pull request #35 from datamol-io/fix/encdec
Browse files Browse the repository at this point in the history
Fix/encdec
  • Loading branch information
maclandrol authored Mar 28, 2024
2 parents bb5b909 + f9213a2 commit 80ccca7
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/code-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:

- name: Install black
run: |
pip install black>=23
pip install black>=24
- name: Lint
run: black --check .
Expand Down
3 changes: 2 additions & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:

# Scientific
- datamol
- pandas <=2.1.1
- numpy
- pytorch >=2.0
- transformers
Expand All @@ -25,7 +26,7 @@ dependencies:
- deepspeed

# dev
- black >=23
- black >=24
- ruff
- pytest >=6.0
- nbconvert
Expand Down
10 changes: 4 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,22 @@ minversion = "6.0"
addopts = "--verbose --color yes"
testpaths = ["tests"]

[tool.ruff.pycodestyle]
max-doc-length = 150

[tool.ruff]
line-length = 120
# Enable Pyflakes `E` and `F` codes by default.
select = [
lint.select = [
"E",
"W", # see: https://pypi.org/project/pycodestyle
"F", # see: https://pypi.org/project/pyflakes
]
extend-select = [
lint.extend-select = [
"C4", # see: https://pypi.org/project/flake8-comprehensions
"SIM", # see: https://pypi.org/project/flake8-simplify
"RET", # see: https://pypi.org/project/flake8-return
"PT", # see: https://pypi.org/project/flake8-pytest-style
]
ignore = [
lint.ignore = [
"E731", # Do not assign a lambda expression, use a def
"S108",
"F401",
Expand All @@ -108,4 +106,4 @@ ignore = [
]
# Exclude a variety of commonly ignored directories.
exclude = [".git", "docs", "_notebooks"]
ignore-init-module-imports = true
lint.ignore-init-module-imports = true
4 changes: 2 additions & 2 deletions safe/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def __init__(
safe_encoder: custom safe encoder to use
verbose: whether to print out logging information during generation
"""
if isinstance(model, os.PathLike):
if isinstance(model, (str, os.PathLike)):
model = SAFEDoubleHeadsModel.from_pretrained(model)

if isinstance(tokenizer, os.PathLike):
if isinstance(tokenizer, (str, os.PathLike)):
tokenizer = SAFETokenizer.load(tokenizer)

model.eval()
Expand Down
6 changes: 3 additions & 3 deletions safe/trainer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,9 @@ def compute_metrics(eval_preds):
prop_loss_coeff=model_args.prop_loss_coeff,
compute_metrics=compute_metrics if training_args.do_eval else None,
data_collator=data_collator,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval
else None,
preprocess_logits_for_metrics=(
preprocess_logits_for_metrics if training_args.do_eval else None
),
)

if training_args.do_train:
Expand Down
2 changes: 2 additions & 0 deletions safe/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
inputs: Optional[Any] = None, # do not remove because of trainer
encoder_hidden_states: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
r"""
Expand Down Expand Up @@ -164,6 +165,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
encoder_hidden_states=encoder_hidden_states,
)

hidden_states = transformer_outputs[0]
Expand Down

0 comments on commit 80ccca7

Please sign in to comment.