Skip to content

Commit a887c9d

Browse files
committed
✨ add needs_import_cache_size configuration
1 parent dc3242a commit a887c9d

File tree

4 files changed

+98
-23
lines changed

4 files changed

+98
-23
lines changed

docs/configuration.rst

+11
Original file line numberDiff line numberDiff line change
@@ -1645,7 +1645,18 @@ keys:
16451645
The related CSS class definition must be done by the user, e.g. by :ref:`own_css`.
16461646
(*optional*) (*default*: ``external_link``)
16471647
1648+
.. _needs_import_cache_size:
16481649
1650+
needs_import_cache_size
1651+
~~~~~~~~~~~~~~~~~~~~~~~
1652+
1653+
.. versionadded:: 3.1.0
1654+
1655+
Sets the maximum number of needs cached by the :ref:`needimport` directive,
1656+
which is used to avoid multiple reads of the same file.
1657+
Note, setting this value too high may lead to high memory usage during the sphinx build.
1658+
1659+
Default: :need_config_default:`import_cache_size`
16491660
16501661
.. _needs_needextend_strict:
16511662

docs/directives/needimport.rst

+5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ The directive also supports URL as argument to download ``needs.json`` from remo
3030
3131
.. needimport:: https://my_company.com/docs/remote-needs.json
3232
33+
.. seealso::
34+
35+
:ref:`needs_import_cache_size`,
36+
to control the cache size for imported needs.
37+
3338
Options
3439
-------
3540

sphinx_needs/config.py

+4
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,10 @@ def __setattr__(self, name: str, value: Any) -> None:
424424
default_factory=list, metadata={"rebuild": "html", "types": (list,)}
425425
)
426426
"""List of external sources to load needs from."""
427+
import_cache_size: int = field(
428+
default=100, metadata={"rebuild": "", "types": (int,)}
429+
)
430+
"""Maximum number of imported needs to cache."""
427431
builder_filter: str = field(
428432
default="is_external==False", metadata={"rebuild": "html", "types": (str,)}
429433
)

sphinx_needs/directives/needimport.py

+78-23
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import json
44
import os
55
import re
6-
from typing import Sequence
6+
import threading
7+
from copy import deepcopy
8+
from typing import Any, OrderedDict, Sequence
79
from urllib.parse import urlparse
810

911
import requests
@@ -52,7 +54,8 @@ class NeedimportDirective(SphinxDirective):
5254

5355
@measure_time("needimport")
5456
def run(self) -> Sequence[nodes.Node]:
55-
# needs_list = {}
57+
needs_config = NeedsSphinxConfig(self.config)
58+
5659
version = self.options.get("version")
5760
filter_string = self.options.get("filter")
5861
id_prefix = self.options.get("id_prefix", "")
@@ -111,21 +114,34 @@ def run(self) -> Sequence[nodes.Node]:
111114
raise ReferenceError(
112115
f"Could not load needs import file {correct_need_import_path}"
113116
)
117+
mtime = os.path.getmtime(correct_need_import_path)
114118

115-
try:
116-
with open(correct_need_import_path) as needs_file:
117-
needs_import_list = json.load(needs_file)
118-
except (OSError, json.JSONDecodeError) as e:
119-
# TODO: Add exception handling
120-
raise SphinxNeedsFileException(correct_need_import_path) from e
121-
122-
errors = check_needs_data(needs_import_list)
123-
if errors.schema:
124-
logger.info(
125-
f"Schema validation errors detected in file {correct_need_import_path}:"
126-
)
127-
for error in errors.schema:
128-
logger.info(f' {error.message} -> {".".join(error.path)}')
119+
if (
120+
needs_import_list := _FileCache.get(correct_need_import_path, mtime)
121+
) is None:
122+
try:
123+
with open(correct_need_import_path) as needs_file:
124+
needs_import_list = json.load(needs_file)
125+
except (OSError, json.JSONDecodeError) as e:
126+
# TODO: Add exception handling
127+
raise SphinxNeedsFileException(correct_need_import_path) from e
128+
129+
errors = check_needs_data(needs_import_list)
130+
if errors.schema:
131+
logger.info(
132+
f"Schema validation errors detected in file {correct_need_import_path}:"
133+
)
134+
for error in errors.schema:
135+
logger.info(f' {error.message} -> {".".join(error.path)}')
136+
else:
137+
_FileCache.set(
138+
correct_need_import_path,
139+
mtime,
140+
needs_import_list,
141+
needs_config.import_cache_size,
142+
)
143+
144+
self.env.note_dependency(correct_need_import_path)
129145

