|
5 | 5 | from more_itertools import windowed
|
6 | 6 |
|
7 | 7 | from delphi.epidata.server.utils import date_to_time_value, CovidcastRecords
|
8 |
| -from delphi.epidata.server.endpoints.covidcast_utils.smooth_diff import generate_row_diffs, generate_smooth_rows |
| 8 | +from delphi.epidata.server.endpoints.covidcast_utils.smooth_diff import generate_row_diffs, generate_smooth_rows, _smoother |
9 | 9 |
|
10 | 10 |
|
11 | 11 | class TestStreaming:
|
| 12 | + def test__smoother(self): |
| 13 | + assert _smoother(list(range(7)), [1] * 7) == sum(range(7)) |
| 14 | + assert _smoother([1] * 6, list(range(7))) == sum([1] * 6) / 6 |
| 15 | + |
| 16 | + |
| 17 | + def test_generate_smooth_rows(self): |
| 18 | + # an empty dataframe should return an empty dataframe |
| 19 | + data = DataFrame({}) |
| 20 | + smoothed_df = DataFrame.from_records(generate_smooth_rows(data.to_dict(orient='records'))) |
| 21 | + expected_df = DataFrame({}) |
| 22 | + assert_frame_equal(smoothed_df, expected_df) |
| 23 | + |
| 24 | + # a dataframe with a single entry should return a single nan value |
| 25 | + data = CovidcastRecords( |
| 26 | + time_values=[20210501], |
| 27 | + values=[1.0] |
| 28 | + ).as_dataframe() |
| 29 | + smoothed_df = DataFrame.from_records(generate_smooth_rows(data.to_dict(orient='records'))) |
| 30 | + expected_df = CovidcastRecords( |
| 31 | + time_values=[20210501], |
| 32 | + values=[None], |
| 33 | + stderrs=[None], |
| 34 | + sample_sizes=[None] |
| 35 | + ).as_dataframe() |
| 36 | + assert_frame_equal(smoothed_df, expected_df) |
| 37 | + |
| 38 | + data = CovidcastRecords( |
| 39 | + time_values=date_range("2021-05-01", "2021-05-10"), |
| 40 | + values=chain(range(7), [None, 2., 1.]) |
| 41 | + ).as_dataframe() |
| 42 | + |
| 43 | + # regular window, nan fill |
| 44 | + smoothed_df = DataFrame.from_records(generate_smooth_rows(data.to_dict(orient='records'))) |
| 45 | + expected_df = CovidcastRecords( |
| 46 | + time_values=date_range("2021-05-07", "2021-05-10"), |
| 47 | + values=(sum(x)/len(x) if None not in x else None for x in windowed(chain(range(7), [None, 2., 1.]), 7)), |
| 48 | + stderrs=[None]*4, |
| 49 | + sample_sizes=[None]*4, |
| 50 | + ).as_dataframe() |
| 51 | + assert_frame_equal(smoothed_df, expected_df) |
| 52 | + |
| 53 | + # regular window, 0 fill |
| 54 | + smoothed_df = DataFrame.from_records(generate_smooth_rows(data.to_dict(orient='records'), nan_fill_value=0.)) |
| 55 | + expected_df = CovidcastRecords( |
| 56 | + time_values=date_range("2021-05-07", "2021-05-10"), |
| 57 | + values=(sum(x)/len(x) if None not in x else None for x in windowed(chain(range(7), [0., 2., 1.]), 7)), |
| 58 | + stderrs=[None]*4, |
| 59 | + sample_sizes=[None]*4, |
| 60 | + ).as_dataframe() |
| 61 | + assert_frame_equal(smoothed_df, expected_df) |
| 62 | + |
| 63 | + # regular window, different window length |
| 64 | + smoothed_df = DataFrame.from_records(generate_smooth_rows(data.to_dict(orient='records'), smoother_window_length = 8)) |
| 65 | + expected_df = CovidcastRecords( |
| 66 | + time_values=date_range("2021-05-08", "2021-05-10"), |
| 67 | + values=(sum(x)/len(x) if None not in x else None for x in windowed(chain(range(7), [None, 2., 1.]), 8)), |
| 68 | + stderrs=[None]*3, |
| 69 | + sample_sizes=[None]*3, |
| 70 | + ).as_dataframe() |
| 71 | + assert_frame_equal(smoothed_df, expected_df) |
| 72 | + |
| 73 | + # regular window, different kernel |
| 74 | + smoothed_df = DataFrame.from_records(generate_smooth_rows(data.to_dict(orient='records'), smoother_kernel = list(range(8)))) |
| 75 | + expected_df = CovidcastRecords( |
| 76 | + time_values=date_range("2021-05-08", "2021-05-10"), |
| 77 | + values=(sum([i * j for i, j in zip(x, range(8))])/len(x) if None not in x else None for x in windowed(chain(range(7), [None, 2., 1.]), 8)), |
| 78 | + stderrs=[None]*3, |
| 79 | + sample_sizes=[None]*3, |
| 80 | + ).as_dataframe() |
| 81 | + assert_frame_equal(smoothed_df, expected_df) |
| 82 | + |
| 83 | + # conflicting smoother args validation |
| 84 | + smoothed_df = DataFrame.from_records(generate_smooth_rows(data.to_dict(orient='records'), smoother_kernel=[1/7.]*7, smoother_window_length=10)) |
| 85 | + expected_df = CovidcastRecords( |
| 86 | + time_values=date_range("2021-05-07", "2021-05-10"), |
| 87 | + values=(sum([i * j for i, j in zip(x, [1/7.]*7)]) if None not in x else None for x in windowed(chain(range(7), [None, 2., 1.]), 7)), |
| 88 | + stderrs=[None]*4, |
| 89 | + sample_sizes=[None]*4, |
| 90 | + ).as_dataframe() |
| 91 | + assert_frame_equal(smoothed_df, expected_df) |
12 | 92 |
|
13 | 93 |
|
14 | 94 | def test_generate_row_diffs(self):
|
|
0 commit comments