Skip to content

Commit

Permalink
Merge branch 'main' of github.com:huggingface/transformers into chang…
Browse files Browse the repository at this point in the history
…e-ci
  • Loading branch information
ArthurZucker committed Apr 26, 2024
2 parents 330f3b8 + 20081c7 commit 7638069
Show file tree
Hide file tree
Showing 15 changed files with 532 additions and 28 deletions.
5 changes: 5 additions & 0 deletions .circleci/create_circleci_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ class CircleCIJob:
marker: Optional[str] = None
parallelism: Optional[int] = 1
pytest_num_workers: int = 12
pytest_num_workers: int = 12
pytest_options: Dict[str, Any] = None
resource_class: Optional[str] = "2xlarge"
resource_class: Optional[str] = "2xlarge"
tests_to_run: Optional[List[str]] = None
working_directory: str = "~/transformers"
# This should be only used for doctest job!
Expand Down Expand Up @@ -260,6 +262,7 @@ def job_name(self):
install_steps=["uv venv", "uv pip install -e ."],
parallelism=1,
pytest_num_workers=12,
pytest_num_workers=12,
)


Expand All @@ -286,6 +289,7 @@ def job_name(self):
install_steps=["uv venv", "uv pip install -e ."],
marker="is_pipeline_test",
pytest_num_workers=12,
pytest_num_workers=12,
)


Expand Down Expand Up @@ -367,6 +371,7 @@ def job_name(self):
"tests/models/*nat",
"tests/models/deta",
"tests/models/udop",
"tests/models/udop",
"tests/models/nougat",
],
pytest_num_workers=1,
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
title: Agents
- local: llm_tutorial
title: Generation with LLMs
- local: conversations
title: Chatting with Transformers
title: Tutorials
- sections:
- isExpanded: false
Expand Down
290 changes: 290 additions & 0 deletions docs/source/en/conversations.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions docs/source/es/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,6 @@
title: BERTología
- local: perplexity
title: Perplejidad de los modelos de longitud fija
- local: pipeline_webserver
title: Flujo de trabajo para la inferencia de los servidores web
title: Guías conceptuales
134 changes: 134 additions & 0 deletions docs/source/es/pipeline_webserver.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
<!--⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Uso de un flujo de trabajo para un servidor web

<Tip>
Crear un motor de inferencia es un tema complejo, y la "mejor" solución probablemente dependerá de tu caso de uso. ¿Estás en CPU o en GPU? ¿Quieres la latencia más baja, el rendimiento más alto, soporte para muchos modelos o simplemente optimizar altamente un modelo específico? Hay muchas formas de abordar este tema, así que lo que vamos a presentar es un buen valor predeterminado para comenzar, que no necesariamente será la solución más óptima para ti.
</Tip>


