diff --git a/src/test/java/gr/james/sampling/WeightedRandomSamplingTest.java b/src/test/java/gr/james/sampling/WeightedRandomSamplingTest.java index 98b98d1..f22b02b 100644 --- a/src/test/java/gr/james/sampling/WeightedRandomSamplingTest.java +++ b/src/test/java/gr/james/sampling/WeightedRandomSamplingTest.java @@ -5,9 +5,7 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Random; +import java.util.*; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -106,4 +104,27 @@ public void cornerWeights() { alg.sample(); } + /** + * Equivalence between {@link WeightedRandomSampling#feed(Object)}, + * {@link WeightedRandomSampling#feed(Iterator, Iterator)} and {@link WeightedRandomSampling#feed(Map)}. + */ + @Test + public void feedAlternative() { + final WeightedRandomSampling rs1 = impl.get(); + final WeightedRandomSampling rs2 = impl.get(); + final WeightedRandomSampling rs3 = impl.get(); + final Map map = new HashMap<>(); + for (int i = 1; i <= SAMPLE; i++) { + map.put(i, (double) i); + rs1.feed(i, (double) i); + } + rs3.feed(map.keySet().iterator(), map.values().iterator()); + rs2.feed(map); + Assert.assertEquals(SAMPLE, rs1.sample().size()); + Assert.assertEquals(SAMPLE, rs2.sample().size()); + Assert.assertEquals(SAMPLE, rs3.sample().size()); + Assert.assertTrue(rs1.sample().containsAll(rs2.sample())); + Assert.assertTrue(rs2.sample().containsAll(rs3.sample())); + } + }