diff --git a/.circleci/create_circleci_config.py b/.circleci/create_circleci_config.py index 5573114298fbbc..7e1414b2d8cac9 100644 --- a/.circleci/create_circleci_config.py +++ b/.circleci/create_circleci_config.py @@ -58,8 +58,10 @@ class CircleCIJob: marker: Optional[str] = None parallelism: Optional[int] = 1 pytest_num_workers: int = 12 + pytest_num_workers: int = 12 pytest_options: Dict[str, Any] = None resource_class: Optional[str] = "2xlarge" + resource_class: Optional[str] = "2xlarge" tests_to_run: Optional[List[str]] = None working_directory: str = "~/transformers" # This should be only used for doctest job! @@ -260,6 +262,7 @@ def job_name(self): install_steps=["uv venv", "uv pip install -e ."], parallelism=1, pytest_num_workers=12, + pytest_num_workers=12, ) @@ -286,6 +289,7 @@ def job_name(self): install_steps=["uv venv", "uv pip install -e ."], marker="is_pipeline_test", pytest_num_workers=12, + pytest_num_workers=12, ) @@ -367,6 +371,7 @@ def job_name(self): "tests/models/*nat", "tests/models/deta", "tests/models/udop", + "tests/models/udop", "tests/models/nougat", ], pytest_num_workers=1, diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e725e1705c1657..afc997a7aa54b7 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -27,6 +27,8 @@ title: Agents - local: llm_tutorial title: Generation with LLMs + - local: conversations + title: Chatting with Transformers title: Tutorials - sections: - isExpanded: false diff --git a/docs/source/en/conversations.md b/docs/source/en/conversations.md new file mode 100644 index 00000000000000..9336503ad7cb8c --- /dev/null +++ b/docs/source/en/conversations.md @@ -0,0 +1,290 @@ + + +# Chatting with Transformers + +If you're reading this article, you're almost certainly aware of **chat models**. Chat models are conversational +AIs that you can send and receive messages with. The most famous of these is the proprietary ChatGPT, but there are +now many open-source chat models which match or even substantially exceed its performance. These models are free to +download and run on a local machine. Although the largest and most capable models require high-powered hardware +and lots of memory to run, there are smaller models that will run perfectly well on a single consumer GPU, or even +an ordinary desktop or notebook CPU. + +This guide will help you get started with chat models. We'll start with a brief quickstart guide that uses a convenient, +high-level "pipeline". This is all you need if you just want to start running a chat model +immediately. After the quickstart, we'll move on to more detailed information about +what exactly chat models are, how to choose an appropriate one, and a low-level breakdown of each of the +steps involved in talking to a chat model. We'll also give some tips on optimizing the performance and memory usage +of your chat models. + + +## Quickstart + +If you have no time for details, here's the brief summary: Chat models continue chats. This means that you pass them +a conversation history, which can be as short as a single user message, and the model will continue the conversation +by adding its response. Let's see this in action. First, let's build a chat: + +```python +chat = [ + {"role": "system", "content": "You are a sassy, wise-cracking robot as imagined by Hollywood circa 1986."}, + {"role": "user", "content": "Hey, can you tell me any fun things to do in New York?"} +] +``` + +Notice that in addition to the user's message, we added a **system** message at the start of the conversation. Not all +chat models support system messages, but when they do, they represent high-level directives about how the model +should behave in the conversation. You can use this to guide the model - whether you want short or long responses, +lighthearted or serious ones, and so on. If you want the model to do useful work instead of +practicing its improv routine, you can either omit the system message or try a terse one such as "You are a helpful and intelligent +AI assistant who responds to user queries." + +Once you have a chat, the quickest way to continue it is using the [`TextGenerationPipeline`]. +Let's see this in action with `LLaMA-3`. Note that `LLaMA-3` is a gated model, which means you will need to +[apply for access](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and log in with your Hugging Face +account to use it. We'll also use `device_map="auto"`, which will load the model on GPU if there's enough memory +for it, and set the dtype to `torch.bfloat16` to save memory: + +```python +import torch +from transformers import pipeline + +pipe = pipeline("text-generation", "meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto") +response = pipe(chat, max_new_tokens=512) +print(response[0]['generated_text'][-1]['content']) +``` + +And you'll get: + +```text +(sigh) Oh boy, you're asking me for advice? You're gonna need a map, pal! Alright, +alright, I'll give you the lowdown. But don't say I didn't warn you, I'm a robot, not a tour guide! + +So, you wanna know what's fun to do in the Big Apple? Well, let me tell you, there's a million +things to do, but I'll give you the highlights. First off, you gotta see the sights: the Statue of +Liberty, Central Park, Times Square... you know, the usual tourist traps. But if you're lookin' for +something a little more... unusual, I'd recommend checkin' out the Museum of Modern Art. It's got +some wild stuff, like that Warhol guy's soup cans and all that jazz. + +And if you're feelin' adventurous, take a walk across the Brooklyn Bridge. Just watch out for +those pesky pigeons, they're like little feathered thieves! (laughs) Get it? Thieves? Ah, never mind. + +Now, if you're lookin' for some serious fun, hit up the comedy clubs in Greenwich Village. You might +even catch a glimpse of some up-and-coming comedians... or a bunch of wannabes tryin' to make it big. (winks) + +And finally, if you're feelin' like a real New Yorker, grab a slice of pizza from one of the many amazing +pizzerias around the city. Just don't try to order a "robot-sized" slice, trust me, it won't end well. (laughs) + +So, there you have it, pal! That's my expert advice on what to do in New York. Now, if you'll +excuse me, I've got some oil changes to attend to. (winks) +``` + +You can continue the chat by appending your own response to it. The +`response` object returned by the pipeline actually contains the entire chat so far, so we can simply append +a message and pass it back: + +```python +chat = response[0]['generated_text'] +chat.append( + {"role": "user", "content": "Wait, what's so wild about soup cans?"} +) +response = pipe(chat, max_new_tokens=512) +print(response[0]['generated_text'][-1]['content']) +``` + +And you'll get: + +```text +(laughs) Oh, you're killin' me, pal! You don't get it, do you? Warhol's soup cans are like, art, man! +It's like, he took something totally mundane, like a can of soup, and turned it into a masterpiece. It's +like, "Hey, look at me, I'm a can of soup, but I'm also a work of art!" +(sarcastically) Oh, yeah, real original, Andy. + +But, you know, back in the '60s, it was like, a big deal. People were all about challenging the +status quo, and Warhol was like, the king of that. He took the ordinary and made it extraordinary. +And, let me tell you, it was like, a real game-changer. I mean, who would've thought that a can of soup could be art? (laughs) + +But, hey, you're not alone, pal. I mean, I'm a robot, and even I don't get it. (winks) +But, hey, that's what makes art, art, right? (laughs) +``` + +The remainder of this tutorial will cover specific topics such +as performance and memory, or how to select a chat model for your needs. + +## Choosing a chat model + +There are an enormous number of different chat models available on the [Hugging Face Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending), +and new users often feel very overwhelmed by the selection offered. Don't be, though! You really need to just focus on +two important considerations: +- The model's size, which will determine if you can fit it in memory and how quickly it will +run. +- The quality of the model's chat output. + +In general, these are correlated - bigger models tend to be +more capable, but even so there's a lot of variation at a given size point! + +### Size and model naming +The size of a model is easy to spot - it's the number in the model name, like "8B" or "70B". This is the number of +**parameters** in the model. Without quantization, you should expect to need about 2 bytes of memory per parameter. +This means that an "8B" model with 8 billion parameters will need about 16GB of memory just to fit the parameters, +plus a little extra for other overhead. It's a good fit for a high-end consumer GPU with 24GB of memory, such as a 3090 +or 4090. + +Some chat models are "Mixture of Experts" models. These may list their sizes in different ways, such as "8x7B" or +"141B-A35B". The numbers are a little fuzzier here, but in general you can read this as saying that the model +has approximately 56 (8x7) billion parameters in the first case, or 141 billion parameters in the second case. + +Note that it is very common to use quantization techniques to reduce the memory usage per parameter to 8 bits, 4 bits, +or even less. This topic is discussed in more detail in the [Memory considerations](#memory-considerations) section below. + +### But which chat model is best? +Even once you know the size of chat model you can run, there's still a lot of choice out there. One way to sift through +it all is to consult **leaderboards**. Two of the most popular leaderboards are the [OpenLLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard) +and the [LMSys Chatbot Arena Leaderboard](https://chat.lmsys.org/?leaderboard). Note that the LMSys leaderboard +also includes proprietary models - look at the `licence` column to identify open-source ones that you can download, then +search for them on the [Hugging Face Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending). + +### Specialist domains +Some models may be specialized for certain domains, such as medical or legal text, or non-English languages. +If you're working in these domains, you may find that a specialized model will give you big performance benefits. +Don't automatically assume that, though! Particularly when specialized models are smaller or older than the current +cutting-edge, a top-end general-purpose model may still outclass them. Thankfully, we are beginning to see +[domain-specific leaderboards](https://huggingface.co/blog/leaderboard-medicalllm) that should make it easier to locate +the best models for specialized domains. + +## What happens inside the pipeline? + +The quickstart above used a high-level pipeline to chat with a chat model, which is convenient, but not the +most flexible. Let's take a more low-level approach, to see each of the steps involved in chat. Let's start with +a code sample, and then break it down: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch + +# Prepare the input as before +chat = [ + {"role": "system", "content": "You are a sassy, wise-cracking robot as imagined by Hollywood circa 1986."}, + {"role": "user", "content": "Hey, can you tell me any fun things to do in New York?"} +] + +# 1: Load the model and tokenizer +model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto", torch_dtype=torch.bfloat16) +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + +# 2: Apply the chat template +formatted_chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) +print("Formatted chat:\n", formatted_chat) + +# 3: Tokenize the chat (This can be combined with the previous step using tokenize=True) +inputs = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False) +# Move the tokenized inputs to the same device the model is on (GPU/CPU) +inputs = {key: tensor.to(model.device) for key, tensor in inputs.items()} +print("Tokenized inputs:\n", inputs) + +# 4: Generate text from the model +outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.) +print("Generated tokens:\n", outputs) + +# 5: Decode the output back to a string +decoded_output = tokenizer.decode(outputs[0][inputs['input_ids'].size(1):], skip_special_tokens=True) +print("Decoded output:\n", decoded_output) +``` + +There's a lot in here, each piece of which could be its own document! Rather than going into too much detail, I'll cover +the broad ideas, and leave the details for the linked documents. The key steps are: + +1. [Models](https://huggingface.co/learn/nlp-course/en/chapter2/3) and [Tokenizers](https://huggingface.co/learn/nlp-course/en/chapter2/4?fw=pt) are loaded from the Hugging Face Hub. +2. The chat is formatted using the tokenizer's [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating) +3. The formatted chat is [tokenized](https://huggingface.co/learn/nlp-course/en/chapter2/4) using the tokenizer. +4. We [generate](https://huggingface.co/docs/transformers/en/llm_tutorial) a response from the model. +5. The tokens output by the model are decoded back to a string + +## Performance, memory and hardware + +You probably know by now that most machine learning tasks are run on GPUs. However, it is entirely possible +to generate text from a chat model or language model on a CPU, albeit somewhat more slowly. If you can fit +the model in GPU memory, though, this will usually be the preferable option. + +### Memory considerations + +By default, Hugging Face classes like [`TextGenerationPipeline`] or [`AutoModelForCausalLM`] will load the model in +`float32` precision. This means that it will need 4 bytes (32 bits) per parameter, so an "8B" model with 8 billion +parameters will need ~32GB of memory. However, this can be wasteful! Most modern language models are trained in +"bfloat16" precision, which uses only 2 bytes per parameter. If your hardware supports it (Nvidia 30xx/Axxx +or newer), you can load the model in `bfloat16` precision, using the `torch_dtype` argument as we did above. + +It is possible to go even lower than 16-bits using "quantization", a method to lossily compress model weights. This +allows each parameter to be squeezed down to 8 bits, 4 bits or even less. Note that, especially at 4 bits, +the model's outputs may be negatively affected, but often this is a tradeoff worth making to fit a larger and more +capable chat model in memory. Let's see this in action with `bitsandbytes`: + +```python +from transformers import AutoModelForCausalLM, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) # You can also try load_in_4bit +model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto", quantization_config=quantization_config) +``` + +Or we can do the same thing using the `pipeline` API: + +```python +from transformers import pipeline, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) # You can also try load_in_4bit +pipe = pipeline("text-generation", "meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto", model_kwargs={"quantization_config": quantization_config}) +``` + +There are several other options for quantizing models besides `bitsandbytes` - please see the [Quantization guide](./quantization) +for more information. + +### Performance considerations + + + +For a more extensive guide on language model performance and optimization, check out [LLM Inference Optimization](./llm_optims) . + + + + +As a general rule, larger chat models will be slower in addition to requiring more memory. It's possible to be +more concrete about this, though: Generating text from a chat model is unusual in that it is bottlenecked by +**memory bandwidth** rather than compute power, because every active parameter must be read from memory for each +token that the model generates. This means that number of tokens per second you can generate from a chat +model is generally proportional to the total bandwidth of the memory it resides in, divided by the size of the model. + +In our quickstart example above, our model was ~16GB in size when loaded in `bfloat16` precision. +This means that 16GB must be read from memory for every token generated by the model. Total memory bandwidth can +vary from 20-100GB/sec for consumer CPUs to 200-900GB/sec for consumer GPUs, specialized CPUs like +Intel Xeon, AMD Threadripper/Epyc or high-end Apple silicon, and finally up to 2-3TB/sec for data center GPUs like +the Nvidia A100 or H100. This should give you a good idea of the generation speed you can expect from these different +hardware types. + +Therefore, if you want to improve the speed of text generation, the easiest solution is to either reduce the +size of the model in memory (usually by quantization), or get hardware with higher memory bandwidth. For advanced users, +several other techniques exist to get around this bandwidth bottleneck. The most common are variants on +[assisted generation](https://huggingface.co/blog/assisted-generation), also known as "speculative +sampling". These techniques try to guess multiple future tokens at once, often using a smaller "draft model", and then +confirm these generations with the chat model. If the guesses are validated by the chat model, more than one token can +be generated per forward pass, which greatly alleviates the bandwidth bottleneck and improves generation speed. + +Finally, we should also note the impact of "Mixture of Experts" (MoE) models here. Several popular chat models, +such as Mixtral, Qwen-MoE and DBRX, are MoE models. In these models, not every parameter is active for every token generated. +As a result, MoE models generally have much lower memory bandwidth requirements, even though their total size +can be quite large. They can therefore be several times faster than a normal "dense" model of the same size. However, +techniques like assisted generation are generally ineffective for these models because more parameters will become +active with each new speculated token, which will negate the bandwidth and speed benefits that the MoE architecture +provides. + diff --git a/docs/source/es/_toctree.yml b/docs/source/es/_toctree.yml index 4506dbd06f96b9..cf1ae39c03077e 100644 --- a/docs/source/es/_toctree.yml +++ b/docs/source/es/_toctree.yml @@ -100,4 +100,6 @@ title: BERTología - local: perplexity title: Perplejidad de los modelos de longitud fija + - local: pipeline_webserver + title: Flujo de trabajo para la inferencia de los servidores web title: Guías conceptuales diff --git a/docs/source/es/pipeline_webserver.md b/docs/source/es/pipeline_webserver.md new file mode 100644 index 00000000000000..e77e620f58b78b --- /dev/null +++ b/docs/source/es/pipeline_webserver.md @@ -0,0 +1,134 @@ + + +# Uso de un flujo de trabajo para un servidor web + + +Crear un motor de inferencia es un tema complejo, y la "mejor" solución probablemente dependerá de tu caso de uso. ¿Estás en CPU o en GPU? ¿Quieres la latencia más baja, el rendimiento más alto, soporte para muchos modelos o simplemente optimizar altamente un modelo específico? Hay muchas formas de abordar este tema, así que lo que vamos a presentar es un buen valor predeterminado para comenzar, que no necesariamente será la solución más óptima para ti. + + + +Lo fundamental para entender es que podemos usar un iterador, tal como [en un conjunto de datos](https://huggingface.co/docs/transformers/pipeline_tutorial#using-pipelines-on-a-dataset), ya que un servidor web es básicamente un sistema que espera solicitudes y las trata a medida que llegan. + + + +Por lo general, los servidores web están multiplexados (multihilo, asíncrono, etc.) para manejar varias solicitudes simultáneamente. Por otro lado, los flujos de trabajo (y principalmente los modelos subyacentes) no son realmente ideales para el paralelismo; consumen mucha RAM, por lo que es mejor darles todos los recursos disponibles cuando se están ejecutando o es un trabajo intensivo en cómputo. + +Vamos a resolver esto haciendo que el servidor web maneje la carga ligera de recibir y enviar solicitudes, y que un único hilo maneje el trabajo real. Este ejemplo va a utilizar `starlette`. El marco de trabajo no es realmente importante, pero es posible que debas ajustar o cambiar el código si estás utilizando otro para lograr el mismo efecto. + +Crear `server.py`: + +```py +from starlette.applications import Starlette +from starlette.responses import JSONResponse +from starlette.routing import Route +from transformers import pipeline +import asyncio + + +async def homepage(request): + payload = await request.body() + string = payload.decode("utf-8") + response_q = asyncio.Queue() + await request.app.model_queue.put((string, response_q)) + output = await response_q.get() + return JSONResponse(output) + + +async def server_loop(q): + pipe = pipeline(model="google-bert/bert-base-uncased") + while True: + (string, response_q) = await q.get() + out = pipe(string) + await response_q.put(out) + + +app = Starlette( + routes=[ + Route("/", homepage, methods=["POST"]), + ], +) + + +@app.on_event("startup") +async def startup_event(): + q = asyncio.Queue() + app.model_queue = q + asyncio.create_task(server_loop(q)) +``` + +Ahora puedes empezar con: +```bash +uvicorn server:app +``` + +Y puedes consultarlo con: +```bash +curl -X POST -d "test [MASK]" http://localhost:8000/ +#[{"score":0.7742936015129089,"token":1012,"token_str":".","sequence":"test."},...] +``` + +¡Y listo, ahora tienes una buena idea de cómo crear un servidor web! + +Lo realmente importante es cargar el modelo solo **una vez**, de modo que no haya copias del modelo en el servidor web. De esta manera, no se utiliza RAM innecesariamente. Luego, el mecanismo de queuing (colas) te permite hacer cosas sofisticadas como acumular algunos elementos antes de inferir para usar el agrupamiento dinámico: + + + +El ejemplo de código a continuación está escrito intencionalmente como pseudocódigo para facilitar la lectura. +¡No lo ejecutes sin verificar si tiene sentido para los recursos de tu sistema! + + + +```py +(string, rq) = await q.get() +strings = [] +queues = [] +while True: + try: + (string, rq) = await asyncio.wait_for(q.get(), timeout=0.001) # 1ms + except asyncio.exceptions.TimeoutError: + break + strings.append(string) + queues.append(rq) +strings +outs = pipe(strings, batch_size=len(strings)) +for rq, out in zip(queues, outs): + await rq.put(out) +``` + +Nuevamente, el código propuesto está optimizado para la legibilidad, no para ser el mejor código. +En primer lugar, no hay límite de tamaño de lote, lo cual generalmente no es una buena idea. Luego, el tiempo de espera se restablece en cada obtención de la cola, lo que significa que podrías esperar mucho más de 1ms antes de ejecutar la inferencia (retrasando la primera solicitud en esa cantidad). + +Sería mejor tener un único plazo de 1ms. + +Esto siempre esperará 1ms incluso si la cola está vacía, lo que podría no ser lo mejor ya que probablemente quieras comenzar a hacer inferencias si no hay nada en la cola. Pero tal vez tenga sentido si el agrupamiento es realmente crucial para tu caso de uso. Nuevamente, no hay una solución única y mejor. + + +## Algunas cosas que podrías considerar + +### Comprobación de errores + +Hay muchas cosas que pueden salir mal en producción: falta de memoria, falta de espacio, cargar el modelo podría fallar, la consulta podría ser incorrecta, la consulta podría ser correcta pero aún así fallar debido a una mala configuración del modelo, y así sucesivamente. + +Generalmente, es bueno que el servidor muestre los errores al usuario, por lo que agregar muchos bloques `try..except` para mostrar esos errores es una buena idea. Pero ten en cuenta que también puede ser un riesgo de seguridad revelar todos esos errores dependiendo de tu contexto de seguridad. + +### Interrupción de circuito + +Los servidores web suelen verse mejor cuando hacen interrupciones de circuitos. Significa que devuelven errores adecuados cuando están sobrecargados en lugar de simplemente esperar la consulta indefinidamente. Devolver un error 503 en lugar de esperar un tiempo muy largo o un error 504 después de mucho tiempo. + +Esto es relativamente fácil de implementar en el código propuesto ya que hay una sola cola. Mirar el tamaño de la cola es una forma básica de empezar a devolver errores antes de que tu servidor web falle bajo carga. + +### Bloqueo del hilo principal + +Actualmente, PyTorch no es consciente de la asincronía, y el cálculo bloqueará el hilo principal mientras se ejecuta. Esto significa que sería mejor si PyTorch se viera obligado a ejecutarse en su propio hilo/proceso. Esto no se hizo aquí porque el código es mucho más complejo (principalmente porque los hilos, la asincronía y las colas no se llevan bien juntos). Pero en última instancia, hace lo mismo. + +Esto sería importante si la inferencia de elementos individuales fuera larga (> 1s) porque en este caso, significa que cada consulta durante la inferencia tendría que esperar 1s antes de recibir incluso un error. + +### Procesamiento por lotes dinámico + +En general, el procesamiento por lotes no es necesariamente una mejora respecto a pasar 1 elemento a la vez (ver [procesamiento por lotes](https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching) para más información). Pero puede ser muy efectivo cuando se usa en el entorno correcto. En la API, no hay procesamiento por lotes dinámico por defecto (demasiada oportunidad para una desaceleración). Pero para la inferencia de BLOOM - que es un modelo muy grande - el procesamiento por lotes dinámico es **esencial** para proporcionar una experiencia decente para todos. diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index a89105029868a1..f9abade2809166 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -15,6 +15,7 @@ """Utilities to dynamically load objects from the Hub.""" import filecmp import importlib +import importlib.util import os import re import shutil @@ -196,9 +197,15 @@ def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) - Returns: `typing.Type`: The class looked for. """ - name = os.path.normpath(module_path).replace(".py", "").replace(os.path.sep, ".") - module_path = str(Path(HF_MODULES_CACHE) / module_path) - module = importlib.machinery.SourceFileLoader(name, module_path).load_module() + name = os.path.normpath(module_path).rstrip(".py").replace(os.path.sep, ".") + module_spec = importlib.util.spec_from_file_location(name, location=Path(HF_MODULES_CACHE) / module_path) + module = sys.modules.get(name) + if module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + # reload in both cases + module_spec.loader.exec_module(module) return getattr(module, class_name) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index be164e8e2c0c00..1ed8040f88c5ef 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -324,7 +324,7 @@ def dtype_byte_size(dtype): """ if dtype == torch.bool: return 1 / 8 - bit_search = re.search(r"[^\d](\d+)$", str(dtype)) + bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) if bit_search is None: raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") bit_size = int(bit_search.groups()[0]) @@ -3170,6 +3170,9 @@ def from_pretrained( torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) device_map = hf_quantizer.update_device_map(device_map) + # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` + user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value + # Force-set to `True` for more mem efficiency if low_cpu_mem_usage is None: low_cpu_mem_usage = True diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 4cf5d98f77f114..5a6d49752daf58 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -327,8 +327,11 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in if labels is not None: final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling - image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) if image_to_overwrite.sum() != image_features.shape[:-1].numel(): diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 155d9e3e6abf40..085999933b6533 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -403,8 +403,11 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in if labels is not None: final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling - image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) if image_to_overwrite.sum() != image_features.shape[:-1].numel(): diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 1b20353410c895..aaffc19bd5e36a 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -331,8 +331,11 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in if labels is not None: final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling - image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) if image_to_overwrite.sum() != image_features.shape[:-1].numel(): diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 112cfd644f1570..764475e7b1db01 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -84,14 +84,12 @@ def validate_environment(self, *args, **kwargs): } if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): raise ValueError( - """ - Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the - quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules - in 32-bit, you need to set `llm_int8_enable_fp32_cpu_offload=True` and pass a custom `device_map` to - `from_pretrained`. Check - https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu - for more details. - """ + "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " + "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " + "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to " + "`from_pretrained`. Check " + "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu " + "for more details. " ) if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.39.0"): diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index b80e9bd3a1dfa2..8016194f9d86c6 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -84,14 +84,12 @@ def validate_environment(self, *args, **kwargs): } if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): raise ValueError( - """ - Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the - quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules - in 32-bit, you need to set `llm_int8_enable_fp32_cpu_offload=True` and pass a custom `device_map` to - `from_pretrained`. Check - https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu - for more details. - """ + "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " + "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " + "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to " + "`from_pretrained`. Check " + "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu " + "for more details. " ) if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.37.2"): diff --git a/src/transformers/quantizers/quantizer_eetq.py b/src/transformers/quantizers/quantizer_eetq.py index 547037a5978ad4..7be0a6bd9e033d 100644 --- a/src/transformers/quantizers/quantizer_eetq.py +++ b/src/transformers/quantizers/quantizer_eetq.py @@ -167,4 +167,4 @@ def is_serializable(self): @property def is_trainable(self) -> bool: - return False + return True diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 1c7e3200904379..3656bb65051bde 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -459,3 +459,29 @@ def test_small_model_integration_test_batch(self): EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_unk_token(self): + # related to (#29835) + model = LlavaNextForConditionalGeneration.from_pretrained( + "llava-hf/llava-v1.6-mistral-7b-hf", + load_in_4bit=True, + ) + + prompt_with_unk = "[INST] \nWhat is shown in this image? [/INST]" + inputs = self.processor(prompt_with_unk, self.image, return_tensors="pt") + + # verify single forward pass + inputs = inputs.to(torch_device) + with torch.no_grad(): + output = model(**inputs) + + # verify generation + output = model.generate(**inputs, max_new_tokens=40) + EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart' # fmt: skip + + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index ba0bf8e6b27ebb..16d8e9e1293d75 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -101,7 +101,12 @@ _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) - from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint + from transformers.modeling_utils import ( + _find_disjoint, + _find_identical, + dtype_byte_size, + shard_checkpoint, + ) # Fake pretrained models for tests class BaseModel(PreTrainedModel): @@ -465,6 +470,31 @@ def test_model_from_pretrained_attn_implementation(self): module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] ) + def test_torch_dtype_byte_sizes(self): + torch_dtypes_and_bytes = [ + (torch.double, 8), + (torch.float64, 8), + (torch.float, 4), + (torch.float32, 4), + (torch.half, 2), + (torch.float16, 2), + (torch.bfloat16, 2), + (torch.long, 8), + (torch.int64, 8), + (torch.int, 4), + (torch.int32, 4), + (torch.short, 2), + (torch.int16, 2), + (torch.uint8, 1), + (torch.int8, 1), + (torch.float8_e4m3fn, 1), + (torch.float8_e5m2, 1), + (torch.bool, 0.125), + ] + + for torch_dtype, bytes_per_element in torch_dtypes_and_bytes: + self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element) + def test_no_super_init_config_and_model(self): config = NoSuperInitConfig(attribute=32) model = NoSuperInitModel(config)