16
16
17
17
18
18
IDENTITY : Callable = lambda rows , ** kwargs : rows
19
+ DIFF : Callable = lambda rows , ** kwargs : generate_row_diffs (rows , ** kwargs )
20
+ SMOOTH : Callable = lambda rows , ** kwargs : generate_smooth_rows (rows , ** kwargs )
21
+ DIFF_SMOOTH : Callable = lambda rows , ** kwargs : generate_smooth_rows (generate_row_diffs (rows , ** kwargs ), ** kwargs )
19
22
20
23
21
24
class HighValuesAre (str , Enum ):
@@ -245,7 +248,6 @@ def _load_data_signals(sources: List[DataSource]):
245
248
data_signals_by_key [(source .db_source , d .signal )] = d
246
249
247
250
248
-
249
251
def get_related_signals (signal : DataSignal ) -> List [DataSignal ]:
250
252
return [s for s in data_signals if s != signal and s .signal_basename == signal .signal_basename ]
251
253
@@ -316,8 +318,15 @@ def _resolve_all_signals(source_signals: Union[SourceSignalPair, List[SourceSign
316
318
317
319
Example: SourceSignalPair("jhu-csse", signal=True) would return SourceSignalPair("jhu-csse", [<list of all JHU signals>]).
318
320
"""
319
- return [SourceSignalPair ("src" , ["sig1" , "sig2" ])]
320
-
321
+ if isinstance (source_signals , SourceSignalPair ):
322
+ if source_signals .signal == True :
323
+ source = data_sources_by_id .get (source_signals .source )
324
+ if source :
325
+ return SourceSignalPair (source .source , [s .signal for s in source .signals ])
326
+ return source_signals
327
+ if isinstance (source_signals , list ):
328
+ return [_resolve_all_signals (pair , data_sources_by_id ) for pair in source_signals ]
329
+ raise TypeError ("source_signals is not Union[SourceSignalPair, List[SourceSignalPair]]." )
321
330
322
331
323
332
def _reindex_iterable (iterable : Iterable [Dict ], time_pairs : List [TimePair ], fill_value : Optional [int ] = None ) -> Iterable :
@@ -331,17 +340,47 @@ def _reindex_iterable(iterable: Iterable[Dict], time_pairs: List[TimePair], fill
331
340
if time_pairs is None :
332
341
return iterable
333
342
343
+ _iterable = peekable (iterable )
344
+ first_item = _iterable .peek ()
334
345
day_range_index = get_day_range (time_pairs )
335
346
for day in day_range_index .time_values :
336
- if day in next_value (iterable ):
337
- return next_value (iterable )
347
+ index_item = first_item .copy ()
348
+ index_item .update ({
349
+ "time_value" : day ,
350
+ "value" : fill_value ,
351
+ "stderr" : None ,
352
+ "sample_size" : None ,
353
+ "missing_value" : Nans .NOT_MISSING if fill_value is not None else Nans .NOT_APPLICABLE ,
354
+ "missing_stderr" : Nans .NOT_APPLICABLE ,
355
+ "missing_sample_size" : Nans .NOT_APPLICABLE
356
+ })
357
+ new_item = _iterable .peek (default = index_item )
358
+ if day == new_item .get ("time_value" ):
359
+ yield next (_iterable , index_item )
338
360
else :
339
- yield updated_default_value
361
+ yield index_item
340
362
341
363
342
364
def _get_base_signal_transform (signal : Union [DataSignal , Tuple [str , str ]], data_signals_by_key : Dict [Tuple [str , str ], DataSignal ] = data_signals_by_key ) -> Callable :
343
365
"""Given a DataSignal, return the transformation that needs to be applied to its base signal to derive the signal."""
344
- return IDENTITY
366
+ if isinstance (signal , DataSignal ):
367
+ parent_signal = data_signals_by_key .get ((signal .source , signal .signal_basename ))
368
+ if signal .format not in [SignalFormat .raw , SignalFormat .raw_count , SignalFormat .count ] or not signal .compute_from_base or not parent_signal :
369
+ return IDENTITY
370
+ if signal .is_cumulative and signal .is_smoothed :
371
+ return SMOOTH
372
+ if not signal .is_cumulative and not signal .is_smoothed :
373
+ return DIFF if parent_signal .is_cumulative else IDENTITY
374
+ if not signal .is_cumulative and signal .is_smoothed :
375
+ return DIFF_SMOOTH if parent_signal .is_cumulative else SMOOTH
376
+ return IDENTITY
377
+ if isinstance (signal , tuple ):
378
+ signal = data_signals_by_key .get (signal )
379
+ if signal :
380
+ return _get_base_signal_transform (signal , data_signals_by_key )
381
+ return IDENTITY
382
+
383
+ raise TypeError ("signal must be either Tuple[str, str] or DataSignal." )
345
384
346
385
347
386
def get_transform_types (source_signal_pairs : List [SourceSignalPair ], data_sources_by_id : Dict [str , DataSource ] = data_sources_by_id , data_signals_by_key : Dict [Tuple [str , str ], DataSignal ] = data_signals_by_key ) -> Set [Callable ]:
@@ -352,7 +391,17 @@ def get_transform_types(source_signal_pairs: List[SourceSignalPair], data_source
352
391
353
392
Used to pad the user DB query with extra days.
354
393
"""
355
- return set ([IDENTITY ])
394
+ source_signal_pairs = _resolve_all_signals (source_signal_pairs , data_sources_by_id )
395
+
396
+ transform_types = set ()
397
+ for source_signal_pair in source_signal_pairs :
398
+ source_name = source_signal_pair .source
399
+ signal_names = source_signal_pair .signal
400
+ if isinstance (signal_names , bool ):
401
+ continue
402
+ transform_types |= {_get_base_signal_transform ((source_name , signal_name ), data_signals_by_key = data_signals_by_key ) for signal_name in signal_names }
403
+
404
+ return transform_types
356
405
357
406
358
407
def get_pad_length (source_signal_pairs : List [SourceSignalPair ], smoother_window_length : int , data_sources_by_id : Dict [str , DataSource ] = data_sources_by_id , data_signals_by_key : Dict [Tuple [str , str ], DataSignal ] = data_signals_by_key ):
@@ -364,10 +413,14 @@ def get_pad_length(source_signal_pairs: List[SourceSignalPair], smoother_window_
364
413
Used to pad the user DB query with extra days.
365
414
"""
366
415
transform_types = get_transform_types (source_signal_pairs , data_sources_by_id = data_sources_by_id , data_signals_by_key = data_signals_by_key )
416
+ pad_length = [0 ]
417
+ if DIFF_SMOOTH in transform_types :
418
+ pad_length .append (smoother_window_length )
367
419
if SMOOTH in transform_types :
368
- return 7
369
- else :
370
- return 0
420
+ pad_length .append (smoother_window_length - 1 )
421
+ if DIFF in transform_types :
422
+ pad_length .append (1 )
423
+ return max (pad_length )
371
424
372
425
373
426
def pad_time_pairs (time_pairs : List [TimePair ], pad_length : int ) -> List [TimePair ]:
@@ -378,7 +431,15 @@ def pad_time_pairs(time_pairs: List[TimePair], pad_length: int) -> List[TimePair
378
431
379
432
Used to pad the user DB query with extra days.
380
433
"""
381
- return [TimePair ("day" , [(20210401 , 20210407 )])]
434
+ if pad_length < 0 :
435
+ raise ValueError ("pad_length should non-negative." )
436
+ if pad_length == 0 :
437
+ return time_pairs .copy ()
438
+ min_time = min (time_value if isinstance (time_value , int ) else time_value [0 ] for time_pair in time_pairs if not isinstance (time_pair .time_values , bool ) for time_value in time_pair .time_values )
439
+ padded_time = TimePair ("day" , [(shift_time_value (min_time , - 1 * pad_length ), min_time )])
440
+ new_time_pairs = time_pairs .copy ()
441
+ new_time_pairs .append (padded_time )
442
+ return new_time_pairs
382
443
383
444
384
445
def pad_time_window (time_window : Tuple [int , int ], pad_length : int ) -> Tuple [int , int ]:
@@ -389,7 +450,12 @@ def pad_time_window(time_window: Tuple[int, int], pad_length: int) -> Tuple[int,
389
450
390
451
Used to pad the user DB query with extra days.
391
452
"""
392
- return (20210401 , 20210407 )
453
+ if pad_length < 0 :
454
+ raise ValueError ("pad_length should non-negative." )
455
+ if pad_length == 0 :
456
+ return time_window
457
+ min_time , max_time = time_window
458
+ return (shift_time_value (min_time , - 1 * pad_length ), max_time )
393
459
394
460
395
461
def get_day_range (time_pairs : Union [TimePair , List [TimePair ]]) -> TimePair :
@@ -401,11 +467,23 @@ def get_day_range(time_pairs: Union[TimePair, List[TimePair]]) -> TimePair:
401
467
402
468
Used to produce a contiguous time index for time series operations.
403
469
"""
404
- return TimePair ("day" , [20210407 , 20210408 , 20210409 , 20210410 ])
470
+ if isinstance (time_pairs , TimePair ):
471
+ time_pair = time_pairs
472
+ time_values = sorted (list (set ().union (* (set (time_value_range (time_value )) if isinstance (time_value , tuple ) else {time_value } for time_value in time_pair .time_values ))))
473
+ if True in time_values :
474
+ raise ValueError ("TimePair.time_value should not be a bool when calling get_day_range." )
475
+ return TimePair (time_pair .time_type , time_values )
476
+ elif isinstance (time_pairs , list ):
477
+ if not all (time_pair .time_type == "day" for time_pair in time_pairs ):
478
+ raise ValueError ("get_day_range only supports day time_type pairs." )
479
+ time_values = sorted (list (set ().union (* (get_day_range (time_pair ).time_values for time_pair in time_pairs ))))
480
+ return TimePair ("day" , time_values )
481
+ else :
482
+ raise ValueError ("get_day_range received an unsupported type as input." )
405
483
406
484
407
485
def _generate_transformed_rows (
408
- parsed_rows : Iterable [Dict ], time_pairs : Optional [List [TimePair ]] = None , transform_dict : Optional [Dict [str , List [Tuple [str , str ]]]]= None , transform_args : Optional [Dict ] = None , group_keyfunc : Optional [Callable ] = None , data_signals_by_key : Dict [Tuple [str , str ], DataSignal ] = data_signals_by_key ,
486
+ parsed_rows : Iterable [Dict ], time_pairs : Optional [List [TimePair ]] = None , transform_dict : Optional [Dict [Tuple [ str , str ], List [Tuple [str , str ]]]] = None , transform_args : Optional [Dict ] = None , group_keyfunc : Optional [Callable ] = None , data_signals_by_key : Dict [Tuple [str , str ], DataSignal ] = data_signals_by_key ,
409
487
) -> Iterable [Dict ]:
410
488
"""Applies time-series transformations to streamed rows from a database.
411
489
@@ -441,13 +519,19 @@ def _generate_transformed_rows(
441
519
for key , group in groupby (parsed_rows , group_keyfunc ):
442
520
_ , _ , source_name , signal_name = key
443
521
# Extract the list of derived signals.
522
+ derived_signals : List [Tuple [str , str ]] = transform_dict .get ((source_name , signal_name ), [(source_name , signal_name )])
444
523
# Create a list of source-signal pairs along with the transformation required for the signal.
524
+ source_signal_pairs_and_group_transforms : List [Tuple [Tuple [str , str ], Callable ]] = [((derived_source , derived_signal ), _get_base_signal_transform ((derived_source , derived_signal ), data_signals_by_key )) for (derived_source , derived_signal ) in derived_signals ]
445
525
# Put the current time series on a contiguous time index.
526
+ group_continguous_time = _reindex_iterable (group , time_pairs , fill_value = transform_args .get ("pad_fill_value" )) if time_pairs else group
446
527
# Create copies of the iterable, with smart memory usage.
528
+ group_iter_copies : Iterable [Iterable [Dict ]] = tee (group_continguous_time , len (source_signal_pairs_and_group_transforms ))
447
529
# Create a list of transformed group iterables, remembering their derived name as needed.
530
+ transformed_group_rows : Iterable [Iterable [Dict ]] = (zip (transform (rows , ** transform_args ), repeat (key )) for (key , transform ), rows in zip (source_signal_pairs_and_group_transforms , group_iter_copies ))
448
531
# Traverse through the transformed iterables in an interleaved fashion, which makes sure that only a small window
449
532
# of the original iterable (group) is stored in memory.
450
- for row in transform_group (group ):
533
+ for row , (_ , derived_signal ) in interleave_longest (* transformed_group_rows ):
534
+ row ["signal" ] = derived_signal
451
535
yield row
452
536
except Exception as e :
453
537
print (f"Tranformation encountered error of type { type (e )} , with message { e } . Yielding None and stopping." )
@@ -462,6 +546,28 @@ def get_basename_signals(source_signal_pairs: List[SourceSignalPair], data_sourc
462
546
SourceSignalPair("src", signal=["sig_base", "sig_smoothed"]) would return SourceSignalPair("src", signal=["sig_base"]) and a transformation function
463
547
that will take the returned database query for "sig_base" and return both the base time series and the smoothed time series.
464
548
"""
465
- transform_dict = {("src" , "sig_base" ): [("src" , "sig_base" ), ("src" , "sig_smoothed" )]}
549
+ source_signal_pairs = _resolve_all_signals (source_signal_pairs , data_sources_by_id )
550
+ base_signal_pairs : List [SourceSignalPair ] = []
551
+ transform_dict : Dict [Tuple [str , str ], List [Tuple [str , str ]]] = dict ()
552
+
553
+ for pair in source_signal_pairs :
554
+ if isinstance (pair .signal , bool ):
555
+ base_signal_pairs .append (pair )
556
+ continue
557
+
558
+ source_name = pair .source
559
+ signal_names = pair .signal
560
+ signals = []
561
+ for signal_name in signal_names :
562
+ signal = data_signals_by_key .get ((source_name , signal_name ))
563
+ if not signal or not signal .compute_from_base :
564
+ signals .append (signal_name )
565
+ transform_dict .setdefault ((source_name , signal_name ), []).append ((source_name , signal_name ))
566
+ else :
567
+ signals .append (signal .signal_basename )
568
+ transform_dict .setdefault ((source_name , signal .signal_basename ), []).append ((source_name , signal_name ))
569
+ base_signal_pairs .append (SourceSignalPair (pair .source , signals ))
570
+
466
571
row_transform_generator = partial (_generate_transformed_rows , transform_dict = transform_dict , data_signals_by_key = data_signals_by_key )
467
- return SourceSignalPair ("src" , signal = ["sig_base" ]), row_transform_generator
572
+
573
+ return base_signal_pairs , row_transform_generator
0 commit comments