Skip to content

Commit 1f10d62

Browse files
committed
✨ add needs_import_cache_size configuration
1 parent 6776e98 commit 1f10d62

File tree

4 files changed

+96
-23
lines changed

4 files changed

+96
-23
lines changed

docs/configuration.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,7 +1645,18 @@ keys:
16451645
The related CSS class definition must be done by the user, e.g. by :ref:`own_css`.
16461646
(*optional*) (*default*: ``external_link``)
16471647
1648+
.. _needs_import_cache_size:
16481649
1650+
needs_import_cache_size
1651+
~~~~~~~~~~~~~~~~~~~~~~~
1652+
1653+
.. versionadded:: 3.1.0
1654+
1655+
Sets the maximum number of needs cached by the :ref:`needimport` directive,
1656+
which is used to avoid multiple reads of the same file.
1657+
Note, setting this value too high may lead to high memory usage during the sphinx build.
1658+
1659+
Default: :need_config_default:`import_cache_size`
16491660
16501661
.. _needs_needextend_strict:
16511662

docs/directives/needimport.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ The directive also supports URL as argument to download ``needs.json`` from remo
3030
3131
.. needimport:: https://my_company.com/docs/remote-needs.json
3232
33+
.. seealso::
34+
35+
:ref:`needs_import_cache_size`,
36+
to control the cache size for imported needs.
37+
3338
Options
3439
-------
3540

sphinx_needs/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,10 @@ def __setattr__(self, name: str, value: Any) -> None:
424424
default_factory=list, metadata={"rebuild": "html", "types": (list,)}
425425
)
426426
"""List of external sources to load needs from."""
427+
import_cache_size: int = field(
428+
default=100, metadata={"rebuild": "html", "types": (int,)}
429+
)
430+
"""Maximum number of imported needs to cache."""
427431
builder_filter: str = field(
428432
default="is_external==False", metadata={"rebuild": "html", "types": (str,)}
429433
)

sphinx_needs/directives/needimport.py

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import json
44
import os
55
import re
6-
from typing import Sequence
6+
import threading
7+
from copy import deepcopy
8+
from typing import Any, OrderedDict, Sequence
79
from urllib.parse import urlparse
810

911
import requests
@@ -52,7 +54,8 @@ class NeedimportDirective(SphinxDirective):
5254

5355
@measure_time("needimport")
5456
def run(self) -> Sequence[nodes.Node]:
55-
# needs_list = {}
57+
needs_config = NeedsSphinxConfig(self.config)
58+
5659
version = self.options.get("version")
5760
filter_string = self.options.get("filter")
5861
id_prefix = self.options.get("id_prefix", "")
@@ -111,21 +114,32 @@ def run(self) -> Sequence[nodes.Node]:
111114
raise ReferenceError(
112115
f"Could not load needs import file {correct_need_import_path}"
113116
)
117+
mtime = os.path.getmtime(correct_need_import_path)
114118

115-
try:
116-
with open(correct_need_import_path) as needs_file:
117-
needs_import_list = json.load(needs_file)
118-
except json.JSONDecodeError as e:
119-
# TODO: Add exception handling
120-
raise SphinxNeedsFileException(correct_need_import_path) from e
121-
122-
errors = check_needs_data(needs_import_list)
123-
if errors.schema:
124-
logger.info(
125-
f"Schema validation errors detected in file {correct_need_import_path}:"
126-
)
127-
for error in errors.schema:
128-
logger.info(f' {error.message} -> {".".join(error.path)}')
119+
if (
120+
needs_import_list := _FileCache.get(correct_need_import_path, mtime)
121+
) is None:
122+
try:
123+
with open(correct_need_import_path) as needs_file:
124+
needs_import_list = json.load(needs_file)
125+
except json.JSONDecodeError as e:
126+
# TODO: Add exception handling
127+
raise SphinxNeedsFileException(correct_need_import_path) from e
128+
129+
errors = check_needs_data(needs_import_list)
130+
if errors.schema:
131+
logger.info(
132+
f"Schema validation errors detected in file {correct_need_import_path}:"
133+
)
134+
for error in errors.schema:
135+
logger.info(f' {error.message} -> {".".join(error.path)}')
136+
else:
137+
_FileCache.set(
138+
correct_need_import_path,
139+
mtime,
140+
needs_import_list,
141+
needs_config.import_cache_size,
142+
)
129143

