33import json
44import os
55import re
6- from typing import Sequence
6+ import threading
7+ from copy import deepcopy
8+ from typing import Any , OrderedDict , Sequence
79from urllib .parse import urlparse
810
911import requests
@@ -52,7 +54,8 @@ class NeedimportDirective(SphinxDirective):
5254
5355 @measure_time ("needimport" )
5456 def run (self ) -> Sequence [nodes .Node ]:
55- # needs_list = {}
57+ needs_config = NeedsSphinxConfig (self .config )
58+
5659 version = self .options .get ("version" )
5760 filter_string = self .options .get ("filter" )
5861 id_prefix = self .options .get ("id_prefix" , "" )
@@ -111,21 +114,34 @@ def run(self) -> Sequence[nodes.Node]:
111114 raise ReferenceError (
112115 f"Could not load needs import file { correct_need_import_path } "
113116 )
117+ mtime = os .path .getmtime (correct_need_import_path )
114118
115- try :
116- with open (correct_need_import_path ) as needs_file :
117- needs_import_list = json .load (needs_file )
118- except (OSError , json .JSONDecodeError ) as e :
119- # TODO: Add exception handling
120- raise SphinxNeedsFileException (correct_need_import_path ) from e
121-
122- errors = check_needs_data (needs_import_list )
123- if errors .schema :
124- logger .info (
125- f"Schema validation errors detected in file { correct_need_import_path } :"
126- )
127- for error in errors .schema :
128- logger .info (f' { error .message } -> { "." .join (error .path )} ' )
119+ if (
120+ needs_import_list := _FileCache .get (correct_need_import_path , mtime )
121+ ) is None :
122+ try :
123+ with open (correct_need_import_path ) as needs_file :
124+ needs_import_list = json .load (needs_file )
125+ except (OSError , json .JSONDecodeError ) as e :
126+ # TODO: Add exception handling
127+ raise SphinxNeedsFileException (correct_need_import_path ) from e
128+
129+ errors = check_needs_data (needs_import_list )
130+ if errors .schema :
131+ logger .info (
132+ f"Schema validation errors detected in file { correct_need_import_path } :"
133+ )
134+ for error in errors .schema :
135+ logger .info (f' { error .message } -> { "." .join (error .path )} ' )
136+ else :
137+ _FileCache .set (
138+ correct_need_import_path ,
139+ mtime ,
140+ needs_import_list ,
141+ needs_config .import_cache_size ,
142+ )
143+
144+ self .env .note_dependency (correct_need_import_path )
129145
130146 if version is None :
131147 try :
@@ -141,17 +157,17 @@ def run(self) -> Sequence[nodes.Node]:
141157 f"Version { version } not found in needs import file { correct_need_import_path } "
142158 )
143159
144- needs_config = NeedsSphinxConfig (self .config )
145160 data = needs_import_list ["versions" ][version ]
146161
162+ # TODO this is not exactly NeedsInfoType, because the export removes/adds some keys
163+ needs_list : dict [str , NeedsInfoType ] = data ["needs" ]
164+
147165 if ids := self .options .get ("ids" ):
148166 id_list = [i .strip () for i in ids .split ("," ) if i .strip ()]
149- data [ "needs" ] = {
167+ needs_list = {
150168 key : data ["needs" ][key ] for key in id_list if key in data ["needs" ]
151169 }
152170
153- # TODO this is not exactly NeedsInfoType, because the export removes/adds some keys
154- needs_list : dict [str , NeedsInfoType ] = data ["needs" ]
155171 if schema := data .get ("needs_schema" ):
156172 # Set defaults from schema
157173 defaults = {
@@ -160,7 +176,8 @@ def run(self) -> Sequence[nodes.Node]:
160176 if "default" in value
161177 }
162178 needs_list = {
163- key : {** defaults , ** value } for key , value in needs_list .items ()
179+ key : {** defaults , ** value } # type: ignore[typeddict-item]
180+ for key , value in needs_list .items ()
164181 }
165182
166183 # Filter imported needs
@@ -169,7 +186,8 @@ def run(self) -> Sequence[nodes.Node]:
169186 if filter_string is None :
170187 needs_list_filtered [key ] = need
171188 else :
172- filter_context = need .copy ()
189+ # we deepcopy here, to ensure that the original data is not modified
190+ filter_context = deepcopy (need )
173191
174192 # Support both ways of addressing the description, as "description" is used in json file, but
175193 # "content" is the sphinx internal name for this kind of information
@@ -185,7 +203,9 @@ def run(self) -> Sequence[nodes.Node]:
185203 location = (self .env .docname , self .lineno ),
186204 )
187205
188- needs_list = needs_list_filtered
206+ # note we need to deepcopy here, as we are going to modify the data,
207+ # but we want to ensure data referenced from the cache is not modified
208+ needs_list = deepcopy (needs_list_filtered )
189209
190210 # tags update
191211 if tags := [
@@ -265,6 +285,41 @@ def docname(self) -> str:
265285 return self .env .docname
266286
267287
288+ class _ImportCache :
289+ """A simple cache for imported needs,
290+ mapping a (path, mtime) to a dictionary of needs.
291+ that is thread safe,
292+ and has a maximum size when adding new items.
293+ """
294+
295+ def __init__ (self ) -> None :
296+ self ._cache : OrderedDict [tuple [str , float ], dict [str , Any ]] = OrderedDict ()
297+ self ._need_count = 0
298+ self ._lock = threading .Lock ()
299+
300+ def set (
301+ self , path : str , mtime : float , value : dict [str , Any ], max_size : int
302+ ) -> None :
303+ with self ._lock :
304+ self ._cache [(path , mtime )] = value
305+ self ._need_count += len (value )
306+ max_size = max (max_size , 0 )
307+ while self ._need_count > max_size :
308+ _ , value = self ._cache .popitem (last = False )
309+ self ._need_count -= len (value )
310+
311+ def get (self , path : str , mtime : float ) -> dict [str , Any ] | None :
312+ with self ._lock :
313+ return self ._cache .get ((path , mtime ), None )
314+
315+ def __repr__ (self ) -> str :
316+ with self ._lock :
317+ return f"{ self .__class__ .__name__ } ({ list (self ._cache )} )"
318+
319+
320+ _FileCache = _ImportCache ()
321+
322+
268323class VersionNotFound (BaseException ):
269324 pass
270325
0 commit comments