Lo fundamental para entender es que podemos usar un iterador, tal como [en un conjunto de datos](https://huggingface.co/docs/transformers/pipeline_tutorial#using-pipelines-on-a-dataset), ya que un servidor web es básicamente un sistema que espera solicitudes y las trata a medida que llegan.

<!--
To do:
* Check the content of es/pipeline_tutorial.md
* And update the link [en un conjunto de datos] -> (pipeline_tutorial#pipelines-en-un-conjunto-de-datos)
-->

Por lo general, los servidores web están multiplexados (multihilo, asíncrono, etc.) para manejar varias solicitudes simultáneamente. Por otro lado, los flujos de trabajo (y principalmente los modelos subyacentes) no son realmente ideales para el paralelismo; consumen mucha RAM, por lo que es mejor darles todos los recursos disponibles cuando se están ejecutando o es un trabajo intensivo en cómputo.

Vamos a resolver esto haciendo que el servidor web maneje la carga ligera de recibir y enviar solicitudes, y que un único hilo maneje el trabajo real. Este ejemplo va a utilizar `starlette`. El marco de trabajo no es realmente importante, pero es posible que debas ajustar o cambiar el código si estás utilizando otro para lograr el mismo efecto.

Crear `server.py`:

```py
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from transformers import pipeline
import asyncio


async def homepage(request):
payload = await request.body()
string = payload.decode("utf-8")
response_q = asyncio.Queue()
await request.app.model_queue.put((string, response_q))
output = await response_q.get()
return JSONResponse(output)


async def server_loop(q):
pipe = pipeline(model="google-bert/bert-base-uncased")
while True:
(string, response_q) = await q.get()
out = pipe(string)
await response_q.put(out)


app = Starlette(
routes=[
Route("/", homepage, methods=["POST"]),
],
)


@app.on_event("startup")
async def startup_event():
q = asyncio.Queue()
app.model_queue = q
asyncio.create_task(server_loop(q))
```

Ahora puedes empezar con:
```bash
uvicorn server:app
```

Y puedes consultarlo con:
```bash
curl -X POST -d "test [MASK]" http://localhost:8000/
#[{"score":0.7742936015129089,"token":1012,"token_str":".","sequence":"test."},...]
```

¡Y listo, ahora tienes una buena idea de cómo crear un servidor web!

Lo realmente importante es cargar el modelo solo **una vez**, de modo que no haya copias del modelo en el servidor web. De esta manera, no se utiliza RAM innecesariamente. Luego, el mecanismo de queuing (colas) te permite hacer cosas sofisticadas como acumular algunos elementos antes de inferir para usar el agrupamiento dinámico:

<Tip warning={true}>

El ejemplo de código a continuación está escrito intencionalmente como pseudocódigo para facilitar la lectura.
¡No lo ejecutes sin verificar si tiene sentido para los recursos de tu sistema!

</Tip>

```py
(string, rq) = await q.get()
strings = []
queues = []
while True:
try:
(string, rq) = await asyncio.wait_for(q.get(), timeout=0.001) # 1ms
except asyncio.exceptions.TimeoutError:
break
strings.append(string)
queues.append(rq)
strings
outs = pipe(strings, batch_size=len(strings))
for rq, out in zip(queues, outs):
await rq.put(out)
```

Nuevamente, el código propuesto está optimizado para la legibilidad, no para ser el mejor código.
En primer lugar, no hay límite de tamaño de lote, lo cual generalmente no es una buena idea. Luego, el tiempo de espera se restablece en cada obtención de la cola, lo que significa que podrías esperar mucho más de 1ms antes de ejecutar la inferencia (retrasando la primera solicitud en esa cantidad).

Sería mejor tener un único plazo de 1ms.

Esto siempre esperará 1ms incluso si la cola está vacía, lo que podría no ser lo mejor ya que probablemente quieras comenzar a hacer inferencias si no hay nada en la cola. Pero tal vez tenga sentido si el agrupamiento es realmente crucial para tu caso de uso. Nuevamente, no hay una solución única y mejor.


## Algunas cosas que podrías considerar

### Comprobación de errores

Hay muchas cosas que pueden salir mal en producción: falta de memoria, falta de espacio, cargar el modelo podría fallar, la consulta podría ser incorrecta, la consulta podría ser correcta pero aún así fallar debido a una mala configuración del modelo, y así sucesivamente.

Generalmente, es bueno que el servidor muestre los errores al usuario, por lo que agregar muchos bloques `try..except` para mostrar esos errores es una buena idea. Pero ten en cuenta que también puede ser un riesgo de seguridad revelar todos esos errores dependiendo de tu contexto de seguridad.

### Interrupción de circuito

Los servidores web suelen verse mejor cuando hacen interrupciones de circuitos. Significa que devuelven errores adecuados cuando están sobrecargados en lugar de simplemente esperar la consulta indefinidamente. Devolver un error 503 en lugar de esperar un tiempo muy largo o un error 504 después de mucho tiempo.

Esto es relativamente fácil de implementar en el código propuesto ya que hay una sola cola. Mirar el tamaño de la cola es una forma básica de empezar a devolver errores antes de que tu servidor web falle bajo carga.

### Bloqueo del hilo principal

Actualmente, PyTorch no es consciente de la asincronía, y el cálculo bloqueará el hilo principal mientras se ejecuta. Esto significa que sería mejor si PyTorch se viera obligado a ejecutarse en su propio hilo/proceso. Esto no se hizo aquí porque el código es mucho más complejo (principalmente porque los hilos, la asincronía y las colas no se llevan bien juntos). Pero en última instancia, hace lo mismo.

Esto sería importante si la inferencia de elementos individuales fuera larga (> 1s) porque en este caso, significa que cada consulta durante la inferencia tendría que esperar 1s antes de recibir incluso un error.

### Procesamiento por lotes dinámico

En general, el procesamiento por lotes no es necesariamente una mejora respecto a pasar 1 elemento a la vez (ver [procesamiento por lotes](https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching) para más información). Pero puede ser muy efectivo cuando se usa en el entorno correcto. En la API, no hay procesamiento por lotes dinámico por defecto (demasiada oportunidad para una desaceleración). Pero para la inferencia de BLOOM - que es un modelo muy grande - el procesamiento por lotes dinámico es **esencial** para proporcionar una experiencia decente para todos.
13 changes: 10 additions & 3 deletions src/transformers/dynamic_module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Utilities to dynamically load objects from the Hub."""
import filecmp
import importlib
import importlib.util
import os
import re
import shutil
Expand Down Expand Up @@ -196,9 +197,15 @@ def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -
Returns:
`typing.Type`: The class looked for.
"""
name = os.path.normpath(module_path).replace(".py", "").replace(os.path.sep, ".")
module_path = str(Path(HF_MODULES_CACHE) / module_path)
module = importlib.machinery.SourceFileLoader(name, module_path).load_module()
name = os.path.normpath(module_path).rstrip(".py").replace(os.path.sep, ".")
module_spec = importlib.util.spec_from_file_location(name, location=Path(HF_MODULES_CACHE) / module_path)
module = sys.modules.get(name)
if module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
# reload in both cases
module_spec.loader.exec_module(module)
return getattr(module, class_name)


Expand Down
5 changes: 4 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def dtype_byte_size(dtype):
"""
if dtype == torch.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
Expand Down Expand Up @@ -3170,6 +3170,9 @@ def from_pretrained(
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
device_map = hf_quantizer.update_device_map(device_map)

# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value

# Force-set to `True` for more mem efficiency
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,11 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)

if image_to_overwrite.sum() != image_features.shape[:-1].numel():
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,11 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)

if image_to_overwrite.sum() != image_features.shape[:-1].numel():
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,11 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)

if image_to_overwrite.sum() != image_features.shape[:-1].numel():
Expand Down
14 changes: 6 additions & 8 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,12 @@ def validate_environment(self, *args, **kwargs):
}
if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
raise ValueError(
"""
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the
quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules
in 32-bit, you need to set `llm_int8_enable_fp32_cpu_offload=True` and pass a custom `device_map` to
`from_pretrained`. Check
https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
for more details.
"""
"Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
"quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
"in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to "
"`from_pretrained`. Check "
"https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu "
"for more details. "
)

if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.39.0"):
Expand Down
14 changes: 6 additions & 8 deletions src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,12 @@ def validate_environment(self, *args, **kwargs):
}
if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
raise ValueError(
"""
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the
quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules
in 32-bit, you need to set `llm_int8_enable_fp32_cpu_offload=True` and pass a custom `device_map` to
`from_pretrained`. Check
https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
for more details.
"""
"Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
"quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
"in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to "
"`from_pretrained`. Check "
"https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu "
"for more details. "
)

if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.37.2"):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/quantizer_eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,4 @@ def is_serializable(self):

@property
def is_trainable(self) -> bool:
return False
return True
26 changes: 26 additions & 0 deletions tests/models/llava_next/test_modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,29 @@ def test_small_model_integration_test_batch(self):

EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)

@slow
@require_bitsandbytes
def test_small_model_integration_test_unk_token(self):
# related to (#29835)
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)

prompt_with_unk = "[INST] <image>\nWhat is shown in this <unk> image? [/INST]"
inputs = self.processor(prompt_with_unk, self.image, return_tensors="pt")

# verify single forward pass
inputs = inputs.to(torch_device)
with torch.no_grad():
output = model(**inputs)

# verify generation
output = model.generate(**inputs, max_new_tokens=40)
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart' # fmt: skip

self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
32 changes: 31 additions & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
)
from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint
from transformers.modeling_utils import (
_find_disjoint,
_find_identical,
dtype_byte_size,
shard_checkpoint,
)

# Fake pretrained models for tests
class BaseModel(PreTrainedModel):
Expand Down Expand Up @@ -465,6 +470,31 @@ def test_model_from_pretrained_attn_implementation(self):
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)

def test_torch_dtype_byte_sizes(self):
torch_dtypes_and_bytes = [
(torch.double, 8),
(torch.float64, 8),
(torch.float, 4),
(torch.float32, 4),
(torch.half, 2),
(torch.float16, 2),
(torch.bfloat16, 2),
(torch.long, 8),
(torch.int64, 8),
(torch.int, 4),
(torch.int32, 4),
(torch.short, 2),
(torch.int16, 2),
(torch.uint8, 1),
(torch.int8, 1),
(torch.float8_e4m3fn, 1),
(torch.float8_e5m2, 1),
(torch.bool, 0.125),
]

for torch_dtype, bytes_per_element in torch_dtypes_and_bytes:
self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element)

def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config)
Expand Down

0 comments on commit 7638069

Please sign in to comment.