Skip to content

Commit

Permalink
analysis done
Browse files Browse the repository at this point in the history
  • Loading branch information
saminens committed May 11, 2020
1 parent 9f9eeee commit b9ca4be
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 24 deletions.
Empty file added src/__init__.py
Empty file.
106 changes: 106 additions & 0 deletions src/curve_fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from scipy.optimize import curve_fit
from src.transition import transition_curve
import argparse
import pandas as pd
import numpy as np
from functools import partial
import matplotlib.pyplot as plt
import matplotlib.dates as mdates


MIN_BOUNDS_B = [0, 100000, 0, 0, 0]
MAX_BOUNDS_B = [99999, 9999999, 10., 2., 1.8]
MIN_BOUNDS_A = [0, 0, 0, 0, 0]
MAX_BOUNDS_A = [99999, 9999999, 10., 2., 1.8]



def prepare_data(input_file_path, country, mitigation_start_date):
"""
Args:
input_file_path: path for input file
country: enter country of choice
mitigation_start_date: start date of mitigation effects/inflection date
Returns: data for optimization and plots
"""
data = pd.read_csv(input_file_path, index_col=None, header=0)
data = data.drop(['Lat', 'Long', 'Province/State'], axis=1)
data = data.melt(id_vars="Country/Region", var_name='date', value_name='Confirmed')
data['date'] = pd.to_datetime(data['date'])
data.sort_values(by=['Country/Region', 'date'], inplace=True)
data = data.pivot_table(index='Country/Region', columns='date', values='Confirmed', aggfunc=np.sum)
country_data = data.loc[country, :]
xdata = country_data.index
ydata = country_data.values
country_before_m = country_data.loc[:mitigation_start_date] # slope 1
country_after_m = country_data.loc[mitigation_start_date:] # slope 2
x1 = country_before_m.index
y1 = country_before_m.values
x2 = country_after_m.index
y2 = country_after_m.values
return x1,y1,x2,y2,xdata,ydata


def plot_curves(popt1, popt2, x2, xdata, ydata, country):
"""
Args:
popt1: Optimized parameters before mitigation effects
popt2: Optimized parameters after mitigation effects
x2: Index of pivoted data
xdata: pivoted country data index
ydata: pivoted country data values
country: country of choice
Returns: plots of curve fits
"""
x1=xdata
fig = plt.figure(figsize=(20, 10))
ax = fig.add_subplot(111)
part_transition = partial(transition_curve, tstart=xdata[0])
plt.plot(x1, part_transition(x1, *popt1), 'g--')
plt.plot(x2, part_transition(x2, *popt2), 'b--')
myFmt = mdates.DateFormatter('%m-%d')
ax.xaxis.set_major_formatter(myFmt)
plt.xlabel('Date')
plt.ylabel("COVID-19 confirmed cases")
plt.plot(xdata, ydata)
plt.grid(which='both')
# for xy in zip(xdata, ydata):
# ax.annotate('(%s, %s)' % xy, xy=xy, textcoords='data')
plt.title(f'Curve fit for {country}')
plt.xticks(rotation=45)
ax.legend(["Cases without mitigation", "Cases after mitigation", "Confirmed Cases"], prop={'size': 14})
plt.show()


def main(input_file_path, country, mitigation_start_date, p0_before, p0_after):
"""
Returns: NA, main function
"""
x1, y1, x2, y2, xdata, ydata = prepare_data(input_file_path, country, mitigation_start_date)
part_transition = partial(transition_curve, tstart=xdata[0])
popt1, _ = curve_fit(part_transition, x1, y1, bounds=(MIN_BOUNDS_B, MAX_BOUNDS_B),
p0=p0_before)
popt2, _ = curve_fit(part_transition, x2, y2, bounds=(MIN_BOUNDS_A, MAX_BOUNDS_A),
p0=p0_after)
plot_curves(popt1, popt2, x2, xdata, ydata, country)



if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input-fpth",dest='input_fpth', type=str, help="Provide the path to the input file")
parser.add_argument("--mitigation-start-date", dest='mitigation_start_date', type=str, help="Mitigation start date")
parser.add_argument("--country", dest='country', type=str, help="Name of the country to do analysis")
parser.add_argument('--p0-before', dest='p0_before', nargs='+', help="Initial parameters before mitigation curve")
parser.add_argument('--p0-after', dest='p0_after', nargs='+', help="Initial parameters after mitigation curve")
args = parser.parse_args()
print(main(args.input_fpth,args.country, args.mitigation_start_date, args.p0_before,args.p0_after))

