From 1087df5f55b146403890a74c3d21b26704bfe1c8 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Tue, 30 Jan 2024 09:57:49 +0100 Subject: [PATCH] Make pre-commit happy about code style --- mcbackend/backends/clickhouse.py | 27 ++++++++++++++++++--------- mcbackend/backends/numpy.py | 6 ++---- mcbackend/core.py | 6 +++++- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/mcbackend/backends/clickhouse.py b/mcbackend/backends/clickhouse.py index ca4709e..31e5169 100644 --- a/mcbackend/backends/clickhouse.py +++ b/mcbackend/backends/clickhouse.py @@ -46,6 +46,10 @@ } +class ClickHouseBackendError(Exception): + """Something bad happened in the ClickHouse backend.""" + + def create_runs_table(client: clickhouse_driver.Client): query = """ CREATE TABLE IF NOT EXISTS runs ( @@ -80,7 +84,7 @@ def create_chain_table(client: clickhouse_driver.Client, meta: ChainMeta, rmeta: # Check that it does not already exist cid = chain_id(meta) if client.execute(f"SHOW TABLES LIKE '{cid}';"): - raise Exception(f"A table for {cid} already exists.") + raise ClickHouseBackendError(f"A table for {cid} already exists.") # Create a table with columns corresponding to the model variables columns = [] @@ -236,7 +240,7 @@ def _get_row_at( query = f"SELECT (`{names}`,) FROM {self.cid} WHERE _draw_idx={idx};" data = self._client.execute(query) if not data: - raise Exception(f"No record found for draw index {idx}.") + raise ClickHouseBackendError(f"No record found for draw index {idx}.") result = dict(zip(var_names, data[0][0])) return result @@ -364,7 +368,10 @@ def __init__( raise ValueError("Either a `client` or a `client_fn` must be provided.") if client_fn is None: - client_fn = lambda: client + + def client_fn(): + return client + if client is None: client = client_fn() @@ -382,11 +389,11 @@ def init_run(self, meta: RunMeta) -> ClickHouseRun: else: created_at = datetime.now().astimezone(timezone.utc) query = "INSERT INTO runs (created_at, rid, proto) VALUES" - params = dict( - created_at=created_at, - rid=meta.rid, - proto=base64.encodebytes(bytes(meta)).decode("ascii"), - ) + params = { + "created_at": created_at, + "rid": meta.rid, + "proto": base64.encodebytes(bytes(meta)).decode("ascii"), + } self._client.execute(query, [params]) return ClickHouseRun(meta, client_fn=self._client_fn, created_at=created_at) @@ -408,7 +415,9 @@ def get_run(self, rid: str) -> ClickHouseRun: {"rid": rid}, ) if len(rows) != 1: - raise Exception(f"Unexpected number of {len(rows)} results for rid='{rid}'.") + raise ClickHouseBackendError( + f"Unexpected number of {len(rows)} results for rid='{rid}'." + ) data = base64.decodebytes(rows[0][2].encode("ascii")) meta = RunMeta().parse(data) return ClickHouseRun( diff --git a/mcbackend/backends/numpy.py b/mcbackend/backends/numpy.py index fb2b570..cb37f06 100644 --- a/mcbackend/backends/numpy.py +++ b/mcbackend/backends/numpy.py @@ -111,7 +111,7 @@ class NumPyRun(Run): """An MCMC run where samples are kept in memory.""" def __init__(self, meta: RunMeta, *, preallocate: int) -> None: - self._settings = dict(preallocate=preallocate) + self._settings = {"preallocate": preallocate} self._chains: List[NumPyChain] = [] super().__init__(meta) @@ -129,9 +129,7 @@ class NumPyBackend(Backend): """An in-memory backend using NumPy.""" def __init__(self, preallocate: int = 1_000) -> None: - self._settings = dict( - preallocate=preallocate, - ) + self._settings = {"preallocate": preallocate} super().__init__() def init_run(self, meta: RunMeta) -> NumPyRun: diff --git a/mcbackend/core.py b/mcbackend/core.py index 02f28e3..6f627e7 100644 --- a/mcbackend/core.py +++ b/mcbackend/core.py @@ -27,6 +27,10 @@ __all__ = ("is_rigid", "chain_id", "Chain", "Run", "Backend") +class ChainError(Exception): + """Something is not right in one chain.""" + + def is_rigid(nshape: Optional[Shape]): """Determines wheather the shape is constant. @@ -119,7 +123,7 @@ def __len__(self) -> int: ]: for var in items: return len(method(var.name)) - raise Exception("This chain has no variables or sample stats.") + raise ChainError("This chain has no variables or sample stats.") @property def cid(self) -> str: