Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 80 additions & 47 deletions src/nessai/flowmodel/importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import glob
import logging
import os
import re
from typing import Optional
from warnings import warn

Expand All @@ -24,7 +25,7 @@ class ImportanceFlowModel(FlowModel):
"""Flow Model that contains multiple flows for importance sampler."""

models: torch.nn.ModuleList = None
_resume_n_models: int = None
_resume_model_keys: list = None

def __init__(
self,
Expand All @@ -39,25 +40,41 @@ def __init__(
output=output,
rng=rng,
)
self.weights_files = []
self.models = torch.nn.ModuleList()
self.weights_files = {}
self.models = torch.nn.ModuleDict()
self._current_model = -1
self._model = None

@property
def model(self):
"""The current flow (model).

Returns None if the no models have been added.
"""
if self.models:
return self.models[-1]
else:
logger.warning("Model not defined yet!")
return None
return self._model

@model.setter
def model(self, model):
if model is not None:
self.models.append(model)
raise ValueError("Cannot set model directly, use add_model()")
self._model = None

def add_model(self, model, key=None):
"""Add a model to the dictionary of models.

Sets the current model to the new model.
"""
if key is None:
self._current_model += 1
key = str(self._current_model)
else:
self._current_model = max(int(key), self._current_model)
self.models[key] = model
self._model = model

@property
def current_model_key(self):
return str(self._current_model)

@property
def n_models(self) -> int:
Expand All @@ -81,7 +98,7 @@ def reset_optimiser(self) -> None:
def add_new_flow(self, reset=False):
"""Add a new flow"""
logger.debug("Add a new flow")
if reset or not self.models:
if reset or not self.models or self.model is None:
new_flow = configure_model(self.flow_config)
else:
new_flow = copy.deepcopy(self.model)
Expand All @@ -95,10 +112,17 @@ def add_new_flow(self, reset=False):
)
logger.debug(f"Inference device: {self.inference_device}")
self.models.eval()
self.models.append(new_flow)
self.add_model(new_flow)
self.reset_optimiser()

def log_prob_ith(self, x, i):
def remove_flow(self, i: str) -> None:
"""Remove the i'th flow"""
logger.debug(f"Removing {i}'th flow")
model = self.models.pop(i)
if self.model is model:
self.model = None

