From 8c0ffd3f7ed6ec956044456406156a294059a7f3 Mon Sep 17 00:00:00 2001 From: Lucain Date: Thu, 9 Nov 2023 18:53:26 +0100 Subject: [PATCH] Fix RequestCounter to make it more future-proof (#27406) * Fix RequestCounter to make it more future-proof * code quality --- src/transformers/testing_utils.py | 51 ++++++++++++--------- tests/models/auto/test_modeling_auto.py | 15 +++--- tests/models/auto/test_modeling_tf_auto.py | 12 ++--- tests/models/auto/test_tokenization_auto.py | 9 ++-- tests/pipelines/test_pipelines_common.py | 6 +-- 5 files changed, 48 insertions(+), 45 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 815a13c9e96daa..eb21cbac2303e6 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -29,14 +29,15 @@ import tempfile import time import unittest +from collections import defaultdict from collections.abc import Mapping from io import StringIO from pathlib import Path from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union from unittest import mock +from unittest.mock import patch -import huggingface_hub -import requests +import urllib3 from transformers import logging as transformers_logging @@ -1983,32 +1984,40 @@ def run_command(command: List[str], return_stdout=False): class RequestCounter: """ Helper class that will count all requests made online. + + Might not be robust if urllib3 changes its logging format but should be good enough for us. + + Usage: + ```py + with RequestCounter() as counter: + _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") + assert counter["GET"] == 0 + assert counter["HEAD"] == 1 + assert counter.total_calls == 1 + ``` """ def __enter__(self): - self.head_request_count = 0 - self.get_request_count = 0 - self.other_request_count = 0 - - # Mock `get_session` to count HTTP calls. - self.old_get_session = huggingface_hub.utils._http.get_session - self.session = requests.Session() - self.session.request = self.new_request - huggingface_hub.utils._http.get_session = lambda: self.session + self._counter = defaultdict(int) + self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug) + self.mock = self.patcher.start() return self - def __exit__(self, *args, **kwargs): - huggingface_hub.utils._http.get_session = self.old_get_session + def __exit__(self, *args, **kwargs) -> None: + for call in self.mock.call_args_list: + log = call.args[0] % call.args[1:] + for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"): + if method in log: + self._counter[method] += 1 + break + self.patcher.stop() - def new_request(self, method, **kwargs): - if method == "GET": - self.get_request_count += 1 - elif method == "HEAD": - self.head_request_count += 1 - else: - self.other_request_count += 1 + def __getitem__(self, key: str) -> int: + return self._counter[key] - return requests.request(method=method, **kwargs) + @property + def total_calls(self) -> int: + return sum(self._counter.values()) def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None): diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 55bc3f3999ff85..41f52517483cf1 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -482,25 +482,22 @@ def test_model_from_flax_suggestion(self): with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"): _ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") - @unittest.skip( - "Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389" - ) def test_cached_model_has_minimum_calls_to_head(self): # Make sure we have cached the model. _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") with RequestCounter() as counter: _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") - self.assertEqual(counter.get_request_count, 0) - self.assertEqual(counter.head_request_count, 1) - self.assertEqual(counter.other_request_count, 0) + self.assertEqual(counter["GET"], 0) + self.assertEqual(counter["HEAD"], 1) + self.assertEqual(counter.total_calls, 1) # With a sharded checkpoint _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded") with RequestCounter() as counter: _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded") - self.assertEqual(counter.get_request_count, 0) - self.assertEqual(counter.head_request_count, 1) - self.assertEqual(counter.other_request_count, 0) + self.assertEqual(counter["GET"], 0) + self.assertEqual(counter["HEAD"], 1) + self.assertEqual(counter.total_calls, 1) def test_attr_not_existing(self): from transformers.models.auto.auto_factory import _LazyAutoMapping diff --git a/tests/models/auto/test_modeling_tf_auto.py b/tests/models/auto/test_modeling_tf_auto.py index 2f6fe476158f65..537d48a57e48e5 100644 --- a/tests/models/auto/test_modeling_tf_auto.py +++ b/tests/models/auto/test_modeling_tf_auto.py @@ -301,14 +301,14 @@ def test_cached_model_has_minimum_calls_to_head(self): _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") with RequestCounter() as counter: _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") - self.assertEqual(counter.get_request_count, 0) - self.assertEqual(counter.head_request_count, 1) - self.assertEqual(counter.other_request_count, 0) + self.assertEqual(counter["GET"], 0) + self.assertEqual(counter["HEAD"], 1) + self.assertEqual(counter.total_calls, 1) # With a sharded checkpoint _ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded") with RequestCounter() as counter: _ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded") - self.assertEqual(counter.get_request_count, 0) - self.assertEqual(counter.head_request_count, 1) - self.assertEqual(counter.other_request_count, 0) + self.assertEqual(counter["GET"], 0) + self.assertEqual(counter["HEAD"], 1) + self.assertEqual(counter.total_calls, 1) diff --git a/tests/models/auto/test_tokenization_auto.py b/tests/models/auto/test_tokenization_auto.py index 40dc99cd136887..597c995b6e3227 100644 --- a/tests/models/auto/test_tokenization_auto.py +++ b/tests/models/auto/test_tokenization_auto.py @@ -419,14 +419,11 @@ def test_revision_not_found(self): ): _ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa") - @unittest.skip( - "Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389" - ) def test_cached_tokenizer_has_minimum_calls_to_head(self): # Make sure we have cached the tokenizer. _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") with RequestCounter() as counter: _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") - self.assertEqual(counter.get_request_count, 0) - self.assertEqual(counter.head_request_count, 1) - self.assertEqual(counter.other_request_count, 0) + self.assertEqual(counter["GET"], 0) + self.assertEqual(counter["HEAD"], 1) + self.assertEqual(counter.total_calls, 1) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 56467bdc4b8b8d..e760d279014640 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -763,9 +763,9 @@ def test_cached_pipeline_has_minimum_calls_to_head(self): _ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert") with RequestCounter() as counter: _ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert") - self.assertEqual(counter.get_request_count, 0) - self.assertEqual(counter.head_request_count, 1) - self.assertEqual(counter.other_request_count, 0) + self.assertEqual(counter["GET"], 0) + self.assertEqual(counter["HEAD"], 1) + self.assertEqual(counter.total_calls, 1) @require_torch def test_chunk_pipeline_batching_single_file(self):