|
4 | 4 | import pandas as pd |
5 | 5 | import pytest |
6 | 6 |
|
7 | | -from quends import DataStream |
| 7 | +from quends import DataStream, RobustWorkflow |
8 | 8 |
|
9 | 9 |
|
10 | 10 | # === Fixtures === |
@@ -428,10 +428,6 @@ def test_cumulative_stats_empty(nan_data): |
428 | 428 | assert ds.cumulative_statistics(window_size=1) == expected |
429 | 429 |
|
430 | 430 |
|
431 | | -# === Additional Data === |
432 | | -import pytest |
433 | | - |
434 | | - |
435 | 431 | def assert_nested_approx(a, b, rel=1e-8): |
436 | 432 | if isinstance(a, dict) and isinstance(b, dict): |
437 | 433 | assert a.keys() == b.keys() |
@@ -656,7 +652,7 @@ def test_find_steady_state_trim_data(trim_data): |
656 | 652 |
|
657 | 653 |
|
658 | 654 | def test_find_steady_state_with_start_time(long_data): |
659 | | - ds = DataStream(long_data) |
| 655 | + DataStream(long_data) |
660 | 656 | pass # |
661 | 657 |
|
662 | 658 |
|
@@ -793,3 +789,103 @@ def test_effective_sample_size_missing_col(long_data): |
793 | 789 | ], |
794 | 790 | } |
795 | 791 | assert result == expected |
| 792 | + |
| 793 | + |
| 794 | +def test_make_stationary_with_stationary_data(stationary_data, workflow): |
| 795 | + ds = DataStream(stationary_data) |
| 796 | + col = "A" |
| 797 | + n_pts_orig = len(stationary_data) |
| 798 | + |
| 799 | + # Call the method |
| 800 | + result_ds, stationary = ds.make_stationary(col, n_pts_orig, workflow) |
| 801 | + print(stationary) |
| 802 | + # Check if the returned DataStream is indeed stationary |
| 803 | + assert ( |
| 804 | + stationary == "Error: Invalid input, x is constant" |
| 805 | + ), "Expected an error message for constant input." |
| 806 | + assert len(result_ds.df) == n_pts_orig |
| 807 | + |
| 808 | + |
| 809 | +@pytest.fixture |
| 810 | +def workflow(): |
| 811 | + """ |
| 812 | + Deterministic workflow configuration for stationarity tests. |
| 813 | + """ |
| 814 | + wf = RobustWorkflow( |
| 815 | + operate_safe=False, |
| 816 | + verbosity=0, |
| 817 | + smoothing_window_correction=0.3, |
| 818 | + ) |
| 819 | + wf._drop_fraction = 0.2 |
| 820 | + wf._n_pts_min = 50 |
| 821 | + wf._n_pts_frac_min = 0.2 |
| 822 | + return wf |
| 823 | + |
| 824 | + |
| 825 | +@pytest.fixture |
| 826 | +def stationary_noise_df(): |
| 827 | + """ |
| 828 | + Already stationary signal. |
| 829 | + """ |
| 830 | + np.random.seed(0) |
| 831 | + return pd.DataFrame( |
| 832 | + { |
| 833 | + "time": np.arange(300), |
| 834 | + "A": np.random.normal(0, 1, 300), |
| 835 | + } |
| 836 | + ) |
| 837 | + |
| 838 | + |
| 839 | +@pytest.fixture |
| 840 | +def slope_to_stationary_df(): |
| 841 | + """ |
| 842 | + Non-stationary trend followed by stationary noise. |
| 843 | + This is the primary success case. |
| 844 | + """ |
| 845 | + np.random.seed(42) |
| 846 | + |
| 847 | + trend = 2 * np.arange(100) |
| 848 | + stationary = np.random.normal(0, 5, 400) |
| 849 | + |
| 850 | + signal = np.concatenate([trend, stationary]) |
| 851 | + |
| 852 | + return pd.DataFrame( |
| 853 | + { |
| 854 | + "time": np.arange(len(signal)), |
| 855 | + "A": signal, |
| 856 | + } |
| 857 | + ) |
| 858 | + |
| 859 | + |
| 860 | +@pytest.fixture |
| 861 | +def pure_trend_df(): |
| 862 | + """ |
| 863 | + Pure non-stationary trend that cannot be fixed by dropping. |
| 864 | + """ |
| 865 | + x = np.arange(300) |
| 866 | + return pd.DataFrame( |
| 867 | + { |
| 868 | + "time": x, |
| 869 | + "A": 3 * x + 10, |
| 870 | + } |
| 871 | + ) |
| 872 | + |
| 873 | + |
| 874 | +def test_make_stationary_already_stationary(stationary_noise_df, workflow): |
| 875 | + ds = DataStream(stationary_noise_df) |
| 876 | + n_pts_orig = len(ds.df) |
| 877 | + |
| 878 | + result_ds, stationary = ds.make_stationary("A", n_pts_orig, workflow) |
| 879 | + |
| 880 | + assert stationary.any() == np.True_ |
| 881 | + assert len(result_ds.df) == n_pts_orig |
| 882 | + |
| 883 | + |
| 884 | +def test_make_stationary_drops_trend(slope_to_stationary_df, workflow): |
| 885 | + ds = DataStream(slope_to_stationary_df) |
| 886 | + n_pts_orig = len(ds.df) |
| 887 | + |
| 888 | + result_ds, stationary = ds.make_stationary("A", n_pts_orig, workflow) |
| 889 | + |
| 890 | + assert stationary == np.True_ |
| 891 | + assert len(result_ds.df) < n_pts_orig |
0 commit comments