Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pySEQTarget/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class SEQopts:
subgroup_colname: str = None
treatment_level: List[int] = field(default_factory=lambda: [0, 1])
trial_include: bool = True
visit_colname: str = None
weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
weight_min: float = 0.0
weight_max: float = None
Expand Down
17 changes: 13 additions & 4 deletions pySEQTarget/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@
from .plot import _survival_plot
from .SEQopts import SEQopts
from .SEQoutput import SEQoutput
from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
_weight_bind, _weight_predict, _weight_setup,
_weight_stats)
from .weighting import (
_fit_denominator,
_fit_LTFU,
_fit_numerator,
_fit_visit,
_weight_bind,
_weight_predict,
_weight_setup,
_weight_stats,
)


class SEQuential:
Expand Down Expand Up @@ -93,7 +100,7 @@ def __init__(
if self.denominator is None:
self.denominator = _denominator(self)

if self.cense_colname is not None:
if self.cense_colname is not None or self.visit_colname is not None:
if self.cense_numerator is None:
self.cense_numerator = _cense_numerator(self)

Expand All @@ -112,6 +119,7 @@ def expand(self) -> None:
self.cense_colname,
self.cense_eligible_colname,
self.compevent_colname,
self.visit_colname,
*self.weight_eligible_colnames,
*self.excused_colnames,
]
Expand Down Expand Up @@ -212,6 +220,7 @@ def fit(self) -> None:
WDT[col] = WDT[col].astype("category")

_fit_LTFU(self, WDT)
_fit_visit(self, WDT)
_fit_numerator(self, WDT)
_fit_denominator(self, WDT)

Expand Down
6 changes: 3 additions & 3 deletions pySEQTarget/analysis/_survival_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def _calculate_risk(self, data, idx=None, val=None):
.group_by("followup")
.agg(
[
pl.col("risk").std().alias("SE"),
pl.col("risk").quantile(lci).alias("LCI"),
pl.col("risk").quantile(uci).alias("UCI"),
pl.col("risk").std().cast(pl.Float64).alias("SE"),
pl.col("risk").quantile(lci).cast(pl.Float64).alias("LCI"),
pl.col("risk").quantile(uci).cast(pl.Float64).alias("UCI"),
]
)
.join(TxDT.select(["followup", main_col]), on="followup")
Expand Down
11 changes: 9 additions & 2 deletions pySEQTarget/error/_param_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,15 @@ def _param_checker(self):
"Only one of followup_class or followup_include can be set to True."
)

if self.weighted and self.method == "ITT" and self.cense_colname is None:
raise ValueError("For weighted ITT analyses, cense_colname must be provided.")
if (
self.weighted
and self.method == "ITT"
and self.cense_colname is None
and self.visit_colname is None
):
raise ValueError(
"For weighted ITT analyses, cense_colname or visit_colname must be provided."
)

if self.excused:
_, self.excused_colnames = _pad(self.treatment_level, self.excused_colnames)
Expand Down
1 change: 1 addition & 0 deletions pySEQTarget/weighting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from ._weight_fit import _fit_denominator as _fit_denominator
from ._weight_fit import _fit_LTFU as _fit_LTFU
from ._weight_fit import _fit_numerator as _fit_numerator
from ._weight_fit import _fit_visit as _fit_visit
from ._weight_pred import _weight_predict as _weight_predict
from ._weight_stats import _weight_stats as _weight_stats
40 changes: 29 additions & 11 deletions pySEQTarget/weighting/_weight_bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ def _weight_bind(self, WDT):

WDT = self.DT.join(WDT, on=on, how=join)

if self.visit_colname is not None:
visit = pl.col(self.visit_colname) == 0
else:
visit = pl.lit(False)

