diff --git a/fooof/core/funcs.py b/fooof/core/funcs.py index e4751c467..32be4deaa 100644 --- a/fooof/core/funcs.py +++ b/fooof/core/funcs.py @@ -7,8 +7,12 @@ - They are left available for easy swapping back in, if desired. """ +from inspect import isfunction + import numpy as np +from scipy.stats import norm + from fooof.core.errors import InconsistentDataError ################################################################################################### @@ -41,6 +45,43 @@ def gaussian_function(xs, *params): return ys +def skewed_gaussian_function(xs, *params): + """Skewed gaussian fitting function. + + Parameters + ---------- + xs : 1d array + Input x-axis values. + *params : float + Parameters that define the skewed gaussian function (center, height, width, alpha). + + Returns + ------- + ys : 1d array + Output values for skewed gaussian function. + """ + + ys = np.zeros_like(xs) + + for ii in range(0, len(params), 4): + + ctr, hgt, wid, alpha = params[ii:ii+4] + + # Gaussian distribution + ys = gaussian_function(xs, ctr, hgt, wid) + + # Skewed cumulative distribution function + cdf = norm.cdf(alpha * ((xs - ctr) / wid)) + + # Skew the gaussian + ys = ys * cdf + + # Rescale height + ys = (ys / np.max(ys)) * hgt + + return ys + + def expo_function(xs, *params): """Exponential fitting function, for fitting aperiodic component with a 'knee'. @@ -167,7 +208,9 @@ def get_pe_func(periodic_mode): """ - if periodic_mode == 'gaussian': + if isfunction(periodic_mode): + pe_func = periodic_mode + elif periodic_mode == 'gaussian': pe_func = gaussian_function else: raise ValueError("Requested periodic mode not understood.") @@ -194,7 +237,9 @@ def get_ap_func(aperiodic_mode): If the specified aperiodic mode label is not understood. """ - if aperiodic_mode == 'fixed': + if isfunction(aperiodic_mode): + ap_func = aperiodic_mode + elif aperiodic_mode == 'fixed': ap_func = expo_nk_function elif aperiodic_mode == 'knee': ap_func = expo_function diff --git a/fooof/objs/fit.py b/fooof/objs/fit.py index fba745d27..68b7418fa 100644 --- a/fooof/objs/fit.py +++ b/fooof/objs/fit.py @@ -67,7 +67,7 @@ from fooof.core.reports import save_report_fm from fooof.core.modutils import copy_doc_func_to_method from fooof.core.utils import group_three, check_array_dim -from fooof.core.funcs import gaussian_function, get_ap_func, infer_ap_func +from fooof.core.funcs import get_pe_func, get_ap_func, infer_ap_func from fooof.core.errors import (FitError, NoModelError, DataError, NoDataError, InconsistentDataError) from fooof.core.strings import (gen_settings_str, gen_results_fm_str, @@ -154,8 +154,9 @@ class FOOOF(): """ # pylint: disable=attribute-defined-outside-init - def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, - peak_threshold=2.0, aperiodic_mode='fixed', verbose=True): + def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, + min_peak_height=0.0, peak_threshold=2.0, aperiodic_mode='fixed', + periodic_mode='gaussian', verbose=True): """Initialize object with desired settings.""" # Set input settings @@ -164,6 +165,7 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h self.min_peak_height = min_peak_height self.peak_threshold = peak_threshold self.aperiodic_mode = aperiodic_mode + self.periodic_mode = periodic_mode self.verbose = verbose ## PRIVATE SETTINGS @@ -439,6 +441,9 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None): if self.verbose: self._check_width_limits() + # Determine the aperiodic and periodic fit funcs + self._set_fit_funcs() + # In rare cases, the model fails to fit, and so uses try / except try: @@ -715,6 +720,11 @@ def set_check_data_mode(self, check_data): self._check_data = check_data + def _set_fit_funcs(self): + """Set the requested aperiodic and periodic fit functions.""" + + self._pe_func = get_pe_func(self.periodic_mode) + self._ap_func = get_ap_func(self.aperiodic_mode) def _check_width_limits(self): """Check and warn about peak width limits / frequency resolution interaction.""" @@ -762,8 +772,7 @@ def _simple_ap_fit(self, freqs, power_spectrum): try: with warnings.catch_warnings(): warnings.simplefilter("ignore") - aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), - freqs, power_spectrum, p0=guess, + aperiodic_params, _ = curve_fit(self._ap_func, freqs, power_spectrum, p0=guess, maxfev=self._maxfev, bounds=ap_bounds) except RuntimeError: raise FitError("Model fitting failed due to not finding parameters in " @@ -818,9 +827,8 @@ def _robust_ap_fit(self, freqs, power_spectrum): try: with warnings.catch_warnings(): warnings.simplefilter("ignore") - aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), - freqs_ignore, spectrum_ignore, p0=popt, - maxfev=self._maxfev, bounds=ap_bounds) + aperiodic_params, _ = curve_fit(self._ap_func, freqs_ignore, spectrum_ignore, + p0=popt, maxfev=self._maxfev, bounds=ap_bounds) except RuntimeError: raise FitError("Model fitting failed due to not finding " "parameters in the robust aperiodic fit.") @@ -904,7 +912,7 @@ def _fit_peaks(self, flat_iter): # Collect guess parameters and subtract this guess gaussian from the data guess = np.vstack((guess, (guess_freq, guess_height, guess_std))) - peak_gauss = gaussian_function(self.freqs, guess_freq, guess_height, guess_std) + peak_gauss = self._pe_func(self.freqs, guess_freq, guess_height, guess_std) flat_iter = flat_iter - peak_gauss # Check peaks based on edges, and on overlap, dropping any that violate requirements @@ -963,7 +971,7 @@ def _fit_peak_guess(self, guess): # Fit the peaks try: - gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat, + gaussian_params, _ = curve_fit(self._pe_func, self.freqs, self._spectrum_flat, p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds) except RuntimeError: raise FitError("Model fitting failed due to not finding " diff --git a/fooof/tests/core/test_funcs.py b/fooof/tests/core/test_funcs.py index f9e05dcb9..6afe44d61 100644 --- a/fooof/tests/core/test_funcs.py +++ b/fooof/tests/core/test_funcs.py @@ -26,6 +26,20 @@ def test_gaussian_function(): assert max(ys) == hgt assert np.allclose([i/sum(ys) for i in ys], norm.pdf(xs, ctr, wid)) +def test_skewed_gaussian_function(): + + ctr, hgt, wid, alpha = 50, 5, 10, 4 + + xs = np.arange(1, 100) + ys_gaussian = gaussian_function(xs, ctr, hgt, wid) + ys = skewed_gaussian_function(xs, ctr, hgt, wid, alpha) + + assert np.all(ys) + + # Positive alphas shift the max to the right + assert np.argmax(ys) >= np.argmax(ys_gaussian) + assert np.max(ys) == np.max(ys_gaussian) == hgt + def test_expo_function(): off, knee, exp = 10, 5, 2