diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index e05152ace4d7ce..0192e39bea4e0e 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -29,7 +29,6 @@ from ...utils import ( is_flash_attn_2_available, is_flash_attn_greater_or_equal, - is_flash_attn_greater_or_equal_2_10, is_torch_greater_or_equal, logging, ) @@ -209,118 +208,183 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -class Gemma2Attention(GemmaAttention): +def eager_attention_forward(config, query, key, value, mask): + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + + if config.attn_logit_softcapping is not None: + attn_weights = attn_weights / config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * config.attn_logit_softcapping + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_output = torch.matmul(attn_weights, value_states) + return attn_output + + + +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16): + if mask is not None: + seq_len = mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + value_states = value.transpose(1, 2) + + dropout_rate = config.attention_dropout if config.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + seq_len, + dropout=dropout_rate, + softmax_scale=config.scaling, + is_causal=config.is_causal, + sliding_window=config.sliding_window, + use_top_left_mask=config._flash_attn_uses_top_left_mask, + softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + ) + + return attn_output + + +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = config.attn_logit_softcapping + return soft_cap * torch.tanh(score / soft_cap) + + attn_output = flex_attention( + query, + key, + value, + block_mask=causal_mask, + score_mod=tanh_softcap, + enable_gqa=True, + scale=config.scaling, + return_lse=output_attentions, + ) + return attn_output + + +def sdpa_attention_forward(config, query, key, value, mask, output_attentions=False, target_dtype=torch.float16): + key = repeat_kv(key, config.num_key_value_groups) + value = repeat_kv(value, config.num_key_value_groups) + + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query.device.type == "cuda" and causal_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and query.shape[1] > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=config.attention_dropout if config.training else 0.0, + is_causal=is_causal, + scale=config.scaling, + ) + return attn_output + + +GEMMA_ATTENTION_FUNCTION = { + "flash_attention": flash_attention_forward, + "flex_attention": flex_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + +class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - self.scaling = config.query_pre_attn_scalar**-0.5 - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + super().__init__() + self.config = config + self.layer_idx = layer_idx - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.scaling = 1 / math.sqrt(config.head_dim) - if self.config.attn_logit_softcapping is not None: - attn_weights = attn_weights / self.config.attn_logit_softcapping - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * self.config.attn_logit_softcapping - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + self.scaling = config.query_pre_attn_scalar**-0.5 + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None + self.attention_type = config.attn_implementation + self.attention_function = GEMMA_ATTENTION_FUNCTION[config.attn_implementation] - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + if self.hidden_size % self.num_heads != 0: raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." ) - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Gemma2FlashAttention2(Gemma2Attention): - """ - Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.rotary_emb = Gemma2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -338,57 +402,8 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if attention_mask is not None: - seq_len = attention_mask.shape[1] - key_states = key_states[:, :, :seq_len] - value_states = value_states[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Gemma2RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=self.scaling, - is_causal=self.is_causal, - sliding_window=self.sliding_window, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + attn_output = self.attention_function( + self, query_states, key_states, value_states, attention_mask, self.config ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -399,83 +414,18 @@ def forward( return attn_output, attn_weights, past_key_value - -class Gemma2SdpaAttention(Gemma2Attention): - """ - Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Gemma2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - def tanh_softcap(score, b, h, q_idx, kv_idx): - soft_cap = self.config.attn_logit_softcapping - return soft_cap * torch.tanh(score / soft_cap) - - attn_output = flex_attention( - query_states, - key_states, - value_states, - block_mask=causal_mask, - score_mod=tanh_softcap, - enable_gqa=True, - scale=self.scaling, - return_lse=output_attentions, - ) - if output_attentions: - attn_output, attention_scores = attn_output - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, attention_scores, past_key_value - - -class Gemma2DecoderLayer(GemmaDecoderLayer): +class Gemma2DecoderLayer(nn.Module): def __init__(self, config: Gemma2Config, layer_idx: int): - super().__init__(config, layer_idx) + super().__init__() + self.hidden_size = config.hidden_size self.config = config self.is_sliding = not bool(layer_idx % 2) + self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) + self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.sliding_window = config.sliding_window @@ -541,20 +491,6 @@ def forward( class Gemma2PreTrainedModel(GemmaPreTrainedModel): _supports_quantized_cache = False - @classmethod - def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): - """ - Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. - SDPA reduces the model performance on Gemma2 because of the logits softcapping. - """ - config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) - - # if using the default path -> swap sdpa by eager - if not hard_check_only and config._attn_implementation == "sdpa": - config._attn_implementation = "eager" - - return config - class Gemma2Model(GemmaModel, Gemma2PreTrainedModel): def __init__(self, config: Gemma2Config): diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 95a21affbd7b8f..e226041f389185 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -714,7 +714,7 @@ class PostModularConverterCleaner(m.MatcherDecoratableTransformer): METADATA_DEPENDENCIES = (ParentNodeProvider,) - def __init__(self, added_dependencies: set, unused_imports:Dict[Union[cst.Import, cst.ImportFrom], Set[str]]): + def __init__(self, added_dependencies: set, unused_imports: Dict[Union[cst.Import, cst.ImportFrom], Set[str]]): super().__init__() self.top_level_functions_or_classes = {} self.all_used_functions_or_classes = set() @@ -755,8 +755,7 @@ def leave_Module(self, original_node: cst.Module, node): # Return a new module with the updated body return node.with_changes(body=new_body) - - def leave_If(self,original_node: cst.If,updated_node: cst.If): + def leave_If(self, original_node: cst.If, updated_node: cst.If): for stmt in original_node.body.body: if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): if len(updated_node.body.body) == 0: @@ -776,8 +775,10 @@ def leave_import_alike(self, original_node, updated_node): return updated_node.with_changes(names=names_to_keep) - def get_unused_imports(source): + r""" + You have to use `isinstance` on assignements, m.matches apparently does not work here yet! + """ wrapper = cst.metadata.MetadataWrapper(source) scopes = set(wrapper.resolve(cst.metadata.ScopeProvider).values()) unused_imports: Dict[Union[cst.Import, cst.ImportFrom], Set[str]] = defaultdict(set) @@ -785,15 +786,18 @@ def get_unused_imports(source): for scope in scopes: for assignment in scope.assignments: node = assignment.node - if isinstance(assignment, cst.metadata.Assignment) and isinstance( - node, (cst.Import, cst.ImportFrom) - ): + if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)): if len(assignment.references) == 0: unused_imports[assignment.name].add(node) location = ranges[node].start print( f"Warning on line {location.line:2d}, column {location.column:2d}: Imported name `{assignment.name}` is unused." ) + if isinstance(scope, cst.metadata.GlobalScope): + for assignment in scope.assignments: + node = assignment.node + if assignment.references == 0: + print(f"Warning, {assignment.name} is never referenced") return unused_imports @@ -814,6 +818,7 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama" self.inserted_deps = [] # nodes inserted via super dependency self.all_imports = {} # just stores all of the imports + self.all_safe_imports = {} # stores the safe imports to place them at the end self.global_scope_index = 0 # fmt: on self.files = { # mapping for different component bodies @@ -834,6 +839,7 @@ def __init__(self, python_module, new_name, given_old_name=None, given_new_name= # Mapping from top-level functions to other top-level functions dependencies self.function_call_dependency_mapping = defaultdict(set) self.added_dependencies = set() + self.original_nodes: Dict[str, cst.ClassDef] = {} # Stores the original class def nodes def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """When visiting imports from `transformers.models.xxx` we need to: @@ -867,6 +873,14 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: f"You are importing from {import_statement} directly using global imports. Import from the correct local path" ) + def leave_If(self, original_node, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + if m.matches(parent_node, m.Module()): + for k in node.body.body[0].body[0].names: + import_name = self.python_module.code_for_node(k.name) + self.all_safe_imports[import_name] = node + return node + def leave_SimpleStatementLine(self, original_node, updated_node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) if m.matches(parent_node, m.Module()): @@ -914,6 +928,7 @@ def leave_ClassDef(self, original_node, updated_node): 3. Replace the calls to `super().xxxx` merging parent code """ class_name = original_node.name.value + self.original_nodes[class_name] = original_node bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping] all_bases = [k.value.value for k in original_node.bases] self.global_scope_index += 100 @@ -997,7 +1012,6 @@ def leave_ClassDef(self, original_node, updated_node): list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True) start_insert_idx = self.global_scope_index file_to_update = self.files[file_type] - is_empty_node = self.python_module.code_for_node(original_node.body) == "pass\n" for dependency, _ in list_dependencies: # we can write to the correct body, using the source of the parent class node = class_finder.global_nodes.get(dependency, None) @@ -1008,6 +1022,8 @@ def leave_ClassDef(self, original_node, updated_node): file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node} self.added_dependencies.add(dependency) elif dependency not in self.inserted_deps: + # if the dependency is defined in the modular file, but is just `pass` + is_empty_node = self.python_module.code_for_node(self.original_nodes[dependency].body).strip(" \n") == "pass" # make sure the node is written after its dependencies start_insert_idx = file_to_update[dependency]["insert_idx"] - 1 if ( @@ -1075,14 +1091,6 @@ def visit_Assign(self, node: cst.Assign) -> None: "node": updated_node, } - def leave_If(self, original_node, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) - if m.matches(parent_node, m.Module()): - for k in node.body.body[0].body[0].names: - import_name = self.python_module.code_for_node(k.name) - self.all_imports[import_name] = node - return node - def visit_Call(self, node: cst.Call): """This is used to create a mapping from functions to class calling them, and from top-level functions to functions called inside them. Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function, @@ -1154,23 +1162,30 @@ def _recursively_add_all_new_needed_functions_in_files(self): ) def leave_Module(self, original_node: cst.Module, node): - imports = {self.python_module.code_for_node(k): k for k in self.all_imports.values()} - dependency_imports = {file_type: imports.copy() for file_type in self.files} + all_imports = list(self.all_imports.values()) + all_imports_keys = {self.python_module.code_for_node(k) for k in self.all_imports.values()} + dependency_imports = {file_type: all_imports.copy() for file_type in self.files} for super_file_name, visiter in self.visited_module.items(): file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] - dependency_imports[file_type].update( - {self.python_module.code_for_node(k): k for k in visiter.imports.values()} - ) + dependency_imports[file_type] += [ + k for k in visiter.imports.values() if self.python_module.code_for_node(k) not in all_imports_keys + ] + all_imports_keys.update({self.python_module.code_for_node(k) for k in dependency_imports[file_type]}) + dependency_imports[file_type] += [ + k + for k in self.all_safe_imports.values() + if self.python_module.code_for_node(k) not in all_imports_keys + ] # Check if any new top-level function from the `modular_xxx.py` should be added to the different files # (if it is called in a class in the file, then it will be copy pasted from `modular.py` to that file). - self._recursively_add_all_new_needed_functions_in_files() + # self._recursively_add_all_new_needed_functions_in_files() for file, body in self.files.items(): new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] if len(new_body) > 0: if file in dependency_imports.keys(): - new_body = list(dependency_imports[file].values()) + new_body + new_body = dependency_imports[file] + new_body new_module = cst.Module(body=[*new_body], header=node.header) # Final cleanup unused_imports = get_unused_imports(new_module)