33import pandas as pd
44import numpy as np
55import dask .dataframe as dd
6+ import dask .bag as db
67
78from psutil import cpu_count
89from astropy import units as u
910from astropy .coordinates import SkyCoord
1011from astropy .io import fits
1112from astropy .wcs import WCS
1213from astropy .wcs .utils import skycoord_to_pixel
14+ from typing import Dict
1315
1416from vast_pipeline .models import Image
1517from vast_pipeline .utils .utils import StopWatch
1820logger = logging .getLogger (__name__ )
1921
2022
21- def check_primary_image (row ) :
23+ def check_primary_image (row : pd . Series ) -> bool :
2224 return row ['primary' ] in row ['img_list' ]
2325
2426
25- def get_image_rms_measurements (group ):
26- """
27- Take the coordinates provided from the group
28- and measure the array value in the provided image.
29- """
30- image = group .iloc [0 ]['img_diff_rms_path' ]
31-
27+ def extract_rms_data_from_img (image : str ) -> Dict :
3228 with fits .open (image ) as hdul :
33- wcs = WCS (hdul [0 ].header , naxis = 2 )
29+ wcs = WCS (hdul [0 ].header , naxis = 2 )
30+ try :
31+ # ASKAP tile images
32+ data = hdul [0 ].data [0 , 0 , :, :]
33+ except Exception as e :
34+ # ASKAP SWarp images
35+ data = hdul [0 ].data
36+
37+ return {'data' : data , 'wcs' : wcs }
3438
35- try :
36- # ASKAP tile images
37- data = hdul [0 ].data [0 , 0 , :, :]
38- except Exception as e :
39- # ASKAP SWarp images
40- data = hdul [0 ].data
4139
40+ def get_coord_array (df : pd .DataFrame ) -> SkyCoord :
4241 coords = SkyCoord (
43- group .wavg_ra , group .wavg_dec , unit = (u .deg , u .deg )
42+ df ['wavg_ra' ].values ,
43+ df ['wavg_dec' ].values ,
44+ unit = (u .deg , u .deg )
4445 )
4546
46- array_coords = wcs .world_to_array_index (coords )
47+ return coords
48+
49+
50+ def finalise_rms_calcs (rms : Dict , coords : np .array ,
51+ df : pd .DataFrame ) -> pd .DataFrame :
52+ """
53+ Take the coordinates provided from the group
54+ and measure the array value in the provided image.
55+ """
56+ array_coords = rms ['wcs' ].world_to_array_index (coords )
4757 array_coords = np .array ([
4858 np .array (array_coords [0 ]),
4959 np .array (array_coords [1 ]),
5060 ])
5161
5262 # check for pixel wrapping
5363 x_valid = np .logical_or (
54- array_coords [0 ] >= data .shape [0 ],
64+ array_coords [0 ] >= rms [ ' data' ] .shape [0 ],
5565 array_coords [0 ] < 0
5666 )
5767
5868 y_valid = np .logical_or (
59- array_coords [1 ] >= data .shape [1 ],
69+ array_coords [1 ] >= rms [ ' data' ] .shape [1 ],
6070 array_coords [1 ] < 0
6171 )
6272
63- valid = ~ np .logical_or (
64- x_valid , y_valid
65- )
66-
67- valid_indexes = group [valid ].index .values
73+ # calculated the mask for indexes
74+ valid = ~ np .logical_or (x_valid , y_valid )
6875
69- rms_values = data [
76+ # create the column data, not matched ones will be NaN.
77+ rms_values = np .full (valid .shape , np .NaN )
78+ rms_values [valid ] = rms ['data' ][
7079 array_coords [0 ][valid ],
7180 array_coords [1 ][valid ]
72- ]
81+ ]. astype ( np . float64 ) * 1.e3
7382
74- # not matched ones will be NaN.
75- group .loc [
76- valid_indexes , 'img_diff_true_rms'
77- ] = rms_values .astype (np .float64 ) * 1.e3
83+ # copy the df and create the rms column
84+ df_out = df .copy ()# dask doesn't like to modify inputs in place
85+ df_out ['img_diff_true_rms' ] = rms_values
7886
79- return group
87+ return df_out
8088
8189
82- def parallel_get_rms_measurements (df ) :
90+ def parallel_get_rms_measurements (df : dd . core . DataFrame ) -> dd . core . DataFrame :
8391 """
8492 Wrapper function to use 'get_image_rms_measurements'
8593 in parallel with Dask.
8694 """
95+ # Use the Dask bag backend to work on different image files
96+ # calculate first the unique image_diff then create the bag
97+ uniq_img_diff = (
98+ df ['img_diff_rms_path' ].unique ()
99+ .compute ()
100+ .tolist ()
101+ )
102+ nr_uniq_img = len (uniq_img_diff )
103+ # map the extract function to the bag to get data from images
104+ img_data_bags = (
105+ db .from_sequence (uniq_img_diff , npartitions = nr_uniq_img )
106+ .map (extract_rms_data_from_img )
107+ )
87108
88- out = df [[
89- 'source' , 'wavg_ra' , 'wavg_dec' ,
90- 'img_diff_rms_path'
91- ]]
92-
93- col_dtype = {
94- 'source' : 'i' ,
95- 'wavg_ra' : 'f' ,
96- 'wavg_dec' : 'f' ,
97- 'img_diff_rms_path' : 'U' ,
98- 'img_diff_true_rms' : 'f' ,
99- }
100-
101- n_cpu = cpu_count () - 1
102-
103- out = (
104- dd .from_pandas (out , n_cpu )
105- .groupby ('img_diff_rms_path' )
106- .apply (
107- get_image_rms_measurements ,
108- meta = col_dtype
109- ).compute (num_workers = n_cpu , scheduler = 'processes' )
109+ # generate bags with dataframes for each unique image_diff
110+ cols = ['img_diff_rms_path' , 'source' , 'wavg_ra' , 'wavg_dec' ]
111+ df_bags = []
112+ for elem in uniq_img_diff :
113+ df_bags .append (df .loc [df ['img_diff_rms_path' ] == elem , cols ])
114+ df_bags = dd .compute (* df_bags )
115+ df_bags = db .from_sequence (df_bags , npartitions = nr_uniq_img )
116+
117+ # map the get_coord_array and column selection function
118+ arr_coords_bags = df_bags .map (get_coord_array )
119+ col_sel_bags = df_bags .map (lambda onedf : onedf [['source' ]])
120+
121+ # combine the bags and apply final operations, this will create a list
122+ # of pandas dataframes
123+ out_bag = (
124+ db .zip (img_data_bags , arr_coords_bags , col_sel_bags )
125+ .map (lambda tup : finalise_rms_calcs (* tup ))
126+ # .to_delayed()
127+ .persist ()
110128 )
129+ import ipdb ; ipdb .set_trace () # breakpoint 3b4d58cc //
130+
131+ # out = (
132+ # # dd.from_pandas(out, n_cpu)
133+ # df[['source', 'wavg_ra', 'wavg_dec', 'img_diff_rms_path']]
134+ # .groupby('img_diff_rms_path')
135+ # .apply(
136+ # get_image_rms_measurements,
137+ # meta=col_dtype
138+ # )
139+ # )
111140
112141 df = df .merge (
113142 out [['source' , 'img_diff_true_rms' ]],
@@ -118,82 +147,79 @@ def parallel_get_rms_measurements(df):
118147 return df
119148
120149
121- def new_sources (sources_df , missing_sources_df , min_sigma , p_run ):
150+ def new_sources (sources_df : dd .core .DataFrame ,
151+ missing_sources_df : dd .core .DataFrame , min_sigma : float , p_run : Run
152+ ) -> dd .core .DataFrame :
122153 """
123154 Process the new sources detected to see if they are
124155 valid.
125156 """
126157
127- timer = StopWatch ()
158+ # timer = StopWatch()
128159
129- logger .info (" Starting new source analysis." )
160+ logger .info (' Starting new source analysis.' )
130161
131162 cols = [
132- 'id' , ' name' , 'noise_path' , 'datetime' ,
163+ 'name' , 'noise_path' , 'datetime' ,
133164 'rms_median' , 'rms_min' , 'rms_max'
134165 ]
135166
136167 images_df = pd .DataFrame (list (
137- Image .objects .filter (
138- run = p_run
139- ).values (* tuple (cols ))
168+ Image .objects .filter (run = p_run )
169+ .values (* tuple (cols ))
140170 )).set_index ('name' )
141171
142172 # Get rid of sources that are not 'new', i.e. sources which the
143173 # first sky region image is not in the image list
144174
145175 missing_sources_df ['primary' ] = missing_sources_df [
146176 'skyreg_img_list'
147- ].apply (lambda x : x [0 ])
177+ ].apply (lambda x : x [0 ], meta = str )
148178
149179 missing_sources_df ['detection' ] = missing_sources_df [
150180 'img_list'
151- ].apply (lambda x : x [0 ])
181+ ].apply (lambda x : x [0 ], meta = str )
152182
153183 missing_sources_df ['in_primary' ] = missing_sources_df [
154184 ['primary' , 'img_list' ]
155- ].apply (
156- check_primary_image ,
157- axis = 1
158- )
185+ ].apply (check_primary_image , axis = 1 , meta = bool )
159186
160187 new_sources_df = missing_sources_df [
161188 missing_sources_df ['in_primary' ] == False
162- ].drop (
163- columns = ['in_primary' ]
164- )
189+ ].drop (columns = ['in_primary' ])
190+ del missing_sources_df
165191
166192 # Check if the previous sources would have actually been seen
167193 # i.e. are the previous images sensitive enough
168194
169- # save the index before exploding
170- new_sources_df = new_sources_df .reset_index ()
171-
172- # Explode now to avoid two loops below
173- new_sources_df = new_sources_df .explode ('img_diff' )
174-
195+ # save the index and explode now to avoid two loops below
175196 # Merge the respective image information to the df
176- new_sources_df = new_sources_df .merge (
177- images_df [['datetime' ]],
178- left_on = 'detection' ,
179- right_on = 'name' ,
180- how = 'left'
181- ).rename (columns = {'datetime' :'detection_time' })
182-
183- new_sources_df = new_sources_df .merge (
184- images_df [[
185- 'datetime' , 'rms_min' , 'rms_median' ,
186- 'noise_path'
187- ]],
188- left_on = 'img_diff' ,
189- right_on = 'name' ,
190- how = 'left'
191- ).rename (columns = {
192- 'datetime' :'img_diff_time' ,
193- 'rms_min' : 'img_diff_rms_min' ,
194- 'rms_median' : 'img_diff_rms_median' ,
195- 'noise_path' : 'img_diff_rms_path'
196- })
197+ new_sources_df = (
198+ new_sources_df .reset_index ()
199+ .explode ('img_diff' )
200+ .merge (
201+ images_df [['datetime' ]],
202+ left_on = 'detection' ,
203+ right_on = 'name' ,
204+ how = 'left'
205+ )
206+ .rename (columns = {'datetime' : 'detection_time' })
207+ # )
208+ # new_sources_df = (
209+ # new_sources_df.merge(
210+ .merge (
211+ images_df ,
212+ left_on = 'img_diff' ,
213+ right_on = 'name' ,
214+ how = 'left'
215+ )
216+ .rename (columns = {
217+ 'datetime' : 'img_diff_time' ,
218+ 'rms_min' : 'img_diff_rms_min' ,
219+ 'rms_median' : 'img_diff_rms_median' ,
220+ 'noise_path' : 'img_diff_rms_path'
221+ })
222+ )
197223
198224 # Select only those images that come before the detection image
199225 # in time.
@@ -202,11 +228,15 @@ def new_sources(sources_df, missing_sources_df, min_sigma, p_run):
202228 ]
203229
204230 # merge the detection fluxes in
205- new_sources_df = pd .merge (
206- new_sources_df , sources_df [['source' , 'image' , 'flux_peak' ]],
207- left_on = ['source' , 'detection' ], right_on = ['source' , 'image' ],
208- how = 'left'
209- ).drop (columns = ['image' ])
231+ new_sources_df = (
232+ new_sources_df .merge (
233+ sources_df ,
234+ left_on = ['source' , 'detection' ],
235+ right_on = ['source' , 'image' ],
236+ how = 'left'
237+ )
238+ .drop (columns = ['image' ])
239+ )
210240
211241 # calculate the sigma of the source if it was placed in the
212242 # minimum rms region of the previous images
@@ -236,9 +266,7 @@ def new_sources(sources_df, missing_sources_df, min_sigma, p_run):
236266
237267 # measure the actual rms in the previous images at
238268 # the source location.
239- new_sources_df = parallel_get_rms_measurements (
240- new_sources_df
241- )
269+ new_sources_df = parallel_get_rms_measurements (new_sources_df )
242270
243271 # this removes those that are out of range
244272 new_sources_df ['img_diff_true_rms' ] = new_sources_df ['img_diff_true_rms' ].fillna (0. )
@@ -262,8 +290,8 @@ def new_sources(sources_df, missing_sources_df, min_sigma, p_run):
262290 columns = {'true_sigma' :'new_high_sigma' }
263291 )
264292
265- logger .info (
266- 'Total new source analysis time: %.2f seconds' , timer .reset_init ()
267- )
293+ # logger.info(
294+ # 'Total new source analysis time: %.2f seconds', timer.reset_init()
295+ # )
268296
269- return new_sources_df .set_index ('source' )
297+ return new_sources_df .set_index ('source' ). persist ()
0 commit comments