From 58caef1358ca09842e93c48e73fcef9d81af8b68 Mon Sep 17 00:00:00 2001 From: foreverpiano Date: Sun, 8 Sep 2024 12:50:56 +0000 Subject: [PATCH] fix sp bug when global_bs = 1 --- opensora/models/diffusion/opensora/modules.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/opensora/models/diffusion/opensora/modules.py b/opensora/models/diffusion/opensora/modules.py index 8904bc0c..49815f25 100644 --- a/opensora/models/diffusion/opensora/modules.py +++ b/opensora/models/diffusion/opensora/modules.py @@ -873,12 +873,12 @@ def __call__( else: raise NotImplementedError(f'Found attention_mode: {self.attention_mode}') - hidden_states = rearrange(hidden_states, 'b h s d -> s b h d') - - hidden_states = hidden_states.reshape(-1, attn.heads // sp_size, head_dim) - + hidden_states = rearrange(hidden_states, 'b h s d -> s h b d').contiguous() + + hidden_states = all_to_all_SBH(hidden_states, scatter_dim=0, gather_dim=1) + # [s * b, h // sp, d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] - hidden_states = all_to_all_SBH(hidden_states, scatter_dim=0, gather_dim=1).reshape(-1, batch_size, h_size) + hidden_states = rearrange(hidden_states, 's h b d -> s b (h d)').contiguous() else: query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)