Skip to content

Commit 86043a6

Browse files
committed
WIP Dask Cluster
1 parent f726605 commit 86043a6

File tree

2 files changed

+138
-110
lines changed

2 files changed

+138
-110
lines changed

requirements/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
aplpy==2.0.3
22
astropy==4.0
33
astroquery==0.4
4-
cloudpickle==1.5.0
4+
cloudpickle==1.6.0
55
dask[complete]==2.15.0
66
dill==0.3.1.1
77
distributed==2.26.0

vast_pipeline/pipeline/new_sources.py

Lines changed: 137 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import pandas as pd
44
import numpy as np
55
import dask.dataframe as dd
6+
import dask.bag as db
67

78
from psutil import cpu_count
89
from astropy import units as u
910
from astropy.coordinates import SkyCoord
1011
from astropy.io import fits
1112
from astropy.wcs import WCS
1213
from astropy.wcs.utils import skycoord_to_pixel
14+
from typing import Dict
1315

1416
from vast_pipeline.models import Image
1517
from vast_pipeline.utils.utils import StopWatch
@@ -18,96 +20,123 @@
1820
logger = logging.getLogger(__name__)
1921

2022

21-
def check_primary_image(row):
23+
def check_primary_image(row: pd.Series) -> bool:
2224
return row['primary'] in row['img_list']
2325

2426

25-
def get_image_rms_measurements(group):
26-
"""
27-
Take the coordinates provided from the group
28-
and measure the array value in the provided image.
29-
"""
30-
image = group.iloc[0]['img_diff_rms_path']
31-
27+
def extract_rms_data_from_img(image: str) -> Dict:
3228
with fits.open(image) as hdul:
33-
wcs = WCS(hdul[0].header, naxis=2)
29+
wcs = WCS(hdul[0].header, naxis=2)
30+
try:
31+
# ASKAP tile images
32+
data = hdul[0].data[0, 0, :, :]
33+
except Exception as e:
34+
# ASKAP SWarp images
35+
data = hdul[0].data
36+
37+
return {'data': data, 'wcs': wcs}
3438

35-
try:
36-
# ASKAP tile images
37-
data = hdul[0].data[0, 0, :, :]
38-
except Exception as e:
39-
# ASKAP SWarp images
40-
data = hdul[0].data
4139

40+
def get_coord_array(df: pd.DataFrame) -> SkyCoord:
4241
coords = SkyCoord(
43-
group.wavg_ra, group.wavg_dec, unit=(u.deg, u.deg)
42+
df['wavg_ra'].values,
43+
df['wavg_dec'].values,
44+
unit=(u.deg, u.deg)
4445
)
4546

46-
array_coords = wcs.world_to_array_index(coords)
47+
return coords
48+
49+
50+
def finalise_rms_calcs(rms: Dict, coords: np.array,
51+
df: pd.DataFrame) -> pd.DataFrame:
52+
"""
53+
Take the coordinates provided from the group
54+
and measure the array value in the provided image.
55+
"""
56+
array_coords = rms['wcs'].world_to_array_index(coords)
4757
array_coords = np.array([
4858
np.array(array_coords[0]),
4959
np.array(array_coords[1]),
5060
])
5161

5262
# check for pixel wrapping
5363
x_valid = np.logical_or(
54-
array_coords[0] >= data.shape[0],
64+
array_coords[0] >= rms['data'].shape[0],
5565
array_coords[0] < 0
5666
)
5767

5868
y_valid = np.logical_or(
59-
array_coords[1] >= data.shape[1],
69+
array_coords[1] >= rms['data'].shape[1],
6070
array_coords[1] < 0
6171
)
6272

63-
valid = ~np.logical_or(
64-
x_valid, y_valid
65-
)
66-
67-
valid_indexes = group[valid].index.values
73+
# calculated the mask for indexes
74+
valid = ~np.logical_or(x_valid, y_valid)
6875

69-
rms_values = data[
76+
# create the column data, not matched ones will be NaN.
77+
rms_values = np.full(valid.shape, np.NaN)
78+
rms_values[valid] = rms['data'][
7079
array_coords[0][valid],
7180
array_coords[1][valid]
72-
]
81+
].astype(np.float64) * 1.e3
7382

74-
# not matched ones will be NaN.
75-
group.loc[
76-
valid_indexes, 'img_diff_true_rms'
77-
] = rms_values.astype(np.float64) * 1.e3
83+
# copy the df and create the rms column
84+
df_out = df.copy()# dask doesn't like to modify inputs in place
85+
df_out['img_diff_true_rms'] = rms_values
7886

