Skip to content

WIP: Multimodal Audio #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 92 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
7709e65
WIP: multimodal support
sohamparikh Apr 8, 2025
0db2bd2
rough idea for memmap
sohamparikh Apr 9, 2025
0d89f68
faster image size reading
sohamparikh Apr 15, 2025
3866a53
solidify prepare
sohamparikh Apr 21, 2025
8413983
wip
sohamparikh Apr 24, 2025
6521e41
vision model
sohamparikh Apr 24, 2025
daf586f
wip
sohamparikh Apr 24, 2025
ef4488d
wip
sohamparikh Apr 25, 2025
6d9d595
missing files
sohamparikh Apr 28, 2025
6cb8f5d
make it work, barely
sohamparikh Apr 30, 2025
5761a2d
fix
sohamparikh Apr 30, 2025
d45d600
fixes
sohamparikh May 1, 2025
74a99b8
changes
sohamparikh May 6, 2025
99ad5d9
patches and fixes
sohamparikh May 7, 2025
bcb557a
fix dependency
sohamparikh May 7, 2025
a6f5364
remove for testing
sohamparikh May 7, 2025
73b431b
mising
sohamparikh May 7, 2025
ec2c9fb
initial data prep
tobyzl2 May 7, 2025
6d65676
fix
sohamparikh May 8, 2025
46aefc1
Merge branch 'main' into soham/pixtral-support
sohamparikh May 9, 2025
82e4edb
audio dataset changes
tobyzl2 May 9, 2025
9f3e60e
merge
tobyzl2 May 9, 2025
66e7081
fixes
sohamparikh May 9, 2025
0d1cd96
audio token computation
tobyzl2 May 9, 2025
33b138b
Merge branch 'soham/pixtral-support' of https://github.com/ServiceNow…
tobyzl2 May 9, 2025
40f3882
implement mm packing
tobyzl2 May 10, 2025
7f86a7f
fix
sohamparikh May 12, 2025
3a8a99d
more fixes after merge
sohamparikh May 12, 2025
d16284e
conv cleanup
sohamparikh May 12, 2025
b3134aa
more conv cleanup
sohamparikh May 12, 2025
c8aa66e
images + loss-masks
sohamparikh May 13, 2025
0baae59
minor fixes
sohamparikh May 13, 2025
48855be
cleanup
sohamparikh May 13, 2025
f35e003
cleanup
sohamparikh May 13, 2025
4eb34cb
cleanup
sohamparikh May 13, 2025
ebb9e27
cleanup
sohamparikh May 13, 2025
51098ef
fix
sohamparikh May 13, 2025
60b87fa
prepare cleanup
sohamparikh May 13, 2025
f8a5532
slightly better conversion
sohamparikh May 13, 2025
a035d0c
merge
tobyzl2 May 13, 2025
490651e
cleanup, sequence parallelism
sohamparikh May 14, 2025
24e1b83
fix conv
sohamparikh May 14, 2025
0f1612a
wip fixes
sohamparikh May 14, 2025
2e48c5f
fix
sohamparikh May 14, 2025
94e439c
data updates
tobyzl2 May 15, 2025
543fc0d
changes
tobyzl2 May 16, 2025
d529d37
fix image position
sohamparikh May 17, 2025
3c22dda
cleanup
sohamparikh May 17, 2025
f0c8d83
cleanup
sohamparikh May 20, 2025
6bbdc94
merge
tobyzl2 May 20, 2025
1a20913
layer changes
tobyzl2 May 20, 2025
ca33ee8
cleaner, extensible multimodal config
sohamparikh May 21, 2025
f3a4a74
cleanup
sohamparikh May 21, 2025
3b955b1
fixes for pixtral
sohamparikh May 21, 2025
49daf58
model fixes
sohamparikh May 21, 2025
b5ed9f4
more cleanup
sohamparikh May 22, 2025
dc888c8
image break token in sampling
sohamparikh May 22, 2025
5ffacab
merge
tobyzl2 May 22, 2025
c5396bc
Merge branch 'soham/pixtral-support' of https://github.com/ServiceNow…
tobyzl2 May 22, 2025
af3e2db
minor fixes
sohamparikh May 23, 2025
6d56be0
fix img break
sohamparikh May 24, 2025
ce91646
fixes
sohamparikh May 27, 2025
7eea79b
update audio encoder
tobyzl2 May 28, 2025
daf98b3
audio transformer updates
tobyzl2 May 28, 2025
cd167fc
audio conversion
tobyzl2 May 28, 2025
80c0aa2
merge
tobyzl2 May 28, 2025
204b3e9
fix image embeddings offset
sohamparikh May 28, 2025
fd08eac
heterogeneous data fixes
sohamparikh May 29, 2025
1e3652a
convert to rgb
sohamparikh May 29, 2025
e0f7dfd
mm loss masking spans
tobyzl2 May 29, 2025
0ae74d1
add lr scale
tobyzl2 May 29, 2025
28a3808
merge
tobyzl2 May 29, 2025
438ba80
mel spec changes
tobyzl2 May 30, 2025
525543a
updates
tobyzl2 May 30, 2025
2aabf35
fix sequence parallel image patches
sohamparikh May 30, 2025
b6d4858
fixes
sohamparikh May 31, 2025
25a650b
no compile for embeddings
sohamparikh May 31, 2025
c904da5
fix sampling
sohamparikh Jun 1, 2025
7a4701c
sampling and preprocessing bugs
sohamparikh Jun 2, 2025
067f901
speed up sampling
sohamparikh Jun 2, 2025
95526a3
adding audio start and end tokens
tobyzl2 Jun 2, 2025
01dfed7
merge
tobyzl2 Jun 2, 2025
fb23ef8
conversion changes
tobyzl2 Jun 3, 2025
d7d1135
adding data prep sharding
tobyzl2 Jun 3, 2025
012a636
faster mel sepc
tobyzl2 Jun 6, 2025
c664444
adding num audio to config
tobyzl2 Jun 12, 2025
ba73939
audio encoder padding updates
tobyzl2 Jun 12, 2025
5667a0a
configurable max pad
tobyzl2 Jun 12, 2025
9f68a5e
small fix
tobyzl2 Jun 16, 2025
c286f8d
debugging updates
tobyzl2 Jun 18, 2025
eb39e7e
working 5b changes
tobyzl2 Jun 23, 2025
a53c89a
small fixes
tobyzl2 Jun 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3
# Install dependencies.
RUN apt-get update \
&& apt-get install --no-install-recommends -y acl git-lfs \
# && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \
&& rm -rf /var/lib/apt/lists/* \
&& git lfs install

Expand Down
42 changes: 41 additions & 1 deletion fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class GPTBatch:
token_ids: torch.Tensor
loss_masking_spans: list[torch.Tensor] | None = None
sequence_lengths: list[torch.Tensor] | None = None
images: list[torch.Tensor] | None = None
image_positions: list[torch.Tensor] | None = None
audio: list[torch.Tensor] | None = None
audio_positions: list[torch.Tensor] | None = None


def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:
Expand All @@ -42,8 +46,44 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]
if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
has_images = False
batch_images = []
for sample in batch:
if sample.images is not None:
batch_images.append([torch.from_numpy(image) for image in sample.images])
has_images = True
else:
batch_images.append([])
batch_image_positions = []
for sample in batch:
if sample.image_positions is not None and len(sample.image_positions) > 0:
batch_image_positions.append(torch.from_numpy(sample.image_positions))
else:
batch_image_positions.append([])

has_audio = False
batch_audio = []
for sample in batch:
if sample.audio is not None and sample.audio_positions is not None:
batch_audio.append([torch.from_numpy(audio) for audio in sample.audio])
has_audio = True
else:
batch_audio.append(None)
batch_audio_positions = []
for sample in batch:
if sample.audio_positions is not None:
batch_audio_positions.append(torch.from_numpy(sample.audio_positions))
else:
batch_audio_positions.append([])

return GPTBatch(
token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths
token_ids=torch.from_numpy(stacked_ids),
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
images=batch_images if has_images else None,
image_positions=batch_image_positions if has_images else None,
audio=batch_audio if has_audio else None,
audio_positions=batch_audio_positions if has_audio else None,
)


Expand Down
23 changes: 22 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ class GPTSamplingParameters(SamplingParameters):
vocab_size: int
use_loss_masking_spans: bool = False
cross_document_attention: bool = True
patch_size: int | None = None
image_size: int | None = None
aud_downsampling_k: int | None = None
aud_padding_duration: int | None = None
aud_sampling_rate: int | None = None
image_break_token: int | None = None
image_end_token: int | None = None
audio_start_token: int | None = None
audio_end_token: int | None = None
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1
Expand Down Expand Up @@ -195,11 +204,23 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig):
desc="Expected number of tokens in the dataset.",
hint=FieldHint.optional,
)
num_pixels: int | None = Field(
default=None,
desc="Expected number of pixels in the dataset.",
hint=FieldHint.optional,
)
num_audio: int | None = Field(
default=None,
desc="Expected number of audio in the dataset.",
hint=FieldHint.optional,
)

def build(self) -> "GPTMemmapDataset":
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset

return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens)
return GPTMemmapDataset(
str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels
)


@config_class()
Expand Down
11 changes: 10 additions & 1 deletion fast_llm/data/dataset/gpt/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,20 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe

def get_document_sizes(self) -> np.ndarray:
# TODO: This can be really big.
return self._dataset.get_document_sizes()[self._begin : self._end]
doc_sizes, im_sizes, aud_sizes = self._dataset.get_document_sizes()
return (
doc_sizes[self._begin : self._end],
im_sizes[self._begin : self._end] if im_sizes else [],
aud_sizes[self._begin : self._end] if aud_sizes else [],
)

def get_document_size(self, index: int) -> int:
return self._dataset.get_document_size(self._begin + index)

@property
def has_images(self) -> bool:
return self._dataset.has_images


class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](
ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset
Expand Down
Loading