|
1 | 1 | import os |
2 | 2 | import logging |
3 | 3 | import pandas as pd |
| 4 | +import dask.array as da |
| 5 | +import dask.bag as db |
| 6 | +import dask.dataframe as dd |
4 | 7 |
|
| 8 | +from typing import List |
5 | 9 | from astropy import units as u |
6 | 10 | from astropy.coordinates import SkyCoord |
7 | 11 |
|
8 | | -from vast_pipeline.models import Association, RelatedSource |
| 12 | +from vast_pipeline.models import Association, RelatedSource, Run |
9 | 13 | from vast_pipeline.utils.utils import StopWatch |
10 | 14 |
|
11 | 15 | from .loading import ( |
|
18 | 22 |
|
19 | 23 |
|
20 | 24 | def final_operations( |
21 | | - sources_df, first_img, p_run, meas_dj_obj, new_sources_df): |
| 25 | + sources_df: dd.DataFrame, first_img: str, p_run: Run, meas_dj_obj: List, |
| 26 | + new_sources_df: dd.DataFrame) -> int: |
22 | 27 | timer = StopWatch() |
23 | 28 |
|
24 | 29 | # calculate source fields |
25 | 30 | logger.info( |
26 | | - 'Calculating statistics for %i sources...', |
27 | | - sources_df.source.unique().shape[0] |
| 31 | + 'Calculating statistics for sources...' |
| 32 | + # 'Calculating statistics for %i sources...', |
| 33 | + # sources_df.source.unique().shape[0] |
28 | 34 | ) |
29 | 35 |
|
30 | 36 | timer.reset() |
31 | 37 | srcs_df = parallel_groupby(sources_df) |
32 | | - logger.info('Groupby-apply time: %.2f seconds', timer.reset()) |
| 38 | + # logger.info('Groupby-apply time: %.2f seconds', timer.reset()) |
33 | 39 | # fill NaNs as resulted from calculated metrics with 0 |
34 | 40 | srcs_df = srcs_df.fillna(0.) |
35 | 41 |
|
36 | 42 | # add new sources |
37 | | - srcs_df['new'] = srcs_df.index.isin(new_sources_df.index) |
| 43 | + srcs_df['new'] = srcs_df.index.isin( |
| 44 | + new_sources_df.index.values.compute() |
| 45 | + ) |
38 | 46 |
|
39 | | - srcs_df = pd.merge( |
40 | | - srcs_df, |
41 | | - new_sources_df['new_high_sigma'], |
42 | | - left_on='source', right_index=True, how='left' |
| 47 | + srcs_df = srcs_df.merge( |
| 48 | + new_sources_df[['new_high_sigma']], |
| 49 | + left_index=True, right_index=True, how='left' |
43 | 50 | ) |
44 | 51 |
|
45 | 52 | srcs_df['new_high_sigma'] = srcs_df['new_high_sigma'].fillna(0.) |
46 | 53 |
|
47 | 54 | # calculate nearest neighbour |
48 | | - srcs_skycoord = SkyCoord( |
49 | | - srcs_df['wavg_ra'], |
50 | | - srcs_df['wavg_dec'], |
51 | | - unit=(u.deg, u.deg) |
52 | | - ) |
| 55 | + ra, dec = dd.compute(srcs_df['wavg_ra'], srcs_df['wavg_dec']) |
| 56 | + srcs_skycoord = SkyCoord(ra, dec, unit=(u.deg, u.deg)) |
| 57 | + del ra, dec |
53 | 58 |
|
54 | 59 | idx, d2d, _ = srcs_skycoord.match_to_catalog_sky( |
55 | 60 | srcs_skycoord, |
56 | 61 | nthneighbor=2 |
57 | 62 | ) |
58 | | - |
59 | | - srcs_df['n_neighbour_dist'] = d2d.deg |
| 63 | + # Dask doen't like assignment of a column from an array, need to be |
| 64 | + # a Dask array with partitions, chuncks in this case |
| 65 | + # TODO: checks if using shape will be faster than len (FYI each partition is |
| 66 | + # a pandas dataframe) |
| 67 | + arr_chuncks = tuple(srcs_df.map_partitions(len).compute()) |
| 68 | + srcs_df['n_neighbour_dist'] = da.from_array(d2d.deg, chunks=arr_chuncks) |
| 69 | + del arr_chuncks, idx, d2d, srcs_skycoord |
60 | 70 |
|
61 | 71 | # generate the source models |
62 | 72 | srcs_df['src_dj'] = srcs_df.apply( |
63 | 73 | get_source_models, |
64 | 74 | pipeline_run=p_run, |
65 | | - axis=1 |
| 75 | + axis=1, |
| 76 | + meta=object |
66 | 77 | ) |
| 78 | + import ipdb; ipdb.set_trace() # breakpoint 0eaeae7c // |
67 | 79 | # upload sources and related to DB |
68 | 80 | upload_sources(p_run, srcs_df) |
69 | 81 |
|
70 | 82 | # get db ids for sources |
71 | | - srcs_df['id'] = srcs_df['src_dj'].apply(lambda x: x.id) |
| 83 | + srcs_df['id'] = srcs_df['src_dj'].apply(lambda x: x.id, meta=int) |
72 | 84 |
|
| 85 | + import ipdb; ipdb.set_trace() # breakpoint bcf8f142 // |
73 | 86 | # gather the related df, upload to db and save to parquet file |
74 | 87 | # the df will look like |
75 | 88 | # |
|
0 commit comments