Skip to content

Commit

Permalink
Gemma capping (#34282)
Browse files Browse the repository at this point in the history
* softcapping

* soft cap before the mask

* style

* ...

* super nit

* update

* fixes

* update

* small issue with modular

* fix modular imports

* update

* fixup

* simplify a hell lot

* simplify cleaning imports

* finish fixing

* update our design

* nits

* use a deprecation cycle

* updates

* Fix modular (recursive deps need to always be computed after merges!)

* push

* fix

* update

* fix modular order

* make fix-copies

* updates

* update

* ?

* don't compile for now

* ?

* fix some stuff

* donc!

* fix copies

* update

* fixup

* ?

* fix two tests

* fix?

* for now, don't use head info

* eager when output attentoin and sdpa or flash as it's the simplest behaviour (for our tests as well :))

* fix-copies

* revert sdpa check

* Apply suggestions from code review

Co-authored-by: Cyril Vallez <[email protected]>

* rebase, fix-copies and push

* add a slow integration test

* update the test

* fix left padding issue

* fix test

* remove duplicate scaling

* quality

* add a small test and make sure it works

* 2b

---------

Co-authored-by: Cyril Vallez <[email protected]>
Co-authored-by: Cyril Vallez <[email protected]>
  • Loading branch information
3 people authored Nov 19, 2024
1 parent 54739a3 commit 4bff54f
Show file tree
Hide file tree
Showing 8 changed files with 426 additions and 536 deletions.
1 change: 1 addition & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,7 @@ def _autoset_attn_implementation(
"eager",
"sdpa",
"flash_attention_2",
"flex_attention",
]:
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
Expand Down
414 changes: 151 additions & 263 deletions src/transformers/models/gemma2/modeling_gemma2.py

Large diffs are not rendered by default.

418 changes: 187 additions & 231 deletions src/transformers/models/gemma2/modular_gemma2.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
is_torch_fp16_available_on_device,
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_greater_or_equal,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,14 @@ def is_flash_attn_greater_or_equal(library_version: str):
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)


@lru_cache()
def is_torch_greater_or_equal(library_version: str):
if not _is_package_available("torch"):
return False

return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)


def is_torchdistx_available():
return _torchdistx_available

Expand Down
2 changes: 1 addition & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature):
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]

# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5))
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-5, rtol=1e-5)

@pytest.mark.generate
def test_past_key_values_format(self):
Expand Down
60 changes: 44 additions & 16 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l
def test_sdpa_equivalence(self):
pass

def test_eager_attention_loaded_by_default(self):
"""Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default."""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

# Usually we enable SDPA by default, but not for Gemma2
model = Gemma2Model(config)
self.assertTrue(model.config._attn_implementation == "eager")

# We can still force SDPA
config._attn_implementation = "sdpa"
model = Gemma2Model(config)
self.assertTrue(model.config._attn_implementation == "sdpa")


@slow
@require_torch_gpu
Expand Down Expand Up @@ -277,9 +264,30 @@ def test_model_9b_pipeline_bf16(self):
"Hi today I'm going to be talking about the history of the United States. The United States of America",
]

model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True)

self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])

@require_read_token
def test_model_2b_pipeline_bf16_flex_attention(self):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
model_id = "google/gemma-2-2b"
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1960s and I am trying to find out what the average",
"Hi today I'm going to be talking about the 10 best anime of all time.\n\n1",
]

model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

Expand Down Expand Up @@ -365,3 +373,23 @@ def test_export_static_cache(self):
)
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)

@require_read_token
def test_model_9b_bf16_flex_attention(self):
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
"<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America",
]

model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)