79-
return group
87+
return df_out
8088

8189

82-
def parallel_get_rms_measurements(df):
90+
def parallel_get_rms_measurements(df: dd.core.DataFrame) -> dd.core.DataFrame:
8391
"""
8492
Wrapper function to use 'get_image_rms_measurements'
8593
in parallel with Dask.
8694
"""
95+
# Use the Dask bag backend to work on different image files
96+
# calculate first the unique image_diff then create the bag
97+
uniq_img_diff = (
98+
df['img_diff_rms_path'].unique()
99+
.compute()
100+
.tolist()
101+
)
102+
nr_uniq_img = len(uniq_img_diff)
103+
# map the extract function to the bag to get data from images
104+
img_data_bags = (
105+
db.from_sequence(uniq_img_diff, npartitions=nr_uniq_img)
106+
.map(extract_rms_data_from_img)
107+
)
87108

88-
out = df[[
89-
'source', 'wavg_ra', 'wavg_dec',
90-
'img_diff_rms_path'
91-
]]
92-
93-
col_dtype = {
94-
'source': 'i',
95-
'wavg_ra': 'f',
96-
'wavg_dec': 'f',
97-
'img_diff_rms_path': 'U',
98-
'img_diff_true_rms': 'f',
99-
}
100-
101-
n_cpu = cpu_count() - 1
102-
103-
out = (
104-
dd.from_pandas(out, n_cpu)
105-
.groupby('img_diff_rms_path')
106-
.apply(
107-
get_image_rms_measurements,
108-
meta=col_dtype
109-
).compute(num_workers=n_cpu, scheduler='processes')
109+
# generate bags with dataframes for each unique image_diff
110+
cols = ['img_diff_rms_path', 'source', 'wavg_ra', 'wavg_dec']
111+
df_bags = []
112+
for elem in uniq_img_diff:
113+
df_bags.append(df.loc[df['img_diff_rms_path'] == elem, cols])
114+
df_bags = dd.compute(*df_bags)
115+
df_bags = db.from_sequence(df_bags, npartitions=nr_uniq_img)
116+
117+
# map the get_coord_array and column selection function
118+
arr_coords_bags = df_bags.map(get_coord_array)
119+
col_sel_bags = df_bags.map(lambda onedf: onedf[['source']])
120+
121+
# combine the bags and apply final operations, this will create a list
122+
# of pandas dataframes
123+
out_bag = (
124+
db.zip(img_data_bags, arr_coords_bags, col_sel_bags)
125+
.map(lambda tup: finalise_rms_calcs(*tup))
126+
# .to_delayed()
127+
.persist()
110128
)
129+
import ipdb; ipdb.set_trace() # breakpoint 3b4d58cc //
130+
131+
# out = (
132+
# # dd.from_pandas(out, n_cpu)
133+
# df[['source', 'wavg_ra', 'wavg_dec', 'img_diff_rms_path']]
134+
# .groupby('img_diff_rms_path')
135+
# .apply(
136+
# get_image_rms_measurements,
137+
# meta=col_dtype
138+
# )
139+
# )
111140

112141
df = df.merge(
113142
out[['source', 'img_diff_true_rms']],
@@ -118,82 +147,79 @@ def parallel_get_rms_measurements(df):
118147
return df
119148

120149

121-
def new_sources(sources_df, missing_sources_df, min_sigma, p_run):
150+
def new_sources(sources_df: dd.core.DataFrame,
151+
missing_sources_df: dd.core.DataFrame, min_sigma: float, p_run: Run
152+
) -> dd.core.DataFrame:
122153
"""
123154
Process the new sources detected to see if they are
124155
valid.
125156
"""
126157

127-
timer = StopWatch()
158+
# timer = StopWatch()
128159

129-
logger.info("Starting new source analysis.")
160+
logger.info('Starting new source analysis.')
130161

131162
cols = [
132-
'id', 'name', 'noise_path', 'datetime',
163+
'name', 'noise_path', 'datetime',
133164
'rms_median', 'rms_min', 'rms_max'
134165
]
135166

136167
images_df = pd.DataFrame(list(
137-
Image.objects.filter(
138-
run=p_run
139-
).values(*tuple(cols))
168+
Image.objects.filter(run=p_run)
169+
.values(*tuple(cols))
140170
)).set_index('name')
141171

142172
# Get rid of sources that are not 'new', i.e. sources which the
143173
# first sky region image is not in the image list
144174

145175
missing_sources_df['primary'] = missing_sources_df[
146176
'skyreg_img_list'
147-
].apply(lambda x: x[0])
177+
].apply(lambda x: x[0], meta=str)
148178

149179
missing_sources_df['detection'] = missing_sources_df[
150180
'img_list'
151-
].apply(lambda x: x[0])
181+
].apply(lambda x: x[0], meta=str)
152182

