Skip to content
100 changes: 98 additions & 2 deletions synthpop/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import pandas as pd
from scipy.stats import chisquare
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor

from . import categorizer as cat
from . import draw
Expand Down Expand Up @@ -75,6 +77,7 @@ def synthesize(h_marg, p_marg, h_jd, p_jd, h_pums, p_pums, geography, ignore_max
ignore_max_iters)
logger.info("Time to run ipu: %.3fs" % (time.time()-t1))

logger.info("Time to run ipu: %.3fs" % (time.time() - t1))
logger.debug("IPU weights:")
logger.debug(best_weights.describe())
logger.debug("Fit quality:")
Expand All @@ -91,6 +94,100 @@ def synthesize(h_marg, p_marg, h_jd, p_jd, h_pums, p_pums, geography, ignore_max
num_households, h_pums, p_pums, household_freq, h_constraint,
p_constraint, best_weights, hh_index_start=hh_index_start)

def geog_preprocessing(geog_id, recipe, marginal_zero_sub, jd_zero_sub,
hh_index_start):
h_marg = recipe.get_household_marginal_for_geography(geog_id)
logger.debug("Household marginal")
logger.debug(h_marg)

p_marg = recipe.get_person_marginal_for_geography(geog_id)
logger.debug("Person marginal")
logger.debug(p_marg)

h_pums, h_jd = recipe.\
get_household_joint_dist_for_geography(geog_id)
logger.debug("Household joint distribution")
logger.debug(h_jd)

p_pums, p_jd = recipe.get_person_joint_dist_for_geography(geog_id)
logger.debug("Person joint distribution")
logger.debug(p_jd)

return h_marg, p_marg, h_jd, p_jd, h_pums, p_pums, marginal_zero_sub,\
jd_zero_sub, hh_index_start

def synthesize_all_in_parallel(
recipe, num_geogs=None, indexes=None, marginal_zero_sub=.01,
jd_zero_sub=.001, max_workers=5, hh_index_start=0):
"""
Returns
-------
households, people : pandas.DataFrame
fit_quality : dict of FitQuality
Keys are geographic IDs, values are namedtuples with attributes
``.household_chisq``, ``household_p``, ``people_chisq``,
and ``people_p``.
"""
with ProcessPoolExecutor(max_workers) as ex:
if indexes is None:
indexes = recipe.get_available_geography_ids()

hh_list = []
people_list = []
cnt = 0
fit_quality = {}
geog_synth_args = []
finished_args = []
geog_ids = []
futures = []

print('Submitting function args for parallel processing:')
for i, geog_id in enumerate(indexes):
geog_synth_args.append(ex.submit(
geog_preprocessing, geog_id, recipe, marginal_zero_sub,
jd_zero_sub, hh_index_start))
geog_ids.append(geog_id)
cnt += 1
if num_geogs is not None and cnt >= num_geogs:
break


with ProcessPoolExecutor(max_workers=5) as ex:
futures = [
ex.submit(synthesize, *geog_args.result()) for geog_args in geog_synth_args]

print('Processing results:')
for i, future in tqdm(enumerate(futures), total=len(futures)):
geog_id = geog_ids[i]
print ('Processing results for: ', geog_id)
try:
households, people, people_chisq, people_p = future.result()
except:
raise ValueError('The synthesis failed for geog_id: {}'.format(geog_id))
else:
# Append location identifiers to the synthesized households
for geog_cat in geog_id.keys():
households[geog_cat] = geog_id[geog_cat]

# update the household_ids since we can't do it in the call to
# synthesize when we execute in parallel
households.index += hh_index_start
people.hh_id += hh_index_start

hh_list.append(households)
people_list.append(people)
key = BlockGroupID(
geog_id['state'], geog_id['county'], geog_id['tract'],
geog_id['block group'])
fit_quality[key] = FitQuality(people_chisq, people_p)

if len(households) > 0:
hh_index_start = households.index.values[-1] + 1

all_households = pd.concat(hh_list)
all_persons = pd.concat(people_list, ignore_index=True)

return (all_households, all_persons, fit_quality)

def synthesize_all(recipe, num_geogs=None, indexes=None, ignore_max_iters=False,
marginal_zero_sub=.01, jd_zero_sub=.001):
Expand All @@ -116,8 +213,7 @@ def synthesize_all(recipe, num_geogs=None, indexes=None, ignore_max_iters=False,
fit_quality = {}
hh_index_start = 0

# TODO will parallelization work here?
for geog_id in indexes:
for geog_id in tqdm(indexes, total=num_geogs):
print("Synthesizing geog id:\n", geog_id)

h_marg = recipe.get_household_marginal_for_geography(geog_id)
Expand Down