|
1 | 1 | import math |
2 | 2 |
|
| 3 | +import matplotlib.pyplot as plt |
3 | 4 | import numpy as np |
4 | 5 | import pandas as pd |
| 6 | +import scipy.stats as sts |
| 7 | +import statsmodels.tsa.stattools as ststls |
5 | 8 | from scipy.optimize import curve_fit |
6 | 9 | from scipy.stats import norm, rankdata |
7 | 10 | from sklearn.preprocessing import MinMaxScaler |
@@ -279,6 +282,279 @@ def trim( |
279 | 282 | new_history.append({"operation": "trim", "options": options}) |
280 | 283 | return DataStream(self.df.iloc[0:0].copy(), _history=new_history) |
281 | 284 |
|
| 285 | + def trim_sss_start(self, col, workflow): |
| 286 | + """ |
| 287 | + Identify and trim the signal to the start of the Statistical Steady State (SSS) |
| 288 | +
|
| 289 | + Parameters |
| 290 | + ---------- |
| 291 | + col : str |
| 292 | + The name of the column in `self.df` to analyze for steady state. |
| 293 | + workflow : object |
| 294 | + A configuration/workflow object containing parameters: |
| 295 | + - `_max_lag_frac`: Fraction of data used for autocorrelation lag. |
| 296 | + - `_verbosity`: Integer controlling plot and print output levels. |
| 297 | + - `_autocorr_sig_level`: Significance level for the Z-test on lags. |
| 298 | + - `_decor_multiplier`: Multiplier for the calculated decorrelation length. |
| 299 | + - `_std_dev_frac`: Fraction of standard deviation used for tolerance. |
| 300 | + - `_fudge_fac`: Constant to prevent zero-tolerance in noiseless signals. |
| 301 | + - `_smoothing_window_correction`: Factor to adjust for rolling mean lag. |
| 302 | + - `_final_smoothing_window`: Window size for smoothing the metric curves. |
| 303 | +
|
| 304 | + Returns |
| 305 | + ------- |
| 306 | + DataStream |
| 307 | + A new DataStream object containing the DataFrame trimmed to the SSS start. |
| 308 | + Returns an empty DataFrame if no SSS is identified. |
| 309 | + """ |
| 310 | + # Get the decorrelation length (in number of points) |
| 311 | + # Note: this approach assumes signal points are spaced equally in time |
| 312 | + n_pts = len(self.df) |
| 313 | + max_lag = int(workflow._max_lag_frac * n_pts) # max lag for autocorrelation |
| 314 | + |
| 315 | + acf_vals = ststls.acf(self.df[col].dropna().values, nlags=max_lag) |
| 316 | + |
| 317 | + # plot the autocorrelation function |
| 318 | + if workflow._verbosity > 1: |
| 319 | + plt.figure(figsize=(10, 6)) |
| 320 | + plt.stem(range(len(acf_vals)), acf_vals) |
| 321 | + plt.xlabel("Lag") |
| 322 | + plt.ylabel("Autocorrelation") |
| 323 | + plt.title("Autocorrelation Function") |
| 324 | + plt.grid() |
| 325 | + plt.show() |
| 326 | + plt.close() |
| 327 | + |
| 328 | + # Use rigorous statistical measure for decorrelation length |
| 329 | + z_critical = sts.norm.ppf(1 - workflow._autocorr_sig_level / 2) |
| 330 | + conf_interval = z_critical / np.sqrt(n_pts) |
| 331 | + significant_lags = np.where(np.abs(acf_vals[1:]) > conf_interval)[0] |
| 332 | + acf_sum = np.sum(np.abs(acf_vals[1:][significant_lags])) |
| 333 | + decor_length = int(np.ceil(1 + 2 * acf_sum)) |
| 334 | + |
| 335 | + # Set smoothing window as multiple of decorrelation length, but not more than max_lag |
| 336 | + decor_index = min(int(workflow._decor_multiplier * decor_length), max_lag) |
| 337 | + |
| 338 | + if workflow._verbosity > 0: |
| 339 | + print( |
| 340 | + f"stats decorrelation length {decor_length} gives smoothing window of {decor_index} points." |
| 341 | + ) |
| 342 | + |
| 343 | + # Smooth signal with rolling mean over window size based on decorrelation length |
| 344 | + rolling_window = max(3, decor_index) # at least 3 points in window |
| 345 | + col_smoothed = ( |
| 346 | + self.df[col].rolling(window=rolling_window).mean() |
| 347 | + ) # get smoothed column as Series |
| 348 | + col_sm_flld = col_smoothed.bfill() # fill initial NaNs with first valid value |
| 349 | + # create new DataFrame with time and smoothed flux |
| 350 | + df_smoothed = pd.DataFrame({"time": self.df["time"], col: col_sm_flld}) |
| 351 | + |
| 352 | + # Compute std dev of original signal from current location till end of signal |
| 353 | + std_dev_till_end = np.empty((n_pts,), dtype=float) |
| 354 | + for i in range(n_pts): |
| 355 | + std_dev_till_end[i] = np.std(self.df[col].iloc[i:]) |
| 356 | + # turn this into a pandas series with same index as col_smoothed |
| 357 | + std_dev_till_end_series = pd.Series(std_dev_till_end, index=self.df.index) |
| 358 | + # Smooth this std dev to avoid it going to zero at end of signal |
| 359 | + std_dev_smoothed = std_dev_till_end_series.rolling( |
| 360 | + window=workflow._final_smoothing_window |
| 361 | + ).mean() |
| 362 | + # Fill initial NaNs with the first valid smoothed std dev value |
| 363 | + std_dev_sm_flld = std_dev_smoothed.bfill() |
| 364 | + |
| 365 | + # create new DataFrame with time and std dev till end of signal |
| 366 | + df_std_dev = pd.DataFrame( |
| 367 | + {"time": self.df["time"], col + "_std_till_end": std_dev_sm_flld} |
| 368 | + ) |
| 369 | + |
| 370 | + # start time of smoothed signal |
| 371 | + smoothed_start_time = df_smoothed["time"].iloc[rolling_window - 1] |
| 372 | + |
| 373 | + # plot smoothed signal and related quantities |
| 374 | + if workflow._verbosity > 1: |
| 375 | + plt.figure(figsize=(10, 6)) |
| 376 | + plt.plot( |
| 377 | + self.df["time"], |
| 378 | + self.df[col], |
| 379 | + label="Original Signal", |
| 380 | + alpha=0.5, |
| 381 | + ) |
| 382 | + plt.plot( |
| 383 | + df_smoothed["time"], |
| 384 | + df_smoothed[col], |
| 385 | + label="Smoothed Signal", |
| 386 | + color="orange", |
| 387 | + ) |
| 388 | + plt.plot( |
| 389 | + df_std_dev["time"], |
| 390 | + df_std_dev[col + "_std_till_end"], |
| 391 | + label="Smoothed Std Dev Till End", |
| 392 | + color="green", |
| 393 | + ) |
| 394 | + plt.axvline( |
| 395 | + x=smoothed_start_time, |
| 396 | + color="g", |
| 397 | + linestyle="--", |
| 398 | + label="First smoothed point", |
| 399 | + ) |
| 400 | + plt.xlabel("Time") |
| 401 | + plt.ylabel(col) |
| 402 | + plt.title("Original and Smoothed Signal") |
| 403 | + plt.legend() |
| 404 | + plt.grid() |
| 405 | + plt.show() |
| 406 | + plt.close() |
| 407 | + |
| 408 | + if workflow._verbosity > 0: |
| 409 | + print("Getting start of SSS based on smoothed signal:") |
| 410 | + |
| 411 | + # Get start of SSS based on where the value of the flux in the smoothed signal |
| 412 | + # is close to the mean of the remaining signal. |
| 413 | + |
| 414 | + # At each location, compute the mean of the remaining smoothed signal |
| 415 | + n_pts_smoothed = len(df_smoothed) |
| 416 | + mean_vals = np.empty((n_pts_smoothed,), dtype=float) |
| 417 | + |
| 418 | + for i in range(n_pts_smoothed): |
| 419 | + mean_vals[i] = np.mean(df_smoothed[col].iloc[i:]) |
| 420 | + |
| 421 | + # Check where the current value of the smoothed signal is within tol_fac of the mean of the remaining signal |
| 422 | + deviation_arr = np.abs(df_smoothed[col] - mean_vals) |
| 423 | + |
| 424 | + # smooth this so the deviation does not go to zero at end of signal by construction |
| 425 | + # turn this into a pandas series with same index as col_smoothed |
| 426 | + deviation_series = pd.Series(deviation_arr, index=self.df.index) |
| 427 | + # Smooth this std dev to avoid it going to zero at end of signal |
| 428 | + deviation_smoothed = deviation_series.rolling( |
| 429 | + window=workflow._final_smoothing_window |
| 430 | + ).mean() |
| 431 | + # Fill initial NaNs with the first valid smoothed std dev value |
| 432 | + deviation_sm_flld = deviation_smoothed.bfill() |
| 433 | + # Build a dataframe for the deviation |
| 434 | + deviation = pd.DataFrame( |
| 435 | + {"time": self.df["time"], col + "_deviation": deviation_sm_flld} |
| 436 | + ) |
| 437 | + |
| 438 | + # Compute tolerance on variation in the mean of the smoothed signal as |
| 439 | + # stdv_frac * (std dev till end + a fudge factor * mean value at start of smoothed signal) |
| 440 | + # fudge factor is for in case there is no noise (and to guard against the tolerance |
| 441 | + # factor going to zero when std dev gets very small at end of signal) |
| 442 | + tol_fac = workflow._std_dev_frac * ( |
| 443 | + df_std_dev[col + "_std_till_end"] + workflow._fudge_fac * abs(mean_vals[0]) |
| 444 | + ) |
| 445 | + tolerance = tol_fac * np.abs(mean_vals) |
| 446 | + |
| 447 | + within_tolerance_all = deviation[col + "_deviation"] <= tolerance |
| 448 | + # Only consider points after the smoothed signal has started |
| 449 | + within_tolerance = within_tolerance_all & ( |
| 450 | + df_smoothed["time"] >= smoothed_start_time |
| 451 | + ) |
| 452 | + # First index where we are within tolerance |
| 453 | + sss_index = np.where(within_tolerance)[0] |
| 454 | + |
| 455 | + # See if there is a segment where ALL remaining points are within tolerance |
| 456 | + crit_met_index = None |
| 457 | + if len(sss_index) > 0: |
| 458 | + # find the segment where ALL remaining points are within tolerance |
| 459 | + for idx in sss_index: |
| 460 | + if np.all(within_tolerance[idx:]): |
| 461 | + crit_met_index = idx |
| 462 | + break |
| 463 | + |
| 464 | + if crit_met_index is not None: # We have a SSS segment |
| 465 | + # Time where criterion has been met |
| 466 | + criterion_time = df_smoothed["time"].iloc[crit_met_index] |
| 467 | + # Take into account that the signal at the point where the criterion has been met is a result |
| 468 | + # of averaging over the rolling window. So set the start of SSS near the start of the rolling window |
| 469 | + # but not all the way at the beginning of the rolling window as there is usually still some transient. |
| 470 | + true_sss_start_index = max( |
| 471 | + 0, |
| 472 | + int( |
| 473 | + crit_met_index |
| 474 | + - workflow._smoothing_window_correction * rolling_window |
| 475 | + ), |
| 476 | + ) |
| 477 | + sss_start_time = df_smoothed["time"].iloc[true_sss_start_index] |
| 478 | + |
| 479 | + if workflow._verbosity > 0: |
| 480 | + print(f"Index where criterion is met: {crit_met_index}") |
| 481 | + print(f"Rolling window: {rolling_window}") |
| 482 | + print(f"time where criterion is met: {criterion_time}") |
| 483 | + print( |
| 484 | + f"time at start of SSS (adjusted for rolling window): {sss_start_time}" |
| 485 | + ) |
| 486 | + |
| 487 | + # Plot deviation and tolerance vs. time |
| 488 | + if workflow._verbosity > 1: |
| 489 | + plt.figure(figsize=(10, 6)) |
| 490 | + plt.plot( |
| 491 | + df_smoothed["time"], |
| 492 | + deviation[col + "_deviation"], |
| 493 | + label="Deviation", |
| 494 | + color="blue", |
| 495 | + ) |
| 496 | + plt.plot( |
| 497 | + df_smoothed["time"], |
| 498 | + tolerance, |
| 499 | + label="Tolerance", |
| 500 | + color="orange", |
| 501 | + ) |
| 502 | + plt.axvline( |
| 503 | + x=criterion_time, |
| 504 | + color="g", |
| 505 | + linestyle="--", |
| 506 | + label="Small Change Criterion Met", |
| 507 | + ) |
| 508 | + plt.axvline( |
| 509 | + x=sss_start_time, color="r", linestyle="--", label="Start SSS" |
| 510 | + ) |
| 511 | + plt.xlabel("Time") |
| 512 | + plt.ylabel("Value") |
| 513 | + plt.title("Deviation and Tolerance vs. Time") |
| 514 | + plt.legend() |
| 515 | + plt.grid() |
| 516 | + plt.show() |
| 517 | + plt.close() |
| 518 | + |
| 519 | + # Trim the original data frame to start at this location minus the smoothing window |
| 520 | + trimmed_df = self.df[self.df["time"] >= sss_start_time] |
| 521 | + # Reset the index so it starts at 0 |
| 522 | + trimmed_df = trimmed_df.reset_index(drop=True) |
| 523 | + # Create new data stream from trimmed data frame |
| 524 | + trimmed_stream = DataStream(trimmed_df) |
| 525 | + |
| 526 | + else: |
| 527 | + if workflow._verbosity > 0: |
| 528 | + print("No SSS found based on behavior of mean of smoothed signal.") |
| 529 | + trimmed_stream = pd.DataFrame( |
| 530 | + columns=["time", "flux"] |
| 531 | + ) # Create empty DataFrame with same columns as original |
| 532 | + |
| 533 | + # Plot deviation and tolerance vs. time |
| 534 | + if workflow._verbosity > 1: |
| 535 | + plt.figure(figsize=(10, 6)) |
| 536 | + plt.plot( |
| 537 | + df_smoothed["time"], |
| 538 | + deviation[col + "_deviation"], |
| 539 | + label="Deviation", |
| 540 | + color="blue", |
| 541 | + ) |
| 542 | + plt.plot( |
| 543 | + df_smoothed["time"], |
| 544 | + tolerance, |
| 545 | + label="Tolerance", |
| 546 | + color="orange", |
| 547 | + ) |
| 548 | + plt.xlabel("Time") |
| 549 | + plt.ylabel("Value") |
| 550 | + plt.title("Deviation and Tolerance vs. Time") |
| 551 | + plt.legend() |
| 552 | + plt.grid() |
| 553 | + plt.show() |
| 554 | + plt.close() |
| 555 | + |
| 556 | + return trimmed_stream |
| 557 | + |
282 | 558 | @staticmethod |
283 | 559 | def find_steady_state_std(data, column_name, window_size=10, robust=True): |
284 | 560 | """ |
|
0 commit comments