diff --git a/docs/configuration.rst b/docs/configuration.rst index 9783c9f25..5f093b828 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -1645,7 +1645,18 @@ keys: The related CSS class definition must be done by the user, e.g. by :ref:`own_css`. (*optional*) (*default*: ``external_link``) +.. _needs_import_cache_size: +needs_import_cache_size +~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 3.1.0 + +Sets the maximum number of needs cached by the :ref:`needimport` directive, +which is used to avoid multiple reads of the same file. +Note, setting this value too high may lead to high memory usage during the sphinx build. + +Default: :need_config_default:`import_cache_size` .. _needs_needextend_strict: diff --git a/docs/directives/needimport.rst b/docs/directives/needimport.rst index e1372a03a..969c65f48 100644 --- a/docs/directives/needimport.rst +++ b/docs/directives/needimport.rst @@ -30,6 +30,11 @@ The directive also supports URL as argument to download ``needs.json`` from remo .. needimport:: https://my_company.com/docs/remote-needs.json +.. seealso:: + + :ref:`needs_import_cache_size`, + to control the cache size for imported needs. + Options ------- diff --git a/sphinx_needs/config.py b/sphinx_needs/config.py index 53ee2853c..957164662 100644 --- a/sphinx_needs/config.py +++ b/sphinx_needs/config.py @@ -440,6 +440,10 @@ def __setattr__(self, name: str, value: Any) -> None: default_factory=list, metadata={"rebuild": "html", "types": (list,)} ) """List of external sources to load needs from.""" + import_cache_size: int = field( + default=100, metadata={"rebuild": "", "types": (int,)} + ) + """Maximum number of imported needs to cache.""" builder_filter: str = field( default="is_external==False", metadata={"rebuild": "html", "types": (str,)} ) diff --git a/sphinx_needs/directives/needimport.py b/sphinx_needs/directives/needimport.py index 435239ff7..8e85d258e 100644 --- a/sphinx_needs/directives/needimport.py +++ b/sphinx_needs/directives/needimport.py @@ -3,7 +3,9 @@ import json import os import re -from typing import Sequence +import threading +from copy import deepcopy +from typing import Any, OrderedDict, Sequence from urllib.parse import urlparse import requests @@ -52,7 +54,8 @@ class NeedimportDirective(SphinxDirective): @measure_time("needimport") def run(self) -> Sequence[nodes.Node]: - # needs_list = {} + needs_config = NeedsSphinxConfig(self.config) + version = self.options.get("version") filter_string = self.options.get("filter") id_prefix = self.options.get("id_prefix", "") @@ -111,21 +114,34 @@ def run(self) -> Sequence[nodes.Node]: raise ReferenceError( f"Could not load needs import file {correct_need_import_path}" ) + mtime = os.path.getmtime(correct_need_import_path) - try: - with open(correct_need_import_path) as needs_file: - needs_import_list = json.load(needs_file) - except (OSError, json.JSONDecodeError) as e: - # TODO: Add exception handling - raise SphinxNeedsFileException(correct_need_import_path) from e - - errors = check_needs_data(needs_import_list) - if errors.schema: - logger.info( - f"Schema validation errors detected in file {correct_need_import_path}:" - ) - for error in errors.schema: - logger.info(f' {error.message} -> {".".join(error.path)}') + if ( + needs_import_list := _FileCache.get(correct_need_import_path, mtime) + ) is None: + try: + with open(correct_need_import_path) as needs_file: + needs_import_list = json.load(needs_file) + except (OSError, json.JSONDecodeError) as e: + # TODO: Add exception handling + raise SphinxNeedsFileException(correct_need_import_path) from e + + errors = check_needs_data(needs_import_list) + if errors.schema: + logger.info( + f"Schema validation errors detected in file {correct_need_import_path}:" + ) + for error in errors.schema: + logger.info(f' {error.message} -> {".".join(error.path)}') + else: + _FileCache.set( + correct_need_import_path, + mtime, + needs_import_list, + needs_config.import_cache_size, + ) + + self.env.note_dependency(correct_need_import_path) if version is None: try: @@ -141,17 +157,17 @@ def run(self) -> Sequence[nodes.Node]: f"Version {version} not found in needs import file {correct_need_import_path}" ) - needs_config = NeedsSphinxConfig(self.config) data = needs_import_list["versions"][version] + # TODO this is not exactly NeedsInfoType, because the export removes/adds some keys + needs_list: dict[str, NeedsInfoType] = data["needs"] + if ids := self.options.get("ids"): id_list = [i.strip() for i in ids.split(",") if i.strip()] - data["needs"] = { + needs_list = { key: data["needs"][key] for key in id_list if key in data["needs"] } - # TODO this is not exactly NeedsInfoType, because the export removes/adds some keys - needs_list: dict[str, NeedsInfoType] = data["needs"] if schema := data.get("needs_schema"): # Set defaults from schema defaults = { @@ -160,7 +176,8 @@ def run(self) -> Sequence[nodes.Node]: if "default" in value } needs_list = { - key: {**defaults, **value} for key, value in needs_list.items() + key: {**defaults, **value} # type: ignore[typeddict-item] + for key, value in needs_list.items() } # Filter imported needs @@ -169,7 +186,8 @@ def run(self) -> Sequence[nodes.Node]: if filter_string is None: needs_list_filtered[key] = need else: - filter_context = need.copy() + # we deepcopy here, to ensure that the original data is not modified + filter_context = deepcopy(need) # Support both ways of addressing the description, as "description" is used in json file, but # "content" is the sphinx internal name for this kind of information @@ -185,7 +203,9 @@ def run(self) -> Sequence[nodes.Node]: location=(self.env.docname, self.lineno), ) - needs_list = needs_list_filtered + # note we need to deepcopy here, as we are going to modify the data, + # but we want to ensure data referenced from the cache is not modified + needs_list = deepcopy(needs_list_filtered) # tags update if tags := [ @@ -265,6 +285,41 @@ def docname(self) -> str: return self.env.docname +class _ImportCache: + """A simple cache for imported needs, + mapping a (path, mtime) to a dictionary of needs. + It's thread safe, + and has a maximum size when adding new items. + """ + + def __init__(self) -> None: + self._cache: OrderedDict[tuple[str, float], dict[str, Any]] = OrderedDict() + self._need_count = 0 + self._lock = threading.Lock() + + def set( + self, path: str, mtime: float, value: dict[str, Any], max_size: int + ) -> None: + with self._lock: + self._cache[(path, mtime)] = value + self._need_count += len(value) + max_size = max(max_size, 0) + while self._need_count > max_size: + _, value = self._cache.popitem(last=False) + self._need_count -= len(value) + + def get(self, path: str, mtime: float) -> dict[str, Any] | None: + with self._lock: + return self._cache.get((path, mtime), None) + + def __repr__(self) -> str: + with self._lock: + return f"{self.__class__.__name__}({list(self._cache)})" + + +_FileCache = _ImportCache() + + class VersionNotFound(BaseException): pass