-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 81e3702
Showing
28 changed files
with
5,666 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
## To replicate the experiments in the paper: | ||
|
||
### Step 0: Environment and Prerequisites | ||
Run the following commands to clone this repo and create the Conda environment: | ||
|
||
``` | ||
git clone https://github.com/hzhang0/CovidForecast.git | ||
cd CovidForecast | ||
conda env create -f environment.yml | ||
conda activate covidforecast | ||
``` | ||
|
||
### Step 1: Obtaining the Data | ||
Update timestamps as appropriate in `lib/Constants.py`, then run: | ||
``` | ||
python get_country_data.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
name: covidforecast | ||
dependencies: | ||
- python=3.6 | ||
- pip | ||
- pip: | ||
- -r file:requirements.txt |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#!/bin/bash | ||
set -e | ||
for i in "" "--use_infections"; do | ||
python evaluate_models.py 2021-03-08 2021-04-03 --forecast_hub_dir /scratch/hdd001/home/haoran/CovidProjections/covid19-forecast-hub --out_dir evaluations/ ${i} | ||
python evaluate_models.py 2021-03-08 2021-04-17 --forecast_hub_dir /scratch/hdd001/home/haoran/CovidProjections/covid19-forecast-hub --out_dir evaluations/ ${i} | ||
python evaluate_models.py 2020-12-28 2021-01-23 --forecast_hub_dir /scratch/hdd001/home/haoran/CovidProjections/covid19-forecast-hub --out_dir evaluations/ ${i} | ||
python evaluate_models.py 2020-12-28 2021-02-06 --forecast_hub_dir /scratch/hdd001/home/haoran/CovidProjections/covid19-forecast-hub --out_dir evaluations/ ${i} | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import numpy as np | ||
from lib import Constants | ||
import pandas as pd | ||
from itertools import product | ||
|
||
def combinations(grid): | ||
return list(dict(zip(grid.keys(), values)) for values in product(*grid.values())) | ||
|
||
def get_script_name(experiment): | ||
if experiment not in globals(): | ||
raise NotImplementedError(experiment) | ||
return globals()[experiment].fname | ||
|
||
def add_common_args(parser): | ||
parser.add_argument('--model', type = str, choices = ['lode', 'gruode', 'gru', 'vgru'], required = True) | ||
parser.add_argument('--data_type', type = str, choices = ['all', 'debug', 'US'], required = True) | ||
parser.add_argument('--randomize_training', action = 'store_true') | ||
parser.add_argument('--concat_cond_ts', action = 'store_true', help = 'concatenate interventions as time series to z instead of statically to z0. \ | ||
Affects lode and vgru.') | ||
return parser | ||
|
||
|
||
#### Experiment Bases | ||
class ExpBase: | ||
def get_hparams(self): | ||
return combinations(self.hparams) | ||
|
||
class LatentODEBase(ExpBase): | ||
fname = "train.py" | ||
def __init__(self): | ||
grid = { | ||
'model': ['lode'], | ||
'latent_dim': [32], | ||
'n_layer': [2, 3], | ||
'n_units': [32, 64], | ||
'feature_set': ['condensed'], | ||
'final_activation': ['none'], | ||
'target_type': ['log'], | ||
'reconstr_weight': [0.0, 1.0], | ||
'ohe_features': [False], | ||
'lr': [1e-3], | ||
'batch_size': [1024], # full batch training | ||
'cond_features': ['all', 'stringency', 'none'], | ||
'smoothed': [True] | ||
} | ||
self.hparams = { **self.hparams, **grid} if 'hparams' in self.__dict__ else grid | ||
|
||
class VGRUBase(ExpBase): | ||
fname = "train.py" | ||
def __init__(self): | ||
grid = { | ||
'model': ['vgru'], | ||
'latent_dim': [8, 32], | ||
'n_layer': [2, 3], | ||
'n_units': [32, 64], | ||
'feature_set': ['condensed'], | ||
'final_activation': ['none'], | ||
'target_type': ['shifted_log'], | ||
'reconstr_weight': [0.0, 1.0], | ||
'ohe_features': [False], | ||
'lr': [1e-3], | ||
'batch_size': [256], | ||
'dropout_p': [0.0, 0.25], | ||
'cond_features': ['stringency'], | ||
'smoothed': [True], | ||
'include_counties': [True], | ||
} | ||
self.hparams = { **self.hparams, **grid} if 'hparams' in self.__dict__ else grid | ||
|
||
class GRUBase(ExpBase): | ||
fname = "train.py" | ||
def __init__(self): | ||
grid = { | ||
'model': ['gru'], | ||
'latent_dim': [8, 16, 32], | ||
'n_units': [32, 64], | ||
'feature_set': ['condensed'], | ||
'final_activation': ['none'], | ||
'target_type': ['shifted_log'], | ||
'reconstr_weight': [0.0], | ||
'ohe_features': [False], | ||
'lr': [1e-3], | ||
'batch_size': [256], | ||
'cond_features': ['stringency'], | ||
'smoothed': [True], | ||
'include_counties': [True], | ||
} | ||
self.hparams = { **self.hparams, **grid} if 'hparams' in self.__dict__ else grid | ||
|
||
|
||
class GRUODEBase(ExpBase): | ||
fname = "train.py" | ||
def __init__(self): | ||
grid = { | ||
'model': ['gruode'], | ||
'latent_dim': [8, 32], | ||
'n_layer': [2, 3], | ||
'n_units': [32, 64], | ||
'feature_set': ['condensed'], | ||
'final_activation': ['none'], | ||
'target_type': ['shifted_log'], | ||
'reconstr_weight': [0.0], | ||
'ohe_features': [False], | ||
'lr': [1e-3], | ||
'batch_size': [256], | ||
'cond_features': ['stringency'], | ||
'smoothed': [True], | ||
'include_counties': [True], | ||
} | ||
self.hparams = { **self.hparams, **grid} if 'hparams' in self.__dict__ else grid | ||
|
||
class WeekBase(ExpBase): | ||
def __init__(self, ndays, anchor_date): | ||
grid = { | ||
'date_cutoff': [str(pd.Timestamp(anchor_date))], | ||
'n_val_days': [ndays] | ||
} | ||
self.hparams = { **self.hparams, **grid} if 'hparams' in self.__dict__ else grid | ||
|
||
#### write experiments here | ||
def get_exp_name(args): | ||
return args.model + '_' + str(args.anchor_date) + '_' + str(args.n_weeks_ahead) + '_' + str(args.data_type) + ('_randomize' if args.randomize_training else '') \ | ||
+ ('_cond_ts' if args.concat_cond_ts else '') | ||
|
||
def get_hparams(args): | ||
model_bases = { | ||
'vgru': VGRUBase, | ||
'lode': LatentODEBase, | ||
'gru': GRUBase, | ||
'gruode': GRUODEBase | ||
} | ||
model_base_class = model_bases[args.model] | ||
|
||
class Experiment(model_base_class, WeekBase): | ||
def __init__(self): | ||
model_base_class.__init__(self) | ||
WeekBase.__init__(self, args.n_weeks_ahead * 7, args.anchor_date) | ||
|
||
grid = { | ||
'experiment_name': [get_exp_name(args)], | ||
'data_type': [args.data_type], | ||
'randomize_training': [args.randomize_training], | ||
'concat_cond_ts': [args.concat_cond_ts] | ||
} | ||
self.hparams = { **self.hparams, **grid} if 'hparams' in self.__dict__ else grid | ||
|
||
experiment = Experiment() | ||
return experiment.get_hparams() | ||
|
||
|
||
|
||
# class LatentODEBasicFourWeek(LatentODEBase, FourWeek): | ||
# def __init__(self): | ||
# self.hparams = { | ||
# 'data_type': ['US'], | ||
# 'smoothed': [True], | ||
# } | ||
# LatentODEBase.__init__(self) | ||
# FourWeek.__init__(self) | ||
|
||
# class RandomizedCondIWLatentODEFourWeek(LatentODEBase, FourWeek): | ||
# def __init__(self): | ||
# self.hparams = { | ||
# 'data_type': ['US'], | ||
# 'smoothed': [True], | ||
# 'noise_std': [1, 0.1], | ||
# 'cond_inds': ['[-1]'], | ||
# 'elbo_type': ['iwae'], | ||
# 'randomize_training': [True] | ||
# } | ||
# LatentODEBase.__init__(self) | ||
# FourWeek.__init__(self) | ||
|
||
# class LatentODEVAEHyperGrid(LatentODEBase, FourWeek): | ||
# def __init__(self): | ||
# self.hparams = { | ||
# 'data_type': ['US'], | ||
# 'smoothed': [True], | ||
# 'n_train_trajectories': [10, 25, 50], | ||
# } | ||
# LatentODEBase.__init__(self) | ||
# FourWeek.__init__(self) |
Oops, something went wrong.