Skip to content

Commit

Permalink
[CLRS] Incorporate a collinearity check into the ConvexHull sampler t…
Browse files Browse the repository at this point in the history
…o prevent generating samples that admit multiple valid paths.

PiperOrigin-RevId: 717531004
  • Loading branch information
RerRayne authored and copybara-github committed Jan 20, 2025
1 parent dcf8643 commit 7712b7b
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 12 deletions.
5 changes: 3 additions & 2 deletions clrs/_src/clrs_text/huggingface_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def clrs_generator(
Example usage for a finite dataset:
algos_and_lengths = {"insertion_sort": [16]}
ds = datasets.Dataset.from_generator(
clrs_gen, gen_kwargs={
clrs_generator, gen_kwargs={
"algos_and_lengths": algos_and_lengths,
"num_samples": 100
}
Expand Down Expand Up @@ -72,7 +72,8 @@ def clrs_generator(
IterableDataset.from_generator.
use_hints: Whether hints should be included in the question and answer.
seed: The random seed for all of the generators.
num_decimals_in_float: The number of decimals to truncate floats to.
num_decimals_in_float: The number of decimals to truncate floats to. Defaults
to 3.
Yields:
A dictionary with the following keys:
Expand Down
104 changes: 94 additions & 10 deletions clrs/_src/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
import collections
import copy
import inspect
import itertools
import types

from typing import Any, Callable, List, Optional, Tuple
from absl import logging

from clrs._src import algorithms
from clrs._src import probing
from clrs._src import specs
Expand Down Expand Up @@ -707,21 +706,106 @@ def intersect(xs, ys):
return [xs, ys]


def _is_collinear(
point_1: np.ndarray,
point_2: np.ndarray,
point_3: np.ndarray,
eps: float,
) -> bool:
"""Checks if three points are collinear.
Args:
point_1: The first point.
point_2: The second point.
point_3: The third point.
eps: The tolerance for collinearity.
Returns:
True if the three points are collinear, False otherwise.
Raises:
ValueError: If any of the points is not a 2D vector.
"""
for point in [point_1, point_2, point_3]:
if point.shape != (2,):
raise ValueError(f'Point {point} is not a 2D vector.')

# Vectors from p1
v_1 = point_2 - point_1
v_2 = point_3 - point_1

cross_val = np.cross(v_1, v_2)

return bool(abs(cross_val) < eps)


class ConvexHullSampler(Sampler):
"""Convex hull sampler of points over a disk of radius r."""
CAN_TRUNCATE_INPUT_DATA = True

def _sample_data(self, length: int, origin_x: float = 0.,
origin_y: float = 0., radius: float = 2.):
def _sample_data(
self,
length: int,
origin_x: float = 0.0,
origin_y: float = 0.0,
radius: float = 2.0,
collinearity_resampling_attempts: int = 1000,
collineararity_eps: float = 1e-12,
):
"""Samples a convex hull of points over a disk of radius r.
thetas = self._random_sequence(length=length, low=0.0, high=2.0 * np.pi)
rs = radius * np.sqrt(
self._random_sequence(length=length, low=0.0, high=1.0))
Args:
length: The number of points to sample.
origin_x: The x-coordinate of the origin of the disk.
origin_y: The y-coordinate of the origin of the disk.
radius: The radius of the disk.
collinearity_resampling_attempts: The number of times to resample if
collinear points are found.
collineararity_eps: The tolerance for collinearity.
xs = rs * np.cos(thetas) + origin_x
ys = rs * np.sin(thetas) + origin_y
Returns:
A list of the sampled points.
return [xs, ys]
Raises:
RuntimeError: If it could not sample stable points within the specified
number of attempts.
"""
for _ in range(collinearity_resampling_attempts):
thetas = self._random_sequence(
length=length,
low=0.0,
high=2.0 * np.pi,
)
rs = radius * np.sqrt(
self._random_sequence(length=length, low=0.0, high=1.0)
)

xs = rs * np.cos(thetas) + origin_x
ys = rs * np.sin(thetas) + origin_y

# Sampler._make_batch may do truncation of the input data after
# calling _sample_data.
# Truncation can lead to collinearity of points, which in turn leads to
# numerous correct traces in the Graham scan algorithm. To prevent this,
# we check for collinearity and resample if collinear points are found.
xs = self._trunc_array(xs)
ys = self._trunc_array(ys)

collinear_found = False
points = np.stack([xs, ys], axis=1)
for point_1, point_2, point_3 in itertools.combinations(points, 3):
if _is_collinear(point_1, point_2, point_3, collineararity_eps):
collinear_found = True
break

if collinear_found:
continue

return [xs, ys]

raise RuntimeError(
f'Could not sample {length} stable points within {10000} tries.'
)


SAMPLERS = {
Expand Down
45 changes: 45 additions & 0 deletions clrs/_src/samplers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,51 @@ def test_trunc_array(self, input_data, expected_output, truncate_decimals):
def test_is_float_array(self, input_data, expected_output):
self.assertEqual(samplers._is_float_array(input_data), expected_output)

@parameterized.named_parameters(
dict(
testcase_name='collinear_points',
point_1=np.array([1, 1]),
point_2=np.array([2, 2]),
point_3=np.array([3, 3]),
eps=1e-6,
expected_output=True,
),
dict(
testcase_name='non_collinear_points',
point_1=np.array([1, 1]),
point_2=np.array([2, 2]),
point_3=np.array([1, 2]),
eps=1e-6,
expected_output=False,
),
dict(
testcase_name='points_within_tolerance',
point_1=np.array([1, 1]),
point_2=np.array([2, 2]),
point_3=np.array([2.00, 1.9999]),
eps=1e-4,
expected_output=True,
),
dict(
testcase_name='points_outside_tolerance',
point_1=np.array([1, 1]),
point_2=np.array([2, 2]),
point_3=np.array([2, 1.9998]),
eps=1e-4,
expected_output=False,
),
)
def test_is_collinear(self, point_1, point_2, point_3, eps, expected_output):
self.assertEqual(
samplers._is_collinear(point_1, point_2, point_3, eps), expected_output
)

def test_is_collinear_raise_error(self):
with self.assertRaises(ValueError):
samplers._is_collinear(
np.array([1, 1, 3]), np.array([2,]), np.array([3, 3]), eps=1e-6,
)


class ProcessRandomPosTest(parameterized.TestCase):

Expand Down

0 comments on commit 7712b7b

Please sign in to comment.