diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 32ee6686b6b..6e9e5205edb 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.models.transformers_causal_lm import CausalLMBatch from text_generation_server.utils import weight_hub_files, download_weights from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 6e6463bc948..7d674947f3f 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -5,7 +5,10 @@ from transformers import AutoTokenizer from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch +from text_generation_server.models.transformers_causal_lm import ( + TransformersCausalLM, + CausalLMBatch, +) @pytest.fixture(scope="session") diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index cb2622d9b53..19152659d9a 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -1,7 +1,7 @@ import pytest from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.models.transformers_causal_lm import CausalLMBatch from text_generation_server.models.santacoder import SantaCoder diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 648fcee953c..5615de656c8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -8,10 +8,13 @@ from huggingface_hub import hf_hub_download, HfApi from typing import Optional, List from pathlib import Path - +import transformers from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model -from text_generation_server.models.causal_lm import CausalLM +from text_generation_server.models.transformers_causal_lm import TransformersCausalLM +from text_generation_server.models.transformers_flash_causal_lm import ( + TransformersFlashCausalLM, +) from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.mpt import MPTSharded @@ -24,6 +27,8 @@ from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.phi import Phi +from text_generation_server.models.globals import USE_CUSTOM_MODELING + from text_generation_server.utils.import_utils import SYSTEM # The flag below controls whether to allow TF32 on matmul. This flag defaults to False @@ -288,6 +293,31 @@ def get_model( ) model_type = config_dict.get("model_type", None) + transformers_causal_lm_class = TransformersCausalLM + if ( + not USE_CUSTOM_MODELING + and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + ): + logger.info( + "TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback." + ) + transformers_model_class = getattr( + transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] + ) + + if ( + transformers_model_class._supports_flash_attn_2 + and transformers_model_class._supports_cache_class + ): + logger.info( + f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersFlashCausalLM with ragged tensors (single dimension for batch and sequence length)." + ) + transformers_causal_lm_class = TransformersFlashCausalLM + else: + logger.info( + f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersCausalLM with classic tensors with padding (two dimensions for batch size and sequence length)." + ) + speculator = None if "medusa_num_heads" in config_dict: medusa_model_id = model_id @@ -449,7 +479,7 @@ def get_model( or model_type == GPT2 and model_id.startswith("bigcode/") ): - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashSantacoderSharded( model_id, revision, @@ -491,7 +521,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GPT2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: try: return FlashGPT2( model_id, @@ -504,7 +534,8 @@ def get_model( except RuntimeError as e: # Lots of legacy models with various weight names. logger.warning(f"Couldn't load flash gpt2 variant: {e}") - return CausalLM( + + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -515,7 +546,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) else: - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -524,7 +555,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif model_type == GPT_NEOX: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashNeoXSharded( model_id, revision, @@ -543,7 +574,7 @@ def get_model( trust_remote_code=trust_remote_code, ) else: - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -553,7 +584,7 @@ def get_model( ) elif model_type == PHI: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashPhi( model_id, revision, @@ -563,7 +594,7 @@ def get_model( trust_remote_code=trust_remote_code, ) else: - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -573,7 +604,7 @@ def get_model( ) elif model_type == "phi-msft": - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: raise NotImplementedError( "Legacy phi-msft is not supported with Flash Attention" ) @@ -588,7 +619,7 @@ def get_model( ) elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashLlama( model_id, revision, @@ -601,7 +632,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) else: - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -610,7 +641,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if model_type == GEMMA: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashGemma( model_id, revision, @@ -622,7 +653,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) else: - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -632,7 +663,7 @@ def get_model( ) if model_type == COHERE: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashCohere( model_id, revision, @@ -644,7 +675,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) else: - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -654,7 +685,7 @@ def get_model( ) if model_type == DBRX: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashDbrx( model_id, revision, @@ -666,7 +697,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) else: - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -677,7 +708,7 @@ def get_model( if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]: if sharded: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: if config_dict.get("alibi", False): raise NotImplementedError("sharded is not supported for this model") return FlashRWSharded( @@ -710,7 +741,7 @@ def get_model( ) if model_type == MISTRAL: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashMistral( model_id, revision, @@ -722,7 +753,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) else: - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -732,7 +763,7 @@ def get_model( ) if model_type == MIXTRAL: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashMixtral( model_id, revision, @@ -744,7 +775,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) else: - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -754,7 +785,7 @@ def get_model( ) if model_type == STARCODER2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashStarcoder2( model_id, revision, @@ -767,7 +798,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") ) else: - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -777,7 +808,7 @@ def get_model( ) if model_type == QWEN2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return FlashQwen2( model_id, revision, @@ -788,7 +819,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) else: - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -817,7 +848,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if model_type == IDEFICS: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return IDEFICSSharded( model_id, revision, @@ -829,7 +860,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == IDEFICS2: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return Idefics2( model_id, revision, @@ -841,7 +872,7 @@ def get_model( else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == "paligemma": - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return PaliGemma( model_id, revision, @@ -854,7 +885,7 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == LLAVA_NEXT: - if FLASH_ATTENTION: + if FLASH_ATTENTION and USE_CUSTOM_MODELING: return LlavaNext( model_id, revision, @@ -881,7 +912,7 @@ def get_model( elif quantize == "exl2": raise NotImplementedError("exl2 quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, @@ -902,7 +933,7 @@ def get_model( auto_map = config_dict.get("auto_map", None) if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): - return CausalLM( + return transformers_causal_lm_class( model_id, revision, quantize=quantize, diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 17aa12e84dc..88cb2bdf09c 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -12,8 +12,8 @@ from text_generation_server.models.custom_modeling.bloom_modeling import ( BloomForCausalLM, ) -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.models import TransformersCausalLM +from text_generation_server.models.transformers_causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 from text_generation_server.utils import ( initialize_torch_distributed, @@ -36,7 +36,7 @@ def from_pb( return batch -class BLOOMSharded(CausalLM): +class BLOOMSharded(TransformersCausalLM): def __init__( self, model_id: str, @@ -89,7 +89,7 @@ def __init__( model = BloomForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c48ed26883f..8cb8c0a9713 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -299,6 +299,7 @@ def __init__(self, prefix, config, weights, index): def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" + and False and self.hidden_act == "silu" and hidden_states.shape[0] == 1 and not self.quantize diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 30c92d90e27..0f9ffd3b6aa 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -9,8 +9,8 @@ AutoConfig, PreTrainedTokenizerBase, ) -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.models import TransformersCausalLM +from text_generation_server.models.transformers_causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.utils import ( @@ -164,7 +164,7 @@ def from_pb( ) -class GalacticaSharded(CausalLM): +class GalacticaSharded(TransformersCausalLM): def __init__( self, model_id: str, @@ -211,7 +211,7 @@ def __init__( model = OPTForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index cc2f172ad15..157c88b9e04 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -44,3 +44,7 @@ def set_model_id(model_id: str): def set_adapter_to_index(adapter_to_index: Dict[str, int]): global ADAPTER_TO_INDEX ADAPTER_TO_INDEX = adapter_to_index + + +USE_CUSTOM_MODELING = os.getenv("USE_CUSTOM_MODELING", "true") +USE_CUSTOM_MODELING = USE_CUSTOM_MODELING == "true" or USE_CUSTOM_MODELING == "1" diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index c37cfb7da72..a707c833c05 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -7,7 +7,7 @@ AutoTokenizer, AutoConfig, ) -from text_generation_server.models import CausalLM +from text_generation_server.models import TransformersCausalLM from text_generation_server.models.custom_modeling.neox_modeling import ( GPTNeoxForCausalLM, ) @@ -18,7 +18,7 @@ ) -class GPTNeoxSharded(CausalLM): +class GPTNeoxSharded(TransformersCausalLM): def __init__( self, model_id: str, @@ -64,7 +64,7 @@ def __init__( model = GPTNeoxForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 1e79b25f263..355c257fcd3 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -8,8 +8,8 @@ from huggingface_hub import hf_hub_download import json -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.models import TransformersCausalLM +from text_generation_server.models.transformers_causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, @@ -37,7 +37,7 @@ def from_pb( return batch -class MPTSharded(CausalLM): +class MPTSharded(TransformersCausalLM): def __init__( self, model_id: str, @@ -89,7 +89,7 @@ def __init__( model = MPTForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 6d7d07f59c3..4f53faafb2d 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -8,7 +8,7 @@ AutoConfig, ) from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM -from text_generation_server.models import CausalLM +from text_generation_server.models import TransformersCausalLM from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -16,7 +16,7 @@ ) -class OPTSharded(CausalLM): +class OPTSharded(TransformersCausalLM): def __init__( self, model_id: str, @@ -62,7 +62,7 @@ def __init__( model = OPTForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index 93d42b2b8dc..92aab9fb3af 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -4,7 +4,7 @@ from transformers import AutoConfig, AutoTokenizer from typing import Optional, List, Tuple -from text_generation_server.models import CausalLM +from text_generation_server.models import TransformersCausalLM from text_generation_server.models.custom_modeling.phi_modeling import ( PhiConfig, PhiForCausalLM, @@ -16,7 +16,7 @@ ) -class Phi(CausalLM): +class Phi(TransformersCausalLM): def __init__( self, model_id: str, @@ -59,7 +59,7 @@ def __init__( weights = Weights(filenames, device, dtype, process_group=self.process_group) model = PhiForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 37ca277b7e0..785137605ec 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -3,10 +3,10 @@ from transformers import AutoTokenizer, AutoModelForCausalLM from typing import List, Optional, Tuple -from text_generation_server.models import CausalLM +from text_generation_server.models import TransformersCausalLM -class RW(CausalLM): +class RW(TransformersCausalLM): def __init__( self, model_id: str, @@ -61,7 +61,7 @@ def __init__( else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - super(CausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index caddbe191b3..b595718d88b 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -4,7 +4,7 @@ from typing import Optional, List from transformers import AutoTokenizer, AutoModelForCausalLM -from text_generation_server.models import CausalLM +from text_generation_server.models import TransformersCausalLM FIM_PREFIX = "" FIM_MIDDLE = "" @@ -13,7 +13,7 @@ EOD = "<|endoftext|>" -class SantaCoder(CausalLM): +class SantaCoder(TransformersCausalLM): def __init__( self, model_id: str, @@ -61,7 +61,7 @@ def __init__( trust_remote_code=trust_remote_code, ) - super(CausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/transformers_causal_lm.py similarity index 99% rename from server/text_generation_server/models/causal_lm.py rename to server/text_generation_server/models/transformers_causal_lm.py index 10c64c6611f..dfe3caf637d 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/transformers_causal_lm.py @@ -478,7 +478,7 @@ def __len__(self): return len(self.requests) -class CausalLM(Model): +class TransformersCausalLM(Model): def __init__( self, model_id: str, diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py new file mode 100644 index 00000000000..13f5118ec48 --- /dev/null +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -0,0 +1,359 @@ +import torch +import time +import sys +from dataclasses import dataclass +from opentelemetry import trace +from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase +from typing import Optional, Tuple, List, Type, Dict, Any +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models import Model +from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.tokens import batch_top_tokens +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.models.flash_causal_lm import ( + FlashCausalLMBatch, + FlashCausalLM, +) + +from text_generation_server.utils.import_utils import ( + empty_cache, + synchronize, + get_free_memory, +) +from text_generation_server.utils.speculate import get_speculate +from text_generation_server.utils.dist import MEMORY_FRACTION + +tracer = trace.get_tracer(__name__) + +from text_generation_server.adapters import AdapterBatchData +from text_generation_server.layers.attention import reshape_and_cache +from transformers.cache_utils import Cache +from transformers.flash_attention_utils import _flash_supports_window_size +from flash_attn import flash_attn_varlen_func +from text_generation_server.layers.attention import paged_attention + +from loguru import logger + +# Why define it here? +BLOCK_SIZE: int = 16 + + +def patch_everywhere( + attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None +): + """ + Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`. + + Args: + attribute_name (`str`): + The name of attribute to patch. + patch (`Any`): + The patch for the attribute. + module_name_prefix (`Optional[str]`, defaults to `None`): + If set, only module names starting with this prefix will be considered for patching. + """ + # sys.modules may be updated while being iterated over, hence the list copy. + for name in list(sys.modules): + module = sys.modules[name] + if module_name_prefix is not None and not name.startswith(module_name_prefix): + continue + if hasattr(module, attribute_name): + setattr(module, attribute_name, patch) + + +def _flash_attention_forward_patched( + query_states, + key_states, + value_states, + attention_mask, + query_length, + layer_idx: int, + dropout=0.0, + softmax_scale=None, + is_causal=False, + _flash_attn_uses_top_left_mask=False, + sliding_window=None, + cache_position=0, + **kwargs, #: Unpack[ExtraKwargs], +): + _flash_attn_uses_top_left_mask = True # TODO felix: fix rocm + + if not _flash_attn_uses_top_left_mask: + causal = is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = is_causal and query_length != 1 + + print(f"causal: {causal}") + + use_sliding_windows = ( + _flash_supports_window_size + and sliding_window is not None + and cache_position > sliding_window + ) + flash_kwargs = ( + {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + ) + + print(f"kwargs {kwargs.keys()}") + + cu_seqlen_prefill = kwargs.get("cu_seqlen_prefill") + max_seq_lens = kwargs.get("max_seq_lens") + + if cu_seqlen_prefill is not None: + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlen_prefill, + cu_seqlens_k=cu_seqlen_prefill, + max_seqlen_q=kwargs["max_s"], + max_seqlen_k=kwargs["max_s"], + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + # **kwargs, + **flash_kwargs, + ) + else: + attn_output = torch.empty_like(query_states) + + paged_attention( + attn_output, + query_states, + kwargs["kv_cache"][layer_idx][0], + kwargs["kv_cache"][layer_idx][1], + kwargs["kv_head_mapping"], + softmax_scale, + kwargs["block_tables"], + kwargs["input_lengths"], + kwargs["max_s"], + ) + + attn_output = attn_output.view(attn_output.shape[0], -1) + + return attn_output + + +class PagedCache(Cache): + def __init__(self) -> None: + pass + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + kv_cache = cache_kwargs["kv_cache"] + reshape_and_cache( + key_states, + value_states, + kv_cache[layer_idx][0], + kv_cache[layer_idx][1], + cache_kwargs["slots"], + ) + + if cache_kwargs["cu_seqlen_prefill"] is not None: + return key_states, value_states + else: + return None, None + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + raise ValueError( + "PagedCache.get_seq_length should never be called, please open an issue." + ) + + def get_max_length(self) -> Optional[int]: + raise ValueError( + "PagedCache.get_max_length should never be called, please open an issue." + ) + + +class TransformersFlashCausalLM(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + if speculator: + raise RuntimeError("Speculator decoding is not enabled for AutoModel") + + if torch.cuda.is_available(): + device = torch.device("cuda:0") # TODO felix: fix support for accelerate + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map=None, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + attn_implementation="flash_attention_2", + ) + if ( + torch.cuda.is_available() + and torch.cuda.device_count() == 1 + and quantize != "bitsandbytes" + ): + model = model.cuda() + + self.kv_cache = [] + + # TODO felix: make this more general. + self.num_layers = len(model.model.layers) + self.num_kv_heads = model.config.num_key_value_heads + self.head_size = model.config.hidden_size // model.config.num_attention_heads + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + # Skip FlashCausalLM init. + super(FlashCausalLM, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + ) + + def warmup(self, batch: FlashCausalLMBatch): + # The warmup batch is the biggest batch we could ever receive + empty_cache() + + patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched) + + try: + self.init_kv_cache( + batch.num_blocks, + self.num_layers, + self.num_kv_heads, + self.head_size, + self.dtype, + self.device, + ) + max_bt = batch.max_blocks + max_s = max_bt * BLOCK_SIZE + + _, batch, _ = self.generate_token(batch) + except torch.cuda.OutOfMemoryError as e: + raise RuntimeError( + f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"You need to decrease `--max-batch-prefill-tokens`" + ) from e + + synchronize(self.device) + + # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) + # Calculate the number of blocks that can be allocated with the free memory + dtype_size = torch.tensor([], dtype=self.dtype).element_size() + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + + free_memory = get_free_memory(self.device, MEMORY_FRACTION) + batch_num_blocks = batch.num_blocks if batch is not None else 0 + + num_blocks = ( + # Leave 5% for some wiggle room + int((free_memory * 0.95) // total_cache_size) + # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. + + batch_num_blocks + ) + + del batch + + self.init_kv_cache( + num_blocks, + self.num_layers, + self.num_kv_heads, + self.head_size, + self.dtype, + self.device, + ) + + return int(num_blocks * BLOCK_SIZE) + + def forward( + self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # NOTE: adapter_data: not supported + + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices + + # TODO felix: support window attention + # if cu_seqlen_prefill is None and self.max_past() is not None: + # # In decode, not prefill, we're actually overwriting the KV-cache + # # in a circular buffer mode. + # # This makes sure the max_s for the decode pass is correct. + # max_s = min(self.max_past(), max_s) + + bs = input_ids.shape[0] + + logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=PagedCache(), + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + cache_position=False, + return_dict=False, + )[0] + + if lm_head_indices is not None: + logits = logits[lm_head_indices] + + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + + speculative_logits = None + + return logits, speculative_logits