130146
if version is None:
131147
try:
@@ -141,17 +157,17 @@ def run(self) -> Sequence[nodes.Node]:
141157
f"Version {version} not found in needs import file {correct_need_import_path}"
142158
)
143159

144-
needs_config = NeedsSphinxConfig(self.config)
145160
data = needs_import_list["versions"][version]
146161

162+
# TODO this is not exactly NeedsInfoType, because the export removes/adds some keys
163+
needs_list: dict[str, NeedsInfoType] = data["needs"]
164+
147165
if ids := self.options.get("ids"):
148166
id_list = [i.strip() for i in ids.split(",") if i.strip()]
149-
data["needs"] = {
167+
needs_list = {
150168
key: data["needs"][key] for key in id_list if key in data["needs"]
151169
}
152170

153-
# TODO this is not exactly NeedsInfoType, because the export removes/adds some keys
154-
needs_list: dict[str, NeedsInfoType] = data["needs"]
155171
if schema := data.get("needs_schema"):
156172
# Set defaults from schema
157173
defaults = {
@@ -160,7 +176,8 @@ def run(self) -> Sequence[nodes.Node]:
160176
if "default" in value
161177
}
162178
needs_list = {
163-
key: {**defaults, **value} for key, value in needs_list.items()
179+
key: {**defaults, **value} # type: ignore[typeddict-item]
180+
for key, value in needs_list.items()
164181
}
165182

166183
# Filter imported needs
@@ -169,7 +186,8 @@ def run(self) -> Sequence[nodes.Node]:
169186
if filter_string is None:
170187
needs_list_filtered[key] = need
171188
else:
172-
filter_context = need.copy()
189+
# we deepcopy here, to ensure that the original data is not modified
190+
filter_context = deepcopy(need)
173191

174192
# Support both ways of addressing the description, as "description" is used in json file, but
175193
# "content" is the sphinx internal name for this kind of information
@@ -185,7 +203,9 @@ def run(self) -> Sequence[nodes.Node]:
185203
location=(self.env.docname, self.lineno),
186204
)
187205

188-
needs_list = needs_list_filtered
206+
# note we need to deepcopy here, as we are going to modify the data,
207+
# but we want to ensure data referenced from the cache is not modified
208+
needs_list = deepcopy(needs_list_filtered)
189209

190210
# tags update
191211
if tags := [
@@ -265,6 +285,41 @@ def docname(self) -> str:
265285
return self.env.docname
266286

267287

288+
class _ImportCache:
289+
"""A simple cache for imported needs,
290+
mapping a (path, mtime) to a dictionary of needs.
291+
that is thread safe,
292+
and has a maximum size when adding new items.
293+
"""
294+
295+
def __init__(self) -> None:
296+
self._cache: OrderedDict[tuple[str, float], dict[str, Any]] = OrderedDict()
297+
self._need_count = 0
298+
self._lock = threading.Lock()
299+
300+
def set(
301+
self, path: str, mtime: float, value: dict[str, Any], max_size: int
302+
) -> None:
303+
with self._lock:
304+
self._cache[(path, mtime)] = value
305+
self._need_count += len(value)
306+
max_size = max(max_size, 0)
307+
while self._need_count > max_size:
308+
_, value = self._cache.popitem(last=False)
309+
self._need_count -= len(value)
310+
311+
def get(self, path: str, mtime: float) -> dict[str, Any] | None:
312+
with self._lock:
313+
return self._cache.get((path, mtime), None)
314+
315+
def __repr__(self) -> str:
316+
with self._lock:
317+
return f"{self.__class__.__name__}({list(self._cache)})"
318+
319+
320+
_FileCache = _ImportCache()
321+
322+
268323
class VersionNotFound(BaseException):
269324
pass
270325

0 commit comments

Comments
 (0)