From 60cbcc9855216051cf091941249553647a5b2f22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogumi=C5=82=20Kami=C5=84ski?= Date: Wed, 15 Feb 2017 22:54:25 +0100 Subject: [PATCH 1/3] Fix #238 A proposal to fix #238. The original article assumes positive weights, so I propose to skip zero weights. Additionally it is now strictly checked if there are not less positive weights in `wv` as required sample size. --- src/sampling.jl | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/sampling.jl b/src/sampling.jl index 2dffbbacd..772f73029 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -515,15 +515,22 @@ function efraimidis_ares_wsample_norep!(a::AbstractArray, wv::WeightVec, x::Abst # initialize priority queue pq = Vector{Pair{Float64,Int}}(k) - @inbounds for i in 1:k - pq[i] = (wv.values[i]/randexp() => i) + i = 0 + s = 0 + @inbounds for s in 1:n + if wv.values[s] > 0.0 + i += 1 + pq[i] = (wv.values[s]/randexp() => s) + end + i >= k && break end + i < k && throw(DimensionMismatch("wv must have at least $k positive entries (got $i)")) heapify!(pq) # set threshold @inbounds threshold = pq[1].first - @inbounds for i in k+1:n + @inbounds for i in s+1:n key = wv.values[i]/randexp() # if key is larger than the threshold @@ -561,17 +568,25 @@ function efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::WeightVec, x::Abs # initialize priority queue pq = Vector{Pair{Float64,Int}}(k) - @inbounds for i in 1:k - pq[i] = (wv.values[i]/randexp() => i) + i = 0 + s = 0 + @inbounds for s in 1:n + if wv.values[s] > 0.0 + i += 1 + pq[i] = (wv.values[s]/randexp() => s) + end + i >= k && break end + i < k && throw(DimensionMismatch("wv must have at least $k positive entries (got $i)")) heapify!(pq) # set threshold @inbounds threshold = pq[1].first X = threshold*randexp() - @inbounds for i in k+1:n + @inbounds for i in s+1:n w = wv.values[i] + w > 0.0 || continue X -= w X <= 0 || continue From 5ed66f384db16b151eb68544fc4ecaa6be8ce3e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogumi=C5=82=20Kami=C5=84ski?= Date: Thu, 16 Feb 2017 20:41:21 +0100 Subject: [PATCH 2/3] weighted sampling without replacement: added tests and small fixes --- src/sampling.jl | 24 ++++++++++++++++-------- test/sampling.jl | 26 ++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/src/sampling.jl b/src/sampling.jl index 772f73029..e6569ba21 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -518,20 +518,25 @@ function efraimidis_ares_wsample_norep!(a::AbstractArray, wv::WeightVec, x::Abst i = 0 s = 0 @inbounds for s in 1:n - if wv.values[s] > 0.0 + w = wv.values[s] + w < 0 && error("Negative weight found in weight vector at index $s") + if w > 0 i += 1 - pq[i] = (wv.values[s]/randexp() => s) + pq[i] = (w/randexp() => s) end i >= k && break end - i < k && throw(DimensionMismatch("wv must have at least $k positive entries (got $i)")) + i < k && throw(DimensionMismatch("wv must have at least $k strictly positive entries (got $i)")) heapify!(pq) # set threshold @inbounds threshold = pq[1].first @inbounds for i in s+1:n - key = wv.values[i]/randexp() + w = wv.values[i] + w < 0 && error("Negative weight found in weight vector at index $i") + w > 0 || continue + key = w/randexp() # if key is larger than the threshold if key > threshold @@ -571,13 +576,15 @@ function efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::WeightVec, x::Abs i = 0 s = 0 @inbounds for s in 1:n - if wv.values[s] > 0.0 + w = wv.values[s] + w < 0 && error("Negative weight found in weight vector at index $s") + if w > 0 i += 1 - pq[i] = (wv.values[s]/randexp() => s) + pq[i] = (w/randexp() => s) end i >= k && break end - i < k && throw(DimensionMismatch("wv must have at least $k positive entries (got $i)")) + i < k && throw(DimensionMismatch("wv must have at least $k strictly positive entries (got $i)")) heapify!(pq) # set threshold @@ -586,7 +593,8 @@ function efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::WeightVec, x::Abs @inbounds for i in s+1:n w = wv.values[i] - w > 0.0 || continue + w < 0 && error("Negative weight found in weight vector at index $i") + w > 0 || continue X -= w X <= 0 || continue diff --git a/test/sampling.jl b/test/sampling.jl index c19094f99..3f0f74dab 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -149,3 +149,29 @@ check_sample_norep(a, (3, 12), 0; ordered=false) a = sample(3:12, 5; replace=false, ordered=true) check_sample_norep(a, (3, 12), 0; ordered=true) + +# test of weighted sampling without replacement +import StatsBase: sample +a = [1:10;] +wv = WeightVec([zeros(6); 1:4]) +x = vcat([sample(a, wv, 1, replace=false) for j in 1:100000]...) +@test minimum(x) == 7 +@test maximum(x) == 10 +@test maximum(abs(proportions(x) - (1:4)/10)) < 0.01 + +x = vcat([sample(a, wv, 2, replace=false) for j in 1:50000]...) +exact2 = [0.117261905, 0.220634921, 0.304166667, 0.357936508] +@test minimum(x) == 7 +@test maximum(x) == 10 +@test maximum(abs(proportions(x) - exact2)) < 0.01 + +x = vcat([sample(a, wv, 4, replace=false) for j in 1:10000]...) +@test minimum(x) == 7 +@test maximum(x) == 10 +@test maximum(abs(proportions(x) - 0.25)) == 0 + +@test_throws DimensionMismatch sample(a, wv, 5, replace=false) + +wv = WeightVec([zeros(5); 1:4; -1]) +@test_throws ErrorException sample(a, wv, 1, replace=false) + From 77ec83a8ba92b897c822209ba5379fdfba51c1d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogumi=C5=82=20Kami=C5=84ski?= Date: Wed, 3 May 2017 23:40:04 +0200 Subject: [PATCH 3/3] removed redundant import in tests --- test/sampling.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/sampling.jl b/test/sampling.jl index 3f0f74dab..bc6027cd7 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -151,7 +151,6 @@ a = sample(3:12, 5; replace=false, ordered=true) check_sample_norep(a, (3, 12), 0; ordered=true) # test of weighted sampling without replacement -import StatsBase: sample a = [1:10;] wv = WeightVec([zeros(6); 1:4]) x = vcat([sample(a, wv, 1, replace=false) for j in 1:100000]...)