if self.weight_preexpansion and self.excused:
trial = (pl.col("trial") == 0) & (pl.col("period") == 0)
excused = (
Expand All @@ -21,6 +26,7 @@ def _weight_bind(self, WDT):
override = (
trial
| excused
| visit
| pl.col(self.outcome_col).is_null()
| (pl.col("denominator") < 1e-7)
)
Expand All @@ -33,6 +39,7 @@ def _weight_bind(self, WDT):
override = (
trial
| excused
| visit
| pl.col(self.outcome_col).is_null()
| (pl.col("denominator") < 1e-7)
| (pl.col("numerator") < 1e-7)
Expand All @@ -45,24 +52,35 @@ def _weight_bind(self, WDT):
override = (
trial
| excused
| visit
| pl.col(self.outcome_col).is_null()
| (pl.col("denominator") < 1e-15)
| pl.col("numerator").is_null()
)

self.DT = (
WDT.with_columns(
pl.when(override)
.then(pl.lit(1.0))
.otherwise(pl.col("numerator") / pl.col("denominator"))
.alias("wt")
(
WDT.with_columns(
pl.when(override)
.then(pl.lit(1.0))
.otherwise(pl.col("numerator") / pl.col("denominator"))
.alias("wt")
)
.sort([self.id_col, "trial", "followup"])
.with_columns(
pl.col("wt")
.fill_null(1.0)
.cum_prod()
.over([self.id_col, "trial"])
.alias("weight")
)
)
.sort([self.id_col, "trial", "followup"])
.with_columns(
pl.col("wt")
.fill_null(1.0)
.cum_prod()
.over([self.id_col, "trial"])
.alias("weight")
(
pl.col("weight")
* pl.col("_cense").fill_null(1.0)
* pl.col("_visit").fill_null(1.0)
).alias("weight")
)
.drop(["_cense", "_visit"])
)
45 changes: 34 additions & 11 deletions pySEQTarget/weighting/_weight_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,45 @@
import statsmodels.formula.api as smf


def _fit_pair(
self, WDT, outcome_attr, formula_attr, output_attrs, eligible_colname_attr=None
):
outcome = getattr(self, outcome_attr)

if eligible_colname_attr is not None:
_eligible_col = getattr(self, eligible_colname_attr)
if _eligible_col is not None:
WDT = WDT[WDT[_eligible_col] == 1]

for rhs, out in zip(formula_attr, output_attrs):
formula = f"{outcome}~{rhs}"
model = smf.glm(formula, WDT, family=sm.families.Binomial())
setattr(self, out, model.fit(disp=0))


def _fit_LTFU(self, WDT):
if self.cense_colname is None:
return
else:
fits = []
if self.cense_eligible_colname is not None:
WDT = WDT[WDT[self.cense_eligible_colname] == 1]
_fit_pair(
self,
WDT,
"cense_colname",
[self.cense_numerator, self.cense_denominator],
["cense_numerator", "cense_denominator"],
"cense_eligible_colname",
)

for i in [self.cense_numerator, self.cense_denominator]:
formula = f"{self.cense_colname}~{i}"
model = smf.glm(formula, WDT, family=sm.families.Binomial())
model_fit = model.fit(disp=0)
fits.append(model_fit)

self.cense_numerator = fits[0]
self.cense_denominator = fits[1]
def _fit_visit(self, WDT):
if self.visit_colname is None:
return
_fit_pair(
self,
WDT,
"visit_colname",
[self.cense_numerator, self.cense_denominator],
["visit_numerator", "visit_denominator"],
)


def _fit_numerator(self, WDT):
Expand Down
52 changes: 38 additions & 14 deletions pySEQTarget/weighting/_weight_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,44 @@ def _weight_predict(self, WDT):
.otherwise(pl.col("numerator"))
.alias("numerator")
)
if self.cense_colname is not None:
p_num = _predict_model(self, self.cense_numerator, WDT).flatten()
p_denom = _predict_model(self, self.cense_denominator, WDT).flatten()
WDT = WDT.with_columns(
[
pl.Series("cense_numerator", p_num),
pl.Series("cense_denominator", p_denom),
]
).with_columns(
(pl.col("cense_numerator") / pl.col("cense_denominator")).alias("cense")
)
else:
WDT = WDT.with_columns(pl.lit(1.0).alias("cense"))
if self.cense_colname is not None:
p_num = _predict_model(self, self.cense_numerator, WDT).flatten()
p_denom = _predict_model(self, self.cense_denominator, WDT).flatten()
WDT = WDT.with_columns(
[
pl.Series("cense_numerator", p_num),
pl.Series("cense_denominator", p_denom),
]
).with_columns(
(pl.col("cense_numerator") / pl.col("cense_denominator")).alias("_cense")
)
else:
WDT = WDT.with_columns(pl.lit(1.0).alias("_cense"))

kept = ["numerator", "denominator", "cense", self.id_col, "trial", time, "tx_lag"]
if self.visit_colname is not None:
p_num = _predict_model(self, self.visit_numerator, WDT).flatten()
p_denom = _predict_model(self, self.visit_denominator, WDT).flatten()

WDT = WDT.with_columns(
[
pl.Series("visit_numerator", p_num),
pl.Series("visit_denominator", p_denom),
]
).with_columns(
(pl.col("visit_numerator") / pl.col("visit_denominator")).alias("_visit")
)
else:
WDT = WDT.with_columns(pl.lit(1.0).alias("_visit"))

kept = [
"numerator",
"denominator",
"_cense",
"_visit",
self.id_col,
"trial",
time,
"tx_lag",
]
exists = [col for col in kept if col in WDT.columns]
return WDT.select(exists).sort(grouping + [time])
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.9.1"
version = "0.10.0"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
Expand Down
78 changes: 55 additions & 23 deletions tests/test_coefficients.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import sys
from pathlib import Path

from pySEQTarget import SEQopts, SEQuential
from pySEQTarget.data import load_data

Expand Down Expand Up @@ -261,16 +258,16 @@ def test_PreE_LTFU_ITT():
s.fit()
matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list()
expected = [
-21.640523091572796,
0.0685235184372898,
-0.19006360662228572,
0.028750950193838918,
-0.0005762057433736666,
0.28554312978583757,
-0.001373044229623057,
0.006589141394458155,
-0.44898959259422394,
1.3875089788036237,
-21.636346991788276,
0.06813705852786496,
-0.1939555961858531,
0.02874152772603635,
-0.0005734047013500563,
0.2854740212699898,
-0.0013729662310668182,
0.006501915963316852,
-0.4467079969655381,
1.3870473474960576,
]
assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected]
Comment thread
ryan-odea marked this conversation as resolved.
Outdated

Expand All @@ -294,16 +291,16 @@ def test_PostE_LTFU_ITT():
s.fit()
matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list()
expected = [
-21.640523091572796,
0.0685235184372898,
-0.19006360662228572,
0.028750950193838918,
-0.0005762057433736666,
0.28554312978583757,
-0.001373044229623057,
0.006589141394458155,
-0.44898959259422394,
1.3875089788036237,
-21.847198431385877,
0.07786703138967718,
-0.15461370944416225,
0.030140057462437704,
-0.0006287338029348562,
0.287393206037481,
-0.0013719595115633126,
0.007295485861066434,
-0.42797049565882755,
1.4082102322835948,
]
assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected]
Comment thread
ryan-odea marked this conversation as resolved.
Outdated

Expand Down Expand Up @@ -371,3 +368,38 @@ def test_weighted_multinomial():
-0.08478678955657822,
]
assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected]


def test_ITT_visit():
data = load_data("SEQdata_LTFU")

s = SEQuential(
data,
id_col="ID",
time_col="time",
eligible_col="eligible",
treatment_col="tx_init",
outcome_col="outcome",
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="ITT",
parameters=SEQopts(
weighted=True, weight_preexpansion=True, visit_colname="LTFU"
),
)
s.expand()
s.fit()
matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list()
expected = [
-21.636346991788276,
0.06813705852786496,
-0.1939555961858531,
0.02874152772603635,
-0.0005734047013500563,
0.2854740212699898,
-0.0013729662310668182,
0.006501915963316852,
-0.4467079969655381,
1.3870473474960576,
]
assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected]
Comment thread
ryan-odea marked this conversation as resolved.
Outdated
Loading