130144
if version is None:
131145
try:
@@ -141,17 +155,17 @@ def run(self) -> Sequence[nodes.Node]:
141155
f"Version {version} not found in needs import file {correct_need_import_path}"
142156
)
143157

144-
needs_config = NeedsSphinxConfig(self.config)
145158
data = needs_import_list["versions"][version]
146159

160+
# TODO this is not exactly NeedsInfoType, because the export removes/adds some keys
161+
needs_list: dict[str, NeedsInfoType] = data["needs"]
162+
147163
if ids := self.options.get("ids"):
148164
id_list = [i.strip() for i in ids.split(",") if i.strip()]
149-
data["needs"] = {
165+
needs_list = {
150166
key: data["needs"][key] for key in id_list if key in data["needs"]
151167
}
152168

153-
# TODO this is not exactly NeedsInfoType, because the export removes/adds some keys
154-
needs_list: dict[str, NeedsInfoType] = data["needs"]
155169
if schema := data.get("needs_schema"):
156170
# Set defaults from schema
157171
defaults = {
@@ -160,7 +174,8 @@ def run(self) -> Sequence[nodes.Node]:
160174
if "default" in value
161175
}
162176
needs_list = {
163-
key: {**defaults, **value} for key, value in needs_list.items()
177+
key: {**defaults, **value} # type: ignore[typeddict-item]
178+
for key, value in needs_list.items()
164179
}
165180

166181
# Filter imported needs
@@ -169,7 +184,8 @@ def run(self) -> Sequence[nodes.Node]:
169184
if filter_string is None:
170185
needs_list_filtered[key] = need
171186
else:
172-
filter_context = need.copy()
187+
# we deepcopy here, to ensure that the original data is not modified
188+
filter_context = deepcopy(need)
173189

174190
# Support both ways of addressing the description, as "description" is used in json file, but
175191
# "content" is the sphinx internal name for this kind of information
@@ -185,7 +201,9 @@ def run(self) -> Sequence[nodes.Node]:
185201
location=(self.env.docname, self.lineno),
186202
)
187203

188-
needs_list = needs_list_filtered
204+
# note we need to deepcopy here, as we are going to modify the data,
205+
# but we want to ensure data referenced from the cache is not modified
206+
needs_list = deepcopy(needs_list_filtered)
189207

190208
# If we need to set an id prefix, we also need to manipulate all used ids in the imported data.
191209
extra_links = needs_config.extra_links
@@ -283,6 +301,41 @@ def docname(self) -> str:
283301
return self.env.docname
284302

285303

304+
class _ImportCache:
305+
"""A simple cache for imported needs,
306+
mapping a (path, mtime) to a dictionary of needs.
307+
that is thread safe,
308+
and has a maximum size when adding new items.
309+
"""
310+
311+
def __init__(self) -> None:
312+
self._cache: OrderedDict[tuple[str, float], dict[str, Any]] = OrderedDict()
313+
self._need_count = 0
314+
self._lock = threading.Lock()
315+
316+
def set(
317+
self, path: str, mtime: float, value: dict[str, Any], max_size: int
318+
) -> None:
319+
with self._lock:
320+
self._cache[(path, mtime)] = value
321+
self._need_count += len(value)
322+
max_size = max(max_size, 0)
323+
while self._need_count > max_size:
324+
_, value = self._cache.popitem(last=False)
325+
self._need_count -= len(value)
326+
327+
def get(self, path: str, mtime: float) -> dict[str, Any] | None:
328+
with self._lock:
329+
return self._cache.get((path, mtime), None)
330+
331+
def __repr__(self) -> str:
332+
with self._lock:
333+
return f"{self.__class__.__name__}({list(self._cache)})"
334+
335+
336+
_FileCache = _ImportCache()
337+
338+
286339
class VersionNotFound(BaseException):
287340
pass
288341

0 commit comments

Comments
 (0)