Skip to content

Commit

Permalink
Rework the Weighted.compareTo method (#46)
Browse files Browse the repository at this point in the history
The method no longer relies on System.identityHashCode to break ties but lazily generates random state in order to break ties. Two different Weighted references can never be equal and can never compare to zero.
  • Loading branch information
gstamatelat committed Aug 22, 2022
1 parent 1435c8d commit 96d674d
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 9 deletions.
47 changes: 38 additions & 9 deletions src/main/java/gr/james/sampling/Weighted.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
package gr.james.sampling;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;

/**
* Represents an item with a weight.
* <p>
* This class is immutable and is meant for use in weighted random sampling algorithms.
* <p>
* The {@link #equals(Object)}, {@link #hashCode()} and {@link #compareTo(Weighted)} methods are implemented in such a
* way that two {@code a.equals(b) == (a == b)} and {@code (a.compareTo(b) == 0) == (a == b)}. As a result, two
* different references are always unequal.
*
* @param <T> the object type
* @author Giorgos Stamatelatos
Expand All @@ -19,6 +27,11 @@ class Weighted<T> implements Comparable<Weighted<T>> {
*/
public final double weight;

/**
* A list of ints to break the ties of {@link #compareTo(Weighted)} among elements with the same weight.
*/
private final List<Integer> id;

/**
* Construct a new {@link Weighted} from a given object and weight.
* <p>
Expand All @@ -30,15 +43,16 @@ class Weighted<T> implements Comparable<Weighted<T>> {
public Weighted(T object, double weight) {
this.object = object;
this.weight = weight;
this.id = new ArrayList<>();
}

/**
* Compares this object with the specified object for order. Returns a negative integer, zero, or a positive integer
* as this object is less than, equal to, or greater than the specified object.
* <p>
* The comparison is based on the {@link #weight} values of the two {@link Weighted} objects. If the weights are of
* the same value, the comparison is based on {@link System#identityHashCode(Object)}. This means that
* {@code a.equals(b)} will evaluate to {@code 0} if and only if {@code a == b}.
* the same value, the comparison is performed using a hidden random internal state of the objects that guarantees
* that {@code a.compareTo(b) == 0} if and only if {@code a == b}.
*
* @param o the object to be compared
* @return a negative integer, zero, or a positive integer as the object is less than, equal to, or greater than the
Expand All @@ -47,27 +61,42 @@ public Weighted(T object, double weight) {
*/
@Override
public int compareTo(Weighted<T> o) {
if (this == o) {
return 0;
}
assert !this.equals(o);
final int c = Double.compare(weight, o.weight);
if (c == 0) {
assert (Integer.compare(System.identityHashCode(this), System.identityHashCode(o)) == 0) == (this.equals(o));
return Integer.compare(System.identityHashCode(this), System.identityHashCode(o));
} else {
assert !this.equals(o);
if (c != 0) {
return c;
}
for (int i = 0; ; i++) {
assert this.id.size() >= i;
assert o.id.size() >= i;
if (this.id.size() == i) {
this.id.add(ThreadLocalRandom.current().nextInt());
}
if (o.id.size() == i) {
o.id.add(ThreadLocalRandom.current().nextInt());
}
if (this.id.get(i) > o.id.get(i)) {
return 1;
} else if (this.id.get(i) < o.id.get(i)) {
return -1;
}
}
}

/**
* Indicates whether some other object is "equal to" this one.
* <p>
* The implementation delegates to an invocation of {@link Object#equals(Object)}.
* The implementation delegates to an invocation of {@link Object#equals(Object)} and returns {@code true} if and
* only if {@code this == obj}.
*
* @param obj the reference object with which to compare
* @return {@code true} if this object is the same as the {@code obj} argument; {@code false} otherwise
*/
@Override
public boolean equals(Object obj) {
assert !super.equals(obj) || super.hashCode() == obj.hashCode();
return super.equals(obj);
}

Expand Down
74 changes: 74 additions & 0 deletions src/test/java/gr/james/sampling/WeightedTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package gr.james.sampling;

import org.junit.Assert;
import org.junit.Test;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Tests for the {@link Weighted} class.
*/
public class WeightedTest {
/**
* In this test, we sort a list of Weighted objects, all of which have the same weight.
* <p>
* After the sorting, we make sure that each element is strictly lower than its successor. With this, we test for
* two things: 1. There exist no duplicate elements, 2. The compareTo method is deterministic and consistent.
*/
@Test
public void compareToConsistency() {
// Number of items in the list
final int COUNT = 2000000;
// Create the list
final List<Weighted<Integer>> weightedList = new ArrayList<>();
for (int i = 0; i < COUNT; i++) {
weightedList.add(new Weighted<>(0, 0));
}
// Sort the list
weightedList.sort(null);
// Test the list
for (int i = 0; i < COUNT - 1; i++) {
final Weighted<Integer> x = weightedList.get(i);
final Weighted<Integer> y = weightedList.get(i + 1);
Assert.assertTrue(x.compareTo(y) < 0);
Assert.assertNotSame(x, y);
Assert.assertNotEquals(x, y);
}
}

/**
* In a list of Weighted objects with same weight, sorting them should have the same effect as shuffling the
* original objects.
*/
@Test
public void compareToShuffling() {
// Number of items in the list
final int COUNT = 5;
final int PERMUTATIONS = 120;
final int REPS = 24000000;
// Frequency map
final Map<List<Integer>, Integer> frequencies = new HashMap<>();
// Do the experiments
for (int i = 0; i < REPS; i++) {
// Create the list
final List<Weighted<Integer>> weightedList = new ArrayList<>();
for (int k = 0; k < COUNT; k++) {
weightedList.add(new Weighted<>(k, 0));
}
// Sort the list
weightedList.sort(null);
// Add list of objects to frequencies
final List<Integer> ll = weightedList.stream().map(x -> x.object).collect(Collectors.toList());
frequencies.merge(ll, 1, Integer::sum);
}
// Tests
Assert.assertEquals(PERMUTATIONS, frequencies.size());
for (int v : frequencies.values()) {
Assert.assertEquals(1.0, 1.0 * v * PERMUTATIONS / REPS, 1e-2);
}
}
}

0 comments on commit 96d674d

Please sign in to comment.