Skip to content

Commit

Permalink
Init
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee committed Dec 30, 2024
1 parent 5e1480b commit 32cc3a0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 22 deletions.
17 changes: 6 additions & 11 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
from vllm.model_executor.layers.activation import (
QuickGELU,
SiluAndMul,
MulAndAndSilu,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
Expand Down Expand Up @@ -461,15 +465,6 @@ def forward(
return output


class SwiGLU(nn.Module):

def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
# Note that the order is reversed compared to
# SiluAndMul.
return x * F.silu(gate)


class LanuageModelMLP(nn.Module):
"""Molmo's LLM mlp."""

Expand All @@ -488,7 +483,7 @@ def __init__(self,
quant_config=quant_config,
)
# Activation function.
self.act_fn = SwiGLU()
self.act_fn = MulAndAndSilu()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
self.intermediate_size,
Expand Down
13 changes: 2 additions & 11 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.activation import get_act_fn, MulAndAndSilu
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
Expand Down Expand Up @@ -226,15 +226,6 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
return audio_embeds


class FlippedSiluAndMul(SiluAndMul):
"""Ultravox is trained with SwiGLU with flipped halves."""

def forward(self, x: torch.Tensor):
a, b = x.chunk(2, dim=-1)
flipped = torch.cat((b, a), dim=-1)
return super().forward(flipped)


class UltravoxProjector(nn.Module):

def __init__(self, config: UltravoxConfig):
Expand All @@ -247,7 +238,7 @@ def __init__(self, config: UltravoxConfig):
dim = self.hidden_dim

if config.projector_act == "swiglu":
self.act = FlippedSiluAndMul()
self.act = MulAndAndSilu()
dim = dim // 2
else:
self.act = get_act_fn(config.projector_act)
Expand Down

0 comments on commit 32cc3a0

Please sign in to comment.