diff --git a/opensora/models/diffusion/opensora/modules.py b/opensora/models/diffusion/opensora/modules.py index 8904bc0c..49815f25 100644 --- a/opensora/models/diffusion/opensora/modules.py +++ b/opensora/models/diffusion/opensora/modules.py @@ -873,12 +873,12 @@ def __call__( else: raise NotImplementedError(f'Found attention_mode: {self.attention_mode}') - hidden_states = rearrange(hidden_states, 'b h s d -> s b h d') - - hidden_states = hidden_states.reshape(-1, attn.heads // sp_size, head_dim) - + hidden_states = rearrange(hidden_states, 'b h s d -> s h b d').contiguous() + + hidden_states = all_to_all_SBH(hidden_states, scatter_dim=0, gather_dim=1) + # [s * b, h // sp, d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] - hidden_states = all_to_all_SBH(hidden_states, scatter_dim=0, gather_dim=1).reshape(-1, batch_size, h_size) + hidden_states = rearrange(hidden_states, 's h b d -> s b (h d)').contiguous() else: query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)