Skip to content

Commit

Permalink
fix flake8 lints
Browse files Browse the repository at this point in the history
  • Loading branch information
ochafik committed Sep 26, 2024
1 parent 1b62801 commit 76d2938
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions tests/update_jinja_goldens.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,19 @@
"mistralai/Mixtral-8x7B-Instruct-v0.1",
]


def raise_exception(message: str):
raise ValueError(message)


def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)


def strftime_now(format):
return datetime.now().strftime(format)

Check failure on line 79 in tests/update_jinja_goldens.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"now" is not a known attribute of module "datetime" (reportAttributeAccessIssue)


def handle_chat_template(model_id, variant, template_src):
print(f"# {model_id} @ {variant}", flush=True)
model_name = model_id.replace("/", "-")
Expand All @@ -87,12 +91,12 @@ def handle_chat_template(model_id, variant, template_src):
print(f"- {template_file}", flush=True)

env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
# keep_trailing_newline=False,
extensions=[
jinja2.ext.loopcontrols
])
trim_blocks=True,
lstrip_blocks=True,
# keep_trailing_newline=False,
extensions=[
jinja2.ext.loopcontrols
])
env.filters['tojson'] = tojson
env.globals['raise_exception'] = raise_exception
env.globals['strftime_now'] = strftime_now
Expand All @@ -118,23 +122,24 @@ def handle_chat_template(model_id, variant, template_src):
print(f"- {output_file}", flush=True)
try:
output = template.render(**context)
except:
except Exception as e1:
# Some templates (e.g. Phi-3-medium-128k's) expect a non-null "content" key in each message.
for message in context["messages"]:
if message.get("content") is None:
message["content"] = ""

try:
output = template.render(**context)
except Exception as e:
print(f" ERROR: {e}", flush=True)
output = f"ERROR: {e}"
except Exception as e2:
print(f" ERROR: {e2} (after first error: {e1})", flush=True)
output = f"ERROR: {e2}"

with open(output_file, 'w') as f:
f.write(output)

print()


def main():
for dir in ['tests/chat/templates', 'tests/chat/goldens']:
if not os.path.isdir(dir):
Expand All @@ -149,7 +154,7 @@ def main():

try:
config = json.loads(config_str)
except json.JSONDecodeError as e:
except json.JSONDecodeError:
# Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
# (Remove extra '}' near the end of the file)
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
Expand All @@ -161,5 +166,6 @@ def main():
for ct in chat_template:
handle_chat_template(model_id, ct['name'], ct['template'])


if __name__ == '__main__':
main()

0 comments on commit 76d2938

Please sign in to comment.