diff --git a/src/nessai/flowmodel/importance.py b/src/nessai/flowmodel/importance.py index 34c873d9..3d56fa84 100644 --- a/src/nessai/flowmodel/importance.py +++ b/src/nessai/flowmodel/importance.py @@ -7,6 +7,7 @@ import glob import logging import os +import re from typing import Optional from warnings import warn @@ -24,7 +25,7 @@ class ImportanceFlowModel(FlowModel): """Flow Model that contains multiple flows for importance sampler.""" models: torch.nn.ModuleList = None - _resume_n_models: int = None + _resume_model_keys: list = None def __init__( self, @@ -39,8 +40,10 @@ def __init__( output=output, rng=rng, ) - self.weights_files = [] - self.models = torch.nn.ModuleList() + self.weights_files = {} + self.models = torch.nn.ModuleDict() + self._current_model = -1 + self._model = None @property def model(self): @@ -48,16 +51,30 @@ def model(self): Returns None if the no models have been added. """ - if self.models: - return self.models[-1] - else: - logger.warning("Model not defined yet!") - return None + return self._model @model.setter def model(self, model): if model is not None: - self.models.append(model) + raise ValueError("Cannot set model directly, use add_model()") + self._model = None + + def add_model(self, model, key=None): + """Add a model to the dictionary of models. + + Sets the current model to the new model. + """ + if key is None: + self._current_model += 1 + key = str(self._current_model) + else: + self._current_model = max(int(key), self._current_model) + self.models[key] = model + self._model = model + + @property + def current_model_key(self): + return str(self._current_model) @property def n_models(self) -> int: @@ -81,7 +98,7 @@ def reset_optimiser(self) -> None: def add_new_flow(self, reset=False): """Add a new flow""" logger.debug("Add a new flow") - if reset or not self.models: + if reset or not self.models or self.model is None: new_flow = configure_model(self.flow_config) else: new_flow = copy.deepcopy(self.model) @@ -95,10 +112,17 @@ def add_new_flow(self, reset=False): ) logger.debug(f"Inference device: {self.inference_device}") self.models.eval() - self.models.append(new_flow) + self.add_model(new_flow) self.reset_optimiser() - def log_prob_ith(self, x, i): + def remove_flow(self, i: str) -> None: + """Remove the i'th flow""" + logger.debug(f"Removing {i}'th flow") + model = self.models.pop(i) + if self.model is model: + self.model = None + + def log_prob_ith(self, x: np.ndarray, i: str) -> np.ndarray: """Compute the log-prob for the ith flow""" x = ( torch.from_numpy(x) @@ -114,6 +138,10 @@ def log_prob_ith(self, x, i): def log_prob_all(self, x): """Compute the log probability using all of the stored models.""" + # if self.model is None: + # if not self.models: + # raise RuntimeError("Models are not initialised yet!") + # self._model = next(iter(self.models.values())) x = ( torch.from_numpy(x) .type(torch.get_default_dtype()) @@ -124,12 +152,12 @@ def log_prob_all(self, x): n = self.n_models log_prob = torch.empty(x.shape[0], n) with torch.no_grad(): - for i, m in enumerate(self.models[:n]): + for i, m in enumerate(list(self.models.values())[:n]): log_prob[:, i] = m.log_prob(x) log_prob = log_prob.cpu().numpy().astype(np.float64) return log_prob - def sample_ith(self, i, N=1): + def sample_ith(self, i: str, N: int = 1): """Draw samples from the ith flow""" if self.models is None: raise RuntimeError("Models are not initialised yet!") @@ -145,67 +173,72 @@ def sample_ith(self, i, N=1): def save_weights(self, weights_file) -> None: """Save the weights file.""" super().save_weights(weights_file) - self.weights_files.append(self.weights_file) + self.weights_files[self.current_model_key] = self.weights_file def load_all_weights(self) -> None: """Load all of the weights files for each flow. Resets any existing models. """ - self.models = torch.nn.ModuleList() + self.models = torch.nn.ModuleDict() logger.debug(f"Loading weights from {self.weights_files}") self.device = torch.device( self.training_config.get("device_tag", "cpu") ) - for wf in self.weights_files: + self._current_model = -1 + for key, wf in self.weights_files.items(): new_flow = configure_model(self.flow_config) new_flow.device = self.device new_flow.load_state_dict(torch.load(wf, weights_only=True)) - self.models.append(new_flow) + self.add_model(new_flow, key=key) self.models.eval() def update_weights_path( - self, weights_path: str, n: Optional[int] = None + self, weights_path: str, keys: Optional[list[str]] = None ) -> None: """Update the weights path. Searches in the specified directory for weights files. + .. versionchanged:: 0.15.0 + Replaced the :code:`n` parameter with :code:`keys` to specify + which models to load. + Parameters ---------- weights_path : str Path to the directory that contains the weights files. - n : Optional[int] - The number of files to load. If not specified, :code:`n_models` is - used instead. Must be specified when resuming since the models list - is not saved. + keys : list[str], optional + The keys (IDs) of the models to load. """ - all_weights_files = glob.glob( + weights_files = {} + for wf in glob.glob( os.path.join(weights_path, "", "level_*", "model.pt") - ) - - if n is None: - if self.n_models: - n = self.n_models - else: + ): + # Extract the level number from the path + level = re.search(r"level_(\d+)", wf) + if level is None: raise RuntimeError( - "n is None and no models are defined, cannot update " - "weights path." + f"Cannot find level number in weights file: {wf}" ) + weights_files[level.group(1)] = wf - logger.debug(f"Loading weights from: {all_weights_files}") - if len(all_weights_files) < n: - raise RuntimeError( - f"Cannot use weights from: {weights_path}. Not enough files." - ) - elif len(all_weights_files) > n: - logger.warning( - "More weights files than expected. Some files will be skipped." + if keys is None: + keys = weights_files.keys() + else: + keys = [str(k) for k in keys] + + keys = set(keys) + known_keys = set(weights_files.keys()) + if not keys.issubset(known_keys): + raise ValueError( + f"Keys {keys - known_keys} not found in weights files." ) - self.weights_files = [ - os.path.join(weights_path, f"level_{i}", "model.pt") - for i in range(n) - ] + + # Only keep the specified keys + self.weights_files = { + k: v for k, v in weights_files.items() if k in keys + } def resume( self, @@ -224,7 +257,7 @@ def resume( "Not weights path specified, looking in output directory" ) weights_path = self.output - self.update_weights_path(weights_path, n=self._resume_n_models) + self.update_weights_path(weights_path, keys=self._resume_model_keys) self.load_all_weights() self.initialise() @@ -232,9 +265,9 @@ def __getstate__(self): d = self.__dict__ # Avoid making a copy because models can be large and this doubles the # memory usage. - exclude = {"models", "_optimiser", "flow_config"} + exclude = {"models", "_optimiser", "flow_config", "_model"} state = {k: d[k] for k in d.keys() - exclude} state["initialised"] = False state["models"] = None - state["_resume_n_models"] = len(d["models"]) + state["_resume_model_keys"] = list(d["models"].keys()) return state diff --git a/src/nessai/livepoint.py b/src/nessai/livepoint.py index b4b25156..c3f9f980 100644 --- a/src/nessai/livepoint.py +++ b/src/nessai/livepoint.py @@ -17,6 +17,7 @@ def add_extra_parameters_to_live_points( parameters, default_values=None, + dtypes=None, ): """Add extra parameters to the live points dtype. @@ -27,6 +28,9 @@ def add_extra_parameters_to_live_points( ---------- parameters: list List of parameters to add. + dtypes: Optional[Union[List, Tuple]] + List of dtypes for each parameter. If not specified, default + dtypes will be set to :code:`config.livepoints.default_float_dtype` in default_values: Optional[Union[List, Tuple]] List of default values for each parameters. If not specified, default values will be set to based on :code: `DEFAULT_FLOAT_VALUE` in @@ -38,14 +42,22 @@ def add_extra_parameters_to_live_points( ) else: default_values = tuple(default_values) - for p, dv in zip(parameters, default_values): + if dtypes is None: + dtypes = len(parameters) * (config.livepoints.default_float_dtype,) + else: + dtypes = tuple(dtypes) + if not (len(parameters) == len(default_values) == len(dtypes)): + raise ValueError( + "Length mismatch between parameters, default_values, and dtypes." + ) + for p, dtype, dv in zip(parameters, dtypes, default_values): if p not in config.livepoints.extra_parameters: config.livepoints.extra_parameters.append(p) config.livepoints.extra_parameters_defaults = ( config.livepoints.extra_parameters_defaults + (dv,) ) config.livepoints.extra_parameters_dtype.append( - config.livepoints.default_float_dtype + dtype, ) else: logger.warning( diff --git a/src/nessai/plot.py b/src/nessai/plot.py index e1c69023..9f58e611 100644 --- a/src/nessai/plot.py +++ b/src/nessai/plot.py @@ -143,7 +143,10 @@ def plot_live_points( df = pd.DataFrame(live_points) df = df.dropna(axis="columns", how="all") - df = df[np.isfinite(df).all(1)] + numeric_df = df.select_dtypes(include=[np.number]) + if not numeric_df.empty: + goodmask = np.isfinite(numeric_df).all(axis=1) + df = df[goodmask] if c is not None: hue = df[c] diff --git a/src/nessai/proposal/importance.py b/src/nessai/proposal/importance.py index 0f1e7030..c8ae540d 100644 --- a/src/nessai/proposal/importance.py +++ b/src/nessai/proposal/importance.py @@ -5,9 +5,11 @@ import logging import os +from functools import partial from typing import Callable, Optional, Tuple, Union import numpy as np +import numpy.lib.recfunctions as rfn from scipy.special import logsumexp from nessai.plot import plot_1d_comparison, plot_histogram, plot_live_points @@ -78,7 +80,7 @@ def __init__( clip: bool = False, plot_training: bool = False, ) -> None: - self.level_count = -1 + self._proposal_count = -1 self._initialised = False self.model = model @@ -94,7 +96,7 @@ def __init__( self.reparameterisation = reparameterisation self.weighted_kl = weighted_kl self.clip = clip - self._weights = {-1: 1.0} + self._weights = {"-1": 1.0} self.dtype = get_dtype(self.model.names) @@ -103,11 +105,6 @@ def weights(self) -> dict: """Dictionary containing the weights for each proposal""" return self._weights - @property - def weights_array(self) -> np.ndarray: - """Array of weights for each proposal""" - return np.fromiter(self._weights.values(), dtype=float) - @property def n_proposals(self) -> int: """Current number of proposals in the meta proposal""" @@ -129,11 +126,44 @@ def flow_config(self, config: dict) -> None: @property def _reset_flow(self) -> bool: """Boolean to indicate if the flow should be reset""" - if not self.reset_flow or self.level_count % self.reset_flow: + if not self.reset_flow or self._proposal_count % self.reset_flow: return False else: return True + @property + def proposal_id(self) -> str: + """The current proposal id""" + return str(self._proposal_count) + + @property + def log_q_dtype(self) -> np.dtype: + """Dtype for the log_q field in the samples array + + Returns + ------- + np.dtype + Dtype for the log_q field in the samples array. Each dtype is + a float64 and the field names are the proposal ids. + """ + return np.dtype([(qid, "f8") for qid in self.weights.keys()]) + + @property + def qid_dtype(self) -> Optional[np.dtype]: + """Dtype for the qID field in live points.""" + try: + idx = config.livepoints.non_sampling_parameters.index("qID") + except ValueError: + return None + return np.dtype(config.livepoints.non_sampling_dtype[idx]) + + def cast_qid(self, qid): + """Cast qID to the dtype defined in live points.""" + dtype = self.qid_dtype + if dtype is None: + return qid + return np.asarray(qid, dtype=dtype).item() + @staticmethod def _check_fields(): """Check that the logQ and logW fields have been added.""" @@ -149,6 +179,18 @@ def _check_fields(): raise RuntimeError( "logU field missing in non-sampling parameters." ) + if "qID" not in config.livepoints.non_sampling_parameters: + raise RuntimeError("qid field missing in non-sampling parameters.") + + def weights_array(self, keys: Optional[list[str]] = None) -> np.ndarray: + """Array of weights for each proposal""" + if keys is None: + keys = self.log_q_dtype.names + elif set(keys) != set(self.log_q_dtype.names): + raise ValueError( + f"Keys must be the same as the log_q dtype names: {keys}" + ) + return np.array([self.weights[k] for k in keys], dtype=float) def initialise(self): """Initialise the proposal""" @@ -276,16 +318,20 @@ def inverse_rescale(self, x_prime: np.ndarray) -> np.ndarray: x = numpy_array_to_live_points(x, self.model.names) return x, log_j - def update_proposal_weights(self, weights: dict) -> None: + def update_weights(self, weights: dict) -> None: """Method to update the proposal weights dictionary. Raises ------ RuntimeError - If the weights do not sum to 1 are the update. + If the weights do not sum to 1 after the update. + ValueError + If the keys in the weights dictionary are not strings. """ + if any(not isinstance(k, str) for k in weights.keys()): + raise ValueError("Keys in weights must be strings") self._weights.update(weights) - w_sum = np.sum(np.fromiter(self._weights.values(), float)) + w_sum = np.sum(self.weights_array()) if not np.isclose(w_sum, 1.0): raise RuntimeError(f"Weights must sum to 1! Actual value: {w_sum}") @@ -296,7 +342,7 @@ def train( output: Union[str, None] = None, weights: np.ndarray = None, **kwargs, - ) -> None: + ) -> str: """Train the proposal with a set of samples. Parameters @@ -311,11 +357,16 @@ def train( kwargs : Key-word arguments passed to \ :py:meth:`nessai.flowmodel.FlowModel.train`. + + Returns + ------- + str + The proposal id. """ - self.level_count += 1 - self._weights[self.level_count] = np.nan + self._proposal_count += 1 + self._weights[self.proposal_id] = np.nan output = self.output if output is None else output - level_output = os.path.join(output, f"level_{self.level_count}", "") + level_output = os.path.join(output, f"level_{self.proposal_id}", "") if not os.path.exists(level_output): os.makedirs(level_output, exist_ok=True) @@ -378,8 +429,9 @@ def train( test_samples, filename=os.path.join(level_output, "generated_samples.png"), ) + return self.proposal_id - def compute_log_Q( + def log_prob_meta_proposal( self, x_prime: np.ndarray, log_j: Optional[np.ndarray] = None, @@ -412,38 +464,41 @@ def compute_log_Q( if any(np.isnan(w) for w in self.weights.values()): raise RuntimeError("Some weights are not set!") - log_q_all = np.zeros([x_prime.shape[0], self.n_proposals]) - n_flows = self.flow.n_models + # Structured array for log_q + log_q = np.empty(len(x_prime), dtype=self.log_q_dtype) if self.n_proposals > 1 and log_j is None: raise RuntimeError( "Must specify log_j! Meta-proposal includes flows" ) - if any([flow.training for flow in self.flow.models]): + if any([flow.training for flow in self.flow.models.values()]): raise RuntimeError("One or more flows are in training mode!") - if n_flows >= 1: - log_q_all[:, 1 : (n_flows + 1)] = ( - self.flow.log_prob_all(x_prime) + log_j[:, np.newaxis] - ) - assert log_q_all.shape[0] == x_prime.shape[0] + for name in log_q.dtype.names: + log_prob_fn = self.get_proposal_log_prob(name, log_j=log_j) + log_q[name] = log_prob_fn(x_prime) - logger.debug(f"log_q is nan: {np.isnan(log_q_all).any()}") - logger.debug( - f"Mean log q for each each flow: {log_q_all.mean(axis=0)}" - ) - log_Q = logsumexp(log_q_all, b=self.weights_array, axis=1) + log_prob = self.log_prob_meta_proposal_from_log_q(log_q) + + if np.isnan(log_prob).any(): + raise ValueError("Log-prob meta proposal is NaN") - if np.isnan(log_Q).any(): - raise ValueError("There is a NaN in log g!") + return log_prob, log_q - return log_Q, log_q_all + def log_prob_meta_proposal_from_log_q(self, log_q): + """Compute the meta-proposal from an array of proposal \ + log-probabilities + """ + return rfn.apply_along_fields( + partial(logsumexp, b=self.weights_array(log_q.dtype.names)), + log_q, + ) def draw( self, n: int, - flow_number: Optional[int] = None, + flow_id: Optional[int] = None, ) -> Tuple[np.ndarray, np.ndarray]: """Draw n new points. @@ -451,7 +506,7 @@ def draw( ---------- n : int Number of points to draw. - flow_number : Optional[int] + flow_id : Optional[int] Specifies which flow to use. If not specified the last flow will be used. @@ -462,21 +517,21 @@ def draw( np.ndarray : Log-proposal probabilities (log_q) """ - if flow_number is None: - flow_number = self.level_count + if flow_id is None: + flow_id = self.proposal_id + flow_id = str(flow_id) # Draw a few more samples in case some are not accepted. n_draw = int(1.01 * n) logger.debug(f"Drawing {n} points") samples = np.zeros(0, dtype=self.dtype) - log_q_samples = np.empty([0, self.n_proposals]) + log_q_samples = np.empty(0, dtype=self.log_q_dtype) n_accepted = 0 while n_accepted < n and n_draw > 0: logger.debug(f"Drawing batch of {n_draw} samples") # x_prime, log_q = self.flow.sample_and_log_prob(N=n_draw) - x_prime = self.flow.sample_ith(i=flow_number, N=n_draw) - log_q = np.ones(n_draw) + x_prime = self.flow.sample_ith(i=flow_id, N=n_draw) x, log_j_inv = self.inverse_rescale(x_prime) # Rescaling can sometimes produce infs that don't appear in samples x_check, log_j = self.rescale(x) @@ -487,40 +542,35 @@ def draw( & np.isfinite(x_prime).all(axis=1) & np.isfinite(log_j) & np.isfinite(log_j_inv) - & np.isfinite(log_q) ) logger.debug(f"Rejected {n_draw - acc.sum()} points") if not np.any(acc): continue - x, x_prime, log_j, log_q = get_subset_arrays( - acc, x, x_prime, log_j, log_q - ) + x, x_prime, log_j = get_subset_arrays(acc, x, x_prime, log_j) - x["logQ"], log_q_all = self.compute_log_Q(x_prime, log_j=log_j) + x["logQ"], log_q = self.log_prob_meta_proposal( + x_prime, log_j=log_j + ) x["logP"] = self.model.batch_evaluate_log_prior( x, unit_hypercube=True ) x["logU"] = self.model.batch_evaluate_log_prior_unit_hypercube(x) x["logW"] = x["logU"] - x["logQ"] - accept = ( - np.isfinite(x["logP"]) - & ~np.isposinf(x["logW"]) - & ~np.isnan(log_q_all).all(axis=1) - & ~np.isposinf(log_q_all).all(axis=1) - ) + accept = np.isfinite(x["logP"]) & ~np.isposinf(x["logW"]) if not np.any(accept): continue - x, log_q_all = get_subset_arrays(accept, x, log_q_all) + x, log_q = get_subset_arrays(accept, x, log_q) samples = np.concatenate([samples, x]) - log_q_samples = np.concatenate([log_q_samples, log_q_all], axis=0) + log_q_samples = np.concatenate([log_q_samples, log_q]) n_accepted += x.size logger.debug(f"Accepted: {n_accepted}") samples = samples[:n] log_q_samples = log_q_samples[:n] + samples["qID"] = self.cast_qid(flow_id) logger.debug(f"Returning {samples.size} samples") return samples, log_q_samples @@ -531,26 +581,18 @@ def update_log_q( log_q: np.ndarray, ) -> np.ndarray: """Update the array of proposal probabilities for a set of samples""" - if log_q.shape[1] == self.n_proposals: + if self.proposal_id in log_q.dtype.names: raise ValueError("log_q array already contains current proposal") x, log_j = self.rescale(samples) - log_prob_fn = self.get_proposal_log_prob(self.level_count) + log_prob_fn = self.get_proposal_log_prob(self.proposal_id, log_j=log_j) log_q_current = log_prob_fn(x) - log_q = np.concatenate( - [log_q, log_q_current[:, np.newaxis] + log_j[:, np.newaxis]], - axis=1, - ) - return log_q - - def compute_meta_proposal_from_log_q(self, log_q): - """Compute the meta-proposal from an array of proposal \ - log-probabilities - """ - return logsumexp( + log_q = rfn.append_fields( log_q, - b=self.weights_array, - axis=1, + self.proposal_id, + log_q_current, + usemask=False, ) + return log_q def compute_meta_proposal_samples(self, samples: np.ndarray) -> np.ndarray: """Compute the meta proposal Q for a set of samples. @@ -564,15 +606,15 @@ def compute_meta_proposal_samples(self, samples: np.ndarray) -> np.ndarray: log_q : numpy.ndarray Array of log q for each flow. """ - if self.level_count not in self.weights or np.isnan( - self.weights[self.level_count] + if self.proposal_id not in self.weights or np.isnan( + self.weights[self.proposal_id] ): raise RuntimeError( "Weight(s) missing or not set. " f"Current weights: {self.weights}." ) x, log_j = self.rescale(samples) - return self.compute_log_Q(x, log_j=log_j) + return self.log_prob_meta_proposal(x, log_j=log_j) def _log_prob_initial(self, x: np.ndarray) -> np.ndarray: """Helper function that returns the log-probability for the initial @@ -580,12 +622,18 @@ def _log_prob_initial(self, x: np.ndarray) -> np.ndarray: """ return np.zeros(x.shape[0]) - def get_proposal_log_prob(self, it: int) -> Callable: + def get_proposal_log_prob( + self, it: str, log_j: np.ndarray = None + ) -> Callable: """Get a pointer to the function for ith proposal.""" - if it == -1: + it = str(it) + if it == "-1": return self._log_prob_initial - elif it < len(self.flow.models): - return lambda x: self.flow.log_prob_ith(x, it) + elif it in self.flow.models: + if log_j is not None: + return lambda x: self.flow.log_prob_ith(x, it) + log_j + else: + return lambda x: self.flow.log_prob_ith(x, it) else: raise ValueError @@ -601,15 +649,27 @@ def compute_kl_between_proposals( current and previous proposals are used. """ x_prime, log_j = self.rescale(x) + flow_keys = list(self.flow.models.keys()) + if flow_keys: + flow_keys = sorted(flow_keys, key=int) if p_it is None: - p_it = self.flow.n_models - 1 - + if not flow_keys: + raise ValueError("No flow models available for p_it") + p_it = flow_keys[-1] if q_it is None: - q_it = self.flow.n_models - 2 + if len(flow_keys) >= 2: + q_it = flow_keys[-2] + else: + q_it = "-1" + p_it = str(p_it) + q_it = str(q_it) if p_it == q_it: raise ValueError("p and q must be different") - elif p_it < -1 or q_it < -1: + elif not ( + (p_it == "-1" or p_it.isdigit()) + and (q_it == "-1" or q_it.isdigit()) + ): raise ValueError(f"Invalid p_it or q_it: {p_it}, {q_it}") log_p_f = self.get_proposal_log_prob(p_it) @@ -618,9 +678,9 @@ def compute_kl_between_proposals( log_p = log_p_f(x_prime) log_q = log_q_f(x_prime) - if p_it > -1: + if p_it != "-1": log_p += log_j - if q_it > -1: + if q_it != "-1": log_q += log_j kl = np.mean(log_p - log_q) @@ -634,7 +694,7 @@ def draw_from_prior(self, n: int) -> Tuple[np.ndarray, np.ndarray]: samples ) prime_samples, log_j = self.rescale(samples) - log_Q, log_q = self.compute_log_Q(prime_samples, log_j=log_j) + log_Q, log_q = self.log_prob_meta_proposal(prime_samples, log_j=log_j) samples["logQ"] = log_Q samples["logW"] = samples["logU"] - log_Q return samples, log_q @@ -693,7 +753,7 @@ def draw_from_flows( )[0] else: prime_samples[count : (count + m)] = self.flow.sample_ith( - id, N=m + str(id), N=m ) sample_its[count : (count + m)] = id count += m @@ -714,21 +774,29 @@ def draw_from_flows( finite, samples, prime_samples, log_j ) - log_q = np.zeros((samples.size, self.n_proposals)) + log_q = np.empty(samples.size, dtype=self.log_q_dtype) logger.debug("Computing log_q") - if self.n_proposals > 1: - log_q[:, 1:] = ( - self.flow.log_prob_all(prime_samples) + log_j[:, np.newaxis] - ) + for name in log_q.dtype.names: + log_prob_fn = self.get_proposal_log_prob(name, log_j=log_j) + log_q[name] = log_prob_fn(prime_samples) # -inf is okay since this is just zero, so only remove +inf or NaN - finite = ~np.isnan(log_q).all(axis=1) & ~np.isposinf(log_q).all(axis=1) + nan_all = np.ones(samples.size, dtype=bool) + posinf_all = np.ones(samples.size, dtype=bool) + for name in log_q.dtype.names: + nan_all &= np.isnan(log_q[name]) + posinf_all &= np.isposinf(log_q[name]) + finite = ~nan_all & ~posinf_all samples, log_q = get_subset_arrays(finite, samples, log_q) logger.debug( - f"Mean g for each each flow: {np.exp(log_q).mean(axis=0)}" + "Mean g for each each flow: " + f"{[np.exp(log_q[n]).mean() for n in log_q.dtype.names]}" + ) + logger.debug( + "Mean log_q for each each flow: " + f"{[log_q[n].mean() for n in log_q.dtype.names]}" ) - logger.debug(f"Mean log_q for each each flow: {log_q.mean(axis=0)}") samples["logP"] = self.model.batch_evaluate_log_prior( samples, unit_hypercube=True diff --git a/src/nessai/samplers/importancesampler.py b/src/nessai/samplers/importancesampler.py index f08e6fef..b2f6c158 100644 --- a/src/nessai/samplers/importancesampler.py +++ b/src/nessai/samplers/importancesampler.py @@ -6,11 +6,13 @@ import datetime import logging import os +from functools import partial from typing import Any, Callable, List, Literal, Optional, Union import matplotlib import matplotlib.pyplot as plt import numpy as np +import numpy.lib.recfunctions as rfn from scipy.special import logsumexp from .. import config @@ -57,6 +59,7 @@ def __init__( strict_threshold: bool = False, replace_all: bool = False, save_log_q: bool = False, + proposal: Optional[ImportanceFlowProposal] = None, ) -> None: self.samples = None self.log_q = None @@ -67,6 +70,7 @@ def __init__( self.state = _INSIntegralState() self.log_likelihood_threshold = None self.save_log_q = save_log_q + self.proposal = proposal @property def live_points(self) -> np.ndarray: @@ -229,23 +233,35 @@ def compute_importance(self, importance_ratio: float = 0.5): Dictionary containing the total, posterior and evidence importance as a function of iteration. """ - log_imp_post = -np.inf * np.ones(self.log_q.shape[1]) - log_imp_z = -np.inf * np.ones(self.log_q.shape[1]) - for i, it in enumerate(range(-1, self.log_q.shape[-1] - 1)): - sidx = np.where(self.samples["it"] == it)[0] - zidx = np.where(self.samples["it"] >= it)[0] - if len(sidx): - log_imp_post[i] = logsumexp( - self.samples["logL"][sidx] + self.samples["logW"][sidx] - ) - np.log(len(sidx)) - if len(zidx): - log_imp_z[i] = logsumexp( - self.samples["logL"][zidx] + self.samples["logW"][zidx] - ) - np.log(len(zidx)) - imp_z = np.exp(log_imp_z - logsumexp(log_imp_z)) - imp_post = np.exp(log_imp_post - logsumexp(log_imp_post)) - imp = (1 - importance_ratio) * imp_z + importance_ratio * imp_post - return {"total": imp, "posterior": imp_post, "evidence": imp_z} + qIDs = self.log_q.dtype.names + log_imp_post = {} + log_imp_post_indv = {} + for qID in qIDs: + qid_value = self.proposal.cast_qid(qID) + idx = self.samples["qID"] == qid_value + log_imp_post[qID] = log_evidence_from_ins_samples( + self.samples[idx] + ) + log_imp_post_indv[qID] = logsumexp( + self.samples["logL"] + self.samples["logU"] - self.log_q[qID] + ) - np.log(self.samples.size) + + imp_post = { + k: np.exp(v - logsumexp(np.fromiter(log_imp_post.values(), float))) + for k, v in log_imp_post.items() + } + imp_post_indv = { + k: np.exp( + v - logsumexp(np.fromiter(log_imp_post_indv.values(), float)) + ) + for k, v in log_imp_post_indv.items() + } + return { + "total": np.nan, + "posterior": imp_post, + "posterior_indv": imp_post_indv, + "evidence": np.nan, + } def compute_evidence_ratio( self, threshold: Optional[float] = None @@ -459,12 +475,14 @@ def __init__( strict_threshold=self.strict_threshold, replace_all=self.replace_all, save_log_q=self.save_log_q, + proposal=self.proposal, ) if self.draw_iid_live: self.iid_samples = OrderedSamples( strict_threshold=self.strict_threshold, replace_all=self.replace_all, save_log_q=self.save_log_q, + proposal=self.proposal, ) else: self.iid_samples = None @@ -645,8 +663,12 @@ def stopping_criteria(self) -> List[str]: @staticmethod def add_fields(): - """Add extra fields logW, logQ, logU""" - add_extra_parameters_to_live_points(["logW", "logQ", "logU"]) + """Add extra fields logW, logQ, logU, qID""" + add_extra_parameters_to_live_points( + parameters=["logW", "logQ", "logU", "qID"], + dtypes=["f8", "f8", "f8", "U8"], + default_values=[np.nan, np.nan, np.nan, "NULL"], + ) def configure_stopping_criterion( self, @@ -764,7 +786,7 @@ def populate_live_points(self) -> None: self.model.batch_evaluate_log_prior_unit_hypercube(live_points) ) live_points["logW"] = live_points["logU"] - live_points["logQ"] - log_q = np.zeros([live_points.size, 1]) + log_q = np.zeros(live_points.size, dtype=[("-1", "f8")]) if self.draw_iid_live: live_points, iid_samples = ( @@ -772,13 +794,13 @@ def populate_live_points(self) -> None: live_points[self.n_initial :], ) log_q, iid_log_q = ( - log_q[: self.n_initial, ...], - log_q[self.n_initial :, ...], + log_q[: self.n_initial], + log_q[self.n_initial :], ) self.iid_samples.add_initial_samples(iid_samples, iid_log_q) self.training_samples.add_initial_samples(live_points, log_q) - self.sample_counts[-1] = self.n_initial + self.sample_counts["-1"] = self.n_initial def initialise(self) -> None: """Initialise the nested sampler. @@ -1070,7 +1092,7 @@ def add_new_proposal(self): n_train: ].copy() self.current_training_log_q = self.training_samples.log_q[ - n_train:, : + n_train: ].copy() logger.info( @@ -1085,7 +1107,7 @@ def add_new_proposal(self): ) if self.replace_all: - weights = -np.exp(self.current_training_log_q[:, -1]) + weights = -np.exp(self.current_training_log_q[self.proposal_id]) elif self.weighted_kl: log_w = self.current_training_samples["logW"].copy() log_w -= logsumexp(log_w) @@ -1093,12 +1115,13 @@ def add_new_proposal(self): else: weights = None - self.proposal.train( + proposal_id = self.proposal.train( self.current_training_samples, plot=self.plot_training_data, weights=weights, ) self.training_time += datetime.datetime.now() - st + return proposal_id def draw_n_samples(self, n: int, **kwargs): """Draw n samples from the current proposal @@ -1184,7 +1207,9 @@ def add_and_update_points(self, n: int): self.compute_leakage(new_samples) ) - self._current_proposal_entropy = differential_entropy(-log_q[:, -1]) + self._current_proposal_entropy = differential_entropy( + -log_q[self.proposal_id] + ) logger.debug( f"New samples ESS: {effective_sample_size(new_samples['logW'])}" @@ -1204,7 +1229,7 @@ def add_and_update_points(self, n: int): ) self.training_samples.samples["logQ"] = ( - self.proposal.compute_meta_proposal_from_log_q( + self.proposal.log_prob_meta_proposal_from_log_q( self.training_samples.log_q ) ) @@ -1223,7 +1248,7 @@ def add_and_update_points(self, n: int): self.iid_samples.samples, self.iid_samples.log_q ) self.iid_samples.samples["logQ"] = ( - self.proposal.compute_meta_proposal_from_log_q( + self.proposal.log_prob_meta_proposal_from_log_q( self.iid_samples.log_q ) ) @@ -1296,7 +1321,7 @@ def adjust_final_samples(self, n_batches=5): else: new_samples, new_log_q = proposal.draw( n=(nc - c), - flow_number=it, + flow_id=it, update_counts=False, ) new_samples["it"] = it @@ -1327,7 +1352,14 @@ def adjust_final_samples(self, n_batches=5): batch_log_q = log_q[idx_keep] assert batch_samples.size == orig_n_total - log_Q = logsumexp(batch_log_q, b=norm_weight, axis=1) + weight_by_id = {str(it): w for it, w in zip(its, norm_weight)} + weights = np.array( + [weight_by_id[name] for name in batch_log_q.dtype.names], + dtype=float, + ) + log_Q = rfn.apply_along_fields( + partial(logsumexp, b=weights), batch_log_q + ) # Weights are normalised because the total number of samples is the # same. batch_samples["logQ"] = log_Q @@ -1433,7 +1465,7 @@ def log_state(self): """Log the state of the sampler""" logger.info( f"Update {self.iteration} - " - f"log Z: {self.log_evidence_error:.3f} +/- " + f"log Z: {self.log_evidence:.3f} +/- " f"{self.log_evidence_error:.3f} " f"ESS: {self.state.ess:.1f} " f"logL min: {self.live_points_unit['logL'].min():.3f} " @@ -1462,7 +1494,7 @@ def update_proposal_weights(self): """ n_total = len(self.samples_unit) new_weights = {k: v / n_total for k, v in self.sample_counts.items()} - self.proposal.update_proposal_weights(new_weights) + self.proposal.update_weights(new_weights) def update_sample_counts(self) -> None: """Update the sample counts for each proposal based on the current @@ -1472,28 +1504,28 @@ def update_sample_counts(self) -> None: See also: :code:`update_proposal_weights`. """ - counts = np.bincount( - self.samples_unit["it"] + 1, - minlength=(self.proposal.n_proposals), - ) - self.sample_counts = {it - 1: c for it, c in enumerate(counts)} + sample_counts = {} + for qid in self.proposal.weights.keys(): + qid_value = self.proposal.cast_qid(qid) if self.proposal else qid + sample_counts[qid] = np.sum(self.samples_unit["qID"] == qid_value) + self.sample_counts = sample_counts - def add_new_proposal_weight(self, iteration: int, n_new: int) -> None: + def add_new_proposal_weight(self, proposal_id: str, n_new: int) -> None: """Set the weights for a new proposal. Samples cannot have been drawn from the proposal already. """ if ( - iteration in self.sample_counts - and self.sample_counts[iteration] != 0 + proposal_id in self.sample_counts + and self.sample_counts[proposal_id] != 0 ): raise RuntimeError( - f"Samples already drawn from proposal {iteration}" + f"Samples already drawn from proposal {proposal_id}" ) n_total = len(self.samples_unit) + n_new - self.sample_counts[iteration] = n_new + self.sample_counts[proposal_id] = n_new new_weights = {k: v / n_total for k, v in self.sample_counts.items()} - self.proposal.update_proposal_weights(new_weights) + self.proposal.update_weights(new_weights) def nested_sampling_loop(self): """Main nested sampling loop.""" @@ -1523,14 +1555,14 @@ def nested_sampling_loop(self): n_removed = self.remove_samples() - self.add_new_proposal() + self.proposal_id = self.add_new_proposal() if self.draw_constant or self.replace_all: n_add = self.nlive else: n_add = n_removed - self.add_new_proposal_weight(self.iteration, n_add) + self.add_new_proposal_weight(self.proposal_id, n_add) self.add_and_update_points(n_add) @@ -1697,7 +1729,7 @@ def draw_final_samples( elif self.iid_samples: logger.warning("Already have i.i.d samples") - final_samples = OrderedSamples() + final_samples = OrderedSamples(proposal=self.proposal) eff = ( self.state.effective_n_posterior_samples @@ -1769,7 +1801,7 @@ def draw_final_samples( n_models = self.proposal.n_proposals samples = np.empty([0], dtype=self.proposal.dtype) - log_q = np.empty([0, n_models]) + log_q = np.empty(0, dtype=self.proposal.log_q_dtype) counts = np.zeros(n_models) it = 0 @@ -1814,7 +1846,9 @@ def draw_final_samples( samples = np.concatenate([samples, it_samples]) - log_Q = logsumexp(log_q, b=weights, axis=1) + log_Q = rfn.apply_along_fields( + partial(logsumexp, b=weights), log_q + ) if np.isposinf(log_Q).any(): logger.warning("Log meta proposal contains +inf") @@ -1970,10 +2004,10 @@ def plot_state( m += 1 - ax[m].plot(its, self.importance["total"][1:], label="Total") - ax[m].plot(its, self.importance["posterior"][1:], label="Posterior") - ax[m].plot(its, self.importance["evidence"][1:], label="Evidence") - ax[m].legend() + # ax[m].plot(its, self.importance["total"][1:], label="Total") + # ax[m].plot(its, self.importance["posterior"][1:], label="Posterior") + # ax[m].plot(its, self.importance["evidence"][1:], label="Evidence") + # ax[m].legend() ax[m].set_ylabel("Level importance") m += 1 diff --git a/src/nessai/utils/io.py b/src/nessai/utils/io.py index dcf44272..eb5f2b14 100644 --- a/src/nessai/utils/io.py +++ b/src/nessai/utils/io.py @@ -143,8 +143,22 @@ def encode_for_hdf5(value): Any Encoded value. """ + import h5py + if value is None: output = "__none__" + elif isinstance(value, np.ndarray): + if value.dtype.names is not None: + output = { + field: encode_for_hdf5(value[field]) + for field in value.dtype.names + } + elif value.dtype.char == "U": + output = value.astype( + h5py.string_dtype(encoding="utf-8", length=None) + ) + else: + output = value else: output = value return output @@ -165,10 +179,11 @@ def add_dict_to_hdf5_file(hdf5_file, path, d): The dictionary to save. """ for key, value in d.items(): - if isinstance(value, dict): - add_dict_to_hdf5_file(hdf5_file, path + key + "/", value) + encoded_value = encode_for_hdf5(value) + if isinstance(encoded_value, dict): + add_dict_to_hdf5_file(hdf5_file, path + key + "/", encoded_value) else: - hdf5_file[path + key] = encode_for_hdf5(value) + hdf5_file[path + key] = encoded_value def save_dict_to_hdf5(d, filename): diff --git a/src/nessai/utils/testing.py b/src/nessai/utils/testing.py index fb6d287d..aa7a7475 100644 --- a/src/nessai/utils/testing.py +++ b/src/nessai/utils/testing.py @@ -73,11 +73,14 @@ def assert_structured_arrays_equal(x, y, atol=0.0, rtol=0.0): valid = {f: False for f in x.dtype.names} max_diff = {f: np.nan for f in x.dtype.names} - for field in valid.keys(): - valid[field] = np.allclose( - x[field], y[field], equal_nan=True, atol=atol, rtol=rtol - ) - max_diff[field] = np.nanmax(x[field] - y[field]) + for field in x.dtype.names: + if x.dtype[field].char in ["U", "S"]: + valid[field] = (x[field] == y[field]).all() + else: + valid[field] = np.allclose( + x[field], y[field], equal_nan=True, atol=atol, rtol=rtol + ) + max_diff[field] = np.nanmax(x[field] - y[field]) if not all(valid.values()): mismatched = [k for k, v in valid.items() if v is False] diff --git a/tests/conftest.py b/tests/conftest.py index 3db0619a..22342211 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -112,9 +112,24 @@ def mp_context(request): def ins_parameters(): """Add (and remove) the standard INS parameters for the tests.""" # Before every test - add_extra_parameters_to_live_points(["logQ", "logW", "logU"]) - yield - reset_extra_live_points_parameters() + add_extra_parameters_to_live_points( + ["logQ", "logW", "logU", "qID"], + dtypes=["f8", "f8", "f8", "U8"], + default_values=[np.nan, np.nan, np.nan, "NULL"], + ) + try: + yield + finally: + reset_extra_live_points_parameters() + + +@pytest.fixture() +def reset_ins_parameters(): + """Reset the INS parameters after the test.""" + try: + yield + finally: + reset_extra_live_points_parameters() def pytest_configure(config): diff --git a/tests/test_flowmodel/test_flowmodel_importance.py b/tests/test_flowmodel/test_flowmodel_importance.py index d152cc13..93905430 100644 --- a/tests/test_flowmodel/test_flowmodel_importance.py +++ b/tests/test_flowmodel/test_flowmodel_importance.py @@ -40,41 +40,26 @@ def test_init(ifm, rng): output=output, rng=rng, ) - assert ifm.weights_files == [] + assert ifm.weights_files == {} assert len(ifm.models) == 0 def test_model_property_with_models(ifm): """Assert last model is returned""" - models = [DummyFlow(), DummyFlow(), DummyFlow()] - ifm.models = models - assert IFM.model.__get__(ifm) is models[-1] - - -def test_model_property_without_models(ifm): - """Assert None is return if no models have been defined""" - ifm.models = [] - assert IFM.model.__get__(ifm) is None + ifm._model = DummyFlow() + assert IFM.model.__get__(ifm) is ifm._model def test_model_setter(ifm): """Assert the model setter appends the model""" - new_model = DummyFlow() - models = [DummyFlow(), DummyFlow()] - ifm.models = models - IFM.model.__set__(ifm, new_model) - assert len(ifm.models) == 3 - assert ifm.models[-1] is new_model + IFM.model.__set__(ifm, None) + assert ifm._model is None -def test_model_setter_none(ifm): +def test_model_setter_error(ifm): """Assert none is not added to the models""" - new_model = None - models = [DummyFlow(), DummyFlow()] - ifm.models = models - IFM.model.__set__(ifm, new_model) - assert len(ifm.models) == 2 - assert None not in ifm.models + with pytest.raises(ValueError, match=r"Cannot set model directly"): + IFM.model.__set__(ifm, DummyFlow()) def test_n_models(ifm): @@ -113,10 +98,10 @@ def test_reset_optimiser(ifm): def test_add_new_flow_reset(ifm): """Assert a new flow is created when reset=True.""" flow = DummyFlow() - ifm.models = torch.nn.ModuleList([DummyFlow(), DummyFlow()]) ifm.training_config = dict(device="cpu", inference_device_tag=None) ifm.flow_config = dict(n_neurons=4) ifm.reset_optimiser = MagicMock() + ifm.add_model = MagicMock() with patch( "nessai.flowmodel.importance.configure_model", @@ -124,8 +109,7 @@ def test_add_new_flow_reset(ifm): ) as mock_configure: IFM.add_new_flow(ifm, reset=True) - assert len(ifm.models) == 3 - assert ifm.models[-1] is flow + ifm.add_model.assert_called_once_with(flow) assert ifm.models.training is False mock_configure.assert_called_once_with(ifm.flow_config) @@ -146,10 +130,10 @@ def test_add_new_flow_no_reset(ifm): ifm.device = device ifm.model = flow_to_copy - ifm.models = torch.nn.ModuleList([DummyFlow(), flow_to_copy]) ifm.flow_config = dict(n_neurons=4) ifm.training_config = dict(patience=20) ifm.reset_optimiser = MagicMock() + ifm.add_model = MagicMock() with patch( "copy.deepcopy", @@ -157,8 +141,7 @@ def test_add_new_flow_no_reset(ifm): ) as mock_copy: IFM.add_new_flow(ifm, reset=False) - assert len(ifm.models) == 3 - assert ifm.models[-1] is flow + ifm.add_model.assert_called_once_with(flow) assert ifm.models.training is False mock_copy.assert_called_once_with(ifm.models[-2]) @@ -176,6 +159,7 @@ def test_add_new_flow_first_flow(ifm): ifm.training_config = dict(device="cpu", inference_device_tag=None) ifm.flow_config = dict(n_neurons=4) ifm.reset_optimiser = MagicMock() + ifm.add_model = MagicMock() with patch( "nessai.flowmodel.importance.configure_model", @@ -183,8 +167,7 @@ def test_add_new_flow_first_flow(ifm): ) as mock_configure: IFM.add_new_flow(ifm, reset=False) - assert len(ifm.models) == 1 - assert ifm.models[-1] is flow + ifm.add_model.assert_called_once_with(flow) assert ifm.models.training is False mock_configure.assert_called_once_with(ifm.flow_config) @@ -351,53 +334,41 @@ def test_load_all_weights(ifm): model.load_state_dict.assert_called_once_with(w) -@pytest.mark.parametrize("n", [None, 10, 16]) -def test_update_weights_path(ifm, tmp_path, n): +@pytest.mark.parametrize("keys", [None, [1, 4, 6, 7, 9]]) +def test_update_weights_path(ifm, tmp_path, keys): """Assert the list of weights files is correctly updated""" path = tmp_path / "outdir" path.mkdir() - n_models = 15 - n_total = 16 - expected_files = [] + n_total = 10 + expected_files = {} for i in range(n_total): d = path / f"level_{i}" d.mkdir() file = d / "model.pt" - file.write_text("data") + file.touch() file = str(file) - expected_files.append(file) - - ifm.n_models = n_models - IFM.update_weights_path(ifm, str(path), n=n) - n_expeceted = n or n_models - assert ifm.weights_files == expected_files[:n_expeceted] + if keys and i not in keys: + continue + expected_files[str(i)] = file - -def test_update_weights_path_cannot_update(ifm): - """Assert the list of weights files is correctly updated""" - ifm.n_models = 0 - with pytest.raises(RuntimeError, match=r"n is None and .*"): - IFM.update_weights_path(ifm, ".", n=None) + IFM.update_weights_path(ifm, str(path), keys=keys) + assert ifm.weights_files == expected_files -def test_update_weights_path_not_enough_files(ifm, tmp_path): +def test_update_weights_path_invalid_keys(ifm, tmp_path): """Assert the list of weights files is correctly updated""" path = tmp_path / "outdir" path.mkdir() - n_models = 5 n_total = 4 - expected_files = [] for i in range(n_total): d = path / f"level_{i}" d.mkdir() file = d / "model.pt" - file.write_text("data") + file.touch() file = str(file) - expected_files.append(file) - ifm.n_models = n_models with pytest.raises(RuntimeError, match=r".* Not enough files."): - IFM.update_weights_path(ifm, str(path)) + IFM.update_weights_path(ifm, str(path), keys=[1, 4, 6, 7, 9]) @pytest.mark.parametrize("weights_path", [None, "weights_directory"]) diff --git a/tests/test_flowsampler.py b/tests/test_flowsampler.py index 28b1909a..a35c9e39 100644 --- a/tests/test_flowsampler.py +++ b/tests/test_flowsampler.py @@ -1141,6 +1141,7 @@ def test_save_results_integration( @pytest.mark.integration_test @pytest.mark.parametrize("ins", [True, False]) +@pytest.mark.usefixtures("reset_ins_parameters") def test_resume_from_data_integration( integration_model, tmp_path, caplog, ins ): diff --git a/tests/test_livepoint.py b/tests/test_livepoint.py index f12baf32..fe750315 100644 --- a/tests/test_livepoint.py +++ b/tests/test_livepoint.py @@ -46,12 +46,26 @@ def non_sampling_parameters(request): return request.param -@pytest.fixture(autouse=True, params=[[], ["logQ", "logW", "logU"]]) +@pytest.fixture( + autouse=True, + params=[ + ([], None, None), + ( + ["logQ", "logW", "logU", "qID"], + ["f8", "f8", "f8", "U8"], + [np.nan, np.nan, np.nan, "NULL"], + ), + ], +) def extra_parameters(request): """Add (and remove) extra parameters for the tests.""" # Before every test lp.reset_extra_live_points_parameters() - lp.add_extra_parameters_to_live_points(request.param) + lp.add_extra_parameters_to_live_points( + request.param[0], + dtypes=request.param[1], + default_values=request.param[2], + ) global EXTRA_PARAMS_DTYPE EXTRA_PARAMS_DTYPE = [ (nsp, d) @@ -62,17 +76,18 @@ def extra_parameters(request): ] # Test happens here - yield - - # Called after every test - lp.reset_extra_live_points_parameters() - EXTRA_PARAMS_DTYPE = [ - (nsp, d) - for nsp, d in zip( - config.livepoints.non_sampling_parameters, - config.livepoints.non_sampling_dtype, - ) - ] + try: + yield + finally: + # Called after every test + lp.reset_extra_live_points_parameters() + EXTRA_PARAMS_DTYPE = [ + (nsp, d) + for nsp, d in zip( + config.livepoints.non_sampling_parameters, + config.livepoints.non_sampling_dtype, + ) + ] @pytest.fixture(params=["f4", "f16"]) @@ -90,9 +105,10 @@ def change_dtype(request): current_dtype = config.livepoints.default_float_dtype config.livepoints.default_float_dtype = dtype - yield dtype - - config.livepoints.default_float_dtype = current_dtype + try: + yield dtype + finally: + config.livepoints.default_float_dtype = current_dtype @pytest.fixture @@ -190,7 +206,10 @@ def test_empty_structured_array_names(non_sampling_parameters): config.livepoints.non_sampling_defaults, ): if non_sampling_parameters: - np.testing.assert_array_equal(array[nsp], v * np.ones(n)) + if array[nsp].dtype != np.dtype("f8"): + assert (array[nsp] == v).all() + else: + np.testing.assert_array_equal(array[nsp], v * np.ones(n)) else: assert nsp not in array.dtype.names @@ -423,10 +442,10 @@ def test_multiple_live_points_to_dict(live_points): """ Test conversion of multiple_live points to a dictionary """ - d = {"x": [1, 4], "y": [2, 5], "z": [3, 6]} + d = {"x": np.array([1, 4]), "y": np.array([2, 5]), "z": np.array([3, 6])} d.update( { - k: 2 * [v] + k: np.array(2 * [v]) for k, v in zip( config.livepoints.non_sampling_parameters, config.livepoints.non_sampling_defaults, @@ -435,7 +454,15 @@ def test_multiple_live_points_to_dict(live_points): ) d_out = lp.live_points_to_dict(live_points) assert list(d.keys()) == list(d_out.keys()) - np.testing.assert_array_equal(list(d.values()), list(d_out.values())) + for k in d.keys(): + if d[k].dtype != np.dtype("f8"): + assert (d[k] == d_out[k]).all() + else: + np.testing.assert_array_equal(d[k], d_out[k]) + # np.testing.assert_array_equal(list(d.values()), list(d_out.values())) + + +# def test_unstructured_view_dtype(live_points): diff --git a/tests/test_plot.py b/tests/test_plot.py index 7f7447b4..ed2ed502 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -35,8 +35,10 @@ def nested_samples(live_points): @pytest.fixture(autouse=True) def auto_close_figures(): """Automatically close all figures after each test""" - yield - plt.close("all") + try: + yield + finally: + plt.close("all") @pytest.mark.parametrize("line_styles", [True, False]) diff --git a/tests/test_proposal/test_importance/test_prob.py b/tests/test_proposal/test_importance/test_prob.py index 8e76d6d8..f58a9710 100644 --- a/tests/test_proposal/test_importance/test_prob.py +++ b/tests/test_proposal/test_importance/test_prob.py @@ -18,17 +18,23 @@ def ifp(ifp): def test_update_proposal_weights(ifp): - ifp._weights = {-1: 0.5, 1: 0.5} - weights = {-1: 1 / 3, 0: 1 / 3, 1: 1 / 3} - IFP.update_proposal_weights(ifp, weights) + ifp._weights = {"-1": 0.5, "1": 0.5} + ifp.weights = ifp._weights + ifp.log_q_dtype = np.dtype([("-1", "f8"), ("0", "f8"), ("1", "f8")]) + ifp.weights_array = IFP.weights_array.__get__(ifp, IFP) + weights = {"-1": 1 / 3, "0": 1 / 3, "1": 1 / 3} + IFP.update_weights(ifp, weights) assert ifp._weights == weights def test_update_proposal_weights_vaild(ifp): - ifp._weights = {-1: 0.5, 1: 0.5} - weights = {-1: 0.33, 0: 0.33, 1: 0.33} + ifp._weights = {"-1": 0.5, "1": 0.5} + ifp.weights = ifp._weights + ifp.log_q_dtype = np.dtype([("-1", "f8"), ("0", "f8"), ("1", "f8")]) + ifp.weights_array = IFP.weights_array.__get__(ifp, IFP) + weights = {"-1": 0.33, "0": 0.33, "1": 0.33} with pytest.raises(RuntimeError, match="Weights must sum to 1!"): - IFP.update_proposal_weights(ifp, weights) + IFP.update_weights(ifp, weights) def test_initial_log_prob(ifp): @@ -42,57 +48,112 @@ def test_get_proposal_log_prob_initial(ifp): assert func is ifp._log_prob_initial -def test_compute_log_Q(ifp, x_prime): +def test_log_prob_meta_proposal(ifp, x_prime): n_flows = 3 - ifp.weights_array = np.array([0.25, 0.25, 0.25, 0.25]) + ifp._weights = {"-1": 0.25, "0": 0.25, "1": 0.25, "2": 0.25} + ifp.weights = ifp._weights + ifp.weights_array = MagicMock( + return_value=np.array([0.25, 0.25, 0.25, 0.25]) + ) ifp.flow.n_models = n_flows + ifp.flow.models = { + "0": MagicMock(training=False), + "1": MagicMock(training=False), + "2": MagicMock(training=False), + } + ifp.log_q_dtype = np.dtype( + [("-1", "f8"), ("0", "f8"), ("1", "f8"), ("2", "f8")] + ) + ifp.get_proposal_log_prob = MagicMock( + side_effect=lambda it, log_j=None: ( + (lambda x: np.zeros(len(x))) + if it == "-1" + else (lambda x: np.log(np.random.rand(len(x)))) + ) + ) + ifp.log_prob_meta_proposal_from_log_q = ( + IFP.log_prob_meta_proposal_from_log_q.__get__(ifp, IFP) + ) ifp.n_proposals = n_flows + 1 log_j = np.log(np.random.rand(len(x_prime))) - def log_prob_all(x): - return np.log(np.random.rand(len(x), n_flows)) + def log_prob_ith(x, it): + return np.log(np.random.rand(len(x))) - ifp.flow.log_prob_all = MagicMock(side_effect=log_prob_all) + ifp.flow.log_prob_ith = MagicMock(side_effect=log_prob_ith) - log_Q, log_q = IFP.compute_log_Q(ifp, x_prime, log_j=log_j) + log_Q, log_q = IFP.log_prob_meta_proposal(ifp, x_prime, log_j=log_j) assert len(log_Q) == len(x_prime) - assert log_q.shape == (len(x_prime), n_flows + 1) - assert all(log_q[:, 0] == 0) + assert log_q.shape == (len(x_prime),) + assert log_q.dtype.names == ("-1", "0", "1", "2") + assert np.all(log_q["-1"] == 0) - expected_log_Q = logsumexp(log_q, b=ifp.weights_array, axis=1) + log_q_values = np.column_stack([log_q[name] for name in log_q.dtype.names]) + expected_log_Q = logsumexp( + log_q_values, b=ifp.weights_array(log_q.dtype.names), axis=1 + ) np.testing.assert_array_equal(log_Q, expected_log_Q) -def test_compute_log_Q_weights_not_set(ifp, x_prime): +def test_log_prob_meta_proposal_weights_not_set(ifp, x_prime): """Assert an error is raised if a flow is in training mode""" n_flows = 3 - ifp.weights = {i - 1: v for i, v in enumerate([0.25, 0.25, 0.25, np.nan])} + ifp._weights = { + "-1": 0.25, + "0": 0.25, + "1": 0.25, + "2": np.nan, + } + ifp.weights = ifp._weights ifp.flow.n_models = n_flows + ifp.log_q_dtype = np.dtype( + [("-1", "f8"), ("0", "f8"), ("1", "f8"), ("2", "f8")] + ) + ifp.get_proposal_log_prob = MagicMock( + side_effect=lambda it, log_j=None: ( + (lambda x: np.zeros(len(x))) + if it == "-1" + else (lambda x: np.log(np.random.rand(len(x)))) + ) + ) ifp.n_proposals = n_flows + 1 log_j = np.log(np.random.rand(len(x_prime))) with pytest.raises(RuntimeError, match="Some weights are not set!"): - IFP.compute_log_Q(ifp, x_prime, log_j=log_j) + IFP.log_prob_meta_proposal(ifp, x_prime, log_j=log_j) -def test_compute_log_Q_flow_training(ifp, x_prime): +def test_log_prob_meta_proposal_flow_training(ifp, x_prime): """Assert an error is raised if a flow is in training mode""" n_flows = 3 - ifp.weights_array = np.array([0.25, 0.25, 0.25, 0.25]) + ifp._weights = {"-1": 0.25, "0": 0.25, "1": 0.25, "2": 0.25} + ifp.weights = ifp._weights + ifp.weights_array = MagicMock( + return_value=np.array([0.25, 0.25, 0.25, 0.25]) + ) ifp.flow.n_models = n_flows - ifp.flow.models = [] - for _ in range(n_flows): - mock_model = MagicMock() - mock_model.training = False - ifp.flow.models.append(mock_model) - ifp.flow.models[-1].training = True + ifp.flow.models = { + "0": MagicMock(training=False), + "1": MagicMock(training=False), + "2": MagicMock(training=True), + } + ifp.log_q_dtype = np.dtype( + [("-1", "f8"), ("0", "f8"), ("1", "f8"), ("2", "f8")] + ) + ifp.get_proposal_log_prob = MagicMock( + side_effect=lambda it, log_j=None: ( + (lambda x: np.zeros(len(x))) + if it == "-1" + else (lambda x: np.log(np.random.rand(len(x)))) + ) + ) ifp.n_proposals = n_flows + 1 log_j = np.log(np.random.rand(len(x_prime))) with pytest.raises( RuntimeError, match="One or more flows are in training mode" ): - IFP.compute_log_Q(ifp, x_prime, log_j=log_j) + IFP.log_prob_meta_proposal(ifp, x_prime, log_j=log_j) @pytest.mark.parametrize("p_it, q_it", [(None, None), (-1, 0), (3, 4)]) @@ -104,7 +165,7 @@ def rescale(x): def get_proposal_log_prob(it): def log_prob(x): - if it == -1: + if it == "-1": return np.zeros(len(x)) else: return np.log(np.random.rand(len(x))) @@ -112,6 +173,7 @@ def log_prob(x): return log_prob ifp.flow.n_models = 15 + ifp.flow.models = {str(i): MagicMock(training=False) for i in range(15)} ifp.rescale = MagicMock(side_effect=rescale) ifp.get_proposal_log_prob = MagicMock(side_effect=get_proposal_log_prob) @@ -123,17 +185,21 @@ def log_prob(x): def test_update_log_q(ifp, model, x): n_proposals = 5 - ifp.level_count = 4 + ifp._proposal_count = 4 + ifp.proposal_id = "4" - log_q = np.log(np.random.rand(len(x), n_proposals - 1)) + names = ["-1", "0", "1", "2"] + log_q = np.empty(len(x), dtype=[(name, "f8") for name in names]) + for name in names: + log_q[name] = np.log(np.random.rand(len(x))) def rescale(x): x = model.to_unit_hypercube(x) x = live_points_to_array(x, model.names) return x, np.zeros(x.shape[0]) - def get_proposal_log_prob(it): - assert it == 4 + def get_proposal_log_prob(it, log_j=None): + assert it == "4" def log_prob(x): return np.log(np.random.rand(len(x))) @@ -147,13 +213,17 @@ def log_prob(x): log_q_out = IFP.update_log_q(ifp, x, log_q) - assert log_q_out.shape == (len(x), n_proposals) + assert log_q_out.shape == (len(x),) + assert log_q_out.dtype.names == ("-1", "0", "1", "2", "4") def test_compute_meta_proposal_from_log_q(ifp): n = 100 n_prop = 10 - log_q = np.log(np.random.rand(n, n_prop)) + names = [str(i - 1) for i in range(n_prop)] + log_q = np.empty(n, dtype=[(name, "f8") for name in names]) + for name in names: + log_q[name] = np.log(np.random.rand(n)) poolsize = np.random.multinomial( n_prop, @@ -161,15 +231,12 @@ def test_compute_meta_proposal_from_log_q(ifp): size=n, ) weights = poolsize / np.sum(poolsize) - ifp.weights_array = weights + ifp.weights_array = MagicMock(return_value=weights) - expected = logsumexp( - log_q, - b=weights, - axis=1, - ) + log_q_values = np.column_stack([log_q[name] for name in log_q.dtype.names]) + expected = logsumexp(log_q_values, b=weights, axis=1) - out = IFP.compute_meta_proposal_from_log_q(ifp, log_q) + out = IFP.log_prob_meta_proposal_from_log_q(ifp, log_q) assert len(out) == len(log_q) np.testing.assert_array_equal(out, expected) @@ -177,33 +244,37 @@ def test_compute_meta_proposal_from_log_q(ifp): @pytest.mark.usefixtures("ins_parameters") def test_compute_meta_proposal_samples(ifp, x, x_prime, log_j): - ifp.level_count = 2 - ifp.weights = {-1: 0.25, 0: 0.25, 1: 0.25, 2: 0.25} + ifp._proposal_count = 2 + ifp._weights = {"-1": 0.25, "0": 0.25, "1": 0.25, "2": 0.25} + ifp.weights = ifp._weights + ifp.proposal_id = "2" x["logQ"] = np.nan x["logW"] = np.nan log_Q = np.log(np.random.rand(len(x))) - log_q = np.log(np.random.rand(len(x), 10)) + log_q = np.empty(len(x), dtype=[(str(i), "f8") for i in range(10)]) + for name in log_q.dtype.names: + log_q[name] = np.log(np.random.rand(len(x))) ifp.rescale = MagicMock(return_value=(x_prime, log_j)) - ifp.compute_log_Q = MagicMock(return_value=(log_Q, log_q)) + ifp.log_prob_meta_proposal = MagicMock(return_value=(log_Q, log_q)) log_Q_out, log_q_out = IFP.compute_meta_proposal_samples(ifp, x) ifp.rescale.assert_called_once_with(x) - ifp.compute_log_Q.assert_called_once_with(x_prime, log_j=log_j) + ifp.log_prob_meta_proposal.assert_called_once_with(x_prime, log_j=log_j) np.testing.assert_array_equal(log_Q_out, log_Q) np.testing.assert_array_equal(log_q_out, log_q) @pytest.mark.parametrize( - "weights", [{-1: 0.5, 0: 0.5}, {-1: 0.5, 0: 0.5, 1: np.nan}] + "weights", [{"-1": 0.5, "0": 0.5}, {"-1": 0.5, "0": 0.5, "1": np.nan}] ) @pytest.mark.usefixtures("ins_parameters") def test_compute_meta_proposal_samples_weights_error(ifp, x, weights): - ifp.level_count = 1 - ifp.weights = weights + ifp._proposal_count = 1 + ifp._weights = weights with pytest.raises(RuntimeError, match=r"Weight\(s\) missing or not set."): IFP.compute_meta_proposal_samples(ifp, x) diff --git a/tests/test_proposal/test_importance/test_resume.py b/tests/test_proposal/test_importance/test_resume.py index 0777e75d..f75937be 100644 --- a/tests/test_proposal/test_importance/test_resume.py +++ b/tests/test_proposal/test_importance/test_resume.py @@ -39,12 +39,11 @@ def test_getstate_integration(tmp_path, model): weighted_kl=False, ) ifp.initialise() - weights = {-1: 1.0} for i in range(4): ifp.train(model.new_point(10), max_epochs=2) - weights = {j - 1: 1 / (i + 2) for j in range(i + 2)} - ifp.update_proposal_weights(weights) + weights = {str(j - 1): 1 / (i + 2) for j in range(i + 2)} + ifp.update_weights(weights) ifp.draw(10) out = pickle.dumps(ifp) diff --git a/tests/test_proposal/test_importance/test_sampling.py b/tests/test_proposal/test_importance/test_sampling.py index f8bb84a8..a276cdcc 100644 --- a/tests/test_proposal/test_importance/test_sampling.py +++ b/tests/test_proposal/test_importance/test_sampling.py @@ -20,16 +20,21 @@ def test_draw_from_prior(ifp, n, model): x = numpy_array_to_live_points(np.random.rand(n, 2), names=model.names) log_j = np.random.rand(n) log_Q = np.random.randn(n) - log_q = np.random.randn(n_proposals, n) + log_q = np.empty(n, dtype=[(str(i - 1), "f8") for i in range(n_proposals)]) + for name in log_q.dtype.names: + log_q[name] = np.random.randn(n) ifp.model.sample_unit_hypercube = MagicMock(return_value=x) ifp.model.batch_evaluate_log_prior_unit_hypercube = MagicMock( return_value=np.zeros(n) ) ifp.rescale = MagicMock(return_value=(x_prime, log_j)) - ifp.compute_log_Q = MagicMock(return_value=(log_Q, log_q)) + ifp.log_prob_meta_proposal = MagicMock(return_value=(log_Q, log_q)) x_out, log_q_out = IFP.draw_from_prior(ifp, n) - assert log_q_out.shape == (n_proposals, n) + assert log_q_out.shape == (n,) + assert log_q_out.dtype.names == tuple( + str(i - 1) for i in range(n_proposals) + ) assert x_out is x @@ -60,23 +65,40 @@ def rescale(x): x = live_points_to_array(x, model.names) return x, np.zeros(x.shape[0]) - def log_prob_all(x): - return np.log(np.random.rand(len(x), n_flows)) + def log_prob_ith(x, it): + return np.log(np.random.rand(len(x))) ifp.model = model + ifp.model.in_unit_hypercube = MagicMock( + side_effect=lambda s: np.ones(s.size, dtype=bool) + ) ifp.flow = create_autospec(ImportanceFlowModel) ifp.flow.sample_ith = MagicMock(side_effect=sample_ith) - ifp.flow.log_prob_all = MagicMock(side_effect=log_prob_all) + ifp.flow.log_prob_ith = MagicMock(side_effect=log_prob_ith) ifp.to_prime = MagicMock(side_effect=to_prime) ifp.inverse_rescale = MagicMock(side_effect=inverse_rescale) ifp.rescale = MagicMock(side_effect=rescale) + ifp._weights = { + str(i - 1): 1.0 / (n_flows + 1) for i in range(n_flows + 1) + } + ifp.log_q_dtype = np.dtype( + [(str(i - 1), "f8") for i in range(n_flows + 1)] + ) + ifp.get_proposal_log_prob = MagicMock( + side_effect=lambda it, log_j=None: ( + (lambda x: np.zeros(len(x))) + if it == "-1" + else (lambda x: np.log(np.random.rand(len(x)))) + ) + ) x, log_q, actual_counts = IFP.draw_from_flows( ifp, n, weights=weights, counts=counts ) assert len(x) == n - assert log_q.shape == (n, n_flows + 1) - assert all(log_q[:, 0] == 0) + assert log_q.shape == (n,) + assert log_q.dtype.names == tuple(str(i - 1) for i in range(n_flows + 1)) + assert np.all(log_q["-1"] == 0) assert np.isfinite(x["logP"]).all() assert np.isnan(x["logL"]).all() @@ -91,39 +113,55 @@ def test_draw(ifp, model): n_proposals = 5 n_draw = 100 ifp.n_proposals = n_proposals - ifp.level_count = n_proposals - 1 + ifp._proposal_count = n_proposals - 1 + ifp.proposal_id = str(n_proposals - 1) ifp.model = model ifp.dtype = model.new_point().dtype - ifp._weights = {-1: 0.2, 0: 0.2, 2: 0.2, 3: 0.4, 4: np.nan} - ifp.weights_array = np.fromiter(ifp._weights.values(), float) + ifp._weights = {"-1": 0.2, "0": 0.2, "2": 0.2, "3": 0.4, "4": np.nan} + ifp.log_q_dtype = np.dtype( + [(str(i - 1), "f8") for i in range(n_proposals)] + ) def inverse_rescale(x): x = numpy_array_to_live_points(x, model.names) - return model.from_unit_hypercube(x), np.zeros(x.size) + return x, np.zeros(x.size) def rescale(x): - x = model.to_unit_hypercube(x) x = live_points_to_array(x, model.names) return x, np.zeros(x.shape[0]) def sample_ith(i, N): - assert i == (n_proposals - 1) + assert i == str(n_proposals - 1) return np.random.rand(N, model.dims) - def compute_log_Q(x_prime, log_j=None, n=None): - log_q = ( - np.log(np.random.rand(len(x_prime), n_proposals)) + log_j[:, None] - ) - log_Q = logsumexp(log_q, b=ifp.weights_array, axis=1) + def log_prob_meta_proposal(x_prime, log_j=None): + names = [str(i - 1) for i in range(n_proposals)] + log_q = np.empty(len(x_prime), dtype=[(n, "f8") for n in names]) + for name in names: + log_q[name] = np.log(np.random.rand(len(x_prime))) + log_q_values = np.column_stack([log_q[name] for name in names]) + weights = np.fromiter(ifp._weights.values(), float) + log_Q = logsumexp(log_q_values, b=weights, axis=1) return log_Q, log_q ifp.rescale = rescale ifp.inverse_rescale = inverse_rescale - ifp.compute_log_Q = compute_log_Q + ifp.log_prob_meta_proposal = log_prob_meta_proposal ifp.flow = create_autospec(ImportanceFlowModel) ifp.flow.sample_ith = sample_ith + ifp.model.in_unit_hypercube = MagicMock( + side_effect=lambda s: np.ones(len(s), dtype=bool) + ) + ifp.model.batch_evaluate_log_prior = MagicMock( + side_effect=lambda s, unit_hypercube=True: np.zeros(len(s)) + ) + ifp.model.batch_evaluate_log_prior_unit_hypercube = MagicMock( + side_effect=lambda s: np.zeros(len(s)) + ) + ifp.qid_dtype = np.dtype("U8") + ifp.cast_qid = IFP.cast_qid.__get__(ifp, IFP) samples_out, log_q_out = IFP.draw(ifp, n_draw) assert len(samples_out) == n_draw - assert log_q_out.shape == (n_draw, n_proposals) + assert log_q_out.shape == (n_draw,) diff --git a/tests/test_proposal/test_importance/test_training.py b/tests/test_proposal/test_importance/test_training.py index 1dbe5598..7c2ae3ba 100644 --- a/tests/test_proposal/test_importance/test_training.py +++ b/tests/test_proposal/test_importance/test_training.py @@ -169,13 +169,13 @@ def test_training_and_prob(model, tmp_path): weighted_kl=False, ) ifp.initialise() - weights = {-1: 1.0} for i in range(4): ifp.train(model.new_point(10), max_epochs=2) - weights = {j - 1: 1 / (i + 2) for j in range(i + 2)} - ifp.update_proposal_weights(weights) + weights = {str(j - 1): 1 / (i + 2) for j in range(i + 2)} + ifp.update_weights(weights) x, _ = ifp.draw(10) log_Q, log_q = ifp.compute_meta_proposal_samples(x) assert len(log_Q) == 10 - assert log_q.shape == (10, 5) + assert log_q.shape == (10,) + assert log_q.dtype.names == tuple(weights.keys()) diff --git a/tests/test_samplers/test_importance_nested_sampler/conftest.py b/tests/test_samplers/test_importance_nested_sampler/conftest.py index 99ea4a9b..b00a90bf 100644 --- a/tests/test_samplers/test_importance_nested_sampler/conftest.py +++ b/tests/test_samplers/test_importance_nested_sampler/conftest.py @@ -13,6 +13,11 @@ ) +@pytest.fixture(autouse=True) +def reset_ins_parameters(reset_ins_parameters): + """Reset the INS parameters before each test.""" + + @pytest.fixture(scope="module", params=[False, True]) def iid(request): return request.param @@ -48,6 +53,8 @@ def n_samples(): def samples(model, n_samples, n_it, log_q, ins_parameters): x = model.sample_unit_hypercube(n_samples) x["it"] = np.random.randint(-1, n_it - 1, size=len(x)) + qids = np.array(log_q.dtype.names) + x["qID"] = np.random.choice(qids, size=len(x)) xx = model.from_unit_hypercube(x) x["logL"] = model.log_likelihood(xx) x["logP"] = model.log_prior(xx) @@ -56,14 +63,19 @@ def samples(model, n_samples, n_it, log_q, ins_parameters): x["it"] + np.abs(x["it"].min()), minlength=n_it ).astype(float) alpha /= alpha.sum() - x["logQ"] = logsumexp(log_q, axis=1, b=alpha) + log_q_values = np.column_stack([log_q[name] for name in log_q.dtype.names]) + x["logQ"] = logsumexp(log_q_values, axis=1, b=alpha) x["logW"] = -x["logQ"].copy() return x @pytest.fixture def log_q(n_samples, n_it): - return np.random.randn(n_samples, n_it) + names = [str(i - 1) for i in range(n_it)] + log_q = np.empty(n_samples, dtype=[(name, "f8") for name in names]) + for name in names: + log_q[name] = np.random.randn(n_samples) + return log_q @pytest.fixture diff --git a/tests/test_samplers/test_importance_nested_sampler/test_ordered_samples.py b/tests/test_samplers/test_importance_nested_sampler/test_ordered_samples.py index 1bd2bd04..0bb78cfa 100644 --- a/tests/test_samplers/test_importance_nested_sampler/test_ordered_samples.py +++ b/tests/test_samplers/test_importance_nested_sampler/test_ordered_samples.py @@ -279,11 +279,21 @@ def add_to_ns(indices): def test_compute_importance(ordered_samples, log_q, samples, ratio): ordered_samples.samples = samples ordered_samples.log_q = log_q + ordered_samples.proposal = MagicMock() + ordered_samples.proposal.cast_qid = lambda qid: qid out = OrderedSamples.compute_importance( ordered_samples, importance_ratio=ratio ) - assert len(set(out.keys()) - {"total", "posterior", "evidence"}) == 0 - assert np.all(np.isfinite(list(out.values()))) + assert ( + len( + set(out.keys()) + - {"total", "posterior", "posterior_indv", "evidence"} + ) + == 0 + ) + assert np.isnan(out["total"]) + assert np.isnan(out["evidence"]) + assert np.all(np.isfinite(list(out["posterior"].values()))) @pytest.mark.parametrize("threshold", [None, -10.0]) @@ -317,7 +327,9 @@ def test_computed_evidence_ratio(ordered_samples, samples, threshold): @pytest.mark.parametrize("save_log_q", [False, True]) def test_getstate(ordered_samples, save_log_q): samples = np.random.randn(20, 4) - log_q = np.random.randn(2, 20) + log_q = np.empty(20, dtype=[("-1", "f8"), ("0", "f8")]) + log_q["-1"] = np.random.randn(20) + log_q["0"] = np.random.randn(20) ordered_samples.save_log_q = save_log_q ordered_samples.log_q = log_q ordered_samples.samples = samples diff --git a/tests/test_samplers/test_importance_nested_sampler/test_proposal.py b/tests/test_samplers/test_importance_nested_sampler/test_proposal.py index 5b9de144..00ec3f3a 100644 --- a/tests/test_samplers/test_importance_nested_sampler/test_proposal.py +++ b/tests/test_samplers/test_importance_nested_sampler/test_proposal.py @@ -86,20 +86,18 @@ def test_draw_n_samples(ins, samples, log_q, history): def test_update_proposal_weights(ins): ins.samples_unit = np.ones(10) - ins.sample_counts = {-1: 2, 0: 4, 1: 4} + ins.sample_counts = {"-1": 2, "0": 4, "1": 4} ins.proposal = MagicMock(spec=ImportanceFlowProposal) INS.update_proposal_weights(ins) - expected_weights = {-1: 0.2, 0: 0.4, 1: 0.4} - ins.proposal.update_proposal_weights.assert_called_once_with( - expected_weights - ) + expected_weights = {"-1": 0.2, "0": 0.4, "1": 0.4} + ins.proposal.update_weights.assert_called_once_with(expected_weights) def test_add_new_proposal_weight(ins): n = 8 n_new = 2 - sample_counts = {-1: 2, 0: 3, 1: 3} - iteration = 2 + sample_counts = {"-1": 2, "0": 3, "1": 3} + iteration = "2" ins.samples_unit = np.ones(n) ins.sample_counts = sample_counts @@ -107,18 +105,16 @@ def test_add_new_proposal_weight(ins): INS.add_new_proposal_weight(ins, iteration, n_new) - assert ins.sample_counts[2] == 2 - expected_weights = {-1: 0.2, 0: 0.3, 1: 0.3, 2: 0.2} - ins.proposal.update_proposal_weights.assert_called_once_with( - expected_weights - ) + assert ins.sample_counts["2"] == 2 + expected_weights = {"-1": 0.2, "0": 0.3, "1": 0.3, "2": 0.2} + ins.proposal.update_weights.assert_called_once_with(expected_weights) def test_add_new_proposal_weight_error(ins): n = 8 n_new = 2 - sample_counts = {-1: 2, 0: 3, 1: 3, 2: 2} - iteration = 2 + sample_counts = {"-1": 2, "0": 3, "1": 3, "2": 2} + iteration = "2" ins.samples_unit = np.ones(n) ins.sample_counts = sample_counts diff --git a/tests/test_samplers/test_importance_nested_sampler/test_resume.py b/tests/test_samplers/test_importance_nested_sampler/test_resume.py index 6a0b3432..73818619 100644 --- a/tests/test_samplers/test_importance_nested_sampler/test_resume.py +++ b/tests/test_samplers/test_importance_nested_sampler/test_resume.py @@ -5,7 +5,9 @@ import numpy as np import pytest -from nessai.samplers.importancesampler import ImportanceNestedSampler as INS +from nessai.samplers.importancesampler import ( + ImportanceNestedSampler as INS, +) def test_getstate_no_model(ins): diff --git a/tests/test_samplers/test_importance_nested_sampler/test_samples.py b/tests/test_samplers/test_importance_nested_sampler/test_samples.py index bdd9a19b..caeeb913 100644 --- a/tests/test_samplers/test_importance_nested_sampler/test_samples.py +++ b/tests/test_samplers/test_importance_nested_sampler/test_samples.py @@ -61,9 +61,9 @@ def test_populate_live_points_no_iid(ins, model): ins.training_samples.add_initial_samples.assert_called_once() assert len(ins.training_samples.add_initial_samples.call_args.args[0]) == n - assert ins.training_samples.add_initial_samples.call_args.args[ - 1 - ].shape == (n, 1) + log_q = ins.training_samples.add_initial_samples.call_args.args[1] + assert log_q.shape == (n,) + assert len(log_q.dtype.names) == 1 @pytest.mark.usefixtures("ins_parameters") @@ -79,9 +79,9 @@ def test_populate_live_points_iid(ins, model): ins.training_samples.add_initial_samples.assert_called_once() assert len(ins.training_samples.add_initial_samples.call_args.args[0]) == n - assert ins.training_samples.add_initial_samples.call_args.args[ - 1 - ].shape == (n, 1) + log_q = ins.training_samples.add_initial_samples.call_args.args[1] + assert log_q.shape == (n,) + assert len(log_q.dtype.names) == 1 assert np.isfinite( ins.training_samples.add_initial_samples.call_args.args[0]["logL"] ).all() @@ -91,10 +91,9 @@ def test_populate_live_points_iid(ins, model): ins.iid_samples.add_initial_samples.assert_called_once() assert len(ins.iid_samples.add_initial_samples.call_args.args[0]) == n - assert ins.iid_samples.add_initial_samples.call_args.args[1].shape == ( - n, - 1, - ) + iid_log_q = ins.iid_samples.add_initial_samples.call_args.args[1] + assert iid_log_q.shape == (n,) + assert len(iid_log_q.dtype.names) == 1 assert np.isfinite( ins.iid_samples.add_initial_samples.call_args.args[0]["logL"] ).all() @@ -122,18 +121,22 @@ def test_remove_samples(ins, iid): def test_adjust_final_samples(ins, proposal, model, samples, log_q): - def draw(n, flow_number=None, update_counts=False): + def draw(n, flow_id=None, update_counts=False): assert update_counts is False x = numpy_array_to_live_points( np.random.randn(n, model.dims), names=model.names, ) - lq = np.random.rand(n, log_q.shape[1]) + lq = np.empty(n, dtype=log_q.dtype) + for name in log_q.dtype.names: + lq[name] = np.random.rand(n) return x, lq def draw_from_prior(n): x = model.new_point(n) - lq = np.random.rand(n, log_q.shape[1]) + lq = np.empty(n, dtype=log_q.dtype) + for name in log_q.dtype.names: + lq[name] = np.random.rand(n) return x, lq proposal.draw = MagicMock(side_effect=draw) @@ -180,8 +183,15 @@ def test_update_evidence(ins, iid): def test_update_sample_counts(ins): - ins.samples_unit = {"it": np.array([-1, 0, 2, 2, 2])} + ins.samples_unit = {"qID": np.array(["-1", "0", "2", "2", "2"])} ins.proposal = MagicMock() - ins.proposal.n_proposals = 5 + ins.proposal.weights = { + "-1": 0.0, + "0": 0.0, + "1": 0.0, + "2": 0.0, + "3": 0.0, + } + ins.proposal.cast_qid = lambda qid: qid INS.update_sample_counts(ins) - assert ins.sample_counts == {-1: 1, 0: 1, 1: 0, 2: 3, 3: 0} + assert ins.sample_counts == {"-1": 1, "0": 1, "1": 0, "2": 3, "3": 0} diff --git a/tests/test_sampling/test_ins_sampling.py b/tests/test_sampling/test_ins_sampling.py index 2919a6f7..a27078e4 100644 --- a/tests/test_sampling/test_ins_sampling.py +++ b/tests/test_sampling/test_ins_sampling.py @@ -9,6 +9,11 @@ from nessai.flowsampler import FlowSampler +@pytest.mark.usefixtures("reset_ins_parameters") +def reset_ins_parameters(): + """Reset the extra live points parameters before each test.""" + + @pytest.mark.slow_integration_test @pytest.mark.flaky(reruns=3) @pytest.mark.parametrize("save_log_q", [False, True]) @@ -48,7 +53,10 @@ def test_ins_resume(tmp_path, integration_model, flow_config, save_log_q): assert fp.ns.max_iteration == 2 assert fp.ns.finalised is True - np.testing.assert_array_almost_equal(new_log_q, original_log_q) + for name in original_log_q.dtype.names: + np.testing.assert_array_almost_equal( + new_log_q[name], original_log_q[name] + ) @pytest.mark.slow_integration_test