Skip to content

Commit 63c71e6

Browse files
committed
successfully create sources models column
1 parent 44a82f3 commit 63c71e6

File tree

2 files changed

+33
-20
lines changed

2 files changed

+33
-20
lines changed

vast_pipeline/pipeline/finalise.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import os
22
import logging
33
import pandas as pd
4+
import dask.array as da
5+
import dask.bag as db
6+
import dask.dataframe as dd
47

8+
from typing import List
59
from astropy import units as u
610
from astropy.coordinates import SkyCoord
711

8-
from vast_pipeline.models import Association, RelatedSource
12+
from vast_pipeline.models import Association, RelatedSource, Run
913
from vast_pipeline.utils.utils import StopWatch
1014

1115
from .loading import (
@@ -18,58 +22,67 @@
1822

1923

2024
def final_operations(
21-
sources_df, first_img, p_run, meas_dj_obj, new_sources_df):
25+
sources_df: dd.DataFrame, first_img: str, p_run: Run, meas_dj_obj: List,
26+
new_sources_df: dd.DataFrame) -> int:
2227
timer = StopWatch()
2328

2429
# calculate source fields
2530
logger.info(
26-
'Calculating statistics for %i sources...',
27-
sources_df.source.unique().shape[0]
31+
'Calculating statistics for sources...'
32+
# 'Calculating statistics for %i sources...',
33+
# sources_df.source.unique().shape[0]
2834
)
2935

3036
timer.reset()
3137
srcs_df = parallel_groupby(sources_df)
32-
logger.info('Groupby-apply time: %.2f seconds', timer.reset())
38+
# logger.info('Groupby-apply time: %.2f seconds', timer.reset())
3339
# fill NaNs as resulted from calculated metrics with 0
3440
srcs_df = srcs_df.fillna(0.)
3541

3642
# add new sources
37-
srcs_df['new'] = srcs_df.index.isin(new_sources_df.index)
43+
srcs_df['new'] = srcs_df.index.isin(
44+
new_sources_df.index.values.compute()
45+
)
3846

39-
srcs_df = pd.merge(
40-
srcs_df,
41-
new_sources_df['new_high_sigma'],
42-
left_on='source', right_index=True, how='left'
47+
srcs_df = srcs_df.merge(
48+
new_sources_df[['new_high_sigma']],
49+
left_index=True, right_index=True, how='left'
4350
)
4451

4552
srcs_df['new_high_sigma'] = srcs_df['new_high_sigma'].fillna(0.)
4653

4754
# calculate nearest neighbour
48-
srcs_skycoord = SkyCoord(
49-
srcs_df['wavg_ra'],
50-
srcs_df['wavg_dec'],
51-
unit=(u.deg, u.deg)
52-
)
55+
ra, dec = dd.compute(srcs_df['wavg_ra'], srcs_df['wavg_dec'])
56+
srcs_skycoord = SkyCoord(ra, dec, unit=(u.deg, u.deg))
57+
del ra, dec
5358

5459
idx, d2d, _ = srcs_skycoord.match_to_catalog_sky(
5560
srcs_skycoord,
5661
nthneighbor=2
5762
)
58-
59-
srcs_df['n_neighbour_dist'] = d2d.deg
63+
# Dask doen't like assignment of a column from an array, need to be
64+
# a Dask array with partitions, chuncks in this case
65+
# TODO: checks if using shape will be faster than len (FYI each partition is
66+
# a pandas dataframe)
67+
arr_chuncks = tuple(srcs_df.map_partitions(len).compute())
68+
srcs_df['n_neighbour_dist'] = da.from_array(d2d.deg, chunks=arr_chuncks)
69+
del arr_chuncks, idx, d2d, srcs_skycoord
6070

6171
# generate the source models
6272
srcs_df['src_dj'] = srcs_df.apply(
6373
get_source_models,
6474
pipeline_run=p_run,
65-
axis=1
75+
axis=1,
76+
meta=object
6677
)
78+
import ipdb; ipdb.set_trace() # breakpoint 0eaeae7c //
6779
# upload sources and related to DB
6880
upload_sources(p_run, srcs_df)
6981

7082
# get db ids for sources
71-
srcs_df['id'] = srcs_df['src_dj'].apply(lambda x: x.id)
83+
srcs_df['id'] = srcs_df['src_dj'].apply(lambda x: x.id, meta=int)
7284

85+
import ipdb; ipdb.set_trace() # breakpoint bcf8f142 //
7386
# gather the related df, upload to db and save to parquet file
7487
# the df will look like
7588
#

vast_pipeline/pipeline/new_sources.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ def new_sources(sources_df: dd.core.DataFrame,
279279
new_sources_df.set_index('true_sigma')
280280
.map_partitions(lambda x: x.sort_index())
281281
.drop_duplicates('source')
282-
.rename(columns={'true_sigma':'new_high_sigma'})
283282
.reset_index()
283+
.rename(columns={'true_sigma': 'new_high_sigma'})
284284
.set_index('source')
285285
)
286286

0 commit comments

Comments
 (0)