153183
missing_sources_df['in_primary'] = missing_sources_df[
154184
['primary', 'img_list']
155-
].apply(
156-
check_primary_image,
157-
axis=1
158-
)
185+
].apply(check_primary_image, axis=1, meta=bool)
159186

160187
new_sources_df = missing_sources_df[
161188
missing_sources_df['in_primary'] == False
162-
].drop(
163-
columns=['in_primary']
164-
)
189+
].drop(columns=['in_primary'])
190+
del missing_sources_df
165191

166192
# Check if the previous sources would have actually been seen
167193
# i.e. are the previous images sensitive enough
168194

169-
# save the index before exploding
170-
new_sources_df = new_sources_df.reset_index()
171-
172-
# Explode now to avoid two loops below
173-
new_sources_df = new_sources_df.explode('img_diff')
174-
195+
# save the index and explode now to avoid two loops below
175196
# Merge the respective image information to the df
176-
new_sources_df = new_sources_df.merge(
177-
images_df[['datetime']],
178-
left_on='detection',
179-
right_on='name',
180-
how='left'
181-
).rename(columns={'datetime':'detection_time'})
182-
183-
new_sources_df = new_sources_df.merge(
184-
images_df[[
185-
'datetime', 'rms_min', 'rms_median',
186-
'noise_path'
187-
]],
188-
left_on='img_diff',
189-
right_on='name',
190-
how='left'
191-
).rename(columns={
192-
'datetime':'img_diff_time',
193-
'rms_min': 'img_diff_rms_min',
194-
'rms_median': 'img_diff_rms_median',
195-
'noise_path': 'img_diff_rms_path'
196-
})
197+
new_sources_df = (
198+
new_sources_df.reset_index()
199+
.explode('img_diff')
200+
.merge(
201+
images_df[['datetime']],
202+
left_on='detection',
203+
right_on='name',
204+
how='left'
205+
)
206+
.rename(columns={'datetime': 'detection_time'})
207+
# )
208+
# new_sources_df = (
209+
# new_sources_df.merge(
210+
.merge(
211+
images_df,
212+
left_on='img_diff',
213+
right_on='name',
214+
how='left'
215+
)
216+
.rename(columns={
217+
'datetime': 'img_diff_time',
218+
'rms_min': 'img_diff_rms_min',
219+
'rms_median': 'img_diff_rms_median',
220+
'noise_path': 'img_diff_rms_path'
221+
})
222+
)
197223

198224
# Select only those images that come before the detection image
199225
# in time.
@@ -202,11 +228,15 @@ def new_sources(sources_df, missing_sources_df, min_sigma, p_run):
202228
]
203229

204230
# merge the detection fluxes in
205-
new_sources_df = pd.merge(
206-
new_sources_df, sources_df[['source', 'image', 'flux_peak']],
207-
left_on=['source', 'detection'], right_on=['source', 'image'],
208-
how='left'
209-
).drop(columns=['image'])
231+
new_sources_df = (
232+
new_sources_df.merge(
233+
sources_df,
234+
left_on=['source', 'detection'],
235+
right_on=['source', 'image'],
236+
how='left'
237+
)
238+
.drop(columns=['image'])
239+
)
210240

211241
# calculate the sigma of the source if it was placed in the
212242
# minimum rms region of the previous images
@@ -236,9 +266,7 @@ def new_sources(sources_df, missing_sources_df, min_sigma, p_run):
236266

237267
# measure the actual rms in the previous images at
238268
# the source location.
239-
new_sources_df = parallel_get_rms_measurements(
240-
new_sources_df
241-
)
269+
new_sources_df = parallel_get_rms_measurements(new_sources_df)
242270

243271
# this removes those that are out of range
244272
new_sources_df['img_diff_true_rms'] = new_sources_df['img_diff_true_rms'].fillna(0.)
@@ -262,8 +290,8 @@ def new_sources(sources_df, missing_sources_df, min_sigma, p_run):
262290
columns={'true_sigma':'new_high_sigma'}
263291
)
264292

265-
logger.info(
266-
'Total new source analysis time: %.2f seconds', timer.reset_init()
267-
)
293+
# logger.info(
294+
# 'Total new source analysis time: %.2f seconds', timer.reset_init()
295+
# )
268296

269-
return new_sources_df.set_index('source')
297+
return new_sources_df.set_index('source').persist()

0 commit comments

Comments
 (0)