Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Getting started on a lightning model #164

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
*.pyc
__pycache__
*.pkl

# testing
.pytest_cache/
Expand Down Expand Up @@ -31,6 +32,7 @@ data/analysis/*
data/interim/*
data/features/*
data/models/*
data/*models
Code
scripts/variables*.txt

Expand All @@ -43,3 +45,5 @@ switch_data_dirs.sh
.vscode
# vscode options
settings.json

lightning_logs
3 changes: 3 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ dependencies:
- pytest-cov
- mypy=0.720
- black
- pip
- pip:
- pytorch-lightning==0.7.1
224 changes: 224 additions & 0 deletions scripts/experiments/17_static_static_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
"""
# from drought TO RUNOFF
mv interim interim_; mv features features_; mv features__ features; mv interim__ interim

# from runoff TO DROUGHT
mv features features__; mv interim interim__; mv interim_ interim ; mv features_ features

# Experiment #7 is VCI3M results
# Experiment #9 is boku_VCI results

mv features 5_features; mv models 5_models;
mv 9_features features; mv 9_models models;
"""

import sys

sys.path.append("../..")

from scripts.utils import _rename_directory, get_data_path
from _base_models import regression, linear_nn, rnn, earnn, persistence, climatology
from src.engineer import Engineer
from pathlib import Path
from typing import Optional, List


def rename_features_dir(data_path: Path):
"""increment the features dir by 1"""
old_paths = [d for d in data_path.glob("*_features*")]
if old_paths == []:
integer = 0
else:
old_max = max([int(p.name.split("_")[0]) for p in old_paths])
integer = old_max + 1

_rename_directory(
from_path=data_path / "features",
to_path=data_path / f"{integer}_features",
with_datetime=False,
)


def rename_models_dir(data_path: Path):
old_paths = [d for d in data_path.glob("*_models*")]
if old_paths == []:
integer = 0
else:
old_max = max([int(p.name.split("_")[0]) for p in old_paths])
integer = old_max + 1

_rename_directory(
from_path=data_path / "models",
to_path=data_path / f"{integer}_models",
with_datetime=False,
)


def engineer(
pred_months=3,
target_var="boku_VCI",
process_static=False,
global_means: bool = True,
log_vars: Optional[List[str]] = None,
):
engineer = Engineer(
get_data_path(), experiment="one_month_forecast", process_static=process_static
)
engineer.engineer(
test_year=[y for y in range(2016, 2019)],
target_variable=target_var,
pred_months=pred_months,
expected_length=pred_months,
global_means=global_means,
)


if __name__ == "__main__":
data_path = get_data_path()

# ----------------------------------
# Setup the experiment
# ----------------------------------
# check if features or models exists
if (data_path / "features").exists():
rename_features_dir(data_path)
if (data_path / "models").exists():
rename_models_dir(data_path)

# ----------------------------------
# Run the Experiment
# ----------------------------------
# 1. Run the engineer
target_var = "boku_VCI" #  "VCI3M" "boku_VCI"
pred_months = 3
engineer(
pred_months=pred_months,
target_var=target_var,
process_static=True,
global_means=True,
)

# NOTE: why have we downloaded 2 variables for ERA5 evaporaton
# important_vars = ["VCI", "precip", "t2m", "pev", "p0005", "SMsurf", "SMroot"]
# always_ignore_vars = ["ndvi", "p84.162", "sp", "tp", "Eb", "E", "p0001"]
important_vars = ["boku_VCI", "precip", "t2m", "pev", "E", "SMsurf"]

# NOTE: if commented out then INCLUDED in the model
always_ignore_vars = [
"VCI",
"p84.162",
"sp",
"tp",
"Eb",
"VCI1M",
"RFE1M",
"VCI3M",
# "boku_VCI",
"modis_ndvi",
"SMroot",
"lc_class", # remove for good clustering (?)
"lc_class_group", # remove for good clustering (?)
"slt", # remove for good clustering (?)
"no_data_one_hot",
"lichens_and_mosses_one_hot",
"permanent_snow_and_ice_one_hot",
"urban_areas_one_hot",
"water_bodies_one_hot",
"t2m",
"SMsurf",
# "pev",
# "e",
"E",
]

assert target_var not in always_ignore_vars
other_target = "boku_VCI" if target_var == "VCI3M" else "VCI3M"
assert other_target in always_ignore_vars

# -------------
# Model Parameters
# -------------
num_epochs = 50
early_stopping = 10
hidden_size = 256
static_size = 64
# normalize_y = True

# -------------
# baseline models
# -------------
persistence()
climatology()

regression(
ignore_vars=always_ignore_vars,
experiment="one_month_forecast",
include_pred_month=True,
surrounding_pixels=None,
explain=False,
)

# # gbdt(ignore_vars=always_ignore_vars)
linear_nn(
ignore_vars=always_ignore_vars,
experiment="one_month_forecast",
include_pred_month=True,
surrounding_pixels=None,
explain=False,
num_epochs=num_epochs,
early_stopping=early_stopping,
layer_sizes=[hidden_size],
include_latlons=True,
include_yearly_aggs=False,
clear_nans=True,
)

# -------------
# LSTM
# -------------
rnn(
experiment="one_month_forecast",
include_pred_month=True,
surrounding_pixels=None,
explain=False,
static="features",
ignore_vars=always_ignore_vars,
num_epochs=num_epochs,
early_stopping=early_stopping,
hidden_size=hidden_size,
include_latlons=True,
include_yearly_aggs=False,
clear_nans=True,
weight_observations=False,
)

# -------------
# EALSTM
# -------------
earnn(
experiment="one_month_forecast",
include_pred_month=True,
surrounding_pixels=None,
pretrained=False,
explain=False,
static="features",
ignore_vars=always_ignore_vars,
num_epochs=num_epochs,
early_stopping=early_stopping,
hidden_size=hidden_size,
static_embedding_size=static_size,
include_latlons=True,
include_yearly_aggs=False,
clear_nans=True,
weight_observations=False,
pred_month_static=False,
)

# rename the output file
data_path = get_data_path()

# _rename_directory(
# from_path=data_path / "models" / "one_month_forecast",
# to_path=data_path / "models" / "one_month_forecast_BASE_static_vars",
# with_datetime=True,
# )
15 changes: 7 additions & 8 deletions scripts/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from src.preprocess.admin_boundaries import KenyaAdminPreprocessor

from scripts.utils import get_data_path
from typing import Optional


def process_vci_2018():
Expand Down Expand Up @@ -55,19 +56,16 @@ def process_era5POS_2018():
)


def process_era5_land(variable: str):
if Path(".").absolute().as_posix().split("/")[-1] == "ml_drought":
data_path = Path("data")
else:
data_path = Path("../data")
regrid_path = data_path / "interim/chirps_preprocessed/chirps_kenya.nc"
def process_era5_land(variable: Optional[str] = None):
data_path = get_data_path()
regrid_path = data_path / "interim/VCI_preprocessed/data_kenya.nc"
assert regrid_path.exists(), f"{regrid_path} not available"

processor = ERA5LandPreprocessor(data_path)

processor.preprocess(
subset_str="kenya",
regrid=None,
regrid=regrid_path,
resample_time="M",
upsampling=False,
variable=variable,
Expand Down Expand Up @@ -194,5 +192,6 @@ def preprocess_boku_ndvi():
# preprocess_kenya_boundaries(selection="level_2")
# preprocess_kenya_boundaries(selection="level_3")
# preprocess_era5_hourly()
preprocess_boku_ndvi()
# preprocess_boku_ndvi()
# preprocess_asal_mask()
process_era5_land()
81 changes: 81 additions & 0 deletions scripts/train_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import sys

sys.path.append("..")

from src.models import (
load_model,
)
from src.lightning_models import LightningModel

from scripts.utils import get_data_path
from argparse import Namespace
import pytorch_lightning as pl


if __name__ == "__main__":
always_ignore_vars = [
"VCI",
"p84.162",
"sp",
"tp",
"Eb",
"VCI1M",
"RFE1M",
"VCI3M",
# "boku_VCI",
"modis_ndvi",
"SMroot",
"lc_class", # remove for good clustering (?)
"lc_class_group", # remove for good clustering (?)
"slt", # remove for good clustering (?)
"no_data_one_hot",
"lichens_and_mosses_one_hot",
"permanent_snow_and_ice_one_hot",
"urban_areas_one_hot",
"water_bodies_one_hot",
"t2m",
"SMsurf",
# "pev",
# "e",
"E",
]
target_vars = ["boku_VCI"]
dynamic_vars = ["precip", "t2m", "pet", "E", "SMsurf"]
static = True

hparams = Namespace(
**{
"model_name": "EALSTM",
"data_path": get_data_path(),
"experiment": "one_month_forecast",
"hidden_size": 64,
"rnn_dropout": 0.3,
"include_latlons": True,
"static_embedding_size": 64,
"include_prev_y": False,
"include_yearly_aggs": False,
"static": "features",
"batch_size": 1,
"include_pred_month": True,
"pred_months": None,
"ignore_vars": always_ignore_vars,
"include_monthly_aggs": False,
"surrounding_pixels": None,
"predict_delta": False,
"spatial_mask": None,
"normalize_y": True,
"dense_features": [128],
"val_ratio": 0.3,
"learning_rate": 1e3,
"save_preds": True,
"static": False,
}
)

model = LightningModel(hparams)
kwargs = dict(fast_dev_run=True) # , gpus=[0],
model.fit(**kwargs)
model.predict()

# TODO: add list of static vars that are included to the ModelArrays
# TODO: get the model running on real data (with gpu)
4 changes: 2 additions & 2 deletions src/exporters/era5_land.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _broken_export(
output_paths: List,
show_api_request: bool = True,
n_parallel_requests: int = 1,
pool: Optional[multiprocessing.pool.Pool] = None,
pool: Optional[multiprocessing.Pool] = None,
) -> List:
if n_parallel_requests > 1: # Run in parallel
assert pool is not None, (
Expand Down Expand Up @@ -184,7 +184,7 @@ def export(
if n_parallel_requests < 1:
n_parallel_requests = 1

p: Optional[multiprocessing.pool.Pool]
p: Optional[multiprocessing.Pool]
if n_parallel_requests > 1: # Run in parallel
p = multiprocessing.Pool(int(n_parallel_requests))
else:
Expand Down
6 changes: 6 additions & 0 deletions src/lightning_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .model_base import LightningModel


__all__ = [
"LightningModel"
]
Loading