diff --git a/Dockerfile b/Dockerfile index 8c2efa85..b7e42d4d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,6 +4,7 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ && apt-get install --no-install-recommends -y acl git-lfs \ + # && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 02c1b6c0..8f700978 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,6 +32,10 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None + images: list[torch.Tensor] | None = None + image_positions: list[torch.Tensor] | None = None + audio: list[torch.Tensor] | None = None + audio_positions: list[torch.Tensor] | None = None def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: @@ -42,8 +46,44 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + has_images = False + batch_images = [] + for sample in batch: + if sample.images is not None: + batch_images.append([torch.from_numpy(image) for image in sample.images]) + has_images = True + else: + batch_images.append([]) + batch_image_positions = [] + for sample in batch: + if sample.image_positions is not None and len(sample.image_positions) > 0: + batch_image_positions.append(torch.from_numpy(sample.image_positions)) + else: + batch_image_positions.append([]) + + has_audio = False + batch_audio = [] + for sample in batch: + if sample.audio is not None and sample.audio_positions is not None: + batch_audio.append([torch.from_numpy(audio) for audio in sample.audio]) + has_audio = True + else: + batch_audio.append(None) + batch_audio_positions = [] + for sample in batch: + if sample.audio_positions is not None: + batch_audio_positions.append(torch.from_numpy(sample.audio_positions)) + else: + batch_audio_positions.append([]) + return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths + token_ids=torch.from_numpy(stacked_ids), + loss_masking_spans=stacked_spans, + sequence_lengths=sequence_lengths, + images=batch_images if has_images else None, + image_positions=batch_image_positions if has_images else None, + audio=batch_audio if has_audio else None, + audio_positions=batch_audio_positions if has_audio else None, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ed9128c6..357623b1 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -74,6 +74,15 @@ class GPTSamplingParameters(SamplingParameters): vocab_size: int use_loss_masking_spans: bool = False cross_document_attention: bool = True + patch_size: int | None = None + image_size: int | None = None + aud_downsampling_k: int | None = None + aud_padding_duration: int | None = None + aud_sampling_rate: int | None = None + image_break_token: int | None = None + image_end_token: int | None = None + audio_start_token: int | None = None + audio_end_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 @@ -195,11 +204,23 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="Expected number of tokens in the dataset.", hint=FieldHint.optional, ) + num_pixels: int | None = Field( + default=None, + desc="Expected number of pixels in the dataset.", + hint=FieldHint.optional, + ) + num_audio: int | None = Field( + default=None, + desc="Expected number of audio in the dataset.", + hint=FieldHint.optional, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) + return GPTMemmapDataset( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels + ) @config_class() diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 688ea6a7..a2bd9977 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -44,11 +44,20 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] + doc_sizes, im_sizes, aud_sizes = self._dataset.get_document_sizes() + return ( + doc_sizes[self._begin : self._end], + im_sizes[self._begin : self._end] if im_sizes else [], + aud_sizes[self._begin : self._end] if aud_sizes else [], + ) def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) + @property + def has_images(self) -> bool: + return self._dataset.has_images + class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index ef060b00..c47d3cf6 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -1,8 +1,12 @@ +import io import pathlib import struct import typing import numpy as np +import PIL.Image +import torchaudio +import soundfile as sf from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample @@ -26,22 +30,40 @@ def __init__( prefix: pathlib.Path | str, num_documents: int | None = None, num_tokens: int | None = None, + num_pixels: int | None = None, ): - self._init(name, prefix, num_documents, num_tokens) + self._init(name, prefix, num_documents, num_tokens, num_pixels) - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: + def _init( + self, + name: str, + prefix: pathlib.Path | str, + num_documents: int | None, + num_tokens: int | None, + num_pixels: int | None, + ) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) self._has_spans = 0 + self._has_images = 0 with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 3: + self._has_preference_spans = struct.unpack("= 4: + self._has_images = struct.unpack("= 5: + self._has_audio = struct.unpack("= 2: self._spans = [] self._num_spans = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, - offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, + offset=offset, ) - span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] for idx in range(self._num_documents): self._spans.append( @@ -79,19 +101,93 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._index_bin_buffer, dtype=np.int32, count=self._num_spans[idx] * 2, - offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, + offset=offset + + self._num_spans.nbytes + + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, + ).reshape(-1, 2) + ) + offset += ( + self._num_spans.nbytes + + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize + # + sum([x.nbytes for x in self._spans]) + ) + self._num_pixels = 0 + self._image_lengths = [] + self._image_positions = [] + if self._has_images and self._version >= 4: + self._n_images = np.frombuffer( + self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset + ) + images_seen = 0 + for n_images in self._n_images: + self._image_lengths.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_images * 2, + offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) + self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum() + self._image_positions.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_images, + offset=offset + + self._n_images.nbytes + + 2 * self._n_images.sum() * np.dtype(np.int32).itemsize + + images_seen * np.dtype(np.int32).itemsize, + ) + ) + images_seen += n_images + offset = offset + self._n_images.nbytes + 3 * self._n_images.sum() * np.dtype(np.int32).itemsize + self._audio_lengths = [] # list of arrays + self._audio_positions = [] # list of arrays + if self._has_audio and self._version >= 5: + self._n_audio = np.frombuffer( + self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset + ) + audio_seen = 0 + + offset = offset + self._n_audio.nbytes + for n_audio in self._n_audio: + self._audio_lengths.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_audio, + offset=offset + audio_seen * np.dtype(np.int32).itemsize, + ) + ) + # self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum() + self._audio_positions.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_audio, + offset=offset + + self._n_audio.sum() * np.dtype(np.int32).itemsize + + audio_seen * np.dtype(np.int32).itemsize, + ) + ) + audio_seen += n_audio self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) - if num_tokens is not None: - assert self._num_tokens == num_tokens + # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign + # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) + + # TODO Toby: Add audio num tokens check + self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) + # if num_pixels is not None: + # assert self._num_pixels == num_pixels + # if num_tokens is not None: + # assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) + return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): self._init(*state) @@ -104,8 +200,46 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap + # def get( + # self, + # idx: int, + # offset: int = 0, + # image_offset: int = 0, + # length: int | None = None, + # use_loss_masking_spans: bool = False, + # ): + # token_ids = np.frombuffer( + # self._bin_buffer, + # dtype=self._dtype, + # count=self._document_sizes[idx] - offset if length is None else length, + # offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + # ) + # if self._has_images: + # image_positions = self._image_positions[idx] + # pixels = np.frombuffer( + # self._bin_buffer, + # dtype=np.dtype(np.uint8), + # count=self._image_lengths[idx].prod(initial=3), + # offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, + # ) + # images = [] + # start = 0 + # for image_length in self._image_lengths[idx]: + # n_pixels = image_length.prod(initial=3) + # images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) + # start += n_pixels + # return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) + def get( - self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False + self, + idx: int, + offset: int = 0, + length: int | None = None, + use_loss_masking_spans: bool = False, + patch_size: int | None = None, + image_size: int | None = None, + image_break: bool = False, + image_end: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -113,16 +247,95 @@ def get( count=self._document_sizes[idx] - offset if length is None else length, offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) + images = None + image_positions = None + if self._has_images: + # Truncations with images are not yet supported + image_positions = self._image_positions[idx] + pixels = np.frombuffer( + self._bin_buffer, + dtype=np.dtype(np.uint8), + count=self._image_lengths[idx].prod(initial=3), + offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, + ) + start = 0 + for image_length in self._image_lengths[idx]: + n_pixels = image_length.prod(initial=3) + images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) + start += n_pixels + + audio = [] + audio_positions = None + if self._has_audio: + audio_positions = self._audio_positions[idx] + # increment offset by documents and images + aud_offset = ( + self._pointers[idx] + + offset * np.dtype(self._dtype).itemsize + + self._document_sizes[idx] * np.dtype(self._dtype).itemsize + ) + + if self._has_images and len(self._image_lengths) > 0: + aud_offset += self._image_lengths[idx].prod(initial=3) * np.dtype(np.uint8).itemsize + all_audio = np.frombuffer( + self._bin_buffer, + dtype=np.dtype(np.float32), + count=self._audio_lengths[idx].sum(), + offset=aud_offset, + ) + start = 0 + for audio_length in self._audio_lengths[idx]: + audio.append(all_audio[start : start + audio_length]) + start += audio_length + + # TODO Soham: return loss_masking_spans sample_spans = None if use_loss_masking_spans and self._spans is not None: sample_spans = self._spans[idx] - # adjust the spans for the offset and length sample_spans = sample_spans[ (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) ] sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) + # if images: + # image_idx = 0 + # for span in sample_spans: + # additional_tokens = 0 + # image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + # while image_position >= span[0] and image_position <= span[1]: + # image_tokens = get_num_image_tokens( + # get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), + # patch_size, + # image_break=image_break, + # ) + # additional_tokens += image_tokens + # image_idx += 1 + # image_position = ( + # image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + # ) + # span[1] += additional_tokens + # if audio: + # audio_idx = 0 + # for span in sample_spans: + # additional_tokens = 0 + # audio_position = audio_positions[audio_idx] if audio_idx < len(audio_positions) else float("inf") + # while audio_position >= span[0] and audio_position <= span[1]: + # audio_tokens = ... + # additional_tokens += audio_tokens + # audio_idx += 1 + # audio_position = ( + # audio_positions[audio_idx] if audio_idx < len(audio_positions) else float("inf") + # ) + # span[1] += additional_tokens + + return GPTSample( + token_ids=token_ids, + images=images, + image_positions=image_positions, + audio=audio, + audio_positions=audio_positions, + loss_masking_spans=sample_spans, + ) @property def name(self) -> str: @@ -135,23 +348,47 @@ def __len__(self) -> int: def num_tokens(self) -> int: return self._num_tokens - def get_document_sizes(self) -> np.ndarray: + @property + def has_images(self) -> bool: + return self._has_images + + @property + def has_audio(self) -> bool: + return self._has_audio + + def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes + return self._document_sizes, self._image_lengths, self._audio_lengths def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item() + # return self._document_sizes[index].item() + ( + # sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) + # if self._has_images + # else 0 + # ) + docsize = self._document_sizes[index].item() + imagesize = self._image_lengths[index] if self._has_images else [] + audiosize = self._audio_lengths[index] if self._has_audio else [] + return docsize, imagesize, audiosize @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): # Initialize metadata dtype = None num_documents = 0 - lengths = [] + doc_lengths = [] + n_images = [] + image_lengths = [] + im_positions = [] + total_images = 0 + n_audio = [] + audio_lengths = [] + aud_positions = [] + total_audio = 0 pointers = [] offset = 0 # number of spans for each document @@ -174,19 +411,55 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) + total_im_size = 0 + total_aud_size = 0 + if document.images: + n_images.append(len(document.images)) + total_images += len(document.images) + for image in document.images: + # assume 3 channels (RGB) for all images + with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + if img.mode != "RGB": + # Convert all images to RGB + img = img.convert("RGB") + pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." + image_lengths.append(np.array(pixels.shape[1:])) + bin_stream.write(pixels.tobytes(order="C")) + total_im_size += pixels.size + im_positions.append(document.image_positions) + if document.audio is not None: + num_audio = 0 + for audio in document.audio: + # audio_arr, _ = torchaudio.load(io.BytesIO(audio["bytes"])) + audio_arr, _ = sf.read(io.BytesIO(audio["bytes"])) + audio_arr = audio_arr.astype(np.float32) + if len(audio_arr) > 0: + num_audio += 1 + audio_lengths.append(len(audio_arr)) + bin_stream.write(audio_arr.tobytes(order="C")) + total_aud_size += audio_arr.size + n_audio.append(num_audio) + total_audio += num_audio + if num_audio > 0: + aud_positions += document.audio_positions # Update metadata doc_length = len(document.token_ids) - lengths.append(doc_length) + doc_lengths.append(doc_length) pointers.append(offset) if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) - offset += doc_length * np.dtype(dtype).itemsize + offset += ( + doc_length * np.dtype(dtype).itemsize + + total_im_size * np.dtype(np.uint8).itemsize + + total_aud_size * np.dtype(np.float32).itemsize + ) num_documents += 1 # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) + doc_lengths = np.array(doc_lengths, dtype=np.int32) pointers = np.array(pointers, dtype=np.int64) num_spans = np.array(num_spans, dtype=np.int32) if len(spans) > 0: @@ -194,27 +467,64 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP else: spans = np.array(spans, dtype=np.int32) + if total_images: + n_images = np.array(n_images, dtype=np.int32) + image_lengths = np.stack(image_lengths, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) + else: + n_images = np.array([]) + image_lengths = np.array([]) + im_positions = np.array([]) + + if total_audio: + n_audio = np.array(n_audio, dtype=np.int32) + audio_lengths = np.array(audio_lengths, dtype=np.int32) + aud_positions = np.array(aud_positions, dtype=np.int32) + else: + n_audio = np.array([]) + audio_lengths = np.array([]) + aud_positions = np.array([]) + # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version - # Version 2 optionally adds loss-masking spans - idx_stream.write(struct.pack(" 0 else 0)) + # Placeholder flag for preference spans + idx_stream.write(struct.pack(" 0 else 0)) + # Flag to indicate whether audio is present + idx_stream.write(struct.pack(" 0 else 0)) # Data type idx_stream.write(struct.pack(" None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + document_sizes, image_sizes, audio_sizes = self._indexed_dataset.get_document_sizes() + document_sizes = torch.from_numpy(document_sizes).to(self._device) + image_token_sizes = torch.zeros_like(document_sizes).to(self._device) + if image_sizes: + image_token_sizes = [] + for i, sizes in enumerate(image_sizes): + image_token_sizes.append( + sum( + get_num_image_tokens( + *get_resize_dims( + *size, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ), + self._parameters.patch_size, + image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, + ) + for size in sizes + ) + ) + image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + else: + image_token_sizes = torch.zeros_like(document_sizes) + + audio_token_sizes = torch.zeros_like(document_sizes).to(self._device) + long_audio_filter = torch.zeros_like(document_sizes, dtype=torch.bool) # longer than audio padding + for i, sizes in enumerate(audio_sizes): + audio_token_size_arr, to_filter = get_num_audio_tokens( + sizes, + self._parameters.aud_padding_duration, + self._parameters.aud_sampling_rate, + self._parameters.aud_downsampling_k, + self._parameters.audio_start_token, + self._parameters.audio_end_token, + ) + audio_token_sizes[i] = audio_token_size_arr.sum() + long_audio_filter[i] = to_filter + documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + tokens_per_epoch = ( + document_sizes.sum().item() + image_token_sizes.sum().item() + audio_token_sizes.sum().item() + ) # Calculate basic stats. if not self._truncate_documents: @@ -133,14 +189,31 @@ def _sample(self) -> None: "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes > self._parameters.sequence_length + 1 - ignored_documents = sum(long_docs_filter) + long_docs_filter = ( + document_sizes + image_token_sizes + +audio_token_sizes > self._parameters.sequence_length + 1 + ) + ignored_documents = long_docs_filter.sum() if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() + ignored_audio_samples = sum(long_audio_filter) + if ignored_audio_samples: + log_main_rank( + f" > {ignored_audio_samples}/{documents_per_epoch} samples contain audio longer than {self._parameters.aud_padding_duration} seconds and will be ignored.", + log_fn=logger.warning, + ) + long_docs_filter = long_docs_filter | long_audio_filter + tokens_per_epoch = ( + ( + document_sizes[~long_docs_filter] + + image_token_sizes[~long_docs_filter] + + audio_token_sizes[~long_docs_filter] + ) + .sum() + .item() + ) if tokens_per_epoch == 0: raise RuntimeError( f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." @@ -170,7 +243,7 @@ def _sample(self) -> None: shuffled_documents = documents_per_epoch * shuffled_epochs unshuffled_epochs = num_epochs - shuffled_epochs - yaml_data = { + yaml_data = { # TODO Toby: add audio "dataset": { "name": self._indexed_dataset.name, "documents_per_epoch": documents_per_epoch, @@ -179,7 +252,10 @@ def _sample(self) -> None: "num_samples": self._parameters.num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._parameters.sequence_length, + "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, + "image_break_token": self._parameters.image_break_token, + "image_end_token": self._parameters.image_end_token, "config": self._config.to_dict(), } if self._truncate_documents: @@ -261,7 +337,7 @@ def _sample(self) -> None: # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes, + document_sizes + image_token_sizes + audio_token_sizes, offset=0, # TODO: Allowing for max 100% extra tokens for padding, is that enough? dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), @@ -271,7 +347,7 @@ def _sample(self) -> None: unshuffled_tokens = 0 if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens + yaml_data["unshuffled_tokens"] = unshuffled_tokens * unshuffled_epochs self._load_yaml_data(yaml_data) if self._yaml_path is not None: self._yaml_path.parent.mkdir(parents=True, exist_ok=True) @@ -284,6 +360,16 @@ def _sample(self) -> None: document_shuffling.to( dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 ) + ] + + image_token_sizes[ + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) + ] + + audio_token_sizes[ + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) ], offset=self._unshuffled_tokens, # TODO: Allowing for max 100% extra tokens for padding, is that enough? @@ -375,6 +461,12 @@ def __getitem__(self, index: int) -> typing.Any: token_ids = [] loss_masking_spans = [] + images = [] + audio = [] + image_positions = [] + audio_positions = [] + mm_tokens_added = 0 + text_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -382,48 +474,252 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size = self._indexed_dataset.get_document_size(document_index) + text_size, image_lengths, audio_lengths = self._indexed_dataset.get_document_size(document_index) + + resized_image_lengths = [ + get_resize_dims( + *image_length, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ) + for image_length in image_lengths + ] + image_sizes = [ + get_num_image_tokens( + *image_length, + self._parameters.patch_size, + image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, + ) + for image_length in resized_image_lengths + ] + image_tokens = sum(image_sizes) + + audio_token_size_arr, _ = get_num_audio_tokens( + audio_lengths, + self._parameters.aud_padding_duration, + self._parameters.aud_sampling_rate, + self._parameters.aud_downsampling_k, + self._parameters.audio_start_token, + self._parameters.audio_end_token, + ) + audio_tokens = int(audio_token_size_arr.sum()) + + document_size = text_size + image_tokens + audio_tokens if not self._truncate_documents: + # Document too long, ignore if document_size > self._parameters.sequence_length + 1: - # Document too long, ignore document_sampling_index += 1 continue + + # Where are we currently in sample? tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + tokens_in_sample > self._parameters.sequence_length + 1: + if document_size + tokens_in_sample >= self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample - if token_count > token_start: + if token_count >= token_start: # Add padding tokens to current sample - token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + try: + token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + except: + pass Assert.eq(token_count + padding_size, token_end) break else: # Move on to the next sample. token_count += padding_size + continue # Determine if the document belongs to the requested sample. if token_count + document_size >= token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) - token_end_index_in_document = min(token_end - token_count, document_size) + token_end_index_in_document = min(token_end - token_count, text_size) sample = self._indexed_dataset.get( document_index, offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, + # image_break=self._parameters.image_break_token is not None, + # image_end=self._parameters.image_end_token is not None, ) - token_ids.append(sample.token_ids) + start_pos = 0 + + # add tokens and multi modal padding placeholders + # multimodal_positions = np.concatenate( + # [ + # arr.astype(np.int32) + # for arr in (sample.image_positions, sample.audio_positions) + # if arr is not None + # ] + # ) or np.array([], dtype=np.int32) + # multimodal_positions.sort() + + multimodal_positions = [] + if sample.image_positions is not None: + multimodal_positions.extend( + [(pos, "image", idx) for idx, pos in enumerate(sample.image_positions)] + ) + if sample.audio_positions is not None: + multimodal_positions.extend( + [(pos, "audio", idx) for idx, pos in enumerate(sample.audio_positions)] + ) + + token_ids_per_sample = [] + special_mm_tok_loss_masking_spans = np.empty((0, 2), dtype=np.int32) + multimodal_positions.sort(key=lambda x: x[0]) + for global_idx, (mm_position, mm_type, source_idx) in enumerate(multimodal_positions): + # Add placeholders for image and audio tokens tokens + token_ids_per_sample.append(sample.token_ids[start_pos:mm_position]) + text_tokens_added += len(token_ids_per_sample[-1]) + if mm_type == "image": + # image_positions.append(im_positions + len(token_ids) + image_tokens_added) + # Add placeholders for image tokens + image_positions.append(text_tokens_added + mm_tokens_added) + if self._parameters.image_break_token is not None: + height, width = resized_image_lengths[source_idx] + num_patches_h = div(height, self._parameters.patch_size) + num_patches_w = div(width, self._parameters.patch_size) + + # Create image token placeholder array + image_token_array = np.full((image_sizes[source_idx],), -100, dtype=np.int64) + + # Add break tokens after each row except the last row + for row in range(num_patches_h - 1): + position = (row + 1) * num_patches_w + row + image_token_array[position] = self._parameters.image_break_token + # add end token if specified, else break token + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if self._parameters.image_end_token is not None: + image_token_array[last_row_position] = self._parameters.image_end_token + else: + image_token_array[last_row_position] = self._parameters.image_break_token + else: + image_token_array = np.full((image_sizes[source_idx],), -100, dtype=np.int64) + if self._parameters.image_end_token is not None: + image_token_array[-1] = self._parameters.image_end_token + token_ids_per_sample.append(image_token_array) + mm_tokens_added += image_sizes[source_idx] + elif mm_type == "audio": + # audio_pos = sum(t.size for t in token_ids) # includes mm tokens added already + # compute audio position + start_token_offset = int(self._parameters.audio_start_token is not None) + audio_pos = text_tokens_added + mm_tokens_added + start_token_offset + audio_positions.append(audio_pos) + + # compute number of special tokens + num_audio_special_tokens = int(self._parameters.audio_start_token is not None) + int( + self._parameters.audio_end_token is not None + ) + + # add start tokens + if self._parameters.audio_start_token is not None: + token_ids_per_sample.append(np.array([self._parameters.audio_start_token])) + # add to loss masking spans + special_mm_tok_loss_masking_spans = np.append( + special_mm_tok_loss_masking_spans, [[audio_pos - 1, audio_pos - 1]], axis=0 + ) + # sample.loss_masking_spans = np.append(sample.loss_masking_spans, [[audio_pos-1, audio_pos-1]], axis=0) + + # add audio pad tokens + num_audio_pad_tokens = audio_token_size_arr[source_idx] + num_audio_pad_tokens -= num_audio_special_tokens # ignore start/end tokens for padding + audio_padding_tokens = np.full((num_audio_pad_tokens,), -100, dtype=np.int64) + token_ids_per_sample.append(audio_padding_tokens) + + # add end token + if self._parameters.audio_end_token is not None: + token_ids_per_sample.append(np.array([self._parameters.audio_end_token])) + # add to loss masking spans + special_mm_tok_loss_masking_spans = np.append( + special_mm_tok_loss_masking_spans, + [[audio_pos + num_audio_pad_tokens, audio_pos + num_audio_pad_tokens]], + axis=0, + ) + # sample.loss_masking_spans = np.append(sample.loss_masking_spans, [[audio_pos + num_audio_pad_tokens, audio_pos + num_audio_pad_tokens]], axis=0) + + # update mm tokens added + mm_tokens_added += num_audio_special_tokens + num_audio_pad_tokens + start_pos = mm_position + + # add remaining text tokens + token_ids_per_sample.append(sample.token_ids[start_pos:]) + text_tokens_added += len(token_ids_per_sample[-1]) + + token_ids.append(np.concatenate(token_ids_per_sample)) + if sample.images: + images.append(sample.images) + else: + images.append([]) + if sample.audio: + # audio.append(self.apply_audio_padding(sample.audio)) + audio.append(sample.audio) + else: + audio.append([]) + if self._parameters.use_loss_masking_spans: + mm_idx = 0 + mm_tokens_before_span = 0 + + # sort by start of span + sample.loss_masking_spans = sample.loss_masking_spans[sample.loss_masking_spans[:, 0].argsort()] for loss_masking_span in sample.loss_masking_spans: + mm_tokens_within_span = 0 + mm_position, mm_type, source_idx = ( + multimodal_positions[mm_idx] + if mm_idx < len(multimodal_positions) + else (float("inf"), _, _) + ) + + # increment mm_idx until span is reached, track mm tokens before span + while mm_position < loss_masking_span[0]: + if mm_type == "image": + num_mm_tokens = image_sizes[source_idx] + elif mm_type == "audio": + num_mm_tokens = audio_token_size_arr[source_idx] + mm_tokens_before_span += num_mm_tokens + mm_idx += 1 + mm_position, mm_type, source_idx = ( + multimodal_positions[mm_idx] + if mm_idx < len(multimodal_positions) + else (float("inf"), _, _) + ) + + # get all multimodal positions within span + while mm_position >= loss_masking_span[0] and mm_position <= loss_masking_span[1]: + if mm_type == "image": + num_mm_tokens = image_sizes[source_idx] + elif mm_type == "audio": + num_mm_tokens = audio_token_size_arr[source_idx] + mm_tokens_within_span += num_mm_tokens + mm_idx += 1 + mm_position, mm_type, source_idx = ( + multimodal_positions[mm_idx] + if mm_idx < len(multimodal_positions) + else (float("inf"), _, _) + ) + loss_masking_span[0] += mm_tokens_before_span # increment by all mm tokens before span + loss_masking_span[1] += mm_tokens_before_span + mm_tokens_within_span + mm_tokens_before_span += mm_tokens_within_span + span = np.clip( - loss_masking_span + token_count - token_start, + loss_masking_span + int(token_count) - int(token_start), 0, self._parameters.sequence_length + self._parameters.extra_tokens, ) - if span[1] > span[0]: + if span[1] >= span[0]: loss_masking_spans.append(span) + for span in special_mm_tok_loss_masking_spans: + # span = np.clip( + # loss_masking_span + token_count - token_start, + # 0, + # self._parameters.sequence_length + self._parameters.extra_tokens, + # ) + if span[1] >= span[0]: + loss_masking_spans.append(span) # Go to the next document. document_sampling_index += 1 token_count += document_size @@ -439,9 +735,25 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) - Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + images = [im for img_list in images for im in img_list] if images else None + image_positions = np.array(image_positions) if image_positions else None + + audio = [aud for aud_list in audio for aud in aud_list] if audio else None # flatten + audio_positions = np.array(audio_positions) if audio_positions else None + # Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) + + # # TODO: Toby remove/comment after testing (for testing only first sample) + # loss_masking_spans = np.append(loss_masking_spans, [[sequence_lengths[0], sequence_lengths[:-1].sum()]], axis=0) + return GPTSample( + token_ids=token_ids, + loss_masking_spans=loss_masking_spans, + sequence_lengths=sequence_lengths, + images=images, + image_positions=image_positions, + audio=audio if audio is not None and len(audio) > 0 else None, + audio_positions=audio_positions, + ) @property def name(self) -> str: @@ -522,7 +834,7 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ logger.info(f" > Sampling dataset {self._indexed_dataset.name} ...") - document_sizes = self._indexed_dataset.get_document_sizes() + document_sizes, _ = self._indexed_dataset.get_document_sizes() num_documents = len(document_sizes) num_tokens = document_sizes.sum() np_rng = np.random.RandomState(seed=self._config.seed) diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 86bc080f..53df3add 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -35,18 +35,16 @@ def __len__(self) -> int: def __getitem__(self, idx) -> typing.Any: start_time = time.perf_counter() - try: - sample = self._dataset[idx] - sample_time = (time.perf_counter() - start_time) * 1000 - if sample_time > self._data_sample_warn_time_ms: - logger.warning( - f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" - ) - return sample - - except Exception: - logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") - raise + # try: + sample = self._dataset[idx] + sample_time = (time.perf_counter() - start_time) * 1000 + if sample_time > self._data_sample_warn_time_ms: + logger.warning(f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load") + return sample + + # except Exception as e: + # logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") + # raise @property def name(self) -> str: diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c3..f4e722dc 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -59,6 +59,21 @@ class GPTHuggingfaceDatasetConfig(Config): loss_masking_spans: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) + image_paths: None | str = Field( + default=None, desc="Field containing images within the document", hint=FieldHint.optional + ) + image_positions: None | str = Field( + default=None, desc="Field containing image positions within a document", hint=FieldHint.optional + ) + images: None | str = Field( + default=None, desc="Field containing images relevant to a document", hint=FieldHint.optional + ) + audio_positions: None | str = Field( + default=None, desc="Field containing audio positions within a document", hint=FieldHint.optional + ) + audio: None | str = Field( + default=None, desc="Field containing audio relevant to a document", hint=FieldHint.optional + ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." @@ -158,6 +173,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) + image_patch_size: int = Field( + default=16, + desc="Patch size for images. This is used solely for computing the number of tokens in an image to get an even split.", + hint=FieldHint.optional, + ) splits: dict[str, float] | None = Field( default=None, desc="Split the output dataset into multiple ones (ex, train/valid/test) with the specified ratios." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 23e497bf..888e1b63 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,3 +1,5 @@ +import io +import itertools import json import logging import multiprocessing @@ -8,6 +10,7 @@ import datasets import huggingface_hub import numpy as np +import PIL.Image import requests import torch.distributed import tqdm @@ -38,57 +41,122 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _tokenizer: Tokenizer _data_type: DataType - def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) - for text in batch[self._config.dataset.field] - ] - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "num_tokens": num_tokens, - } + def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + pass - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans = map( + def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + # input_ids = [ + # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) + # for text in batch[self._config.dataset.field] + # ] + input_ids, token_spans, image_token_positions, audio_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32).reshape(-1, 2), + np.array(image_token_positions, dtype=np.int32), + np.array(audio_token_positions, dtype=np.int32), ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, char_spans) - for text, char_spans in zip( - batch[self._config.dataset.field], batch[self._config.dataset.loss_masking_spans] + for input_ids, token_spans, image_token_positions, audio_token_positions in [ + self._tokenizer.tokenize( + text, + loss_mask_spans, + im_char_positions, + aud_char_positions, + ) + for text, loss_mask_spans, im_char_positions, aud_char_positions in zip( + batch[self._config.dataset.field], + batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), + batch.get(self._config.dataset.image_positions, itertools.repeat(None)), + batch.get(self._config.dataset.audio_positions, itertools.repeat(None)), ) ] ] ), ) num_tokens = [len(x) for x in input_ids] + num_pixels = [0] * len(input_ids) + for idx, images in enumerate(batch.get("images", [])): + for bytes_im in images: + with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + width, height = im.size + num_pixels[idx] += width * height * 3 + + num_audio = [0] * len(input_ids) + for idx, audio_lst in enumerate(batch.get(self._config.dataset.audio, [])): + for audio in audio_lst: + num_audio[idx] += len(audio) + return { "input_ids": input_ids, + "image_positions": image_token_positions, + "audio_positions": audio_token_positions, "token_spans": token_spans, "num_tokens": num_tokens, + "num_pixels": num_pixels, + "num_audio": num_audio, } + # def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + # input_ids, token_spans, images, image_token_positions = map( + # list, + # zip( + # *[ + # ( + # np.array(input_ids, dtype=self._data_type.numpy), + # np.array(token_spans, dtype=np.int32).reshape(-1, 2), + # np.array(images, dtype=np.uint8), + # np.array(image_token_positions, dtype=np.int32), + # ) + # for input_ids, token_spans, images, image_token_positions in [ + # self._tokenizer.tokenize_with_spans(text, char_spans) + # for text, char_spans in zip( + # batch[self._config.dataset.field], + # batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), + # batch.get(self._config.dataset.images, itertools.repeat(None)), + # batch.get(self._config.dataset.image_positions, itertools.repeat(None)), + # ) + # ] + # ] + # ), + # ) + # num_tokens = [len(x) for x in input_ids] + # num_pixels = [0] * len(input_ids) + # for idx, images in enumerate(images): + # for bytes_im in images: + # with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + # width, height = im.size + # num_pixels[idx] += width * height * 3 + # return { + # "input_ids": input_ids, + # "token_spans": token_spans, + # "images": images, + # "image_positions": image_token_positions, + # "num_tokens": num_tokens, + # "num_pixels": num_pixels, + # } + def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: shard_idx, shard_dataset = args prefix = f"shard_{self._config.distributed.rank}_{shard_idx}" shard_output_path = self._config.output_path / prefix def _document_generator(): - if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - ) - else: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + np.array(item["input_ids"], dtype=self._data_type.numpy), + ( + np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) + if self._config.dataset.loss_masking_spans + else None + ), + item["images"] if self._config.dataset.images else None, + item["image_positions"] if self._config.dataset.image_positions else None, + item[self._config.dataset.audio] if self._config.dataset.audio else None, + item[self._config.dataset.audio_positions] if self._config.dataset.audio_positions else None, + ) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -98,19 +166,25 @@ def _document_generator(): "path": prefix, "num_documents": len(shard_dataset), # Use the length of the shard dataset directly "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), + "num_pixels": sum(doc["num_pixels"] for doc in shard_dataset), + "num_audio": sum(doc["num_audio"] for doc in shard_dataset), } ) def _load_dataset(self) -> datasets.Dataset: - dataset = datasets.load_dataset( - path=self._config.dataset.path, - name=self._config.dataset.config_name, - data_dir=self._config.dataset.data_directory, - data_files=self._config.dataset.data_files, - split=self._config.dataset.split, - num_proc=self._config.loading_workers, - trust_remote_code=self._config.dataset.trust_remote_code, - ) + try: + dataset = datasets.load_dataset( + path=self._config.dataset.path, + name=self._config.dataset.config_name, + data_dir=self._config.dataset.data_directory, + data_files=self._config.dataset.data_files, + split=self._config.dataset.split, + num_proc=self._config.loading_workers, + trust_remote_code=self._config.dataset.trust_remote_code, + ) + except: + # backup if dataset is saved in arrow format (can we auto-detect this?) + dataset = datasets.load_from_disk(dataset_path=self._config.dataset.data_directory) assert isinstance(dataset, datasets.Dataset) return dataset @@ -214,12 +288,12 @@ def run(self) -> None: ) if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") - if self._config.dataset.loss_masking_spans is not None: - if self._config.dataset.loss_masking_spans not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._config.dataset.loss_masking_spans}'.") - tokenize_fn = self._tokenize_batch_with_spans - else: - tokenize_fn = self._tokenize_batch + tokenize_fn = self._tokenize_batch + # decoding bytes to images is slow and should be done only when needed + if self._config.dataset.images is not None: + dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) + if self._config.dataset.audio is not None: + dataset = dataset.cast_column("audio", datasets.Sequence(datasets.Audio(decode=False))) # Tokenize the dataset in parallel tokenized_dataset = dataset.map( @@ -227,10 +301,25 @@ def run(self) -> None: batched=True, num_proc=self._config.tokenize_workers, desc="Tokenizing batches", + load_from_cache_file=False # TODO Toby: remove ) # Calculate total number of tokens total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) + total_pixels = ( + sum(tqdm.tqdm(tokenized_dataset["num_pixels"], desc="Counting pixels", unit="pixels")) + if self._config.dataset.images + else 0 + ) + total_audio = ( + sum(tqdm.tqdm(tokenized_dataset["num_audio"], desc="Counting audio", unit="audio")) + if self._config.dataset.audio + else 0 + ) + total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize + total_tokens += total_audio * np.float32().itemsize // np.dtype(self._data_type.numpy).itemsize + + tokenized_dataset = tokenized_dataset.shuffle(seed=42) # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) @@ -259,7 +348,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa # Create the config file(s) on rank 0 if self._config.splits: for split_name, split_config in self._split_and_blend_dataset_configs( - dataset_configs, self._config.splits, self._config.output_path + dataset_configs, self._config.splits, self._config.output_path, self._config.image_patch_size ).items(): self._save_dataset_config( split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" @@ -299,7 +388,11 @@ def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) - @classmethod def _split_and_blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path + cls, + dataset_configs: list[GPTMemmapDatasetConfig], + splits: dict[str, int | float], + output_path: pathlib.Path, + image_patch_size: int, ) -> dict[str, GPTSampledDatasetConfig]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] @@ -329,10 +422,20 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - sizes_cumsum = dataset.get_document_sizes().cumsum() - Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + text_sizes, image_sizes = dataset.get_document_sizes() + tokens_cumsum = text_sizes.cumsum() + Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) + if image_sizes: + num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) + # We use the patch sizes only for the purposes of even splitting and blending weights. + # We can always use a different patch size for training without any significant impact + # Unless the patch size used at training time is significantly different from the one used here + image_tokens_cumsum = num_pixels_cumsum // (image_patch_size**2) + tokens_cumsum += image_tokens_cumsum + num_pixels_cumsum = num_pixels_cumsum * 3 + Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) + begin_index = _get_nearest_split(tokens_cumsum, split_begin_in_dataset * tokens_cumsum[-1]) + end_index = _get_nearest_split(tokens_cumsum, split_end_in_dataset * tokens_cumsum[-1]) if end_index > begin_index: datasets_in_split.append( GPTDatasetSliceConfig.from_dict( @@ -345,8 +448,8 @@ def _split_and_blend_dataset_configs( ) ) dataset_tokens_in_split.append( - sizes_cumsum[end_index - 1].item() - - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0) + tokens_cumsum[end_index - 1].item() + - (tokens_cumsum[begin_index - 1].item() if begin_index > 0 else 0) ) # [else] None of the dataset belongs to the split. diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 28e105ee..e37b0e6d 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -35,43 +35,128 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, begin=True, end=True) -> list[int]: + def _tokenize(self, text: str, begin=True, end=True) -> list[int]: return ( ([self.bod_id] if begin else []) + self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end else []) ) - def tokenize_with_spans( - self, text: str, char_spans: list[tuple[int, int]] + def tokenize( + self, text: str, char_spans=None, image_positions=None, audio_positions=None ) -> tuple[list[int], list[tuple[int, int]]]: """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. + Tokenize the input text and return the tokenized input_ids and if provided, token spans and image positions. """ - input_ids = [] - token_spans = [] + image_positions = image_positions or [] + audio_positions = audio_positions or [] + char_spans = char_spans or [] + + if len(set(image_positions).intersection(audio_positions)) > 0: + raise ValueError("Image and audio can not have the same position.") + multimodal_positions = sorted(image_positions + audio_positions) + + mm_idx = 0 char_pos = 0 + token_ids = [] + image_token_positions = [] + audio_token_positions = [] + token_spans = [] beginning_of_text = True + multimodal_position = multimodal_positions[mm_idx] if mm_idx < len(multimodal_positions) else float("inf") for start, end in char_spans: + # tokenize text, compute mm token position before span + while multimodal_position <= start: + # tokenize text before mm position + tokenized_text = self._tokenize(text[char_pos:multimodal_position], begin=beginning_of_text, end=False) + beginning_of_text = False + token_ids.extend(tokenized_text) + + # update mm token positions + multimodal_type = "image" if multimodal_position in image_positions else "audio" + if multimodal_type == "image": + image_token_positions.append(len(token_ids)) + else: + audio_token_positions.append(len(token_ids)) + + # updates + mm_idx += 1 + char_pos = multimodal_position + multimodal_position = ( + multimodal_positions[mm_idx] if mm_idx < len(multimodal_positions) else float("inf") + ) + + # tokenize remaining text before span if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) + self._tokenize(text[char_pos:start], begin=beginning_of_text, end=False) beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - else: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) + token_ids.extend(tokenized_text) + + char_pos = start + span_length = 0 + token_start = len(token_ids) + + # tokenize text, compute mm token position within span + while multimodal_position <= end: + # tokenize text before mm position + tokenized_text = self._tokenize(text[char_pos:multimodal_position], begin=beginning_of_text, end=False) + beginning_of_text = False + token_ids.extend(tokenized_text) + + # update mm token positions + multimodal_type = "image" if multimodal_position in image_positions else "audio" + if multimodal_type == "image": + image_token_positions.append(len(token_ids)) + else: + audio_token_positions.append(len(token_ids)) + + # updates + span_length += len(tokenized_text) + char_pos = multimodal_position + mm_idx += 1 + multimodal_position = ( + multimodal_positions[mm_idx] if mm_idx < len(multimodal_positions) else float("inf") + ) + + # tokenize remaining text until end of span + if char_pos < end: + if end >= len(text) - 1: + tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=True) + else: + tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=False) + beginning_of_text = False + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + char_pos = end + 1 + + # update token spans + token_spans.append((token_start, token_start + span_length - 1)) + + # tokenize text, compute mm token position after all spans + while multimodal_position <= len(text): + # tokenize text before mm position + multimodal_position = multimodal_positions[mm_idx] + tokenized_text = self._tokenize(text[char_pos:multimodal_position], begin=beginning_of_text, end=False) beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 - if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - input_ids.extend(tokenized_text) - return input_ids, token_spans + token_ids.extend(tokenized_text) + + # update mm token positions + multimodal_type = "image" if multimodal_position in image_positions else "audio" + if multimodal_type == "image": + image_token_positions.append(len(token_ids)) + else: + audio_token_positions.append(len(token_ids)) + + # updates + char_pos = multimodal_position + mm_idx += 1 + multimodal_position = multimodal_positions[mm_idx] if mm_idx < len(multimodal_positions) else float("inf") + + # tokenize text after all spans + tokenized_text = self._tokenize(text[char_pos:], begin=beginning_of_text, end=True) + token_ids.extend(tokenized_text) + + return token_ids, token_spans, image_token_positions, audio_token_positions def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 675e878b..b1c7df81 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -121,7 +121,7 @@ def backward( assert self._mode.support_backward input_, output = grad_context output.backward(output_grad) - return input_.grad + return input_.grad if input_.grad is not None else torch.zeros_like(input_) def restore_parameters(self) -> None: assert self._is_setup diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 141490ac..f2acf9b6 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -49,6 +49,17 @@ class BatchConfig(Config): desc="Pointer to a distributed configuration, required to know the data-parallel split of the batch.", hint=FieldHint.setup, ) + # Image inputs + image_size: int | None = Field( + default=None, + desc="Maximum image height and width", + hint=FieldHint.optional, + ) + aud_padding_duration: int = Field( + default=-1, + desc="Audio padding duration in seconds.", + hint=FieldHint.feature, + ) def setup(self, distributed_config: DistributedConfig) -> None: self._distributed = distributed_config diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 22f23174..480fa067 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -40,6 +40,7 @@ class ActivationType(enum.StrEnum): """ gelu = "gelu" + gelu_pytorch_tanh = "gelu_pytorch_tanh" silu = "silu" relu = "relu" squared_relu = "squared_relu" @@ -67,7 +68,8 @@ def _set_activation_fn_map() -> None: global _ACTIVATION_FN_MAP _ACTIVATION_FN_MAP = { - ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.gelu: torch.nn.functional.gelu, + ActivationType.gelu_pytorch_tanh: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), ActivationType.silu: torch.nn.functional.silu, ActivationType.relu: torch.nn.functional.relu, ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2), @@ -78,7 +80,8 @@ def _set_activation_fn_map() -> None: _ACTIVATION_FN_MAP: dict[ActivationType, typing.Callable[["torch.Tensor"], "torch.Tensor"]] = {} _ACTIVATION_HF_NAMES = { - ActivationType.gelu: "gelu_pytorch_tanh", + ActivationType.gelu: "gelu", + ActivationType.gelu_pytorch_tanh: "gelu_pytorch_tanh", ActivationType.silu: "silu", ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", @@ -86,6 +89,7 @@ def _set_activation_fn_map() -> None: } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} + MAX_DROPLESS_BLOCK_SIZE_ROW = 128 diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 513510ec..53b5979e 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -145,7 +145,7 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask + per_sample_loss = per_sample_loss[loss_mask] loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index ee3ba304..0fb71bd5 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -50,7 +50,7 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - if activation_type == _TritonActivationType.gelu: + if activation_type == _TritonActivationType.gelu_pytorch_tanh: tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) @@ -100,7 +100,7 @@ def triton_mlp_activation_backward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32) - if activation_type == _TritonActivationType.gelu: + if activation_type == _TritonActivationType.gelu_pytorch_tanh: tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh) diff --git a/fast_llm/layers/audio_encoder/adapter.py b/fast_llm/layers/audio_encoder/adapter.py new file mode 100644 index 00000000..bc4f8f00 --- /dev/null +++ b/fast_llm/layers/audio_encoder/adapter.py @@ -0,0 +1,87 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.triton.mlp import torch_mlp_activation +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderDimNames +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.tensor import TensorMeta, init_normal_ + + +class AudioAdapter(Layer): + """ + Vision adapter layer for the LLM. + """ + + def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): + super().__init__() + audio_hidden_dim = tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels) + input_dim = tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_input) + self._activation_type = config.adapter_activation_type + self._use_adapter_bias = config.adapter_bias + self.lr_scale = config.adapter_lr_scale + + self.norm_1 = config.transformer.normalization.get_layer(audio_hidden_dim) + self.norm_1.lr_scale = self.lr_scale + self.norm_2 = config.transformer.normalization.get_layer( + tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size) + ) + self.norm_2.lr_scale = self.lr_scale + + # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? + self.layer_1 = Linear( + input_dim, + tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size), + bias=self._use_adapter_bias, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + lr_scale=self.lr_scale, + ) + self.layer_2 = Linear( + tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size), + tensor_space.get_tensor_dim(TransformerDimNames.hidden), + bias=self._use_adapter_bias, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + lr_scale=self.lr_scale, + ) + + self.aud_downsampling_k = config.aud_downsampling_k + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Audio adapter output", + dtype=input_.dtype, + ) + input_ = self.norm_1(input_) + batch_size, seq_len, dim = input_.size() + + # Check if sequence length is divisible by downsampling rate. + if seq_len % self.aud_downsampling_k != 0: + # If not divisible, trim the end of the sequence. + trimmed_seq_len = seq_len - (seq_len % self.aud_downsampling_k) + input_ = input_[:, :trimmed_seq_len, :] + seq_len = trimmed_seq_len + + # Reshape: group every k frames together (concatenate along feature dimension). + new_seq_len = seq_len // self.aud_downsampling_k + input_ = input_.contiguous().view(batch_size, new_seq_len, dim * self.aud_downsampling_k) + layer1_res = torch_mlp_activation( + input_=self.layer_1(input_), gated=False, activation_type=self._activation_type + ) + torch.manual_seed(0) # TODO Toby: remove after debugging + layer1_res_dropout = torch.nn.functional.dropout(layer1_res, 0.1) + layer1_res_norm = self.norm_2(layer1_res_dropout) + layer2_res = self.layer_2(layer1_res_norm) + return layer2_res diff --git a/fast_llm/layers/audio_encoder/config.py b/fast_llm/layers/audio_encoder/config.py new file mode 100644 index 00000000..95665901 --- /dev/null +++ b/fast_llm/layers/audio_encoder/config.py @@ -0,0 +1,159 @@ +import enum + +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType +from fast_llm.layers.transformer.config import AudioTransformerConfig +from fast_llm.utils import Assert + + +class AudioEncoderDimNames: + in_channels = "audio_in_channels" + out_channels = "audio_out_channels" + kernel_size = "audio_kernel_size" + adapter_input = "audio_adapter_input" + adapter_size = "audio_adapter_size" + audio_channels = "audio_kv_channels" + max_source_positions = "audio_max_source_positions" + + +class AudioEncoderKwargs: + audio = "audio" + audio_mel = "audio_mel" + audio_positions = "audio_positions" + + kv_channels = "audio_kv_channels" # TODO: check this + out_channels = "audio_out_channels" + hidden_dims = "audio_hidden_dims" + + # TODO: used for backup attention + sequence_length = "audio_sequence_length" + sequence_k_dim = "audio_sequence_k_dim" + sequence_q_dim = "audio_sequence_q_dim" + + +class AudioEncoderType(str, enum.Enum): + none = "none" + whisper = "whisper" + + +@config_class() +class AudioEncoderConfig(BaseModelConfig): + _abstract = False + + type: AudioEncoderType = Field( + default=AudioEncoderType.none, + desc="Type of the audio encoder. Choices: none, whisper.", + hint=FieldHint.architecture, + ) + transformer: AudioTransformerConfig = Field( + default_factory=AudioTransformerConfig, + desc="Configuration for the audio transformer architecture.", + hint=FieldHint.core, + ) + + # encoder configs + conv_bias: bool = Field( + default=True, + desc="Whether to use bias in the convolutional layer.", + hint=FieldHint.optional, + ) + encoder_dropout: float = Field( + default=0.0, + desc="Dropout for encoder.", + hint=FieldHint.core, + ) + kernel_size: int = Field( + default=3, + desc="Encoder convolution layer kernel size.", + hint=FieldHint.core, + ) + conv_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the convolutional layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + pos_emb_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the position embedding layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + + # adapter configs + adapter_size: int = Field( + default=5120, + desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + hint=FieldHint.core, + ) + adapter_activation_type: ActivationType = Field( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) + adapter_bias: bool = Field( + default=True, + desc="Whether to use bias in the adapter layer.", + hint=FieldHint.optional, + ) + adapter_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the adapter weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + + # audio configs + num_mel_bins: int = Field( + default=80, + desc="Number of bins for mel spectogram.", + hint=FieldHint.core, + ) + aud_downsampling_k: int = Field( + default=5, + desc="Audio downsampling k parameter.", + hint=FieldHint.feature, + ) + aud_sampling_rate: int = Field( + default=16000, + desc="Audio sampling rate to use.", + hint=FieldHint.feature, + ) + + # audio start/end tokens + audio_start_token: int | None = Field( + default=None, + desc="Token id for audio start.", + hint=FieldHint.optional, + ) + audio_end_token: int | None = Field( + default=None, + desc="Token id for audio end.", + hint=FieldHint.optional, + ) + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.in_channels, self.num_mel_bins)) + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.kernel_size, self.kernel_size)) + tensor_space.add_tensor_dim( + TensorDim(AudioEncoderDimNames.adapter_input, self.transformer.hidden_size * self.aud_downsampling_k) + ) + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.adapter_size, self.adapter_size)) + tensor_space.add_tensor_dim( + TensorDim(AudioEncoderDimNames.max_source_positions, 1500) + ) # TODO: configure later + + tensor_space.add_tensor_dim( + TensorDim( + AudioEncoderDimNames.audio_channels, + self.transformer.hidden_size // self.transformer.num_attention_heads, + ) + ) + self.transformer.setup_tensor_space(tensor_space) + + @property + def enabled(self) -> bool: + return self.type != AudioEncoderType.none diff --git a/fast_llm/layers/audio_encoder/encoder.py b/fast_llm/layers/audio_encoder/encoder.py new file mode 100644 index 00000000..b35cc174 --- /dev/null +++ b/fast_llm/layers/audio_encoder/encoder.py @@ -0,0 +1,93 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderDimNames +from fast_llm.layers.transformer.config import AudioTransformerKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ + + +class AudioConv(Layer): + def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): + super().__init__() + self._tensor_space = tensor_space + self.dropout_p = config.encoder_dropout + self._conv_lr_scale = config.conv_lr_scale + self._pos_emb_lr_scale = config.pos_emb_lr_scale + + self.conv1_weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), + ), + init_method=init_normal_(), + lr_scale=self._conv_lr_scale, + ) + self.conv1_stride = 1 # TODO Toby: parameterize? + + self.conv2_weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), + ), + init_method=init_normal_(), + lr_scale=self._conv_lr_scale, + ) + self.conv2_stride = 2 # TODO Toby: parameterize? + + if config.conv_bias: + self.conv1_bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_scale=self._conv_lr_scale, + ) + self.conv2_bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_scale=self._conv_lr_scale, + ) + else: + self.conv1_bias = None + self.conv2_bias = None + + self.positional_embeddings = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.max_source_positions), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + ), + init_method=init_normal_(), + lr_scale=self._pos_emb_lr_scale, + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + hidden_dims = kwargs[AudioTransformerKwargs.hidden_dims] # TODO: check seq q + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims(hidden_dims, tensor_name="audio conv output", dtype=input_.dtype) + + # TODO Toby: check how to best cast dtype + input_ = input_.to(self.conv1_weight.dtype) + + input_ = torch.nn.functional.conv1d( + input_, self.conv1_weight, self.conv1_bias, stride=self.conv1_stride, padding=1 + ) + input_ = torch.nn.functional.gelu(input_) + input_ = torch.nn.functional.conv1d( + input_, self.conv2_weight, self.conv2_bias, stride=self.conv2_stride, padding=1 + ) + input_ = torch.nn.functional.gelu(input_) + + audio_embeddings = input_.permute(0, 2, 1) + audio_embeddings = audio_embeddings + self.positional_embeddings + audio_embeddings = torch.nn.functional.dropout(audio_embeddings, p=self.dropout_p, training=self.training) + + return audio_embeddings.contiguous() diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py new file mode 100644 index 00000000..21262fe9 --- /dev/null +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -0,0 +1,153 @@ +import math +import typing + +import numpy as np +import torch +from transformers import WhisperFeatureExtractor + +from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderKwargs + + +def get_num_audio_tokens( + sizes, aud_padding_duration, aud_sampling_rate, aud_downsampling_k, audio_start_token, audio_end_token +): + if len(sizes) == 0: # sample has no audio + return np.array(sizes), False + to_filter = False + # account for padding + if aud_padding_duration > 0: + raw_audio_seq_length = aud_padding_duration * aud_sampling_rate + sizes = sizes.copy() # original is read-only + to_filter = bool(np.any(sizes > raw_audio_seq_length)) # filter sample where any audio is too long + sizes.fill(raw_audio_seq_length) # set all audio sizes to padded amount + + # account for mel spectogram, convolution, downsampling k + audio_token_size_arr = sizes // 160 # default hop length TODO Toby: check divisible? + audio_token_size_arr = audio_token_size_arr // ( + 2 * aud_downsampling_k + ) # convolution (2 stride) * downsampling TODO Toby: make configurable convolution + + if audio_start_token is not None: + audio_token_size_arr += 1 + if audio_end_token is not None: + audio_token_size_arr += 1 + return audio_token_size_arr, to_filter + + +def apply_audio_padding(audio, aud_padding_duration, aud_sampling_rate): + if len(audio) == 0: + return audio + # TODO Toby: check 2d + padded_audio = [] + if aud_padding_duration > 0: + raw_audio_seq_length = aud_padding_duration * aud_sampling_rate + for aud in audio: + padded = np.pad(aud, (0, raw_audio_seq_length - len(aud)), mode="constant", constant_values=0) + padded_audio.append(padded) + return padded_audio + else: + return audio + + +class AudioPreprocessor(Preprocessor): + def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + + self.feature_extractor = WhisperFeatureExtractor(sampling_rate=self._config.aud_sampling_rate) + + # self.mel_transform = MelSpectrogram( + # sample_rate=self._config.aud_sampling_rate, + # n_fft=400, + # win_length=400, + # hop_length=160, + # n_mels=80, + # f_min=0.0, + # f_max=8000.0, + # mel_scale="slaney", + # norm="slaney", + # center=True, + # power=2.0, + # ) + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + # kwargs[AudioEncoderKwargs.audio_mel_meta] = TensorMeta.from_dims( + # ( + # TensorDim( + # VisionTransformerDimNames.batch, + # kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, + # ), + # TensorDim(VisionEncoderDimNames.in_channels, 3), + # TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + # TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + # ), + # dtype=self._distributed_config.training_dtype.torch, + # ) + pass + + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: + # check if audio is in batch + audio_mel = [] + if AudioEncoderKwargs.audio in kwargs: + print("Preprocessing Contains Audio") + audio_raw = kwargs[AudioEncoderKwargs.audio] + flattened_audio = [ + audio_arr for sequence in audio_raw for audio_arr in sequence + ] # flatten in the batch dimension + print("Preprocessing Flattened Audio: ", flattened_audio) + + for audio in flattened_audio: + audio_mel.append( + self.feature_extractor( + audio, + sampling_rate=self._config.aud_sampling_rate, + return_tensors="pt", + max_length=30 * self._config.aud_sampling_rate, + device=self._tensor_space.distributed.device, + )["input_features"] + ) + audio_mel = torch.stack(audio_mel, dim=0).squeeze(1) + curr_size = audio_mel.size(0) + else: + print("Preprocessing No Audio") + audio_mel = torch.tensor(audio_mel, dtype=torch.float32) + curr_size = 0 + + print("Preprocessing Audio Mel Raw: ", audio_mel) + + # compute max pad + max_pad = math.ceil( + kwargs["sequence_length"] / (kwargs["audio_encoder_sequence_length"] // self._config.aud_downsampling_k) + ) + max_pad = 1 + max_pad = max(max_pad, curr_size) + + # add padding + padding_size = max_pad - curr_size + if padding_size > 0: + padding = torch.zeros( + padding_size, + self.feature_extractor.feature_size, + self.feature_extractor.nb_max_frames, + dtype=audio_mel.dtype, + device=audio_mel.device, + ) + audio_mel = torch.cat((audio_mel, padding), dim=0) + + print("Preprocessing Audio Mel Final: ", audio_mel) + + # move to device + audio_mel = audio_mel.to(self._tensor_space.distributed.device) + kwargs[AudioEncoderKwargs.audio_mel] = audio_mel + + # # set attention mask # TODO Toby: fix backup attention + # sequence_k = kwargs[self._transformer_kwargs.sequence_k_dim].size + # sequence_q = kwargs[self._transformer_kwargs.sequence_q_dim].size + # kwargs[self._transformer_kwargs.attention_mask] = self._mask[ + # None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k + # ] + # kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value + # audio_mel = torch.rand(len(flattened_audio), 80, 3000) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index d0f03ccf..8ba066cb 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,7 +5,9 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig from fast_llm.utils import Assert @@ -33,6 +35,7 @@ class LanguageModelKwargs: position_ids = "position_ids" # TODO: These are generic labels = "labels" + tokens = "tokens" phase = "phase" loss_mask = "loss_mask" @@ -44,6 +47,16 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) + vision_encoder: VisionEncoderConfig = Field( + default_factory=VisionEncoderConfig, + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) + audio_encoder: AudioEncoderConfig = Field( + default_factory=AudioEncoderConfig, + desc="Configuration for the audio encoder that transforms audio into embeddings.", + hint=FieldHint.optional, + ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -167,6 +180,10 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: # TODO: Need both? tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) + if self.vision_encoder.enabled: + self.vision_encoder.setup_tensor_space(tensor_space) + if self.audio_encoder.enabled: + self.audio_encoder.setup_tensor_space(tensor_space) @property def num_absolute_position_embeddings(self) -> int: diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d9406ed..f51f40df 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -99,7 +99,10 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t input_ = split(input_, group=group, dim=0) if self._use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) - embeddings = torch.embedding(self.word_embeddings_weight, input_) + # mask padded tokens + input_mask = input_ >= 0 + masked_input = input_ * input_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) with set_generator( diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py new file mode 100644 index 00000000..d137de5e --- /dev/null +++ b/fast_llm/layers/multi_modal/embedding.py @@ -0,0 +1,198 @@ +import typing + +import torch + +from fast_llm.core.distributed import set_generator +from fast_llm.core.ops import gather, reduce_forward, split +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.audio_encoder.config import AudioEncoderKwargs +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches +from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert, div + + +class MultiModalEmbedding(LanguageModelEmbedding): + """ + Multi-modal embedding layer to combine embeddings from text, image and more modalities. + """ + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + ): + super().__init__(config, tensor_space) + + # @torch.compile + def _forward( + self, + input_: torch.Tensor, + tokens: torch.Tensor, + position_ids: torch.Tensor | None, + image_positions: list[torch.Tensor] | None, + image_sizes: list[list[tuple[int, int]]] | None, + audio_positions: list[torch.Tensor] | None, + ) -> torch.Tensor: + """ + Forward pass for the multi-modal embedding layer. + Args: + input_: The input tensor (image embeddings). + tokens: The tokenized text input. + position_ids: The position ids for the text input. + image_positions: The positions of the image tokens in the input. + image_sizes: The sizes of the images in the input. + Returns: + The combined embeddings for text and images. + """ + Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) + group = self._tensor_space.distributed.tensor_group + if self._sequence_parallel: + micro_seqlen = input_.size(0) + patch_start_offset = self._distributed_config.tensor_rank * micro_seqlen + patch_end_offset = (self._distributed_config.tensor_rank + 1) * micro_seqlen + else: + patch_start_offset = 0 + patch_end_offset = input_.size(0) + if self._parallel_embeddings: + token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) + masked_tokens = (tokens - self._vocab_start_index) * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa + embeddings = embeddings.clone() + input_ = gather(input_, group, dim=0) + # TODO: Toby implement audio + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if image_embedding_offset + num_patches < patch_start_offset: + continue + if self._config.vision_encoder.image_break_token is not None: + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + if row_start_src > patch_end_offset: + break + if row_start_src + patch_width <= patch_start_offset: + continue + + input_start_index = max(row_start_src, patch_start_offset) - patch_start_offset + input_end_index = min(row_start_src + patch_width, patch_end_offset) - patch_start_offset + embeddings_start_index = row_start_dst - max(patch_start_offset - row_start_src, 0) + embeddings_end_index = ( + row_start_dst + patch_width - max(row_start_src + patch_width - patch_end_offset, 0) + ) + # row_end_src = min(row_start_src + patch_width, patch_end_offset) + if self._sequence_parallel: + # Copy with dimensions swapped for sequence parallel case + embeddings[embeddings_start_index:embeddings_end_index, sample_idx] = input_[ + input_start_index:input_end_index, sample_idx + ] + tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 + else: + # Copy with normal dimension ordering + embeddings[sample_idx, embeddings_start_index:embeddings_end_index] = input_[ + sample_idx, input_start_index:input_end_index + ] + tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 + else: + input_start_index = max(image_embedding_offset, patch_start_offset) - patch_start_offset + input_end_index = ( + min(image_embedding_offset + num_patches, patch_end_offset) - patch_start_offset + ) + embedding_start_index = position - max(patch_start_offset - image_embedding_offset, 0) + embedding_end_index = ( + position + num_patches - max(image_embedding_offset + num_patches - patch_end_offset, 0) + ) + embeddings[sample_idx, embedding_start_index:embedding_end_index] = input_[ + input_start_index:input_end_index, sample_idx + ] + # embeddings[sample_idx, position : position + num_patches] = input_[ + # sample_idx, image_embedding_offset : image_embedding_offset + num_patches + # ] + image_embedding_offset += num_patches + if image_embedding_offset > patch_end_offset: + break + embeddings = reduce_forward(embeddings, group) + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + if self._sequence_parallel: + embeddings = split(embeddings, group=group, dim=0) + else: + if self._sequence_parallel: + tokens = split(tokens, group=group, dim=0) + if self._use_absolute_position_embeddings: + position_ids = split(position_ids, group=group, dim=0) + # TODO Soham: get image positions for current split. Maybe in preprocessing? + # for positions in image_positions: + # if positions > self._distributed_config.tensor_rank + # mask padded tokens + token_mask = tokens >= 0 + masked_tokens = tokens * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) + embeddings = embeddings.clone() + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if self._config.vision_encoder.image_break_token is not None: + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) + + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + else: + embeddings[sample_idx, position : position + num_patches] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_patches + ] + # Move to the next image in the input tensor + image_embedding_offset += num_patches + + audio_position_idx = 0 + for sample_idx, positions in enumerate(audio_positions): + for position in positions: + num_audio_tokens = input_.shape[1] # TODO: Toby better way to get this? + embeddings[sample_idx, position : position + num_audio_tokens] = input_[audio_position_idx] + audio_position_idx += 1 + + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + with set_generator( + self._tensor_space.distributed.tp_generator + if self._sequence_parallel + else self._tensor_space.distributed.pp_generator + ): + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + return embeddings.to(dtype=self._residual_dtype) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Embedding output", + dtype=self._residual_dtype, + ) + # TODO: How do we support both Audio and Vision? + position_ids = kwargs.get(LanguageModelKwargs.position_ids) + image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes, []) + image_positions = kwargs.get(VisionEncoderKwargs.image_positions, []) + audio_positions = kwargs.get(AudioEncoderKwargs.audio_positions, []) + tokens = kwargs.get(LanguageModelKwargs.tokens) + + return self._forward(input_, tokens, position_ids, image_positions, image_sizes, audio_positions) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 0b442f66..e88f64a3 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -9,12 +9,7 @@ from fast_llm.functional.rotary import apply_rotary_embeddings from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerSubLayerName, -) +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs, TransformerSubLayerName from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -57,24 +52,6 @@ class Attention(torch.nn.Module): A self-attention layer. """ - _QUERY_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_heads, - TransformerDimNames.kv_channels, - ) - _KV_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, - TransformerDimNames.kv_channels, - ) - _CONTEXT_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_dense, - ) - def __init__( self, config: TransformerConfig, @@ -82,12 +59,16 @@ def __init__( layer_index, ): super().__init__() + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config self._tensor_space = tensor_space - Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) + # TODO Soham: fix assert + # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer + self._causal = self._config.causal self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -101,19 +82,19 @@ def __init__( max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels).size + self._head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).global_size + self._local_head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).size + self._local_heads_per_group = self._tensor_space.get_tensor_dim(self._transformer_dim_names.group_heads).size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_query), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -122,7 +103,7 @@ def __init__( ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_key_value), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -133,7 +114,7 @@ def __init__( # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_dense), hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, @@ -199,7 +180,7 @@ def _attn_fused( def _get_meta( self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} + hidden_dims = {dim.name: dim for dim in kwargs[self._transformer_kwargs.hidden_dims]} return TensorMeta.from_dims( tuple( hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) @@ -209,6 +190,32 @@ def _get_meta( dtype=input_.dtype, ) + @property + def _query_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.composite_heads, + self._transformer_dim_names.kv_channels, + ) + + @property + def _kv_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.group_heads, + self._transformer_dim_names.kv_channels, + ) + + @property + def _context_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.composite_dense, + ) + def _debug_log( self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] ) -> None: @@ -307,12 +314,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # TODO: Move the rest to function. - if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(self._transformer_kwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(TransformerKwargs.presents)) is not None: + if (presents := kwargs.get(self._transformer_kwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) @@ -339,23 +346,23 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._config.rotary.enabled: if self._debug_transformer: - self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) + self._debug_log(query, "query_rotary_input", self._query_dims, kwargs) self._debug_log( key, "key_rotary_input", - self._KV_DIMS, + self._kv_dims, kwargs, ) rotary_fn = triton_rotary_autograd_ if self._config.rotary.triton else apply_rotary_embeddings - query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) - key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) + query = rotary_fn(query, kwargs[self._transformer_kwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[self._transformer_kwargs.rotary_freq_k]) window_size = self._decide_window_size() if self._use_flash_attention: assert _flash_available with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: + if (cu_seqlens_q := kwargs.get(self._transformer_kwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -365,12 +372,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(self._transformer_kwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(self._transformer_kwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(self._transformer_kwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), - causal=True, + causal=self._causal, softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -380,7 +387,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ value, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), dropout_p=self._config.attention_dropout if self.training else 0.0, - causal=True, + causal=self._causal, softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) @@ -390,25 +397,25 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[TransformerKwargs.attention_mask], - kwargs[TransformerKwargs.attention_mask_value], + kwargs[self._transformer_kwargs.attention_mask], + kwargs[self._transformer_kwargs.attention_mask_value], ) if self._debug_transformer: - self._debug_log(query, "query", self._QUERY_DIMS, kwargs) + self._debug_log(query, "query", self._query_dims, kwargs) self._debug_log( key, "key", - self._KV_DIMS, + self._kv_dims, kwargs, ) self._debug_log( value, "value", - self._KV_DIMS, + self._kv_dims, kwargs, ) - self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) + self._debug_log(input_, "context", self._context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/audio_transformer.py b/fast_llm/layers/transformer/audio_transformer.py new file mode 100644 index 00000000..f0fb6d17 --- /dev/null +++ b/fast_llm/layers/transformer/audio_transformer.py @@ -0,0 +1,40 @@ +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.transformer.config import AudioTransformerDimNames, AudioTransformerKwargs, TransformerConfig +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.tensor import TensorMeta + + +class AudioTransformerLayer(TransformerLayer): + """ + A audio transformer layer to encode image patches + """ + + def __init__( + self, + config: TransformerConfig, + tensor_space: TensorSpace, + layer_index: int, + return_input: bool = False, + ): + super().__init__(config, tensor_space, layer_index, return_input) + + hidden_dim = self._tensor_space.get_tensor_dim(AudioTransformerDimNames.hidden) + + # use regular layernorm (not rms norm) + self.norm_1 = self._config.normalization.get_layer(hidden_dim) + self.norm_2 = self._config.normalization.get_layer(hidden_dim) + + self.norm_1 = self._config.peft.apply_other(self.norm_1) + self.norm_2 = self._config.peft.apply_other(self.norm_2) + + @property + def name(self) -> str: + return f"Audio transformer layer {self._layer_index}" + + def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): + dims = kwargs[AudioTransformerKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index e69b1841..45d911a6 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -5,7 +5,7 @@ import typing import warnings -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace @@ -28,59 +28,94 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -class TransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "batch" - # TODO: Distinguish micro-sequence? - sequence_q = "sequence_q" - sequence_q_tp = "sequence_q_tp" - sequence_k = "sequence_k" - hidden = "hidden" - # Self-attention dimensions - head_groups = "head_groups" - group_heads = "group_heads" - key_and_value = "key_value" - kv_channels = "kv_channels" - composite_heads = "composite_heads" - composite_query = "composite_query" - composite_key_value = "composite_key_value" - composite_dense = "composite_dense" - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" - - -class TransformerKwargs: - rotary_freq_q = "rotary_freq_q" - rotary_freq_k = "rotary_freq_k" - attention_mask = "attention_mask" - attention_mask_value = "attention_mask_value" - sequence_lengths = "sequence_lengths" - cu_seqlens_q = "cu_seqlens_q" - cu_seqlens_k = "cu_seqlens_k" - max_seqlen_q = "max_seqlen_q" - max_seqlen_k = "max_seqlen_k" - # TODO: Review these - presents = "presents" - past_key_values = "past_key_values" - sequence_first = "sequence_first" - hidden_dims = "hidden_dims" - sequence_q_dim = "sequence_q_dim" - sequence_k_dim = "sequence_k_dim" - sequence_length = "sequence_length" - # TODO: Move - grad_output = "grad_output" +class BaseTransformerDimNames: + _kwargs_attributes = { + "batch": "batch", + "sequence_q": "sequence_q", + "sequence_q_tp": "sequence_q_tp", + "sequence_k": "sequence_k", + "hidden": "hidden", + "head_groups": "head_groups", + "group_heads": "group_heads", + "key_and_value": "key_value", + "kv_channels": "kv_channels", + "composite_heads": "composite_heads", + "composite_query": "composite_query", + "composite_key_value": "composite_key_value", + "composite_dense": "composite_dense", + "mlp": "mlp", + "gate_and_up": "gate_and_up", + "composite_gated_mlp": "composite_gated_mlp", + "experts": "experts", + "top_experts": "top_experts", + "shared_experts": "shared_experts", + "unshared_experts": "unshared_experts", + "composite_expert_mlp": "composite_expert_mlp", + "composite_gated_expert_mlp": "composite_gated_expert_mlp", + "composite_shared_expert_mlp": "composite_shared_expert_mlp", + "composite_gated_shared_expert_mlp": "composite_gated_shared_expert_mlp", + } + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseTransformerDimNames._kwargs_attributes.items(): + setattr(cls, attr, f"{cls._prefix}_{value}") + + +class TransformerDimNames(BaseTransformerDimNames, prefix=""): + pass + + +class VisionTransformerDimNames(BaseTransformerDimNames, prefix="image_encoder"): + pass + + +class AudioTransformerDimNames(BaseTransformerDimNames, prefix="audio_encoder"): + pass + + +class BaseTransformerKwargs: + _kwargs_attributes = { + "rotary_freq_q": "rotary_freq_q", + "rotary_freq_k": "rotary_freq_k", + "attention_mask": "attention_mask", + "attention_mask_value": "attention_mask_value", + "sequence_lengths": "sequence_lengths", + "cu_seqlens_q": "cu_seqlens_q", + "cu_seqlens_k": "cu_seqlens_k", + "max_seqlen_q": "max_seqlen_q", + "max_seqlen_k": "max_seqlen_k", + "presents": "presents", + "past_key_values": "past_key_values", + "sequence_first": "sequence_first", + "hidden_dims": "hidden_dims", + "sequence_q_dim": "sequence_q_dim", + "sequence_k_dim": "sequence_k_dim", + "sequence_length": "sequence_length", + "micro_batch_size": "micro_batch_size", + "grad_output": "grad_output", + } + + _prefix = "" + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseTransformerKwargs._kwargs_attributes.items(): + setattr(cls, value, f"{cls._prefix}_{value}" if cls._prefix else value) + + +class TransformerKwargs(BaseTransformerKwargs, prefix=""): + pass + + +class VisionTransformerKwargs(BaseTransformerKwargs, prefix="image_encoder"): + patch_position_ids = "patch_position_ids" + + +class AudioTransformerKwargs(BaseTransformerKwargs, prefix="audio_encoder"): + pass class TransformerLossNames: @@ -93,6 +128,14 @@ class RotaryEmbeddingType(str, enum.Enum): default = "default" llama3 = "llama3" yarn = "yarn" + # TODO Soham: generic name? + pixtral = "pixtral" + + +class TransformerType(str, enum.Enum): + lm_decoder = "lm_decoder" + image_encoder = "image_encoder" + audio_encoder = "audio_encoder" @config_class() @@ -157,6 +200,40 @@ def _validate(self) -> None: if self.triton and not TritonConfig.TRITON_ENABLED: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") + @property + def _transformer_dim_names(self) -> TransformerDimNames: + return TransformerDimNames + + @property + def _transformer_kwargs(self) -> TransformerKwargs: + return TransformerKwargs + + +@config_class() +class VisionRotaryConfig(RotaryConfig): + type: RotaryEmbeddingType = Field( + default=RotaryEmbeddingType.pixtral, + desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", + hint=FieldHint.feature, + ) + + @property + def _transformer_dim_names(self) -> VisionTransformerDimNames: + return VisionTransformerDimNames + + @property + def _transformer_kwargs(self) -> VisionTransformerKwargs: + return VisionTransformerKwargs + + +# @config_class() +# class AudioRotaryConfig(RotaryConfig): +# type: RotaryEmbeddingType = Field( +# default=RotaryEmbeddingType.none, +# desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", +# hint=FieldHint.feature, +# ) + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" @@ -247,6 +324,11 @@ def _validate(self) -> None: @config_class() class TransformerConfig(BaseModelConfig): _abstract = False + transformer_type: TransformerType = Field( + default=TransformerType.lm_decoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder, audio_encoder.", + hint=FieldHint.architecture, + ) normalization: NormalizationConfig = Field( default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", @@ -526,6 +608,11 @@ class TransformerConfig(BaseModelConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + causal: bool = Field( + default=True, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) def _validate(self) -> None: with self._set_implicit_default(): @@ -641,63 +728,83 @@ def _from_dict( cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) return super()._from_dict(default, strict, flat) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) -> None: + if type == "vision": + # TODO Soham: better way to get around circular imports? Maybe add a type class variable to TransformerConfig? + pass + + elif type == "audio": + pass + + else: + pass + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + self._transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - TransformerDimNames.group_heads, + self._transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(self._transformer_dim_names.key_and_value, 2)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + kv_channels := TensorDim(self._transformer_dim_names.kv_channels, self.kv_channels) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self._transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim(self._transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim( + self._transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels) + ) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(self._transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim(mlp := TensorDim(self._transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim( + gate_and_up := TensorDim(self._transformer_dim_names.gate_and_up, 2 if self.gated else 1) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(self._transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(experts := TensorDim(self._transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim( + CompositeTensorDim(self._transformer_dim_names.composite_expert_mlp, (experts, mlp)) + ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(self._transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + self._transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) @@ -712,3 +819,78 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: Assert.is_(self.window_size, None) return use_flash_attention + + @property + def _transformer_kwargs(self) -> TransformerKwargs: + return TransformerKwargs + + @property + def _transformer_dim_names(self) -> TransformerDimNames: + return TransformerDimNames + + +@config_class() +class VisionTransformerConfig(TransformerConfig): + """ + Configuration for the Vision Transformer (ViT) model. + """ + + transformer_type: TransformerType = FieldUpdate( + default=TransformerType.image_encoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder, audio_encoder.", + hint=FieldHint.architecture, + ) + causal: bool = FieldUpdate( + default=False, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) + rotary: VisionRotaryConfig = FieldUpdate( + default_factory=VisionRotaryConfig, + desc="Configuration for the rotary positional embeddings.", + hint=FieldHint.feature, + ) + + @property + def _transformer_kwargs(self) -> VisionTransformerKwargs: + return VisionTransformerKwargs + + @property + def _transformer_dim_names(self) -> VisionTransformerDimNames: + return VisionTransformerDimNames + + +@config_class() +class AudioTransformerConfig(TransformerConfig): + """ + Configuration for the Audio Transformer model. + """ + + transformer_type: TransformerType = FieldUpdate( + default=TransformerType.audio_encoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder, audio_encoder.", + hint=FieldHint.architecture, + ) + causal: bool = FieldUpdate( + default=False, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Audio Transformer.", + hint=FieldHint.feature, + ) + gated: bool = FieldUpdate( + default=False, + desc="MLP gating.", + hint=FieldHint.feature, + ) + # rotary: AudioRotaryConfig = FieldUpdate( + # default_factory=AudioRotaryConfig, + # desc="Configuration for the rotary positional embeddings.", + # hint=FieldHint.feature, + # ) + + @property + def _transformer_kwargs(self) -> AudioTransformerKwargs: + return AudioTransformerKwargs + + @property + def _transformer_dim_names(self) -> AudioTransformerDimNames: + return AudioTransformerDimNames diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 1c38705f..42393a41 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,7 +8,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName +from fast_llm.layers.transformer.config import TransformerConfig, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -18,6 +18,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s super().__init__() self._name = name + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs + init_method_1 = init_normal_( std=config.init_method_std_mlp_1, min_val=config.init_method_min_mlp_1, @@ -29,8 +32,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) + self._intermediate_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.composite_expert_mlp) self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -41,7 +44,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space.get_tensor_dim(self._transformer_dim_names.composite_gated_expert_mlp), bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 2415a2f9..1b436eba 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -7,13 +7,8 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.rotary import convert_rotary_complex_to_real -from fast_llm.layers.transformer.config import ( - RotaryConfig, - RotaryEmbeddingType, - TransformerConfig, - TransformerDimNames, - TransformerKwargs, -) +from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType, TransformerConfig, TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -130,64 +125,116 @@ def get_rotary_frequencies( return frequencies +def get_2d_rotary_frequencies( + config: RotaryConfig, + height, + width, + kv_channels, + *, + device="cuda", +) -> torch.Tensor: + # Calculate complex frequencies by using alternating channels for width and height + height_positions = torch.arange(height, device=device, dtype=torch.float64) + width_positions = torch.arange(width, device=device, dtype=torch.float64) + frequencies = config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + # TODO Soham: apply scaling + angles_h = torch.outer(height_positions, frequencies[::2]) + angles_w = torch.outer(width_positions, frequencies[1::2]) + angles = torch.cat( + [ + angles_h[:, None, :].repeat(1, width, 1), + angles_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not config.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + ).contiguous() + + return frequencies + + class RotaryEmbeddingPreprocessor(Preprocessor): _scalar_dim: TensorDim - _kv_channels_dim: TensorDim - _rotary_embedding_frequencies: torch.Tensor _mask: torch.Tensor _mask_value: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 def __init__( self, config: RotaryConfig, tensor_space: TensorSpace, ): + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config assert self._config.enabled self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._kv_channels_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels) + self._tensor_cache_max_sequence_length: int = -1 - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, num_patches: None | int = None) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length - self._rotary_embedding_frequencies = get_rotary_frequencies( - self._config, - sequence_length, - self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, - ) + if self._config.type == RotaryEmbeddingType.pixtral: + self._rotary_embedding_frequencies = get_2d_rotary_frequencies( + self._config, + num_patches, + num_patches, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) + else: + self._rotary_embedding_frequencies = get_rotary_frequencies( + self._config, + sequence_length, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) + if self._config.type == RotaryEmbeddingType.pixtral: + max_num_patches = kwargs[VisionEncoderKwargs.image_size] // kwargs[VisionEncoderKwargs.patch_size] + self._create_tensors(kwargs[TransformerKwargs.sequence_length], max_num_patches) + else: + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k - ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + if self._config.type == RotaryEmbeddingType.pixtral: + position_ids = kwargs[self._transformer_kwargs.patch_position_ids] + # TODO Soham: use position_ids_q and position_ids_k for sequence_data_parallelism + kwargs[self._transformer_kwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[self._transformer_kwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] + else: + kwargs[self._transformer_kwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - sequence_q : sequence_k + ] + kwargs[self._transformer_kwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=self._transformer_kwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=self._transformer_kwargs.rotary_freq_k, ) @@ -204,6 +251,8 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -231,25 +280,25 @@ def _create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - kwargs[TransformerKwargs.attention_mask] = self._mask[ + self._create_tensors(kwargs[self._transformer_kwargs.sequence_length]) + sequence_k = kwargs[self._transformer_kwargs.sequence_k_dim].size + sequence_q = kwargs[self._transformer_kwargs.sequence_q_dim].size + kwargs[self._transformer_kwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(self._transformer_kwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[TransformerKwargs.attention_mask] = ( - kwargs[TransformerKwargs.attention_mask] + kwargs[self._transformer_kwargs.attention_mask] = ( + kwargs[self._transformer_kwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[TransformerKwargs.attention_mask_value] = self._mask_value + kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, self._scalar_dim, @@ -257,12 +306,12 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: self._scalar_dim, kwargs[TransformerKwargs.sequence_k_dim], ), - tensor_name=TransformerKwargs.attention_mask, + tensor_name=self._transformer_kwargs.attention_mask, dtype=torch.bool, ) - kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), - tensor_name=TransformerKwargs.attention_mask_value, + tensor_name=self._transformer_kwargs.attention_mask_value, dtype=self._tensor_space.distributed_config.training_dtype.torch, ) @@ -273,6 +322,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: """ @@ -323,17 +374,17 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( + kwargs[self._transformer_kwargs.cu_seqlens_q] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( + kwargs[self._transformer_kwargs.cu_seqlens_k] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() + kwargs[self._transformer_kwargs.max_seqlen_q] = seqlens_q.max() + kwargs[self._transformer_kwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 40dd2e00..392ebb88 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -9,7 +9,7 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage @@ -29,6 +29,8 @@ def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__() + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config: TransformerConfig = config self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout @@ -37,7 +39,8 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + + hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) @@ -68,7 +71,7 @@ def name(self) -> str: return f"{self._name} {self._layer_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[TransformerKwargs.hidden_dims] + dims = kwargs[self._transformer_kwargs.hidden_dims] if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py new file mode 100644 index 00000000..7c1be0d1 --- /dev/null +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -0,0 +1,16 @@ +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.layers.transformer.config import VisionTransformerKwargs +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.tensor import TensorMeta + + +class VisionTransformerLayer(TransformerLayer): + _name: str = "Vision transformer layer" + + def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): + dims = kwargs[VisionTransformerKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py new file mode 100644 index 00000000..41ea065d --- /dev/null +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -0,0 +1,53 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.triton.mlp import torch_mlp_activation +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames +from fast_llm.tensor import TensorMeta, init_normal_ + + +class VisionAdapter(Layer): + """ + Vision adapter layer for the LLM. + """ + + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + super().__init__() + input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + self._activation_type = config.adapter_activation_type + self.layer_1 = Linear( + input_dim, + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + self.layer_2 = Linear( + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), + tensor_space.get_tensor_dim(TransformerDimNames.hidden), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Vision adapter output", + dtype=input_.dtype, + ) + return self.layer_2( + torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py new file mode 100644 index 00000000..26794174 --- /dev/null +++ b/fast_llm/layers/vision_encoder/config.py @@ -0,0 +1,165 @@ +import enum + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.transformer.config import VisionTransformerConfig +from fast_llm.utils import Assert + + +class VisionEncoderDimNames: + in_channels = "vision_in_channels" + out_channels = "vision_out_channels" + adapter_size = "vision_adapter_size" + patch_size = "vision_patch_size" + kv_channels = "vision_kv_channels" + + +class VisionEncoderKwargs: + patch_size = "patch_size" + images = "images" + image_patches = "image_patches" + image_positions = "image_positions" + image_size = "image_size" + image_sizes = "image_sizes" + image_mean = "image_normalization_mean" + image_std = "image_normalization_std" + image_rescale_factor = "image_rescale_factor" + rope_theta = "vit_rope_theta" + rotary_inv_freq = "vit_rotary_inv_freq" + kv_channels = "vit_kv_channels" + max_image_tokens = "max_image_tokens" + patch_embeddings = "patch_embeddings" + hidden_dims = "vit_hidden_dims" + image_patches_meta = "vit_image_patches_meta" + out_channels = "vit_out_channels" + + +@config_class() +class ImageNormalizationConfig(Config): + mean_r: float = Field( + default=0.48145466, + desc="Mean value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_g: float = Field( + default=0.4578275, + desc="Mean value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_b: float = Field( + default=0.40821073, + desc="Mean value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_r: float = Field( + default=0.26862954, + desc="Standard deviation value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_g: float = Field( + default=0.26130258, + desc="Standard deviation value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_b: float = Field( + default=0.27577711, + desc="Standard deviation value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + rescale_factor: float = Field( + default=255.0, + desc="Rescale factor for the image normalization process.", + hint=FieldHint.optional, + ) + + +class VisionEncoderType(str, enum.Enum): + none = "none" + pixtral = "pixtral" + + +@config_class() +class VisionEncoderConfig(BaseModelConfig): + _abstract = False + + type: VisionEncoderType = Field( + default=VisionEncoderType.none, + desc="Type of the vision encoder. Choices: none, pixtral.", + hint=FieldHint.architecture, + ) + transformer: VisionTransformerConfig = Field( + default_factory=VisionTransformerConfig, + desc="Configuration for the vision transformer architecture.", + hint=FieldHint.core, + ) + patch_size: int = Field( + default=16, + desc="Patch size for the image encoder.", + hint=FieldHint.core, + ) + conv_bias: bool = Field( + default=False, + desc="Whether to use bias in the convolutional layer.", + hint=FieldHint.optional, + ) + patch_norm: NormalizationConfig = Field( + default_factory=NormalizationConfig, + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) + adapter_size: int = Field( + default=5120, + desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + hint=FieldHint.core, + ) + adapter_activation_type: ActivationType = Field( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) + adapter_bias: bool = Field( + default=True, + desc="Whether to use bias in the adapter linear layer.", + hint=FieldHint.optional, + ) + image_normalization: ImageNormalizationConfig = Field( + default_factory=ImageNormalizationConfig, + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) + image_break_token: int | None = Field( + default=None, + desc="Token id to separate image rows. If None, no token id is applied.", + hint=FieldHint.optional, + ) + image_end_token: int | None = Field( + default=None, + desc="Token id to indicate the end of an image. If None, no token id is applied.", + hint=FieldHint.optional, + ) + adapter_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the adapter weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + conv_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the convolutional layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.in_channels, 3)) + self.transformer.setup_tensor_space(tensor_space) + + @property + def enabled(self) -> bool: + return self.type != VisionEncoderType.none diff --git a/fast_llm/layers/vision_encoder/patch_conv.py b/fast_llm/layers/vision_encoder/patch_conv.py new file mode 100644 index 00000000..68f22200 --- /dev/null +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -0,0 +1,90 @@ +import typing + +import torch + +from fast_llm.core.ops import split +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ + + +def position_ids_in_meshgrid(patch_embeddings_list, max_size): + positions = [] + for patch in patch_embeddings_list: + height, width = patch.shape[-2:] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_size + v_grid + positions.append(ids[:, 0]) + return torch.cat(positions) + + +def generate_block_attention_mask(patch_embeds_list, tensor): + dtype = tensor.dtype + device = tensor.device + seq_len = tensor.shape[1] + d_min = torch.finfo(dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + + block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) + block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) + for start, end in zip(block_start_idx, block_end_idx): + causal_mask[start:end, start:end] = 0 + + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + return causal_mask + + +class PatchConv(Layer): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + super().__init__() + self._tensor_space = tensor_space + self._distributed_config = tensor_space.distributed_config + self._sequence_parallel = self._distributed_config.sequence_tensor_parallel + self._lr_scale = config.adapter_lr_scale + # TODO Soham: lr_scale + self.weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), + ), + init_method=init_normal_(), + lr_scale=self._lr_scale, + ) + if config.conv_bias: + self.bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_sclae=self._lr_scale, + ) + else: + self.bias = None + self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) + self.stride = config.patch_size + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + hidden_dims = kwargs[VisionTransformerKwargs.hidden_dims] + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) + micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] + sequence_length = kwargs[TransformerKwargs.sequence_length] + out_channels = kwargs[VisionEncoderKwargs.out_channels] + reshape_dims = (micro_batch_size, sequence_length, out_channels) + group = self._tensor_space.distributed.tensor_group + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) + patch_embeddings = self.norm(input_.flatten(1)) + patch_embeddings = patch_embeddings.view(reshape_dims) + if self._sequence_parallel: + patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() + patch_embeddings = split(patch_embeddings, group=group, dim=0) + return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py new file mode 100644 index 00000000..76b0aa28 --- /dev/null +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -0,0 +1,261 @@ +import math +import typing + +import torch +import torchvision.transforms.v2.functional as F + +from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.tensor import TensorMeta +from fast_llm.utils import div + + +def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the number of patches in height and width dimensions. + """ + return div(height, patch_size) * div(width, patch_size) + + +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool, image_end: bool) -> int: + """ + Calculate the number of image tokens. + If image_break is True, we consider 1 additional token after every row of patches. + """ + height_patches = div(height, patch_size) + width_patches = div(width, patch_size) + num_tokens = height_patches * width_patches + if image_break: + num_tokens += height_patches + elif image_end: + num_tokens += 1 + return num_tokens + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + ratio = max(height / max_height, width / max_width) + if ratio > 1: + # Resize to fit within max dimensions + height = int(height / ratio) + width = int(width / ratio) + return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) + + +def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + resize_dims = get_resize_dims(image.size(1), image.size(2), max_height, max_width, patch_size=patch_size) + # TODO: options for interpolation mode? + return F.resize(image, size=resize_dims, interpolation=F.InterpolationMode.BICUBIC) + + +def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch.Tensor: + """ + Normalize the image using the specified mean and standard deviation. + """ + return F.normalize(image, mean=mean, std=std) + + +def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: + """ + Pad images on the right and bottom with 0s untitl max_height and max_width + """ + width_padding = max(0, max_height - image.size(1)) + depth_padding = max(0, max_width - image.size(2)) + return F.pad(image, (0, 0, depth_padding, width_padding), 0) + + +def create_inv_freqs(rope_theta: int, kv_channels: int, image_size: int, patch_size: int) -> torch.Tensor: + freqs = 1.0 / (rope_theta ** (torch.arange(0, kv_channels, 2).float() / kv_channels)) + max_patches_per_side = image_size // patch_size + + h = torch.arange(max_patches_per_side) + w = torch.arange(max_patches_per_side) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + return torch.cat((inv_freq, inv_freq), dim=-1) + + +def position_ids_in_meshgrid(image_sizes: list[torch.Tensor], max_size: int, patch_size: int) -> torch.Tensor: + positions = [] + for h, w in image_sizes: + patch_height = h // patch_size + patch_width = w // patch_size + mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_size + v_grid + positions.append(ids[:, 0]) + return positions + + +def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: + patch_height = height // patch_size + patch_width = width // patch_size + mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_size + v_grid + return ids[:, 0] + + +class VisionPreprocessor(Preprocessor): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( + ( + TensorDim( + VisionTransformerDimNames.batch, + kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, + ), + TensorDim(VisionEncoderDimNames.in_channels, 3), + TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + ), + dtype=self._distributed_config.training_dtype.torch, + ) + + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: + images = kwargs.get(VisionEncoderKwargs.images) + im_height = kwargs.get(VisionEncoderKwargs.image_size) + im_width = kwargs.get(VisionEncoderKwargs.image_size) + patch_size = kwargs[VisionEncoderKwargs.patch_size] + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + image_sizes = [ + [get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in ims] + for ims in images + ] + kwargs[VisionEncoderKwargs.image_sizes] = image_sizes + images = [ + [ + normalize( + resize(image, im_height, im_width, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch + ) + / kwargs[VisionEncoderKwargs.image_rescale_factor], + mean=kwargs[VisionEncoderKwargs.image_mean], + std=kwargs[VisionEncoderKwargs.image_std], + ) + for image in imgs + ] + for imgs in images + ] + + labels = kwargs[LanguageModelKwargs.labels] + if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): + # If image break or end token is present, we need to replace image token ids to -100 in labels + # TODO: avoid double cloning labels in case of loss masking spans? + labels = labels.clone() + + patches = [] + patch_position_ids = [] + cu_seqlens = [0] + max_seqlen = -1 + kwargs.get(TransformerKwargs.sequence_first) + for idx, (imgs, sizes, positions) in enumerate(zip(images, image_sizes, image_positions)): + # add an empty tensor for clean concatenation in case of no images + seq_patches = [ + torch.tensor([]).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + ] + sample_cu_seqlen = 0 + for image, size, position in zip(imgs, sizes, positions): + seqlen = get_num_patches(*size, patch_size) + num_tokens = get_num_image_tokens( + *size, + patch_size=patch_size, + image_break=self._config.image_break_token is not None, + image_end=self._config.image_end_token is not None, + ) + # set labels for image patches to -100 + labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 + if seqlen > max_seqlen: + max_seqlen = seqlen + cu_seqlens.append(cu_seqlens[-1] + seqlen) + sample_cu_seqlen += seqlen + seq_patches.append( + torch.cat( + [ + torch.nn.functional.unfold(image, kernel_size=patch_size, stride=patch_size).T.reshape( + -1, 3, patch_size, patch_size + ), + ] + ) + ) + # TODO Soham: should this be micro_sequence_length? + padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen + if padding_size > max_seqlen: + max_seqlen = padding_size + cu_seqlens.append(kwargs[TransformerKwargs.sequence_length] * (idx + 1)) + patches.append( + torch.cat( + [ + *seq_patches, + torch.zeros(padding_size, 3, patch_size, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ), + ] + ) + ) + if sizes: + position_ids = torch.cat( + [position_ids_in_meshgrid(*size, im_height // patch_size, patch_size) for size in sizes] + ).to(device=self._tensor_space.distributed.device) + else: + position_ids = torch.tensor( + [], + dtype=torch.int64, + device=self._tensor_space.distributed.device, + ) + # We pad at the end instead of padding at the position in meshgrid because flash attention does not support custom attention masks + patch_position_ids.append( + torch.cat( + [ + position_ids, + torch.full((padding_size,), 0).to(device=self._tensor_space.distributed.device), + ] + ) + ) + # TODO Soham: remove + assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] + patches = torch.cat(patches) + patch_position_ids = torch.cat(patch_position_ids) + kwargs[VisionEncoderKwargs.image_patches] = patches + kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids + kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( + kwargs[VisionEncoderKwargs.rope_theta], + kwargs[VisionEncoderKwargs.kv_channels], + im_height, + patch_size, + ).to(device=self._tensor_space.distributed.device) + kwargs[VisionEncoderKwargs.max_image_tokens] = div(im_height * im_width, patch_size**2) + # TODO Soham: handle sequence data parallel + kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.cu_seqlens_k] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen + kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen + kwargs[LanguageModelKwargs.labels] = labels diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 418f948e..ae6fc6ad 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -57,6 +57,27 @@ class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): trust_remote_code: typing.ClassVar[bool] = True +class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llava" + # Using default values for vision and text models. Can be overridden in the config + vision_name: typing.ClassVar[str] = "pixtral" + text_name: typing.ClassVar[str] = "mistral" + + +class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "pixtral" + + +class WhisperGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "whisper" + + +class AyraAudioModelGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "ayra_audio" + audio_name: typing.ClassVar[str] = "whisper" + text_name: typing.ClassVar[str] = "llama" + + @config_class() class GPTBatchConfig(BatchConfig): sequence_length: int = Field( @@ -138,6 +159,10 @@ class GPTModelConfig(FastLLMModelConfig): MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, + WhisperGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, + AyraAudioModelGPTHuggingfaceCheckpointFormat, ) @classmethod @@ -152,6 +177,25 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: return HuggingfaceGPTModelForCausalLM + @classmethod + def get_checkpoint_format(cls, format: type[CheckpointFormat]) -> type[CheckpointFormat]: + if isinstance(format, type) and issubclass(format, CheckpointFormat): + format_ = cls.get_checkpoint_format(format.name) + Assert.is_(format, format_) + return format_ + elif isinstance(format, dict): + for format_ in cls.checkpoint_formats: + if format_.name == format["name"]: + if (vision_name := format.get("vision_name")) is not None: + format_.vision_name = vision_name + if (text_name := format.get("text_name")) is not None: + format_.text_name = text_name + return format_ + for format_ in cls.checkpoint_formats: + if format_.name == format: + return format_ + raise ValueError(f"Checkpoint format {format} not supported for model {cls.model_name}") + @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index bd733692..568c7808 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -6,8 +6,10 @@ import torch from transformers.configuration_utils import PretrainedConfig -from fast_llm.config import DEFAULT, MISSING -from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm import __version__ +from fast_llm.config import DEFAULT, MISSING, get_nested_dict_value, set_nested_dict_value +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.external import ( AutoStateDictCheckpointHandler, ConstantExportParamConverter, @@ -22,20 +24,26 @@ WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex +from fast_llm.layers.audio_encoder.config import AudioEncoderType from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig +from fast_llm.layers.vision_encoder.config import VisionEncoderType from fast_llm.models.gpt.config import ( + AyraAudioModelGPTHuggingfaceCheckpointFormat, GPTBaseModelConfig, GPTModelConfig, LlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, + WhisperGPTHuggingfaceCheckpointFormat, ) from fast_llm.models.gpt.external.mtp_llama.configuration_mtp_llama import MTPLlamaConfig from fast_llm.models.gpt.model import GPTModel @@ -110,7 +118,37 @@ def import_weight( return (merged_weight.t().contiguous(),) -class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): +class WeightAndBiasConverterMixin: + def _get_weight_and_bias_converters( + self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + ) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters + + +class CommonHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): _model: GPTModel _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig """ @@ -166,17 +204,60 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters( self, + hf_base_prefix: str = "", + offset: int = 0, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers # Embeddings - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + converters.append( + WeightConverter(f"layers.{offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight") + ) - converters += self._create_lm_head_converters() + converters += self._create_lm_head_converters(hf_base_prefix, offset=offset) for i in range(num_layers): - converters += self._create_transformer_layer_converters(f"layers.{i+1}", f"model.layers.{i}") + converters += self._create_transformer_layer_converters( + f"layers.{i+offset+1}", f"{hf_base_prefix}model.layers.{i}" + ) + + return converters + + def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: + num_layers = self._model.config.base_model.transformer.num_layers + prediction_heads = self._model.config.base_model.prediction_heads + norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + converters = [] + + # Next-token prediction head + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + offset + 1}.final_norm", f"{hf_base_prefix}model.norm", norm_bias + ) + # Output weights + if self._model.config.base_model.tie_word_embeddings: + converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) + else: + converters.append( + WeightConverter(f"layers.{num_layers + offset + 1}.output_weights", f"{hf_base_prefix}lm_head.weight") + ) + + # MTP-heads > 0 are thrown away + # TODO Soham: handle offset with MTP + for i in range(1, prediction_heads): + logger.warning( + f"The model weights for the multi-token prediction head {i} are discarded during conversion." + ) + mtp_transformer_layer_index = num_layers - 1 + 2 * i + # MTP transformer layer + converters += self._create_transformer_layer_converters( + f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True + ) + # MTP output norm + converters += self._get_weight_and_bias_converters( + f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter + ) return converters @@ -247,68 +328,6 @@ def _create_transformer_layer_converters( converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") return converters - def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers - prediction_heads = self._model.config.base_model.prediction_heads - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - converters = [] - - # Next-token prediction head - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias - ) - # Output weights - if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) - else: - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) - - # MTP-heads > 0 are thrown away - for i in range(1, prediction_heads): - logger.warning( - f"The model weights for the multi-token prediction head {i} are discarded during conversion." - ) - mtp_transformer_layer_index = num_layers - 1 + 2 * i - # MTP transformer layer - converters += self._create_transformer_layer_converters( - f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True - ) - # MTP output norm - converters += self._get_weight_and_bias_converters( - f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter - ) - - return converters - - def _get_weight_and_bias_converters( - self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, - ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) - ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) - ) - return converters - class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Starcoder2GPTHuggingfaceCheckpointFormat @@ -357,7 +376,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), RenameParamConverter( fast_llm_names=(("transformer", "kv_channels"),), - export_names=(("head_dim"),), + export_names=(("head_dim",),), ), ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), @@ -548,6 +567,705 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] +class WhisperHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = WhisperGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + # set default layernorm + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.layer_norm + ), + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["WhisperForConditionalGeneration"] + ), + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=AudioEncoderType.whisper), + # make transformer noncasual + ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_layers", + ), + ), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "hidden_size", + ), + ), + export_names=(("d_model",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_attention_heads", + ), + ), + export_names=(("encoder_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "head_groups", + ), + ), + export_names=(("encoder_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "ffn_hidden_size", + ), + ), + export_names=(("encoder_ffn_dim",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("activation_function",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.none + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=False), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=True), + RenameParamConverter( + fast_llm_names=(("num_mel_bins",),), + export_names=(("num_mel_bins",),), + ), + ] + + def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + # return [ + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.weight", f"{hf_prefix}fc1.weight"), + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.bias", f"{hf_prefix}fc1.bias"), + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.weight", f"{hf_prefix}fc2.weight"), + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.bias", f"{hf_prefix}fc2.bias"), + # ] + transformer_config = self._model.config.base_model.audio_encoder.transformer + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + f"{hf_prefix}fc1", + transformer_config.add_mlp_bias, + WeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}fc2", + transformer_config.add_mlp_bias, + MLPLayer2Converter, + ), + ] + + def _create_audio_transformer_layer_converters( + self, transformer_layer_index: int, fast_llm_offset: int = 1, hf_base_prefix: str = "" + ) -> list[WeightConverter]: + # Vision transformer layer + transformer_config = self._model.config.base_model.audio_encoder.transformer + norm_bias: bool = transformer_config.normalization.type == NormalizationType.layer_norm + name_bias_cls = [ + # Self-attn + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.query", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.q_proj", + transformer_config.add_attn_qkv_bias, + QueryWeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.key_value", + ( + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.k_proj", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.v_proj", + ), + transformer_config.add_attn_qkv_bias, # TODO Toby: add permanent fix for key bias + KeyValueWeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.dense", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.out_proj", + transformer_config.add_attn_dense_bias, + WeightConverter, + ), + # Norm + ( + f"layers.{fast_llm_offset + transformer_layer_index}.norm_1", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn_layer_norm", + norm_bias, + WeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.norm_2", + f"{hf_base_prefix}layers.{transformer_layer_index}.final_layer_norm", + norm_bias, + WeightConverter, + ), + ] + converters = [] + for fast_llm_prefix, hf_prefix, use_bias, cls in name_bias_cls: + converters += self._get_weight_and_bias_converters( + fast_llm_prefix, + hf_prefix, + use_bias, + cls, + ) + # MLP + converters += self._get_transformer_mlp_converters( + f"layers.{fast_llm_offset + transformer_layer_index}", + f"{hf_base_prefix}layers.{transformer_layer_index}.", + ) + return converters + + def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: + converters = [] + + # audio encoder conv + converters += [ + WeightConverter(f"layers.{offset}.conv1_weight", f"{hf_base_prefix}conv1.weight"), + WeightConverter(f"layers.{offset}.conv2_weight", f"{hf_base_prefix}conv2.weight"), + ] + + if self._model.config.base_model.audio_encoder.conv_bias: + converters += [ + WeightConverter(f"layers.{offset}.conv1_bias", f"{hf_base_prefix}conv1.bias"), + WeightConverter(f"layers.{offset}.conv2_bias", f"{hf_base_prefix}conv2.bias"), + ] + + # position embedding + converters.append( + WeightConverter(f"layers.{offset}.positional_embeddings", f"{hf_base_prefix}embed_positions.weight") + ) + + # transformer encoder layers + num_layers = self._model.config.base_model.audio_encoder.transformer.num_layers + for i in range(num_layers): + converters += self._create_audio_transformer_layer_converters(i, offset + 1, hf_base_prefix) + + offset = offset + num_layers + 1 + + # add final layernorm + if self._model.config.base_model.audio_encoder.transformer.normalization.type == NormalizationType.layer_norm: + converters += [ + WeightConverter(f"layers.{offset}.norm_1.weight", f"{hf_base_prefix}layer_norm.weight"), + WeightConverter(f"layers.{offset}.norm_2.weight", "encoder_projector.layer_norm.weight"), + WeightConverter(f"layers.{offset}.norm_1.bias", f"{hf_base_prefix}layer_norm.bias"), + WeightConverter(f"layers.{offset}.norm_2.bias", "encoder_projector.layer_norm.bias"), + ] + + # multimodal projector + converters.extend( + [ + WeightConverter(f"layers.{offset}.layer_1.weight", "encoder_projector.linear1.weight"), + WeightConverter(f"layers.{offset}.layer_2.weight", "encoder_projector.linear2.weight"), + ] + ) + if self._model.config.base_model.audio_encoder.adapter_bias: + converters.extend( + [ + WeightConverter(f"layers.{offset}.layer_1.bias", "encoder_projector.linear1.bias"), + WeightConverter(f"layers.{offset}.layer_2.bias", "encoder_projector.linear2.bias"), + ] + ) + + return converters + + @property + def num_layers(self) -> int: + # +2 for projector and conv layers + return self._model.config.base_model.audio_encoder.transformer.num_layers + 2 + + +class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantImportParamConverter( + fast_llm_names=(("patch_norm", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=VisionEncoderType.pixtral), + ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_layers", + ), + ), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "hidden_size", + ), + ), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_attention_heads", + ), + ), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "head_groups", + ), + ), + export_names=(("num_key_value_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "ffn_hidden_size", + ), + ), + export_names=(("intermediate_size",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "kv_channels", + ), + ), + export_names=(("head_dim",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "rotary", + "theta", + ), + ), + export_names=(("rope_theta",),), + ), + RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), + ] + + def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.feed_forward.down_proj.weight", + self._model.config.base_model, + ), + ] + + def _create_vision_transformer_layer_converters( + self, transformer_layer_index: int, fast_llm_offset: int = 1, hf_base_prefix: str = "" + ) -> list[WeightConverter]: + # Vision transformer layer + transformer_config = self._model.config.base_model.vision_encoder.transformer + norm_bias: bool = transformer_config.normalization.type == NormalizationType.layer_norm + name_bias_cls = [ + # Self-attn + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.query", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.q_proj", + transformer_config.add_attn_qkv_bias, + QueryWeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.key_value", + ( + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.k_proj", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.v_proj", + ), + transformer_config.add_attn_qkv_bias, + KeyValueWeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.dense", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.o_proj", + transformer_config.add_attn_dense_bias, + WeightConverter, + ), + # Norm + ( + f"layers.{fast_llm_offset + transformer_layer_index}.norm_1", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention_norm", + norm_bias, + WeightConverter, + ), + ( + f"layers.{fast_llm_offset + transformer_layer_index}.norm_2", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.ffn_norm", + norm_bias, + WeightConverter, + ), + ] + converters = [] + for fast_llm_prefix, hf_prefix, use_bias, cls in name_bias_cls: + converters += self._get_weight_and_bias_converters( + fast_llm_prefix, + hf_prefix, + use_bias, + cls, + ) + # MLP + converters += self._get_transformer_mlp_converters( + f"layers.{fast_llm_offset + transformer_layer_index}", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}", + ) + return converters + + def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: + converters = [] + converters.append(WeightConverter(f"layers.{offset}.weight", f"{hf_base_prefix}patch_conv.weight")) + if self._model.config.base_model.vision_encoder.conv_bias: + converters.append(WeightConverter(f"layers.{offset}.bias", f"{hf_base_prefix}patch_conv.bias")) + converters.append(WeightConverter(f"layers.{offset}.norm.weight", f"{hf_base_prefix}ln_pre.weight")) + if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + converters.append(WeightConverter(f"layers.{offset}.norm.bias", f"{hf_base_prefix}ln_pre.bias")) + + num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers + for i in range(num_layers): + converters += self._create_vision_transformer_layer_converters(i, offset + 1, hf_base_prefix) + + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.weight", "multi_modal_projector.linear_1.weight" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.weight", "multi_modal_projector.linear_2.weight" + ), + ] + ) + if self._model.config.base_model.vision_encoder.adapter_bias: + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.bias", "multi_modal_projector.linear_1.bias" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.bias", "multi_modal_projector.linear_2.bias" + ), + ] + ) + + return converters + + @property + def num_layers(self) -> int: + # +2 for projector and conv layers + return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 + + +class AyraAudioModelHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = AyraAudioModelGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + cfg_dict = cls._load_config(config.path) + kwargs = {} + if "text_config" in cfg_dict: + text_kwargs = cls._import_config(cfg_dict["text_config"]) + kwargs.update(text_kwargs) + if "audio_config" in cfg_dict: + audio_kwargs = cls._import_config(cfg_dict["audio_config"]) + audio_kwargs = {tuple(["audio_encoder"] + list(key)): value for key, value in audio_kwargs.items()} + kwargs.update(audio_kwargs) + kwargs.update( + cls._import_config( + {key: value for key, value in cfg_dict.items() if key not in ("text_config", "audio_config")} + ) + ) + imported_model_config = cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + return CheckpointMetadata( + fast_llm_version=__version__, + model=cls._model_class, + format=config.format, + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + shards=["weights"], + ) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["AyraAudioModel"]), + # projector + MappedConfigParamConverter( + fast_llm_names=(("audio_encoder", "adapter_activation_type"),), + export_names=(("activation_function",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("audio_encoder", "adapter_size"),), + export_names=(("adapter_size",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "audio_encoder", + "aud_downsampling_k", + ), + ), + export_names=(("encoder_projector_ds_rate",),), + ), + ] + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: + handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + kwargs = {} + for converter in handler_cls._create_config_converters(): + try: + values = () + for export_name in converter.export_names: + try: + value = get_nested_dict_value(config, export_name) + except KeyError: + value = MISSING + values = values + (value,) + values = converter.import_params(values) + for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): + if value is MISSING: + raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") + if fast_llm_name in kwargs: + raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") + kwargs[fast_llm_name] = value + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return kwargs + + @classmethod + def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: + # TODO Toby: implement for audio + exported_config = {} + audio_handler_class = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.audio_name) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + for converter in audio_handler_class._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, ("audio_encoder",) + fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("audio_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in text_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("text_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return exported_config + + def _create_weight_converters(self): + audio_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.audio_name) + audio_handler = audio_handler_cls(self._model) # TODO Toby: are we calling this twice? + converters = audio_handler._create_weight_converters(hf_base_prefix="encoder.", offset=0) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.text_name) + text_handler = text_handler_cls(self._model) + converters.extend( + text_handler._create_weight_converters(hf_base_prefix="llm.", offset=audio_handler.num_layers) + ) + return converters + + +class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + cfg_dict = cls._load_config(config.path) + kwargs = {} + if "text_config" in cfg_dict: + text_kwargs = cls._import_config(cfg_dict["text_config"]) + kwargs.update(text_kwargs) + if "vision_config" in cfg_dict: + vision_kwargs = cls._import_config(cfg_dict["vision_config"]) + vision_kwargs = {tuple(["vision_encoder"] + list(key)): value for key, value in vision_kwargs.items()} + kwargs.update(vision_kwargs) + kwargs.update( + cls._import_config( + {key: value for key, value in cfg_dict.items() if key not in ("text_config", "vision_config")} + ) + ) + imported_model_config = cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + return CheckpointMetadata( + fast_llm_version=__version__, + model=cls._model_class, + format=config.format, + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + shards=["weights"], + ) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] + ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=(("projector_intermediate_size",),), + ), + ] + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: + handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + kwargs = {} + for converter in handler_cls._create_config_converters(): + try: + values = () + for export_name in converter.export_names: + try: + value = get_nested_dict_value(config, export_name) + except KeyError: + value = MISSING + values = values + (value,) + values = converter.import_params(values) + for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): + if value is MISSING: + raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") + if fast_llm_name in kwargs: + raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") + kwargs[fast_llm_name] = value + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return kwargs + + @classmethod + def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: + exported_config = {} + vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + for converter in vision_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, ("vision_encoder",) + fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("vision_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in text_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("text_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return exported_config + + def _create_weight_converters(self): + vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.vision_name) + vision_handler = vision_handler_cls(self._model) + converters = vision_handler._create_weight_converters(hf_base_prefix="vision_tower.", offset=0) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.text_name) + text_handler = text_handler_cls(self._model) + converters.extend( + text_handler._create_weight_converters(hf_base_prefix="language_model.", offset=vision_handler.num_layers) + ) + return converters + + class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat @@ -680,4 +1398,8 @@ class AutoGPTHuggingfaceCheckpointHandler( MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, + LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, + WhisperGPTHuggingfaceCheckpointFormat.name: WhisperHuggingfaceCheckpointHandler, + PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, + AyraAudioModelGPTHuggingfaceCheckpointFormat.name: AyraAudioModelHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index d177a41d..48f5760b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,15 +10,25 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.audio_encoder.adapter import AudioAdapter +from fast_llm.layers.audio_encoder.config import AudioEncoderKwargs +from fast_llm.layers.audio_encoder.encoder import AudioConv +from fast_llm.layers.audio_encoder.preprocessing import AudioPreprocessor from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor +from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding +from fast_llm.layers.transformer.audio_transformer import AudioTransformerLayer from fast_llm.layers.transformer.config import ( + AudioTransformerDimNames, + AudioTransformerKwargs, RoutingType, TransformerDimNames, TransformerKwargs, TransformerLossNames, + VisionTransformerDimNames, + VisionTransformerKwargs, ) from fast_llm.layers.transformer.preprocessing import ( BackupAttentionPreprocessor, @@ -26,6 +36,11 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer +from fast_llm.layers.vision_encoder.adapter import VisionAdapter +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs +from fast_llm.layers.vision_encoder.patch_conv import PatchConv +from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -70,6 +85,15 @@ def __init__( else: self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + if self._config.vision_encoder.enabled: + self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) + if self._config.vision_encoder.transformer.rotary.enabled: + self._preprocessors.append( + RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) + ) + if self._config.audio_encoder.enabled: + self._preprocessors.append(AudioPreprocessor(self._config.audio_encoder, self._tensor_space)) + def get_output_layers(self) -> list[Layer]: layers = [] for i in range(self._config.prediction_heads): @@ -94,9 +118,47 @@ def get_output_layers(self) -> list[Layer]: ) return layers + def get_vision_layers(self) -> list[Layer]: + patch_conv = PatchConv(self._config.vision_encoder, self._tensor_space) + vit_layers = [ + VisionTransformerLayer(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) + for idx in range(self._config.vision_encoder.transformer.num_layers) + ] + return [ + patch_conv, + *vit_layers, + VisionAdapter(self._config.vision_encoder, self._tensor_space), + MultiModalEmbedding(self._config, self._tensor_space), + ] + + def get_audio_layers(self) -> list[Layer]: + audio_conv = AudioConv(self._config.audio_encoder, self._tensor_space) + audio_layers = [ + AudioTransformerLayer(self._config.audio_encoder.transformer, self._tensor_space, layer_index=idx + 1) + for idx in range(self._config.audio_encoder.transformer.num_layers) + ] + return [ + audio_conv, + *audio_layers, + AudioAdapter(self._config.audio_encoder, self._tensor_space), + MultiModalEmbedding(self._config, self._tensor_space), + ] + + def get_multimodal_layers(self) -> list[Layer]: + if self._config.vision_encoder.enabled: + return self.get_vision_layers() + elif self._config.audio_encoder.enabled: + return self.get_audio_layers() + else: + assert False + def get_layers(self) -> list[Layer]: return [ - LanguageModelEmbedding(self._config, self._tensor_space), + *( + [LanguageModelEmbedding(self._config, self._tensor_space)] + if not self._config.vision_encoder.enabled and not self._config.audio_encoder.enabled + else self.get_multimodal_layers() + ), *[ TransformerLayer( self._config.transformer, @@ -127,6 +189,48 @@ def preprocess_meta( sequence_length -= self._config.prediction_heads micro_sequence_length = sequence_length + if self._config.vision_encoder.enabled: + image_size = batch_meta.image_size + image_mean = [ + self._config.vision_encoder.image_normalization.mean_r, + self._config.vision_encoder.image_normalization.mean_g, + self._config.vision_encoder.image_normalization.mean_b, + ] + image_std = [ + self._config.vision_encoder.image_normalization.std_r, + self._config.vision_encoder.image_normalization.std_g, + self._config.vision_encoder.image_normalization.std_b, + ] + image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor + vision_kwargs = { + VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, + VisionEncoderKwargs.image_size: image_size, + VisionEncoderKwargs.image_mean: image_mean, + VisionEncoderKwargs.image_std: image_std, + VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, + VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, + VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( + VisionTransformerDimNames.kv_channels + ).size, + VisionEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( + VisionEncoderDimNames.out_channels + ).size, + } + else: + vision_kwargs = {} + + if self._config.audio_encoder.enabled: + audio_kwargs = { + AudioEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( + AudioTransformerDimNames.kv_channels + ).size, + AudioEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( + AudioEncoderKwargs.out_channels + ).size, + } + else: + audio_kwargs = {} + batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) @@ -169,6 +273,34 @@ def preprocess_meta( if sequence_first else (batch_dim, hidden_sequence_q_dim, hidden_dim) ) + if self._config.vision_encoder.enabled: + vision_hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) + vision_hidden_dims = ( + (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) + if sequence_first + else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) + ) + vision_kwargs.update( + { + VisionTransformerKwargs.hidden_dims: vision_hidden_dims, + } + ) + + if self._config.audio_encoder.enabled: + audio_hidden_dim = self._tensor_space.get_tensor_dim(AudioTransformerDimNames.hidden) + audio_hidden_dims = ( + (hidden_sequence_q_dim, batch_dim, audio_hidden_dim) + if sequence_first + else (batch_dim, hidden_sequence_q_dim, audio_hidden_dim) + ) + audio_kwargs.update( + { + AudioTransformerKwargs.hidden_dims: audio_hidden_dims, + AudioTransformerKwargs.sequence_length: 1500, # TODO: Toby Parameterize + AudioTransformerKwargs.sequence_k_dim: 1500, + AudioTransformerKwargs.sequence_q_dim: 1500, + } + ) common_kwargs = { LanguageModelKwargs.phase: phase, @@ -176,7 +308,10 @@ def preprocess_meta( TransformerKwargs.hidden_dims: hidden_dims, TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, + TransformerKwargs.micro_batch_size: micro_batch_size, } + common_kwargs.update(vision_kwargs) + common_kwargs.update(audio_kwargs) sequence_k_pasts = range( sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, @@ -222,7 +357,11 @@ def preprocess_meta( reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs - preprocessed_meta.append((tokens, kwargs)) + if self._config.vision_encoder.enabled: + # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size + preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) + else: + preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -305,7 +444,7 @@ def preprocess( if batch.loss_masking_spans is not None: # avoid changing input tokens labels = labels.clone() - for i, spans in enumerate(batch.loss_masking_spans): + for idx, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue valid_spans = spans[ @@ -318,28 +457,71 @@ def preprocess( loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - loss_mask[start : end + 1, i] = False + loss_mask[start : end + 1, idx] = False else: - loss_mask[i, start : end + 1] = False + loss_mask[idx, start : end + 1] = False if self._config.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) + if self._config.vision_encoder.enabled: + batch_images = ( + batch.images if batch.images is not None else [[]] * kwargs[TransformerKwargs.micro_batch_size] + ) + kwargs[VisionEncoderKwargs.images] = [ + [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for img in images + ] + for images in batch_images + ] + kwargs[VisionEncoderKwargs.image_positions] = ( + batch.image_positions + if batch.image_positions is not None + else [[]] * kwargs[TransformerKwargs.micro_batch_size] + ) + kwargs[LanguageModelKwargs.tokens] = tokens + + if self._config.audio_encoder.enabled: + if batch.audio is not None: + kwargs[AudioEncoderKwargs.audio] = [ + [aud.to(device="cpu", dtype=torch.float32, non_blocking=True) for aud in audio] + for audio in batch.audio + ] + kwargs[AudioEncoderKwargs.audio_positions] = batch.audio_positions + kwargs[LanguageModelKwargs.tokens] = tokens + for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) - preprocessed.append((tokens, kwargs)) + image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + audio_mel = kwargs.get(AudioEncoderKwargs.audio_mel, None) + if audio_mel is not None: + preprocessed.append((audio_mel, kwargs)) + elif image_patches is not None: + preprocessed.append((image_patches, kwargs)) + else: + preprocessed.append((tokens, kwargs)) return preprocessed @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[0] + return self.layers[self.embedding_layer_index] @property def transformer_layers(self) -> list[TransformerLayer]: - return self.layers[1:-1] + return self.layers[self.embedding_layer_index + 1 : -1] + + @property + def embedding_layer_index(self) -> int: + if self._config.vision_encoder.enabled: + return self._config.vision_encoder.transformer.num_layers + 2 + elif self._config.audio_encoder.enabled: + return self._config.audio_encoder.transformer.num_layers + 2 + else: + return 0 @property def model_head(self) -> LanguageModelHead: @@ -354,7 +536,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - (0, *self.model_head_indices), + (self.embedding_layer_index, *self.model_head_indices), ) } elif self._config.prediction_heads > 1: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 3bdb05c3..b4a3036f 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -32,6 +32,25 @@ def _get_sampling_parameters( "extra_tokens": self._config.model.base_model.prediction_heads, } ) + if self._config.model.base_model.vision_encoder.enabled: + parameters.update( + { + "patch_size": self._config.model.base_model.vision_encoder.patch_size, + "image_size": self._config.batch.image_size, + "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, + "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, + } + ) + if self._config.model.base_model.audio_encoder.enabled: + parameters.update( + { + "aud_downsampling_k": self._config.model.base_model.audio_encoder.aud_downsampling_k, + "aud_padding_duration": self._config.batch.aud_padding_duration, + "aud_sampling_rate": self._config.model.base_model.audio_encoder.aud_sampling_rate, + "audio_start_token": self._config.model.base_model.audio_encoder.audio_start_token, + "audio_end_token": self._config.model.base_model.audio_encoder.audio_end_token, + } + ) return parameters if _return_dict else GPTSamplingParameters(**parameters) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: diff --git a/setup.cfg b/setup.cfg index 9b944b27..25f8af8b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ CORE = # Required for some optional features and tools. OPTIONAL = # Huggingface tools - transformers>=4.44.2 + transformers>=4.48.3 hf-transfer>=0.1.8 datasets>=3.1.0 huggingface-hub>=0.28.1 @@ -44,6 +44,10 @@ OPTIONAL = # Miscellanous requests>=2.32.3 tqdm>=4.66.3 + # Vision Tools + webp>=0.4.0 + pillow-simd>=9.5.0 + torchvision>=0.20.0 DEV = # Pre-commit git hook