Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve Numpy implementation #29

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 25 additions & 41 deletions bfast/monitor/python/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import numpy as np
np.warnings.filterwarnings('ignore')
np.set_printoptions(suppress=True)
from sklearn import linear_model

from bfast.base import BFASTMonitorBase
from bfast.monitor.utils import compute_end_history, compute_lam, map_indices
Expand Down Expand Up @@ -102,7 +101,7 @@ def __init__(self,
self._timers = {}
self.use_mp = use_mp

def fit(self, data, dates, nan_value=0):
def fit(self, data, dates, nan_value=0, **kwargs):
""" Fits the models for the ndarray 'data'

Parameters
Expand Down Expand Up @@ -151,30 +150,14 @@ def fit(self, data, dates, nan_value=0):
self.magnitudes = rval[:,:,2].astype(np.float32)
self.valids = rval[:,:,3].astype(np.int32)
else:
means_global = np.zeros((data.shape[1], data.shape[2]), dtype=np.float32)
magnitudes_global = np.zeros((data.shape[1], data.shape[2]), dtype=np.float32)
breaks_global = np.zeros((data.shape[1], data.shape[2]), dtype=np.int32)
valids_global = np.zeros((data.shape[1], data.shape[2]), dtype=np.int32)

for i in range(data.shape[1]):
if self.verbose > 0:
print("Processing row {}".format(i))

for j in range(data.shape[2]):
y = data[:,i,j]
(pix_break,
pix_mean,
pix_magnitude,
pix_num_valid) = self.fit_single(y)
breaks_global[i,j] = pix_break
means_global[i,j] = pix_mean
magnitudes_global[i,j] = pix_magnitude
valids_global[i,j] = pix_num_valid

self.breaks = breaks_global
self.means = means_global
self.magnitudes = magnitudes_global
self.valids = valids_global

rval = np.apply_along_axis(self.fit_single, 0, data)

#print(rval.shape)
self.breaks = rval[0].astype(np.int32)
self.means = rval[1].astype(np.float32)
self.magnitudes = rval[2].astype(np.float32)
self.valids = rval[3].astype(np.int32)

return self

Expand Down Expand Up @@ -210,7 +193,10 @@ def fit_single(self, y):
magnitude = 0.0
if self.verbose > 1:
print("WARNING: Not enough observations: ns={ns}, Ns={Ns}".format(ns=ns, Ns=Ns))
return brk, mean, magnitude, Ns

rval = np.array([brk, mean, magnitude, Ns])

return rval

val_inds = val_inds[ns:]
val_inds -= self.n
Expand All @@ -224,30 +210,26 @@ def fit_single(self, y):
X_nn_m = X_nn[:, ns:]
y_nn_h = y_nn[:ns]
y_nn_m = y_nn[ns:]

# (1) fit linear regression model for history period
model = linear_model.LinearRegression(fit_intercept=False)
model.fit(X_nn_h.T, y_nn_h)
coef = np.linalg.pinv(X_nn_h@X_nn_h.T)@X_nn_h@y_nn_h

if self.verbose > 1:
column_names = np.array(["Intercept",
"trend",
"harmonsin1",
column_names = np.array(["harmonsin1",
"harmoncos1",
"harmonsin2",
"harmoncos2",
"harmonsin3",
"harmoncos3"])
if self.trend:
indxs = np.array([0, 1, 3, 5, 7, 2, 4, 6])
indxs = np.array([1, 3, 5, 7, 2, 4, 6])
else:
indxs = np.array([0, 2, 4, 6, 1, 3, 5])
# print(column_names[indxs])
indxs = np.array([2, 4, 6, 1, 3, 5])
print(column_names[indxs])
print(model.coef_[indxs])
print(coef[indxs])

# get predictions for all non-nan points
y_pred = model.predict(X_nn.T)
y_pred = X_nn.T@coef
y_error = y_nn - y_pred

# (2) evaluate model on monitoring period mosum_nn process
Expand Down Expand Up @@ -277,14 +259,16 @@ def fit_single(self, y):
print("bounds", bounds)

breaks = np.abs(mosum) > bounds
first_break = np.where(breaks)[0]
first_break = np.nonzero(breaks)[0]

if first_break.shape[0] > 0:
first_break = first_break[0]
first_break = first_break[0].item()
else:
first_break = -1

return first_break, mean, magnitude, Ns
rval = np.array([first_break, mean.item(), magnitude.item(), Ns.item()])

return rval

def get_timers(self):
""" Returns runtime measurements for the
Expand Down
13 changes: 6 additions & 7 deletions bfast/monitor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@
__critval_h = np.array([0.25, 0.5, 1])
__critval_period = np.arange(2, 12, 2)
__critval_level = np.arange(0.95, 0.999, 0.001)
__critval_mr = np.array(["max", "range"])
__critval_mr = ["max", "range"]

def _check_par(val, name, arr, fun=lambda x: x):
if not val in arr:
Expand All @@ -403,13 +403,12 @@ def get_critval(h, period, level, mr):
index = np.zeros(4, dtype=np.int)

# Get index into table from arguments
index[0] = np.where(mr == __critval_mr)[0][0]
index[1] = np.where(level == __critval_level)[0][0]
# index[2] = np.where(period == __critval_period)[0][0]
# print((np.abs(__critval_period - period)).argmin())
index[0] = next(i for i, v in enumerate(__critval_mr) if v == mr)
index[1] = np.nonzero(level == __critval_level)[0][0]
index[2] = (np.abs(__critval_period - period)).argmin()
index[3] = np.where(h == __critval_h)[0][0]
# For historical reasons, the critvals are scaled by sqrt(2)
index[3] = np.nonzero(h == __critval_h)[0][0]

# For legacy reasons, the critvals are scaled by sqrt(2)
return __critvals[tuple(index)] * np.sqrt(2)

def _find_index_date(dates, t):
Expand Down