Skip to content

Commit 8936c4f

Browse files
committed
add test for trim_sss_start from robust workflow
1 parent 6efc084 commit 8936c4f

2 files changed

Lines changed: 278 additions & 252 deletions

File tree

src/quends/base/data_stream.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import math
22

3+
import matplotlib.pyplot as plt
34
import numpy as np
45
import pandas as pd
6+
import scipy.stats as sts
7+
import statsmodels.tsa.stattools as ststls
58
from scipy.optimize import curve_fit
69
from scipy.stats import norm, rankdata
710
from sklearn.preprocessing import MinMaxScaler
@@ -279,6 +282,279 @@ def trim(
279282
new_history.append({"operation": "trim", "options": options})
280283
return DataStream(self.df.iloc[0:0].copy(), _history=new_history)
281284

285+
def trim_sss_start(self, col, workflow):
286+
"""
287+
Identify and trim the signal to the start of the Statistical Steady State (SSS)
288+
289+
Parameters
290+
----------
291+
col : str
292+
The name of the column in `self.df` to analyze for steady state.
293+
workflow : object
294+
A configuration/workflow object containing parameters:
295+
- `_max_lag_frac`: Fraction of data used for autocorrelation lag.
296+
- `_verbosity`: Integer controlling plot and print output levels.
297+
- `_autocorr_sig_level`: Significance level for the Z-test on lags.
298+
- `_decor_multiplier`: Multiplier for the calculated decorrelation length.
299+
- `_std_dev_frac`: Fraction of standard deviation used for tolerance.
300+
- `_fudge_fac`: Constant to prevent zero-tolerance in noiseless signals.
301+
- `_smoothing_window_correction`: Factor to adjust for rolling mean lag.
302+
- `_final_smoothing_window`: Window size for smoothing the metric curves.
303+
304+
Returns
305+
-------
306+
DataStream
307+
A new DataStream object containing the DataFrame trimmed to the SSS start.
308+
Returns an empty DataFrame if no SSS is identified.
309+
"""
310+
# Get the decorrelation length (in number of points)
311+
# Note: this approach assumes signal points are spaced equally in time
312+
n_pts = len(self.df)
313+
max_lag = int(workflow._max_lag_frac * n_pts) # max lag for autocorrelation
314+
315+
acf_vals = ststls.acf(self.df[col].dropna().values, nlags=max_lag)
316+
317+
# plot the autocorrelation function
318+
if workflow._verbosity > 1:
319+
plt.figure(figsize=(10, 6))
320+
plt.stem(range(len(acf_vals)), acf_vals)
321+
plt.xlabel("Lag")
322+
plt.ylabel("Autocorrelation")
323+
plt.title("Autocorrelation Function")
324+
plt.grid()
325+
plt.show()
326+
plt.close()
327+
328+
# Use rigorous statistical measure for decorrelation length
329+
z_critical = sts.norm.ppf(1 - workflow._autocorr_sig_level / 2)
330+
conf_interval = z_critical / np.sqrt(n_pts)
331+
significant_lags = np.where(np.abs(acf_vals[1:]) > conf_interval)[0]
332+
acf_sum = np.sum(np.abs(acf_vals[1:][significant_lags]))
333+
decor_length = int(np.ceil(1 + 2 * acf_sum))
334+
335+
# Set smoothing window as multiple of decorrelation length, but not more than max_lag
336+
decor_index = min(int(workflow._decor_multiplier * decor_length), max_lag)
337+
338+
if workflow._verbosity > 0:
339+
print(
340+
f"stats decorrelation length {decor_length} gives smoothing window of {decor_index} points."
341+
)
342+
343+
# Smooth signal with rolling mean over window size based on decorrelation length
344+
rolling_window = max(3, decor_index) # at least 3 points in window
345+
col_smoothed = (
346+
self.df[col].rolling(window=rolling_window).mean()
347+
) # get smoothed column as Series
348+
col_sm_flld = col_smoothed.bfill() # fill initial NaNs with first valid value
349+
# create new DataFrame with time and smoothed flux
350+
df_smoothed = pd.DataFrame({"time": self.df["time"], col: col_sm_flld})
351+
352+
# Compute std dev of original signal from current location till end of signal
353+
std_dev_till_end = np.empty((n_pts,), dtype=float)
354+
for i in range(n_pts):
355+
std_dev_till_end[i] = np.std(self.df[col].iloc[i:])
356+
# turn this into a pandas series with same index as col_smoothed
357+
std_dev_till_end_series = pd.Series(std_dev_till_end, index=self.df.index)
358+
# Smooth this std dev to avoid it going to zero at end of signal
359+
std_dev_smoothed = std_dev_till_end_series.rolling(
360+
window=workflow._final_smoothing_window
361+
).mean()
362+
# Fill initial NaNs with the first valid smoothed std dev value
363+
std_dev_sm_flld = std_dev_smoothed.bfill()
364+
365+
# create new DataFrame with time and std dev till end of signal
366+
df_std_dev = pd.DataFrame(
367+
{"time": self.df["time"], col + "_std_till_end": std_dev_sm_flld}
368+
)
369+
370+
# start time of smoothed signal
371+
smoothed_start_time = df_smoothed["time"].iloc[rolling_window - 1]
372+
373+
# plot smoothed signal and related quantities
374+
if workflow._verbosity > 1:
375+
plt.figure(figsize=(10, 6))
376+
plt.plot(
377+
self.df["time"],
378+
self.df[col],
379+
label="Original Signal",
380+
alpha=0.5,
381+
)
382+
plt.plot(
383+
df_smoothed["time"],
384+
df_smoothed[col],
385+
label="Smoothed Signal",
386+
color="orange",
387+
)
388+
plt.plot(
389+
df_std_dev["time"],
390+
df_std_dev[col + "_std_till_end"],
391+
label="Smoothed Std Dev Till End",
392+
color="green",
393+
)
394+
plt.axvline(
395+
x=smoothed_start_time,
396+
color="g",
397+
linestyle="--",
398+
label="First smoothed point",
399+
)
400+
plt.xlabel("Time")
401+
plt.ylabel(col)
402+
plt.title("Original and Smoothed Signal")
403+
plt.legend()
404+
plt.grid()
405+
plt.show()
406+
plt.close()
407+
408+
if workflow._verbosity > 0:
409+
print("Getting start of SSS based on smoothed signal:")
410+
411+
# Get start of SSS based on where the value of the flux in the smoothed signal
412+
# is close to the mean of the remaining signal.
413+
414+
# At each location, compute the mean of the remaining smoothed signal
415+
n_pts_smoothed = len(df_smoothed)
416+
mean_vals = np.empty((n_pts_smoothed,), dtype=float)
417+
418+
for i in range(n_pts_smoothed):
419+
mean_vals[i] = np.mean(df_smoothed[col].iloc[i:])
420+
421+
# Check where the current value of the smoothed signal is within tol_fac of the mean of the remaining signal
422+
deviation_arr = np.abs(df_smoothed[col] - mean_vals)
423+
424+
# smooth this so the deviation does not go to zero at end of signal by construction
425+
# turn this into a pandas series with same index as col_smoothed
426+
deviation_series = pd.Series(deviation_arr, index=self.df.index)
427+
# Smooth this std dev to avoid it going to zero at end of signal
428+
deviation_smoothed = deviation_series.rolling(
429+
window=workflow._final_smoothing_window
430+
).mean()
431+
# Fill initial NaNs with the first valid smoothed std dev value
432+
deviation_sm_flld = deviation_smoothed.bfill()
433+
# Build a dataframe for the deviation
434+
deviation = pd.DataFrame(
435+
{"time": self.df["time"], col + "_deviation": deviation_sm_flld}
436+
)
437+
438+
# Compute tolerance on variation in the mean of the smoothed signal as
439+
# stdv_frac * (std dev till end + a fudge factor * mean value at start of smoothed signal)
440+
# fudge factor is for in case there is no noise (and to guard against the tolerance
441+
# factor going to zero when std dev gets very small at end of signal)
442+
tol_fac = workflow._std_dev_frac * (
443+
df_std_dev[col + "_std_till_end"] + workflow._fudge_fac * abs(mean_vals[0])
444+
)
445+
tolerance = tol_fac * np.abs(mean_vals)
446+
447+
within_tolerance_all = deviation[col + "_deviation"] <= tolerance
448+
# Only consider points after the smoothed signal has started
449+
within_tolerance = within_tolerance_all & (
450+
df_smoothed["time"] >= smoothed_start_time
451+
)
452+
# First index where we are within tolerance
453+
sss_index = np.where(within_tolerance)[0]
454+
455+
# See if there is a segment where ALL remaining points are within tolerance
456+
crit_met_index = None
457+
if len(sss_index) > 0:
458+
# find the segment where ALL remaining points are within tolerance
459+
for idx in sss_index:
460+
if np.all(within_tolerance[idx:]):
461+
crit_met_index = idx
462+
break
463+
464+
if crit_met_index is not None: # We have a SSS segment
465+
# Time where criterion has been met
466+
criterion_time = df_smoothed["time"].iloc[crit_met_index]
467+
# Take into account that the signal at the point where the criterion has been met is a result
468+
# of averaging over the rolling window. So set the start of SSS near the start of the rolling window
469+
# but not all the way at the beginning of the rolling window as there is usually still some transient.
470+
true_sss_start_index = max(
471+
0,
472+
int(
473+
crit_met_index
474+
- workflow._smoothing_window_correction * rolling_window
475+
),
476+
)
477+
sss_start_time = df_smoothed["time"].iloc[true_sss_start_index]
478+
479+
if workflow._verbosity > 0:
480+
print(f"Index where criterion is met: {crit_met_index}")
481+
print(f"Rolling window: {rolling_window}")
482+
print(f"time where criterion is met: {criterion_time}")
483+
print(
484+
f"time at start of SSS (adjusted for rolling window): {sss_start_time}"
485+
)
486+
487+
# Plot deviation and tolerance vs. time
488+
if workflow._verbosity > 1:
489+
plt.figure(figsize=(10, 6))
490+
plt.plot(
491+
df_smoothed["time"],
492+
deviation[col + "_deviation"],
493+
label="Deviation",
494+
color="blue",
495+
)
496+
plt.plot(
497+
df_smoothed["time"],
498+
tolerance,
499+
label="Tolerance",
500+
color="orange",
501+
)
502+
plt.axvline(
503+
x=criterion_time,
504+
color="g",
505+
linestyle="--",
506+
label="Small Change Criterion Met",
507+
)
508+
plt.axvline(
509+
x=sss_start_time, color="r", linestyle="--", label="Start SSS"
510+
)
511+
plt.xlabel("Time")
512+
plt.ylabel("Value")
513+
plt.title("Deviation and Tolerance vs. Time")
514+
plt.legend()
515+
plt.grid()
516+
plt.show()
517+
plt.close()
518+
519+
# Trim the original data frame to start at this location minus the smoothing window
520+
trimmed_df = self.df[self.df["time"] >= sss_start_time]
521+
# Reset the index so it starts at 0
522+
trimmed_df = trimmed_df.reset_index(drop=True)
523+
# Create new data stream from trimmed data frame
524+
trimmed_stream = DataStream(trimmed_df)
525+
526+
else:
527+
if workflow._verbosity > 0:
528+
print("No SSS found based on behavior of mean of smoothed signal.")
529+
trimmed_stream = pd.DataFrame(
530+
columns=["time", "flux"]
531+
) # Create empty DataFrame with same columns as original
532+
533+
# Plot deviation and tolerance vs. time
534+
if workflow._verbosity > 1:
535+
plt.figure(figsize=(10, 6))
536+
plt.plot(
537+
df_smoothed["time"],
538+
deviation[col + "_deviation"],
539+
label="Deviation",
540+
color="blue",
541+
)
542+
plt.plot(
543+
df_smoothed["time"],
544+
tolerance,
545+
label="Tolerance",
546+
color="orange",
547+
)
548+
plt.xlabel("Time")
549+
plt.ylabel("Value")
550+
plt.title("Deviation and Tolerance vs. Time")
551+
plt.legend()
552+
plt.grid()
553+
plt.show()
554+
plt.close()
555+
556+
return trimmed_stream
557+
282558
@staticmethod
283559
def find_steady_state_std(data, column_name, window_size=10, robust=True):
284560
"""

0 commit comments

Comments
 (0)