127 changes: 103 additions & 24 deletions src/transition.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import numpy as np
from datetime import datetime, timedelta, date
import pandas as pd
import matplotlib.pyplot as plt

def transition(curve_type,dmax,time_to_max,time_in_lag,launch,time,interval):

def transition(curve_type, dmax, time_to_max, time_in_lag, launch, time, interval):
"""
Design a transition curve with given choice of inputs; playing with the parameters p,q
For more info https://businessperspectives.org/journals/Parametric analysis of the Bass model - Business Perspectives.pdf
Args:
curve_type:
dmax:
time_to_max:
time_in_lag:
launch:
time:
interval:
curve_type: slow uptake - 0, quick uptake - 10
dmax: max difference between initial value and final
time_to_max: time to peak
time_in_lag: the inflection duration
launch: date of start
time: end date/inflection date
interval: for monthly 1/12
Returns:
Returns: transition
"""
# Limit curve type from 0-10
Expand Down Expand Up @@ -43,14 +47,15 @@ def S_shape_curve(dmax, time_to_max, time_lag, launch, time, interval):
"""
Args:
dmax:
time_to_max:
time_lag:
launch:
time:
interval:
dmax: max difference between initial value and final
time_to_max: time to peak
time_lag: the inflection duration
launch: date of start
time: end date/inflection date
interval: for monthly 1/12
Returns: S_shape_curve
Returns:
"""
launch = time_index(launch)
Expand All @@ -71,17 +76,17 @@ def S_shape_curve(dmax, time_to_max, time_lag, launch, time, interval):
return S_shape_curve


def rapid_curve(dmax,time_to_max,launch,time,interval):
def rapid_curve(dmax, time_to_max,launch,time,interval):
"""
Args:
dmax:
time_to_max:
launch:
time:
interval:
dmax: max difference between initial value and final
time_to_max: time to peak
launch: date of start
time: end date/inflection date
interval: for monthly 1/12
Returns:
Returns: rapid_curve
"""
launch=time_index(launch)
Expand All @@ -97,5 +102,79 @@ def rapid_curve(dmax,time_to_max,launch,time,interval):
rapid_curve=(upper_limit-lower_limit)/interval
return rapid_curve


def time_index(time):
"""
Args:
time: end date/inflection date
Returns: time_index
"""
# time=datetime.strptime(time, '%Y-%m-%d')
if isinstance(time,date):
time_index=(time.year+(time.month-1)/12 + (time.day-1)/365)
else:
if time.isnumeric():
time_index=time
else:
time_index=1/0
return time_index


def step_function(start, end, launch, time, interval):
"""
Args:
start: start value
end: end value
launch: date of start
time: end date/inflection date
interval: for monthly 1/12
Returns: step_function
"""
if time+interval<=launch:
step_function=start*interval
elif time>=launch:
step_function=end*interval
else:
weight=(launch-time)/interval
step_function=weight*start+(1-weight)*end


def transition_curve(datelist, initial_val, final_val, ct, t2p, tlag, tstart):
"""
Args:
datelist: dates before/after inflection point
initial_val: initial guess
final_val: peak value
ct: a.k.a. curve type; slow uptake - 0, quick uptake - 10
t2p: a.k.a. time to peak
tlag: a.k.a. time_in_lag; the inflection duration
tstart: start date
Returns: transition curve
"""
transition_values = []
for dt in datelist:
tr_value = initial_val + transition(ct,final_val-initial_val,t2p,tlag,tstart,dt,1/365)
transition_values.append(tr_value)
return transition_values


if __name__ == '__main__':
pass
datelist = pd.date_range('2019-12-01', end='2020-12-01', freq='M', name='str').date.tolist()
plt.plot(transition_curve(datelist, 0, 1, 0, 11 / 12, 0, datelist[0]))
plt.plot(transition_curve(datelist, 0, 1, 5, 11 / 12, 0, datelist[0]))
plt.plot(transition_curve(datelist, 0, 1, 10, 11 / 12, 0, datelist[0]))
plt.legend(['Slow uptake', 'Linear uptake', 'Fast uptake'])
plt.title("Variation of uptakes with respect to curve type")
plt.xlabel("Time(in Months)")
plt.grid(b=True, which='major')
plt.show()
print("Sucessful")

0 comments on commit b9ca4be

Please sign in to comment.