Skip to content
1,052 changes: 1,052 additions & 0 deletions notebooks/mkl/06 - Gaussian posteriors copy 2.ipynb

Large diffs are not rendered by default.

1,416 changes: 1,416 additions & 0 deletions notebooks/mkl/07c - Inference - v1.ipynb

Large diffs are not rendered by default.

944 changes: 944 additions & 0 deletions notebooks/mkl/10b - Datasets TUM.ipynb

Large diffs are not rendered by default.

141 changes: 141 additions & 0 deletions notebooks/mkl/Collecting_Clicks.ipynb

Large diffs are not rendered by default.

1,500 changes: 749 additions & 751 deletions pixi.lock

Large diffs are not rendered by default.

37 changes: 24 additions & 13 deletions src/b3d/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def camera_from_screen_and_depth(


def camera_from_screen(uv: ScreenCoordinates, intrinsics) -> CameraCoordinates:
z = jnp.ones_like(uv.shape[-1:])
z = jnp.ones(uv.shape[:-1])
return camera_from_screen_and_depth(uv, z, intrinsics)


Expand Down Expand Up @@ -122,7 +122,9 @@ def camera_from_depth(z: DepthImage, intrinsics) -> CameraCoordinates:
unproject_depth = camera_from_depth


def screen_from_camera(xyz: CameraCoordinates, intrinsics) -> ScreenCoordinates:
def screen_from_camera(
xyz: CameraCoordinates, intrinsics, culling=False
) -> ScreenCoordinates:
"""
Maps to sensor coordintaes `uv` from camera coordinates `xyz`, which are
defined by $(u,v) = (u'/z,v'/z)$, where
Expand All @@ -138,25 +140,31 @@ def screen_from_camera(xyz: CameraCoordinates, intrinsics) -> ScreenCoordinates:
Returns:
(...,2) array of screen coordinates.
"""
# TODO: check this
xyz = jnp.clip(
xyz,
jnp.array([-jnp.inf, -jnp.inf, intrinsics.near]),
jnp.array([jnp.inf, jnp.inf, intrinsics.far]),
)
_, _, fx, fy, cx, cy, _, _ = intrinsics
_, _, fx, fy, cx, cy, near, far = intrinsics
x, y, z = xyz[..., 0], xyz[..., 1], xyz[..., 2]
u = x * fx / z + cx
v = y * fy / z + cy
u_ = x * fx / z + cx
v_ = y * fy / z + cy

# TODO: What is the right way of doing this? Returning infs?
in_range = ((near <= z) & (z <= far)) | (not culling)

u = jnp.where(in_range, u_, jnp.inf)
v = jnp.where(in_range, v_, jnp.inf)

return jnp.stack([u, v], axis=-1)


screen_from_xyz = screen_from_camera


def screen_from_world(x, cam, intr):
def screen_from_world(x, cam, intr, culling=False):
"""Maps to screen coordintaes `uv` from world coordinates `xyz`."""
return screen_from_camera(cam.inv().apply(x), intr)
return screen_from_camera(cam.inv().apply(x), intr, culling=culling)


def world_from_screen(uv, cam, intr):
"""Maps to world coordintaes `xyz` from screen coords `uv`."""
return cam.apply(camera_from_screen(uv, intr))


def camera_matrix_from_intrinsics(intr: Intrinsics) -> CameraMatrix3x3:
Expand Down Expand Up @@ -216,6 +224,9 @@ def homogeneous_coordinates(xs, z=jnp.array(1.0)):
return jnp.concatenate([xs, jnp.ones_like(xs[..., :1])], axis=-1) * z[..., None]


homogeneous = homogeneous_coordinates


def planar_coordinates(xs):
"""
Maps homogeneous to planar coordinates, eg.,
Expand Down
2 changes: 2 additions & 0 deletions src/b3d/chisight/sfm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .eight_point import *
from .epipolar import *
51 changes: 51 additions & 0 deletions src/b3d/chisight/sfm/camera_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Computation of the camera matrix from world points and their projections.

References:
> Terzakis--Lourakis, "A Consistently Fast and Globally Optimal Solution to the Perspective-n-Point Problem"
> Hartley--Zisserman, "Multiple View Geometry in Computer Vision", 2nd ed.
"""

from jax import numpy as jnp

from b3d.types import Matrix3x4, Point3D


def solve_camera_projection_constraints(Xs: Point3D, ys: Point3D) -> Matrix3x4:
"""
Solve for the camera projection matrix given 3D points and their 2D projections,
as described in Chapter 7 ("Computation of the Camera Matrix P") of
> Hartley--Zisserman, "Multiple View Geometry in Computer Vision" (2nd ed).

Args:
Xs: 3D points in world coordinates, shape (N, 3).
ys: Normalized image coordinates, shape (N, 2).

Returns:
Camera projection matrix, shape (3, 4).
"""
# We change notation from B3D notation
# to Hartley--Zisserman, for easy of comparison
X = Xs
x = ys[:, 0]
y = ys[:, 1]
w = ys[:, 2]
n = X.shape[0]

A = jnp.concatenate(
[
jnp.block(
[
[jnp.zeros(3), -w[i] * X[i], y[i] * X[i]],
[w[i] * X[i], jnp.zeros(3), -x[i] * X[i]],
[-y[i] * X[i], x[i] * X[i], jnp.zeros(3)],
]
)
for i in jnp.arange(n)
],
axis=0,
)

_, _, vt = jnp.linalg.svd(A)
P = vt[-1].reshape(3, 4)
return P
Loading