4
4
import logging
5
5
import os
6
6
import time
7
+ from datetime import datetime
7
8
from functools import cmp_to_key
8
9
9
10
import requests
@@ -156,6 +157,7 @@ def __init__(
156
157
keys = None ,
157
158
source = "" ,
158
159
cache_time = 300 ,
160
+ ignore_errors_period = 0 ,
159
161
fileformat = "jwks" ,
160
162
keytype = "RSA" ,
161
163
keyusage = None ,
@@ -188,6 +190,8 @@ def __init__(
188
190
self .remote = False
189
191
self .local = False
190
192
self .cache_time = cache_time
193
+ self .ignore_errors_period = ignore_errors_period
194
+ self .ignore_errors_until = None # UNIX timestamp of last error
191
195
self .time_out = 0
192
196
self .etag = ""
193
197
self .source = None
@@ -314,7 +318,11 @@ def do_local_jwk(self, filename):
314
318
Load a JWKS from a local file
315
319
316
320
:param filename: Name of the file from which the JWKS should be loaded
321
+ :return: True if load was successful or False if file hasn't been modified
317
322
"""
323
+ if not self ._local_update_required ():
324
+ return False
325
+
318
326
LOGGER .info ("Reading local JWKS from %s" , filename )
319
327
with open (filename ) as input_file :
320
328
_info = json .load (input_file )
@@ -324,6 +332,7 @@ def do_local_jwk(self, filename):
324
332
self .do_keys ([_info ])
325
333
self .last_local = time .time ()
326
334
self .time_out = self .last_local + self .cache_time
335
+ return True
327
336
328
337
def do_local_der (self , filename , keytype , keyusage = None , kid = "" ):
329
338
"""
@@ -332,7 +341,11 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
332
341
:param filename: Name of the file
333
342
:param keytype: Presently 'rsa' and 'ec' supported
334
343
:param keyusage: encryption ('enc') or signing ('sig') or both
344
+ :return: True if load was successful or False if file hasn't been modified
335
345
"""
346
+ if not self ._local_update_required ():
347
+ return False
348
+
336
349
LOGGER .info ("Reading local DER from %s" , filename )
337
350
key_args = {}
338
351
_kty = keytype .lower ()
@@ -355,16 +368,25 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
355
368
self .do_keys ([key_args ])
356
369
self .last_local = time .time ()
357
370
self .time_out = self .last_local + self .cache_time
371
+ return True
358
372
359
373
def do_remote (self ):
360
374
"""
361
375
Load a JWKS from a webpage.
362
376
363
- :return: True or False if load was successful
377
+ :return: True if load was successful or False if remote hasn't been modified
364
378
"""
365
379
# if self.verify_ssl is not None:
366
380
# self.httpc_params["verify"] = self.verify_ssl
367
381
382
+ if self .ignore_errors_until and time .time () < self .ignore_errors_until :
383
+ LOGGER .warning (
384
+ "Not reading remote JWKS from %s (in error holddown until %s)" ,
385
+ self .source ,
386
+ datetime .fromtimestamp (self .ignore_errors_until ),
387
+ )
388
+ return False
389
+
368
390
LOGGER .info ("Reading remote JWKS from %s" , self .source )
369
391
try :
370
392
LOGGER .debug ("KeyBundle fetch keys from: %s" , self .source )
@@ -378,7 +400,10 @@ def do_remote(self):
378
400
LOGGER .error (err )
379
401
raise UpdateFailed (REMOTE_FAILED .format (self .source , str (err )))
380
402
381
- if _http_resp .status_code == 200 : # New content
403
+ load_successful = _http_resp .status_code == 200
404
+ not_modified = _http_resp .status_code == 304
405
+
406
+ if load_successful :
382
407
self .time_out = time .time () + self .cache_time
383
408
384
409
self .imp_jwks = self ._parse_remote_response (_http_resp )
@@ -390,25 +415,27 @@ def do_remote(self):
390
415
self .do_keys (self .imp_jwks ["keys" ])
391
416
except KeyError :
392
417
LOGGER .error ("No 'keys' keyword in JWKS" )
418
+ self .ignore_errors_until = time .time () + self .ignore_errors_period
393
419
raise UpdateFailed (MALFORMED .format (self .source ))
394
420
395
421
if hasattr (_http_resp , "headers" ):
396
422
headers = getattr (_http_resp , "headers" )
397
423
self .last_remote = headers .get ("last-modified" ) or headers .get ("date" )
398
-
399
- elif _http_resp .status_code == 304 : # Not modified
424
+ elif not_modified :
400
425
LOGGER .debug ("%s not modified since %s" , self .source , self .last_remote )
401
426
self .time_out = time .time () + self .cache_time
402
-
403
427
else :
404
428
LOGGER .warning (
405
429
"HTTP status %d reading remote JWKS from %s" ,
406
430
_http_resp .status_code ,
407
431
self .source ,
408
432
)
433
+ self .ignore_errors_until = time .time () + self .ignore_errors_period
409
434
raise UpdateFailed (REMOTE_FAILED .format (self .source , _http_resp .status_code ))
435
+
410
436
self .last_updated = time .time ()
411
- return True
437
+ self .ignore_errors_until = None
438
+ return load_successful
412
439
413
440
def _parse_remote_response (self , response ):
414
441
"""
@@ -433,23 +460,20 @@ def _parse_remote_response(self, response):
433
460
return None
434
461
435
462
def _uptodate (self ):
436
- res = False
437
463
if self .remote or self .local :
438
464
if time .time () > self .time_out :
439
- if self .local and not self ._local_update_required ():
440
- res = True
441
- elif self .update ():
442
- res = True
443
- return res
465
+ return self .update ()
466
+ return False
444
467
445
468
def update (self ):
446
469
"""
447
470
Reload the keys if necessary.
448
471
449
472
This is a forced update, will happen even if cache time has not elapsed.
450
473
Replaced keys will be marked as inactive and not removed.
474
+
475
+ :return: True if update was ok or False if we encountered an error during update.
451
476
"""
452
- res = True # An update was successful
453
477
if self .source :
454
478
_old_keys = self ._keys # just in case
455
479
@@ -459,24 +483,27 @@ def update(self):
459
483
try :
460
484
if self .local :
461
485
if self .fileformat in ["jwks" , "jwk" ]:
462
- self .do_local_jwk (self .source )
486
+ updated = self .do_local_jwk (self .source )
463
487
elif self .fileformat == "der" :
464
- self .do_local_der (self .source , self .keytype , self .keyusage )
488
+ updated = self .do_local_der (self .source , self .keytype , self .keyusage )
465
489
elif self .remote :
466
- res = self .do_remote ()
490
+ updated = self .do_remote ()
467
491
except Exception as err :
468
492
LOGGER .error ("Key bundle update failed: %s" , err )
469
493
self ._keys = _old_keys # restore
470
494
return False
471
495
472
- now = time .time ()
473
- for _key in _old_keys :
474
- if _key not in self ._keys :
475
- if not _key .inactive_since : # If already marked don't mess
476
- _key .inactive_since = now
477
- self ._keys .append (_key )
496
+ if updated :
497
+ now = time .time ()
498
+ for _key in _old_keys :
499
+ if _key not in self ._keys :
500
+ if not _key .inactive_since : # If already marked don't mess
501
+ _key .inactive_since = now
502
+ self ._keys .append (_key )
503
+ else :
504
+ self ._keys = _old_keys
478
505
479
- return res
506
+ return True
480
507
481
508
def get (self , typ = "" , only_active = True ):
482
509
"""
0 commit comments