diff --git a/doc/user-guide/computation.rst b/doc/user-guide/computation.rst index dc9748af80b..215fc76d5bf 100644 --- a/doc/user-guide/computation.rst +++ b/doc/user-guide/computation.rst @@ -542,7 +542,7 @@ two gaussian peaks: coords=["x", "y"], func=multi_peak, param_names=names, - kwargs={"maxfev": 10000}, + maxfev=10000, ) .. note:: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 10b59ba1c0c..6af6914e5e6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -75,6 +75,9 @@ Bug fixes - Allow writing NetCDF files including only dimensionless variables using the distributed or multiprocessing scheduler (:issue:`7013`, :pull:`7040`). By `Francesco Nattino `_. +- Allow passing additional kwargs directly to ``scipy`` in + :py:meth:`Dataset.curvefit`. (:issue:`6891`, :pull:`6978`) + By `Sam Levang `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 302562243b2..d7323e7c871 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -5068,7 +5068,7 @@ def curvefit( p0: dict[str, Any] | None = None, bounds: dict[str, Any] | None = None, param_names: Sequence[str] | None = None, - kwargs: dict[str, Any] | None = None, + **kwargs: Any, ) -> Dataset: """ Curve fitting optimization for arbitrary functions. @@ -5109,7 +5109,7 @@ def curvefit( should be manually supplied when fitting a function that takes a variable number of parameters. **kwargs : optional - Additional keyword arguments to passed to scipy curve_fit. + Additional keyword arguments passed to scipy curve_fit. Returns ------- @@ -5134,7 +5134,7 @@ def curvefit( p0=p0, bounds=bounds, param_names=param_names, - kwargs=kwargs, + **kwargs, ) def drop_duplicates( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b5b694c776a..9158ebf4e03 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8206,7 +8206,7 @@ def curvefit( p0: dict[str, Any] | None = None, bounds: dict[str, Any] | None = None, param_names: Sequence[str] | None = None, - kwargs: dict[str, Any] | None = None, + **kwargs: Any, ) -> T_Dataset: """ Curve fitting optimization for arbitrary functions. @@ -8247,7 +8247,7 @@ def curvefit( should be manually supplied when fitting a function that takes a variable number of parameters. **kwargs : optional - Additional keyword arguments to passed to scipy curve_fit. + Additional keyword arguments passed to scipy curve_fit. Returns ------- @@ -8276,7 +8276,8 @@ def curvefit( bounds = {} if kwargs is None: kwargs = {} - + elif "kwargs" in kwargs: + kwargs = {**kwargs.pop("kwargs"), **kwargs} if not reduce_dims: reduce_dims_ = [] elif isinstance(reduce_dims, str) or not isinstance(reduce_dims, Iterable): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index ab6e5763248..69d9581ba0b 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4219,12 +4219,22 @@ def exp_decay(t, n0, tau=1): da = da.chunk({"x": 1}) fit = da.curvefit( - coords=[da.t], func=exp_decay, p0={"n0": 4}, bounds={"tau": [2, 6]} + coords=[da.t], + func=exp_decay, + p0={"n0": 4}, + bounds={"tau": [2, 6]}, + maxfev=1000, ) assert_allclose(fit.curvefit_coefficients, expected, rtol=1e-3) da = da.compute() - fit = da.curvefit(coords="t", func=np.power, reduce_dims="x", param_names=["a"]) + fit = da.curvefit( + coords="t", + func=np.power, + reduce_dims="x", + param_names=["a"], + kwargs={"maxfev": 1000}, + ) assert "a" in fit.param assert "x" not in fit.dims