Skip to content

Commit

Permalink
Fix RequestCounter to make it more future-proof (huggingface#27406)
Browse files Browse the repository at this point in the history
* Fix RequestCounter to make it more future-proof

* code quality
  • Loading branch information
Wauplin authored and EduardoPach committed Nov 19, 2023
1 parent da1ab02 commit 8c0ffd3
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 45 deletions.
51 changes: 30 additions & 21 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
import tempfile
import time
import unittest
from collections import defaultdict
from collections.abc import Mapping
from io import StringIO
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
from unittest import mock
from unittest.mock import patch

import huggingface_hub
import requests
import urllib3

from transformers import logging as transformers_logging

Expand Down Expand Up @@ -1983,32 +1984,40 @@ def run_command(command: List[str], return_stdout=False):
class RequestCounter:
"""
Helper class that will count all requests made online.
Might not be robust if urllib3 changes its logging format but should be good enough for us.
Usage:
```py
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
assert counter["GET"] == 0
assert counter["HEAD"] == 1
assert counter.total_calls == 1
```
"""

def __enter__(self):
self.head_request_count = 0
self.get_request_count = 0
self.other_request_count = 0

# Mock `get_session` to count HTTP calls.
self.old_get_session = huggingface_hub.utils._http.get_session
self.session = requests.Session()
self.session.request = self.new_request
huggingface_hub.utils._http.get_session = lambda: self.session
self._counter = defaultdict(int)
self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
self.mock = self.patcher.start()
return self

def __exit__(self, *args, **kwargs):
huggingface_hub.utils._http.get_session = self.old_get_session
def __exit__(self, *args, **kwargs) -> None:
for call in self.mock.call_args_list:
log = call.args[0] % call.args[1:]
for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
if method in log:
self._counter[method] += 1
break
self.patcher.stop()

def new_request(self, method, **kwargs):
if method == "GET":
self.get_request_count += 1
elif method == "HEAD":
self.head_request_count += 1
else:
self.other_request_count += 1
def __getitem__(self, key: str) -> int:
return self._counter[key]

return requests.request(method=method, **kwargs)
@property
def total_calls(self) -> int:
return sum(self._counter.values())


def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
Expand Down
15 changes: 6 additions & 9 deletions tests/models/auto/test_modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,25 +482,22 @@ def test_model_from_flax_suggestion(self):
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")

@unittest.skip(
"Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389"
)
def test_cached_model_has_minimum_calls_to_head(self):
# Make sure we have cached the model.
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
self.assertEqual(counter["GET"], 0)
self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.total_calls, 1)

# With a sharded checkpoint
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
self.assertEqual(counter["GET"], 0)
self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.total_calls, 1)

def test_attr_not_existing(self):
from transformers.models.auto.auto_factory import _LazyAutoMapping
Expand Down
12 changes: 6 additions & 6 deletions tests/models/auto/test_modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,14 @@ def test_cached_model_has_minimum_calls_to_head(self):
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
self.assertEqual(counter["GET"], 0)
self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.total_calls, 1)

# With a sharded checkpoint
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
self.assertEqual(counter["GET"], 0)
self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.total_calls, 1)
9 changes: 3 additions & 6 deletions tests/models/auto/test_tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,11 @@ def test_revision_not_found(self):
):
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")

@unittest.skip(
"Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389"
)
def test_cached_tokenizer_has_minimum_calls_to_head(self):
# Make sure we have cached the tokenizer.
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
self.assertEqual(counter["GET"], 0)
self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.total_calls, 1)
6 changes: 3 additions & 3 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,9 +763,9 @@ def test_cached_pipeline_has_minimum_calls_to_head(self):
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
self.assertEqual(counter["GET"], 0)
self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.total_calls, 1)

@require_torch
def test_chunk_pipeline_batching_single_file(self):
Expand Down

0 comments on commit 8c0ffd3

Please sign in to comment.