diff --git a/ai_edge_torch/generative/examples/stable_diffusion/clip.py b/ai_edge_torch/generative/examples/stable_diffusion/clip.py index e929c701..4a109a40 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/clip.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/clip.py @@ -15,65 +15,99 @@ import torch from torch import nn -from torch._prims_common import mask_tensor -from torch._prims_common.wrappers import out_wrapper -from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA +from ai_edge_torch.generative.layers.attention import TransformerBlock +import ai_edge_torch.generative.layers.attention_utils as attention_utils +import ai_edge_torch.generative.layers.builder as builder +import ai_edge_torch.generative.layers.model_config as cfg +import ai_edge_torch.generative.utilities.loader as loading_utils + +TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( + ff_up_proj="layers.{}.linear_1", + ff_down_proj="layers.{}.linear_2", + ff_gate_proj="layers.{}.linear_1", + attn_fused_qkv_proj="layers.{}.attention.in_proj", + attn_output_proj="layers.{}.attention.out_proj", + pre_attn_norm="layers.{}.layernorm_1", + pre_ff_norm="layers.{}.layernorm_2", + embedding="embedding.token_embedding", + embedding_position="embedding.position_value", + final_norm="layernorm", + lm_head=None, +) -class CLIPEmbedding(nn.Module): - - def __init__(self, n_vocab: int, n_embd: int, n_token: int): - super().__init__() - self.token_embedding = nn.Embedding(n_vocab, n_embd) - self.position_value = nn.Parameter(torch.zeros((n_token, n_embd))) - - def forward(self, tokens): - x = self.token_embedding(tokens) - x += self.position_value - return x - - -class CLIPLayer(nn.Module): +class CLIP(nn.Module): + """CLIP text encoder + For details, see https://arxiv.org/abs/2103.00020 + """ - def __init__(self, n_head: int, n_embd: int): + def __init__(self, config: cfg.ModelConfig): super().__init__() - self.layernorm_1 = nn.LayerNorm(n_embd) - self.attention = SelfAttention(n_head, n_embd) - self.layernorm_2 = nn.LayerNorm(n_embd) - self.linear_1 = nn.Linear(n_embd, 4 * n_embd) - self.linear_2 = nn.Linear(4 * n_embd, n_embd) - - def forward(self, x): - residue = x - x = self.layernorm_1(x) - x = self.attention(x, causal_mask=True) - x += residue + self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim) + self.tok_embedding_position = nn.Parameter( + torch.zeros((config.max_seq_len, config.embedding_dim)) + ) - residue = x - x = self.layernorm_2(x) - x = self.linear_1(x) - x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function - x = self.linear_2(x) - x += residue + self.config = config + self.transformer_blocks = nn.ModuleList( + TransformerBlock(config) for _ in range(config.num_layers) + ) + self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config) - return x - - -class CLIP(nn.Module): - - def __init__(self): - super().__init__() - self.embedding = CLIPEmbedding(49408, 768, 77) - self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)]) - self.layernorm = nn.LayerNorm(768) + self.mask_cache = attention_utils.build_causal_mask_cache( + size=config.max_seq_len, dtype=torch.float32 + ) @torch.inference_mode def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor: tokens = tokens.type(torch.long) - state = self.embedding(tokens) - for layer in self.layers: - state = layer(state) - output = self.layernorm(state) + state = self.tok_embedding(tokens) + self.tok_embedding_position + for layer in self.transformer_blocks: + state = layer(state, mask=self.mask_cache) + output = self.final_norm(state) return output + + +def get_model_config() -> cfg.ModelConfig: + max_seq_len = 77 + vocab_size = 49408 + num_layers = 12 + num_heads = 12 + num_query_groups = 12 + embedding_dim = 768 + + attn_config = cfg.AttentionConfig( + num_heads=num_heads, + num_query_groups=num_query_groups, + rotary_percentage=0.0, + qkv_use_bias=True, + qkv_transpose_before_split=True, + output_proj_use_bias=True, + enable_kv_cache=False, + ) + + ff_config = cfg.FeedForwardConfig( + type=cfg.FeedForwardType.SEQUENTIAL, + activation=cfg.ActivationType.GELU_QUICK, + intermediate_size=embedding_dim * 4, + use_bias=True, + ) + + norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM) + + config = cfg.ModelConfig( + vocab_size=vocab_size, + num_layers=num_layers, + max_seq_len=max_seq_len, + embedding_dim=embedding_dim, + attn_config=attn_config, + ff_config=ff_config, + pre_attention_norm_config=norm_config, + pre_ff_norm_config=norm_config, + final_norm_config=norm_config, + enable_hlfb=True, + ) + + return config diff --git a/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py b/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py index bb1b4108..318c15c6 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py @@ -19,11 +19,12 @@ import torch import ai_edge_torch -from ai_edge_torch.generative.examples.stable_diffusion.clip import CLIP +import ai_edge_torch.generative.examples.stable_diffusion.clip as clip from ai_edge_torch.generative.examples.stable_diffusion.decoder import Decoder from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder import ai_edge_torch.generative.examples.stable_diffusion.util as util +import ai_edge_torch.generative.utilities.loader as loading_utils @torch.inference_mode @@ -36,8 +37,9 @@ def convert_stable_diffusion_to_tflite( image_width: int = 512, ): - clip = CLIP() - clip.load_state_dict(torch.load(clip_ckpt_path)) + clip_model = clip.CLIP(clip.get_model_config()) + loader = loading_utils.ModelLoader(clip_ckpt_path, clip.TENSOR_NAMES) + loader.load(clip_model, strict=False) encoder = Encoder() encoder.load_state_dict(torch.load(encoder_ckpt_path)) @@ -59,13 +61,13 @@ def convert_stable_diffusion_to_tflite( ) input_latents = encoder(input_image, noise) - context_cond = clip(prompt_tokens) + context_cond = clip_model(prompt_tokens) context_uncond = torch.zeros_like(context_cond) context = torch.cat([context_cond, context_uncond], axis=0) time_embedding = util.get_time_embedding(timestamp) # CLIP text encoder - ai_edge_torch.signature('encode', clip, (prompt_tokens,)).convert().export( + ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export( '/tmp/stable_diffusion/clip.tflite' ) diff --git a/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py b/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py index 2992f3c3..be8ee0e2 100644 --- a/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +++ b/ai_edge_torch/generative/examples/stable_diffusion/diffusion.py @@ -202,11 +202,6 @@ def forward(self, x, context, time): x = self.bottleneck(x, context, time) - # print('x shape:') - # print(list(x.shape)) - # print('time shape:') - # print(list(time.shape)) - for layers in self.decoders: x = torch.cat((x, skip_connections.pop()), dim=1) x = layers(x, context, time) @@ -214,199 +209,6 @@ def forward(self, x, context, time): return x -# The encoder component. -class UNetEncoder(nn.Module): - - def __init__(self): - super().__init__() - self.time_embedding = TimeEmbedding(320) - self.encoders = nn.ModuleList( - [ - SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)), - SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)), - SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)), - SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)), - SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)), - SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)), - SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)), - SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)), - SwitchSequential(ResidualBlock(1280, 1280), AttentionBlock(8, 160)), - SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)), - SwitchSequential(ResidualBlock(1280, 1280)), - SwitchSequential(ResidualBlock(1280, 1280)), - ] - ) - - def forward(self, x, context, time): - time_embedding = self.time_embedding(time) - skip_connections = [] - for layers in self.encoders: - x = layers(x, context, time_embedding) - skip_connections.append(x) - - return x, skip_connections, time_embedding - - -class UNetBottleNeck(nn.Module): - - def __init__(self): - super().__init__() - self.bottleneck = SwitchSequential( - ResidualBlock(1280, 1280), - AttentionBlock(8, 160), - ResidualBlock(1280, 1280), - ) - - def forward(self, x, context, time): - x = self.bottleneck(x, context, time) - # print('shape') - # print(list(x.shape)) - return x - - -# Unet decoder. -class UNetDecoder1(nn.Module): - - def __init__(self): - super().__init__() - self.decoders = nn.ModuleList( - [ - SwitchSequential(ResidualBlock(2560, 1280)), - SwitchSequential(ResidualBlock(2560, 1280)), - SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)), - SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)), - ] - ) - - def forward(self, x, context, time, s9, s10, s11, s12): - x = torch.cat((x, s12), dim=1) - x = self.decoders[0](x, context, time) - x = torch.cat((x, s11), dim=1) - x = self.decoders[1](x, context, time) - x = torch.cat((x, s10), dim=1) - x = self.decoders[2](x, context, time) - x = torch.cat((x, s9), dim=1) - x = self.decoders[3](x, context, time) - - return x - - -class UNetDecoder2(nn.Module): - - def __init__(self): - super().__init__() - self.decoders = nn.ModuleList( - [ - SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)), - SwitchSequential( - ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280) - ), - SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)), - SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)), - ] - ) - - def forward(self, x, context, time, s5, s6, s7, s8): - x = torch.cat((x, s8), dim=1) - x = self.decoders[0](x, context, time) - x = torch.cat((x, s7), dim=1) - x = self.decoders[1](x, context, time) - x = torch.cat((x, s6), dim=1) - x = self.decoders[2](x, context, time) - x = torch.cat((x, s5), dim=1) - x = self.decoders[3](x, context, time) - return x - - -class UNetDecoder3(nn.Module): - - def __init__(self): - super().__init__() - self.decoders = nn.ModuleList( - [ - SwitchSequential( - ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640) - ), - SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)), - SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)), - SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)), - ] - ) - self.final = FinalLayer(320, 4) - - def forward(self, x, context, time, s1, s2, s3, s4): - x = torch.cat((x, s4), dim=1) - x = self.decoders[0](x, context, time) - x = torch.cat((x, s3), dim=1) - x = self.decoders[1](x, context, time) - x = torch.cat((x, s2), dim=1) - x = self.decoders[2](x, context, time) - x = torch.cat((x, s1), dim=1) - x = self.decoders[3](x, context, time) - - x = self.final(x) - return x - - -class UNetDecoder(nn.Module): - - def __init__(self): - super().__init__() - self.decoders = nn.ModuleList( - [ - SwitchSequential(ResidualBlock(2560, 1280)), - SwitchSequential(ResidualBlock(2560, 1280)), - SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)), - SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)), - SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)), - SwitchSequential( - ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280) - ), - SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)), - SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)), - SwitchSequential( - ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640) - ), - SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)), - SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)), - SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)), - ] - ) - self.final = FinalLayer(320, 4) - - def forward( - self, x, context, time, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12 - ): - x = torch.cat((x, s12), dim=1) - x = self.decoders[0](x, context, time) - x = torch.cat((x, s11), dim=1) - x = self.decoders[1](x, context, time) - x = torch.cat((x, s10), dim=1) - x = self.decoders[2](x, context, time) - x = torch.cat((x, s9), dim=1) - x = self.decoders[3](x, context, time) - x = torch.cat((x, s8), dim=1) - x = self.decoders[4](x, context, time) - x = torch.cat((x, s7), dim=1) - x = self.decoders[5](x, context, time) - x = torch.cat((x, s6), dim=1) - x = self.decoders[6](x, context, time) - x = torch.cat((x, s5), dim=1) - x = self.decoders[7](x, context, time) - x = torch.cat((x, s4), dim=1) - x = self.decoders[0](x, context, time) - x = torch.cat((x, s3), dim=1) - x = self.decoders[1](x, context, time) - x = torch.cat((x, s2), dim=1) - x = self.decoders[2](x, context, time) - x = torch.cat((x, s1), dim=1) - x = self.decoders[3](x, context, time) - - x = self.final(x) - - return x - - class FinalLayer(nn.Module): def __init__(self, in_channels, out_channels): @@ -432,68 +234,6 @@ def __init__(self): @torch.inference_mode def forward(self, latent, context, time): time = self.time_embedding(time) - # print('time:') - # print(list(time.shape)) output = self.unet(latent, context, time) output = self.final(output) return output - - -# Calling code as if Diffusion is splitted into two parts. -class DiffusionSplitted(nn.Module): - - def __init__(self): - super().__init__() - self.unet_encoder = UNetEncoder() - self.bottleneck = UNetBottleNeck() - self.unet_decoder1 = UNetDecoder1() - self.unet_decoder2 = UNetDecoder2() - self.unet_decoder3 = UNetDecoder3() - - def get_skip_connections(self, latent, context, time): - _, skip_connections, _ = self.unet_encoder(latent, context, time) - return skip_connections - - def forward(self, latent, context, time): - output, skip_connections, time = self.unet_encoder(latent, context, time) - # print("output shape of unet encoder...") - # print(list(output.shape)) - # print("output shape of time...") - # print(list(time.shape)) - output = self.bottleneck(output, context, time) - # print("output shape of bn") - # print(list(output.shape)) - output = self.unet_decoder1( - output, - context, - time, - skip_connections[8], - skip_connections[9], - skip_connections[10], - skip_connections[11], - ) - # print("output shape of d1:") - # print(list(output.shape)) - - output = self.unet_decoder2( - output, - context, - time, - skip_connections[4], - skip_connections[5], - skip_connections[6], - skip_connections[7], - ) - - # print("output shape of d2:") - # print(list(output.shape)) - output = self.unet_decoder3( - output, - context, - time, - skip_connections[0], - skip_connections[1], - skip_connections[2], - skip_connections[3], - ) - return output diff --git a/ai_edge_torch/generative/layers/attention.py b/ai_edge_torch/generative/layers/attention.py index d161c3fa..6c320a0d 100644 --- a/ai_edge_torch/generative/layers/attention.py +++ b/ai_edge_torch/generative/layers/attention.py @@ -57,7 +57,7 @@ def __init__(self, config: cfg.ModelConfig) -> None: def forward( self, x: torch.Tensor, - rope: Tuple[torch.Tensor, torch.Tensor], + rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -134,7 +134,7 @@ def __init__( def forward( self, x: torch.Tensor, - rope: Tuple[torch.Tensor, torch.Tensor], + rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -159,28 +159,35 @@ def forward( # Assemble into a number of query groups to support MHA, MQA and GQA. q_per_kv = self.config.num_heads // self.config.num_query_groups total_qkv = q_per_kv + 2 # Each group has >=1 queries, 1 key, and 1 value. - qkv = qkv.view( - B, T, self.config.num_query_groups, total_qkv, self.head_dim - ) # (B, T, num_query_groups, total_qkv, head_dim) + if self.config.qkv_transpose_before_split: + qkv = qkv.view( + B, T, total_qkv, self.config.num_query_groups, self.head_dim + ) # (B, T, total_qkv, num_query_groups, head_dim) + qkv_axis = -3 + else: + qkv = qkv.view( + B, T, self.config.num_query_groups, total_qkv, self.head_dim + ) # (B, T, num_query_groups, total_qkv, head_dim) + qkv_axis = -2 # Split batched computation into three. - q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2) - + q, k, v = qkv.split((q_per_kv, 1, 1), dim=qkv_axis) q = q.reshape(B, T, -1, self.head_dim) k = k.reshape(B, T, -1, self.head_dim) v = v.reshape(B, T, -1, self.head_dim) # Compute rotary positional embedding for query and key. n_elem = int(self.config.rotary_percentage * self.head_dim) - cos, sin = rope - q_roped = rotary_pos_emb.apply_rope( - q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) - ) - k_roped = rotary_pos_emb.apply_rope( - k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) - ) - q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) - k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) + if n_elem > 0: + cos, sin = rope + q_roped = rotary_pos_emb.apply_rope( + q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) + ) + k_roped = rotary_pos_emb.apply_rope( + k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2) + ) + q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) + k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) if self.kv_cache is not None: # TODO(haoliang): Handle when execeeding max sequence length. diff --git a/ai_edge_torch/generative/layers/builder.py b/ai_edge_torch/generative/layers/builder.py index 6b12a274..55720b3a 100644 --- a/ai_edge_torch/generative/layers/builder.py +++ b/ai_edge_torch/generative/layers/builder.py @@ -97,6 +97,10 @@ def _get_activation(type_: cfg.ActivationType): return F.gelu elif type_ == cfg.ActivationType.GELU_TANH: return lambda x: F.gelu(x, approximate="tanh") + elif type_ == cfg.ActivationType.GELU_QUICK: + # GELU approximation that is fast but somewhat inaccurate. + # See: https://github.com/hendrycks/GELUs + return lambda x: x * F.sigmoid(1.702 * x) elif type_ == cfg.ActivationType.RELU: return F.relu else: diff --git a/ai_edge_torch/generative/layers/model_config.py b/ai_edge_torch/generative/layers/model_config.py index f8796bc8..59b5fde1 100644 --- a/ai_edge_torch/generative/layers/model_config.py +++ b/ai_edge_torch/generative/layers/model_config.py @@ -27,6 +27,7 @@ class ActivationType(enum.Enum): SILU = enum.auto() GELU = enum.auto() GELU_TANH = enum.auto() + GELU_QUICK = enum.auto() RELU = enum.auto() @@ -46,7 +47,7 @@ class FeedForwardType(enum.Enum): # `output = linear(act(linear(x)))`. SEQUENTIAL = enum.auto() - # `output = linear(act(linear(x)) * lienar(x))`. + # `output = linear_2(act(linear_1(x)) * lienar_3(x))`. GATED = enum.auto() @@ -60,6 +61,9 @@ class AttentionConfig: num_query_groups: Optional[int] # Percentage of Rotary Positional Embedding added Q and K projections. rotary_percentage: Optional[float] = None + # Whether to transpose the query groups of qkv bundled tensor before + # splitting into separated tensors. + qkv_transpose_before_split: bool = False # Whether to use bias with Query, Key, and Value projection. qkv_use_bias: bool = False # Whether to use bias with attention output projection. diff --git a/ai_edge_torch/generative/utilities/loader.py b/ai_edge_torch/generative/utilities/loader.py index 020f2489..a1280773 100644 --- a/ai_edge_torch/generative/utilities/loader.py +++ b/ai_edge_torch/generative/utilities/loader.py @@ -69,10 +69,16 @@ def load_pytorch_statedict(full_path: str): Raises: ValueError: If no tensors are loaded from the provided directory or file. """ - pattern = os.path.join(full_path, "*.bin") if os.path.isdir(full_path) else full_path files = [] - for file in glob.glob(pattern): - files.append(file) + patterns = [] + if os.path.isdir(full_path): + patterns.append(os.path.join(full_path, "*.bin")) + patterns.append(os.path.join(full_path, "*.pt")) + else: + patterns.append(full_path) + for pattern in patterns: + for file in glob.glob(pattern): + files.append(file) tensors = {} for file in files: @@ -93,18 +99,20 @@ class ModelLoader: @dataclass class TensorNames: - attn_query_proj: str - attn_key_proj: str - attn_value_proj: str - attn_output_proj: str - - ff_up_proj: str - ff_down_proj: str + attn_query_proj: str = None + attn_key_proj: str = None + attn_value_proj: str = None + attn_fused_qkv_proj: str = None + attn_output_proj: str = None + + ff_up_proj: str = None + ff_down_proj: str = None ff_gate_proj: str = None pre_attn_norm: str = None pre_ff_norm: str = None embedding: str = None + embedding_position: str = None final_norm: str = None lm_head: str = None @@ -129,6 +137,10 @@ def load(self, model: torch.nn.Module, strict: bool = True): strict (bool, optional): Whether the converted keys are strictly matched. Defaults to True. + Returns: + missing_keys (List[str]): a list of str containing the missing keys + unexpected_keys (List[str]): a list of str containing the unexpected keys + Raises: ValueError: If conversion results in unmapped tensors and strict mode is enabled. @@ -139,6 +151,10 @@ def load(self, model: torch.nn.Module, strict: bool = True): converted_state["tok_embedding.weight"] = state.pop( f"{self._names.embedding}.weight" ) + if self._names.embedding_position is not None: + converted_state["tok_embedding_position"] = state.pop( + f"{self._names.embedding_position}" + ) if self._names.lm_head is not None: converted_state["lm_head.weight"] = state.pop(f"{self._names.lm_head}.weight") if model.config.lm_head_use_bias: @@ -158,7 +174,7 @@ def load(self, model: torch.nn.Module, strict: bool = True): raise ValueError( f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}" ) - model.load_state_dict(converted_state, strict=strict) + return model.load_state_dict(converted_state, strict=strict) def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]: """A best effort method for finding appropriate state loader. @@ -172,13 +188,15 @@ def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]: if os.path.isdir(self._file_name): if glob.glob(os.path.join(self._file_name, "*.safetensors")): return load_safetensors - if glob.glob(os.path.join(self._file_name, "*.bin")): + if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob( + os.path.join(self._file_name, "*.pt") + ): return load_pytorch_statedict if self._file_name.endswith(".safetensors"): return load_safetensors - if self._file_name.endswith(".bin"): + if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"): return load_pytorch_statedict raise ValueError(f"File format not supported.") @@ -225,22 +243,33 @@ def _map_attention( converted_state: Dict[str, torch.Tensor], ): prefix = f"transformer_blocks.{idx}" - q_name = self._names.attn_query_proj.format(idx) - k_name = self._names.attn_key_proj.format(idx) - v_name = self._names.attn_value_proj.format(idx) - converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv( - config, - state.pop(f"{q_name}.weight"), - state.pop(f"{k_name}.weight"), - state.pop(f"{v_name}.weight"), - ) - if config.attn_config.qkv_use_bias: - converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv( + if self._names.attn_fused_qkv_proj: + fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx) + converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop( + f"{fused_qkv_name}.weight" + ) + else: + q_name = self._names.attn_query_proj.format(idx) + k_name = self._names.attn_key_proj.format(idx) + v_name = self._names.attn_value_proj.format(idx) + converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv( config, - state.pop(f"{q_name}.bias"), - state.pop(f"{k_name}.bias"), - state.pop(f"{v_name}.bias"), + state.pop(f"{q_name}.weight"), + state.pop(f"{k_name}.weight"), + state.pop(f"{v_name}.weight"), ) + if config.attn_config.qkv_use_bias: + if self._names.attn_fused_qkv_proj: + converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop( + f"{fused_qkv_name}.bias" + ) + else: + converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv( + config, + state.pop(f"{q_name}.bias"), + state.pop(f"{k_name}.bias"), + state.pop(f"{v_name}.bias"), + ) o_name = self._names.attn_output_proj.format(idx) converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(