self.assertEqual(output_text, EXPECTED_TEXTS)
58 changes: 33 additions & 25 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,37 +153,37 @@ def __init__(self, all_bases: Set[str]):
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
# Handle ClassB.call_to_method
if (
isinstance(original_node.value, cst.Name)
m.matches(original_node.value, m.Name())
and original_node.value.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
and m.matches(original_node.attr, m.Name())
):
# Replace with super().call_to_method
return updated_node.with_changes(
value=cst.Call(cst.Name("super")),
)
# Handle ClassB().call_to_method
elif (
isinstance(original_node.value, cst.Call)
and isinstance(original_node.value.func, cst.Name)
m.matches(original_node.value, m.Call())
and m.matches(original_node.value.func, m.Name())
and original_node.value.func.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
and m.matches(original_node.attr, m.Name())
):
# Replace with super().call_to_method
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
return updated_node

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
if isinstance(original_node.func, cst.Attribute) and (
if m.matches(original_node.func, m.Attribute()) and (
# Match ClassB().func_a(...)
(
isinstance(original_node.func.value, cst.Call)
and isinstance(original_node.func.value.func, cst.Name)
m.matches(original_node.func.value, m.Call())
and m.matches(original_node.func.value.func, m.Name())
and original_node.func.value.func.value in self.all_bases
)
or
# Match ClassB.func_a(...)
(isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
(m.matches(original_node.func.value, m.Name()) and original_node.func.value.value in self.all_bases)
):
# Check if the first argument is 'self', and remove it
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
Expand Down Expand Up @@ -632,8 +632,10 @@ def leave_Module(self, node):
for id, node in self.global_nodes.items():
self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line

# Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that
# are not part of the recorded objects (i.e. built-in variables, imports, etc)
def _restrict_dependencies_to_known_entities(self):
"""Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that
are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc).
This should be called only after all merging operations have been finalized!!"""
global_objects = set(self.global_nodes.keys())
for object_name, dependencies in self.object_dependency_mapping.items():
self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects}
Expand Down Expand Up @@ -814,6 +816,8 @@ def merge_modular_dependencies(self, classes, functions, assignments, object_map
# Correctly re-set the global nodes at this point
self.global_nodes.update(self.functions)
self.global_nodes.update(self.assignments)
# Restrict the dependency mappings to the know entities to avoid Python's built-ins
self._restrict_dependencies_to_known_entities()
# Create the global mapping of recursive dependencies for functions and assignments
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()

Expand Down Expand Up @@ -1142,22 +1146,20 @@ def visit_SimpleStatementLine(self, node):
if assigned_variable == "__all__":
self.all_all_to_add = split_all_assignment(node)
else:
self.current_assignment = assigned_variable
self.assignments[assigned_variable] = node

def leave_Module(self, node):
"""When we leave the modular file, we do the following in order:
1. compute the nested (recursive) function and assignment dependencies
2. for each modeling file found in the imports, rename it with the new model name, visit it, and update
1. for each modeling file found in the imports, rename it with the new model name, visit it, and update
its dependency graph with the new function and assignment definitions found in the modular
3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
3. compute the nested (recursive) function and assignment dependencies
"""
# Takes care of finalizing our visit
super().leave_Module(node)

# 1. compute the nested (recursive) function and assignment dependencies
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()

# 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
# 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
self.visited_modules = {}
self.renamers = {}
for file, module in self.model_specific_modules.items():
Expand All @@ -1177,10 +1179,13 @@ def leave_Module(self, node):
# We record it so that we can rename classes later the exact same way
self.renamers[file] = renamer

# 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
# 2. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
# definitions found in the visited files
self.merge_model_specific_imports(self.visited_modules)

# 3. compute the nested (recursive) function and assignment dependencies
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()

# We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later
# Note that we may visit several of the same file types, thus we save them per file type, not file
self.imported_objects_per_file = defaultdict(set)
Expand All @@ -1200,9 +1205,9 @@ def merge_model_specific_imports(self, visited_modules):
if object_name in visited_module.functions and object_name not in self.functions:
self.functions[object_name] = visited_module.functions[object_name]
self.added_objects_file_mapping[object_name] = file
dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None)
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
if dependencies is not None:
self.object_recursive_dependency_mapping[object_name] = dependencies
self.object_dependency_mapping[object_name] = dependencies
for dep in dependencies:
if dep not in self.global_nodes:
self.added_objects_file_mapping[dep] = file
Expand All @@ -1212,16 +1217,18 @@ def merge_model_specific_imports(self, visited_modules):
elif object_name in visited_module.assignments and object_name not in self.assignments:
self.assignments[object_name] = visited_module.assignments[object_name]
self.added_objects_file_mapping[object_name] = file
dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None)
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
if dependencies is not None:
self.object_recursive_dependency_mapping[object_name] = dependencies
self.object_dependency_mapping[object_name] = dependencies
for dep in dependencies:
if dep not in self.global_nodes:
self.added_objects_file_mapping[dep] = file
self.assignments[dep] = visited_module.global_nodes[dep]

# Do not forget to re-assign all nodes after the merge
self.global_nodes = {**self.assignments, **self.classes, **self.functions}
# And restric dependencies to those nodes only
self._restrict_dependencies_to_known_entities()

def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that
Expand All @@ -1239,10 +1246,11 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
else:
original_dependencies.append(dep)
# Sort all lists according to the order in their respective file
all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x])
all_dependencies = []
for file, dependencies in other_files_dependencies.items():
sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x])
all_dependencies += sorted_dependencies
all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x])

# Add all original node first, then merged ones (one file at a time)
for dep in all_dependencies:
Expand Down Expand Up @@ -1485,7 +1493,7 @@ def save_modeling_file(modular_file, converted_file):
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["src/transformers/models/gemma/modular_gemma.py"],
default=["src/transformers/models/gemma2/modular_gemma2.py"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)
Expand Down

0 comments on commit 4bff54f

Please sign in to comment.