Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pySEQTarget/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class SEQopts:
:type excused: bool
:param excused_colnames: Column names (at the same length of treatment_level) specifying excused conditions, default ``[]``
:type excused_colnames: List[str]
:param expand_only: If True, ``SEQuential.expand()`` returns the expanded dataset and skips weighting,
modelling, and survival steps
:type expand_only: bool
:param followup_class: Boolean to force followup values to be treated as classes
:type followup_class: bool
:param followup_include: Boolean to force regular followup values into model covariates
Expand Down Expand Up @@ -121,6 +124,7 @@ class SEQopts:
denominator: Optional[str] = None
excused: bool = False
excused_colnames: List[str] = field(default_factory=lambda: [])
expand_only: bool = False
followup_class: bool = False
followup_include: bool = True
followup_max: int = None
Expand Down Expand Up @@ -161,6 +165,7 @@ class SEQopts:
def _validate_bools(self):
bools = [
"excused",
"expand_only",
"followup_class",
"followup_include",
"followup_spline",
Expand Down
11 changes: 8 additions & 3 deletions pySEQTarget/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@ def __init__(
_param_checker(self)
_data_checker(self)

def expand(self) -> None:
def expand(self):
"""
Creates the sequentially nested, emulated target trial structure
Creates the sequentially nested, emulated target trial structure.
If ``expand_only`` is set in parameters, returns the expanded dataset as a
:class:`polars.DataFrame` and skips all subsequent analysis steps.
"""
start = time.perf_counter()
kept = [
Expand Down Expand Up @@ -160,7 +162,7 @@ def expand(self) -> None:
pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col)
)

if self.method != "ITT":
if self.method == "dose-response" or (self.method == "censoring" and not self.expand_only):
_dynamic(self)
if self.selection_random:
_random_selection(self)
Expand All @@ -169,6 +171,9 @@ def expand(self) -> None:
end = time.perf_counter()
self._expansion_time = _format_time(start, end)

if self.expand_only:
return self.DT

def bootstrap(self, **kwargs) -> None:
"""
Internally sets up bootstrapping - creating a list of IDs to use per iteration
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.13.1"
version = "0.13.2"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
Expand Down
38 changes: 37 additions & 1 deletion tests/test_expansion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import polars as pl
from polars.testing import assert_frame_equal

from pySEQTarget import SEQuential
from pySEQTarget import SEQopts, SEQuential


def _make_model(data):
Expand Down Expand Up @@ -90,3 +91,38 @@ def test_expansion_truncates_each_trial_independently():

# Trial 1 starts at time=1, outcome at time=3 → followup 0,1,2
assert sorted(trial_1["followup"].to_list()) == [0, 1, 2]


def test_expand_only_returns_expanded_dataframe():
"""expand_only=True should return the expanded DataFrame directly and the
return value should equal self.DT from a standard expand() call."""
data = pl.DataFrame(
{
"ID": [1, 1, 1, 1, 1],
"time": [0, 1, 2, 3, 4],
"eligible": [1, 0, 0, 0, 0],
"treatment": [0, 1, 0, 1, 0],
"outcome": [0, 0, 0, 0, 0],
}
)

model_only = SEQuential(
data,
id_col="ID",
time_col="time",
eligible_col="eligible",
treatment_col="treatment",
outcome_col="outcome",
time_varying_cols=[],
fixed_cols=[],
parameters=SEQopts(expand_only=True),
)
result = model_only.expand()

assert isinstance(result, pl.DataFrame)
assert_frame_equal(result, model_only.DT)

model_full = _make_model(data)
model_full.expand()

assert_frame_equal(result, model_full.DT)