Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: mixtral #1328

Merged
merged 10 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm-cuda

# Build megablocks
FROM kernel-builder as megablocks-builder

RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e

# Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base

Expand All @@ -175,8 +180,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
&& rm -rf /var/lib/apt/lists/*

# Copy conda with PyTorch installed
COPY --from=pytorch-install /opt/conda /opt/conda
# Copy conda with PyTorch and Megablocks installed
COPY --from=megablocks-builder /opt/conda /opt/conda

# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
Expand Down
5 changes: 5 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,9 @@ pub async fn run(
// Batch size buckets
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
// Speculated tokens buckets
let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();

// Prometheus handler
let builder = PrometheusBuilder::new()
Expand All @@ -641,6 +644,8 @@ pub async fn run(
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
.unwrap()
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
.unwrap()
.set_buckets_for_metric(skipped_matcher, &skipped_buckets)
.unwrap();
let prom_handle = builder
.install_recorder()
Expand Down
3 changes: 3 additions & 0 deletions server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ gen-server:
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py

install-megablocks:
pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e

install: gen-server
pip install pip --upgrade
pip install -r requirements_cuda.txt
Expand Down
27 changes: 24 additions & 3 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import torch

from loguru import logger
Expand Down Expand Up @@ -78,6 +77,18 @@
if MISTRAL:
__all__.append(FlashMistral)

MIXTRAL = True
try:
from text_generation_server.models.flash_mixtral import FlashMixtral
except ImportError as e:
logger.warning(f"Could not import Mixtral model: {e}")
MIXTRAL = False

if MIXTRAL:
__all__.append(FlashMixtral)



def get_model(
model_id: str,
revision: Optional[str],
Expand Down Expand Up @@ -141,7 +152,6 @@ def get_model(
use_medusa = None
if "medusa_num_heads" in config_dict:
use_medusa = model_id
medusa_config = config_dict
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
Expand Down Expand Up @@ -292,7 +302,18 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mistral model requires flash attention v2")
raise NotImplementedError("Mistral models requires flash attention v2")

if model_type == "mixtral":
if MIXTRAL:
return FlashMixtral(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")

if model_type == "opt":
return OPTSharded(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,8 @@
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastRMSNorm
)
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM

if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops


class LlamaConfig(PretrainedConfig):
def __init__(
Expand Down Expand Up @@ -95,75 +89,6 @@ def __init__(
)


class LlamaRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()

weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps

def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states

hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)

# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states

return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states

out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")


def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
Expand Down Expand Up @@ -363,10 +288,8 @@ def __init__(self, layer_id, config, weights):
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)

self.input_layernorm = LlamaRMSNorm(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = LlamaRMSNorm(
self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
Expand Down Expand Up @@ -430,7 +353,7 @@ def __init__(self, config, weights):
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(
self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,9 @@
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastRMSNorm
)
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM

if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops

if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
raise ImportError("Mistral model requires flash attn v2")
Expand Down Expand Up @@ -100,76 +96,6 @@ def __init__(
**kwargs,
)


class MistralRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()

weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps

def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states

hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)

# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states

return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states

out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")


def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
Expand Down Expand Up @@ -371,10 +297,10 @@ def __init__(self, layer_id, config, weights):
)
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)

self.input_layernorm = MistralRMSNorm(
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = MistralRMSNorm(
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
Expand Down Expand Up @@ -440,7 +366,7 @@ def __init__(self, config, weights):
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = MistralRMSNorm(
self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)

Expand Down
Loading
Loading