Skip to content

Commit 0e319cb

Browse files
Moved the DepthPro implementation to depth-estimation
commit-id:671b4efa
1 parent b77d4d1 commit 0e319cb

File tree

13 files changed

+1516
-0
lines changed

13 files changed

+1516
-0
lines changed

packages/depth-estimation/README.md

Whitespace-only changes.
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
[project]
2+
name = "depth-estimation"
3+
version = "0.1.0"
4+
description = "Add your description here"
5+
readme = "README.md"
6+
authors = [
7+
{ name = "Jan Smółka", email = "[email protected]" }
8+
]
9+
requires-python = ">=3.12.7"
10+
dependencies = [
11+
"jaxtyping>=0.3.0",
12+
"timm>=1.0.15",
13+
"torch>=2.6.0",
14+
"torchvision>=0.21.0",
15+
]
16+
17+
[build-system]
18+
requires = ["hatchling"]
19+
build-backend = "hatchling.build"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from . import depth_pro
2+
3+
__all__ = ['depth_pro']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .network import Configuration, DepthPro
2+
3+
__all__ = ['DepthPro', 'Configuration']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
from copy import deepcopy
2+
from typing import Self
3+
4+
import torch
5+
from torch import nn
6+
7+
# +---------------------------------------------------------------------+
8+
# | Code adopted from: |
9+
# | Repository: https://github.com/apple/ml-depth-pro |
10+
# | Commit: b2cd0d51daa95e49277a9f642f7fd736b7f9e91d |
11+
# | File: `src/depth_pro/network/decoder.py` |
12+
# | Acknowledgement: Copyright (C) 2024 Apple Inc. All Rights Reserved. |
13+
# +---------------------------------------------------------------------+
14+
15+
16+
class Decoder(nn.Module):
17+
"""Decoder for multi-resolution encodings."""
18+
19+
dims_encoder: list[int]
20+
dim_decoder: int
21+
dim_out: int
22+
23+
convs: nn.ModuleList
24+
fusions: nn.ModuleList
25+
26+
def __init__(
27+
self,
28+
dims_encoder: list[int],
29+
dim_decoder: int,
30+
) -> None:
31+
"""
32+
Initialize multiresolution convolutional decoder.
33+
34+
Parameters:
35+
---
36+
dims_encoder: list[str]
37+
Expected dimensions at each level from the encoder.
38+
39+
dim_decoder: int
40+
Dimension of decoder features.
41+
"""
42+
43+
super().__init__()
44+
45+
self.dims_encoder = dims_encoder
46+
self.dim_decoder = dim_decoder
47+
self.dim_out = dim_decoder
48+
49+
n_encoders = len(dims_encoder)
50+
51+
# At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution
52+
# when the dimensions mismatch. Otherwise we do not do anything, which is
53+
# the default behavior of monodepth.
54+
conv0 = (
55+
nn.Conv2d(dims_encoder[0], dim_decoder, kernel_size=1, bias=False)
56+
if self.dims_encoder[0] != dim_decoder
57+
else nn.Identity()
58+
)
59+
60+
convs = [conv0] + [
61+
nn.Conv2d(
62+
in_channels,
63+
dim_decoder,
64+
kernel_size=3,
65+
stride=1,
66+
padding=1,
67+
bias=False,
68+
)
69+
for in_channels in dims_encoder[1:]
70+
]
71+
self.convs = nn.ModuleList(convs)
72+
73+
fusions = [
74+
FeatureFusionBlock2d(
75+
features=dim_decoder,
76+
use_deconv=False,
77+
batch_norm=False,
78+
)
79+
] + [
80+
FeatureFusionBlock2d(
81+
features=dim_decoder,
82+
use_deconv=True,
83+
batch_norm=False,
84+
)
85+
for _ in range(1, n_encoders)
86+
]
87+
self.fusions = nn.ModuleList(fusions)
88+
89+
def forward(self, encodings: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
90+
"""Decode the multi-resolution encodings."""
91+
92+
num_levels = len(encodings)
93+
num_encoders = len(self.dims_encoder)
94+
95+
if num_levels != num_encoders:
96+
raise ValueError(
97+
f'Got encoder output levels={num_levels}, expected levels={num_encoders + 1}.'
98+
)
99+
100+
# Project features of different encoder dims to the same decoder dim.
101+
# Fuse features from the lowest resolution (num_levels-1)
102+
# to the highest (0).
103+
features = self.convs[-1](encodings[-1])
104+
low_resolution_features = features
105+
features = self.fusions[-1](features)
106+
107+
for i in range(num_levels - 2, -1, -1):
108+
features_i = self.convs[i](encodings[i])
109+
features = self.fusions[i](features, features_i)
110+
111+
return features, low_resolution_features
112+
113+
114+
class ResidualBlock(nn.Module):
115+
"""
116+
Generic implementation of residual blocks.
117+
118+
This implements a generic residual block from
119+
He et al. - Identity Mappings in Deep Residual Networks (2016),
120+
https://arxiv.org/abs/1603.05027
121+
which can be further customized via factory functions.
122+
"""
123+
124+
residual: nn.Module
125+
shortcut: nn.Module | None
126+
127+
def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
128+
"""Initialize ResidualBlock."""
129+
super().__init__()
130+
self.residual = residual
131+
self.shortcut = shortcut
132+
133+
@classmethod
134+
def with_shape(cls, n: int, batch_norm: bool) -> Self:
135+
layers: list[nn.Module] = [
136+
nn.ReLU(inplace=False),
137+
nn.Conv2d(
138+
n,
139+
n,
140+
kernel_size=3,
141+
stride=1,
142+
padding=1,
143+
bias=not batch_norm,
144+
),
145+
]
146+
147+
if batch_norm:
148+
layers.append(nn.BatchNorm2d(n))
149+
150+
return cls(
151+
nn.Sequential(
152+
*deepcopy(layers),
153+
*layers,
154+
)
155+
)
156+
157+
def forward(self, x: torch.Tensor) -> torch.Tensor:
158+
"""Apply residual block."""
159+
160+
delta_x: torch.Tensor = self.residual(x)
161+
162+
if self.shortcut is not None:
163+
x = self.shortcut(x)
164+
165+
return x + delta_x
166+
167+
168+
class FeatureFusionBlock2d(nn.Module):
169+
"""Feature fusion for DPT."""
170+
171+
features: int
172+
use_deconv: bool
173+
174+
skip_add: torch.ao.nn.quantized.FloatFunctional
175+
176+
resnet1: ResidualBlock
177+
resnet2: ResidualBlock
178+
179+
deconv: nn.ConvTranspose2d
180+
out_conv: nn.Conv2d
181+
182+
def __init__(
183+
self,
184+
features: int,
185+
use_deconv: bool = False,
186+
batch_norm: bool = False,
187+
):
188+
"""
189+
Initialize feature fusion block.
190+
191+
Parameters
192+
---
193+
features: int
194+
Number of input and output dimensions.
195+
196+
deconv: bool
197+
Whether to use deconvolution before the final output convolution.
198+
199+
batch_norm: bool
200+
Whether to use batch normalization in resnet blocks.
201+
202+
"""
203+
204+
super().__init__()
205+
206+
self.features = features
207+
self.use_deconv = use_deconv
208+
self.skip_add = torch.ao.nn.quantized.FloatFunctional()
209+
210+
self.resnet1 = ResidualBlock.with_shape(features, batch_norm)
211+
self.resnet2 = ResidualBlock.with_shape(features, batch_norm)
212+
213+
if use_deconv:
214+
self.deconv = nn.ConvTranspose2d(
215+
in_channels=features,
216+
out_channels=features,
217+
kernel_size=2,
218+
stride=2,
219+
padding=0,
220+
bias=False,
221+
)
222+
223+
self.out_conv = nn.Conv2d(
224+
features,
225+
features,
226+
kernel_size=1,
227+
stride=1,
228+
padding=0,
229+
bias=True,
230+
)
231+
232+
def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
233+
"""Process and fuse input features."""
234+
235+
x = x0
236+
237+
if x1 is not None:
238+
res = self.resnet1(x1)
239+
x = self.skip_add.add(x, res)
240+
241+
x = self.resnet2(x)
242+
243+
if self.use_deconv:
244+
x = self.deconv(x)
245+
x = self.out_conv(x)
246+
247+
return x

0 commit comments

Comments
 (0)