Skip to content

Commit 1af98ee

Browse files
Sort only upper tail waits (#22)
* Deprecate sorted keyword * Use partialsortperm * Increment version number * Update src/core.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent ae98bab commit 1af98ee

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PSIS"
22
uuid = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
33
authors = ["Seth Axen <[email protected]> and contributors"]
4-
version = "0.2.5"
4+
version = "0.2.6"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/core.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,6 @@ While `psis` computes smoothed log weights out-of-place, `psis!` smooths them in
170170
171171
# Keywords
172172
173-
- `sorted=issorted(vec(log_ratios))`: whether `log_ratios` are already sorted. Only
174-
accepted if `nparams==1`.
175173
- `improved=false`: If `true`, use the adaptive empirical prior of [^Zhang2010].
176174
If `false`, use the simpler prior of [^ZhangStephens2009], which is also used in
177175
[^VehtariSimpson2021].
@@ -207,7 +205,7 @@ end
207205
function psis!(
208206
logw::AbstractVector,
209207
reff=1;
210-
sorted::Bool=issorted(logw),
208+
sorted::Bool=false, # deprecated
211209
improved::Bool=false,
212210
warn::Bool=true,
213211
)
@@ -219,11 +217,11 @@ function psis!(
219217
@warn "$M tail draws is insufficient to fit the generalized Pareto distribution. $MISSING_SHAPE_SUMMARY"
220218
return PSISResult(logw, LogExpFunctions.logsumexp(logw), reff_val, M, missing)
221219
end
222-
perm = sorted ? collect(eachindex(logw)) : sortperm(logw)
223-
icut = S - M
224-
tail_range = (icut + 1):S
225-
@inbounds logw_tail = @views logw[perm[tail_range]]
226-
@inbounds logu = logw[perm[icut]]
220+
perm = partialsortperm(logw, (S - M):S)
221+
cutoff_ind = perm[1]
222+
tail_inds = @view perm[2:(M + 1)]
223+
logu = logw[cutoff_ind]
224+
logw_tail = @views logw[tail_inds]
227225
_, tail_dist = psis_tail!(logw_tail, logu, M, improved)
228226
warn && check_pareto_shape(tail_dist)
229227
return PSISResult(logw, LogExpFunctions.logsumexp(logw), reff_val, M, tail_dist)

0 commit comments

Comments
 (0)