diff --git a/.pylintrc b/.pylintrc index 1d214f9..de65dda 100644 --- a/.pylintrc +++ b/.pylintrc @@ -62,7 +62,8 @@ confidence= # --disable=W". disable=fixme, no-else-return, - too-many-lines + too-many-lines, + similarities # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option @@ -484,6 +485,9 @@ preferred-modules= [DESIGN] +# Maximum number of positional arguments for function / method. +max-positional-arguments=10 + # Maximum number of arguments for function / method. max-args=10 diff --git a/pyei/__init__.py b/pyei/__init__.py index f58512c..cf6c905 100644 --- a/pyei/__init__.py +++ b/pyei/__init__.py @@ -1,6 +1,6 @@ """A package for rpv and ecological inference""" -__version__ = "1.1.1" +__version__ = "1.1.2" from .two_by_two import * from .goodmans_er import * from .plot_utils import * diff --git a/pyei/r_by_c.py b/pyei/r_by_c.py index d299f6c..bfb3855 100644 --- a/pyei/r_by_c.py +++ b/pyei/r_by_c.py @@ -9,7 +9,7 @@ """ import warnings -from pymc import sampling_jax +import pymc as pm import numpy as np from .plot_utils import ( plot_boxplots, @@ -200,8 +200,11 @@ def fit( # pylint: disable=too-many-branches "multinomial-dirichlet", ]: # for models whose sampling is w/ pycm with self.sim_model: # pylint: disable=not-context-manager - self.sim_trace = sampling_jax.sample_numpyro_nuts( - target_accept=target_accept, tune=tune, **other_sampling_args + self.sim_trace = pm.sample( + target_accept=target_accept, + tune=tune, + nuts_sampler="numpyro", + **other_sampling_args, ) elif self.model_name == "greiner-quinn": self.sim_trace = pyei_greiner_quinn_sample( diff --git a/pyei/two_by_two.py b/pyei/two_by_two.py index bcb37b6..c0646df 100644 --- a/pyei/two_by_two.py +++ b/pyei/two_by_two.py @@ -5,7 +5,6 @@ import warnings import pymc as pm -from pymc import sampling_jax import numpy as np import pytensor.tensor as at import pytensor @@ -847,9 +846,10 @@ def fit( **other_sampling_args, ) else: - self.sim_trace = sampling_jax.sample_numpyro_nuts( + self.sim_trace = pm.sample( target_accept=target_accept, tune=tune, + nuts_sampler="numpyro", **other_sampling_args, ) diff --git a/requirements.txt b/requirements.txt index b5f0778..27c3c30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pymc >= 5.10.0 +pymc >= 5.18.0 arviz scikit-learn matplotlib diff --git a/setup.py b/setup.py index f3f29d0..131aa59 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ def get_requirements(): with codecs.open(REQUIREMENTS_FILE, "r", encoding="utf-8") as buff: return buff.read().splitlines() except: - return """pymc >= 5.10.0 + return """pymc >= 5.18.0 arviz scikit-learn matplotlib