def log_prob_ith(self, x: np.ndarray, i: str) -> np.ndarray:
"""Compute the log-prob for the ith flow"""
x = (
torch.from_numpy(x)
Expand All @@ -114,6 +138,10 @@ def log_prob_ith(self, x, i):

def log_prob_all(self, x):
"""Compute the log probability using all of the stored models."""
# if self.model is None:
# if not self.models:
# raise RuntimeError("Models are not initialised yet!")
# self._model = next(iter(self.models.values()))
x = (
torch.from_numpy(x)
.type(torch.get_default_dtype())
Expand All @@ -124,12 +152,12 @@ def log_prob_all(self, x):
n = self.n_models
log_prob = torch.empty(x.shape[0], n)
with torch.no_grad():
for i, m in enumerate(self.models[:n]):
for i, m in enumerate(list(self.models.values())[:n]):
log_prob[:, i] = m.log_prob(x)
log_prob = log_prob.cpu().numpy().astype(np.float64)
return log_prob

def sample_ith(self, i, N=1):
def sample_ith(self, i: str, N: int = 1):
"""Draw samples from the ith flow"""
if self.models is None:
raise RuntimeError("Models are not initialised yet!")
Expand All @@ -145,67 +173,72 @@ def sample_ith(self, i, N=1):
def save_weights(self, weights_file) -> None:
"""Save the weights file."""
super().save_weights(weights_file)
self.weights_files.append(self.weights_file)
self.weights_files[self.current_model_key] = self.weights_file

def load_all_weights(self) -> None:
"""Load all of the weights files for each flow.

Resets any existing models.
"""
self.models = torch.nn.ModuleList()
self.models = torch.nn.ModuleDict()
logger.debug(f"Loading weights from {self.weights_files}")
self.device = torch.device(
self.training_config.get("device_tag", "cpu")
)
for wf in self.weights_files:
self._current_model = -1
for key, wf in self.weights_files.items():
new_flow = configure_model(self.flow_config)
new_flow.device = self.device
new_flow.load_state_dict(torch.load(wf, weights_only=True))
self.models.append(new_flow)
self.add_model(new_flow, key=key)
self.models.eval()

def update_weights_path(
self, weights_path: str, n: Optional[int] = None
self, weights_path: str, keys: Optional[list[str]] = None
) -> None:
"""Update the weights path.

Searches in the specified directory for weights files.

.. versionchanged:: 0.15.0
Replaced the :code:`n` parameter with :code:`keys` to specify
which models to load.

Parameters
----------
weights_path : str
Path to the directory that contains the weights files.
n : Optional[int]
The number of files to load. If not specified, :code:`n_models` is
used instead. Must be specified when resuming since the models list
is not saved.
keys : list[str], optional
The keys (IDs) of the models to load.
"""
all_weights_files = glob.glob(
weights_files = {}
for wf in glob.glob(
os.path.join(weights_path, "", "level_*", "model.pt")
)

if n is None:
if self.n_models:
n = self.n_models
else:
):
# Extract the level number from the path
level = re.search(r"level_(\d+)", wf)
if level is None:
raise RuntimeError(
"n is None and no models are defined, cannot update "
"weights path."
f"Cannot find level number in weights file: {wf}"
)
weights_files[level.group(1)] = wf

logger.debug(f"Loading weights from: {all_weights_files}")
if len(all_weights_files) < n:
raise RuntimeError(
f"Cannot use weights from: {weights_path}. Not enough files."
)
elif len(all_weights_files) > n:
logger.warning(
"More weights files than expected. Some files will be skipped."
if keys is None:
keys = weights_files.keys()
else:
keys = [str(k) for k in keys]

keys = set(keys)
known_keys = set(weights_files.keys())
if not keys.issubset(known_keys):
raise ValueError(
f"Keys {keys - known_keys} not found in weights files."
)
self.weights_files = [
os.path.join(weights_path, f"level_{i}", "model.pt")
for i in range(n)
]

# Only keep the specified keys
self.weights_files = {
k: v for k, v in weights_files.items() if k in keys
}

def resume(
self,
Expand All @@ -224,17 +257,17 @@ def resume(
"Not weights path specified, looking in output directory"
)
weights_path = self.output
self.update_weights_path(weights_path, n=self._resume_n_models)
self.update_weights_path(weights_path, keys=self._resume_model_keys)
self.load_all_weights()
self.initialise()

def __getstate__(self):
d = self.__dict__
# Avoid making a copy because models can be large and this doubles the
# memory usage.
exclude = {"models", "_optimiser", "flow_config"}
exclude = {"models", "_optimiser", "flow_config", "_model"}
state = {k: d[k] for k in d.keys() - exclude}
state["initialised"] = False
state["models"] = None
state["_resume_n_models"] = len(d["models"])
state["_resume_model_keys"] = list(d["models"].keys())
return state
16 changes: 14 additions & 2 deletions src/nessai/livepoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
def add_extra_parameters_to_live_points(
parameters,
default_values=None,
dtypes=None,
):
"""Add extra parameters to the live points dtype.

Expand All @@ -27,6 +28,9 @@ def add_extra_parameters_to_live_points(
----------
parameters: list
List of parameters to add.
dtypes: Optional[Union[List, Tuple]]
List of dtypes for each parameter. If not specified, default
dtypes will be set to :code:`config.livepoints.default_float_dtype` in
default_values: Optional[Union[List, Tuple]]
List of default values for each parameters. If not specified, default
values will be set to based on :code: `DEFAULT_FLOAT_VALUE` in
Expand All @@ -38,14 +42,22 @@ def add_extra_parameters_to_live_points(
)
else:
default_values = tuple(default_values)
for p, dv in zip(parameters, default_values):
if dtypes is None:
dtypes = len(parameters) * (config.livepoints.default_float_dtype,)
else:
dtypes = tuple(dtypes)
if not (len(parameters) == len(default_values) == len(dtypes)):
raise ValueError(
"Length mismatch between parameters, default_values, and dtypes."
)
for p, dtype, dv in zip(parameters, dtypes, default_values):
if p not in config.livepoints.extra_parameters:
config.livepoints.extra_parameters.append(p)
config.livepoints.extra_parameters_defaults = (
config.livepoints.extra_parameters_defaults + (dv,)
)
config.livepoints.extra_parameters_dtype.append(
config.livepoints.default_float_dtype
dtype,
)
else:
logger.warning(
Expand Down
5 changes: 4 additions & 1 deletion src/nessai/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ def plot_live_points(

df = pd.DataFrame(live_points)
df = df.dropna(axis="columns", how="all")
df = df[np.isfinite(df).all(1)]
numeric_df = df.select_dtypes(include=[np.number])
if not numeric_df.empty:
goodmask = np.isfinite(numeric_df).all(axis=1)
df = df[goodmask]

if c is not None:
hue = df[c]
Expand Down
Loading
Loading