From 7709e65d05652906936ada2cdb53c31ab4e68663 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 8 Apr 2025 06:51:19 +0000 Subject: [PATCH 01/82] WIP: multimodal support --- fast_llm/data/config.py | 39 ++++++++++++++++++ fast_llm/data/image_processor.py | 40 +++++++++++++++++++ .../data/preparator/gpt_memmap/prepare.py | 3 ++ fast_llm/data/processor.py | 11 +++++ setup.cfg | 2 + 5 files changed, 95 insertions(+) create mode 100644 fast_llm/data/image_processor.py create mode 100644 fast_llm/data/processor.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 1586d370..351dcaae 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -34,3 +34,42 @@ class TokenizerConfig(Config): desc="Path to the tokenizer file.", hint=FieldHint.core, ) + + +@config_class() +class ImageProcessorConfig(Config): + """ + Configuration for the image processor + """ + + # Defaults taken from [pixtral](https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/pixtral/image_processing_pixtral.py#L201) + patch_size: list[int] = Field( + default_factory=lambda: [16, 16], + desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer", + hint=FieldHint.optional, + ) + max_height: int = Field( + default=1024, + desc="Maximum height of the image. Image will be resized if larger", + hint=FieldHint.optional, + ) + max_width: int = Field( + default=1024, + desc="Maximum width of the image. Image will be resized if larger", + hint=FieldHint.optional, + ) + mean: list[float] = Field( + default_factory=lambda: [0.48145466, 0.4578275, 0.40821073], + desc="Mean RGB values for pixel normalization", + hint=FieldHint.optional, + ) + std: list[float] = Field( + default_factory=lambda: [0.26862954, 0.26130258, 0.27577711], + desc="Standard deviation RGB values for pixel normalization", + hint=FieldHint.optional, + ) + rescale_factor: float = Field( + default=255.0, + desc="Diminisher factor for pixel normalization", + hint=FieldHint.optional, + ) diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py new file mode 100644 index 00000000..cf4c6e93 --- /dev/null +++ b/fast_llm/data/image_processor.py @@ -0,0 +1,40 @@ +import math + +import torch +from torchvision.transforms.v2 import functional as F + +from fast_llm.data.config import ImageProcessorConfig + + +class ImageProcessor: + def __init__(self, config: ImageProcessorConfig): + self.patch_size = config.patch_size + self.mean = config.mean / config.rescale_factor + self.std = config.std / config.rescale_factor + self.max_height = config.max_height + self.max_width = config.max_width + assert ( + self.max_height % self.patch_size[0] == 0 + ), "max_height must be divisible by patch_size[0]. Found {max_height} and {self.patch_size[0]}" + assert ( + self.max_width % self.patch_size[1] == 0 + ), "max_width must be divisible by patch_size[1]. Found {max_width} and {self.patch_size[1]}" + + def resize(self, image: torch.Tensor) -> torch.Tensor: + # Resize the image to the specified size + height = image.shape[0] + width = image.shape[1] + ratio = max(height / self.max_height, width / self.max_width) + if ratio > 1: + height = math.ceil(height / ratio) + width = math.ceil(width / ratio) + else: + height = self.patch_size[0] * math.ceil(height / self.self.patch_size[0]) + width = self.patch_size[1] * math.ceil(width / self.patch_size[1]) + + # TODO: options for interpolation mode + return F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) + + def normalize(self, image: torch.Tensor) -> torch.Tensor: + # Normalize the image using the mean and std + return F.normalize(image, mean=self.mean, std=self.std) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b3dae1df..5cfad9ec 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -38,6 +38,9 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _tokenizer: Tokenizer _data_type: DataType + def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + pass + def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids = [ np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) diff --git a/fast_llm/data/processor.py b/fast_llm/data/processor.py new file mode 100644 index 00000000..43b1cda8 --- /dev/null +++ b/fast_llm/data/processor.py @@ -0,0 +1,11 @@ +from fast_llm.data.tokenizer import Tokenizer + + +class MultiModalProcessor: + """ + Combines multiple modalities (text and image) and converts to tokens/patches for text and images. + """ + + def __init__(self, tokenizer: Tokenizer, image_processor=None): + self._tokenizer = tokenizer + self._image_processor = image_processor diff --git a/setup.cfg b/setup.cfg index c21f02a7..3c1dad9d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,8 @@ OPTIONAL = # Miscellanous requests>=2.32.3 tqdm>=4.66.3 + # Vision Tools + torchvision>=0.20.0 DEV = # Pre-commit git hook From 0db2bd21218fa133d4a1e41223552ece8f3044a7 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Apr 2025 06:19:10 +0000 Subject: [PATCH 02/82] rough idea for memmap --- fast_llm/data/config.py | 18 ++++ fast_llm/data/dataset/gpt/memmap.py | 59 ++++++++++-- fast_llm/data/dataset/gpt/sampled.py | 2 + fast_llm/data/image_processor.py | 3 + fast_llm/data/preparator/gpt_memmap/config.py | 8 +- .../data/preparator/gpt_memmap/prepare.py | 92 ++++++++++++++----- fast_llm/data/processor.py | 11 --- 7 files changed, 145 insertions(+), 48 deletions(-) delete mode 100644 fast_llm/data/processor.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 351dcaae..8c2c3c28 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -73,3 +73,21 @@ class ImageProcessorConfig(Config): desc="Diminisher factor for pixel normalization", hint=FieldHint.optional, ) + + +@config_class() +class MultiModalProcessorConfig(Config): + """ + Wrapper config that stores the `ImageProcessorConfig` and `TokenizerConfig` + """ + + tokenizer: TokenizerConfig = Field( + default_factory=TokenizerConfig, + desc="Configuration for the tokenizer.", + hint=FieldHint.core, + ) + image_processor: ImageProcessorConfig = Field( + default_factory=ImageProcessorConfig, + desc="Configuration for the image processor.", + hint=FieldHint.core, + ) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index ef060b00..c8b2592f 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -38,10 +38,14 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 3: + self._has_images = struct.unpack("= 2: self._spans = [] self._num_spans = np.frombuffer( self._index_bin_buffer, @@ -82,6 +86,8 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) + if self._has_images and self._version >= 3: + self._image_sizes = np.frombuffer() self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) @@ -151,7 +157,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Initialize metadata dtype = None num_documents = 0 - lengths = [] + doc_lengths = [] + n_images = [] + im_lengths = [] + im_positions = [] + total_images = 0 pointers = [] offset = 0 # number of spans for each document @@ -160,8 +170,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) + pathlib.Path(prefix + "_images") # Write the binary data file (.bin) lazily + # TODO Soham: append image tokens along with text tokens with prefix.with_suffix(".bin").open("wb") as bin_stream: for document in documents: # Infer dtype from the first document @@ -174,10 +186,18 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) + if document.images: + n_images.append(len(document.images)) + total_images += len(document.images) + for image, image_position in zip(document.images, document.image_positions): + im_lengths.append(image.size) + im_positions.append(document.image_positions) + bin_stream.write(image.tobytes(order="C")) # Update metadata doc_length = len(document.token_ids) - lengths.append(doc_length) + doc_lengths.append(doc_length) + im_lengths.append() pointers.append(offset) if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) @@ -186,7 +206,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP num_documents += 1 # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) + doc_lengths = np.array(doc_lengths, dtype=np.int32) pointers = np.array(pointers, dtype=np.int64) num_spans = np.array(num_spans, dtype=np.int32) if len(spans) > 0: @@ -194,27 +214,46 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP else: spans = np.array(spans, dtype=np.int32) + # TODO Soham: else condition might not be necessary + if total_images: + n_images = np.array(n_images, dtype=np.int32) + im_lengths = np.array(im_lengths, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) + else: + n_images = np.array([]) + im_lengths = np.array([]) + im_positions = np.array([]) + # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version - # Version 2 optionally adds loss-masking spans - idx_stream.write(struct.pack(" 0 else 0)) + # Flag to indicate whether images are present + idx_stream.write(struct.pack(" 0 else 0)) # Data type idx_stream.write(struct.pack(" torch.Tensor: def normalize(self, image: torch.Tensor) -> torch.Tensor: # Normalize the image using the mean and std return F.normalize(image, mean=self.mean, std=self.std) + + def get_num_patches(self, image: torch.Tensor) -> torch.Tensor: + return (image.size(0) // self.patch_size[0]) * (image.size(1) // self.patch_size[1]) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c3..60262743 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -3,7 +3,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.config import TokenizerConfig +from fast_llm.data.config import MultiModalProcessorConfig from fast_llm.data.preparator.config import DatasetPreparatorConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -153,9 +153,9 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the dataset.", hint=FieldHint.feature, ) - tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, - desc="Configuration for the tokenizer.", + data_processor: MultiModalProcessorConfig = Field( + default_factory=MultiModalProcessorConfig, + desc="Configuration for data processing. Describes the tokenizer and image processor", hint=FieldHint.feature, ) splits: dict[str, float] | None = Field( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 5cfad9ec..d4180986 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -23,9 +23,9 @@ ) from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.multi_modal_processor import MultiModalProcessor from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig -from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -35,45 +35,79 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig - _tokenizer: Tokenizer + # _tokenizer: Tokenizer + _data_processor: MultiModalProcessor _data_type: DataType def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: pass def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) - for text in batch[self._config.dataset.field] + # input_ids = [ + # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) + # for text in batch[self._config.dataset.field] + # ] + input_ids, images, image_token_positions = map( + list, + zip( + *[ + ( + np.array(input_ids, dtype=self._data_type.numpy), + np.array(images, dtype=np.uint8), + np.array(image_token_positions, dtype=np.int32), + ) + for input_ids, images, image_token_positions in [ + self._data_processor.tokenize(text, ims, im_char_positions) + for text, ims, im_char_positions in zip( + batch[self._config.dataset.field], + batch[self._config.dataset.images], + batch[self._config.dataset.image_positions], + ) + ] + ] + ), + ) + num_tokens = [ + len(x) + self._data_processor._image_processor.get_num_patches(im) for x, im in zip(input_ids, images) ] - num_tokens = [len(x) for x in input_ids] return { "input_ids": input_ids, + "images": images, + "image_positions": image_token_positions, "num_tokens": num_tokens, } def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans = map( + input_ids, token_spans, images, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32).reshape(-1, 2), + np.array(images, dtype=np.uint8), + np.array(image_token_positions, dtype=np.int32), ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, char_spans) + for input_ids, token_spans, images, image_token_positions in [ + self._data_processor.tokenize_with_spans(text, char_spans) for text, char_spans in zip( - batch[self._config.dataset.field], batch[self._config.dataset.loss_masking_spans] + batch[self._config.dataset.field], + batch[self._config.dataset.loss_masking_spans], + batch[self._config.dataset.images], + batch[self._config.dataset.image_positions], ) ] ] ), ) - num_tokens = [len(x) for x in input_ids] + num_tokens = [ + len(x) + self._data_processor._image_processor.get_num_patches(im) for x, im in zip(input_ids, images) + ] return { "input_ids": input_ids, "token_spans": token_spans, + "images": images, + "image_positions": image_token_positions, "num_tokens": num_tokens, } @@ -83,15 +117,27 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): - if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - ) - else: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) + # TODO Soham: simplify this + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + np.array(item["input_ids"], dtype=self._data_type.numpy), + ( + np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) + if self._config.dataset.loss_masking_spans + else None + ), + images if self._config.dataset.images else None, + image_positions if self._config.dataset.image_positions else None, + ) + # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: + # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + # yield GPTSample( + # np.array(item["input_ids"], dtype=self._data_type.numpy), + # np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), + # ) + # else: + # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + # yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -169,12 +215,12 @@ def run(self) -> None: if self._config.dataset.disable_disk_space_check: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True - # Load tokenizer - self._tokenizer = Tokenizer(config=self._config.tokenizer) + # Load Processor + self._processor = MultiModalProcessor(config=self._config.data_processor) # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( - get_unsigned_integer_type(self._tokenizer.vocab_size) + get_unsigned_integer_type(self.processor._tokenizer.vocab_size) if self._config.dataset.data_type is None else self._config.dataset.data_type ) diff --git a/fast_llm/data/processor.py b/fast_llm/data/processor.py deleted file mode 100644 index 43b1cda8..00000000 --- a/fast_llm/data/processor.py +++ /dev/null @@ -1,11 +0,0 @@ -from fast_llm.data.tokenizer import Tokenizer - - -class MultiModalProcessor: - """ - Combines multiple modalities (text and image) and converts to tokens/patches for text and images. - """ - - def __init__(self, tokenizer: Tokenizer, image_processor=None): - self._tokenizer = tokenizer - self._image_processor = image_processor From 0d89f68d7c4d5a40f5fa7e2651ac61b75da31aa5 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 15 Apr 2025 06:10:33 +0000 Subject: [PATCH 03/82] faster image size reading --- fast_llm/data/dataset/gpt/memmap.py | 54 ++++++++++++------ fast_llm/data/image_processor.py | 17 +++--- fast_llm/data/preparator/gpt_memmap/config.py | 18 +++++- .../data/preparator/gpt_memmap/prepare.py | 55 ++++++++++++------- setup.cfg | 3 + 5 files changed, 101 insertions(+), 46 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c8b2592f..06924054 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -34,12 +34,12 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._name = name self._prefix = pathlib.Path(prefix) self._has_spans = 0 + self._has_images = 0 with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 2: self._spans = [] @@ -73,9 +74,8 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._index_bin_buffer, dtype=np.int32, count=self._num_documents, - offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, + offset=offset, ) - span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] for idx in range(self._num_documents): self._spans.append( @@ -83,18 +83,40 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._index_bin_buffer, dtype=np.int32, count=self._num_spans[idx] * 2, - offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, + offset=offset + + self._num_spans.nbytes + + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) + offset += ( + self._num_spans.nbytes + + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize + + sum([x.nbytes for x in self._spans]) + ) if self._has_images and self._version >= 3: - self._image_sizes = np.frombuffer() + self._n_images = np.frombuffer( + self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset + ) + self._im_lengths = np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=self._n_images.sum() * 3, + offset=offset + self._n_images.nbytes, + ) + self._im_positions = np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=self._n_images.sum(), + offset=offset + self._n_images.nbytes + self._im_lengths.nbytes, + ) self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) + # TODO Soham: fix num_tokens to include images self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) - if num_tokens is not None: - assert self._num_tokens == num_tokens + # if num_tokens is not None: + # assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens) @@ -110,6 +132,7 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap + # TODO Soham: get images def get( self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False ) -> GPTSample: @@ -170,10 +193,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) - pathlib.Path(prefix + "_images") # Write the binary data file (.bin) lazily - # TODO Soham: append image tokens along with text tokens with prefix.with_suffix(".bin").open("wb") as bin_stream: for document in documents: # Infer dtype from the first document @@ -186,23 +207,25 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) + total_im_size = 0 if document.images: n_images.append(len(document.images)) total_images += len(document.images) for image, image_position in zip(document.images, document.image_positions): - im_lengths.append(image.size) + # assume 3 channels (RGB) for all images + im_lengths.append(np.array(image.shape[1:])) im_positions.append(document.image_positions) bin_stream.write(image.tobytes(order="C")) + total_im_size += image.size # Update metadata doc_length = len(document.token_ids) doc_lengths.append(doc_length) - im_lengths.append() pointers.append(offset) if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) - offset += doc_length * np.dtype(dtype).itemsize + offset += doc_length * np.dtype(dtype).itemsize + total_im_size * np.dtype(np.uint8).itemsize num_documents += 1 # Finalize metadata arrays @@ -214,15 +237,12 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP else: spans = np.array(spans, dtype=np.int32) - # TODO Soham: else condition might not be necessary if total_images: n_images = np.array(n_images, dtype=np.int32) - im_lengths = np.array(im_lengths, dtype=np.int32) - im_positions = np.array(im_positions, dtype=np.int32) else: n_images = np.array([]) - im_lengths = np.array([]) - im_positions = np.array([]) + im_lengths = np.stack(im_lengths, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py index 473db11a..c5cbe909 100644 --- a/fast_llm/data/image_processor.py +++ b/fast_llm/data/image_processor.py @@ -9,8 +9,8 @@ class ImageProcessor: def __init__(self, config: ImageProcessorConfig): self.patch_size = config.patch_size - self.mean = config.mean / config.rescale_factor - self.std = config.std / config.rescale_factor + self.mean = [x / config.rescale_factor for x in config.mean] + self.std = [x / config.rescale_factor for x in config.std] self.max_height = config.max_height self.max_width = config.max_width assert ( @@ -20,16 +20,19 @@ def __init__(self, config: ImageProcessorConfig): self.max_width % self.patch_size[1] == 0 ), "max_width must be divisible by patch_size[1]. Found {max_width} and {self.patch_size[1]}" - def resize(self, image: torch.Tensor) -> torch.Tensor: + def resize(self, image): # Resize the image to the specified size - height = image.shape[0] - width = image.shape[1] + # TODO Soham: resize for patches only during train? + # TODO Soham: convert all images to tensor? + # height = image.shape[0] + # width = image.shape[1] + height, width = image.size ratio = max(height / self.max_height, width / self.max_width) if ratio > 1: height = math.ceil(height / ratio) width = math.ceil(width / ratio) else: - height = self.patch_size[0] * math.ceil(height / self.self.patch_size[0]) + height = self.patch_size[0] * math.ceil(height / self.patch_size[0]) width = self.patch_size[1] * math.ceil(width / self.patch_size[1]) # TODO: options for interpolation mode @@ -40,4 +43,4 @@ def normalize(self, image: torch.Tensor) -> torch.Tensor: return F.normalize(image, mean=self.mean, std=self.std) def get_num_patches(self, image: torch.Tensor) -> torch.Tensor: - return (image.size(0) // self.patch_size[0]) * (image.size(1) // self.patch_size[1]) + return (image.size[0] // self.patch_size[0]) * (image.size[1] // self.patch_size[1]) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 60262743..8a15d96c 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -59,6 +59,15 @@ class GPTHuggingfaceDatasetConfig(Config): loss_masking_spans: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) + image_paths: None | str = Field( + default=None, desc="Field containing images within the document", hint=FieldHint.optional + ) + image_positions: None | str = Field( + default=None, desc="Field containing image positions within a document", hint=FieldHint.optional + ) + images: None | str = Field( + default=None, desc="Field containing images relevant to a document", hint=FieldHint.optional + ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." @@ -142,6 +151,12 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, valid=check_field(Assert.geq, 1), ) + tokenize_batch_size: int = Field( + default=1000, + desc="Batch size for tokenization.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 1), + ) saving_workers: int = Field( default=1, desc="Number of processes for saving the data.", @@ -165,8 +180,9 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, ) + # TODO Soham: move tokenizer validation to MultiModalDataProcessor def _validate(self) -> None: - assert self.tokenizer.path is not None + assert self.data_processor.tokenizer.path is not None if self.dataset.data_type is not None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) super()._validate() diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index d4180986..0199cb40 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,3 +1,5 @@ +import io +import itertools import json import logging import multiprocessing @@ -13,6 +15,7 @@ import tqdm import transformers import yaml +from PIL import Image from fast_llm.data.dataset.gpt.config import ( GPTBlendedDatasetConfig, @@ -42,37 +45,43 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: pass + # TODO Soham: can we merged tokenize_batch and tokenize_batch_with_spans? def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: # input_ids = [ # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) # for text in batch[self._config.dataset.field] # ] - input_ids, images, image_token_positions = map( + input_ids, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), - np.array(images, dtype=np.uint8), np.array(image_token_positions, dtype=np.int32), ) - for input_ids, images, image_token_positions in [ - self._data_processor.tokenize(text, ims, im_char_positions) - for text, ims, im_char_positions in zip( + for input_ids, image_token_positions in [ + self._data_processor.tokenize( + text, + im_char_positions, + ) + for text, im_char_positions in zip( batch[self._config.dataset.field], - batch[self._config.dataset.images], - batch[self._config.dataset.image_positions], + batch.get(self._config.dataset.image_positions, itertools.repeat(None)), ) ] ] ), ) - num_tokens = [ - len(x) + self._data_processor._image_processor.get_num_patches(im) for x, im in zip(input_ids, images) - ] + num_tokens = [len(x) for x in input_ids] + # TODO Soham: is this ok? Should we get num_image_tokens separately? + for idx, images in enumerate(batch.get("images", [])): + for bytes_im in images: + with Image.open(io.BytesIO(bytes_im["bytes"])) as im: + width, height = im.size + num_tokens[idx] += (width * height * 3) // np.dtype(self._dtype).itemsize + return { "input_ids": input_ids, - "images": images, "image_positions": image_token_positions, "num_tokens": num_tokens, } @@ -92,16 +101,17 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict self._data_processor.tokenize_with_spans(text, char_spans) for text, char_spans in zip( batch[self._config.dataset.field], - batch[self._config.dataset.loss_masking_spans], - batch[self._config.dataset.images], - batch[self._config.dataset.image_positions], + batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), + batch.get(self._config.dataset.images, itertools.repeat(None)), + batch.get(self._config.dataset.image_positions, itertools.repeat(None)), ) ] ] ), ) num_tokens = [ - len(x) + self._data_processor._image_processor.get_num_patches(im) for x, im in zip(input_ids, images) + len(x) + sum([self._data_processor._image_processor.get_num_patches(im) for im in doc_images]) + for x, doc_images in zip(input_ids, images) ] return { "input_ids": input_ids, @@ -117,7 +127,6 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): - # TODO Soham: simplify this for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( np.array(item["input_ids"], dtype=self._data_type.numpy), @@ -126,8 +135,9 @@ def _document_generator(): if self._config.dataset.loss_masking_spans else None ), - images if self._config.dataset.images else None, - image_positions if self._config.dataset.image_positions else None, + # [np.array(Image.open(pathlib.Path(self._config.dataset.path) / path)) for path in item["image_paths"]] if self._config.dataset.image_paths else None, + [np.array(im) for im in item["images"]] if self._config.dataset.images else None, + item["image_positions"] if self._config.dataset.image_positions else None, ) # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): @@ -215,12 +225,12 @@ def run(self) -> None: if self._config.dataset.disable_disk_space_check: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True - # Load Processor - self._processor = MultiModalProcessor(config=self._config.data_processor) + # Load the data processor + self._data_processor = MultiModalProcessor(config=self._config.data_processor) # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( - get_unsigned_integer_type(self.processor._tokenizer.vocab_size) + get_unsigned_integer_type(self._data_processor._tokenizer.vocab_size) if self._config.dataset.data_type is None else self._config.dataset.data_type ) @@ -269,6 +279,9 @@ def run(self) -> None: tokenize_fn = self._tokenize_batch_with_spans else: tokenize_fn = self._tokenize_batch + # Avoid decoding bytes to images unless asked + if self._config.dataset.images is not None: + dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) # Tokenize the dataset in parallel tokenized_dataset = dataset.map( diff --git a/setup.cfg b/setup.cfg index 3c1dad9d..57913f83 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,9 @@ OPTIONAL = requests>=2.32.3 tqdm>=4.66.3 # Vision Tools + # TODO Soham: use pillow-simd instead of pillow? + webp>=0.4.0 + pillow-simd>=9.5.0 torchvision>=0.20.0 DEV = From 3866a5330fcf299ba8347b8e3aed057b598b5185 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 21 Apr 2025 07:04:41 +0000 Subject: [PATCH 04/82] solidify prepare --- fast_llm/data/config.py | 60 ++++----- fast_llm/data/data/gpt/config.py | 1 + fast_llm/data/data/gpt/data.py | 9 +- fast_llm/data/dataset/gpt/config.py | 15 ++- fast_llm/data/dataset/gpt/memmap.py | 127 ++++++++++++++---- fast_llm/data/dataset/gpt/sampled.py | 52 +++++-- fast_llm/data/image_processor.py | 25 ++-- fast_llm/data/preparator/gpt_memmap/config.py | 10 +- .../data/preparator/gpt_memmap/prepare.py | 49 ++++--- fast_llm/data/tokenizer.py | 30 ++++- fast_llm/layers/language_model/config.py | 14 ++ 11 files changed, 291 insertions(+), 101 deletions(-) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 8c2c3c28..f1a0fd58 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -43,36 +43,36 @@ class ImageProcessorConfig(Config): """ # Defaults taken from [pixtral](https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/pixtral/image_processing_pixtral.py#L201) - patch_size: list[int] = Field( - default_factory=lambda: [16, 16], - desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer", - hint=FieldHint.optional, - ) - max_height: int = Field( - default=1024, - desc="Maximum height of the image. Image will be resized if larger", - hint=FieldHint.optional, - ) - max_width: int = Field( - default=1024, - desc="Maximum width of the image. Image will be resized if larger", - hint=FieldHint.optional, - ) - mean: list[float] = Field( - default_factory=lambda: [0.48145466, 0.4578275, 0.40821073], - desc="Mean RGB values for pixel normalization", - hint=FieldHint.optional, - ) - std: list[float] = Field( - default_factory=lambda: [0.26862954, 0.26130258, 0.27577711], - desc="Standard deviation RGB values for pixel normalization", - hint=FieldHint.optional, - ) - rescale_factor: float = Field( - default=255.0, - desc="Diminisher factor for pixel normalization", - hint=FieldHint.optional, - ) + # patch_size: list[int] = Field( + # default_factory=lambda: [16, 16], + # desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer", + # hint=FieldHint.optional, + # ) + # max_height: int = Field( + # default=1024, + # desc="Maximum height of the image. Image will be resized if larger", + # hint=FieldHint.optional, + # ) + # max_width: int = Field( + # default=1024, + # desc="Maximum width of the image. Image will be resized if larger", + # hint=FieldHint.optional, + # ) + # mean: list[float] = Field( + # default_factory=lambda: [0.48145466, 0.4578275, 0.40821073], + # desc="Mean RGB values for pixel normalization", + # hint=FieldHint.optional, + # ) + # std: list[float] = Field( + # default_factory=lambda: [0.26862954, 0.26130258, 0.27577711], + # desc="Standard deviation RGB values for pixel normalization", + # hint=FieldHint.optional, + # ) + # rescale_factor: float = Field( + # default=255.0, + # desc="Diminisher factor for pixel normalization", + # hint=FieldHint.optional, + # ) @config_class() diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index c98a781e..652342b5 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -21,6 +21,7 @@ class GPTSamplingDefaultConfig(SamplingDefaultConfig, GPTSamplingConfig): gpu: bool = FieldUpdate(default=True) use_loss_masking_spans: bool = FieldUpdate(default=False) + use_images: bool = FieldUpdate(default=False) shuffle: ShufflingType = FieldUpdate(default=ShufflingType.epoch) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index a0940e7c..5bd9d09e 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,10 +32,16 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None + images: list[torch.Tensor] | None = None + image_positions: list[torch.Tensor] | None = None +# TODO: do we need a separate use_images? def gpt_data_collate_fn( - batch: list[GPTSample], use_loss_masking_spans: bool, cross_document_attention: bool + batch: list[GPTSample], + use_loss_masking_spans: bool, + cross_document_attention: bool, + use_images: bool, ) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None @@ -170,6 +176,7 @@ def get_iterator( gpt_data_collate_fn, use_loss_masking_spans=self._config.sampling.use_loss_masking_spans, cross_document_attention=self._cross_document_attention, + use_images=self._config.sampling.use_images, ), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 0f04884b..45d27e7d 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -57,6 +57,11 @@ class GPTSamplingConfig(SamplingConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) + use_images: bool | None = Field( + default=None, + desc="Use images in the dataset.", + hint=FieldHint.feature, + ) shuffle: ShufflingType | None = Field( default=None, desc="Shuffling strategy.", @@ -73,6 +78,7 @@ class GPTSamplingData(SamplingData): tokenizer: "Tokenizer" truncate_documents: bool = True cross_document_attention: bool = True + patch_size: list[int] | None = None @config_class() @@ -178,11 +184,18 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="Expected number of tokens in the dataset.", hint=FieldHint.optional, ) + num_pixels: int | None = Field( + default=None, + desc="Expected number of pixels in the dataset.", + hint=FieldHint.optional, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) + return GPTMemmapDataset( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels + ) @config_class() diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 06924054..87bd3a8e 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -1,8 +1,10 @@ +import io import pathlib import struct import typing import numpy as np +import PIL.Image from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample @@ -26,10 +28,18 @@ def __init__( prefix: pathlib.Path | str, num_documents: int | None = None, num_tokens: int | None = None, + num_pixels: int | None = None, ): - self._init(name, prefix, num_documents, num_tokens) + self._init(name, prefix, num_documents, num_tokens, num_pixels) - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: + def _init( + self, + name: str, + prefix: pathlib.Path | str, + num_documents: int | None, + num_tokens: int | None, + num_pixels: int | None, + ) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) @@ -93,30 +103,48 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize + sum([x.nbytes for x in self._spans]) ) + self._n_pixels = 0 if self._has_images and self._version >= 3: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._im_lengths = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._n_images.sum() * 3, - offset=offset + self._n_images.nbytes, - ) - self._im_positions = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._n_images.sum(), - offset=offset + self._n_images.nbytes + self._im_lengths.nbytes, - ) + self._im_lengths = [] + self._im_positions = [] + images_seen = 0 + # TODO Soham: verify correctness, reshaping into width, height? + for n_images in self._n_images: + self._im_lengths.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_images * 2, + offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, + ).reshape(-1, 2) + ) + self._n_pixels += self._im_lengths[-1].prod(axis=1, initial=3).sum() + self._im_positions.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_images, + offset=offset + + self._n_images.nbytes + + 2 * self._n_images.sum() * np.dtype(np.int32).itemsize + + images_seen * np.dtype(np.int32).itemsize, + ) + ) + images_seen += n_images self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - # TODO Soham: fix num_tokens to include images - self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) - # if num_tokens is not None: - # assert self._num_tokens == num_tokens + # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign + # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) + self._num_tokens = div(self._bin_buffer_mmap.size - self._n_pixels, np.dtype(self._dtype).itemsize) + if num_pixels is not None: + assert self._n_pixels == num_pixels + if num_tokens is not None: + assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens) @@ -133,6 +161,42 @@ def __del__(self): del self._index_bin_buffer_mmap # TODO Soham: get images + def get( + self, + idx: int, + offset: int = 0, + length: int | None = None, + use_loss_masking_spans: bool = False, + # , patch_size: tuple(int), max_height: int, max_width: int + ): + # TODO Soham: Handle truncations? + # if self._has_images: + # doc_size = self._document_sizes[idx] + # n_images = self._n_images[idx] + # image_positions = self._im_positions[idx] + # image_lengths = self._im_lengths[idx] + # image_tokens_seen = 0 + # for idx in range(n_images): + # height, width = ImageProcessor.get_resize_dims(image_lengths[0], image_lengths[1], max_height, max_width) + # n_image_tokens = (height // patch_size[0]) * (width // patch_size[1]) + # if (image_positions[idx] > offset + length) or (image_positions[idx] + n_tokens < offset): + # continue + token_ids = np.frombuffer( + self._bin_buffer, + dtype=self._dtype, + count=self._document_sizes[idx] - offset if length is None else length, + offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + ) + if self._has_images: + image_positions = self._im_positions[idx] + images = np.frombuffer( + self._bin_buffer, + dtype=np.dtype(np.uint8).itemsize, + count=self._image_lengths[idx][0] * self._image_lengths[idx][1] * 3, + offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, + ) + return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) + def get( self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False ) -> GPTSample: @@ -164,16 +228,25 @@ def __len__(self) -> int: def num_tokens(self) -> int: return self._num_tokens + @property + def has_images(self) -> bool: + return self._has_images + + # TODO: image sizes def get_document_sizes(self) -> np.ndarray: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes + return self._document_sizes, self._im_lengths - def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item() + def get_document_size(self, index: int, patch_size: list[int]) -> int: + return self._document_sizes[index].item() + ( + sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) + if self._has_images + else 0 + ) @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): @@ -211,12 +284,14 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if document.images: n_images.append(len(document.images)) total_images += len(document.images) - for image, image_position in zip(document.images, document.image_positions): + for image in document.images: # assume 3 channels (RGB) for all images - im_lengths.append(np.array(image.shape[1:])) - im_positions.append(document.image_positions) - bin_stream.write(image.tobytes(order="C")) - total_im_size += image.size + with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + pixels = np.array(img) + im_lengths.append(np.array(pixels.shape[:2])) + bin_stream.write(pixels.tobytes(order="C")) + total_im_size += pixels.size + im_positions.append(document.image_positions) # Update metadata doc_length = len(document.token_ids) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 22e3396b..288018b1 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -12,6 +12,7 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset +from fast_llm.data.image_processor import ImageProcessor from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert @@ -89,11 +90,17 @@ def __init__( self._indexed_dataset = indexed_dataset self._num_samples = sampling.num_samples self._sequence_length = sampling.sequence_length + self._patch_size = sampling.patch_size self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config self._truncate_documents = sampling.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") + if self._indexed_dataset.has_images and self._truncate_documents: + raise RuntimeError( + "Truncating documents with images is not supported. Please turn off truncation to use images." + ) + if sampling.cache_directory is None: self._document_shuffling = MemmapArray() self._token_cumsum_shuffled = MemmapArray() @@ -126,9 +133,15 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + # TODO Soham: verify numpy correctness + document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() + document_sizes = torch.from_numpy(document_sizes).to(self._device) + image_token_sizes = torch.zeros_like(document_sizes).to(self._device) + for i, sizes in enumerate(image_sizes): + image_token_sizes[i] = sum(sizes[0, :] // self._patch_size[0]) * (sizes[:, 1] // self._patch_size[1]) + documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() # Calculate basic stats. if not self._truncate_documents: @@ -136,14 +149,14 @@ def _sample(self) -> None: "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes > self._sequence_length + 1 + long_docs_filter = document_sizes + image_token_sizes > self._sequence_length + 1 ignored_documents = sum(long_docs_filter) if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() + tokens_per_epoch = (document_sizes[~long_docs_filter] + image_token_sizes[~long_docs_filter]).sum().item() if tokens_per_epoch == 0: raise RuntimeError( f" > No documents shorter than {self._sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." @@ -177,6 +190,7 @@ def _sample(self) -> None: "num_samples": self._num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._sequence_length, + "patch_size": self._patch_size, "truncate_documents": self._truncate_documents, "config": self._config.to_serialized(), } @@ -258,7 +272,7 @@ def _sample(self) -> None: # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, num_tokens_unshuffled = self._get_token_cumsum( - document_sizes, + document_sizes + image_token_sizes, offset=0, # TODO: Allowing for max 100% extra tokens for padding, is that enough? dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), @@ -282,6 +296,9 @@ def _sample(self) -> None: document_shuffling.to( dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 ) + ] + + image_token_sizes[ + document_shuffling.to(torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32) ], offset=num_tokens_unshuffled, # TODO: Allowing for max 100% extra tokens for padding, is that enough? @@ -360,6 +377,9 @@ def __getitem__(self, index: int) -> typing.Any: token_ids = [] loss_masking_spans = [] + images = [] + image_positions = [] + image_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -367,7 +387,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size = self._indexed_dataset.get_document_size(document_index) + document_size = self._indexed_dataset.get_document_size(document_index, self._patch_size) if not self._truncate_documents: if document_size > self._sequence_length + 1: @@ -398,6 +418,12 @@ def __getitem__(self, index: int) -> typing.Any: length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._config.use_loss_masking_spans, ) + # TODO Soham: handle images with loss masking spans + for idx, im_position in enumerate(sample.image_positions): + # image_positions.append(im_positions + len(token_ids) + image_tokens_added) + image_positions.append(im_position + len(token_ids) + image_tokens_added) + image_tokens_added += ImageProcessor.get_num_patches(sample.images[idx]) + images.append(sample.images) token_ids.append(sample.token_ids) if self._config.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: @@ -411,6 +437,7 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) + + np.array([ImageProcessor.get_num_patches(image) for image in images[idx] for idx in range(len(images))]) if not self._cross_document_attention else None ) @@ -420,9 +447,16 @@ def __getitem__(self, index: int) -> typing.Any: if self._config.use_loss_masking_spans else None ) - Assert.eq(len(token_ids), self._sequence_length + 1) - - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + images = [im for img_list in images for im in img_list] + Assert.eq(len(token_ids) + image_tokens_added, self._sequence_length + 1) + + return GPTSample( + token_ids=token_ids, + loss_masking_spans=loss_masking_spans, + sequence_lengths=sequence_lengths, + images=images, + image_positions=image_positions, + ) @property def name(self) -> str: diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py index c5cbe909..567c8146 100644 --- a/fast_llm/data/image_processor.py +++ b/fast_llm/data/image_processor.py @@ -26,21 +26,26 @@ def resize(self, image): # TODO Soham: convert all images to tensor? # height = image.shape[0] # width = image.shape[1] - height, width = image.size - ratio = max(height / self.max_height, width / self.max_width) - if ratio > 1: - height = math.ceil(height / ratio) - width = math.ceil(width / ratio) - else: - height = self.patch_size[0] * math.ceil(height / self.patch_size[0]) - width = self.patch_size[1] * math.ceil(width / self.patch_size[1]) + height, width = self.get_resize_dims(image.shape[0], image.shape[1], self.max_height, self.max_width) # TODO: options for interpolation mode return F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) + # TODO Soham: move to utils + @classmethod + def get_resize_dims(height, width, max_height, max_width, patch_size: list[int]): + ratio = max(height / max_height, width / max_width) + return ( + (math.ceil(height / ratio), math.ceil(width / ratio)) + if ratio > 1 + else (patch_size[0] * math.ceil(height / patch_size[0]), patch_size[1] * math.ceil(width / patch_size[1])) + ) + def normalize(self, image: torch.Tensor) -> torch.Tensor: # Normalize the image using the mean and std return F.normalize(image, mean=self.mean, std=self.std) - def get_num_patches(self, image: torch.Tensor) -> torch.Tensor: - return (image.size[0] // self.patch_size[0]) * (image.size[1] // self.patch_size[1]) + @classmethod + # TODO Soham: move to utils + def get_num_patches(image: torch.Tensor, patch_size: list[int]) -> torch.Tensor: + return (image.size[0] // patch_size[0]) * (image.size[1] // patch_size[1]) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 8a15d96c..89fe904c 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -3,7 +3,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.config import MultiModalProcessorConfig +from fast_llm.data.config import TokenizerConfig from fast_llm.data.preparator.config import DatasetPreparatorConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -168,9 +168,9 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the dataset.", hint=FieldHint.feature, ) - data_processor: MultiModalProcessorConfig = Field( - default_factory=MultiModalProcessorConfig, - desc="Configuration for data processing. Describes the tokenizer and image processor", + tokenizer: TokenizerConfig = Field( + default_factory=TokenizerConfig, + desc="Tokenizer configuration.", hint=FieldHint.feature, ) splits: dict[str, float] | None = Field( @@ -182,7 +182,7 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): # TODO Soham: move tokenizer validation to MultiModalDataProcessor def _validate(self) -> None: - assert self.data_processor.tokenizer.path is not None + assert self.tokenizer.path is not None if self.dataset.data_type is not None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) super()._validate() diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 0199cb40..4965dfdf 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -10,12 +10,12 @@ import datasets import huggingface_hub import numpy as np +import PIL.Image import requests import torch.distributed import tqdm import transformers import yaml -from PIL import Image from fast_llm.data.dataset.gpt.config import ( GPTBlendedDatasetConfig, @@ -26,9 +26,9 @@ ) from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.data.multi_modal_processor import MultiModalProcessor from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -38,8 +38,7 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig - # _tokenizer: Tokenizer - _data_processor: MultiModalProcessor + _tokenizer: Tokenizer _data_type: DataType def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -60,7 +59,7 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ np.array(image_token_positions, dtype=np.int32), ) for input_ids, image_token_positions in [ - self._data_processor.tokenize( + self._tokenizer.tokenize( text, im_char_positions, ) @@ -73,17 +72,18 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ ), ) num_tokens = [len(x) for x in input_ids] - # TODO Soham: is this ok? Should we get num_image_tokens separately? + num_pixels = [0] * len(input_ids) for idx, images in enumerate(batch.get("images", [])): for bytes_im in images: - with Image.open(io.BytesIO(bytes_im["bytes"])) as im: + with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: width, height = im.size - num_tokens[idx] += (width * height * 3) // np.dtype(self._dtype).itemsize + num_pixels[idx] += width * height * 3 return { "input_ids": input_ids, "image_positions": image_token_positions, "num_tokens": num_tokens, + "num_pixels": num_pixels, } def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -98,7 +98,7 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict np.array(image_token_positions, dtype=np.int32), ) for input_ids, token_spans, images, image_token_positions in [ - self._data_processor.tokenize_with_spans(text, char_spans) + self._tokenizer.tokenize_with_spans(text, char_spans) for text, char_spans in zip( batch[self._config.dataset.field], batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), @@ -109,16 +109,20 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict ] ), ) - num_tokens = [ - len(x) + sum([self._data_processor._image_processor.get_num_patches(im) for im in doc_images]) - for x, doc_images in zip(input_ids, images) - ] + num_tokens = [len(x) for x in input_ids] + num_pixels = [0] * len(input_ids) + for idx, images in enumerate(images): + for bytes_im in images: + with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + width, height = im.size + num_pixels[idx] += width * height * 3 return { "input_ids": input_ids, "token_spans": token_spans, "images": images, "image_positions": image_token_positions, "num_tokens": num_tokens, + "num_pixels": num_pixels, } def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: @@ -136,7 +140,8 @@ def _document_generator(): else None ), # [np.array(Image.open(pathlib.Path(self._config.dataset.path) / path)) for path in item["image_paths"]] if self._config.dataset.image_paths else None, - [np.array(im) for im in item["images"]] if self._config.dataset.images else None, + # [np.array(im) for im in item["images"]] if self._config.dataset.images else None, + item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, ) # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: @@ -157,6 +162,7 @@ def _document_generator(): "path": prefix, "num_documents": len(shard_dataset), # Use the length of the shard dataset directly "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), + "num_pixels": sum(doc["num_pixels"] for doc in shard_dataset), } ) @@ -225,12 +231,12 @@ def run(self) -> None: if self._config.dataset.disable_disk_space_check: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True - # Load the data processor - self._data_processor = MultiModalProcessor(config=self._config.data_processor) + # Load tokenizer + self._tokenizer = Tokenizer(config=self._config.tokenizer) # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( - get_unsigned_integer_type(self._data_processor._tokenizer.vocab_size) + get_unsigned_integer_type(self._tokenizer.vocab_size) if self._config.dataset.data_type is None else self._config.dataset.data_type ) @@ -293,6 +299,12 @@ def run(self) -> None: # Calculate total number of tokens total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) + total_pixels = ( + sum(tqdm.tqdm(tokenized_dataset["num_pixels"], desc="Counting pixels", unit="pixels")) + if self._config.dataset.images + else 0 + ) + total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) @@ -391,7 +403,8 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - sizes_cumsum = dataset.get_document_sizes().cumsum() + # TODO Soham: handle pixels (could still work with number of tokens?) + sizes_cumsum = dataset.get_document_sizes()[0].cumsum() Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 28e105ee..0e7d5470 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -35,13 +35,41 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, begin=True, end=True) -> list[int]: + def _tokenize(self, text: str, begin=True, end=True) -> list[int]: return ( ([self.bod_id] if begin else []) + self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end else []) ) + def tokenize(self, text, image_positions=None): + if not image_positions: + return self._tokenize(text), [], [] + image_idx = 0 + char_pos = 0 + token_ids = [] + image_token_positions = [] + beginning_of_text = True + while image_idx < len(image_positions): + if image_positions[image_idx] > len(text): + raise ValueError( + f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" + ) + curr_text = text[char_pos : image_positions[image_idx]] + tokenized_text = self._tokenize( + curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) + ) + beginning_of_text = False + token_ids.extend(tokenized_text) + image_token_positions = len(token_ids) + char_pos = image_positions[image_idx] + image_idx += 1 + if char_pos < len(text): + curr_text = text[char_pos:] + tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) + token_ids.extend(tokenized_text) + return token_ids, image_token_positions + def tokenize_with_spans( self, text: str, char_spans: list[tuple[int, int]] ) -> tuple[list[int], list[tuple[int, int]]]: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 3bd79603..75c5418b 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -6,6 +6,7 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig +from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig from fast_llm.utils import Assert @@ -198,3 +199,16 @@ def _validate(self) -> None: if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) super()._validate() + + +class MultiModalBaseConfig: + language_model: LanguageModelBaseConfig = Field( + default_factory=LanguageModelBaseConfig, + desc="Configuration for the language model.", + hint=FieldHint.core, + ) + vision_model: VisionArchitectureConfig = Field( + default_factory=VisionArchitectureConfig, + desc="Configuration for the vision inputs.", + hint=FieldHint.core, + ) From 841398396714e5c3b346d6d2c46dcb37f532c167 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 24 Apr 2025 07:55:31 +0000 Subject: [PATCH 05/82] wip --- fast_llm/data/data/gpt/config.py | 1 - fast_llm/data/data/gpt/data.py | 31 ++- fast_llm/data/dataset/gpt/config.py | 7 +- fast_llm/data/dataset/gpt/indexed.py | 12 +- fast_llm/data/dataset/gpt/memmap.py | 97 +++++---- fast_llm/data/dataset/gpt/sampled.py | 32 ++- fast_llm/data/image_processor.py | 10 +- fast_llm/engine/schedule/config.py | 15 ++ fast_llm/layers/language_model/config.py | 13 +- fast_llm/models/gpt/config.py | 4 + fast_llm/models/gpt/conversion.py | 258 +++++++++++++++++++++-- fast_llm/models/gpt/model.py | 12 +- fast_llm/models/gpt/trainer.py | 3 + 13 files changed, 400 insertions(+), 95 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 652342b5..c98a781e 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -21,7 +21,6 @@ class GPTSamplingDefaultConfig(SamplingDefaultConfig, GPTSamplingConfig): gpu: bool = FieldUpdate(default=True) use_loss_masking_spans: bool = FieldUpdate(default=False) - use_images: bool = FieldUpdate(default=False) shuffle: ShufflingType = FieldUpdate(default=ShufflingType.epoch) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 5bd9d09e..22e4730c 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -36,12 +36,11 @@ class GPTBatch: image_positions: list[torch.Tensor] | None = None -# TODO: do we need a separate use_images? +# TODO: collate images def gpt_data_collate_fn( batch: list[GPTSample], use_loss_masking_spans: bool, cross_document_attention: bool, - use_images: bool, ) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None @@ -50,8 +49,24 @@ def gpt_data_collate_fn( stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] if not cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + batch_images = [] + for sample in batch: + if sample.images is not None: + batch_images.append([torch.from_numpy(image) for image in sample.images]) + else: + batch_images.append(None) + batch_image_positions = [] + for sample in batch: + if sample.image_positions is not None: + batch_image_positions.append(torch.from_numpy(sample.image_positions)) + else: + batch_image_positions.append(None) return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths + token_ids=torch.from_numpy(stacked_ids), + loss_masking_spans=stacked_spans, + sequence_lengths=sequence_lengths, + images=batch_images if any(batch_images) else None, + image_positions=batch_image_positions if any(batch_image_positions) else None, ) @@ -73,6 +88,9 @@ def __init__( vocab_size: int, max_sequence_length: int, cross_document_attention: bool = True, + patch_size: list[int] | None = None, + max_image_height: int | None = None, + max_image_width: int | None = None, ): """ Create the data and gather some basic information on the dataset(s). @@ -82,6 +100,9 @@ def __init__( self._vocab_size = vocab_size self._max_sequence_length = max_sequence_length self._cross_document_attention = cross_document_attention + self._patch_size = patch_size + self._max_image_height = max_image_height + self._max_image_width = max_image_width def setup( self, @@ -129,6 +150,9 @@ def setup( tokenizer=self._tokenizer, truncate_documents=self._config.truncate_documents, cross_document_attention=self._cross_document_attention, + patch_size=self._patch_size, + image_height=self._max_image_height, + image_width=self._max_image_width, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) @@ -176,7 +200,6 @@ def get_iterator( gpt_data_collate_fn, use_loss_masking_spans=self._config.sampling.use_loss_masking_spans, cross_document_attention=self._cross_document_attention, - use_images=self._config.sampling.use_images, ), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 45d27e7d..8022a05f 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -57,11 +57,6 @@ class GPTSamplingConfig(SamplingConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) - use_images: bool | None = Field( - default=None, - desc="Use images in the dataset.", - hint=FieldHint.feature, - ) shuffle: ShufflingType | None = Field( default=None, desc="Shuffling strategy.", @@ -79,6 +74,8 @@ class GPTSamplingData(SamplingData): truncate_documents: bool = True cross_document_attention: bool = True patch_size: list[int] | None = None + image_height: int | None = None + image_width: int | None = None @config_class() diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 688ea6a7..209c6e31 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -11,6 +11,7 @@ class GPTIndexedDataset(IndexedDataset): + # TODO Soham: should we change this to include images? @abc.abstractmethod def get_document_sizes(self) -> np.ndarray: """ @@ -44,10 +45,15 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] + doc_sizes, im_sizes = self._dataset.get_document_sizes() + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] - def get_document_size(self, index: int) -> int: - return self._dataset.get_document_size(self._begin + index) + def get_document_size(self, index: int, patch_size: list[int]) -> int: + return self._dataset.get_document_size(self._begin + index, patch_size) + + @property + def has_images(self) -> bool: + return self._dataset.has_images class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 87bd3a8e..43fba843 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -103,17 +103,17 @@ def _init( + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize + sum([x.nbytes for x in self._spans]) ) - self._n_pixels = 0 + self._num_pixels = 0 if self._has_images and self._version >= 3: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._im_lengths = [] - self._im_positions = [] + self._image_lengths = [] + self._image_positions = [] images_seen = 0 # TODO Soham: verify correctness, reshaping into width, height? for n_images in self._n_images: - self._im_lengths.append( + self._image_lengths.append( np.frombuffer( self._index_bin_buffer, dtype=np.int32, @@ -121,8 +121,8 @@ def _init( offset=offset + self._n_images.nbytes + 2 * images_seen * np.dtype(np.int32).itemsize, ).reshape(-1, 2) ) - self._n_pixels += self._im_lengths[-1].prod(axis=1, initial=3).sum() - self._im_positions.append( + self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum() + self._image_positions.append( np.frombuffer( self._index_bin_buffer, dtype=np.int32, @@ -140,14 +140,14 @@ def _init( # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) - self._num_tokens = div(self._bin_buffer_mmap.size - self._n_pixels, np.dtype(self._dtype).itemsize) + self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) if num_pixels is not None: - assert self._n_pixels == num_pixels + assert self._num_pixels == num_pixels if num_tokens is not None: assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) + return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): self._init(*state) @@ -169,7 +169,7 @@ def get( use_loss_masking_spans: bool = False, # , patch_size: tuple(int), max_height: int, max_width: int ): - # TODO Soham: Handle truncations? + # TODO Soham: handle spans # if self._has_images: # doc_size = self._document_sizes[idx] # n_images = self._n_images[idx] @@ -188,34 +188,42 @@ def get( offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) if self._has_images: - image_positions = self._im_positions[idx] - images = np.frombuffer( + image_positions = self._image_positions[idx] + pixels = np.frombuffer( self._bin_buffer, - dtype=np.dtype(np.uint8).itemsize, - count=self._image_lengths[idx][0] * self._image_lengths[idx][1] * 3, + dtype=np.dtype(np.uint8), + count=self._image_lengths[idx].prod(initial=3), offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, ) + images = [] + start = 0 + for image_length in self._image_lengths[idx]: + # TODO Soham: verify reshape dimension order + n_pixels = image_length.prod(initial=3) + images.append(pixels[start : start + n_pixels].reshape(image_length[0], image_length[1], 3)) + start += n_pixels + # TODO Soham: return loss_masking_spans return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) - def get( - self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False - ) -> GPTSample: - token_ids = np.frombuffer( - self._bin_buffer, - dtype=self._dtype, - count=self._document_sizes[idx] - offset if length is None else length, - offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - ) - sample_spans = None - if use_loss_masking_spans and self._spans is not None: - sample_spans = self._spans[idx] - # adjust the spans for the offset and length - sample_spans = sample_spans[ - (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - ] - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) + # def get( + # self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False + # ) -> GPTSample: + # token_ids = np.frombuffer( + # self._bin_buffer, + # dtype=self._dtype, + # count=self._document_sizes[idx] - offset if length is None else length, + # offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + # ) + # sample_spans = None + # if use_loss_masking_spans and self._spans is not None: + # sample_spans = self._spans[idx] + # # adjust the spans for the offset and length + # sample_spans = sample_spans[ + # (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) + # ] + # sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset + # sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset + # return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) @property def name(self) -> str: @@ -233,20 +241,21 @@ def has_images(self) -> bool: return self._has_images # TODO: image sizes - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes, self._im_lengths + return self._document_sizes, self._image_lengths def get_document_size(self, index: int, patch_size: list[int]) -> int: - return self._document_sizes[index].item() + ( - sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) - if self._has_images - else 0 - ) + # return self._document_sizes[index].item() + ( + # sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) + # if self._has_images + # else 0 + # ) + return self._document_sizes[index].item(), self._image_lengths[index] if self._has_images else [] @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): @@ -255,7 +264,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP num_documents = 0 doc_lengths = [] n_images = [] - im_lengths = [] + image_lengths = [] im_positions = [] total_images = 0 pointers = [] @@ -288,7 +297,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: pixels = np.array(img) - im_lengths.append(np.array(pixels.shape[:2])) + image_lengths.append(np.array(pixels.shape[:2])) bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.append(document.image_positions) @@ -316,7 +325,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP n_images = np.array(n_images, dtype=np.int32) else: n_images = np.array([]) - im_lengths = np.stack(im_lengths, dtype=np.int32) + image_lengths = np.stack(image_lengths, dtype=np.int32) im_positions = np.array(im_positions, dtype=np.int32) # Write the index file (.idx) @@ -347,7 +356,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Number of images per document idx_stream.write(n_images.tobytes(order="C")) # n_pixels * 3 per image - idx_stream.write(im_lengths.tobytes(order="C")) + idx_stream.write(image_lengths.tobytes(order="C")) # Position of each image in the document idx_stream.write(im_positions.tobytes(order="C")) # Document indices, unused but needed for compatibility with Megatron-LM diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 288018b1..8acbf9ee 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -91,11 +91,14 @@ def __init__( self._num_samples = sampling.num_samples self._sequence_length = sampling.sequence_length self._patch_size = sampling.patch_size + self._image_height = sampling.image_height + self._image_width = sampling.image_width self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config self._truncate_documents = sampling.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") + # TODO Soham: use something else for this check, introducing has_images for just this check might be unnecessary. if self._indexed_dataset.has_images and self._truncate_documents: raise RuntimeError( "Truncating documents with images is not supported. Please turn off truncation to use images." @@ -137,8 +140,9 @@ def _sample(self) -> None: document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) image_token_sizes = torch.zeros_like(document_sizes).to(self._device) + # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum(sizes[0, :] // self._patch_size[0]) * (sizes[:, 1] // self._patch_size[1]) + image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size[0]) * (sizes[:, 1] // self._patch_size[1])) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -387,15 +391,26 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size = self._indexed_dataset.get_document_size(document_index, self._patch_size) + document_size, image_lengths = self._indexed_dataset.get_document_size(document_index, self._patch_size) + + image_sizes = [ + ImageProcessor.get_num_patches_from_dims( + *ImageProcessor.get_resize_dims( + *image_length, self._image_height, self._image_width, self._patch_size + ), + self._patch_size, + ) + for image_length in image_lengths + ] + image_tokens = sum(image_sizes) if not self._truncate_documents: - if document_size > self._sequence_length + 1: + if document_size + image_tokens > self._sequence_length + 1: # Document too long, ignore document_sampling_index += 1 continue tokens_in_sample = token_count % (self._sequence_length + 1) - if document_size + tokens_in_sample > self._sequence_length + 1: + if document_size + image_tokens + tokens_in_sample > self._sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._sequence_length + 1 - tokens_in_sample if token_count > token_start: @@ -408,7 +423,7 @@ def __getitem__(self, index: int) -> typing.Any: token_count += padding_size # Determine if the document belongs to the requested sample. - if token_count + document_size >= token_start: + if token_count + document_size + image_tokens >= token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) @@ -422,7 +437,7 @@ def __getitem__(self, index: int) -> typing.Any: for idx, im_position in enumerate(sample.image_positions): # image_positions.append(im_positions + len(token_ids) + image_tokens_added) image_positions.append(im_position + len(token_ids) + image_tokens_added) - image_tokens_added += ImageProcessor.get_num_patches(sample.images[idx]) + image_tokens_added += image_tokens images.append(sample.images) token_ids.append(sample.token_ids) if self._config.use_loss_masking_spans: @@ -433,7 +448,7 @@ def __getitem__(self, index: int) -> typing.Any: # Go to the next document. document_sampling_index += 1 - token_count += document_size + token_count += document_size + image_tokens sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) @@ -447,7 +462,8 @@ def __getitem__(self, index: int) -> typing.Any: if self._config.use_loss_masking_spans else None ) - images = [im for img_list in images for im in img_list] + images = [im for img_list in images for im in img_list] if images else None + image_positions = np.array(image_positions) if image_positions else None Assert.eq(len(token_ids) + image_tokens_added, self._sequence_length + 1) return GPTSample( diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py index 567c8146..edfeceb9 100644 --- a/fast_llm/data/image_processor.py +++ b/fast_llm/data/image_processor.py @@ -33,7 +33,7 @@ def resize(self, image): # TODO Soham: move to utils @classmethod - def get_resize_dims(height, width, max_height, max_width, patch_size: list[int]): + def get_resize_dims(self, height, width, max_height, max_width, patch_size: list[int]): ratio = max(height / max_height, width / max_width) return ( (math.ceil(height / ratio), math.ceil(width / ratio)) @@ -47,5 +47,9 @@ def normalize(self, image: torch.Tensor) -> torch.Tensor: @classmethod # TODO Soham: move to utils - def get_num_patches(image: torch.Tensor, patch_size: list[int]) -> torch.Tensor: - return (image.size[0] // patch_size[0]) * (image.size[1] // patch_size[1]) + def get_num_patches(self, image: torch.Tensor, patch_size: list[int]) -> torch.Tensor: + return (image.shape[0] // patch_size[0]) * (image.shape[1] // patch_size[1]) + + @classmethod + def get_num_patches_from_dims(self, height: int, width: int, patch_size: list[int]) -> torch.Tensor: + return (height // patch_size[0]) * (width // patch_size[1]) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 83d3d51a..16cfaf71 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,6 +55,21 @@ class BatchConfig(Config): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) + patch_size: list[int] | None = Field( + default=None, + desc="Patch size for each image token", + hint=FieldHint.optional, + ) + max_image_height: int | None = Field( + default=None, + desc="Maximum image height for each image token", + hint=FieldHint.optional, + ) + max_image_width: int | None = Field( + default=None, + desc="Maximum image width for each image token", + hint=FieldHint.optional, + ) num_micro_sequences: int = Field( init=False, desc="Number of micro-sequences to split each sample (= seqence length / micro-sequence length).", diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 75c5418b..0175296c 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -125,6 +125,11 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) + vision_encoder: VisionArchitectureConfig | None = Field( + default=None, + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) init_method_std_embed: float = Field( default=None, desc="Initialization scale for the vocabulary embedding and output weights (logits).", @@ -200,8 +205,14 @@ def _validate(self) -> None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) super()._validate() + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + super().setup_tensor_space(tensor_space) + + if self.vision_encoder is not None: + self.vision_encoder.setup_tensor_space(tensor_space) + -class MultiModalBaseConfig: +class MultiModalBaseConfig(BaseModelConfig): language_model: LanguageModelBaseConfig = Field( default_factory=LanguageModelBaseConfig, desc="Configuration for the language model.", diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 5a21368f..c90da81b 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -48,6 +48,10 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" +class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "pixtral" + + @config_class() class GPTArchitectureConfig(LanguageModelArchitectureConfig): _abstract = False diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 30ae8041..30f54f06 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -32,6 +32,7 @@ LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) @@ -163,54 +164,65 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters( self, + hf_base_prefix: str = "", + fast_llm_offset: int = 0, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers # Embeddings - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + converters.append( + WeightConverter( + f"layers.{fast_llm_offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight" + ) + ) - converters += self._create_lm_head_converters() + converters += self._create_lm_head_converters(hf_base_prefix, fast_llm_offset) for i in range(num_layers): - converters += self._create_transformer_layer_converters(i) + converters += self._create_transformer_layer_converters(i, hf_base_prefix, fast_llm_offset) return converters - def _create_transformer_layer_converters(self, i: int, ignore_export: bool = False) -> list[WeightConverter]: + def _create_transformer_layer_converters( + self, i: int, ignore_export: bool = False, hf_base_prefix: str = "", fast_llm_offset: int = 1 + ) -> list[WeightConverter]: transformer_config: TransformerConfig = self._model.config.base_model.transformer norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm converters = [] names_bias_cls = [ # Self-attn ( - f"layers.{i+1}.self_attn.query", - f"model.layers.{i}.self_attn.q_proj", + f"layers.{i+fast_llm_offset}.self_attn.query", + f"{hf_base_prefix}model.layers.{i}.self_attn.q_proj", transformer_config.add_attn_qkv_bias, QueryWeightConverter, ), ( - f"layers.{i+1}.self_attn.key_value", - (f"model.layers.{i}.self_attn.k_proj", f"model.layers.{i}.self_attn.v_proj"), + f"layers.{i+fast_llm_offset}.self_attn.key_value", + ( + f"{hf_base_prefix}model.layers.{i}.self_attn.k_proj", + f"{hf_base_prefix}model.layers.{i}.self_attn.v_proj", + ), transformer_config.add_attn_qkv_bias, KeyValueWeightConverter, ), ( - f"layers.{i+1}.self_attn.dense", - f"model.layers.{i}.self_attn.o_proj", + f"layers.{i+fast_llm_offset}.self_attn.dense", + f"{hf_base_prefix}model.layers.{i}.self_attn.o_proj", transformer_config.add_attn_dense_bias, WeightConverter, ), # Norm ( - f"layers.{i+1}.norm_1", - f"model.layers.{i}.input_layernorm", + f"layers.{i+fast_llm_offset}.norm_1", + f"{hf_base_prefix}model.layers.{i}.input_layernorm", norm_bias, WeightConverter, ), ( - f"layers.{i+1}.norm_2", - f"model.layers.{i}.post_attention_layernorm", + f"layers.{i+fast_llm_offset}.norm_2", + f"{hf_base_prefix}model.layers.{i}.post_attention_layernorm", norm_bias, WeightConverter, ), @@ -226,17 +238,23 @@ def _create_transformer_layer_converters(self, i: int, ignore_export: bool = Fal # MLP if ignore_export: converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mlp.layer_1", (), transformer_config.add_mlp_bias, cls=IgnoreExportWeightConverter + f"layers.{i+fast_llm_offset}.mlp.layer_1", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mlp.layer_2", (), transformer_config.add_mlp_bias, cls=IgnoreExportWeightConverter + f"layers.{i+fast_llm_offset}.mlp.layer_2", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, ) - converters += [IgnoreExportWeightConverter(f"layers.{i+1}.mlp.router.weight", ())] + converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] else: - converters += self._get_mlp_converters(f"layers.{i+1}", f"model.layers.{i}") + converters += self._get_mlp_converters(f"layers.{i+fast_llm_offset}", f"{hf_base_prefix}model.layers.{i}") return converters - def _create_lm_head_converters(self) -> list[WeightConverter]: + def _create_lm_head_converters(self, hf_base_prefix: str, fast_llm_offset: int = 1) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm @@ -245,15 +263,20 @@ def _create_lm_head_converters(self) -> list[WeightConverter]: # Next-token prediction head # Final norm converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias + f"layers.{num_layers + fast_llm_offset}.final_norm", f"{hf_base_prefix}model.norm", norm_bias ) # Output weights if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) else: - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + converters.append( + WeightConverter( + f"layers.{num_layers + fast_llm_offset}.output_weights", f"{hf_base_prefix}lm_head.weight" + ) + ) # MTP-heads > 0 are thrown away + # TODO Soham: handle offset with MTP for i in range(1, prediction_heads): logger.warning( f"The model weights for the multi-token prediction head {i} are discarded during conversion." @@ -531,6 +554,196 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] +class PixtralHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + lm_converters = super()._create_config_converters() + for converter in lm_converters: + if converter.fast_llm_names[0][0] == "transformer": + converter.export_names[0] = ("text_config", *converter.export_names[0]) + return lm_converters + [ + # Multimodal adapter + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=( + ( + "text_config", + "hidden_size", + ) + ), + ), + # Image processing and conv layer + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "image_size"),), + export_names=( + ( + "vision_config", + "image_size", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "patch_size"),), + export_names=( + ( + "vision_config", + "patch_size", + ) + ), + ), + # Vision Transformer + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "num_hidden_layers"),), + export_names=( + ( + "vision_config", + "num_hidden_layers", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "hidden_size"),), + export_names=( + ( + "vision_config", + "hidden_size", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "num_attention_heads"),), + export_names=( + ( + "vision_config", + "num_attention_heads", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "intermediate_size"),), + export_names=( + ( + "vision_config", + "intermediate_size", + ) + ), + ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "activation_type"),), + export_names=( + ( + "vision_config", + "hidden_act", + ) + ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "num_channels"),), + export_names=( + ( + "vision_config", + "num_channels", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "attention_dropout"),), + export_names=( + ( + "vision_config", + "attention_dropout", + ) + ), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "rope_theta"),), + export_names=(("vision_config", "rope_theta"),), + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "initializer_range"),), + export_names=(("vision_config", "initializer_range"),), + ), + ] + + def _create_vision_transformer_converters(self) -> list[WeightConverter]: + num_layers = self._model.config.base_model.vision_encoder.encoder.num_hidden_layers + vision_transformer_converters = [] + for i in range(num_layers): + vision_transformer_converters += [ + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.k_proj.weight", + f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.v_proj.weight", + f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.q_proj.weight", + f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.o_proj.weight", + f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention_norm.weight", + f"vision_tower.transformer.layers.{i}.attention_norm.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.down_proj.weight", + f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", + f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.up_proj.weight", + f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", + ), + WeightConverter( + f"layers.0._vision_encoder.encoder.transformer.layers.{i}.ffn_norm.weight", + f"vision_tower.transformer.layers.{i}.ffn_norm.weight", + ), + ] + + return vision_transformer_converters + + def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: + patch_conv_converter = WeightConverter( + "layers.0._vision_encoder.patch_conv.weight", + "vision_tower.patch_conv.weight", + ) + vision_transformer_converters = self._create_vision_transformer_converters() + adapter_converters = [ + WeightConverter( + "layers.0._vision_encoder._adapter.layer_1.weight", + "multi_modal_projector.linear_1.weight", + ), + WeightConverter( + "layers.0._vision_encoder._adapter.layer_1.bias", + "multi_modal_projector.linear_1.bias", + ), + WeightConverter( + "layers.0._vision_encoder._adapter.layer_2.weight", + "multi_modal_projector.linear_2.weight", + ), + WeightConverter( + "layers.0._vision_encoder._adapter.layer_2.bias", + "multi_modal_projector.linear_2.bias", + ), + ] + return [patch_conv_converter] + vision_transformer_converters + adapter_converters + + def _create_weight_converters(self) -> list[WeightConverter]: + vision_encoder_converter = self._create_vision_encoder_weight_converters() + lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=1) + return vision_encoder_converter + lm_converters + + class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat @@ -580,4 +793,5 @@ class AutoGPTHuggingfaceCheckpointHandler( Qwen2GPTHuggingfaceCheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler, MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, + PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index e878530c..67411641 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -26,6 +26,7 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.vision_encoder.encoder import VisionEncoder from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -100,7 +101,10 @@ def get_layers(self) -> list[Layer]: LanguageModelEmbedding(self._config, self._tensor_space), LanguageModelHead(self._config, self._tensor_space, 0), ] - return [ + return ( + [VisionEncoder(self._config, self._tensor_space)] if self._config.vision_encoder is not None else [] + ) + [ + # return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ TransformerLayer( @@ -312,11 +316,11 @@ def preprocess( @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[0] + return self.layers[self._config.vision_encoder is not None] @property def transformer_layers(self) -> list[TransformerLayer]: - return self.layers[1:-1] + return self.layers[(self._config.vision_encoder is not None) + 1 : -1] @property def model_head(self) -> LanguageModelHead: @@ -331,7 +335,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - (0, *self.model_head_indices), + (self._config.vision_encoder is not None, *self.model_head_indices), ) } elif self._config.prediction_heads > 1: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 376d8b84..b801fbd3 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -21,6 +21,9 @@ def _get_data(self) -> GPTData: vocab_size=self._config.model.base_model.vocab_size, max_sequence_length=self._config.batch.sequence_length, cross_document_attention=self._config.batch.cross_document_attention, + patch_size=self._config.batch.patch_size, + max_image_height=self._config.batch.max_image_height, + max_image_width=self._config.batch.max_image_width, ) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: From 6521e41920fe8b17f207b32f58c43978bfcc8a46 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 24 Apr 2025 19:00:23 +0000 Subject: [PATCH 06/82] vision model --- fast_llm/layers/vision_encoder/adapter.py | 44 ++++++++ fast_llm/layers/vision_encoder/config.py | 128 ++++++++++++++++++++++ fast_llm/layers/vision_encoder/encoder.py | 89 +++++++++++++++ 3 files changed, 261 insertions(+) create mode 100644 fast_llm/layers/vision_encoder/adapter.py create mode 100644 fast_llm/layers/vision_encoder/config.py create mode 100644 fast_llm/layers/vision_encoder/encoder.py diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py new file mode 100644 index 00000000..234c451a --- /dev/null +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -0,0 +1,44 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.common.linear import LinearBase +from fast_llm.layers.transformer.config import TransformerDimNames +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames +from fast_llm.tensor import init_normal_ + + +class VisionAdapter(Layer): + """ + Vision adapter layer for the LLM. + """ + + def __init__(self, intermediate_size: int, tensor_space: TensorSpace, name: str = "vision_adapter"): + super().__init__() + self._name = name + input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + self.layer_1 = LinearBase( + input_dim, + tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + self.layer_2 = LinearBase( + tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + tensor_space.get_tensor_dim(TransformerDimNames.hidden), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ): + return self.layer_2(self.layer_1(input_)) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py new file mode 100644 index 00000000..d410f92d --- /dev/null +++ b/fast_llm/layers/vision_encoder/config.py @@ -0,0 +1,128 @@ +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.base_model.config import BaseModelArchitectureConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType + + +class VisionEncoderDimNames: + out_channels = "vision_out_channels" + intermediate_size = "vision_intermediate_size" + patch_height = "vision_patch_height" + patch_width = "vision_patch_width" + + +@config_class() +class PatchConvConfig(BaseModelArchitectureConfig): + _abstract = False + """ + Configuration class for the convolution layers to apply on the image patches + """ + in_channels: int = Field( + default=3, + desc="Number of input channels for the convolution layers. Typically 3 for RGB images.", + hint=FieldHint.optional, + ) + bias: bool = Field( + default=False, desc="Whether to use a bias term in the convolution layers.", hint=FieldHint.optional + ) + height: int = Field( + default=16, + desc="Height of the image patches considered as tokens", + ) + width: int | None = Field( + default=16, + desc="Width of the image patches considered as tokens", + ) + + +@config_class() +class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + """ + Configuration class for the vision encoder, which transforms images into embeddings + """ + path: str | None = Field(default=None, desc="Path to a pretrained vision encoder model.", hint=FieldHint.optional) + hidden_size: int = Field( + default=1024, desc="The size of the hidden layers in the transformer model.", hint=FieldHint.optional + ) + intermediate_size: int = Field( + default=4096, + desc="The size of the intermediate (feed-forward) layers in the transformer model.", + hint=FieldHint.optional, + ) + num_hidden_layers: int = Field( + default=24, desc="The number of hidden layers in the transformer model.", hint=FieldHint.optional + ) + num_attention_heads: int = Field( + default=16, desc="The number of attention heads for the multi-head attention layers.", hint=FieldHint.optional + ) + num_channels: int = Field( + default=3, desc="Number of channels in the input image, typically 3 for RGB.", hint=FieldHint.optional + ) + image_size: int = Field( + default=1024, desc="The size of the input images (assumed square).", hint=FieldHint.optional + ) + patch_size: int = Field(default=16, desc="The size of the image patches to be encoded.", hint=FieldHint.optional) + hidden_act: str = Field( + default="gelu", desc="The activation function used in the hidden layers.", hint=FieldHint.optional + ) + attention_dropout: float = Field( + default=0.0, desc="The dropout probability for attention layers.", hint=FieldHint.optional + ) + rope_theta: float = Field( + default=10000.0, desc="The base value for rotary position embeddings.", hint=FieldHint.optional + ) + initializer_range: float = Field( + default=0.02, desc="The standard deviation of the normal initializer.", hint=FieldHint.optional + ) + + +@config_class() +class VisionArchitectureConfig(BaseModelArchitectureConfig): + _abstract = False + + encoder: VisionEncoderArchitectureConfig = Field( + default_factory=VisionEncoderArchitectureConfig, + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) + adapter_size: int = Field( + default=5120, + desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + hint=FieldHint.optional, + ) + adapter_activation_type: ActivationType = Field( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.encoder.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.intermediate_size, self.adapter_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_height, self.encoder.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_width, self.encoder.patch_size)) + # tensor_space.add_tensor_dim( + # CompositeTensorDim(VisionEncoderDimNames.) + # ) + + # patch_convolution: PatchConvConfig = Field( + # default_factory=PatchConvConfig, + # desc="Configuration for the convolution layers applied to the image patches.", + # hint=FieldHint.optional + # ) + # normalization: NormalizationArchitectureConfig = Field( + # default_factory=NormalizationArchitectureConfig, + # desc="Configuration for the normalization layers applied to the image patches.", + # hint=FieldHint.optional + # ) + # transformer: TransformerArchitectureConfig = Field( + # default_factory=TransformerArchitectureConfig, + # desc="Configuration for the transformer layers applied to the image patches.", + # hint=FieldHint.optional + # ) + # patch_rotary: RotaryArchitectureConfig = Field( + # default_factory=RotaryArchitectureConfig, + # desc="Configuration for the rotary positional embeddings applied to the image patches.", + # hint=FieldHint.optional + # ) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py new file mode 100644 index 00000000..2ea5c1e4 --- /dev/null +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -0,0 +1,89 @@ +import functools +import typing + +import torch +from transformers import PixtralVisionConfig, PixtralVisionModel + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.adapter import VisionAdapter +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ + + +class VisionEncoder(Layer): + """ + A vision encoder layer for creating token embeddings from vision model + """ + + def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): + super().__init__() + + self._config = config.vision_encoder + self._distributed_config = tensor_space.distributed_config + with torch.device("meta"): + if self._config.encoder.path: + self._vision_encoder = PixtralVisionModel.from_pretrained( + self._config.encoder.path, torch_dtype=self._distributed_config.training_dtype.torch + ) + else: + self._vision_encoder = PixtralVisionModel( + PixtralVisionConfig( + hidden_size=self._config.hidden_size, + intermediate_size=self._config.intermediate_size, + num_hidden_layers=self._config.num_hidden_layers, + num_attention_heads=self._config.num_attention_heads, + num_channels=self._config.num_channels, + image_size=self._config.image_size, + patch_size=self._config.patch_size, + hidden_act=self._config.hidden_act, + attention_dropout=self._config.attention_dropout, + rope_theta=self._config.rope_theta, + initializer_range=self._config.initializer_range, + ) + ) + param_names = [] + # gather all names first. PyTorch complains if we do it in the loop + for name, param in self._vision_encoder.named_parameters(): + param_names.append(name) + for name in param_names: + # exclude .weight/.bias + *module_path, stem = name.split(".")[:-1] + module = functools.reduce(getattr, module_path, self._vision_encoder) + param = self._vision_encoder.get_parameter(name) + setattr( + module, + stem, + ParameterMeta.from_dims( + tuple(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), + init_method=init_normal_(), + ), + # ParameterMeta( + # param, + # tensor_name=name, + # dims=(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), + # init_method=init_normal_(), + # allow_no_grad=True, + # ), + ) + self._adapter = VisionAdapter( + intermediate_size=tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + tensor_space=tensor_space, + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Vision Output", + dtype=self._distributed_config.training_dtype.torch, + ) + return self._vision_encoder(input_) From daf586f8d6a428398674771bce71a61a7e32cdbf Mon Sep 17 00:00:00 2001 From: root Date: Thu, 24 Apr 2025 21:49:28 +0000 Subject: [PATCH 07/82] wip --- fast_llm/models/gpt/config.py | 5 ++-- fast_llm/models/gpt/conversion.py | 43 ++++++++++++++++--------------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index c90da81b..ca73b879 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -48,8 +48,8 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" -class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): - name: typing.ClassVar[str] = "pixtral" +class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "llava" @config_class() @@ -109,6 +109,7 @@ class GPTModelConfig(FastLLMModelConfig): Qwen2GPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 30f54f06..ad74ad53 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -30,9 +30,9 @@ GPTArchitectureConfig, GPTModelConfig, LlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, - PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) @@ -367,7 +367,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), RenameParamConverter( fast_llm_names=(("transformer", "kv_channels"),), - export_names=(("head_dim"),), + export_names=(("head_dim",),), ), ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), @@ -554,23 +554,24 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] -class PixtralHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): +class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + @classmethod def _create_config_converters(cls) -> list[ParamConverter]: lm_converters = super()._create_config_converters() for converter in lm_converters: - if converter.fast_llm_names[0][0] == "transformer": - converter.export_names[0] = ("text_config", *converter.export_names[0]) + if isinstance(converter, (RenameParamConverter, MappedConfigParamConverter, RopeScalingParamConverter)): + # Llava uses a different name for the text config + # if converter.fast_llm_names[0][0] == "transformer": + converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) + # if converter.fast_llm_names[0][0] == "transformer": + # converter.export_names[0] = ("text_config", *converter.export_names[0]) return lm_converters + [ # Multimodal adapter RenameParamConverter( fast_llm_names=(("vision_encoder", "adapter_size"),), - export_names=( - ( - "text_config", - "hidden_size", - ) - ), + export_names=(("text_config", "hidden_size"),), ), # Image processing and conv layer RenameParamConverter( @@ -579,7 +580,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "image_size", - ) + ), ), ), RenameParamConverter( @@ -588,7 +589,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "patch_size", - ) + ), ), ), # Vision Transformer @@ -598,7 +599,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "num_hidden_layers", - ) + ), ), ), RenameParamConverter( @@ -607,7 +608,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "hidden_size", - ) + ), ), ), RenameParamConverter( @@ -616,7 +617,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "num_attention_heads", - ) + ), ), ), RenameParamConverter( @@ -625,7 +626,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "intermediate_size", - ) + ), ), ), MappedConfigParamConverter( @@ -634,7 +635,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "hidden_act", - ) + ), ), fast_llm_value=ActivationType.from_hf_name, export_value=lambda activation_type: activation_type.hf_name, @@ -645,7 +646,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "num_channels", - ) + ), ), ), RenameParamConverter( @@ -654,7 +655,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ( "vision_config", "attention_dropout", - ) + ), ), ), RenameParamConverter( @@ -793,5 +794,5 @@ class AutoGPTHuggingfaceCheckpointHandler( Qwen2GPTHuggingfaceCheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler, MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, - PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, + LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, } From ef4488d4f94b9c19b04f409917b3091b8e8601e8 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 25 Apr 2025 22:47:08 +0000 Subject: [PATCH 08/82] wip --- fast_llm/layers/language_model/config.py | 7 ++- fast_llm/layers/vision_encoder/config.py | 11 +++++ fast_llm/layers/vision_encoder/encoder.py | 35 +++++++-------- fast_llm/models/gpt/conversion.py | 54 ++++++++++++++++------- 4 files changed, 69 insertions(+), 38 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 0175296c..ec80a933 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -44,6 +44,11 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.core, ) + vision_encoder: None | VisionArchitectureConfig = Field( + default=None, + desc="Configuration for the vision encoder that transforms images into embeddings.", + hint=FieldHint.optional, + ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -125,7 +130,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) - vision_encoder: VisionArchitectureConfig | None = Field( + vision_encoder: None | VisionArchitectureConfig = FieldUpdate( default=None, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index d410f92d..76af3d37 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -2,6 +2,7 @@ from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.config import NormalizationType class VisionEncoderDimNames: @@ -42,6 +43,11 @@ class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): Configuration class for the vision encoder, which transforms images into embeddings """ path: str | None = Field(default=None, desc="Path to a pretrained vision encoder model.", hint=FieldHint.optional) + pre_norm: NormalizationType = Field( + default=NormalizationType.rms_norm, + desc="The type of normalization to use before the transformer layers.", + hint=FieldHint.optional, + ) hidden_size: int = Field( default=1024, desc="The size of the hidden layers in the transformer model.", hint=FieldHint.optional ) @@ -75,6 +81,11 @@ class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): initializer_range: float = Field( default=0.02, desc="The standard deviation of the normal initializer.", hint=FieldHint.optional ) + activation_type: ActivationType = Field( + default=ActivationType.silu, + desc="The activation function used in the hidden layers. Default: SiLU.", + hint=FieldHint.optional, + ) @config_class() diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 2ea5c1e4..88064b51 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -31,17 +31,17 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): else: self._vision_encoder = PixtralVisionModel( PixtralVisionConfig( - hidden_size=self._config.hidden_size, - intermediate_size=self._config.intermediate_size, - num_hidden_layers=self._config.num_hidden_layers, - num_attention_heads=self._config.num_attention_heads, - num_channels=self._config.num_channels, - image_size=self._config.image_size, - patch_size=self._config.patch_size, - hidden_act=self._config.hidden_act, - attention_dropout=self._config.attention_dropout, - rope_theta=self._config.rope_theta, - initializer_range=self._config.initializer_range, + hidden_size=self._config.encoder.hidden_size, + intermediate_size=self._config.encoder.intermediate_size, + num_hidden_layers=self._config.encoder.num_hidden_layers, + num_attention_heads=self._config.encoder.num_attention_heads, + num_channels=self._config.encoder.num_channels, + image_size=self._config.encoder.image_size, + patch_size=self._config.encoder.patch_size, + hidden_act=self._config.encoder.hidden_act, + attention_dropout=self._config.encoder.attention_dropout, + rope_theta=self._config.encoder.rope_theta, + initializer_range=self._config.encoder.initializer_range, ) ) param_names = [] @@ -49,8 +49,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): for name, param in self._vision_encoder.named_parameters(): param_names.append(name) for name in param_names: - # exclude .weight/.bias - *module_path, stem = name.split(".")[:-1] + *module_path, stem = name.split(".") module = functools.reduce(getattr, module_path, self._vision_encoder) param = self._vision_encoder.get_parameter(name) setattr( @@ -60,14 +59,10 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): tuple(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), init_method=init_normal_(), ), - # ParameterMeta( - # param, - # tensor_name=name, - # dims=(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), - # init_method=init_normal_(), - # allow_no_grad=True, - # ), ) + none_params = [key for key, value in module._parameters.items() if value is None] + for key in none_params: + module._parameters.pop(key) self._adapter = VisionAdapter( intermediate_size=tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), tensor_space=tensor_space, diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index ad74ad53..f730d79c 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -173,14 +173,16 @@ def _create_weight_converters( # Embeddings converters.append( WeightConverter( - f"layers.{fast_llm_offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight" + f"layers.{fast_llm_offset - 1}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight" ) ) converters += self._create_lm_head_converters(hf_base_prefix, fast_llm_offset) for i in range(num_layers): - converters += self._create_transformer_layer_converters(i, hf_base_prefix, fast_llm_offset) + converters += self._create_transformer_layer_converters( + i, hf_base_prefix=hf_base_prefix, fast_llm_offset=fast_llm_offset + ) return converters @@ -560,6 +562,9 @@ class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: lm_converters = super()._create_config_converters() + lm_converters[-2] = ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] + ) for converter in lm_converters: if isinstance(converter, (RenameParamConverter, MappedConfigParamConverter, RopeScalingParamConverter)): # Llava uses a different name for the text config @@ -674,39 +679,39 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: for i in range(num_layers): vision_transformer_converters += [ WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.k_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention.k_proj.weight", f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.v_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention.v_proj.weight", f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.q_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention.q_proj.weight", f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention.o_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention.o_proj.weight", f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.attention_norm.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.attention_norm.weight", f"vision_tower.transformer.layers.{i}.attention_norm.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.down_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.down_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.feed_forward.up_proj.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.up_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.encoder.transformer.layers.{i}.ffn_norm.weight", + f"layers.0._vision_encoder.transformer.layers.{i}.ffn_norm.weight", f"vision_tower.transformer.layers.{i}.ffn_norm.weight", ), ] @@ -718,30 +723,45 @@ def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: "layers.0._vision_encoder.patch_conv.weight", "vision_tower.patch_conv.weight", ) + # TODO Soham: use _get_weight_and_bias_converters? + layer_norm_converter = WeightConverter( + "layers.0._vision_encoder.ln_pre.weight", + "vision_tower.ln_pre.weight", + ) + if self._model.config.base_model.vision_encoder.encoder.pre_norm == NormalizationType.layer_norm: + layer_norm_bias_converter = WeightConverter( + "layers.0._vision_encoder.ln_pre.bias", + "vision_tower.ln_pre.bias", + ) vision_transformer_converters = self._create_vision_transformer_converters() adapter_converters = [ WeightConverter( - "layers.0._vision_encoder._adapter.layer_1.weight", + "layers.0._adapter.layer_1.weight", "multi_modal_projector.linear_1.weight", ), WeightConverter( - "layers.0._vision_encoder._adapter.layer_1.bias", + "layers.0._adapter.layer_1.bias", "multi_modal_projector.linear_1.bias", ), + # TODO Soham: conditionally add bias WeightConverter( - "layers.0._vision_encoder._adapter.layer_2.weight", + "layers.0._adapter.layer_2.weight", "multi_modal_projector.linear_2.weight", ), WeightConverter( - "layers.0._vision_encoder._adapter.layer_2.bias", + "layers.0._adapter.layer_2.bias", "multi_modal_projector.linear_2.bias", ), ] - return [patch_conv_converter] + vision_transformer_converters + adapter_converters + return ( + [patch_conv_converter, layer_norm_converter, layer_norm_bias_converter] + + vision_transformer_converters + + adapter_converters + ) def _create_weight_converters(self) -> list[WeightConverter]: vision_encoder_converter = self._create_vision_encoder_weight_converters() - lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=1) + lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=2) return vision_encoder_converter + lm_converters From 6d9d595b921bc5139a910fe843d6ece3403445fb Mon Sep 17 00:00:00 2001 From: root Date: Mon, 28 Apr 2025 15:23:50 +0000 Subject: [PATCH 09/82] missing files --- fast_llm/data/dataset/gpt/memmap.py | 6 +- fast_llm/engine/multi_stage/stage_base.py | 3 + fast_llm/layers/multi_modal/embedding.py | 83 +++++++++++++++++++ fast_llm/layers/vision_encoder/config.py | 57 ++++++++++++- fast_llm/layers/vision_encoder/encoder.py | 26 +++--- .../layers/vision_encoder/preprocessing.py | 74 +++++++++++++++++ fast_llm/models/gpt/conversion.py | 44 +++++----- fast_llm/models/gpt/model.py | 59 +++++++++++-- setup.cfg | 2 +- 9 files changed, 309 insertions(+), 45 deletions(-) create mode 100644 fast_llm/layers/multi_modal/embedding.py create mode 100644 fast_llm/layers/vision_encoder/preprocessing.py diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 43fba843..99bfbfa4 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -200,7 +200,7 @@ def get( for image_length in self._image_lengths[idx]: # TODO Soham: verify reshape dimension order n_pixels = image_length.prod(initial=3) - images.append(pixels[start : start + n_pixels].reshape(image_length[0], image_length[1], 3)) + images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels # TODO Soham: return loss_masking_spans return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) @@ -296,8 +296,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP for image in document.images: # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: - pixels = np.array(img) - image_lengths.append(np.array(pixels.shape[:2])) + pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + image_lengths.append(np.array(pixels.shape[1:])) bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.append(document.image_positions) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 0f83c862..e97ef041 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -161,6 +161,9 @@ def _replace(module: torch.nn.Module): nonlocal i for key in module._parameters: meta = typing.cast(ParameterMeta, module._parameters[key]) + # TODO Soham: clean way to get around check? + if meta is None: + continue module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) i += 1 diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py new file mode 100644 index 00000000..a92fdc4e --- /dev/null +++ b/fast_llm/layers/multi_modal/embedding.py @@ -0,0 +1,83 @@ +import typing + +import torch + +from fast_llm.core.distributed import set_generator +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionModelKwargs +from fast_llm.layers.vision_encoder.encoder import VisionEncoder +from fast_llm.tensor import TensorMeta + + +class MultiModalEmbedding(LanguageModelEmbedding): + """ + Multi-modal embedding layer to combine embeddings from text, image and more modalities. + """ + + def __init__( + self, + config: LanguageModelBaseConfig, + tensor_space: TensorSpace, + ): + super().__init__(config, tensor_space) + self.vision_encoder = VisionEncoder(config, tensor_space) + + def _forward( + self, + input_: torch.Tensor, + position_ids: torch.Tensor | None, + images: torch.Tensor | None, + image_sizes: torch.Tensor | None, + image_positions: list[torch.Tensor] | None, + ) -> torch.Tensor: + image_embeddings = self.vision_encoder(images, kwargs={VisionModelKwargs.image_sizes: image_sizes}) + # TODO Soham: offset position ids + img_tokens_seen = 0 + image_idx = 0 + text_embeddings = super()._forward(input_, position_ids) + embeddings = [] + for sample_idx, positions in enumerate(image_positions): + embedding_parts = [] + for position in positions: + image = images[image_idx] + image_tokens = (image.size[1] // self._config.vision_encoder.encoder.patch_size) * ( + image.size[2] // self._config.vision_encoder.encoder.patch_size + ) + image_idx += 1 + img_tokens_seen += image_tokens + embedding_parts.append(text_embeddings[sample_idx, :position]) + embedding_parts.append(image_embeddings[img_tokens_seen : img_tokens_seen + image_tokens]) + embedding_parts.append(text_embeddings[sample_idx, position + image_tokens :]) + embeddings.append(torch.cat(embedding_parts, dim=0)) + embeddings = torch.stack(embeddings, dim=0) + with set_generator( + self._tensor_space.distributed.tp_generator + if self._sequence_parallel + else self._tensor_space.distributed.pp_generator + ): + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + return embeddings.to(self._residual_dtype) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Embedding output", + dtype=self._residual_dtype, + ) + return self._forward( + input_, + kwargs.get(LanguageModelKwargs.position_ids), + kwargs.get(VisionModelKwargs.images), + kwargs.get(VisionModelKwargs.image_sizes), + kwargs.get(VisionModelKwargs.image_positions), + ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 76af3d37..5e472251 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,4 +1,4 @@ -from fast_llm.config import Field, FieldHint, config_class +from fast_llm.config import Config, Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType @@ -12,6 +12,17 @@ class VisionEncoderDimNames: patch_width = "vision_patch_width" +class VisionModelKwargs: + images = "images" + image_positions = "image_positions" + image_height = "image_height" + image_width = "image_width" + image_sizes = "image_sizes" + image_mean = "image_normalization_mean" + image_std = "image_normalization_std" + image_rescale_factor = "image_rescale_factor" + + @config_class() class PatchConvConfig(BaseModelArchitectureConfig): _abstract = False @@ -88,6 +99,45 @@ class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): ) +@config_class() +class ImageNormalizationConfig(Config): + mean_r: float = Field( + default=0.48145466, + desc="Mean value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_g: float = Field( + default=0.4578275, + desc="Mean value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_b: float = Field( + default=0.40821073, + desc="Mean value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_r: float = Field( + default=0.26862954, + desc="Standard deviation value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_g: float = Field( + default=0.26130258, + desc="Standard deviation value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_b: float = Field( + default=0.27577711, + desc="Standard deviation value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + rescale_factor: float = Field( + default=255.0, + desc="Rescale factor for the image normalization process.", + hint=FieldHint.optional, + ) + + @config_class() class VisionArchitectureConfig(BaseModelArchitectureConfig): _abstract = False @@ -107,6 +157,11 @@ class VisionArchitectureConfig(BaseModelArchitectureConfig): desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", hint=FieldHint.core, ) + normalization: ImageNormalizationConfig = Field( + default_factory=ImageNormalizationConfig, + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.encoder.hidden_size)) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 88064b51..b028fa1f 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -9,10 +9,11 @@ from fast_llm.layers.language_model.config import LanguageModelBaseConfig from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionModelKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +# TODO Soham: should this just be nn.Module? class VisionEncoder(Layer): """ A vision encoder layer for creating token embeddings from vision model @@ -25,11 +26,14 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._distributed_config = tensor_space.distributed_config with torch.device("meta"): if self._config.encoder.path: - self._vision_encoder = PixtralVisionModel.from_pretrained( + self.vision_encoder = PixtralVisionModel.from_pretrained( self._config.encoder.path, torch_dtype=self._distributed_config.training_dtype.torch ) else: - self._vision_encoder = PixtralVisionModel( + # TODO Soham options to fix rotary: + # 1. load PixtralTransformer instead of PixtralVisionModel. Required to implement conv2d, ln_pre separately and store positional embeddings in kwargs_meta + # 2. set self.vision_encoder.position_embeddings = PixtralRotaryEmbedding(config) outside of meta scope + self.vision_encoder = PixtralVisionModel( PixtralVisionConfig( hidden_size=self._config.encoder.hidden_size, intermediate_size=self._config.encoder.intermediate_size, @@ -46,12 +50,12 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): ) param_names = [] # gather all names first. PyTorch complains if we do it in the loop - for name, param in self._vision_encoder.named_parameters(): + for name, param in self.vision_encoder.named_parameters(): param_names.append(name) for name in param_names: *module_path, stem = name.split(".") - module = functools.reduce(getattr, module_path, self._vision_encoder) - param = self._vision_encoder.get_parameter(name) + module = functools.reduce(getattr, module_path, self.vision_encoder) + param = self.vision_encoder.get_parameter(name) setattr( module, stem, @@ -60,10 +64,10 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): init_method=init_normal_(), ), ) - none_params = [key for key, value in module._parameters.items() if value is None] - for key in none_params: - module._parameters.pop(key) - self._adapter = VisionAdapter( + # none_params = [key for key, value in module._parameters.items() if value is None] + # for key in none_params: + # module._parameters.pop(key) + self.adapter = VisionAdapter( intermediate_size=tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), tensor_space=tensor_space, ) @@ -81,4 +85,4 @@ def forward( tensor_name="Vision Output", dtype=self._distributed_config.training_dtype.torch, ) - return self._vision_encoder(input_) + return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py new file mode 100644 index 00000000..7ebfd3d7 --- /dev/null +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -0,0 +1,74 @@ +import typing + +import torch +import torchvision.transforms.v2.functional as F + +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig, VisionModelKwargs + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int) -> tuple[int, int]: + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + ratio = max(height / max_height, width / max_width) + return ( + (int(height / ratio), int(width / ratio)) + if ratio > 1 + else (max_height * (height // max_height), max_width * (width // max_width)) + ) + + +def resize(image: torch.Tensor, max_height: int, max_width: int) -> tuple[int, int]: + resize_dims = get_resize_dims(image.size(1), image.size(2), max_height, max_width) + # TODO: options for interpolation mode? + return F.resize(image, size=resize_dims, interpolation=F.InterpolationMode.BICUBIC) + + +def normalize(image: torch.Tensor, mean: list[float], std: list[float]) -> torch.Tensor: + """ + Normalize the image using the specified mean and standard deviation. + """ + return F.normalize(image, mean=mean, std=std) + + +def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: + """ + Pad images on the right and bottom with 0s untitl max_height and max_width + """ + width_padding = max(0, max_height - image.size(1)) + depth_padding = max(0, max_width - image.size(2)) + return F.pad(image, (0, 0, width_padding, depth_padding), 0) + + +class VisionPreprocessor: + def __init__(self, config: VisionArchitectureConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + images = kwargs.get("images") + im_height = kwargs.get(VisionModelKwargs.image_height) + im_width = kwargs.get(VisionModelKwargs.image_width) + kwargs[VisionModelKwargs.image_sizes] = [(im.size(1), im.size(2)) for im in images] + images = [ + pad( + normalize( + resize(image, im_height, im_width) / kwargs[VisionModelKwargs.image_rescale_factor], + mean=kwargs[VisionModelKwargs.image_mean], + std=kwargs[VisionModelKwargs.image_std], + ), + max_height=im_height, + max_width=im_width, + ) + for image in images + ] + images = torch.stack(images, dim=0).to( + # TODO Soham: is this needed? + device=self._tensor_space.distributed.device, + dtype=self._distributed_config.training_dtype.torch, + ) + kwargs[VisionModelKwargs.images] = images diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index f730d79c..3caaee5a 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -679,39 +679,39 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: for i in range(num_layers): vision_transformer_converters += [ WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention.k_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.k_proj.weight", f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention.v_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.v_proj.weight", f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention.q_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.q_proj.weight", f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention.o_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.o_proj.weight", f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.attention_norm.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention_norm.weight", f"vision_tower.transformer.layers.{i}.attention_norm.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.down_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.down_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.feed_forward.up_proj.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.up_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", ), WeightConverter( - f"layers.0._vision_encoder.transformer.layers.{i}.ffn_norm.weight", + f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.ffn_norm.weight", f"vision_tower.transformer.layers.{i}.ffn_norm.weight", ), ] @@ -720,48 +720,48 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: patch_conv_converter = WeightConverter( - "layers.0._vision_encoder.patch_conv.weight", + "layers.0.vision_encoder.vision_encoder.patch_conv.weight", "vision_tower.patch_conv.weight", ) # TODO Soham: use _get_weight_and_bias_converters? + layernorm_converters = [] layer_norm_converter = WeightConverter( - "layers.0._vision_encoder.ln_pre.weight", + "layers.0.vision_encoder.vision_encoder.ln_pre.weight", "vision_tower.ln_pre.weight", ) + layernorm_converters.append(layer_norm_converter) + layer_norm_converter if self._model.config.base_model.vision_encoder.encoder.pre_norm == NormalizationType.layer_norm: layer_norm_bias_converter = WeightConverter( - "layers.0._vision_encoder.ln_pre.bias", + "layers.0.vision_encoder.vision_encoder.ln_pre.bias", "vision_tower.ln_pre.bias", ) + layernorm_converters.append(layer_norm_bias_converter) vision_transformer_converters = self._create_vision_transformer_converters() adapter_converters = [ WeightConverter( - "layers.0._adapter.layer_1.weight", + "layers.0.vision_encoder.adapter.layer_1.weight", "multi_modal_projector.linear_1.weight", ), WeightConverter( - "layers.0._adapter.layer_1.bias", + "layers.0.vision_encoder.adapter.layer_1.bias", "multi_modal_projector.linear_1.bias", ), # TODO Soham: conditionally add bias WeightConverter( - "layers.0._adapter.layer_2.weight", + "layers.0.vision_encoder.adapter.layer_2.weight", "multi_modal_projector.linear_2.weight", ), WeightConverter( - "layers.0._adapter.layer_2.bias", + "layers.0.vision_encoder.adapter.layer_2.bias", "multi_modal_projector.linear_2.bias", ), ] - return ( - [patch_conv_converter, layer_norm_converter, layer_norm_bias_converter] - + vision_transformer_converters - + adapter_converters - ) + return [patch_conv_converter] + layernorm_converters + vision_transformer_converters + adapter_converters def _create_weight_converters(self) -> list[WeightConverter]: vision_encoder_converter = self._create_vision_encoder_weight_converters() - lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=2) + lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=1) return vision_encoder_converter + lm_converters diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 67411641..0890051e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -14,6 +14,7 @@ from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor +from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding from fast_llm.layers.transformer.config import ( RoutingType, TransformerDimNames, @@ -26,7 +27,8 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.encoder import VisionEncoder +from fast_llm.layers.vision_encoder.config import VisionModelKwargs +from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -72,6 +74,9 @@ def __init__( else: self._flash_varlen_preprocessor = FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space) + if self._config.vision_encoder: + self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) + def get_output_layers(self) -> list[Layer]: return [ layer @@ -98,14 +103,19 @@ def get_layers(self) -> list[Layer]: if self._config.transformer.num_layers == 0: Assert.eq(self._config.prediction_heads, 1) return [ - LanguageModelEmbedding(self._config, self._tensor_space), + ( + LanguageModelEmbedding(self._config, self._tensor_space) + if self._config.vision_encoder is None + else MultiModalEmbedding(self._config, self._tensor_space) + ), LanguageModelHead(self._config, self._tensor_space, 0), ] - return ( - [VisionEncoder(self._config, self._tensor_space)] if self._config.vision_encoder is not None else [] - ) + [ - # return [ - LanguageModelEmbedding(self._config, self._tensor_space), + return [ + ( + LanguageModelEmbedding(self._config, self._tensor_space) + if self._config.vision_encoder is None + else MultiModalEmbedding(self._config, self._tensor_space) + ), *[ TransformerLayer( self._config.transformer, @@ -139,6 +149,30 @@ def preprocess_meta( sequence_length -= 1 micro_sequence_length = sequence_length + if self._config.vision_encoder: + image_height = batch_meta.max_image_height + image_width = batch_meta.max_image_width + image_mean = [ + self._config.vision_encoder.normalization.mean_r, + self._config.vision_encoder.normalization.mean_g, + self._config.vision_encoder.normalization.mean_b, + ] + image_std = [ + self._config.vision_encoder.normalization.std_r, + self._config.vision_encoder.normalization.std_g, + self._config.vision_encoder.normalization.std_b, + ] + image_rescale_factor = self._config.vision_encoder.normalization.rescale_factor + vision_kwargs = { + VisionModelKwargs.image_height: image_height, + VisionModelKwargs.image_width: image_width, + VisionModelKwargs.image_mean: image_mean, + VisionModelKwargs.image_std: image_std, + VisionModelKwargs.image_rescale_factor: image_rescale_factor, + } + else: + vision_kwargs = {} + batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) @@ -189,6 +223,7 @@ def preprocess_meta( TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, } + common_kwargs.update(vision_kwargs) preprocessed_meta = [] for sequence_k_past in range( @@ -271,6 +306,16 @@ def preprocess( if self._use_flash_attention: self._flash_varlen_preprocessor.preprocess(kwargs_meta) + if batch.images is not None: + kwargs_meta[VisionModelKwargs.images] = [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for images in batch.images + for img in images + ] + kwargs_meta[VisionModelKwargs.image_positions] = batch.image_positions + if self._config.vision_encoder: + self._vision_preprocessor.preprocess(kwargs_meta) + # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. pasts = presents diff --git a/setup.cfg b/setup.cfg index 57913f83..52676c79 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ CORE = # Required for some optional features and tools. OPTIONAL = # Huggingface tools - transformers>=4.44.2 + transformers>=4.48.3 hf-transfer>=0.1.8 datasets>=3.1.0 huggingface-hub>=0.28.1 From 6cb8f5d0e85e8b1bd24470e387b4c0d259124201 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 30 Apr 2025 16:31:46 +0000 Subject: [PATCH 10/82] make it work, barely --- Dockerfile | 1 + fast_llm/data/data/gpt/data.py | 6 +- fast_llm/data/dataset/gpt/sampled.py | 17 ++- fast_llm/layers/multi_modal/embedding.py | 76 +++++----- fast_llm/layers/vision_encoder/adapter.py | 19 +-- fast_llm/layers/vision_encoder/config.py | 17 ++- fast_llm/layers/vision_encoder/encoder.py | 134 ++++++++++++++---- .../layers/vision_encoder/preprocessing.py | 31 +++- fast_llm/models/gpt/conversion.py | 30 ++-- fast_llm/models/gpt/model.py | 26 ++-- 10 files changed, 240 insertions(+), 117 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8c2efa85..b8e1f888 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,6 +4,7 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ && apt-get install --no-install-recommends -y acl git-lfs \ + && apt-get install --no-install-recommends -y libtiff5-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 22e4730c..cffaa734 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -49,10 +49,12 @@ def gpt_data_collate_fn( stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] if not cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] + has_images = False batch_images = [] for sample in batch: if sample.images is not None: batch_images.append([torch.from_numpy(image) for image in sample.images]) + has_images = True else: batch_images.append(None) batch_image_positions = [] @@ -65,8 +67,8 @@ def gpt_data_collate_fn( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, - images=batch_images if any(batch_images) else None, - image_positions=batch_image_positions if any(batch_image_positions) else None, + images=batch_images if has_images else None, + image_positions=batch_image_positions if has_images else None, ) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 8acbf9ee..973c1db5 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -433,13 +433,22 @@ def __getitem__(self, index: int) -> typing.Any: length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._config.use_loss_masking_spans, ) - # TODO Soham: handle images with loss masking spans + start_pos = 0 for idx, im_position in enumerate(sample.image_positions): # image_positions.append(im_positions + len(token_ids) + image_tokens_added) + # Add placeholders for image tokens + token_ids.append(sample.token_ids[start_pos:im_position]) + token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) image_positions.append(im_position + len(token_ids) + image_tokens_added) image_tokens_added += image_tokens + start_pos = im_position + token_ids.append(sample.token_ids[start_pos:]) + # TODO Soham: remove this + # if len(sample.images) == 1: + # sample.images.append(sample.images[0]) + # sample.image_positions = np.concatenate([sample.image_positions, sample.image_positions]) images.append(sample.images) - token_ids.append(sample.token_ids) + # TODO Soham: add offsets for loss masking spans if self._config.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: span = np.clip(loss_masking_span + token_count - token_start, 0, self._sequence_length + 1) @@ -452,7 +461,7 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) - + np.array([ImageProcessor.get_num_patches(image) for image in images[idx] for idx in range(len(images))]) + # + np.array([ImageProcessor.get_num_patches(image) for image in images[idx] for idx in range(len(images))]) if not self._cross_document_attention else None ) @@ -464,7 +473,7 @@ def __getitem__(self, index: int) -> typing.Any: ) images = [im for img_list in images for im in img_list] if images else None image_positions = np.array(image_positions) if image_positions else None - Assert.eq(len(token_ids) + image_tokens_added, self._sequence_length + 1) + Assert.eq(len(token_ids), self._sequence_length + 1) return GPTSample( token_ids=token_ids, diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index a92fdc4e..3b62c60b 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -25,59 +25,59 @@ def __init__( super().__init__(config, tensor_space) self.vision_encoder = VisionEncoder(config, tensor_space) - def _forward( + def forward( self, input_: torch.Tensor, - position_ids: torch.Tensor | None, - images: torch.Tensor | None, - image_sizes: torch.Tensor | None, - image_positions: list[torch.Tensor] | None, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, ) -> torch.Tensor: - image_embeddings = self.vision_encoder(images, kwargs={VisionModelKwargs.image_sizes: image_sizes}) + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Embedding output", + dtype=self._residual_dtype, + ) + # return self._forward( + # input_, + # kwargs.get(LanguageModelKwargs.position_ids), + # kwargs.get(VisionModelKwargs.images), + # kwargs.get(VisionModelKwargs.image_sizes), + # kwargs.get(VisionModelKwargs.image_positions), + # ) # TODO Soham: offset position ids + images = kwargs.pop(VisionModelKwargs.images)[:1] + position_ids = kwargs.get(LanguageModelKwargs.position_ids) + image_positions = kwargs.get(VisionModelKwargs.image_positions)[:1] + image_embeddings = self.vision_encoder(images, kwargs) + embeddings = super()._forward(input_, position_ids) img_tokens_seen = 0 image_idx = 0 - text_embeddings = super()._forward(input_, position_ids) - embeddings = [] for sample_idx, positions in enumerate(image_positions): - embedding_parts = [] - for position in positions: + # embedding_parts = [] + for position in positions[:1]: image = images[image_idx] - image_tokens = (image.size[1] // self._config.vision_encoder.encoder.patch_size) * ( - image.size[2] // self._config.vision_encoder.encoder.patch_size + image_tokens = (image.size(1) // self._config.vision_encoder.encoder.patch_size) * ( + image.size(2) // self._config.vision_encoder.encoder.patch_size ) + embeddings[sample_idx, position : position + image_tokens] = image_embeddings[ + sample_idx, img_tokens_seen : img_tokens_seen + image_tokens + ] + # embedding_parts.append(text_embeddings[sample_idx, :position]) + # embedding_parts.append(image_embeddings[sample_idx, img_tokens_seen : img_tokens_seen + image_tokens]) image_idx += 1 img_tokens_seen += image_tokens - embedding_parts.append(text_embeddings[sample_idx, :position]) - embedding_parts.append(image_embeddings[img_tokens_seen : img_tokens_seen + image_tokens]) - embedding_parts.append(text_embeddings[sample_idx, position + image_tokens :]) - embeddings.append(torch.cat(embedding_parts, dim=0)) - embeddings = torch.stack(embeddings, dim=0) + # embedding_parts.append(text_embeddings[sample_idx, position:]) + # TODO Soham: debug from here + # embeddings.append(torch.cat(embedding_parts, dim=0)) + # embeddings = torch.stack(embeddings, dim=0) with set_generator( self._tensor_space.distributed.tp_generator if self._sequence_parallel else self._tensor_space.distributed.pp_generator ): embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + # assert embeddings.size(1) == 8192 + del image_embeddings + del images return embeddings.to(self._residual_dtype) - - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict | None = None, - ) -> torch.Tensor: - if isinstance(input_, TensorMeta): - return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims], - tensor_name="Embedding output", - dtype=self._residual_dtype, - ) - return self._forward( - input_, - kwargs.get(LanguageModelKwargs.position_ids), - kwargs.get(VisionModelKwargs.images), - kwargs.get(VisionModelKwargs.image_sizes), - kwargs.get(VisionModelKwargs.image_positions), - ) diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index 234c451a..b8436f72 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -1,16 +1,13 @@ -import typing - import torch -from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.common.linear import LinearBase +from fast_llm.layers.common.linear import Linear from fast_llm.layers.transformer.config import TransformerDimNames from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames from fast_llm.tensor import init_normal_ -class VisionAdapter(Layer): +class VisionAdapter(torch.nn.Module): """ Vision adapter layer for the LLM. """ @@ -19,14 +16,14 @@ def __init__(self, intermediate_size: int, tensor_space: TensorSpace, name: str super().__init__() self._name = name input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) - self.layer_1 = LinearBase( + self.layer_1 = Linear( input_dim, tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), bias=True, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) - self.layer_2 = LinearBase( + self.layer_2 = Linear( tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), tensor_space.get_tensor_dim(TransformerDimNames.hidden), bias=True, @@ -34,11 +31,5 @@ def __init__(self, intermediate_size: int, tensor_space: TensorSpace, name: str bias_init_method=init_normal_(), ) - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ): + def forward(self, input_: torch.Tensor): return self.layer_2(self.layer_1(input_)) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 5e472251..65ae8e50 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -2,7 +2,7 @@ from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import NormalizationType +from fast_llm.layers.common.config import NormalizationConfig class VisionEncoderDimNames: @@ -10,9 +10,11 @@ class VisionEncoderDimNames: intermediate_size = "vision_intermediate_size" patch_height = "vision_patch_height" patch_width = "vision_patch_width" + kv_channels = "vision_kv_channels" class VisionModelKwargs: + patch_size = "patch_size" images = "images" image_positions = "image_positions" image_height = "image_height" @@ -21,6 +23,9 @@ class VisionModelKwargs: image_mean = "image_normalization_mean" image_std = "image_normalization_std" image_rescale_factor = "image_rescale_factor" + rope_theta = "vit_rope_theta" + rotary_inv_freq = "vit_rotary_inv_freq" + kv_channels = "vit_kv_channels" @config_class() @@ -54,10 +59,8 @@ class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): Configuration class for the vision encoder, which transforms images into embeddings """ path: str | None = Field(default=None, desc="Path to a pretrained vision encoder model.", hint=FieldHint.optional) - pre_norm: NormalizationType = Field( - default=NormalizationType.rms_norm, - desc="The type of normalization to use before the transformer layers.", - hint=FieldHint.optional, + pre_norm: NormalizationConfig = Field( + default_factory=NormalizationConfig, ) hidden_size: int = Field( default=1024, desc="The size of the hidden layers in the transformer model.", hint=FieldHint.optional @@ -168,6 +171,10 @@ def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.intermediate_size, self.adapter_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_height, self.encoder.patch_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_width, self.encoder.patch_size)) + # TODO Soham: add a check for kv channels + tensor_space.add_tensor_dim( + TensorDim(VisionEncoderDimNames.kv_channels, self.encoder.hidden_size // self.encoder.num_attention_heads) + ) # tensor_space.add_tensor_dim( # CompositeTensorDim(VisionEncoderDimNames.) # ) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index b028fa1f..bbcebf25 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -2,7 +2,8 @@ import typing import torch -from transformers import PixtralVisionConfig, PixtralVisionModel +from transformers import PixtralVisionConfig +from transformers.models.pixtral.modeling_pixtral import PixtralTransformer from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -13,6 +14,33 @@ from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +def position_ids_in_meshgrid(patch_embeddings_list, max_width): + positions = [] + for patch in patch_embeddings_list: + height, width = patch.shape[-2:] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + positions.append(ids[:, 0]) + return torch.cat(positions) + + +def generate_block_attention_mask(patch_embeds_list, tensor): + dtype = tensor.dtype + device = tensor.device + seq_len = tensor.shape[1] + d_min = torch.finfo(dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + + block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) + block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) + for start, end in zip(block_start_idx, block_end_idx): + causal_mask[start:end, start:end] = 0 + + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + return causal_mask + + # TODO Soham: should this just be nn.Module? class VisionEncoder(Layer): """ @@ -25,37 +53,49 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config.vision_encoder self._distributed_config = tensor_space.distributed_config with torch.device("meta"): - if self._config.encoder.path: - self.vision_encoder = PixtralVisionModel.from_pretrained( - self._config.encoder.path, torch_dtype=self._distributed_config.training_dtype.torch - ) - else: - # TODO Soham options to fix rotary: - # 1. load PixtralTransformer instead of PixtralVisionModel. Required to implement conv2d, ln_pre separately and store positional embeddings in kwargs_meta - # 2. set self.vision_encoder.position_embeddings = PixtralRotaryEmbedding(config) outside of meta scope - self.vision_encoder = PixtralVisionModel( - PixtralVisionConfig( - hidden_size=self._config.encoder.hidden_size, - intermediate_size=self._config.encoder.intermediate_size, - num_hidden_layers=self._config.encoder.num_hidden_layers, - num_attention_heads=self._config.encoder.num_attention_heads, - num_channels=self._config.encoder.num_channels, - image_size=self._config.encoder.image_size, - patch_size=self._config.encoder.patch_size, - hidden_act=self._config.encoder.hidden_act, - attention_dropout=self._config.encoder.attention_dropout, - rope_theta=self._config.encoder.rope_theta, - initializer_range=self._config.encoder.initializer_range, - ) - ) + # TODO Soham options to fix rotary: + # 1. load PixtralTransformer instead of PixtralVisionModel. Required to implement conv2d, ln_pre separately and store positional embeddings in kwargs_meta + # 2. set self.vision_encoder.position_embeddings = PixtralRotaryEmbedding(config) outside of meta scope + config = PixtralVisionConfig( + hidden_size=self._config.encoder.hidden_size, + intermediate_size=self._config.encoder.intermediate_size, + num_hidden_layers=self._config.encoder.num_hidden_layers, + num_attention_heads=self._config.encoder.num_attention_heads, + num_channels=self._config.encoder.num_channels, + image_size=self._config.encoder.image_size, + patch_size=self._config.encoder.patch_size, + hidden_act=self._config.encoder.hidden_act, + attention_dropout=self._config.encoder.attention_dropout, + rope_theta=self._config.encoder.rope_theta, + initializer_range=self._config.encoder.initializer_range, + ) + self.patch_conv = torch.nn.Conv2d( + in_channels=3, + out_channels=self._config.encoder.hidden_size, + kernel_size=self._config.encoder.patch_size, + stride=self._config.encoder.patch_size, + bias=False, + ) + self.patch_conv.weight = ParameterMeta.from_dims( + tuple( + TensorDim(f"patch_conv_weight_{idx}", size) + for idx, size in enumerate(self.patch_conv.weight.shape) + ), + init_method=init_normal_(), + ) + self.norm = self._config.encoder.pre_norm.get_layer( + tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + ) + self.vision_transformer = PixtralTransformer(config) + # self.vision_encoder = PixtralVisionModel(config) param_names = [] # gather all names first. PyTorch complains if we do it in the loop - for name, param in self.vision_encoder.named_parameters(): + for name, param in self.vision_transformer.named_parameters(): param_names.append(name) for name in param_names: *module_path, stem = name.split(".") - module = functools.reduce(getattr, module_path, self.vision_encoder) - param = self.vision_encoder.get_parameter(name) + module = functools.reduce(getattr, module_path, self.vision_transformer) + param = self.vision_transformer.get_parameter(name) setattr( module, stem, @@ -72,6 +112,38 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): tensor_space=tensor_space, ) + def _forward( + self, input_: torch.Tensor, image_sizes: torch.Tensor, inv_freq: torch.Tensor, image_width: int + ) -> torch.Tensor: + patch_embeddings = self.patch_conv(input_) + patch_embeddings_list = [ + embedding[..., : image_size[0], : image_size[1]] + for embedding, image_size in zip(patch_embeddings, image_sizes) + ] + patch_embeddings = torch.cat([p.flatten(1).T for p in patch_embeddings_list], dim=0).unsqueeze(0) + patch_embeddings = self.norm(patch_embeddings) + position_ids = position_ids_in_meshgrid(patch_embeddings_list, image_width // self._config.encoder.patch_size) + freqs = inv_freq[position_ids] + with torch.autocast(device_type=input_.device.type): + cos = freqs.cos() + sin = freqs.sin() + cos = cos.to(dtype=input_.dtype) + sin = sin.to(dtype=input_.dtype) + + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeddings_list], patch_embeddings + ) + + (out,) = self.vision_transformer( + patch_embeddings, + attention_mask=attention_mask, + position_embeddings=(cos, sin), + output_attentions=False, + return_dict=False, + ) + + return self.adapter(out) + def forward( self, input_: torch.Tensor, @@ -85,4 +157,10 @@ def forward( tensor_name="Vision Output", dtype=self._distributed_config.training_dtype.torch, ) - return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) + return self._forward( + input_, + kwargs[VisionModelKwargs.image_sizes][:1], + kwargs[VisionModelKwargs.rotary_inv_freq], + image_width=kwargs[VisionModelKwargs.image_width], + ) + # return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 7ebfd3d7..57ee3a0b 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -40,7 +40,27 @@ def pad(image: torch.Tensor, max_height, max_width) -> torch.Tensor: """ width_padding = max(0, max_height - image.size(1)) depth_padding = max(0, max_width - image.size(2)) - return F.pad(image, (0, 0, width_padding, depth_padding), 0) + return F.pad(image, (0, 0, depth_padding, width_padding), 0) + + +def create_inv_freqs(rope_theta: int, kv_channels: int, image_size: int, patch_size: int) -> torch.Tensor: + freqs = 1.0 / (rope_theta ** (torch.arange(0, kv_channels, 2).float() / kv_channels)) + max_patches_per_side = image_size // patch_size + + h = torch.arange(max_patches_per_side) + w = torch.arange(max_patches_per_side) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + return torch.cat((inv_freq, inv_freq), dim=-1) class VisionPreprocessor: @@ -53,7 +73,8 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get("images") im_height = kwargs.get(VisionModelKwargs.image_height) im_width = kwargs.get(VisionModelKwargs.image_width) - kwargs[VisionModelKwargs.image_sizes] = [(im.size(1), im.size(2)) for im in images] + image_sizes = [get_resize_dims(im.size(1), im.size(2), im_height, im_width) for im in images] + kwargs[VisionModelKwargs.image_sizes] = image_sizes images = [ pad( normalize( @@ -72,3 +93,9 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: dtype=self._distributed_config.training_dtype.torch, ) kwargs[VisionModelKwargs.images] = images + kwargs[VisionModelKwargs.rotary_inv_freq] = create_inv_freqs( + kwargs[VisionModelKwargs.rope_theta], + kwargs[VisionModelKwargs.kv_channels], + im_height, + kwargs[VisionModelKwargs.patch_size], + ).to(device=self._tensor_space.distributed.device) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 3caaee5a..bd7da797 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -597,6 +597,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ), ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "encoder", "pre_norm", "type"),), + fast_llm_value=NormalizationType.rms_norm, + ), # Vision Transformer RenameParamConverter( fast_llm_names=(("vision_encoder", "encoder", "num_hidden_layers"),), @@ -679,39 +683,39 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: for i in range(num_layers): vision_transformer_converters += [ WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.k_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.k_proj.weight", f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.v_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.v_proj.weight", f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.q_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.q_proj.weight", f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention.o_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.o_proj.weight", f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.attention_norm.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention_norm.weight", f"vision_tower.transformer.layers.{i}.attention_norm.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.down_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.down_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.gate_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.gate_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.feed_forward.up_proj.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.up_proj.weight", f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", ), WeightConverter( - f"layers.0.vision_encoder.vision_encoder.transformer.layers.{i}.ffn_norm.weight", + f"layers.0.vision_encoder.vision_transformer.layers.{i}.ffn_norm.weight", f"vision_tower.transformer.layers.{i}.ffn_norm.weight", ), ] @@ -720,20 +724,20 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: patch_conv_converter = WeightConverter( - "layers.0.vision_encoder.vision_encoder.patch_conv.weight", + "layers.0.vision_encoder.patch_conv.weight", "vision_tower.patch_conv.weight", ) # TODO Soham: use _get_weight_and_bias_converters? layernorm_converters = [] layer_norm_converter = WeightConverter( - "layers.0.vision_encoder.vision_encoder.ln_pre.weight", + "layers.0.vision_encoder.norm.weight", "vision_tower.ln_pre.weight", ) layernorm_converters.append(layer_norm_converter) layer_norm_converter - if self._model.config.base_model.vision_encoder.encoder.pre_norm == NormalizationType.layer_norm: + if self._model.config.base_model.vision_encoder.encoder.pre_norm.type == NormalizationType.layer_norm: layer_norm_bias_converter = WeightConverter( - "layers.0.vision_encoder.vision_encoder.ln_pre.bias", + "layers.0.vision_encoder.norm.bias", "vision_tower.ln_pre.bias", ) layernorm_converters.append(layer_norm_bias_converter) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 0890051e..ffbd2281 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -27,7 +27,7 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionModelKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionModelKwargs from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -164,11 +164,16 @@ def preprocess_meta( ] image_rescale_factor = self._config.vision_encoder.normalization.rescale_factor vision_kwargs = { + VisionModelKwargs.patch_size: self._config.vision_encoder.encoder.patch_size, VisionModelKwargs.image_height: image_height, VisionModelKwargs.image_width: image_width, VisionModelKwargs.image_mean: image_mean, VisionModelKwargs.image_std: image_std, VisionModelKwargs.image_rescale_factor: image_rescale_factor, + VisionModelKwargs.rope_theta: self._config.vision_encoder.encoder.rope_theta, + VisionModelKwargs.kv_channels: self._tensor_space.get_tensor_dim( + VisionEncoderDimNames.kv_channels + ).size, } else: vision_kwargs = {} @@ -306,16 +311,6 @@ def preprocess( if self._use_flash_attention: self._flash_varlen_preprocessor.preprocess(kwargs_meta) - if batch.images is not None: - kwargs_meta[VisionModelKwargs.images] = [ - img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) - for images in batch.images - for img in images - ] - kwargs_meta[VisionModelKwargs.image_positions] = batch.image_positions - if self._config.vision_encoder: - self._vision_preprocessor.preprocess(kwargs_meta) - # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. pasts = presents @@ -349,6 +344,15 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels + if batch.images is not None: + kwargs[VisionModelKwargs.images] = [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for images in batch.images + for img in images + ] + kwargs[VisionModelKwargs.image_positions] = batch.image_positions + if self._config.vision_encoder: + self._vision_preprocessor.preprocess(kwargs) if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs) if self._config.transformer.rotary.enabled: From 5761a2d52cf4e7e5fcfd38ec19750be48cb06f8e Mon Sep 17 00:00:00 2001 From: root Date: Wed, 30 Apr 2025 18:24:54 +0000 Subject: [PATCH 11/82] fix --- fast_llm/data/dataset/gpt/memmap.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 99bfbfa4..54bf6826 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -49,11 +49,14 @@ def _init( with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 3: + self._has_preference_spans = struct.unpack("= 4: self._has_images = struct.unpack("= 3: + if self._has_images and self._version >= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) @@ -333,10 +336,12 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version # Version 2 onwards optionally add loss-masking spans - # Version 3 onwards optionally add images - idx_stream.write(struct.pack(" 0 else 0)) + # Placeholder flag for preference spans + idx_stream.write(struct.pack(" 0 else 0)) # Data type From d45d60061068b316c3e49d633ea0e8adbc2d52ef Mon Sep 17 00:00:00 2001 From: root Date: Thu, 1 May 2025 05:43:50 +0000 Subject: [PATCH 12/82] fixes --- fast_llm/data/config.py | 57 ------------------- fast_llm/data/data/gpt/data.py | 9 +-- fast_llm/data/dataset/gpt/config.py | 5 +- fast_llm/data/dataset/gpt/memmap.py | 15 +---- fast_llm/data/dataset/gpt/sampled.py | 18 ++---- fast_llm/data/image_processor.py | 55 ------------------ fast_llm/engine/schedule/config.py | 11 +--- fast_llm/layers/vision_encoder/config.py | 57 +------------------ fast_llm/layers/vision_encoder/encoder.py | 2 +- .../layers/vision_encoder/preprocessing.py | 30 +++++++--- fast_llm/models/gpt/model.py | 6 +- fast_llm/models/gpt/trainer.py | 3 +- 12 files changed, 44 insertions(+), 224 deletions(-) delete mode 100644 fast_llm/data/image_processor.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index f1a0fd58..1586d370 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -34,60 +34,3 @@ class TokenizerConfig(Config): desc="Path to the tokenizer file.", hint=FieldHint.core, ) - - -@config_class() -class ImageProcessorConfig(Config): - """ - Configuration for the image processor - """ - - # Defaults taken from [pixtral](https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/pixtral/image_processing_pixtral.py#L201) - # patch_size: list[int] = Field( - # default_factory=lambda: [16, 16], - # desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer", - # hint=FieldHint.optional, - # ) - # max_height: int = Field( - # default=1024, - # desc="Maximum height of the image. Image will be resized if larger", - # hint=FieldHint.optional, - # ) - # max_width: int = Field( - # default=1024, - # desc="Maximum width of the image. Image will be resized if larger", - # hint=FieldHint.optional, - # ) - # mean: list[float] = Field( - # default_factory=lambda: [0.48145466, 0.4578275, 0.40821073], - # desc="Mean RGB values for pixel normalization", - # hint=FieldHint.optional, - # ) - # std: list[float] = Field( - # default_factory=lambda: [0.26862954, 0.26130258, 0.27577711], - # desc="Standard deviation RGB values for pixel normalization", - # hint=FieldHint.optional, - # ) - # rescale_factor: float = Field( - # default=255.0, - # desc="Diminisher factor for pixel normalization", - # hint=FieldHint.optional, - # ) - - -@config_class() -class MultiModalProcessorConfig(Config): - """ - Wrapper config that stores the `ImageProcessorConfig` and `TokenizerConfig` - """ - - tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, - desc="Configuration for the tokenizer.", - hint=FieldHint.core, - ) - image_processor: ImageProcessorConfig = Field( - default_factory=ImageProcessorConfig, - desc="Configuration for the image processor.", - hint=FieldHint.core, - ) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index cffaa734..34b86f21 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -91,8 +91,7 @@ def __init__( max_sequence_length: int, cross_document_attention: bool = True, patch_size: list[int] | None = None, - max_image_height: int | None = None, - max_image_width: int | None = None, + max_image_size: int | None = None, ): """ Create the data and gather some basic information on the dataset(s). @@ -103,8 +102,7 @@ def __init__( self._max_sequence_length = max_sequence_length self._cross_document_attention = cross_document_attention self._patch_size = patch_size - self._max_image_height = max_image_height - self._max_image_width = max_image_width + self._max_image_size = max_image_size def setup( self, @@ -153,8 +151,7 @@ def setup( truncate_documents=self._config.truncate_documents, cross_document_attention=self._cross_document_attention, patch_size=self._patch_size, - image_height=self._max_image_height, - image_width=self._max_image_width, + image_size=self._max_image_size, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 8022a05f..65adf0bd 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -73,9 +73,8 @@ class GPTSamplingData(SamplingData): tokenizer: "Tokenizer" truncate_documents: bool = True cross_document_attention: bool = True - patch_size: list[int] | None = None - image_height: int | None = None - image_width: int | None = None + patch_size: int | None = None + image_size: int | None = None @config_class() diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 54bf6826..8651b8fc 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -170,20 +170,8 @@ def get( offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, - # , patch_size: tuple(int), max_height: int, max_width: int ): # TODO Soham: handle spans - # if self._has_images: - # doc_size = self._document_sizes[idx] - # n_images = self._n_images[idx] - # image_positions = self._im_positions[idx] - # image_lengths = self._im_lengths[idx] - # image_tokens_seen = 0 - # for idx in range(n_images): - # height, width = ImageProcessor.get_resize_dims(image_lengths[0], image_lengths[1], max_height, max_width) - # n_image_tokens = (height // patch_size[0]) * (width // patch_size[1]) - # if (image_positions[idx] > offset + length) or (image_positions[idx] + n_tokens < offset): - # continue token_ids = np.frombuffer( self._bin_buffer, dtype=self._dtype, @@ -299,6 +287,9 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP for image in document.images: # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + if img.mode == "L": + # Convert grayscale to RGB + img = img.convert("RGB") pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW image_lengths.append(np.array(pixels.shape[1:])) bin_stream.write(pixels.tobytes(order="C")) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 973c1db5..0ba3f0e1 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -12,9 +12,9 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.image_processor import ImageProcessor from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches, get_resize_dims from fast_llm.utils import Assert try: @@ -91,8 +91,7 @@ def __init__( self._num_samples = sampling.num_samples self._sequence_length = sampling.sequence_length self._patch_size = sampling.patch_size - self._image_height = sampling.image_height - self._image_width = sampling.image_width + self._image_size = sampling.image_size self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config self._truncate_documents = sampling.truncate_documents @@ -142,7 +141,7 @@ def _sample(self) -> None: image_token_sizes = torch.zeros_like(document_sizes).to(self._device) # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size[0]) * (sizes[:, 1] // self._patch_size[1])) + image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size) * (sizes[:, 1] // self._patch_size)) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -394,10 +393,8 @@ def __getitem__(self, index: int) -> typing.Any: document_size, image_lengths = self._indexed_dataset.get_document_size(document_index, self._patch_size) image_sizes = [ - ImageProcessor.get_num_patches_from_dims( - *ImageProcessor.get_resize_dims( - *image_length, self._image_height, self._image_width, self._patch_size - ), + get_num_patches( + *get_resize_dims(*image_length, self._image_size, self._image_size, self._patch_size), self._patch_size, ) for image_length in image_lengths @@ -443,10 +440,6 @@ def __getitem__(self, index: int) -> typing.Any: image_tokens_added += image_tokens start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) - # TODO Soham: remove this - # if len(sample.images) == 1: - # sample.images.append(sample.images[0]) - # sample.image_positions = np.concatenate([sample.image_positions, sample.image_positions]) images.append(sample.images) # TODO Soham: add offsets for loss masking spans if self._config.use_loss_masking_spans: @@ -461,7 +454,6 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) - # + np.array([ImageProcessor.get_num_patches(image) for image in images[idx] for idx in range(len(images))]) if not self._cross_document_attention else None ) diff --git a/fast_llm/data/image_processor.py b/fast_llm/data/image_processor.py deleted file mode 100644 index edfeceb9..00000000 --- a/fast_llm/data/image_processor.py +++ /dev/null @@ -1,55 +0,0 @@ -import math - -import torch -from torchvision.transforms.v2 import functional as F - -from fast_llm.data.config import ImageProcessorConfig - - -class ImageProcessor: - def __init__(self, config: ImageProcessorConfig): - self.patch_size = config.patch_size - self.mean = [x / config.rescale_factor for x in config.mean] - self.std = [x / config.rescale_factor for x in config.std] - self.max_height = config.max_height - self.max_width = config.max_width - assert ( - self.max_height % self.patch_size[0] == 0 - ), "max_height must be divisible by patch_size[0]. Found {max_height} and {self.patch_size[0]}" - assert ( - self.max_width % self.patch_size[1] == 0 - ), "max_width must be divisible by patch_size[1]. Found {max_width} and {self.patch_size[1]}" - - def resize(self, image): - # Resize the image to the specified size - # TODO Soham: resize for patches only during train? - # TODO Soham: convert all images to tensor? - # height = image.shape[0] - # width = image.shape[1] - height, width = self.get_resize_dims(image.shape[0], image.shape[1], self.max_height, self.max_width) - - # TODO: options for interpolation mode - return F.resize(image, size=(height, width), interpolation=F.InterpolationMode.BICUBIC) - - # TODO Soham: move to utils - @classmethod - def get_resize_dims(self, height, width, max_height, max_width, patch_size: list[int]): - ratio = max(height / max_height, width / max_width) - return ( - (math.ceil(height / ratio), math.ceil(width / ratio)) - if ratio > 1 - else (patch_size[0] * math.ceil(height / patch_size[0]), patch_size[1] * math.ceil(width / patch_size[1])) - ) - - def normalize(self, image: torch.Tensor) -> torch.Tensor: - # Normalize the image using the mean and std - return F.normalize(image, mean=self.mean, std=self.std) - - @classmethod - # TODO Soham: move to utils - def get_num_patches(self, image: torch.Tensor, patch_size: list[int]) -> torch.Tensor: - return (image.shape[0] // patch_size[0]) * (image.shape[1] // patch_size[1]) - - @classmethod - def get_num_patches_from_dims(self, height: int, width: int, patch_size: list[int]) -> torch.Tensor: - return (height // patch_size[0]) * (width // patch_size[1]) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 16cfaf71..9cf8f8b5 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,19 +55,14 @@ class BatchConfig(Config): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - patch_size: list[int] | None = Field( + patch_size: int | None = Field( default=None, desc="Patch size for each image token", hint=FieldHint.optional, ) - max_image_height: int | None = Field( + max_image_size: int | None = Field( default=None, - desc="Maximum image height for each image token", - hint=FieldHint.optional, - ) - max_image_width: int | None = Field( - default=None, - desc="Maximum image width for each image token", + desc="Maximum image height and width", hint=FieldHint.optional, ) num_micro_sequences: int = Field( diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 65ae8e50..b83a118b 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -8,8 +8,7 @@ class VisionEncoderDimNames: out_channels = "vision_out_channels" intermediate_size = "vision_intermediate_size" - patch_height = "vision_patch_height" - patch_width = "vision_patch_width" + patch_size = "vision_patch_size" kv_channels = "vision_kv_channels" @@ -17,8 +16,7 @@ class VisionModelKwargs: patch_size = "patch_size" images = "images" image_positions = "image_positions" - image_height = "image_height" - image_width = "image_width" + image_size = "image_size" image_sizes = "image_sizes" image_mean = "image_normalization_mean" image_std = "image_normalization_std" @@ -28,30 +26,6 @@ class VisionModelKwargs: kv_channels = "vit_kv_channels" -@config_class() -class PatchConvConfig(BaseModelArchitectureConfig): - _abstract = False - """ - Configuration class for the convolution layers to apply on the image patches - """ - in_channels: int = Field( - default=3, - desc="Number of input channels for the convolution layers. Typically 3 for RGB images.", - hint=FieldHint.optional, - ) - bias: bool = Field( - default=False, desc="Whether to use a bias term in the convolution layers.", hint=FieldHint.optional - ) - height: int = Field( - default=16, - desc="Height of the image patches considered as tokens", - ) - width: int | None = Field( - default=16, - desc="Width of the image patches considered as tokens", - ) - - @config_class() class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): _abstract = False @@ -169,33 +143,8 @@ class VisionArchitectureConfig(BaseModelArchitectureConfig): def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.encoder.hidden_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.intermediate_size, self.adapter_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_height, self.encoder.patch_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_width, self.encoder.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.encoder.patch_size)) # TODO Soham: add a check for kv channels tensor_space.add_tensor_dim( TensorDim(VisionEncoderDimNames.kv_channels, self.encoder.hidden_size // self.encoder.num_attention_heads) ) - # tensor_space.add_tensor_dim( - # CompositeTensorDim(VisionEncoderDimNames.) - # ) - - # patch_convolution: PatchConvConfig = Field( - # default_factory=PatchConvConfig, - # desc="Configuration for the convolution layers applied to the image patches.", - # hint=FieldHint.optional - # ) - # normalization: NormalizationArchitectureConfig = Field( - # default_factory=NormalizationArchitectureConfig, - # desc="Configuration for the normalization layers applied to the image patches.", - # hint=FieldHint.optional - # ) - # transformer: TransformerArchitectureConfig = Field( - # default_factory=TransformerArchitectureConfig, - # desc="Configuration for the transformer layers applied to the image patches.", - # hint=FieldHint.optional - # ) - # patch_rotary: RotaryArchitectureConfig = Field( - # default_factory=RotaryArchitectureConfig, - # desc="Configuration for the rotary positional embeddings applied to the image patches.", - # hint=FieldHint.optional - # ) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index bbcebf25..8c694d28 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -161,6 +161,6 @@ def forward( input_, kwargs[VisionModelKwargs.image_sizes][:1], kwargs[VisionModelKwargs.rotary_inv_freq], - image_width=kwargs[VisionModelKwargs.image_width], + image_width=kwargs[VisionModelKwargs.image_size], ) # return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 57ee3a0b..154c1a16 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -1,3 +1,4 @@ +import math import typing import torch @@ -5,9 +6,17 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig, VisionModelKwargs +from fast_llm.utils import div -def get_resize_dims(height: int, width: int, max_height: int, max_width: int) -> tuple[int, int]: +def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the number of patches in height and width dimensions. + """ + return div(height, patch_size) * div(width, patch_size) + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: """ Calculate the new dimensions for resizing an image while maintaining the aspect ratio. If the image is larger than the max dimensions, it will be resized to fit within them. @@ -17,12 +26,12 @@ def get_resize_dims(height: int, width: int, max_height: int, max_width: int) -> return ( (int(height / ratio), int(width / ratio)) if ratio > 1 - else (max_height * (height // max_height), max_width * (width // max_width)) + else (patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size)) ) -def resize(image: torch.Tensor, max_height: int, max_width: int) -> tuple[int, int]: - resize_dims = get_resize_dims(image.size(1), image.size(2), max_height, max_width) +def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + resize_dims = get_resize_dims(image.size(1), image.size(2), max_height, max_width, patch_size=patch_size) # TODO: options for interpolation mode? return F.resize(image, size=resize_dims, interpolation=F.InterpolationMode.BICUBIC) @@ -71,14 +80,17 @@ def __init__(self, config: VisionArchitectureConfig, tensor_space: TensorSpace): def preprocess(self, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get("images") - im_height = kwargs.get(VisionModelKwargs.image_height) - im_width = kwargs.get(VisionModelKwargs.image_width) - image_sizes = [get_resize_dims(im.size(1), im.size(2), im_height, im_width) for im in images] + im_height = kwargs.get(VisionModelKwargs.image_size) + im_width = kwargs.get(VisionModelKwargs.image_size) + patch_size = kwargs[VisionModelKwargs.patch_size] + image_sizes = [ + get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in images + ] kwargs[VisionModelKwargs.image_sizes] = image_sizes images = [ pad( normalize( - resize(image, im_height, im_width) / kwargs[VisionModelKwargs.image_rescale_factor], + resize(image, im_height, im_width, patch_size) / kwargs[VisionModelKwargs.image_rescale_factor], mean=kwargs[VisionModelKwargs.image_mean], std=kwargs[VisionModelKwargs.image_std], ), @@ -97,5 +109,5 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: kwargs[VisionModelKwargs.rope_theta], kwargs[VisionModelKwargs.kv_channels], im_height, - kwargs[VisionModelKwargs.patch_size], + patch_size, ).to(device=self._tensor_space.distributed.device) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index ffbd2281..c273f09b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -150,8 +150,7 @@ def preprocess_meta( micro_sequence_length = sequence_length if self._config.vision_encoder: - image_height = batch_meta.max_image_height - image_width = batch_meta.max_image_width + image_size = batch_meta.max_image_size image_mean = [ self._config.vision_encoder.normalization.mean_r, self._config.vision_encoder.normalization.mean_g, @@ -165,8 +164,7 @@ def preprocess_meta( image_rescale_factor = self._config.vision_encoder.normalization.rescale_factor vision_kwargs = { VisionModelKwargs.patch_size: self._config.vision_encoder.encoder.patch_size, - VisionModelKwargs.image_height: image_height, - VisionModelKwargs.image_width: image_width, + VisionModelKwargs.image_size: image_size, VisionModelKwargs.image_mean: image_mean, VisionModelKwargs.image_std: image_std, VisionModelKwargs.image_rescale_factor: image_rescale_factor, diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index b801fbd3..bc16829b 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -22,8 +22,7 @@ def _get_data(self) -> GPTData: max_sequence_length=self._config.batch.sequence_length, cross_document_attention=self._config.batch.cross_document_attention, patch_size=self._config.batch.patch_size, - max_image_height=self._config.batch.max_image_height, - max_image_width=self._config.batch.max_image_width, + max_image_size=self._config.batch.max_image_size, ) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: From 74a99b8ec047e31acd514a48237196ed9da761be Mon Sep 17 00:00:00 2001 From: root Date: Tue, 6 May 2025 17:44:50 +0000 Subject: [PATCH 13/82] changes --- fast_llm/engine/schedule/config.py | 21 +- fast_llm/functional/config.py | 2 + fast_llm/layers/language_model/config.py | 20 +- fast_llm/layers/multi_modal/embedding.py | 52 +-- fast_llm/layers/transformer/attention.py | 109 +++--- fast_llm/layers/transformer/config.py | 96 +++-- fast_llm/layers/transformer/mlp.py | 22 +- fast_llm/layers/transformer/preprocessing.py | 139 ++++++-- fast_llm/layers/transformer/transformer.py | 18 +- fast_llm/layers/vision_encoder/adapter.py | 39 ++- fast_llm/layers/vision_encoder/config.py | 178 ++++++---- fast_llm/layers/vision_encoder/encoder.py | 141 ++------ .../layers/vision_encoder/preprocessing.py | 153 ++++++-- fast_llm/models/gpt/conversion.py | 330 +++++++++++------- fast_llm/models/gpt/model.py | 113 ++++-- fast_llm/tools/cli.py | 1 - 16 files changed, 886 insertions(+), 548 deletions(-) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 9cf8f8b5..517a9cff 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,16 +55,6 @@ class BatchConfig(Config): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - patch_size: int | None = Field( - default=None, - desc="Patch size for each image token", - hint=FieldHint.optional, - ) - max_image_size: int | None = Field( - default=None, - desc="Maximum image height and width", - hint=FieldHint.optional, - ) num_micro_sequences: int = Field( init=False, desc="Number of micro-sequences to split each sample (= seqence length / micro-sequence length).", @@ -81,6 +71,17 @@ class BatchConfig(Config): desc="Pointer to a distributed configuration, required to know the data-parallel split of the batch.", hint=FieldHint.setup, ) + # Image inputs + patch_size: int | None = Field( + default=None, + desc="Patch size for each image token", + hint=FieldHint.optional, + ) + max_image_size: int | None = Field( + default=None, + desc="Maximum image height and width", + hint=FieldHint.optional, + ) def setup(self, distributed_config: DistributedConfig) -> None: self._distributed = distributed_config diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 9f1fe005..c5da0f9b 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -82,6 +82,8 @@ def _set_activation_fn_map() -> None: ActivationType.squared_relu: "relu2", } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} +_ACTIVATION_HF_NAMES_INV["gelu"] = ActivationType.gelu + MAX_DROPLESS_BLOCK_SIZE_ROW = 128 diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index ec80a933..887952d7 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -6,7 +6,7 @@ from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig -from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig +from fast_llm.layers.vision_encoder.config import VisionEncoderArchitectureConfig, VisionEncoderConfig from fast_llm.utils import Assert @@ -34,6 +34,7 @@ class LanguageModelKwargs: position_ids = "position_ids" # TODO: These are generic labels = "labels" + tokens = "tokens" phase = "phase" @@ -44,7 +45,7 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.core, ) - vision_encoder: None | VisionArchitectureConfig = Field( + vision_encoder: None | VisionEncoderArchitectureConfig = Field( default=None, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, @@ -130,7 +131,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) - vision_encoder: None | VisionArchitectureConfig = FieldUpdate( + vision_encoder: None | VisionEncoderConfig = FieldUpdate( default=None, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, @@ -215,16 +216,3 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: if self.vision_encoder is not None: self.vision_encoder.setup_tensor_space(tensor_space) - - -class MultiModalBaseConfig(BaseModelConfig): - language_model: LanguageModelBaseConfig = Field( - default_factory=LanguageModelBaseConfig, - desc="Configuration for the language model.", - hint=FieldHint.core, - ) - vision_model: VisionArchitectureConfig = Field( - default_factory=VisionArchitectureConfig, - desc="Configuration for the vision inputs.", - hint=FieldHint.core, - ) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 3b62c60b..a3abe781 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -7,8 +7,8 @@ from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.vision_encoder.config import VisionModelKwargs -from fast_llm.layers.vision_encoder.encoder import VisionEncoder +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches from fast_llm.tensor import TensorMeta @@ -23,7 +23,6 @@ def __init__( tensor_space: TensorSpace, ): super().__init__(config, tensor_space) - self.vision_encoder = VisionEncoder(config, tensor_space) def forward( self, @@ -38,46 +37,29 @@ def forward( tensor_name="Embedding output", dtype=self._residual_dtype, ) - # return self._forward( - # input_, - # kwargs.get(LanguageModelKwargs.position_ids), - # kwargs.get(VisionModelKwargs.images), - # kwargs.get(VisionModelKwargs.image_sizes), - # kwargs.get(VisionModelKwargs.image_positions), - # ) - # TODO Soham: offset position ids - images = kwargs.pop(VisionModelKwargs.images)[:1] + # image_embeddings = kwargs.pop(VisionEncoderKwargs.patch_embeddings) position_ids = kwargs.get(LanguageModelKwargs.position_ids) - image_positions = kwargs.get(VisionModelKwargs.image_positions)[:1] - image_embeddings = self.vision_encoder(images, kwargs) - embeddings = super()._forward(input_, position_ids) - img_tokens_seen = 0 + image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + tokens = kwargs.get(LanguageModelKwargs.tokens) + # get text embeddings + embeddings = super()._forward(tokens, position_ids) image_idx = 0 - for sample_idx, positions in enumerate(image_positions): - # embedding_parts = [] - for position in positions[:1]: - image = images[image_idx] - image_tokens = (image.size(1) // self._config.vision_encoder.encoder.patch_size) * ( - image.size(2) // self._config.vision_encoder.encoder.patch_size - ) - embeddings[sample_idx, position : position + image_tokens] = image_embeddings[ - sample_idx, img_tokens_seen : img_tokens_seen + image_tokens + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) + embeddings[sample_idx, position : position + num_image_tokens] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens ] - # embedding_parts.append(text_embeddings[sample_idx, :position]) - # embedding_parts.append(image_embeddings[sample_idx, img_tokens_seen : img_tokens_seen + image_tokens]) + image_embedding_offset += num_image_tokens image_idx += 1 - img_tokens_seen += image_tokens - # embedding_parts.append(text_embeddings[sample_idx, position:]) - # TODO Soham: debug from here - # embeddings.append(torch.cat(embedding_parts, dim=0)) - # embeddings = torch.stack(embeddings, dim=0) + with set_generator( self._tensor_space.distributed.tp_generator if self._sequence_parallel else self._tensor_space.distributed.pp_generator ): embeddings = torch.dropout(embeddings, self._dropout_p, self.training) - # assert embeddings.size(1) == 8192 - del image_embeddings - del images + return embeddings.to(self._residual_dtype) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index c7ae55c5..3a3f4023 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -14,7 +14,9 @@ TransformerDimNames, TransformerKwargs, TransformerSubLayerName, + VisionTransformerConfig, ) +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -57,24 +59,6 @@ class Attention(torch.nn.Module): A self-attention layer. """ - _QUERY_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_heads, - TransformerDimNames.kv_channels, - ) - _KV_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, - TransformerDimNames.kv_channels, - ) - _CONTEXT_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_dense, - ) - def __init__( self, config: TransformerConfig, @@ -82,12 +66,19 @@ def __init__( layer_index, ): super().__init__() + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs self._config = config self._tensor_space = tensor_space - Assert.in_range_incl(layer_index, 1, self._config.num_layers) + # Assert.in_range_incl(layer_index, 1, self._config.num_layers) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer + self._causal = self._config.causal self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -101,19 +92,19 @@ def __init__( max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels).size + self._head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).global_size + self._local_head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).size + self._local_heads_per_group = self._tensor_space.get_tensor_dim(self._transformer_dim_names.group_heads).size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_query), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -122,7 +113,7 @@ def __init__( ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_key_value), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -133,7 +124,7 @@ def __init__( # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_dense), hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, @@ -199,7 +190,7 @@ def _attn_fused( def _get_meta( self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} + hidden_dims = {dim.name: dim for dim in kwargs[self._transformer_kwargs.hidden_dims]} return TensorMeta.from_dims( tuple( hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) @@ -209,6 +200,32 @@ def _get_meta( dtype=input_.dtype, ) + @property + def query_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.composite_heads, + self._transformer_dim_names.kv_channels, + ) + + @property + def kv_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.group_heads, + self._transformer_dim_names.kv_channels, + ) + + @property + def context_dims(self): + return ( + self._transformer_dim_names.batch, + self._transformer_dim_names.sequence_q, + self._transformer_dim_names.composite_dense, + ) + def _debug_log( self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] ) -> None: @@ -307,12 +324,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # TODO: Move the rest to function. - if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(self._transformer_kwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(TransformerKwargs.presents)) is not None: + if (presents := kwargs.get(self._transformer_kwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) @@ -339,23 +356,23 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._config.rotary.enabled: if self._debug_transformer: - self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) + self._debug_log(query, "query_rotary_input", self.query_dims, kwargs) self._debug_log( key, "key_rotary_input", - self._KV_DIMS, + self.kv_dims, kwargs, ) rotary_fn = triton_rotary_autograd_ if self._config.rotary.triton else apply_rotary_embeddings - query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) - key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) + query = rotary_fn(query, kwargs[self._transformer_kwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[self._transformer_kwargs.rotary_freq_k]) window_size = self._decide_window_size() if self._use_flash_attention: assert _flash_available with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: + if (cu_seqlens_q := kwargs.get(self._transformer_kwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -365,12 +382,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(self._transformer_kwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(self._transformer_kwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(self._transformer_kwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), - causal=True, + causal=self._causal, softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -380,7 +397,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ value, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), dropout_p=self._config.attention_dropout if self.training else 0.0, - causal=True, + causal=self._causal, softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) @@ -390,25 +407,25 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[TransformerKwargs.attention_mask], - kwargs[TransformerKwargs.attention_mask_value], + kwargs[self._transformer_kwargs.attention_mask], + kwargs[self._transformer_kwargs.attention_mask_value], ) if self._debug_transformer: - self._debug_log(query, "query", self._QUERY_DIMS, kwargs) + self._debug_log(query, "query", self.query_dims, kwargs) self._debug_log( key, "key", - self._KV_DIMS, + self.kv_dims, kwargs, ) self._debug_log( value, "value", - self._KV_DIMS, + self.kv_dims, kwargs, ) - self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) + self._debug_log(input_, "context", self.context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 4806e37e..6b0d7ad6 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -84,6 +84,7 @@ class TransformerKwargs: sequence_q_dim = "sequence_q_dim" sequence_k_dim = "sequence_k_dim" sequence_length = "sequence_length" + micro_batch_size = "micro_batch_size" # TODO: Move grad_output = "grad_output" @@ -98,6 +99,8 @@ class RotaryEmbeddingType(str, enum.Enum): default = "default" llama3 = "llama3" yarn = "yarn" + # TODO Soham: generic name? + pixtral = "pixtral" @config_class() @@ -166,6 +169,15 @@ class RotaryConfig(RotaryArchitectureConfig, BaseModelConfig): pass +@config_class() +class VisionRotaryConfig(RotaryConfig): + type: RotaryEmbeddingType = Field( + default=RotaryEmbeddingType.pixtral, + desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", + hint=FieldHint.feature, + ) + + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -398,63 +410,73 @@ def _from_dict( cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) return super()._from_dict(default, strict, flat) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + if type == "vision": + # TODO Soham: better way to get around circular imports? + from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames + + transformer_dim_names = VisionTransformerDimNames + else: + transformer_dim_names = TransformerDimNames + # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - TransformerDimNames.group_heads, + transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(transformer_dim_names.kv_channels, self.kv_channels)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim(mlp := TensorDim(transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + gate_and_up := TensorDim(transformer_dim_names.gate_and_up, 2 if self.gated else 1) ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp))) + tensor_space.add_tensor_dim(experts := TensorDim(transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim( + CompositeTensorDim(transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) @@ -656,6 +678,11 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + causal: bool = Field( + default=True, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) def _validate(self) -> None: if self.init_method_std is None: @@ -718,3 +745,30 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: Assert.is_(self.window_size, None) return use_flash_attention + + +@config_class() +class VisionRotaryConfig(RotaryConfig): + type: RotaryEmbeddingType = Field( + default=RotaryEmbeddingType.pixtral, + desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", + hint=FieldHint.feature, + ) + + +@config_class() +class VisionTransformerConfig(TransformerConfig): + """ + Configuration for the Vision Transformer (ViT) model. + """ + + causal: bool = FieldUpdate( + default=False, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) + rotary: VisionRotaryConfig = FieldUpdate( + default_factory=VisionRotaryConfig, + desc="Configuration for the rotary positional embeddings.", + hint=FieldHint.feature, + ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 9b90beff..1b494fc0 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,7 +8,14 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName +from fast_llm.layers.transformer.config import ( + TransformerConfig, + TransformerDimNames, + TransformerKwargs, + TransformerSubLayerName, + VisionTransformerConfig, +) +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -18,6 +25,13 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s super().__init__() self._name = name + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs + init_method_1 = init_normal_( std=config.init_method_std_mlp_1, min_val=config.init_method_min_mlp_1, @@ -29,8 +43,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) + self._intermediate_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.composite_expert_mlp) self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -41,7 +55,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space.get_tensor_dim(self._transformer_dim_names.composite_gated_expert_mlp), bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index cbafe6c9..542b4d42 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -12,7 +12,9 @@ TransformerConfig, TransformerDimNames, TransformerKwargs, + VisionTransformerConfig, ) +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -129,63 +131,122 @@ def get_rotary_frequencies( return frequencies +def get_2d_rotary_frequencies( + config: RotaryConfig, + height, + width, + kv_channels, + *, + device="cuda", +) -> torch.Tensor: + # Calculate complex frequencies by using alternating channels for width and height + height_positions = torch.arange(height, device=device, dtype=torch.float64) + width_positions = torch.arange(width, device=device, dtype=torch.float64) + frequencies = config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + # TODO Soham: apply scaling + angles_h = torch.outer(height_positions, frequencies[::2]) + angles_w = torch.outer(width_positions, frequencies[1::2]) + angles = torch.cat( + [ + angles_h[:, None, :].repeat(1, width, 1), + angles_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ).reshape(-1, kv_channels // 2) + + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not config.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + ).contiguous() + + return frequencies + + class RotaryEmbeddingPreprocessor: _scalar_dim: TensorDim - _kv_channels_dim: TensorDim - _rotary_embedding_frequencies: torch.Tensor _mask: torch.Tensor _mask_value: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 def __init__( self, config: RotaryConfig, tensor_space: TensorSpace, ): + # if isinstance(config, TransformerConfig): + # self._transformer_dim_names = TransformerDimNames + # self._transformer_kwargs = TransformerKwargs + # elif isinstance(config, VisionTransformerConfig): + # self._transformer_dim_names = VisionTransformerDimNames + # self._transformer_kwargs = VisionTransformerKwargs + # TODO Soham: better way to do this? + if config.type == RotaryEmbeddingType.pixtral: + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + else: + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs self._config = config assert self._config.enabled self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._kv_channels_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels) + self._tensor_cache_max_sequence_length: int = -1 - def create_tensors(self, sequence_length: int) -> None: + def create_tensors(self, sequence_length: int, num_patches: None | int = None) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length - self._rotary_embedding_frequencies = get_rotary_frequencies( - self._config, - sequence_length, - self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, - ) + if self._config.type == RotaryEmbeddingType.pixtral: + self._rotary_embedding_frequencies = get_2d_rotary_frequencies( + self._config, + num_patches, + num_patches, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) + else: + self._rotary_embedding_frequencies = get_rotary_frequencies( + self._config, + sequence_length, + self._kv_channels_dim.global_size, + device=self._tensor_space.distributed.device, + ) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k - ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + if self._config.type == RotaryEmbeddingType.pixtral: + position_ids = kwargs[self._transformer_kwargs.patch_position_ids] + # TODO Soham: use position_ids_q and position_ids_k for sequence_data_parallelism + kwargs[self._transformer_kwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[self._transformer_kwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] + else: + kwargs[self._transformer_kwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - sequence_q : sequence_k + ] + kwargs[self._transformer_kwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.rotary_freq_q] = TensorMeta.from_dims( ( self._scalar_dim, kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=self._transformer_kwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.rotary_freq_k] = TensorMeta.from_dims( ( self._scalar_dim, kwargs[TransformerKwargs.sequence_q_dim], self._scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=self._transformer_kwargs.rotary_freq_k, ) @@ -202,6 +263,12 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -231,22 +298,22 @@ def create_tensors(self, sequence_length: int) -> None: def preprocess(self, kwargs: dict[str, typing.Any]) -> None: sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - kwargs[TransformerKwargs.attention_mask] = self._mask[ + kwargs[self._transformer_kwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(self._transformer_kwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[TransformerKwargs.attention_mask] = ( - kwargs[TransformerKwargs.attention_mask] + kwargs[self._transformer_kwargs.attention_mask] = ( + kwargs[self._transformer_kwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[TransformerKwargs.attention_mask_value] = self._mask_value + kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.attention_mask] = TensorMeta.from_dims( ( self._scalar_dim, self._scalar_dim, @@ -254,12 +321,12 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: self._scalar_dim, kwargs[TransformerKwargs.sequence_k_dim], ), - tensor_name=TransformerKwargs.attention_mask, + tensor_name=self._transformer_kwargs.attention_mask, dtype=torch.bool, ) - kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( + kwargs[self._transformer_kwargs.attention_mask_value] = TensorMeta.from_dims( (self._scalar_dim,), - tensor_name=TransformerKwargs.attention_mask_value, + tensor_name=self._transformer_kwargs.attention_mask_value, dtype=self._tensor_space.distributed_config.training_dtype.torch, ) @@ -270,6 +337,12 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs def preprocess(self, kwargs: dict[str, typing.Any]) -> None: """ @@ -281,7 +354,7 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - sequence_lengths = kwargs.get(TransformerKwargs.sequence_lengths) + sequence_lengths = kwargs.get(self._transformer_kwargs.sequence_lengths) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size if sequence_q < kwargs[TransformerKwargs.sequence_length]: @@ -318,17 +391,17 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( + kwargs[self._transformer_kwargs.cu_seqlens_q] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( + kwargs[self._transformer_kwargs.cu_seqlens_k] = torch.cat( ( torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), ) ) - kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() + kwargs[self._transformer_kwargs.max_seqlen_q] = seqlens_q.max() + kwargs[self._transformer_kwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 311403fc..ba4e5139 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,9 +8,15 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import ( + TransformerConfig, + TransformerDimNames, + TransformerKwargs, + VisionTransformerConfig, +) from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -30,6 +36,12 @@ def __init__( return_input: bool = False, ): super().__init__() + if isinstance(config, VisionTransformerConfig): + self._transformer_dim_names = VisionTransformerDimNames + self._transformer_kwargs = VisionTransformerKwargs + elif isinstance(config, TransformerConfig): + self._transformer_dim_names = TransformerDimNames + self._transformer_kwargs = TransformerKwargs self._config = config self._tensor_space = tensor_space self._dropout_p = self._config.hidden_dropout @@ -39,7 +51,7 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) @@ -66,7 +78,7 @@ def name(self) -> str: return f"Transformer layer {self._layer_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[TransformerKwargs.hidden_dims] + dims = kwargs[self._transformer_kwargs.hidden_dims] if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index b8436f72..bf5f3f1a 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -1,35 +1,54 @@ +import typing + import torch +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.triton.mlp import torch_mlp_activation from fast_llm.layers.common.linear import Linear -from fast_llm.layers.transformer.config import TransformerDimNames -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames -from fast_llm.tensor import init_normal_ +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames +from fast_llm.tensor import TensorMeta, init_normal_ -class VisionAdapter(torch.nn.Module): +class VisionAdapter(Layer): """ Vision adapter layer for the LLM. """ - def __init__(self, intermediate_size: int, tensor_space: TensorSpace, name: str = "vision_adapter"): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() - self._name = name input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + self._activation_type = config.adapter_activation_type + # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? self.layer_1 = Linear( input_dim, - tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), bias=True, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) self.layer_2 = Linear( - tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), tensor_space.get_tensor_dim(TransformerDimNames.hidden), bias=True, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) - def forward(self, input_: torch.Tensor): - return self.layer_2(self.layer_1(input_)) + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Vision adapter output", + dtype=input_.dtype, + ) + return self.layer_2( + torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index b83a118b..7c650bf9 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,20 +1,55 @@ -from fast_llm.config import Config, Field, FieldHint, config_class -from fast_llm.engine.base_model.config import BaseModelArchitectureConfig +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class +from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.transformer.config import TransformerArchitectureConfig, VisionTransformerConfig class VisionEncoderDimNames: out_channels = "vision_out_channels" - intermediate_size = "vision_intermediate_size" + adapter_size = "vision_adapter_size" patch_size = "vision_patch_size" kv_channels = "vision_kv_channels" -class VisionModelKwargs: +class VisionTransformerDimNames: + # A set of common tensor dim names packed into a namespace. + # Input dimensions (variable) + # TODO: Does batch belong here? + batch = "vit_batch" + # TODO: Distinguish micro-sequence? + sequence_q = "vit_sequence_q" + sequence_q_tp = "vit_sequence_q_tp" + sequence_k = "vit_sequence_k" + hidden = "vit_hidden" + # Self-attention dimensions + head_groups = "vit_head_groups" + group_heads = "vit_group_heads" + key_and_value = "vit_key_value" + kv_channels = "vit_kv_channels" + composite_heads = "vit_composite_heads" + composite_query = "vit_composite_query" + composite_key_value = "vit_composite_key_value" + composite_dense = "vit_composite_dense" + # MLP dimensions + mlp = "vit_mlp" + gate_and_up = "vit_gate_and_up" + composite_gated_mlp = "vit_composite_gated_mlp" + experts = "vit_experts" + top_experts = "vit_top_experts" + shared_experts = "vit_shared_experts" + unshared_experts = "vit_unshared_experts" + composite_expert_mlp = "vit_composite_expert_mlp" + composite_gated_expert_mlp = "vit_composite_gated_expert_mlp" + composite_shared_expert_mlp = "vit_composite_shared_expert_mlp" + composite_gated_shared_expert_mlp = "vit_composite_gated_shared_expert_mlp" + + +class VisionEncoderKwargs: patch_size = "patch_size" images = "images" + image_patches = "image_patches" image_positions = "image_positions" image_size = "image_size" image_sizes = "image_sizes" @@ -24,56 +59,34 @@ class VisionModelKwargs: rope_theta = "vit_rope_theta" rotary_inv_freq = "vit_rotary_inv_freq" kv_channels = "vit_kv_channels" + max_image_tokens = "max_image_tokens" + patch_embeddings = "patch_embeddings" + hidden_dims = "vit_hidden_dims" -@config_class() -class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): - _abstract = False - """ - Configuration class for the vision encoder, which transforms images into embeddings - """ - path: str | None = Field(default=None, desc="Path to a pretrained vision encoder model.", hint=FieldHint.optional) - pre_norm: NormalizationConfig = Field( - default_factory=NormalizationConfig, - ) - hidden_size: int = Field( - default=1024, desc="The size of the hidden layers in the transformer model.", hint=FieldHint.optional - ) - intermediate_size: int = Field( - default=4096, - desc="The size of the intermediate (feed-forward) layers in the transformer model.", - hint=FieldHint.optional, - ) - num_hidden_layers: int = Field( - default=24, desc="The number of hidden layers in the transformer model.", hint=FieldHint.optional - ) - num_attention_heads: int = Field( - default=16, desc="The number of attention heads for the multi-head attention layers.", hint=FieldHint.optional - ) - num_channels: int = Field( - default=3, desc="Number of channels in the input image, typically 3 for RGB.", hint=FieldHint.optional - ) - image_size: int = Field( - default=1024, desc="The size of the input images (assumed square).", hint=FieldHint.optional - ) - patch_size: int = Field(default=16, desc="The size of the image patches to be encoded.", hint=FieldHint.optional) - hidden_act: str = Field( - default="gelu", desc="The activation function used in the hidden layers.", hint=FieldHint.optional - ) - attention_dropout: float = Field( - default=0.0, desc="The dropout probability for attention layers.", hint=FieldHint.optional - ) - rope_theta: float = Field( - default=10000.0, desc="The base value for rotary position embeddings.", hint=FieldHint.optional - ) - initializer_range: float = Field( - default=0.02, desc="The standard deviation of the normal initializer.", hint=FieldHint.optional - ) - activation_type: ActivationType = Field( - default=ActivationType.silu, - desc="The activation function used in the hidden layers. Default: SiLU.", - hint=FieldHint.optional, - ) +# TODO Soham: do we need all of them? +class VisionTransformerKwargs: + rotary_freq_q = "vit_rotary_freq_q" + rotary_freq_k = "vit_rotary_freq_k" + attention_mask = "vit_attention_mask" + attention_mask_value = "vit_attention_mask_value" + sequence_lengths = "vit_sequence_lengths" + cu_seqlens_q = "vit_cu_seqlens_q" + cu_seqlens_k = "vit_cu_seqlens_k" + max_seqlen_q = "vit_max_seqlen_q" + max_seqlen_k = "vit_max_seqlen_k" + # TODO: Review these + presents = "vit_presents" + past_key_values = "vit_past_key_values" + sequence_first = "vit_sequence_first" + hidden_dims = "vit_hidden_dims" + sequence_q_dim = "vit_sequence_q_dim" + sequence_k_dim = "vit_sequence_k_dim" + sequence_length = "vit_sequence_length" + micro_batch_size = "vit_micro_batch_size" + # TODO: Move + grad_output = "vit_grad_output" + patch_position_ids = "patch_position_ids" @config_class() @@ -116,35 +129,70 @@ class ImageNormalizationConfig(Config): @config_class() -class VisionArchitectureConfig(BaseModelArchitectureConfig): +class VisionEncoderArchitectureConfig(BaseModelArchitectureConfig): _abstract = False - encoder: VisionEncoderArchitectureConfig = Field( - default_factory=VisionEncoderArchitectureConfig, - desc="Configuration for the vision encoder that transforms images into embeddings.", + transformer: TransformerArchitectureConfig = Field( + default_factory=TransformerArchitectureConfig, + desc="Configuration for the vision transformer architecture.", + hint=FieldHint.core, + ) + patch_size: int = Field( + default=16, + desc="Patch size for the image encoder.", + hint=FieldHint.core, + ) + patch_norm: NormalizationConfig = Field( + default_factory=NormalizationConfig, + desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) adapter_size: int = Field( default=5120, desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", - hint=FieldHint.optional, + hint=FieldHint.core, ) adapter_activation_type: ActivationType = Field( default=ActivationType.gelu, desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", hint=FieldHint.core, ) - normalization: ImageNormalizationConfig = Field( + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) + # TODO Soham: add a check for presence of kv channels parameter (head_dim) + tensor_space.add_tensor_dim( + TensorDim( + VisionEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads + ) + ) + self.transformer.setup_tensor_space(tensor_space, type="vision") + + +@config_class() +class VisionEncoderConfig(VisionEncoderArchitectureConfig, BaseModelConfig): + transformer: VisionTransformerConfig = FieldUpdate( + default_factory=VisionTransformerConfig, + desc="Configuration for the transformer architecture.", + hint=FieldHint.core, + ) + patch_norm: NormalizationConfig = FieldUpdate( + default_factory=NormalizationConfig, + desc="Configuration for the normalization layers applied to the image patches.", + hint=FieldHint.optional, + ) + image_normalization: ImageNormalizationConfig = Field( default_factory=ImageNormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) + adapter_activation_type: ActivationType = FieldUpdate( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) def setup_tensor_space(self, tensor_space: TensorSpace): - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.encoder.hidden_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.intermediate_size, self.adapter_size)) - tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.encoder.patch_size)) - # TODO Soham: add a check for kv channels - tensor_space.add_tensor_dim( - TensorDim(VisionEncoderDimNames.kv_channels, self.encoder.hidden_size // self.encoder.num_attention_heads) - ) + super().setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 8c694d28..9369037d 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -1,26 +1,20 @@ -import functools import typing import torch -from transformers import PixtralVisionConfig -from transformers.models.pixtral.modeling_pixtral import PixtralTransformer from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionModelKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ -def position_ids_in_meshgrid(patch_embeddings_list, max_width): +def position_ids_in_meshgrid(patch_embeddings_list, max_size): positions = [] for patch in patch_embeddings_list: height, width = patch.shape[-2:] mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) - ids = h_grid * max_width + v_grid + ids = h_grid * max_size + v_grid positions.append(ids[:, 0]) return torch.cat(positions) @@ -41,108 +35,24 @@ def generate_block_attention_mask(patch_embeds_list, tensor): return causal_mask -# TODO Soham: should this just be nn.Module? -class VisionEncoder(Layer): - """ - A vision encoder layer for creating token embeddings from vision model - """ - - def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): +class PatchConv(Layer): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() - - self._config = config.vision_encoder - self._distributed_config = tensor_space.distributed_config + # TODO Soham: device=meta with torch.device("meta"): - # TODO Soham options to fix rotary: - # 1. load PixtralTransformer instead of PixtralVisionModel. Required to implement conv2d, ln_pre separately and store positional embeddings in kwargs_meta - # 2. set self.vision_encoder.position_embeddings = PixtralRotaryEmbedding(config) outside of meta scope - config = PixtralVisionConfig( - hidden_size=self._config.encoder.hidden_size, - intermediate_size=self._config.encoder.intermediate_size, - num_hidden_layers=self._config.encoder.num_hidden_layers, - num_attention_heads=self._config.encoder.num_attention_heads, - num_channels=self._config.encoder.num_channels, - image_size=self._config.encoder.image_size, - patch_size=self._config.encoder.patch_size, - hidden_act=self._config.encoder.hidden_act, - attention_dropout=self._config.encoder.attention_dropout, - rope_theta=self._config.encoder.rope_theta, - initializer_range=self._config.encoder.initializer_range, - ) - self.patch_conv = torch.nn.Conv2d( + self.conv = torch.nn.Conv2d( in_channels=3, - out_channels=self._config.encoder.hidden_size, - kernel_size=self._config.encoder.patch_size, - stride=self._config.encoder.patch_size, + out_channels=config.transformer.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, bias=False, + dtype=tensor_space.distributed_config.training_dtype.torch, ) - self.patch_conv.weight = ParameterMeta.from_dims( - tuple( - TensorDim(f"patch_conv_weight_{idx}", size) - for idx, size in enumerate(self.patch_conv.weight.shape) - ), + self.conv.weight = ParameterMeta.from_dims( + tuple(TensorDim(f"patch_conv_weight_{idx}", size) for idx, size in enumerate(self.conv.weight.shape)), init_method=init_normal_(), ) - self.norm = self._config.encoder.pre_norm.get_layer( - tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) - ) - self.vision_transformer = PixtralTransformer(config) - # self.vision_encoder = PixtralVisionModel(config) - param_names = [] - # gather all names first. PyTorch complains if we do it in the loop - for name, param in self.vision_transformer.named_parameters(): - param_names.append(name) - for name in param_names: - *module_path, stem = name.split(".") - module = functools.reduce(getattr, module_path, self.vision_transformer) - param = self.vision_transformer.get_parameter(name) - setattr( - module, - stem, - ParameterMeta.from_dims( - tuple(TensorDim(f"{name}_{idx}", size) for idx, size in enumerate(param.shape)), - init_method=init_normal_(), - ), - ) - # none_params = [key for key, value in module._parameters.items() if value is None] - # for key in none_params: - # module._parameters.pop(key) - self.adapter = VisionAdapter( - intermediate_size=tensor_space.get_tensor_dim(VisionEncoderDimNames.intermediate_size), - tensor_space=tensor_space, - ) - - def _forward( - self, input_: torch.Tensor, image_sizes: torch.Tensor, inv_freq: torch.Tensor, image_width: int - ) -> torch.Tensor: - patch_embeddings = self.patch_conv(input_) - patch_embeddings_list = [ - embedding[..., : image_size[0], : image_size[1]] - for embedding, image_size in zip(patch_embeddings, image_sizes) - ] - patch_embeddings = torch.cat([p.flatten(1).T for p in patch_embeddings_list], dim=0).unsqueeze(0) - patch_embeddings = self.norm(patch_embeddings) - position_ids = position_ids_in_meshgrid(patch_embeddings_list, image_width // self._config.encoder.patch_size) - freqs = inv_freq[position_ids] - with torch.autocast(device_type=input_.device.type): - cos = freqs.cos() - sin = freqs.sin() - cos = cos.to(dtype=input_.dtype) - sin = sin.to(dtype=input_.dtype) - - attention_mask = generate_block_attention_mask( - [p.shape[-2] * p.shape[-1] for p in patch_embeddings_list], patch_embeddings - ) - - (out,) = self.vision_transformer( - patch_embeddings, - attention_mask=attention_mask, - position_embeddings=(cos, sin), - output_attentions=False, - return_dict=False, - ) - - return self.adapter(out) + self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) def forward( self, @@ -150,17 +60,14 @@ def forward( kwargs: dict[str, typing.Any], losses: dict[str, typing.Any] | None = None, metrics: dict | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> torch.Tensor: + hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): - return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims], - tensor_name="Vision Output", - dtype=self._distributed_config.training_dtype.torch, - ) - return self._forward( - input_, - kwargs[VisionModelKwargs.image_sizes][:1], - kwargs[VisionModelKwargs.rotary_inv_freq], - image_width=kwargs[VisionModelKwargs.image_size], - ) - # return self.adapter(self.vision_encoder(input_, kwargs[VisionModelKwargs.image_sizes])) + return TensorMeta.from_dims(hidden_dims) + # we don't need images after this point + # image_patches = kwargs.pop(VisionEncoderKwargs.image_patches) + patch_embeddings = self.norm(self.conv(input_)) + patch_embeddings = patch_embeddings.reshape(*(x.size for x in hidden_dims)) + # Hack to pass patch embeddings to the next layer + # kwargs[VisionEncoderKwargs.patch_embeddings] = patch_embeddings + return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 154c1a16..abae6f11 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -5,7 +5,12 @@ import torchvision.transforms.v2.functional as F from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.vision_encoder.config import VisionArchitectureConfig, VisionModelKwargs +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.vision_encoder.config import ( + VisionEncoderArchitectureConfig, + VisionEncoderKwargs, + VisionTransformerKwargs, +) from fast_llm.utils import div @@ -23,11 +28,11 @@ def get_resize_dims(height: int, width: int, max_height: int, max_width: int, pa If the image is smaller, it will be resized to the nearest multiple of the patch size. """ ratio = max(height / max_height, width / max_width) - return ( - (int(height / ratio), int(width / ratio)) - if ratio > 1 - else (patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size)) - ) + if ratio > 1: + # Resize to fit within max dimensions + height = int(height / ratio) + width = int(width / ratio) + return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) def resize(image: torch.Tensor, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: @@ -72,42 +77,128 @@ def create_inv_freqs(rope_theta: int, kv_channels: int, image_size: int, patch_s return torch.cat((inv_freq, inv_freq), dim=-1) +def position_ids_in_meshgrid(image_sizes: list[torch.Tensor], max_size: int, patch_size: int) -> torch.Tensor: + positions = [] + for h, w in image_sizes: + patch_height = h // patch_size + patch_width = w // patch_size + mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_size + v_grid + positions.append(ids[:, 0]) + return positions + + +def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: + patch_height = height // patch_size + patch_width = width // patch_size + mesh = torch.meshgrid(torch.arange(patch_height), torch.arange(patch_width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_size + v_grid + return ids[:, 0] + + class VisionPreprocessor: - def __init__(self, config: VisionArchitectureConfig, tensor_space: TensorSpace): + def __init__(self, config: VisionEncoderArchitectureConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - images = kwargs.get("images") - im_height = kwargs.get(VisionModelKwargs.image_size) - im_width = kwargs.get(VisionModelKwargs.image_size) - patch_size = kwargs[VisionModelKwargs.patch_size] + images = kwargs.get(VisionEncoderKwargs.images) + im_height = kwargs.get(VisionEncoderKwargs.image_size) + im_width = kwargs.get(VisionEncoderKwargs.image_size) + patch_size = kwargs[VisionEncoderKwargs.patch_size] image_sizes = [ - get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in images + [get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in ims] + for ims in images ] - kwargs[VisionModelKwargs.image_sizes] = image_sizes + kwargs[VisionEncoderKwargs.image_sizes] = image_sizes images = [ - pad( + [ normalize( - resize(image, im_height, im_width, patch_size) / kwargs[VisionModelKwargs.image_rescale_factor], - mean=kwargs[VisionModelKwargs.image_mean], - std=kwargs[VisionModelKwargs.image_std], - ), - max_height=im_height, - max_width=im_width, - ) - for image in images + resize(image, im_height, im_width, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch + ) + / kwargs[VisionEncoderKwargs.image_rescale_factor], + mean=kwargs[VisionEncoderKwargs.image_mean], + std=kwargs[VisionEncoderKwargs.image_std], + ) + for image in imgs + ] + for imgs in images ] - images = torch.stack(images, dim=0).to( - # TODO Soham: is this needed? - device=self._tensor_space.distributed.device, - dtype=self._distributed_config.training_dtype.torch, - ) - kwargs[VisionModelKwargs.images] = images - kwargs[VisionModelKwargs.rotary_inv_freq] = create_inv_freqs( - kwargs[VisionModelKwargs.rope_theta], - kwargs[VisionModelKwargs.kv_channels], + # position_ids = position_ids_in_meshgrid(image_sizes, im_height, patch_size) + patches = [] + patch_position_ids = [] + cu_seqlens = [0] + max_seqlen = -1 + for imgs, sizes in zip(images, image_sizes): + # TODO Soham: should this be micro_sequence_length? + # sum( + # get_num_patches(*size, patch_size) for size in sizes + # ) + seq_patches = [] + for image, size in zip(imgs, sizes): + seqlen = get_num_patches(*size, patch_size) + if seqlen > max_seqlen: + max_seqlen = seqlen + cu_seqlens.append(cu_seqlens[-1] + seqlen) + seq_patches.append( + torch.cat( + [ + torch.nn.functional.unfold(image, kernel_size=patch_size, stride=patch_size).T.reshape( + -1, 3, patch_size, patch_size + ), + ] + ) + ) + padding_size = kwargs[TransformerKwargs.sequence_length] - cu_seqlens[-1] + if padding_size > max_seqlen: + max_seqlen = padding_size + cu_seqlens.append(kwargs[TransformerKwargs.sequence_length]) + patches.append( + torch.cat( + [ + *seq_patches, + torch.zeros(padding_size, 3, patch_size, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ), + ] + ) + ) + position_ids = torch.cat( + [position_ids_in_meshgrid(*size, im_height // patch_size, patch_size) for size in sizes] + ).to(device=self._tensor_space.distributed.device) + # We pad at the end instead of padding at the position in meshgrid because flash attention does not support custom attention masks + patch_position_ids.append( + torch.cat( + [ + position_ids, + torch.full((padding_size,), 0).to(device=self._tensor_space.distributed.device), + ] + ) + ) + # TODO Soham: remove + assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] + patches = torch.cat(patches) + patch_position_ids = torch.cat(patch_position_ids) + kwargs[VisionEncoderKwargs.image_patches] = patches + kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( + kwargs[VisionEncoderKwargs.rope_theta], + kwargs[VisionEncoderKwargs.kv_channels], im_height, patch_size, ).to(device=self._tensor_space.distributed.device) + kwargs[VisionEncoderKwargs.max_image_tokens] = div(im_height * im_width, patch_size**2) + kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids + # TODO Soham: handle sequence data parallel + kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.cu_seqlens_k] = torch.tensor( + cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen + kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index bd7da797..d599a114 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -165,7 +165,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters( self, hf_base_prefix: str = "", - fast_llm_offset: int = 0, + fast_llm_offset: int = 1, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers @@ -187,9 +187,18 @@ def _create_weight_converters( return converters def _create_transformer_layer_converters( - self, i: int, ignore_export: bool = False, hf_base_prefix: str = "", fast_llm_offset: int = 1 + self, + i: int, + ignore_export: bool = False, + hf_base_prefix: str = "", + fast_llm_offset: int = 1, + type: str | None = None, ) -> list[WeightConverter]: - transformer_config: TransformerConfig = self._model.config.base_model.transformer + if type is not None: + if type == "vision": + transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer + else: + transformer_config: TransformerConfig = self._model.config.base_model.transformer norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm converters = [] names_bias_cls = [ @@ -565,6 +574,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: lm_converters[-2] = ConstantExportParamConverter( export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] ) + # TODO Soham: cleaner way to get language model config converters for converter in lm_converters: if isinstance(converter, (RenameParamConverter, MappedConfigParamConverter, RopeScalingParamConverter)): # Llava uses a different name for the text config @@ -579,31 +589,36 @@ def _create_config_converters(cls) -> list[ParamConverter]: export_names=(("text_config", "hidden_size"),), ), # Image processing and conv layer - RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "image_size"),), - export_names=( - ( - "vision_config", - "image_size", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "patch_size"),), - export_names=( - ( - "vision_config", - "patch_size", - ), - ), + # TODO Soham: these options are not in the fast-llm model config. They're read from BatchConfig currently + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "encoder", "image_size"),), + # export_names=( + # ( + # "vision_config", + # "image_size", + # ), + # ), + # ), + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "encoder", "patch_size"),), + # export_names=( + # ( + # "vision_config", + # "patch_size", + # ), + # ), + # ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "patch_norm", "type"),), + fast_llm_value=NormalizationType.rms_norm, ), ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "pre_norm", "type"),), + fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm, ), # Vision Transformer RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "num_hidden_layers"),), + fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), export_names=( ( "vision_config", @@ -612,7 +627,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "hidden_size"),), + fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), export_names=( ( "vision_config", @@ -621,7 +636,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "num_attention_heads"),), + fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), export_names=( ( "vision_config", @@ -630,144 +645,213 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "intermediate_size"),), + fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), export_names=( ( "vision_config", - "intermediate_size", + "num_key_value_heads", ), ), ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "activation_type"),), - export_names=( - ( - "vision_config", - "hidden_act", - ), - ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "num_channels"),), + fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), export_names=( ( "vision_config", - "num_channels", + "intermediate_size", ), ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "attention_dropout"),), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), export_names=( ( "vision_config", - "attention_dropout", + "hidden_act", ), ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "rope_theta"),), - export_names=(("vision_config", "rope_theta"),), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False + ), + # TODO Soham: add this config param for completeness? + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "encoder", "num_channels"),), + # export_names=( + # ( + # "vision_config", + # "num_channels", + # ), + # ), + # ), + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "attention_dropout"),), + # export_names=( + # ( + # "vision_config", + # "attention_dropout", + # ), + # ), + # ), RenameParamConverter( - fast_llm_names=(("vision_encoder", "encoder", "initializer_range"),), - export_names=(("vision_config", "initializer_range"),), + fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), + export_names=(("vision_config", "rope_theta"),), ), + # TODO Soham: add this config param in vision encoder for completeness? + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "initializer_range"),), + # export_names=(("vision_config", "initializer_range"),), + # ), ] def _create_vision_transformer_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.vision_encoder.encoder.num_hidden_layers + num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers vision_transformer_converters = [] - for i in range(num_layers): - vision_transformer_converters += [ - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.k_proj.weight", - f"vision_tower.transformer.layers.{i}.attention.k_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.v_proj.weight", - f"vision_tower.transformer.layers.{i}.attention.v_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.q_proj.weight", - f"vision_tower.transformer.layers.{i}.attention.q_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention.o_proj.weight", - f"vision_tower.transformer.layers.{i}.attention.o_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.attention_norm.weight", - f"vision_tower.transformer.layers.{i}.attention_norm.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.down_proj.weight", - f"vision_tower.transformer.layers.{i}.feed_forward.down_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.gate_proj.weight", - f"vision_tower.transformer.layers.{i}.feed_forward.gate_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.feed_forward.up_proj.weight", - f"vision_tower.transformer.layers.{i}.feed_forward.up_proj.weight", - ), - WeightConverter( - f"layers.0.vision_encoder.vision_transformer.layers.{i}.ffn_norm.weight", - f"vision_tower.transformer.layers.{i}.ffn_norm.weight", - ), - ] + for layer in range(num_layers): + # TODO Soham: check if args are correct + vision_transformer_converters.extend( + self._create_vision_transformer_layer_converters( + layer, + ignore_export=False, + hf_base_prefix="vision_tower.transformer.layers.", + fast_llm_offset=1, + type="vision", + ) + ) return vision_transformer_converters def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converter = WeightConverter( - "layers.0.vision_encoder.patch_conv.weight", - "vision_tower.patch_conv.weight", - ) - # TODO Soham: use _get_weight_and_bias_converters? - layernorm_converters = [] - layer_norm_converter = WeightConverter( - "layers.0.vision_encoder.norm.weight", - "vision_tower.ln_pre.weight", - ) - layernorm_converters.append(layer_norm_converter) - layer_norm_converter - if self._model.config.base_model.vision_encoder.encoder.pre_norm.type == NormalizationType.layer_norm: - layer_norm_bias_converter = WeightConverter( - "layers.0.vision_encoder.norm.bias", - "vision_tower.ln_pre.bias", - ) - layernorm_converters.append(layer_norm_bias_converter) + patch_conv_converter = WeightConverter("layers.0.conv.weight", "vision_tower.patch_conv.weight") + layernorm_converters = [ + WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), + ] + if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) + vision_transformer_converters = self._create_vision_transformer_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 adapter_converters = [ - WeightConverter( - "layers.0.vision_encoder.adapter.layer_1.weight", - "multi_modal_projector.linear_1.weight", - ), - WeightConverter( - "layers.0.vision_encoder.adapter.layer_1.bias", - "multi_modal_projector.linear_1.bias", - ), - # TODO Soham: conditionally add bias - WeightConverter( - "layers.0.vision_encoder.adapter.layer_2.weight", - "multi_modal_projector.linear_2.weight", - ), - WeightConverter( - "layers.0.vision_encoder.adapter.layer_2.bias", - "multi_modal_projector.linear_2.bias", - ), + WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), + WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), + # TODO Soham: add bias based on config + WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), + WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), ] + return [patch_conv_converter] + layernorm_converters + vision_transformer_converters + adapter_converters def _create_weight_converters(self) -> list[WeightConverter]: vision_encoder_converter = self._create_vision_encoder_weight_converters() - lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=1) + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 + lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) return vision_encoder_converter + lm_converters + def _create_vision_transformer_layer_converters( + self, + i: int, + ignore_export: bool = False, + hf_base_prefix: str = "", + fast_llm_offset: int = 1, + type: str | None = None, + ) -> list[WeightConverter]: + if type is not None: + if type == "vision": + transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer + else: + transformer_config: TransformerConfig = self._model.config.base_model.transformer + norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + converters = [] + names_bias_cls = [ + # Self-attn + ( + f"layers.{i+fast_llm_offset}.self_attn.query", + f"vision_tower.transformer.layers.{i}.attention.q_proj", + transformer_config.add_attn_qkv_bias, + QueryWeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.self_attn.key_value", + ( + f"vision_tower.transformer.layers.{i}.attention.k_proj", + f"vision_tower.transformer.layers.{i}.attention.v_proj", + ), + transformer_config.add_attn_qkv_bias, + KeyValueWeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.self_attn.dense", + f"vision_tower.transformer.layers.{i}.attention.o_proj", + transformer_config.add_attn_dense_bias, + WeightConverter, + ), + # Norm + ( + f"layers.{i+fast_llm_offset}.norm_1", + f"vision_tower.transformer.layers.{i}.attention_norm", + norm_bias, + WeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.norm_2", + f"vision_tower.transformer.layers.{i}.ffn_norm", + norm_bias, + WeightConverter, + ), + ] + for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: + converters += self._get_weight_and_bias_converters( + fast_llm_prefix, + () if ignore_export else hf_prefix, + use_bias, + cls=IgnoreExportWeightConverter if ignore_export else cls, + ) + + # MLP + if ignore_export: + converters += self._get_weight_and_bias_converters( + f"layers.{i+fast_llm_offset}.mlp.layer_1", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+fast_llm_offset}.mlp.layer_2", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, + ) + converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] + else: + converters += self._get_vision_transformer_mlp_converters( + f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" + ) + return converters + + def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.feed_forward.down_proj.weight", + self._model.config.base_model, + ), + ] + class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c273f09b..6aef273f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -27,7 +27,10 @@ RotaryEmbeddingPreprocessor, ) from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionModelKwargs +from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer +from fast_llm.layers.vision_encoder.adapter import VisionAdapter +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs, VisionTransformerDimNames +from fast_llm.layers.vision_encoder.encoder import PatchConv from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -76,6 +79,10 @@ def __init__( if self._config.vision_encoder: self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) + if self._config.vision_encoder.transformer.rotary.enabled: + self._vision_rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( + self._config.vision_encoder.transformer.rotary, self._tensor_space + ) def get_output_layers(self) -> list[Layer]: return [ @@ -99,22 +106,35 @@ def get_output_layers(self) -> list[Layer]: ] ] + def get_vision_layers(self) -> list[Layer]: + patch_conv = PatchConv(self._config.vision_encoder, self._tensor_space) + vit_layers = [ + VisionTransformerLayer(self._config.vision_encoder.transformer, self._tensor_space, layer_index=idx + 1) + for idx in range(self._config.vision_encoder.transformer.num_layers) + ] + return [ + patch_conv, + *vit_layers, + VisionAdapter(self._config.vision_encoder, self._tensor_space), + MultiModalEmbedding(self._config, self._tensor_space), + ] + def get_layers(self) -> list[Layer]: if self._config.transformer.num_layers == 0: Assert.eq(self._config.prediction_heads, 1) return [ - ( - LanguageModelEmbedding(self._config, self._tensor_space) + *( + [LanguageModelEmbedding(self._config, self._tensor_space)] if self._config.vision_encoder is None - else MultiModalEmbedding(self._config, self._tensor_space) + else self.get_vision_layers(self._config, self._tensor_space) ), LanguageModelHead(self._config, self._tensor_space, 0), ] return [ - ( - LanguageModelEmbedding(self._config, self._tensor_space) + *( + [LanguageModelEmbedding(self._config, self._tensor_space)] if self._config.vision_encoder is None - else MultiModalEmbedding(self._config, self._tensor_space) + else self.get_vision_layers() ), *[ TransformerLayer( @@ -152,24 +172,24 @@ def preprocess_meta( if self._config.vision_encoder: image_size = batch_meta.max_image_size image_mean = [ - self._config.vision_encoder.normalization.mean_r, - self._config.vision_encoder.normalization.mean_g, - self._config.vision_encoder.normalization.mean_b, + self._config.vision_encoder.image_normalization.mean_r, + self._config.vision_encoder.image_normalization.mean_g, + self._config.vision_encoder.image_normalization.mean_b, ] image_std = [ - self._config.vision_encoder.normalization.std_r, - self._config.vision_encoder.normalization.std_g, - self._config.vision_encoder.normalization.std_b, + self._config.vision_encoder.image_normalization.std_r, + self._config.vision_encoder.image_normalization.std_g, + self._config.vision_encoder.image_normalization.std_b, ] - image_rescale_factor = self._config.vision_encoder.normalization.rescale_factor + image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor vision_kwargs = { - VisionModelKwargs.patch_size: self._config.vision_encoder.encoder.patch_size, - VisionModelKwargs.image_size: image_size, - VisionModelKwargs.image_mean: image_mean, - VisionModelKwargs.image_std: image_std, - VisionModelKwargs.image_rescale_factor: image_rescale_factor, - VisionModelKwargs.rope_theta: self._config.vision_encoder.encoder.rope_theta, - VisionModelKwargs.kv_channels: self._tensor_space.get_tensor_dim( + VisionEncoderKwargs.patch_size: batch_meta.patch_size, + VisionEncoderKwargs.image_size: image_size, + VisionEncoderKwargs.image_mean: image_mean, + VisionEncoderKwargs.image_std: image_std, + VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, + VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, + VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( VisionEncoderDimNames.kv_channels ).size, } @@ -218,6 +238,18 @@ def preprocess_meta( if sequence_first else (batch_dim, hidden_sequence_q_dim, hidden_dim) ) + if self._config.vision_encoder: + vision_hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) + vision_hidden_dims = ( + (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) + if sequence_first + else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) + ) + vision_kwargs.update( + { + VisionEncoderKwargs.hidden_dims: vision_hidden_dims, + } + ) common_kwargs = { LanguageModelKwargs.phase: phase, @@ -225,6 +257,7 @@ def preprocess_meta( TransformerKwargs.hidden_dims: hidden_dims, TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_q_dim: sequence_q_dim, + TransformerKwargs.micro_batch_size: micro_batch_size, } common_kwargs.update(vision_kwargs) @@ -253,6 +286,9 @@ def preprocess_meta( self._position_embedding_preprocessor.preprocess_meta(kwargs) if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor.preprocess_meta(kwargs) + if self._config.vision_encoder: + if self._config.vision_encoder.transformer.rotary.enabled: + self._vision_rotary_embedding_preprocessor.preprocess_meta(kwargs) if not self._use_flash_attention: self._backup_attention_preprocessor.preprocess_meta(kwargs) preprocessed_meta.append((tokens, kwargs)) @@ -294,6 +330,11 @@ def preprocess( self._rotary_embedding_preprocessor.create_tensors(sequence_length) if not self._use_flash_attention: self._backup_attention_preprocessor.create_tensors(sequence_length) + if self._config.vision_encoder and self._config.vision_encoder.transformer.rotary.enabled: + max_num_patches = ( + common_kwargs[VisionEncoderKwargs.image_size] // common_kwargs[VisionEncoderKwargs.patch_size] + ) + self._vision_rotary_embedding_preprocessor.create_tensors(sequence_length, max_num_patches) preprocessed = [] presents = None @@ -342,32 +383,38 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels - if batch.images is not None: - kwargs[VisionModelKwargs.images] = [ - img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) - for images in batch.images - for img in images - ] - kwargs[VisionModelKwargs.image_positions] = batch.image_positions - if self._config.vision_encoder: - self._vision_preprocessor.preprocess(kwargs) if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs) if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor.preprocess(kwargs) if not self._use_flash_attention: self._backup_attention_preprocessor.preprocess(kwargs) - preprocessed.append((tokens, kwargs)) + if batch.images is not None: + kwargs[VisionEncoderKwargs.images] = [ + [ + img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for img in images + ] + for images in batch.images + ] + kwargs[VisionEncoderKwargs.image_positions] = batch.image_positions + if self._config.vision_encoder: + self._vision_preprocessor.preprocess(kwargs) + self._vision_rotary_embedding_preprocessor.preprocess(kwargs) + kwargs[LanguageModelKwargs.tokens] = tokens + preprocessed.append((kwargs[VisionEncoderKwargs.image_patches], kwargs)) + else: + preprocessed.append((tokens, kwargs)) return preprocessed @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[self._config.vision_encoder is not None] + return self.layers[self._config.vision_encoder.transformer.num_layers + 2] @property def transformer_layers(self) -> list[TransformerLayer]: - return self.layers[(self._config.vision_encoder is not None) + 1 : -1] + return self.layers[self._config.vision_encoder.transformer.num_layers + 3 : -1] @property def model_head(self) -> LanguageModelHead: diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index e9df18ed..b1f14ccc 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -32,7 +32,6 @@ def fast_llm(args=None): sys.exit(1) except Exception: # noqa logger.critical(traceback.format_exc()) - sys.exit(1) if __name__ == "__main__": From 99ad5d9bda84eea74e377c8cc75f7184bb0dcc76 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 May 2025 18:34:24 +0000 Subject: [PATCH 14/82] patches and fixes --- fast_llm/layers/language_model/config.py | 10 ++++++---- fast_llm/layers/vision_encoder/config.py | 2 ++ fast_llm/layers/vision_encoder/encoder.py | 2 +- .../layers/vision_encoder/preprocessing.py | 20 ++++++++++++++++++- fast_llm/models/gpt/model.py | 10 +++++++--- 5 files changed, 35 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 887952d7..ef0e7a5c 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -45,8 +45,9 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.core, ) - vision_encoder: None | VisionEncoderArchitectureConfig = Field( - default=None, + # TODO Soham: make this None by default. Need to figure out how to handle this in the config (see ) + vision_encoder: VisionEncoderArchitectureConfig = Field( + default_factory=VisionEncoderArchitectureConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) @@ -131,8 +132,9 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_class = LanguageModelArchitectureConfig transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) - vision_encoder: None | VisionEncoderConfig = FieldUpdate( - default=None, + # TODO Soham: make this None by default. Need to figure out how to handle this in the config + vision_encoder: VisionEncoderConfig = FieldUpdate( + default_factory=VisionEncoderConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 7c650bf9..28351372 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -7,6 +7,7 @@ class VisionEncoderDimNames: + in_channels = "vision_in_channels" out_channels = "vision_out_channels" adapter_size = "vision_adapter_size" patch_size = "vision_patch_size" @@ -62,6 +63,7 @@ class VisionEncoderKwargs: max_image_tokens = "max_image_tokens" patch_embeddings = "patch_embeddings" hidden_dims = "vit_hidden_dims" + image_patches_meta = "vit_image_patches_meta" # TODO Soham: do we need all of them? diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 9369037d..ed6fbc92 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -63,7 +63,7 @@ def forward( ) -> torch.Tensor: hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): - return TensorMeta.from_dims(hidden_dims) + return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) # we don't need images after this point # image_patches = kwargs.pop(VisionEncoderKwargs.image_patches) patch_embeddings = self.norm(self.conv(input_)) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index abae6f11..c087cf6d 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -4,13 +4,16 @@ import torch import torchvision.transforms.v2.functional as F -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import ( VisionEncoderArchitectureConfig, + VisionEncoderDimNames, VisionEncoderKwargs, + VisionTransformerDimNames, VisionTransformerKwargs, ) +from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -104,6 +107,21 @@ def __init__(self, config: VisionEncoderArchitectureConfig, tensor_space: Tensor self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + # kwargs[VisionEncoderDimNames] + kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( + ( + TensorDim( + VisionTransformerDimNames.batch, + kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, + ), + TensorDim(VisionEncoderDimNames.in_channels, 3), + TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + ), + dtype=self._distributed_config.training_dtype.torch, + ) + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get(VisionEncoderKwargs.images) im_height = kwargs.get(VisionEncoderKwargs.image_size) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6aef273f..5425a1e1 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -286,12 +286,16 @@ def preprocess_meta( self._position_embedding_preprocessor.preprocess_meta(kwargs) if self._config.transformer.rotary.enabled: self._rotary_embedding_preprocessor.preprocess_meta(kwargs) + if not self._use_flash_attention: + self._backup_attention_preprocessor.preprocess_meta(kwargs) if self._config.vision_encoder: + self._vision_preprocessor.preprocess_meta(kwargs) if self._config.vision_encoder.transformer.rotary.enabled: self._vision_rotary_embedding_preprocessor.preprocess_meta(kwargs) - if not self._use_flash_attention: - self._backup_attention_preprocessor.preprocess_meta(kwargs) - preprocessed_meta.append((tokens, kwargs)) + # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size + preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) + else: + preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta From bcb557aca291afcbb2e19969d2e7e1da16a93612 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 May 2025 20:44:40 +0000 Subject: [PATCH 15/82] fix dependency --- Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index b8e1f888..149a498e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,8 +3,7 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ - && apt-get install --no-install-recommends -y acl git-lfs \ - && apt-get install --no-install-recommends -y libtiff5-dev \ + && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install From a6f5364d33c8d80ff46ea592612362fd03f85f30 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 May 2025 20:49:53 +0000 Subject: [PATCH 16/82] remove for testing --- Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 149a498e..b7e42d4d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,8 @@ FROM nvcr.io/nvidia/pytorch:24.11-py3 # Install dependencies. RUN apt-get update \ - && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \ + && apt-get install --no-install-recommends -y acl git-lfs \ + # && apt-get install --no-install-recommends -y acl git-lfs libtiff5-dev \ && rm -rf /var/lib/apt/lists/* \ && git lfs install From 73b431b22d0c4b54d41d25a4dcf0738c5a1b1711 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 7 May 2025 21:57:17 +0000 Subject: [PATCH 17/82] mising --- .../layers/transformer/vision_transformer.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 fast_llm/layers/transformer/vision_transformer.py diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py new file mode 100644 index 00000000..94a9c70a --- /dev/null +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -0,0 +1,55 @@ +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.tensor import TensorMeta + + +class VisionTransformerLayer(TransformerLayer): + """ + A vision transformer layer to encode image patches + """ + + def __init__( + self, + config: TransformerConfig, + tensor_space: TensorSpace, + layer_index: int, + return_input: bool = False, + ): + super().__init__(config, tensor_space, layer_index, return_input) + + hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) + self.norm_1 = self._config.normalization.get_layer(hidden_dim) + self.norm_2 = self._config.normalization.get_layer(hidden_dim) + + self.norm_1 = self._config.peft.apply_other(self.norm_1) + self.norm_2 = self._config.peft.apply_other(self.norm_2) + + @property + def name(self) -> str: + return f"Vision transformer layer {self._layer_index}" + + def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): + dims = kwargs[VisionTransformerKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) + + # TODO Soham: remove this since we only need to call the parent method + # def forward( + # self, + # input_: torch.Tensor, + # kwargs: dict[str, typing.Any], + # losses: dict[str, typing.Any] | None = None, + # metrics: dict[str, typing.Any] | None = None, + # ) -> torch.Tensor: + # if isinstance(input_, TensorMeta): + # return self._get_meta(input_, "output", kwargs) + # # Hack for now to compute the patch embeddings + # kwargs[VisionTransformerKwargs.patch_embeddings] = super().forward( + # kwargs.pop(VisionTransformerKwargs.patch_embeddings), kwargs, losses, metrics + # ) + # return input_ From ec2c9fb5d01304a8b024c663129df7dd63ec431e Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 7 May 2025 22:37:45 +0000 Subject: [PATCH 18/82] initial data prep --- fast_llm/data/dataset/gpt/memmap.py | 12 ++++++ fast_llm/data/dataset/gpt/sampled.py | 2 + fast_llm/data/preparator/gpt_memmap/config.py | 6 +++ .../data/preparator/gpt_memmap/prepare.py | 34 ++++++++++------ fast_llm/data/tokenizer.py | 40 ++++++++++++------- 5 files changed, 68 insertions(+), 26 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 8651b8fc..752c83ce 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -258,6 +258,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP image_lengths = [] im_positions = [] total_images = 0 + n_audio = [] + audio_lengths = [] + aud_positions = [] + total_audio = 0 pointers = [] offset = 0 # number of spans for each document @@ -295,6 +299,14 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.append(document.image_positions) + if document.audio: + n_audio.append(len(document.audio)) + total_audio += len(document.audio) + for audio in document.audio: + audio_lengths.append(len(audio)) + bin_stream.write(audio.to_bytes(order="C")) + # total_aud_size += + aud_positions.append(document.audio_positions) # Update metadata doc_length = len(document.token_ids) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 0ba3f0e1..d700dcc0 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -33,6 +33,8 @@ class GPTSample: loss_masking_spans: np.ndarray | None = None images: np.ndarray | None = None image_positions: np.ndarray | None = None + audio: np.ndarray | None = None + audio_positions: np.ndarray | None = None sequence_lengths: np.ndarray | None = None diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 89fe904c..a56d766a 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -68,6 +68,12 @@ class GPTHuggingfaceDatasetConfig(Config): images: None | str = Field( default=None, desc="Field containing images relevant to a document", hint=FieldHint.optional ) + audio_positions: None | str = Field( + default=None, desc="Field containing audio positions within a document", hint=FieldHint.optional + ) + audio: None | str = Field( + default=None, desc="Field containing audio relevant to a document", hint=FieldHint.optional + ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 4965dfdf..19858cbc 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -50,22 +50,25 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) # for text in batch[self._config.dataset.field] # ] - input_ids, image_token_positions = map( + input_ids, image_token_positions, audio_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), np.array(image_token_positions, dtype=np.int32), + np.array(audio_token_positions, dtype=np.int32), ) - for input_ids, image_token_positions in [ + for input_ids, image_token_positions, audio_token_positions in [ self._tokenizer.tokenize( text, im_char_positions, + aud_char_positions, ) - for text, im_char_positions in zip( + for text, im_char_positions, aud_char_positions in zip( batch[self._config.dataset.field], batch.get(self._config.dataset.image_positions, itertools.repeat(None)), + batch.get(self._config.dataset.audio_positions, itertools.repeat(None)), ) ] ] @@ -82,6 +85,7 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ return { "input_ids": input_ids, "image_positions": image_token_positions, + "audio_token_positions": audio_token_positions, "num_tokens": num_tokens, "num_pixels": num_pixels, } @@ -143,6 +147,8 @@ def _document_generator(): # [np.array(im) for im in item["images"]] if self._config.dataset.images else None, item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, + np.array(item[self._config.dataset.audio]) if self._config.dataset.audio else None, + item[self._config.dataset.audio_positions] if self._config.dataset.audio_positions else None, ) # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): @@ -167,15 +173,19 @@ def _document_generator(): ) def _load_dataset(self) -> datasets.Dataset: - dataset = datasets.load_dataset( - path=self._config.dataset.path, - name=self._config.dataset.config_name, - data_dir=self._config.dataset.data_directory, - data_files=self._config.dataset.data_files, - split=self._config.dataset.split, - num_proc=self._config.loading_workers, - trust_remote_code=self._config.dataset.trust_remote_code, - ) + try: + dataset = datasets.load_dataset( + path=self._config.dataset.path, + name=self._config.dataset.config_name, + data_dir=self._config.dataset.data_directory, + data_files=self._config.dataset.data_files, + split=self._config.dataset.split, + num_proc=self._config.loading_workers, + trust_remote_code=self._config.dataset.trust_remote_code, + ) + except: + # backup if dataset is saved in arrow format (can we auto-detect this?) + dataset = datasets.load_from_disk(dataset_path=self._config.dataset.data_directory) assert isinstance(dataset, datasets.Dataset) return dataset diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 0e7d5470..98cfbb85 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -42,33 +42,45 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: + ([self.eod_id] if end else []) ) - def tokenize(self, text, image_positions=None): - if not image_positions: + def tokenize(self, text, image_positions=None, audio_positions=None): + image_positions = image_positions or [] + audio_positions = audio_positions or [] + if len(set(image_positions).intersection(audio_positions)) > 0: + raise ValueError("Image and audio can not have the same position.") + multimodal_positions = sorted(image_positions + audio_positions) + if not multimodal_positions: return self._tokenize(text), [], [] - image_idx = 0 + multimodel_idx = 0 char_pos = 0 token_ids = [] image_token_positions = [] + audio_token_positions = [] beginning_of_text = True - while image_idx < len(image_positions): - if image_positions[image_idx] > len(text): + while multimodel_idx < len(multimodal_positions): + multimodal_char_pos = multimodal_positions[multimodel_idx] + multimodal_type = "image" if multimodal_char_pos in image_positions else "audio" + + if multimodal_char_pos > len(text): raise ValueError( - f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" + f"{multimodal_type.capitalize()} position {multimodal_char_pos} is greater than text length {len(text)}" ) - curr_text = text[char_pos : image_positions[image_idx]] - tokenized_text = self._tokenize( - curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) - ) + curr_text = text[char_pos:multimodal_char_pos] + tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=multimodal_char_pos >= len(text)) beginning_of_text = False token_ids.extend(tokenized_text) - image_token_positions = len(token_ids) - char_pos = image_positions[image_idx] - image_idx += 1 + + # store multimodal token positions + if multimodal_type == "image": + image_token_positions.append(len(token_ids)) + else: + audio_token_positions.append(len(token_ids)) + char_pos = multimodal_char_pos + multimodel_idx += 1 if char_pos < len(text): curr_text = text[char_pos:] tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) token_ids.extend(tokenized_text) - return token_ids, image_token_positions + return token_ids, image_token_positions, audio_token_positions def tokenize_with_spans( self, text: str, char_spans: list[tuple[int, int]] From 6d6567673450e3e97ae07879957a55875ec80caf Mon Sep 17 00:00:00 2001 From: root Date: Thu, 8 May 2025 06:11:54 +0000 Subject: [PATCH 19/82] fix --- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/layers/multi_modal/embedding.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 0ba3f0e1..2f80ee77 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -485,7 +485,7 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if unshuffled_tokens := data.get("unshuffled_tokens") is not None: + if (unshuffled_tokens := data.get("unshuffled_tokens")) is not None: self._unshuffled_tokens = unshuffled_tokens else: self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"] diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index a3abe781..b7d79dd3 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -43,7 +43,8 @@ def forward( image_positions = kwargs.get(VisionEncoderKwargs.image_positions) tokens = kwargs.get(LanguageModelKwargs.tokens) # get text embeddings - embeddings = super()._forward(tokens, position_ids) + # TODO Soham: cloning to avoid pytorch complaint about in-place operation. Can we do better? + embeddings = super()._forward(tokens, position_ids).clone() image_idx = 0 for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 From 82e4edba3558df9531ae6b5226b44bc50b7c4952 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 9 May 2025 01:31:42 +0000 Subject: [PATCH 20/82] audio dataset changes --- fast_llm/data/data/gpt/data.py | 9 +++ fast_llm/data/dataset/gpt/config.py | 3 + fast_llm/data/dataset/gpt/indexed.py | 8 +- fast_llm/data/dataset/gpt/memmap.py | 105 ++++++++++++++++++++++++--- fast_llm/data/dataset/gpt/sampled.py | 14 +++- fast_llm/engine/schedule/config.py | 15 ++++ fast_llm/models/gpt/trainer.py | 3 + 7 files changed, 144 insertions(+), 13 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 34b86f21..681d4443 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -92,6 +92,9 @@ def __init__( cross_document_attention: bool = True, patch_size: list[int] | None = None, max_image_size: int | None = None, + aud_downsampling_k: int | None = None, + aud_padding_duration: int | None = None, + aud_sampling_rate: int | None = None, ): """ Create the data and gather some basic information on the dataset(s). @@ -103,6 +106,9 @@ def __init__( self._cross_document_attention = cross_document_attention self._patch_size = patch_size self._max_image_size = max_image_size + self._aud_downsampling_k = aud_downsampling_k + self._aud_padding_duration = aud_padding_duration + self._aud_sampling_rate = aud_sampling_rate def setup( self, @@ -152,6 +158,9 @@ def setup( cross_document_attention=self._cross_document_attention, patch_size=self._patch_size, image_size=self._max_image_size, + aud_downsampling_k=self._aud_downsampling_k, + aud_padding_duration=self._aud_padding_duration, + aud_sampling_rate=self._aud_sampling_rate, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 65adf0bd..aeb57ffe 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -75,6 +75,9 @@ class GPTSamplingData(SamplingData): cross_document_attention: bool = True patch_size: int | None = None image_size: int | None = None + aud_downsampling_k: int | None = None + aud_padding_duration: int | None = None + aud_sampling_rate: int | None = None @config_class() diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 209c6e31..1bbd30c7 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -45,8 +45,12 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. - doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] + doc_sizes, im_sizes, aud_sizes = self._dataset.get_document_sizes() + return ( + doc_sizes[self._begin : self._end], + im_sizes[self._begin : self._end], + aud_sizes[self._begin : self._end], + ) def get_document_size(self, index: int, patch_size: list[int]) -> int: return self._dataset.get_document_size(self._begin + index, patch_size) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 752c83ce..88e31d78 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -59,6 +59,9 @@ def _init( if self._version >= 4: self._has_images = struct.unpack("= 5: + self._has_audio = struct.unpack("= 5: + self._n_audio = np.frombuffer( + self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset + ) + self._audio_lengths = [] + self._audio_positions = [] + audio_seen = 0 + + offset = offset + self._n_audio.nbytes + for n_audio in self._n_audio: + self._audio_lengths.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_audio, + offset=offset + audio_seen * np.dtype(np.int32).itemsize, + ).reshape(-1, 2) + ) + # self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum() + self._audio_positions.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=n_audio, + offset=offset + + self._n_audio.sum() * np.dtype(np.int32).itemsize + + audio_seen * np.dtype(np.int32).itemsize, + ) + ) + audio_seen += n_audio self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) @@ -193,8 +227,30 @@ def get( n_pixels = image_length.prod(initial=3) images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels + + if self._has_audio: + audio_positions = self._audio_positions[idx] + all_audio = np.frombuffer( + self._bin_buffer, + dtype=np.dtype(np.float32), + count=self._audio_lengths[idx].sum(), + offset=self._pointers[idx] + + self._document_sizes[idx] * np.dtype(self._dtype).itemsize + + self._image_lengths.prod(initial=3) * np.dtype(np.uint8).itemsize, + ) + audio = [] + start = 0 + for audio_length in self._audio_lengths[idx]: + audio.append(all_audio[start : start + audio_length]) + start += audio_length # TODO Soham: return loss_masking_spans - return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) + return GPTSample( + token_ids=token_ids, + images=images, + image_positions=image_positions, + audio=audio, + audio_positions=audio_positions, + ) # def get( # self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False @@ -231,6 +287,10 @@ def num_tokens(self) -> int: def has_images(self) -> bool: return self._has_images + @property + def has_audio(self) -> bool: + return self._has_audio + # TODO: image sizes def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ @@ -238,7 +298,7 @@ def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes, self._image_lengths + return self._document_sizes, self._image_lengths, self._audio_lengths def get_document_size(self, index: int, patch_size: list[int]) -> int: # return self._document_sizes[index].item() + ( @@ -246,7 +306,10 @@ def get_document_size(self, index: int, patch_size: list[int]) -> int: # if self._has_images # else 0 # ) - return self._document_sizes[index].item(), self._image_lengths[index] if self._has_images else [] + docsize = self._document_sizes[index].item() + imagesize = self._image_lengths[index] if self._has_images else [] + audiosize = self._audio_lengths if self._has_audio else 0 + return docsize, imagesize, audiosize @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): @@ -285,6 +348,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Write document to binary file bin_stream.write(document.token_ids.tobytes(order="C")) total_im_size = 0 + total_aud_size = 0 if document.images: n_images.append(len(document.images)) total_images += len(document.images) @@ -299,13 +363,13 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size im_positions.append(document.image_positions) - if document.audio: + if document.audio is not None: n_audio.append(len(document.audio)) total_audio += len(document.audio) for audio in document.audio: audio_lengths.append(len(audio)) - bin_stream.write(audio.to_bytes(order="C")) - # total_aud_size += + bin_stream.write(audio.tobytes(order="C")) + total_aud_size += audio.size aud_positions.append(document.audio_positions) # Update metadata @@ -315,7 +379,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) - offset += doc_length * np.dtype(dtype).itemsize + total_im_size * np.dtype(np.uint8).itemsize + offset += ( + doc_length * np.dtype(dtype).itemsize + + total_im_size * np.dtype(np.uint8).itemsize + + total_aud_size * np.dtype(np.float32).itemsize + ) num_documents += 1 # Finalize metadata arrays @@ -329,10 +397,21 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if total_images: n_images = np.array(n_images, dtype=np.int32) + image_lengths = np.stack(image_lengths, dtype=np.int32) + im_positions = np.array(im_positions, dtype=np.int32) else: n_images = np.array([]) - image_lengths = np.stack(image_lengths, dtype=np.int32) - im_positions = np.array(im_positions, dtype=np.int32) + image_lengths = np.array([]) + im_positions = np.array([]) + + if total_audio: + n_audio = np.array(n_audio, dtype=np.int32) + audio_lengths = np.array(audio_lengths, dtype=np.int32) + aud_positions = np.array(aud_positions, dtype=np.int32) + else: + n_audio = np.array([]) + audio_lengths = np.array([]) + aud_positions = np.array([]) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: @@ -340,7 +419,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # Indicates the version # Version 2 onwards optionally add loss-masking spans # Version 4 onwards optionally add images - idx_stream.write(struct.pack(" 0 else 0)) # Placeholder flag for preference spans @@ -367,5 +446,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP idx_stream.write(image_lengths.tobytes(order="C")) # Position of each image in the document idx_stream.write(im_positions.tobytes(order="C")) + # Number of audio per document + idx_stream.write(n_audio.tobytes(order="C")) + # Audio lengths + idx_stream.write(audio_lengths.tobytes(order="C")) + # Position of each audio in the document + idx_stream.write(aud_positions.tobytes(order="C")) # Document indices, unused but needed for compatibility with Megatron-LM idx_stream.write(np.arange(num_documents + 1, dtype=np.int64).tobytes(order="C")) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index d700dcc0..d2a53066 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -94,6 +94,9 @@ def __init__( self._sequence_length = sampling.sequence_length self._patch_size = sampling.patch_size self._image_size = sampling.image_size + self._aud_downsampling_k = sampling.aud_downsampling_k + self._aud_padding_duration = sampling.aud_padding_duration + self._aud_sampling_rate = sampling.aud_sampling_rate self._cross_document_attention = sampling.cross_document_attention self._config = sampling.config self._truncate_documents = sampling.truncate_documents @@ -138,13 +141,22 @@ def _sample(self) -> None: """ # Get the document sizes, the main information needed for sampling. # TODO Soham: verify numpy correctness - document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() + document_sizes, image_sizes, audio_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) image_token_sizes = torch.zeros_like(document_sizes).to(self._device) # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size) * (sizes[:, 1] // self._patch_size)) + # compute audio token sizes + if self._aud_padding_duration > 0 and len(audio_sizes) > 0: + self._aud_padding_duration * self._aud_sampling_rate + # 2. mel spectogram + + # 3. convolution + + # 4. downsampling + documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 517a9cff..0f692482 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -66,6 +66,21 @@ class BatchConfig(Config): desc="Applies attention to tokens from other documents in the packed sequence. Set to False for masking attention to other documents.", hint=FieldHint.feature, ) + aud_downsampling_k: int = Field( + default=5, + desc="Audio downsampling k parameter.", + hint=FieldHint.feature, + ) + aud_padding_duration: int = Field( + default=-1, + desc="Audio padding duration in seconds.", + hint=FieldHint.feature, + ) + aud_sampling_rate: int = Field( + default=16000, + desc="Audio sampling rate to use.", + hint=FieldHint.feature, + ) _distributed: DistributedConfig = Field( init=False, desc="Pointer to a distributed configuration, required to know the data-parallel split of the batch.", diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index bc16829b..e2ce3fd9 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -23,6 +23,9 @@ def _get_data(self) -> GPTData: cross_document_attention=self._config.batch.cross_document_attention, patch_size=self._config.batch.patch_size, max_image_size=self._config.batch.max_image_size, + aud_downsampling_k=self._config.batch.aud_downsampling_k, + aud_padding_duration=self._config.batch.aud_padding_duration, + aud_sampling_rate=self._config.batch.aud_sampling_rate, ) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: From 66e708170d98bd476e679fcbaf6fbf761b284388 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 9 May 2025 18:39:55 +0000 Subject: [PATCH 21/82] fixes --- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/layers/language_model/config.py | 6 +----- fast_llm/layers/vision_encoder/config.py | 6 +++--- fast_llm/layers/vision_encoder/preprocessing.py | 7 ++++--- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index cb6d6c8d..54564a21 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -278,7 +278,7 @@ def _sample(self) -> None: # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes + image_token_sizes + document_sizes + image_token_sizes, offset=0, # TODO: Allowing for max 100% extra tokens for padding, is that enough? dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 45104420..ab570780 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,12 +5,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl -<<<<<<< HEAD -from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig -from fast_llm.layers.vision_encoder.config import VisionEncoderArchitectureConfig, VisionEncoderConfig -======= from fast_llm.layers.transformer.config import TransformerConfig ->>>>>>> main +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig from fast_llm.utils import Assert diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index b15f90bd..345b118e 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,9 +1,9 @@ -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class -from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig +from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationConfig -from fast_llm.layers.transformer.config import TransformerArchitectureConfig, VisionTransformerConfig +from fast_llm.layers.transformer.config import VisionTransformerConfig class VisionEncoderDimNames: diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index c087cf6d..7bd8a2aa 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -4,10 +4,11 @@ import torch import torchvision.transforms.v2.functional as F +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import ( - VisionEncoderArchitectureConfig, + VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs, VisionTransformerDimNames, @@ -101,8 +102,8 @@ def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tenso return ids[:, 0] -class VisionPreprocessor: - def __init__(self, config: VisionEncoderArchitectureConfig, tensor_space: TensorSpace): +class VisionPreprocessor(Preprocessor): + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config From 0d1cd96984612c22e012870ef8db092d65aa8a2a Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 9 May 2025 18:48:49 +0000 Subject: [PATCH 22/82] audio token computation --- fast_llm/data/dataset/gpt/sampled.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 25fba8b7..7b70d598 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -138,18 +138,27 @@ def _sample(self) -> None: document_sizes, image_sizes, audio_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) image_token_sizes = torch.zeros_like(document_sizes).to(self._device) + audio_token_sizes = torch.zeros_like(document_sizes).to(self._device) # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size) * (sizes[:, 1] // self._patch_size)) # compute audio token sizes - if self._aud_padding_duration > 0 and len(audio_sizes) > 0: - self._aud_padding_duration * self._aud_sampling_rate - # 2. mel spectogram + audio_sizes = torch.tensor(audio_sizes) - # 3. convolution + # account for padding + if len(audio_sizes) > 0 and self._parameters.aud_padding_duration > 0: + raw_audio_seq_length = self._parameters.aud_padding_duration * self._parameters.aud_sampling_rate + audio_sizes.fill_(raw_audio_seq_length) # set all audio sizes to padded amount + long_audio_filter = audio_sizes > raw_audio_seq_length # filter audio that is too long + else: + audio_sizes > self._parameters.sequence_length + 1 - # 4. downsampling + # account for mel spectogram, convolution, downsampling k + audio_token_sizes = audio_sizes / 160 # default hop length + audio_token_sizes = audio_token_sizes // ( + 2 * self._parameters.aud_downsampling_k + ) # convolution (2) * downsampling documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() From 40f3882d6f7738238ad53579edc517e29f17d3f2 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Sat, 10 May 2025 16:52:12 +0000 Subject: [PATCH 23/82] implement mm packing --- fast_llm/data/dataset/gpt/memmap.py | 42 ++++--- fast_llm/data/dataset/gpt/sampled.py | 165 ++++++++++++++++++++------- 2 files changed, 151 insertions(+), 56 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 88e31d78..cd4bf7b7 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -49,7 +49,7 @@ def _init( with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._image_lengths = [] - self._image_positions = [] images_seen = 0 # TODO Soham: verify correctness, reshaping into width, height? for n_images in self._n_images: @@ -141,12 +141,12 @@ def _init( ) images_seen += n_images offset = offset + self._n_images.nbytes + 3 * self._n_images.sum() * np.dtype(np.int32).itemsize + self._audio_lengths = [] + self._audio_positions = [] if self._has_audio and self._version >= 5: self._n_audio = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset ) - self._audio_lengths = [] - self._audio_positions = [] audio_seen = 0 offset = offset + self._n_audio.nbytes @@ -157,7 +157,7 @@ def _init( dtype=np.int32, count=n_audio, offset=offset + audio_seen * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) + ) ) # self._num_pixels += self._image_lengths[-1].prod(axis=1, initial=3).sum() self._audio_positions.append( @@ -177,11 +177,13 @@ def _init( # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) + + # TODO Toby: Add audio num tokens check self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) - if num_pixels is not None: - assert self._num_pixels == num_pixels - if num_tokens is not None: - assert self._num_tokens == num_tokens + # if num_pixels is not None: + # assert self._num_pixels == num_pixels + # if num_tokens is not None: + # assert self._num_tokens == num_tokens def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: return (self._name, self._prefix, self._num_documents, self._num_tokens, self._num_pixels) @@ -212,6 +214,8 @@ def get( count=self._document_sizes[idx] - offset if length is None else length, offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) + images = [] + image_positions = np.array([]) if self._has_images: image_positions = self._image_positions[idx] pixels = np.frombuffer( @@ -220,7 +224,6 @@ def get( count=self._image_lengths[idx].prod(initial=3), offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, ) - images = [] start = 0 for image_length in self._image_lengths[idx]: # TODO Soham: verify reshape dimension order @@ -228,17 +231,19 @@ def get( images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels + audio = [] + audio_positions = np.array([]) if self._has_audio: audio_positions = self._audio_positions[idx] + offset = self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize + if len(self._image_lengths) > 0: + offset += self._image_lengths[idx].prod(initial=3) * np.dtype(np.uint8).itemsize all_audio = np.frombuffer( self._bin_buffer, dtype=np.dtype(np.float32), count=self._audio_lengths[idx].sum(), - offset=self._pointers[idx] - + self._document_sizes[idx] * np.dtype(self._dtype).itemsize - + self._image_lengths.prod(initial=3) * np.dtype(np.uint8).itemsize, + offset=offset, ) - audio = [] start = 0 for audio_length in self._audio_lengths[idx]: audio.append(all_audio[start : start + audio_length]) @@ -308,7 +313,7 @@ def get_document_size(self, index: int, patch_size: list[int]) -> int: # ) docsize = self._document_sizes[index].item() imagesize = self._image_lengths[index] if self._has_images else [] - audiosize = self._audio_lengths if self._has_audio else 0 + audiosize = self._audio_lengths[index] if self._has_audio else [] return docsize, imagesize, audiosize @classmethod @@ -370,7 +375,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP audio_lengths.append(len(audio)) bin_stream.write(audio.tobytes(order="C")) total_aud_size += audio.size - aud_positions.append(document.audio_positions) + if len(document.audio) > 0: + aud_positions.append(document.audio_positions) # Update metadata doc_length = len(document.token_ids) @@ -426,6 +432,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP idx_stream.write(struct.pack(" 0 else 0)) + # Flag to indicate whether audio is present + idx_stream.write(struct.pack(" 0 else 0)) # Data type idx_stream.write(struct.pack(" 0: + raw_audio_seq_length = self._parameters.aud_padding_duration * self._parameters.aud_sampling_rate + sizes = sizes.copy() # original is read-only + to_filter = bool(np.any(sizes > raw_audio_seq_length)) # filter sample where any audio is too long + sizes.fill(raw_audio_seq_length) # set all audio sizes to padded amount + + # account for mel spectogram, convolution, downsampling k + audio_token_size_arr = sizes // 160 # default hop length TODO: check divisible? + audio_token_size_arr = audio_token_size_arr // ( + 2 * self._parameters.aud_downsampling_k + ) # convolution (2) * downsampling + return audio_token_size_arr, to_filter + + def apply_audio_padding(self, audio): + if len(audio) == 0: + return audio + # TODO Toby: check 2d + padded_audio = [] + if self._parameters.aud_padding_duration > 0: + raw_audio_seq_length = self._parameters.aud_padding_duration * self._parameters.aud_sampling_rate + for aud in audio: + padded = np.pad(aud, (0, raw_audio_seq_length - len(aud)), mode="constant", constant_values=0) + padded_audio.append(padded) + return padded_audio + else: + return audio + def _sample(self) -> None: """ Create a `GPTSampledDataset` with the requested parameters. @@ -139,29 +175,22 @@ def _sample(self) -> None: document_sizes = torch.from_numpy(document_sizes).to(self._device) image_token_sizes = torch.zeros_like(document_sizes).to(self._device) audio_token_sizes = torch.zeros_like(document_sizes).to(self._device) + long_audio_filter = torch.zeros_like(document_sizes, dtype=torch.bool) # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size) * (sizes[:, 1] // self._patch_size)) - - # compute audio token sizes - audio_sizes = torch.tensor(audio_sizes) - - # account for padding - if len(audio_sizes) > 0 and self._parameters.aud_padding_duration > 0: - raw_audio_seq_length = self._parameters.aud_padding_duration * self._parameters.aud_sampling_rate - audio_sizes.fill_(raw_audio_seq_length) # set all audio sizes to padded amount - long_audio_filter = audio_sizes > raw_audio_seq_length # filter audio that is too long - else: - audio_sizes > self._parameters.sequence_length + 1 + image_token_sizes[i] = sum( + (sizes[:, 0] // self._parameters.patch_size) * (sizes[:, 1] // self._parameters.patch_size) + ) - # account for mel spectogram, convolution, downsampling k - audio_token_sizes = audio_sizes / 160 # default hop length - audio_token_sizes = audio_token_sizes // ( - 2 * self._parameters.aud_downsampling_k - ) # convolution (2) * downsampling + for i, sizes in enumerate(audio_sizes): + audio_token_size_arr, to_filter = self._compute_audio_token_size(sizes) + audio_token_sizes[i] = audio_token_size_arr.sum() + long_audio_filter[i] = to_filter documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() + tokens_per_epoch = ( + document_sizes.sum().item() + image_token_sizes.sum().item() + audio_token_sizes.sum().item() + ) # Calculate basic stats. if not self._truncate_documents: @@ -169,14 +198,31 @@ def _sample(self) -> None: "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes + image_token_sizes > self._parameters.sequence_length + 1 + long_docs_filter = ( + document_sizes + image_token_sizes + audio_token_sizes > self._parameters.sequence_length + 1 + ) ignored_documents = sum(long_docs_filter) if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = (document_sizes[~long_docs_filter] + image_token_sizes[~long_docs_filter]).sum().item() + ignored_audio_samples = sum(long_audio_filter) + if ignored_audio_samples: + log_main_rank( + f" > {ignored_audio_samples}/{documents_per_epoch} samples contain audio longer than {self._parameters.aud_padding_duration} seconds and will be ignored.", + log_fn=logger.warning, + ) + long_docs_filter = long_docs_filter | long_audio_filter + tokens_per_epoch = ( + ( + document_sizes[~long_docs_filter] + + image_token_sizes[~long_docs_filter] + + audio_token_sizes[~long_docs_filter] + ) + .sum() + .item() + ) if tokens_per_epoch == 0: raise RuntimeError( f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." @@ -215,7 +261,7 @@ def _sample(self) -> None: "num_samples": self._parameters.num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._parameters.sequence_length, - "patch_size": self._patch_size, + "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, "config": self._config.to_dict(), } @@ -298,7 +344,7 @@ def _sample(self) -> None: # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` if unshuffled_epochs > 0: token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes + image_token_sizes, + document_sizes + image_token_sizes + audio_token_sizes, offset=0, # TODO: Allowing for max 100% extra tokens for padding, is that enough? dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), @@ -308,7 +354,7 @@ def _sample(self) -> None: unshuffled_tokens = 0 if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens + yaml_data["unshuffled_tokens"] = unshuffled_tokens.item() self._load_yaml_data(yaml_data) if self._yaml_path is not None: self._yaml_path.parent.mkdir(parents=True, exist_ok=True) @@ -323,7 +369,14 @@ def _sample(self) -> None: ) ] + image_token_sizes[ - document_shuffling.to(torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32) + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) + ] + + audio_token_sizes[ + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) ], offset=self._unshuffled_tokens, # TODO: Allowing for max 100% extra tokens for padding, is that enough? @@ -416,8 +469,10 @@ def __getitem__(self, index: int) -> typing.Any: token_ids = [] loss_masking_spans = [] images = [] + audio = [] image_positions = [] - image_tokens_added = 0 + audio_positions = [] + mm_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -425,29 +480,40 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size, image_lengths = self._indexed_dataset.get_document_size(document_index, self._patch_size) + document_size, image_lengths, audio_lengths = self._indexed_dataset.get_document_size( + document_index, self._parameters.patch_size + ) image_sizes = [ get_num_patches( - *get_resize_dims(*image_length, self._image_size, self._image_size, self._patch_size), - self._patch_size, + *get_resize_dims(*image_length, self._image_size, self._image_size, self._parameters.patch_size), + self._parameters.patch_size, ) for image_length in image_lengths ] image_tokens = sum(image_sizes) + audio_token_size_arr, _ = self._compute_audio_token_size(audio_lengths) + audio_tokens = audio_token_size_arr.sum() + if not self._truncate_documents: - if document_size + image_tokens > self._parameters.sequence_length + 1: + if document_size + image_tokens + audio_tokens > self._parameters.sequence_length + 1: # Document too long, ignore document_sampling_index += 1 continue tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + image_tokens + tokens_in_sample > self._parameters.sequence_length + 1: + if ( + document_size + image_tokens + audio_tokens + tokens_in_sample + > self._parameters.sequence_length + 1 + ): # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: # Add padding tokens to current sample - token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + try: + token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + except: + pass Assert.eq(token_count + padding_size, token_end) break else: @@ -455,7 +521,7 @@ def __getitem__(self, index: int) -> typing.Any: token_count += padding_size # Determine if the document belongs to the requested sample. - if token_count + document_size + image_tokens >= token_start: + if token_count + document_size + image_tokens + audio_tokens >= token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) @@ -466,16 +532,32 @@ def __getitem__(self, index: int) -> typing.Any: use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) start_pos = 0 - for idx, im_position in enumerate(sample.image_positions): + multimodal_positions = np.concatenate( + [sample.image_positions.astype(np.int32), sample.audio_positions.astype(np.int32)] + ) + multimodal_positions.sort() + for idx, mm_position in enumerate(multimodal_positions): + if mm_position in sample.image_positions: # TODO Toby: use enum + mm_type = "image" + elif mm_position in sample.audio_positions: + mm_type = "audio" + else: + assert False # image_positions.append(im_positions + len(token_ids) + image_tokens_added) # Add placeholders for image tokens - token_ids.append(sample.token_ids[start_pos:im_position]) - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) - image_positions.append(im_position + len(token_ids) + image_tokens_added) - image_tokens_added += image_tokens - start_pos = im_position + token_ids.append(sample.token_ids[start_pos:mm_position]) + if mm_type == "image": + token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + image_positions.append(mm_position + len(token_ids) + mm_tokens_added) + mm_tokens_added += image_tokens + elif mm_type == "audio": + token_ids.append(np.full((audio_token_size_arr[idx],), -100, dtype=np.int64)) + audio_positions.append(mm_position + mm_tokens_added) + mm_tokens_added += audio_tokens + start_pos = mm_position token_ids.append(sample.token_ids[start_pos:]) images.append(sample.images) + audio.append(self.apply_audio_padding(sample.audio)) # TODO Soham: add offsets for loss masking spans if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: @@ -489,7 +571,7 @@ def __getitem__(self, index: int) -> typing.Any: # Go to the next document. document_sampling_index += 1 - token_count += document_size + image_tokens + token_count += document_size + image_tokens + audio_tokens sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) @@ -504,6 +586,9 @@ def __getitem__(self, index: int) -> typing.Any: ) images = [im for img_list in images for im in img_list] if images else None image_positions = np.array(image_positions) if image_positions else None + + audio = [aud for aud_list in audio for aud in aud_list] if audio else None + audio_positions = np.array(audio_positions) if audio_positions else None Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) return GPTSample( @@ -512,6 +597,8 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths=sequence_lengths, images=images, image_positions=image_positions, + audio=audio, + audio_positions=audio_positions, ) @property From 7f86a7f1889065ca06dade517d0cc69ef8b83215 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 06:18:03 +0000 Subject: [PATCH 24/82] fix --- fast_llm/data/dataset/gpt/sampled.py | 19 ++- fast_llm/data/tokenizer.py | 153 ++++++++++++------ fast_llm/engine/schedule/config.py | 2 +- fast_llm/layers/language_model/config.py | 2 +- fast_llm/layers/transformer/config.py | 43 ++--- fast_llm/layers/vision_encoder/config.py | 1 - .../layers/vision_encoder/preprocessing.py | 2 +- fast_llm/models/gpt/model.py | 19 ++- 8 files changed, 153 insertions(+), 88 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 54564a21..f99a9d3e 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -138,7 +138,9 @@ def _sample(self) -> None: image_token_sizes = torch.zeros_like(document_sizes).to(self._device) # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum((sizes[:, 0] // self._patch_size) * (sizes[:, 1] // self._patch_size)) + image_token_sizes[i] = sum( + (sizes[:, 0] // self._parameters.patch_size) * (sizes[:, 1] // self._parameters.patch_size) + ) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -195,7 +197,7 @@ def _sample(self) -> None: "num_samples": self._parameters.num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._parameters.sequence_length, - "patch_size": self._patch_size, + "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, "config": self._config.to_dict(), } @@ -405,12 +407,19 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size, image_lengths = self._indexed_dataset.get_document_size(document_index, self._patch_size) + document_size, image_lengths = self._indexed_dataset.get_document_size( + document_index, self._parameters.patch_size + ) image_sizes = [ get_num_patches( - *get_resize_dims(*image_length, self._image_size, self._image_size, self._patch_size), - self._patch_size, + *get_resize_dims( + *image_length, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ), + self._parameters.patch_size, ) for image_length in image_lengths ] diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 0e7d5470..10b8b2c6 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -42,64 +42,119 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: + ([self.eod_id] if end else []) ) - def tokenize(self, text, image_positions=None): + def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[list[int], list[tuple[int, int]]]: + """ + Tokenize the input text and return the tokenized input_ids along with token spans. + """ + # if not image_positions and not char_spans: + # return self._tokenize(text), [], [] if not image_positions: - return self._tokenize(text), [], [] + image_positions = [] + if not char_spans: + char_spans = [] + image_idx = 0 char_pos = 0 token_ids = [] image_token_positions = [] beginning_of_text = True - while image_idx < len(image_positions): - if image_positions[image_idx] > len(text): - raise ValueError( - f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" - ) - curr_text = text[char_pos : image_positions[image_idx]] - tokenized_text = self._tokenize( - curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) - ) - beginning_of_text = False - token_ids.extend(tokenized_text) - image_token_positions = len(token_ids) - char_pos = image_positions[image_idx] - image_idx += 1 - if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) - token_ids.extend(tokenized_text) - return token_ids, image_token_positions - - def tokenize_with_spans( - self, text: str, char_spans: list[tuple[int, int]] - ) -> tuple[list[int], list[tuple[int, int]]]: - """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. - """ - input_ids = [] - token_spans = [] - char_pos = 0 - beginning_of_text = True for start, end in char_spans: + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + while image_position <= start: + tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) + beginning_of_text = False + image_token_positions.append(len(token_ids)) + token_ids.extend(tokenized_text) + image_idx += 1 + char_pos = image_position + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) + self._tokenize(text[char_pos:start], begin=beginning_of_text, end=False) beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - else: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 - if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - input_ids.extend(tokenized_text) - return input_ids, token_spans + token_ids.extend(tokenized_text) + char_pos = start + len(token_ids) + span_length = 0 + while image_position <= end: + tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) + beginning_of_text = False + image_token_positions.append(len(token_ids)) + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + char_pos = image_position + image_idx += 1 + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + if char_pos < end: + if end >= len(text) - 1: + tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=True) + beginning_of_text = False + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + char_pos = end + 1 + else: + tokenized_text = self._tokenize(text[char_pos : end + 1], begin=beginning_of_text, end=False) + beginning_of_text = False + token_ids.extend(tokenized_text) + span_length += len(tokenized_text) + + # def tokenize(self, text, image_positions=None): + # if not image_positions: + # return self._tokenize(text), [], [] + # image_idx = 0 + # char_pos = 0 + # token_ids = [] + # image_token_positions = [] + # beginning_of_text = True + # while image_idx < len(image_positions): + # if image_positions[image_idx] > len(text): + # raise ValueError( + # f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" + # ) + # curr_text = text[char_pos : image_positions[image_idx]] + # tokenized_text = self._tokenize( + # curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) + # ) + # beginning_of_text = False + # token_ids.extend(tokenized_text) + # image_token_positions = len(token_ids) + # char_pos = image_positions[image_idx] + # image_idx += 1 + # if char_pos < len(text): + # curr_text = text[char_pos:] + # tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) + # token_ids.extend(tokenized_text) + # return token_ids, image_token_positions + + # def tokenize_with_spans( + # self, text: str, char_spans: list[tuple[int, int]] + # ) -> tuple[list[int], list[tuple[int, int]]]: + # """ + # Perform span-aware tokenization and return the tokenized input_ids along with token spans. + # """ + # input_ids = [] + # token_spans = [] + # char_pos = 0 + # beginning_of_text = True + # for start, end in char_spans: + # if char_pos < start: + # curr_text = text[char_pos:start] + # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) + # beginning_of_text = False + # input_ids.extend(tokenized_text) + # curr_text = text[start : end + 1] + # if end >= len(text) - 1: + # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) + # else: + # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) + # beginning_of_text = False + # token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) + # input_ids.extend(tokenized_text) + # char_pos = end + 1 + # if char_pos < len(text): + # curr_text = text[char_pos:] + # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) + # input_ids.extend(tokenized_text) + # return input_ids, token_spans def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 10f87835..48daf0e6 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -55,7 +55,7 @@ class BatchConfig(Config): desc="Patch size for each image token", hint=FieldHint.optional, ) - max_image_size: int | None = Field( + image_size: int | None = Field( default=None, desc="Maximum image height and width", hint=FieldHint.optional, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index ab570780..78de218f 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -167,7 +167,7 @@ def _validate(self) -> None: raise NotImplementedError("Multi-token prediction not supported with distillation.") def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space, type="vision" if self.vision_encoder is not None else None) + self.transformer.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 55320a1b..38dc9ec4 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -169,6 +169,7 @@ class VisionRotaryConfig(RotaryConfig): hint=FieldHint.feature, ) + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -668,59 +669,61 @@ def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - TransformerDimNames.group_heads, + transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(transformer_dim_names.kv_channels, self.kv_channels)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) + tensor_space.add_tensor_dim(mlp := TensorDim(transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim( + gate_and_up := TensorDim(transformer_dim_names.gate_and_up, 2 if self.gated else 1) + ) + tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp))) + tensor_space.add_tensor_dim(experts := TensorDim(transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_expert_mlp, (experts, mlp))) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + CompositeTensorDim(transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 345b118e..4dde28be 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -176,4 +176,3 @@ def setup_tensor_space(self, tensor_space: TensorSpace): ) ) self.transformer.setup_tensor_space(tensor_space, type="vision") - super().setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 7bd8a2aa..46bf0ab3 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -123,7 +123,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: dtype=self._distributed_config.training_dtype.torch, ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: images = kwargs.get(VisionEncoderKwargs.images) im_height = kwargs.get(VisionEncoderKwargs.image_size) im_width = kwargs.get(VisionEncoderKwargs.image_size) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c80c05f9..b832f1b0 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -77,14 +77,10 @@ def __init__( self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) if self._config.vision_encoder: - self._preprocessors.append( - VisionPreprocessor(self._config.vision_encoder, self._tensor_space) - ) + self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) if self._config.vision_encoder.transformer.rotary.enabled: self._preprocessors.append( - RotaryEmbeddingPreprocessor( - self._config.vision_encoder.transformer.rotary, self._tensor_space - ) + RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) ) # self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) # if self._config.vision_encoder.transformer.rotary.enabled: @@ -167,7 +163,7 @@ def preprocess_meta( micro_sequence_length = sequence_length if self._config.vision_encoder: - image_size = batch_meta.max_image_size + image_size = batch_meta.image_size image_mean = [ self._config.vision_encoder.image_normalization.mean_r, self._config.vision_encoder.image_normalization.mean_g, @@ -411,8 +407,6 @@ def preprocess( kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) - for preprocessor in self._preprocessors: - preprocessor.preprocess(tokens, kwargs) if batch.images is not None: kwargs[VisionEncoderKwargs.images] = [ [ @@ -423,7 +417,12 @@ def preprocess( ] kwargs[VisionEncoderKwargs.image_positions] = batch.image_positions kwargs[LanguageModelKwargs.tokens] = tokens - preprocessed.append((kwargs[VisionEncoderKwargs.image_patches], kwargs)) + + for preprocessor in self._preprocessors: + preprocessor.preprocess(tokens, kwargs) + image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + if image_patches is not None: + preprocessed.append((image_patches, kwargs)) else: preprocessed.append((tokens, kwargs)) From 3a8a99d62c559f97f35d37dc4c2133d5e0a77a73 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 15:23:33 +0000 Subject: [PATCH 25/82] more fixes after merge --- fast_llm/layers/transformer/preprocessing.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 01b95397..870463df 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -15,7 +15,11 @@ TransformerKwargs, VisionTransformerConfig, ) -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import ( + VisionEncoderKwargs, + VisionTransformerDimNames, + VisionTransformerKwargs, +) from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -163,6 +167,7 @@ def get_2d_rotary_frequencies( return frequencies + class RotaryEmbeddingPreprocessor(Preprocessor): _scalar_dim: TensorDim _mask: torch.Tensor @@ -216,7 +221,11 @@ def _create_tensors(self, sequence_length: int, num_patches: None | int = None) ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) + if self._config.type == RotaryEmbeddingType.pixtral: + max_num_patches = kwargs[VisionEncoderKwargs.image_size] // kwargs[VisionEncoderKwargs.patch_size] + self._create_tensors(kwargs[TransformerKwargs.sequence_length], max_num_patches) + else: + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size if self._config.type == RotaryEmbeddingType.pixtral: From d16284ee0b96598e63e74c27b6b09e7e70d9d367 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 19:32:32 +0000 Subject: [PATCH 26/82] conv cleanup --- fast_llm/data/dataset/gpt/memmap.py | 1 - fast_llm/data/preparator/gpt_memmap/config.py | 1 - fast_llm/layers/vision_encoder/config.py | 6 +++ fast_llm/layers/vision_encoder/encoder.py | 39 ++++++++++--------- fast_llm/models/gpt/conversion.py | 6 ++- setup.cfg | 1 - 6 files changed, 30 insertions(+), 24 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 8651b8fc..5d3df598 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -163,7 +163,6 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap - # TODO Soham: get images def get( self, idx: int, diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 89fe904c..38d90ed4 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -180,7 +180,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, ) - # TODO Soham: move tokenizer validation to MultiModalDataProcessor def _validate(self) -> None: assert self.tokenizer.path is not None if self.dataset.data_type is not None: diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 4dde28be..be3fb38c 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -144,6 +144,11 @@ class VisionEncoderConfig(BaseModelConfig): desc="Patch size for the image encoder.", hint=FieldHint.core, ) + conv_bias: bool = Field( + default=False, + desc="Whether to use bias in the convolutional layer.", + hint=FieldHint.optional, + ) patch_norm: NormalizationConfig = Field( default_factory=NormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", @@ -169,6 +174,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) + tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.in_channels, 3)) # TODO Soham: add a check for presence of kv channels parameter (head_dim) tensor_space.add_tensor_dim( TensorDim( diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index ed6fbc92..59212c58 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -3,7 +3,7 @@ import torch from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -38,21 +38,25 @@ def generate_block_attention_mask(patch_embeds_list, tensor): class PatchConv(Layer): def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() - # TODO Soham: device=meta - with torch.device("meta"): - self.conv = torch.nn.Conv2d( - in_channels=3, - out_channels=config.transformer.hidden_size, - kernel_size=config.patch_size, - stride=config.patch_size, - bias=False, - dtype=tensor_space.distributed_config.training_dtype.torch, - ) - self.conv.weight = ParameterMeta.from_dims( - tuple(TensorDim(f"patch_conv_weight_{idx}", size) for idx, size in enumerate(self.conv.weight.shape)), - init_method=init_normal_(), + self._tensor_space = tensor_space + # TODO Soham: lr_scale + self.weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), + self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), + ), + init_method=init_normal_(), + ) + if config.conv_bias: + self.bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),) ) + else: + self.bias = None self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) + self.stride = config.patch_size def forward( self, @@ -64,10 +68,7 @@ def forward( hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) - # we don't need images after this point - # image_patches = kwargs.pop(VisionEncoderKwargs.image_patches) - patch_embeddings = self.norm(self.conv(input_)) + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) + patch_embeddings = self.norm(input_.flatten(1)) patch_embeddings = patch_embeddings.reshape(*(x.size for x in hidden_dims)) - # Hack to pass patch embeddings to the next layer - # kwargs[VisionEncoderKwargs.patch_embeddings] = patch_embeddings return patch_embeddings diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 4b08d564..6aa3aaf1 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -728,7 +728,9 @@ def _create_vision_transformer_converters(self) -> list[WeightConverter]: return vision_transformer_converters def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converter = WeightConverter("layers.0.conv.weight", "vision_tower.patch_conv.weight") + patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] + if self._model.config.base_model.vision_encoder.conv_bias: + patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) layernorm_converters = [ WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), ] @@ -745,7 +747,7 @@ def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), ] - return [patch_conv_converter] + layernorm_converters + vision_transformer_converters + adapter_converters + return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters def _create_weight_converters(self) -> list[WeightConverter]: vision_encoder_converter = self._create_vision_encoder_weight_converters() diff --git a/setup.cfg b/setup.cfg index 3b5eea40..25f8af8b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,7 +45,6 @@ OPTIONAL = requests>=2.32.3 tqdm>=4.66.3 # Vision Tools - # TODO Soham: use pillow-simd instead of pillow? webp>=0.4.0 pillow-simd>=9.5.0 torchvision>=0.20.0 From b3134aade1428641c47ea831d50701a43ee222ca Mon Sep 17 00:00:00 2001 From: root Date: Mon, 12 May 2025 19:35:17 +0000 Subject: [PATCH 27/82] more conv cleanup --- fast_llm/engine/multi_stage/stage_base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 4d9cd848..fd50f55c 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -162,9 +162,6 @@ def _replace(module: torch.nn.Module): nonlocal i for key in module._parameters: meta = typing.cast(ParameterMeta, module._parameters[key]) - # TODO Soham: clean way to get around check? - if meta is None: - continue module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) i += 1 From c8aa66ec3793e222e0412afa0b142869f513e431 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:08:16 +0000 Subject: [PATCH 28/82] images + loss-masks --- fast_llm/data/dataset/gpt/memmap.py | 94 +++++++++++++------ .../data/preparator/gpt_memmap/prepare.py | 9 +- fast_llm/data/tokenizer.py | 81 ++++------------ 3 files changed, 92 insertions(+), 92 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 5d3df598..73fb3903 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -10,6 +10,7 @@ from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches, get_resize_dims from fast_llm.utils import Assert, div @@ -114,7 +115,6 @@ def _init( self._image_lengths = [] self._image_positions = [] images_seen = 0 - # TODO Soham: verify correctness, reshaping into width, height? for n_images in self._n_images: self._image_lengths.append( np.frombuffer( @@ -141,8 +141,6 @@ def _init( self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) - # TODO Soham: fix num_tokens to include images. Get total number of image pixels from index file and assign - # self._num_tokens = div(self._bin_buffer_mmap.size - n_pixels, np.dtype(self._dtype).itemsize) self._num_tokens = div(self._bin_buffer_mmap.size - self._num_pixels, np.dtype(self._dtype).itemsize) if num_pixels is not None: assert self._num_pixels == num_pixels @@ -163,21 +161,54 @@ def __del__(self): self._index_bin_buffer_mmap._mmap.close() # noqa del self._index_bin_buffer_mmap + # def get( + # self, + # idx: int, + # offset: int = 0, + # image_offset: int = 0, + # length: int | None = None, + # use_loss_masking_spans: bool = False, + # ): + # token_ids = np.frombuffer( + # self._bin_buffer, + # dtype=self._dtype, + # count=self._document_sizes[idx] - offset if length is None else length, + # offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, + # ) + # if self._has_images: + # image_positions = self._image_positions[idx] + # pixels = np.frombuffer( + # self._bin_buffer, + # dtype=np.dtype(np.uint8), + # count=self._image_lengths[idx].prod(initial=3), + # offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, + # ) + # images = [] + # start = 0 + # for image_length in self._image_lengths[idx]: + # n_pixels = image_length.prod(initial=3) + # images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) + # start += n_pixels + # return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) + def get( self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, - ): - # TODO Soham: handle spans + patch_size: int | None = None, + image_size: int | None = None, + ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, dtype=self._dtype, count=self._document_sizes[idx] - offset if length is None else length, offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) + images = None if self._has_images: + # Truncations with images are not yet supported image_positions = self._image_positions[idx] pixels = np.frombuffer( self._bin_buffer, @@ -188,32 +219,39 @@ def get( images = [] start = 0 for image_length in self._image_lengths[idx]: - # TODO Soham: verify reshape dimension order n_pixels = image_length.prod(initial=3) images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels - # TODO Soham: return loss_masking_spans - return GPTSample(token_ids=token_ids, images=images, image_positions=image_positions) - - # def get( - # self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False - # ) -> GPTSample: - # token_ids = np.frombuffer( - # self._bin_buffer, - # dtype=self._dtype, - # count=self._document_sizes[idx] - offset if length is None else length, - # offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - # ) - # sample_spans = None - # if use_loss_masking_spans and self._spans is not None: - # sample_spans = self._spans[idx] - # # adjust the spans for the offset and length - # sample_spans = sample_spans[ - # (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - # ] - # sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset - # sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - # return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) + sample_spans = None + if use_loss_masking_spans and self._spans is not None: + sample_spans = self._spans[idx] + sample_spans = sample_spans[ + (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) + ] + sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset + sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset + if images: + image_idx = 0 + for span in sample_spans: + additional_tokens = 0 + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + while image_position >= span[0] and image_position <= span[1]: + image_tokens = get_num_patches( + get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), + patch_size, + ) + additional_tokens += image_tokens + image_idx += 1 + image_position = ( + image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + ) + span[1] += additional_tokens + return GPTSample( + token_ids=token_ids, + images=images, + image_positions=image_positions, + loss_masking_spans=sample_spans, + ) @property def name(self) -> str: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 2a3778df..b6d81773 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -50,21 +50,24 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) # for text in batch[self._config.dataset.field] # ] - input_ids, image_token_positions = map( + input_ids, token_spans, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), + np.array(token_spans, dtype=np.int32).reshape(-1, 2), np.array(image_token_positions, dtype=np.int32), ) - for input_ids, image_token_positions in [ + for input_ids, token_spans, image_token_positions in [ self._tokenizer.tokenize( text, + loss_mask_spans, im_char_positions, ) - for text, im_char_positions in zip( + for text, loss_mask_spans, im_char_positions in zip( batch[self._config.dataset.field], + batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), batch.get(self._config.dataset.image_positions, itertools.repeat(None)), ) ] diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 10b8b2c6..c44715d8 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -42,7 +42,7 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: + ([self.eod_id] if end else []) ) - def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[list[int], list[tuple[int, int]]]: + def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[list[int], list[tuple[int, int]]]: """ Tokenize the input text and return the tokenized input_ids along with token spans. """ @@ -57,14 +57,15 @@ def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[li char_pos = 0 token_ids = [] image_token_positions = [] + token_spans = [] beginning_of_text = True + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") for start, end in char_spans: - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") while image_position <= start: tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) beginning_of_text = False - image_token_positions.append(len(token_ids)) token_ids.extend(tokenized_text) + image_token_positions.append(len(token_ids)) image_idx += 1 char_pos = image_position image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") @@ -75,11 +76,12 @@ def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[li char_pos = start len(token_ids) span_length = 0 + token_start = len(token_ids) while image_position <= end: tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) beginning_of_text = False - image_token_positions.append(len(token_ids)) token_ids.extend(tokenized_text) + image_token_positions.append(len(token_ids)) span_length += len(tokenized_text) char_pos = image_position image_idx += 1 @@ -96,65 +98,22 @@ def tokenize(self, text: str, image_positions=None, char_spans=None) -> tuple[li beginning_of_text = False token_ids.extend(tokenized_text) span_length += len(tokenized_text) + char_pos = end + 1 + token_spans.append((token_start, token_start + span_length - 1)) - # def tokenize(self, text, image_positions=None): - # if not image_positions: - # return self._tokenize(text), [], [] - # image_idx = 0 - # char_pos = 0 - # token_ids = [] - # image_token_positions = [] - # beginning_of_text = True - # while image_idx < len(image_positions): - # if image_positions[image_idx] > len(text): - # raise ValueError( - # f"Image position {image_positions[image_idx]} is greater than text length {len(text)}" - # ) - # curr_text = text[char_pos : image_positions[image_idx]] - # tokenized_text = self._tokenize( - # curr_text, begin=beginning_of_text, end=image_positions[image_idx] >= len(text) - # ) - # beginning_of_text = False - # token_ids.extend(tokenized_text) - # image_token_positions = len(token_ids) - # char_pos = image_positions[image_idx] - # image_idx += 1 - # if char_pos < len(text): - # curr_text = text[char_pos:] - # tokenized_text = self._tokenize(curr_text, begin=beginning_of_text, end=True) - # token_ids.extend(tokenized_text) - # return token_ids, image_token_positions + while image_position <= len(text): + image_position = image_positions[image_idx] + tokenized_text = self._tokenize(text[char_pos:image_position], begin=beginning_of_text, end=False) + beginning_of_text = False + token_ids.extend(tokenized_text) + image_token_positions.append(len(token_ids)) + char_pos = image_position + image_idx += 1 + image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + tokenized_text = self._tokenize(text[char_pos:], begin=beginning_of_text, end=True) + token_ids.extend(tokenized_text) - # def tokenize_with_spans( - # self, text: str, char_spans: list[tuple[int, int]] - # ) -> tuple[list[int], list[tuple[int, int]]]: - # """ - # Perform span-aware tokenization and return the tokenized input_ids along with token spans. - # """ - # input_ids = [] - # token_spans = [] - # char_pos = 0 - # beginning_of_text = True - # for start, end in char_spans: - # if char_pos < start: - # curr_text = text[char_pos:start] - # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - # beginning_of_text = False - # input_ids.extend(tokenized_text) - # curr_text = text[start : end + 1] - # if end >= len(text) - 1: - # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - # else: - # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - # beginning_of_text = False - # token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - # input_ids.extend(tokenized_text) - # char_pos = end + 1 - # if char_pos < len(text): - # curr_text = text[char_pos:] - # tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - # input_ids.extend(tokenized_text) - # return input_ids, token_spans + return token_ids, token_spans, image_token_positions def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) From 0baae59dc9c4d7401a98b253b03fb41323219910 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:21:39 +0000 Subject: [PATCH 29/82] minor fixes --- fast_llm/data/dataset/gpt/indexed.py | 4 ++-- fast_llm/data/dataset/gpt/memmap.py | 8 +------- fast_llm/data/dataset/gpt/sampled.py | 6 ++---- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 209c6e31..f8260413 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -48,8 +48,8 @@ def get_document_sizes(self) -> np.ndarray: doc_sizes, im_sizes = self._dataset.get_document_sizes() return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] - def get_document_size(self, index: int, patch_size: list[int]) -> int: - return self._dataset.get_document_size(self._begin + index, patch_size) + def get_document_size(self, index: int) -> int: + return self._dataset.get_document_size(self._begin + index) @property def has_images(self) -> bool: diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 73fb3903..af632d5b 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -268,7 +268,6 @@ def num_tokens(self) -> int: def has_images(self) -> bool: return self._has_images - # TODO: image sizes def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ The size of each document in the dataset. @@ -277,12 +276,7 @@ def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ return self._document_sizes, self._image_lengths - def get_document_size(self, index: int, patch_size: list[int]) -> int: - # return self._document_sizes[index].item() + ( - # sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) - # if self._has_images - # else 0 - # ) + def get_document_size(self, index: int) -> int: return self._document_sizes[index].item(), self._image_lengths[index] if self._has_images else [] @classmethod diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f99a9d3e..2a1df443 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -407,9 +407,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size, image_lengths = self._indexed_dataset.get_document_size( - document_index, self._parameters.patch_size - ) + document_size, image_lengths = self._indexed_dataset.get_document_size(document_index) image_sizes = [ get_num_patches( @@ -582,7 +580,7 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ logger.info(f" > Sampling dataset {self._indexed_dataset.name} ...") - document_sizes = self._indexed_dataset.get_document_sizes() + document_sizes, _ = self._indexed_dataset.get_document_sizes() num_documents = len(document_sizes) num_tokens = document_sizes.sum() np_rng = np.random.RandomState(seed=self._config.seed) From 48855be3c9413298a38af9a94ee25eb56167815f Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:30:55 +0000 Subject: [PATCH 30/82] cleanup --- fast_llm/data/dataset/gpt/sampled.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2a1df443..a8ad574c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -135,12 +135,24 @@ def _sample(self) -> None: # TODO Soham: verify numpy correctness document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - image_token_sizes = torch.zeros_like(document_sizes).to(self._device) + image_token_sizes = [] # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): - image_token_sizes[i] = sum( - (sizes[:, 0] // self._parameters.patch_size) * (sizes[:, 1] // self._parameters.patch_size) + image_token_sizes.append( + sum( + get_num_patches( + *get_resize_dims( + *size, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ), + self._parameters.patch_size, + ) + for size in sizes + ) ) + image_token_sizes = image_token_sizes.to(self._device) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() From f35e003d82b05e4787bc791928e1955262d4ba6a Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:34:37 +0000 Subject: [PATCH 31/82] cleanup --- fast_llm/data/dataset/gpt/sampled.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index a8ad574c..ce92d1c1 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -434,14 +434,15 @@ def __getitem__(self, index: int) -> typing.Any: for image_length in image_lengths ] image_tokens = sum(image_sizes) + document_size += image_tokens if not self._truncate_documents: - if document_size + image_tokens > self._parameters.sequence_length + 1: + if document_size > self._parameters.sequence_length + 1: # Document too long, ignore document_sampling_index += 1 continue tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + image_tokens + tokens_in_sample > self._parameters.sequence_length + 1: + if document_size + tokens_in_sample > self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: @@ -454,7 +455,7 @@ def __getitem__(self, index: int) -> typing.Any: token_count += padding_size # Determine if the document belongs to the requested sample. - if token_count + document_size + image_tokens >= token_start: + if token_count + document_size >= token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) @@ -488,7 +489,7 @@ def __getitem__(self, index: int) -> typing.Any: # Go to the next document. document_sampling_index += 1 - token_count += document_size + image_tokens + token_count += document_size sequence_lengths = ( np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) From 4eb34cb0c4a4be901d079aaf0997e048035dbce6 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 06:41:02 +0000 Subject: [PATCH 32/82] cleanup --- fast_llm/data/dataset/gpt/sampled.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index ce92d1c1..01459fa0 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -96,7 +96,7 @@ def __init__( # TODO Soham: use something else for this check, introducing has_images for just this check might be unnecessary. if self._indexed_dataset.has_images and self._truncate_documents: raise RuntimeError( - "Truncating documents with images is not supported. Please turn off truncation to use images." + "Truncating documents with images is not yet supported. Please turn off truncation to use images." ) if sampling.cache_directory is None: @@ -132,11 +132,9 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - # TODO Soham: verify numpy correctness document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) image_token_sizes = [] - # TODO Soham: handle max image size for i, sizes in enumerate(image_sizes): image_token_sizes.append( sum( @@ -476,7 +474,6 @@ def __getitem__(self, index: int) -> typing.Any: start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) images.append(sample.images) - # TODO Soham: add offsets for loss masking spans if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: span = np.clip( From ebb9e276a3b97b3571e26c346a986be67d8e87cc Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 15:13:09 +0000 Subject: [PATCH 33/82] cleanup --- fast_llm/data/dataset/gpt/indexed.py | 1 - .../layers/transformer/vision_transformer.py | 16 ---------------- 2 files changed, 17 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index f8260413..6e9bef96 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -11,7 +11,6 @@ class GPTIndexedDataset(IndexedDataset): - # TODO Soham: should we change this to include images? @abc.abstractmethod def get_document_sizes(self) -> np.ndarray: """ diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index 94a9c70a..3588956c 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -37,19 +37,3 @@ def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) - - # TODO Soham: remove this since we only need to call the parent method - # def forward( - # self, - # input_: torch.Tensor, - # kwargs: dict[str, typing.Any], - # losses: dict[str, typing.Any] | None = None, - # metrics: dict[str, typing.Any] | None = None, - # ) -> torch.Tensor: - # if isinstance(input_, TensorMeta): - # return self._get_meta(input_, "output", kwargs) - # # Hack for now to compute the patch embeddings - # kwargs[VisionTransformerKwargs.patch_embeddings] = super().forward( - # kwargs.pop(VisionTransformerKwargs.patch_embeddings), kwargs, losses, metrics - # ) - # return input_ From 51098ef106b72a0528c71558aba1405993d96aa0 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 15:45:07 +0000 Subject: [PATCH 34/82] fix --- fast_llm/data/dataset/gpt/sampled.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 01459fa0..fc2ddb6a 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -150,7 +150,7 @@ def _sample(self) -> None: for size in sizes ) ) - image_token_sizes = image_token_sizes.to(self._device) + image_token_sizes = torch.tensor(image_token_sizes).to(self._device) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -417,7 +417,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size, image_lengths = self._indexed_dataset.get_document_size(document_index) + text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) image_sizes = [ get_num_patches( @@ -432,7 +432,7 @@ def __getitem__(self, index: int) -> typing.Any: for image_length in image_lengths ] image_tokens = sum(image_sizes) - document_size += image_tokens + document_size = text_size + image_tokens if not self._truncate_documents: if document_size > self._parameters.sequence_length + 1: @@ -456,7 +456,7 @@ def __getitem__(self, index: int) -> typing.Any: if token_count + document_size >= token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) - token_end_index_in_document = min(token_end - token_count, document_size) + token_end_index_in_document = min(token_end - token_count, text_size) sample = self._indexed_dataset.get( document_index, offset=token_start_index_in_document, From 60b87fa766a77a183a4aa998ae914a2d22b1e195 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 16:39:46 +0000 Subject: [PATCH 35/82] prepare cleanup --- fast_llm/data/dataset/gpt/memmap.py | 2 + fast_llm/data/dataset/gpt/sampled.py | 1 - fast_llm/data/preparator/gpt_memmap/config.py | 5 +++ .../data/preparator/gpt_memmap/prepare.py | 44 ++++++++++--------- fast_llm/data/tokenizer.py | 2 - 5 files changed, 31 insertions(+), 23 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index af632d5b..e1297b14 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -108,6 +108,8 @@ def _init( + sum([x.nbytes for x in self._spans]) ) self._num_pixels = 0 + self._image_lengths = None + self._image_positions = None if self._has_images and self._version >= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index fc2ddb6a..91f8ca8f 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -93,7 +93,6 @@ def __init__( self._truncate_documents = sampling.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") - # TODO Soham: use something else for this check, introducing has_images for just this check might be unnecessary. if self._indexed_dataset.has_images and self._truncate_documents: raise RuntimeError( "Truncating documents with images is not yet supported. Please turn off truncation to use images." diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 38d90ed4..53f8e468 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -173,6 +173,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Tokenizer configuration.", hint=FieldHint.feature, ) + image_patch_size: int = Field( + default=16, + desc="Patch size for images. This is used solely for computing the number of tokens in an image to get an even split.", + hint=FieldHint.optional, + ) splits: dict[str, float] | None = Field( default=None, desc="Split the output dataset into multiple ones (ex, train/valid/test) with the specified ratios." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b6d81773..c5a1b339 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -44,12 +44,7 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D def _process_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: pass - # TODO Soham: can we merged tokenize_batch and tokenize_batch_with_spans? def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - # input_ids = [ - # np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) - # for text in batch[self._config.dataset.field] - # ] input_ids, token_spans, image_token_positions = map( list, zip( @@ -85,6 +80,7 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ return { "input_ids": input_ids, "image_positions": image_token_positions, + "token_spans": token_spans, "num_tokens": num_tokens, "num_pixels": num_pixels, } @@ -282,12 +278,7 @@ def run(self) -> None: ) if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") - if self._config.dataset.loss_masking_spans is not None: - if self._config.dataset.loss_masking_spans not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._config.dataset.loss_masking_spans}'.") - tokenize_fn = self._tokenize_batch_with_spans - else: - tokenize_fn = self._tokenize_batch + tokenize_fn = self._tokenize_batch # Avoid decoding bytes to images unless asked if self._config.dataset.images is not None: dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) @@ -336,7 +327,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa # Create the config file(s) on rank 0 if self._config.splits: for split_name, split_config in self._split_and_blend_dataset_configs( - dataset_configs, self._config.splits, self._config.output_path + dataset_configs, self._config.splits, self._config.output_path, self._config.image_patch_size ).items(): self._save_dataset_config( split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" @@ -376,7 +367,11 @@ def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) - @classmethod def _split_and_blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path + cls, + dataset_configs: list[GPTMemmapDatasetConfig], + splits: dict[str, int | float], + output_path: pathlib.Path, + image_patch_size: int, ) -> dict[str, GPTSampledDatasetConfig]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] @@ -406,11 +401,20 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - # TODO Soham: handle pixels (could still work with number of tokens?) - sizes_cumsum = dataset.get_document_sizes()[0].cumsum() - Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + text_sizes, image_sizes = dataset.get_document_sizes() + tokens_cumsum = text_sizes.cumsum() + Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) + if image_sizes: + num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) + # We use the patch sizes only for the purposes of even splitting and blending weights. + # We can always use a different patch size for training without any significant impact + # Unless the patch size used at training time is significantly different from the one used here + image_tokens_cumsum = num_pixels_cumsum // (image_patch_size**2) + tokens_cumsum += image_tokens_cumsum + num_pixels_cumsum = num_pixels_cumsum * 3 + Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) + begin_index = _get_nearest_split(tokens_cumsum, split_begin_in_dataset * tokens_cumsum[-1]) + end_index = _get_nearest_split(tokens_cumsum, split_end_in_dataset * tokens_cumsum[-1]) if end_index > begin_index: datasets_in_split.append( GPTDatasetSliceConfig.from_dict( @@ -423,8 +427,8 @@ def _split_and_blend_dataset_configs( ) ) dataset_tokens_in_split.append( - sizes_cumsum[end_index - 1].item() - - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0) + tokens_cumsum[end_index - 1].item() + - (tokens_cumsum[begin_index - 1].item() if begin_index > 0 else 0) ) # [else] None of the dataset belongs to the split. diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index c44715d8..0acb65e4 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -46,8 +46,6 @@ def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[li """ Tokenize the input text and return the tokenized input_ids along with token spans. """ - # if not image_positions and not char_spans: - # return self._tokenize(text), [], [] if not image_positions: image_positions = [] if not char_spans: From f8a5532f16df73794bbed793721a2b507bb8b280 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 May 2025 22:21:27 +0000 Subject: [PATCH 36/82] slightly better conversion --- fast_llm/models/gpt/conversion.py | 328 +++++++++++++----------------- 1 file changed, 146 insertions(+), 182 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 6aa3aaf1..4363c96c 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -167,20 +167,16 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_weight_converters( self, - hf_base_prefix: str = "", - fast_llm_offset: int = 1, ) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers # Embeddings converters.append( - WeightConverter( - f"layers.{fast_llm_offset - 1}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight" - ) + WeightConverter(f"layers.{num_layers - 1}.word_embeddings_weight", f"model.embed_tokens.weight") ) - converters += self._create_lm_head_converters(hf_base_prefix, fast_llm_offset) + converters += self._create_lm_head_converters() for i in range(num_layers): converters += self._create_transformer_layer_converters(f"layers.{i+1}", f"model.layers.{i}") @@ -565,196 +561,111 @@ class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: + # lm_converters = super()._create_config_converters() lm_converters = super()._create_config_converters() - lm_converters[-2] = ConstantExportParamConverter( - export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] - ) - # TODO Soham: cleaner way to get language model config converters - for converter in lm_converters: - if isinstance(converter, (RenameParamConverter, MappedConfigParamConverter, RopeScalingParamConverter)): - # Llava uses a different name for the text config - # if converter.fast_llm_names[0][0] == "transformer": + for idx, converter in enumerate(lm_converters): + if converter.export_names == (("model_type",),): + continue + elif converter.export_names == (("architectures",),): + ignore_index = idx + if converter.export_names: converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) - # if converter.fast_llm_names[0][0] == "transformer": - # converter.export_names[0] = ("text_config", *converter.export_names[0]) - return lm_converters + [ - # Multimodal adapter - RenameParamConverter( - fast_llm_names=(("vision_encoder", "adapter_size"),), - export_names=(("text_config", "hidden_size"),), - ), - # Image processing and conv layer - # TODO Soham: these options are not in the fast-llm model config. They're read from BatchConfig currently - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "encoder", "image_size"),), - # export_names=( - # ( - # "vision_config", - # "image_size", - # ), - # ), - # ), - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "encoder", "patch_size"),), - # export_names=( - # ( - # "vision_config", - # "patch_size", - # ), - # ), - # ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "patch_norm", "type"),), - fast_llm_value=NormalizationType.rms_norm, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), - fast_llm_value=NormalizationType.rms_norm, - ), - # Vision Transformer - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), - export_names=( - ( - "vision_config", - "num_hidden_layers", + + return ( + lm_converters[:ignore_index] + + lm_converters[ignore_index + 1 :] + + [ + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] + ), + # Vision Adapter + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=(("text_config", "hidden_size"),), + ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "patch_norm", "type"),), + fast_llm_value=NormalizationType.rms_norm, + ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), + fast_llm_value=NormalizationType.rms_norm, + ), + # Vision Transformer + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), + export_names=( + ( + "vision_config", + "num_hidden_layers", + ), ), ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), - export_names=( - ( - "vision_config", - "hidden_size", + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), + export_names=( + ( + "vision_config", + "hidden_size", + ), ), ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), - export_names=( - ( - "vision_config", - "num_attention_heads", + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), + export_names=( + ( + "vision_config", + "num_attention_heads", + ), ), ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), - export_names=( - ( - "vision_config", - "num_key_value_heads", + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), + export_names=( + ( + "vision_config", + "num_key_value_heads", + ), ), ), - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), - export_names=( - ( - "vision_config", - "intermediate_size", + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), + export_names=( + ( + "vision_config", + "intermediate_size", + ), ), ), - ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), - export_names=( - ( - "vision_config", - "hidden_act", + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), + export_names=( + ( + "vision_config", + "hidden_act", + ), ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True - ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "adapter_activation_type"),), - export_names=(("projector_hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False - ), - # TODO Soham: add this config param for completeness? - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "encoder", "num_channels"),), - # export_names=( - # ( - # "vision_config", - # "num_channels", - # ), - # ), - # ), - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "attention_dropout"),), - # export_names=( - # ( - # "vision_config", - # "attention_dropout", - # ), - # ), - # ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), - export_names=(("vision_config", "rope_theta"),), - ), - # TODO Soham: add this config param in vision encoder for completeness? - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "initializer_range"),), - # export_names=(("vision_config", "initializer_range"),), - # ), - ] - - def _create_vision_transformer_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers - vision_transformer_converters = [] - for layer in range(num_layers): - # TODO Soham: check if args are correct - vision_transformer_converters.extend( - self._create_vision_transformer_layer_converters( - layer, - ignore_export=False, - hf_base_prefix="vision_tower.transformer.layers.", - fast_llm_offset=1, - type="vision", - ) - ) - - return vision_transformer_converters - - def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] - if self._model.config.base_model.vision_encoder.conv_bias: - patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) - layernorm_converters = [ - WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), - ] - if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: - layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) - - vision_transformer_converters = self._create_vision_transformer_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 - adapter_converters = [ - WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), - WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), - # TODO Soham: add bias based on config - WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), - WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), - ] - - return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters - - def _create_weight_converters(self) -> list[WeightConverter]: - vision_encoder_converter = self._create_vision_encoder_weight_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 - # TODO Soham: call _create_transformer_layer_converters with llava's custom offset - lm_converters = super()._create_weight_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) - return vision_encoder_converter + lm_converters + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True + ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False + ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), + export_names=(("vision_config", "rope_theta"),), + ), + ] + ) def _create_vision_transformer_layer_converters( self, @@ -850,6 +761,59 @@ def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix ), ] + def _create_vision_transformer_converters(self) -> list[WeightConverter]: + num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers + vision_transformer_converters = [] + for layer in range(num_layers): + # TODO Soham: check if args are correct + vision_transformer_converters.extend( + self._create_vision_transformer_layer_converters( + layer, + ignore_export=False, + hf_base_prefix="vision_tower.transformer.layers.", + fast_llm_offset=1, + type="vision", + ) + ) + + return vision_transformer_converters + + def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: + patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] + if self._model.config.base_model.vision_encoder.conv_bias: + patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) + layernorm_converters = [ + WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), + ] + if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) + + vision_transformer_converters = self._create_vision_transformer_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 + adapter_converters = [ + WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), + WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), + # TODO Soham: add bias based on config + WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), + WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), + ] + + return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters + + def _create_weight_converters(self) -> list[WeightConverter]: + vision_encoder_converter = self._create_vision_encoder_weight_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 + # Embeddings + lm_converters = [ + WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") + ] + for i in range(self._model.config.base_model.transformer.num_layers): + lm_converters += self._create_transformer_layer_converters( + fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" + ) + lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) + return vision_encoder_converter + lm_converters + class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat From 490651e4b074073e60e36910c6d6d0ed1fa46c21 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 06:51:31 +0000 Subject: [PATCH 37/82] cleanup, sequence parallelism --- fast_llm/data/dataset/gpt/indexed.py | 2 +- fast_llm/data/dataset/gpt/memmap.py | 1 + fast_llm/layers/language_model/config.py | 2 +- fast_llm/layers/multi_modal/embedding.py | 96 +++++++++++++++---- fast_llm/layers/vision_encoder/config.py | 16 ++++ fast_llm/layers/vision_encoder/encoder.py | 8 +- .../layers/vision_encoder/preprocessing.py | 2 +- fast_llm/models/gpt/conversion.py | 10 +- fast_llm/models/gpt/model.py | 26 +++-- 9 files changed, 126 insertions(+), 37 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 6e9bef96..cbe77ff0 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -45,7 +45,7 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else None def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index e1297b14..1efc312e 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -209,6 +209,7 @@ def get( offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, ) images = None + image_positions = None if self._has_images: # Truncations with images are not yet supported image_positions = self._image_positions[idx] diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 78de218f..e46e104c 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -175,7 +175,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: # TODO: Need both? tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) - if self.vision_encoder is not None: + if self.vision_encoder.enabled: self.vision_encoder.setup_tensor_space(tensor_space) @property diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index b7d79dd3..52eaaac3 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -3,6 +3,7 @@ import torch from fast_llm.core.distributed import set_generator +from fast_llm.core.ops import gather, reduce_forward, split from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.language_model.embedding import LanguageModelEmbedding @@ -10,6 +11,7 @@ from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.layers.vision_encoder.preprocessing import get_num_patches from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert class MultiModalEmbedding(LanguageModelEmbedding): @@ -24,6 +26,78 @@ def __init__( ): super().__init__(config, tensor_space) + @torch.compile + def _forward( + self, + input_: torch.Tensor, + tokens: torch.Tensor, + position_ids: torch.Tensor | None, + image_positions: list[torch.Tensor] | None, + image_sizes: list[list[tuple[int, int]]] | None, + ) -> torch.Tensor: + """ + Forward pass for the multi-modal embedding layer. + Args: + input_: The input tensor (image embeddings). + tokens: The tokenized text input. + position_ids: The position ids for the text input. + image_positions: The positions of the image tokens in the input. + image_sizes: The sizes of the images in the input. + Returns: + The combined embeddings for text and images. + """ + Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) + group = self._tensor_space.distributed.tensor_group + if self._parallel_embeddings: + token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) + masked_tokens = (tokens - self._vocab_start_index) * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa + embeddings = reduce_forward(embeddings, group) + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + # TODO Soham: avoid cloning? + embeddings = embeddings.clone() + input_ = gather(input_, group, dim=0) + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) + embeddings[sample_idx, position : position + num_image_tokens] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens + ] + image_embedding_offset += num_image_tokens + if self._sequence_parallel: + embeddings = split(embeddings, group=group, dim=0) + else: + if self._sequence_parallel: + tokens = split(tokens, group=group, dim=0) + if self._use_absolute_position_embeddings: + position_ids = split(position_ids, group=group, dim=0) + # TODO Soham: get image positions for current split. Maybe in preprocessing? + # for positions in image_positions: + # if positions > self._distributed_config.tensor_rank + embeddings = torch.embedding(self.word_embeddings_weight, tokens) + # TODO Soham: avoid cloning? + embeddings = embeddings.clone() + for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): + image_embedding_offset = 0 + for position, size in zip(positions, sizes): + num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) + embeddings[sample_idx, position : position + num_image_tokens] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens + ] + image_embedding_offset += num_image_tokens + + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + with set_generator( + self._tensor_space.distributed.tp_generator + if self._sequence_parallel + else self._tensor_space.distributed.pp_generator + ): + embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + return embeddings.to(dtype=self._residual_dtype) + def forward( self, input_: torch.Tensor, @@ -42,25 +116,5 @@ def forward( image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) image_positions = kwargs.get(VisionEncoderKwargs.image_positions) tokens = kwargs.get(LanguageModelKwargs.tokens) - # get text embeddings - # TODO Soham: cloning to avoid pytorch complaint about in-place operation. Can we do better? - embeddings = super()._forward(tokens, position_ids).clone() - image_idx = 0 - for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): - image_embedding_offset = 0 - for position, size in zip(positions, sizes): - num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) - embeddings[sample_idx, position : position + num_image_tokens] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens - ] - image_embedding_offset += num_image_tokens - image_idx += 1 - - with set_generator( - self._tensor_space.distributed.tp_generator - if self._sequence_parallel - else self._tensor_space.distributed.pp_generator - ): - embeddings = torch.dropout(embeddings, self._dropout_p, self.training) - return embeddings.to(self._residual_dtype) + return self._forward(input_, tokens, position_ids, image_positions, image_sizes) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index be3fb38c..e9bfd7d1 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,3 +1,5 @@ +import enum + from fast_llm.config import Config, Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -130,10 +132,20 @@ class ImageNormalizationConfig(Config): ) +class VisionEncoderType(str, enum.Enum): + none = "none" + pixtral = "pixtral" + + @config_class() class VisionEncoderConfig(BaseModelConfig): _abstract = False + type: VisionEncoderType = Field( + default=VisionEncoderType.none, + desc="Type of the vision encoder. Choices: none, pixtral.", + hint=FieldHint.architecture, + ) transformer: VisionTransformerConfig = Field( default_factory=VisionTransformerConfig, desc="Configuration for the vision transformer architecture.", @@ -182,3 +194,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace): ) ) self.transformer.setup_tensor_space(tensor_space, type="vision") + + @property + def enabled(self) -> bool: + return self.type != VisionEncoderType.none diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 59212c58..a67053d5 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -2,6 +2,7 @@ import torch +from fast_llm.core.ops import split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs @@ -39,6 +40,8 @@ class PatchConv(Layer): def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() self._tensor_space = tensor_space + self._distributed_config = tensor_space.distributed_config + self._sequence_parallel = self._distributed_config.sequence_tensor_parallel # TODO Soham: lr_scale self.weight = ParameterMeta.from_dims( ( @@ -68,7 +71,10 @@ def forward( hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) + group = self._tensor_space.distributed.tensor_group input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) patch_embeddings = self.norm(input_.flatten(1)) - patch_embeddings = patch_embeddings.reshape(*(x.size for x in hidden_dims)) + patch_embeddings = patch_embeddings.reshape(*(x.global_size for x in hidden_dims)) + if self._sequence_parallel: + patch_embeddings = split(patch_embeddings, group=group, dim=0) return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 46bf0ab3..db726e24 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -153,7 +153,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: cu_seqlens = [0] max_seqlen = -1 for imgs, sizes in zip(images, image_sizes): - # TODO Soham: should this be micro_sequence_length? # sum( # get_num_patches(*size, patch_size) for size in sizes # ) @@ -172,6 +171,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] ) ) + # TODO Soham: should this be micro_sequence_length? padding_size = kwargs[TransformerKwargs.sequence_length] - cu_seqlens[-1] if padding_size > max_seqlen: max_seqlen = padding_size diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 4363c96c..ad4df737 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -27,6 +27,7 @@ from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig +from fast_llm.layers.vision_encoder.config import VisionEncoderType from fast_llm.models.gpt.config import ( GPTBaseModelConfig, GPTModelConfig, @@ -172,9 +173,7 @@ def _create_weight_converters( num_layers = self._model.config.base_model.transformer.num_layers # Embeddings - converters.append( - WeightConverter(f"layers.{num_layers - 1}.word_embeddings_weight", f"model.embed_tokens.weight") - ) + converters.append(WeightConverter(f"layers.0.word_embeddings_weight", f"model.embed_tokens.weight")) converters += self._create_lm_head_converters() @@ -250,7 +249,7 @@ def _create_transformer_layer_converters( converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") return converters - def _create_lm_head_converters(self, hf_base_prefix: str, fast_llm_offset: int = 1) -> list[WeightConverter]: + def _create_lm_head_converters(self, hf_base_prefix: str = "", fast_llm_offset: int = 1) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm @@ -575,6 +574,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: lm_converters[:ignore_index] + lm_converters[ignore_index + 1 :] + [ + ConstantImportParamConverter( + fast_llm_names=(("vision_encoder", "type"),), fast_llm_value=VisionEncoderType.pixtral + ), ConstantExportParamConverter( export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] ), diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b832f1b0..4219ac32 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -76,7 +76,7 @@ def __init__( else: self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) - if self._config.vision_encoder: + if self._config.vision_encoder.enabled: self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) if self._config.vision_encoder.transformer.rotary.enabled: self._preprocessors.append( @@ -129,7 +129,7 @@ def get_layers(self) -> list[Layer]: return [ *( [LanguageModelEmbedding(self._config, self._tensor_space)] - if self._config.vision_encoder is None + if not self._config.vision_encoder.enabled else self.get_vision_layers() ), *[ @@ -162,7 +162,7 @@ def preprocess_meta( sequence_length -= self._config.prediction_heads micro_sequence_length = sequence_length - if self._config.vision_encoder: + if self._config.vision_encoder.enabled: image_size = batch_meta.image_size image_mean = [ self._config.vision_encoder.image_normalization.mean_r, @@ -231,7 +231,7 @@ def preprocess_meta( if sequence_first else (batch_dim, hidden_sequence_q_dim, hidden_dim) ) - if self._config.vision_encoder: + if self._config.vision_encoder.enabled: vision_hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) vision_hidden_dims = ( (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) @@ -298,7 +298,7 @@ def preprocess_meta( reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs - if self._config.vision_encoder: + if self._config.vision_encoder.enabled: # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) else: @@ -430,11 +430,17 @@ def preprocess( @property def embedding(self) -> LanguageModelEmbedding: - return self.layers[self._config.vision_encoder.transformer.num_layers + 2] + if self._config.vision_encoder.enabled: + return self.layers[self._config.vision_encoder.transformer.num_layers + 2] + else: + return self.layers[0] @property def transformer_layers(self) -> list[TransformerLayer]: - return self.layers[self._config.vision_encoder.transformer.num_layers + 3 : -1] + if self._config.vision_encoder.enabled: + return self.layers[self._config.vision_encoder.transformer.num_layers + 3 : -1] + else: + return self.layers[1:-1] @property def model_head(self) -> LanguageModelHead: @@ -449,7 +455,11 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - (self._config.vision_encoder is not None, *self.model_head_indices), + # TODO Soham: make embedding layer index a property + ( + self._config.vision_encoder.enabled * (self._config.vision_encoder.transformer.num_layers + 2), + *self.model_head_indices, + ), ) } elif self._config.prediction_heads > 1: From 24e1b83f15c0ec89cb866b5438283533218bc005 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 07:19:49 +0000 Subject: [PATCH 38/82] fix conv --- fast_llm/layers/vision_encoder/encoder.py | 28 +++++++++++++++++------ 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index a67053d5..cff87479 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -61,6 +61,25 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) self.stride = config.patch_size + @torch.compile + def _forward( + self, + input_: torch.Tensor, + hidden_dims: tuple[TensorMeta, ...], + ): + group = self._tensor_space.distributed.tensor_group + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) + patch_embeddings = self.norm(input_.flatten(1)) + batch_dim, sequence_q_dim, hidden_dim = hidden_dims + if self._sequence_parallel: + patch_embeddings = patch_embeddings.reshape( + sequence_q_dim.global_size, batch_dim.size, hidden_dim.global_size + ) + patch_embeddings = split(patch_embeddings, group=group, dim=0) + else: + patch_embeddings = patch_embeddings.reshape(batch_dim.size, sequence_q_dim.size, hidden_dim.size) + return patch_embeddings + def forward( self, input_: torch.Tensor, @@ -71,10 +90,5 @@ def forward( hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) - group = self._tensor_space.distributed.tensor_group - input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) - patch_embeddings = self.norm(input_.flatten(1)) - patch_embeddings = patch_embeddings.reshape(*(x.global_size for x in hidden_dims)) - if self._sequence_parallel: - patch_embeddings = split(patch_embeddings, group=group, dim=0) - return patch_embeddings + + return self._forward(input_, hidden_dims) From 0f1612a63c84b355c45b282cb10f174c6a9a7da3 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 14:57:47 +0000 Subject: [PATCH 39/82] wip fixes --- fast_llm/data/dataset/gpt/indexed.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 53 ++++++++++++++++------------ 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index cbe77ff0..56c4c892 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -45,7 +45,7 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe def get_document_sizes(self) -> np.ndarray: # TODO: This can be really big. doc_sizes, im_sizes = self._dataset.get_document_sizes() - return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else None + return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else [] def get_document_size(self, index: int) -> int: return self._dataset.get_document_size(self._begin + index) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 91f8ca8f..9fbb218e 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -133,23 +133,26 @@ def _sample(self) -> None: # Get the document sizes, the main information needed for sampling. document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - image_token_sizes = [] - for i, sizes in enumerate(image_sizes): - image_token_sizes.append( - sum( - get_num_patches( - *get_resize_dims( - *size, - self._parameters.image_size, - self._parameters.image_size, + if image_sizes: + image_token_sizes = [] + for i, sizes in enumerate(image_sizes): + image_token_sizes.append( + sum( + get_num_patches( + *get_resize_dims( + *size, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ), self._parameters.patch_size, - ), - self._parameters.patch_size, + ) + for size in sizes ) - for size in sizes ) - ) - image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + else: + image_token_sizes = torch.zeros_like(document_sizes) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() + image_token_sizes.sum().item() @@ -463,16 +466,20 @@ def __getitem__(self, index: int) -> typing.Any: use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) start_pos = 0 - for idx, im_position in enumerate(sample.image_positions): - # image_positions.append(im_positions + len(token_ids) + image_tokens_added) - # Add placeholders for image tokens - token_ids.append(sample.token_ids[start_pos:im_position]) - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) - image_positions.append(im_position + len(token_ids) + image_tokens_added) - image_tokens_added += image_tokens - start_pos = im_position + if sample.image_positions: + for idx, im_position in enumerate(sample.image_positions): + # image_positions.append(im_positions + len(token_ids) + image_tokens_added) + # Add placeholders for image tokens + token_ids.append(sample.token_ids[start_pos:im_position]) + token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + image_positions.append(im_position + len(token_ids) + image_tokens_added) + image_tokens_added += image_tokens + start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) - images.append(sample.images) + if sample.images: + images.append(sample.images) + else: + images.append([]) if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: span = np.clip( From 2e48c5f282e4e5b1e460e96efdc9e42b2c0743db Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 May 2025 22:26:10 +0000 Subject: [PATCH 40/82] fix --- fast_llm/layers/multi_modal/embedding.py | 11 ++++-- fast_llm/layers/vision_encoder/config.py | 1 + fast_llm/layers/vision_encoder/encoder.py | 34 +++++++------------ .../layers/vision_encoder/preprocessing.py | 5 ++- fast_llm/models/gpt/model.py | 3 ++ 5 files changed, 29 insertions(+), 25 deletions(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 52eaaac3..9a035d8f 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -62,9 +62,14 @@ def _forward( image_embedding_offset = 0 for position, size in zip(positions, sizes): num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) - embeddings[sample_idx, position : position + num_image_tokens] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens - ] + if self._sequence_parallel: + embeddings[position : position + num_image_tokens, sample_idx] = input_[ + image_embedding_offset : image_embedding_offset + num_image_tokens, sample_idx + ] + else: + embeddings[sample_idx, position : position + num_image_tokens] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens + ] image_embedding_offset += num_image_tokens if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index e9bfd7d1..fdbe2726 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -66,6 +66,7 @@ class VisionEncoderKwargs: patch_embeddings = "patch_embeddings" hidden_dims = "vit_hidden_dims" image_patches_meta = "vit_image_patches_meta" + out_channels = "vit_out_channels" # TODO Soham: do we need all of them? diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index cff87479..1df7f889 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -5,6 +5,7 @@ from fast_llm.core.ops import split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -61,25 +62,6 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self.norm = config.patch_norm.get_layer(tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels)) self.stride = config.patch_size - @torch.compile - def _forward( - self, - input_: torch.Tensor, - hidden_dims: tuple[TensorMeta, ...], - ): - group = self._tensor_space.distributed.tensor_group - input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) - patch_embeddings = self.norm(input_.flatten(1)) - batch_dim, sequence_q_dim, hidden_dim = hidden_dims - if self._sequence_parallel: - patch_embeddings = patch_embeddings.reshape( - sequence_q_dim.global_size, batch_dim.size, hidden_dim.global_size - ) - patch_embeddings = split(patch_embeddings, group=group, dim=0) - else: - patch_embeddings = patch_embeddings.reshape(batch_dim.size, sequence_q_dim.size, hidden_dim.size) - return patch_embeddings - def forward( self, input_: torch.Tensor, @@ -90,5 +72,15 @@ def forward( hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) - - return self._forward(input_, hidden_dims) + micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] + sequence_length = kwargs[TransformerKwargs.sequence_length] + out_channels = kwargs[VisionEncoderKwargs.out_channels] + reshape_dims = (micro_batch_size, sequence_length, out_channels) + group = self._tensor_space.distributed.tensor_group + input_ = torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self.stride) + patch_embeddings = self.norm(input_.flatten(1)) + patch_embeddings = patch_embeddings.view(reshape_dims) + if self._sequence_parallel: + patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() + patch_embeddings = split(patch_embeddings, group=group, dim=0) + return patch_embeddings diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index db726e24..7ebfb522 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -152,16 +152,19 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_position_ids = [] cu_seqlens = [0] max_seqlen = -1 + sequence_first = kwargs.get(TransformerKwargs.sequence_first) for imgs, sizes in zip(images, image_sizes): # sum( # get_num_patches(*size, patch_size) for size in sizes # ) seq_patches = [] + sample_cu_seqlen = 0 for image, size in zip(imgs, sizes): seqlen = get_num_patches(*size, patch_size) if seqlen > max_seqlen: max_seqlen = seqlen cu_seqlens.append(cu_seqlens[-1] + seqlen) + sample_cu_seqlen += seqlen seq_patches.append( torch.cat( [ @@ -172,7 +175,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ) ) # TODO Soham: should this be micro_sequence_length? - padding_size = kwargs[TransformerKwargs.sequence_length] - cu_seqlens[-1] + padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen if padding_size > max_seqlen: max_seqlen = padding_size cu_seqlens.append(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 4219ac32..9fff50bc 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -185,6 +185,9 @@ def preprocess_meta( VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( VisionEncoderDimNames.kv_channels ).size, + VisionEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( + VisionEncoderDimNames.out_channels + ).size, } else: vision_kwargs = {} From 94e439c9c3b5b2a2d486f2a940fc5947c9bd6e22 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 15 May 2025 22:53:23 +0000 Subject: [PATCH 41/82] data updates --- fast_llm/data/data/gpt/data.py | 21 +++++++++++- fast_llm/data/dataset/gpt/memmap.py | 8 ++--- fast_llm/data/dataset/gpt/sampled.py | 50 ++++++++++++++++------------ 3 files changed, 53 insertions(+), 26 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 4fcd42ae..0e43ec2b 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -34,6 +34,8 @@ class GPTBatch: sequence_lengths: list[torch.Tensor] | None = None images: list[torch.Tensor] | None = None image_positions: list[torch.Tensor] | None = None + audio: list[torch.Tensor] | None = None + audio_positions: list[torch.Tensor] | None = None def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: @@ -54,16 +56,33 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling batch_images.append(None) batch_image_positions = [] for sample in batch: - if sample.image_positions is not None: + if sample.image_positions is not None and len(sample.image_positions) > 0: batch_image_positions.append(torch.from_numpy(sample.image_positions)) else: batch_image_positions.append(None) + + has_audio = False + batch_audio = [] + for sample in batch: + if sample.audio is not None and len(sample.audio_positions) > 0: + batch_audio.append([torch.from_numpy(image) for image in sample.audio]) + has_audio = True + else: + batch_audio.append(None) + batch_audio_positions = [] + for sample in batch: + if sample.audio_positions is not None: + batch_audio_positions.append(torch.from_numpy(sample.audio_positions)) + else: + batch_audio_positions.append(None) return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, images=batch_images if has_images else None, image_positions=batch_image_positions if has_images else None, + audio=batch_audio if has_audio else None, + audio_positions=batch_image_positions if has_audio else None, ) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 619c5624..50d4b416 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -111,8 +111,8 @@ def _init( + sum([x.nbytes for x in self._spans]) ) self._num_pixels = 0 - self._image_lengths = None - self._image_positions = None + self._image_lengths = [] + self._image_positions = [] if self._has_images and self._version >= 4: self._n_images = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset @@ -266,7 +266,7 @@ def get( if self._has_audio: audio_positions = self._audio_positions[idx] offset = self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize - if len(self._image_lengths) > 0: + if self._has_images and len(self._image_lengths) > 0: offset += self._image_lengths[idx].prod(initial=3) * np.dtype(np.uint8).itemsize all_audio = np.frombuffer( self._bin_buffer, @@ -340,7 +340,7 @@ def get_document_sizes(self) -> tuple[np.ndarray, np.ndarray]: """ return self._document_sizes, self._image_lengths, self._audio_lengths - def get_document_size(self, index: int, patch_size: list[int]) -> int: + def get_document_size(self, index: int) -> int: # return self._document_sizes[index].item() + ( # sum((h // patch_size[0]) * (w // patch_size[1]) for h, w in self._image_lengths[index]) # if self._has_images diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index d5c8fc4b..5074f1a0 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -169,26 +169,26 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() + document_sizes, image_sizes, audio_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) - image_token_sizes = [] + image_token_sizes = torch.zeros_like(document_sizes).to(self._device) for i, sizes in enumerate(image_sizes): - image_token_sizes.append( - sum( - get_num_patches( - *get_resize_dims( - *size, - self._parameters.image_size, - self._parameters.image_size, - self._parameters.patch_size, - ), + image_token_sizes[i] = sum( + get_num_patches( + *get_resize_dims( + *size, + self._parameters.image_size, + self._parameters.image_size, self._parameters.patch_size, - ) - for size in sizes + ), + self._parameters.patch_size, ) + for size in sizes ) - image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + # image_token_sizes = torch.tensor(image_token_sizes).to(self._device) + audio_token_sizes = torch.zeros_like(document_sizes).to(self._device) + long_audio_filter = torch.zeros_like(document_sizes, dtype=torch.bool) # longer than audio padding for i, sizes in enumerate(audio_sizes): audio_token_size_arr, to_filter = self._compute_audio_token_size(sizes) audio_token_sizes[i] = audio_token_size_arr.sum() @@ -502,17 +502,22 @@ def __getitem__(self, index: int) -> typing.Any: for image_length in image_lengths ] image_tokens = sum(image_sizes) - document_size = text_size + image_tokens audio_token_size_arr, _ = self._compute_audio_token_size(audio_lengths) audio_tokens = audio_token_size_arr.sum() + document_size = text_size + image_tokens + audio_tokens + if not self._truncate_documents: + # Document too long, ignore if document_size > self._parameters.sequence_length + 1: - # Document too long, ignore document_sampling_index += 1 continue + + # Where are we currently in sample? tokens_in_sample = token_count % (self._parameters.sequence_length + 1) + + # Add padding if document_size + tokens_in_sample > self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample @@ -540,6 +545,8 @@ def __getitem__(self, index: int) -> typing.Any: use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) start_pos = 0 + + # add tokens and multi modal padding placeholders multimodal_positions = np.concatenate( [sample.image_positions.astype(np.int32), sample.audio_positions.astype(np.int32)] ) @@ -552,7 +559,7 @@ def __getitem__(self, index: int) -> typing.Any: else: assert False # image_positions.append(im_positions + len(token_ids) + image_tokens_added) - # Add placeholders for image tokens + # Add placeholders for image and audio tokens tokens token_ids.append(sample.token_ids[start_pos:mm_position]) if mm_type == "image": token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) @@ -560,7 +567,7 @@ def __getitem__(self, index: int) -> typing.Any: mm_tokens_added += image_tokens elif mm_type == "audio": token_ids.append(np.full((audio_token_size_arr[idx],), -100, dtype=np.int64)) - audio_positions.append(mm_position + mm_tokens_added) + audio_positions.append(len(token_ids)) mm_tokens_added += audio_tokens start_pos = mm_position token_ids.append(sample.token_ids[start_pos:]) @@ -593,12 +600,13 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) - images = [im for img_list in images for im in img_list] if images else None - image_positions = np.array(image_positions) if image_positions else None + # images = [im for img_list in images for im in img_list] if images else None + # image_positions = np.array(image_positions) if image_positions else None + images = None audio = [aud for aud_list in audio for aud in aud_list] if audio else None audio_positions = np.array(audio_positions) if audio_positions else None - Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) + # Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) return GPTSample( token_ids=token_ids, From 543fc0d53026f82bd0735f872e4109f5fea8f7fa Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 16 May 2025 20:52:23 +0000 Subject: [PATCH 42/82] changes --- fast_llm/data/dataset/gpt/memmap.py | 2 +- fast_llm/data/tokenizer.py | 6 +++--- fast_llm/engine/schedule/config.py | 11 ----------- fast_llm/layers/language_model/config.py | 8 +++++++- fast_llm/layers/transformer/config.py | 13 +++++++++++++ fast_llm/models/gpt/model.py | 22 ++++++++++++++++------ fast_llm/models/gpt/trainer.py | 4 ++-- 7 files changed, 42 insertions(+), 24 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 50d4b416..d63653e6 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -411,7 +411,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP bin_stream.write(audio.tobytes(order="C")) total_aud_size += audio.size if len(document.audio) > 0: - aud_positions.append(document.audio_positions) + aud_positions += document.audio_positions # Update metadata doc_length = len(document.token_ids) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index c6d7a51a..cccf5985 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -73,7 +73,7 @@ def tokenize( token_ids.extend(tokenized_text) # update mm token positions - multimodal_type = "image" if multimodal_position in multimodal_positions else "audio" + multimodal_type = "image" if multimodal_position in image_positions else "audio" if multimodal_type == "image": image_token_positions.append(len(token_ids)) else: @@ -104,7 +104,7 @@ def tokenize( token_ids.extend(tokenized_text) # update mm token positions - multimodal_type = "image" if multimodal_position in multimodal_positions else "audio" + multimodal_type = "image" if multimodal_position in image_positions else "audio" if multimodal_type == "image": image_token_positions.append(len(token_ids)) else: @@ -141,7 +141,7 @@ def tokenize( token_ids.extend(tokenized_text) # update mm token positions - multimodal_type = "image" if multimodal_position in multimodal_positions else "audio" + multimodal_type = "image" if multimodal_position in image_positions else "audio" if multimodal_type == "image": image_token_positions.append(len(token_ids)) else: diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 7412beb0..13aee09a 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -60,22 +60,11 @@ class BatchConfig(Config): desc="Maximum image height and width", hint=FieldHint.optional, ) - # Audio inputs - aud_downsampling_k: int = Field( - default=5, - desc="Audio downsampling k parameter.", - hint=FieldHint.feature, - ) aud_padding_duration: int = Field( default=-1, desc="Audio padding duration in seconds.", hint=FieldHint.feature, ) - aud_sampling_rate: int = Field( - default=16000, - desc="Audio sampling rate to use.", - hint=FieldHint.feature, - ) def setup(self, distributed_config: DistributedConfig) -> None: self._distributed = distributed_config diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 78de218f..44bc5f30 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,6 +5,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.vision_encoder.config import VisionEncoderConfig from fast_llm.utils import Assert @@ -47,11 +48,16 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.architecture, ) # TODO Soham: make this None by default. Need to figure out how to handle this in the config (see ) - vision_encoder: VisionEncoderConfig = Field( + vision_encoder: VisionEncoderConfig | None = Field( default_factory=VisionEncoderConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) + audio_encoder: AudioEncoderConfig | None = Field( + default_factory=AudioEncoderConfig, + desc="Configuration for the audio encoder that transforms audio into embeddings.", + hint=FieldHint.optional, + ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 38dc9ec4..40a29959 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -765,3 +765,16 @@ class VisionTransformerConfig(TransformerConfig): desc="Configuration for the rotary positional embeddings.", hint=FieldHint.feature, ) + + +@config_class() +class AudioTransformerConfig(TransformerConfig): + """ + Configuration for the Audio Transformer model. + """ + + causal: bool = FieldUpdate( + default=False, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b832f1b0..b4bbf736 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,6 +10,8 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.audio_encoder.config import AudioEncoderKwargs +from fast_llm.layers.audio_encoder.preprocessing import AudioPreprocessor from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead @@ -82,11 +84,8 @@ def __init__( self._preprocessors.append( RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) ) - # self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) - # if self._config.vision_encoder.transformer.rotary.enabled: - # self._vision_rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( - # self._config.vision_encoder.transformer.rotary, self._tensor_space - # ) + if self._config.audio_encoder: + self._preprocessors.append(AudioPreprocessor(self._config.audio_encoder, self._tensor_space)) def get_output_layers(self) -> list[Layer]: layers = [] @@ -418,6 +417,17 @@ def preprocess( kwargs[VisionEncoderKwargs.image_positions] = batch.image_positions kwargs[LanguageModelKwargs.tokens] = tokens + if batch.audio is not None: + kwargs[AudioEncoderKwargs.audio] = [ + [ + aud.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + for aud in audio + ] + for audio in batch.audio + ] + kwargs[AudioEncoderKwargs.audio_positions] = batch.audio_positions + kwargs[LanguageModelKwargs.tokens] = tokens + for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) @@ -448,7 +458,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: if self._config.tie_word_embeddings: return { WORD_EMBEDDINGS_WEIGHT: ( - self.embedding.word_embeddings_weight, + self.layers[0].word_embeddings_weight, (self._config.vision_encoder is not None, *self.model_head_indices), ) } diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 66587b7c..ed912ec4 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -32,9 +32,9 @@ def _get_sampling_parameters( "extra_tokens": self._config.model.base_model.prediction_heads, "patch_size": self._config.batch.patch_size, "image_size": self._config.batch.image_size, - "aud_downsampling_k": self._config.batch.aud_downsampling_k, + "aud_downsampling_k": self._config.model.base_model.audio_encoder.aud_downsampling_k, "aud_padding_duration": self._config.batch.aud_padding_duration, - "aud_sampling_rate": self._config.batch.aud_sampling_rate, + "aud_sampling_rate": self._config.model.base_model.audio_encoder.aud_sampling_rate, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) From d529d37d881849afff40e57609ef4d10a916b742 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 17 May 2025 17:42:24 +0000 Subject: [PATCH 43/82] fix image position --- fast_llm/data/dataset/gpt/sampled.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 9fbb218e..780b1887 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -412,6 +412,7 @@ def __getitem__(self, index: int) -> typing.Any: images = [] image_positions = [] image_tokens_added = 0 + text_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -471,11 +472,13 @@ def __getitem__(self, index: int) -> typing.Any: # image_positions.append(im_positions + len(token_ids) + image_tokens_added) # Add placeholders for image tokens token_ids.append(sample.token_ids[start_pos:im_position]) + text_tokens_added += len(token_ids[-1]) token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) image_positions.append(im_position + len(token_ids) + image_tokens_added) image_tokens_added += image_tokens start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(token_ids[-1]) if sample.images: images.append(sample.images) else: From 3c22ddafc27e02a6f5af31ad7022a6d315cb3f03 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 17 May 2025 17:45:04 +0000 Subject: [PATCH 44/82] cleanup --- .../layers/transformer/vision_transformer.py | 25 ++----------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index 3588956c..72bd95dd 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -1,33 +1,12 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionTransformerKwargs from fast_llm.tensor import TensorMeta class VisionTransformerLayer(TransformerLayer): - """ - A vision transformer layer to encode image patches - """ - - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index: int, - return_input: bool = False, - ): - super().__init__(config, tensor_space, layer_index, return_input) - - hidden_dim = self._tensor_space.get_tensor_dim(VisionTransformerDimNames.hidden) - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) - - self.norm_1 = self._config.peft.apply_other(self.norm_1) - self.norm_2 = self._config.peft.apply_other(self.norm_2) - @property def name(self) -> str: return f"Vision transformer layer {self._layer_index}" From f0c8d830da9c4ea43df478a6cafbbb48bf910111 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 20 May 2025 07:05:01 +0000 Subject: [PATCH 45/82] cleanup --- fast_llm/layers/language_model/config.py | 1 - fast_llm/layers/transformer/attention.py | 17 +- fast_llm/layers/transformer/config.py | 259 ++++++++++++------ fast_llm/layers/transformer/mlp.py | 17 +- fast_llm/layers/transformer/preprocessing.py | 58 ++-- fast_llm/layers/transformer/transformer.py | 24 +- .../layers/transformer/vision_transformer.py | 12 +- fast_llm/layers/vision_encoder/config.py | 60 +--- fast_llm/models/gpt/model.py | 19 +- fast_llm/utils.py | 7 + 10 files changed, 239 insertions(+), 235 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index e46e104c..cdb27d9e 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -46,7 +46,6 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - # TODO Soham: make this None by default. Need to figure out how to handle this in the config (see ) vision_encoder: VisionEncoderConfig = Field( default_factory=VisionEncoderConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index b16f1740..3180b6cb 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -9,14 +9,7 @@ from fast_llm.functional.rotary import apply_rotary_embeddings from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerSubLayerName, - VisionTransformerConfig, -) -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs, TransformerSubLayerName from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -66,12 +59,8 @@ def __init__( layer_index, ): super().__init__() - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config self._tensor_space = tensor_space # TODO Soham: fix assert diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 38dc9ec4..9a6bec07 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -28,60 +28,109 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -class TransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "batch" - # TODO: Distinguish micro-sequence? - sequence_q = "sequence_q" - sequence_q_tp = "sequence_q_tp" - sequence_k = "sequence_k" - hidden = "hidden" - # Self-attention dimensions - head_groups = "head_groups" - group_heads = "group_heads" - key_and_value = "key_value" - kv_channels = "kv_channels" - composite_heads = "composite_heads" - composite_query = "composite_query" - composite_key_value = "composite_key_value" - composite_dense = "composite_dense" - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" - - -class TransformerKwargs: - rotary_freq_q = "rotary_freq_q" - rotary_freq_k = "rotary_freq_k" - attention_mask = "attention_mask" - attention_mask_value = "attention_mask_value" - sequence_lengths = "sequence_lengths" - cu_seqlens_q = "cu_seqlens_q" - cu_seqlens_k = "cu_seqlens_k" - max_seqlen_q = "max_seqlen_q" - max_seqlen_k = "max_seqlen_k" - # TODO: Review these - presents = "presents" - past_key_values = "past_key_values" - sequence_first = "sequence_first" - hidden_dims = "hidden_dims" - sequence_q_dim = "sequence_q_dim" - sequence_k_dim = "sequence_k_dim" - sequence_length = "sequence_length" - micro_batch_size = "micro_batch_size" - # TODO: Move - grad_output = "grad_output" +class BaseTransformerDimNames: + _kwargs_attributes = { + "batch": "batch", + "sequence_q": "sequence_q", + "sequence_q_tp": "sequence_q_tp", + "sequence_k": "sequence_k", + "hidden": "hidden", + "head_groups": "head_groups", + "group_heads": "group_heads", + "key_and_value": "key_value", + "kv_channels": "kv_channels", + "composite_heads": "composite_heads", + "composite_query": "composite_query", + "composite_key_value": "composite_key_value", + "composite_dense": "composite_dense", + "mlp": "mlp", + "gate_and_up": "gate_and_up", + "composite_gated_mlp": "composite_gated_mlp", + "experts": "experts", + "top_experts": "top_experts", + "shared_experts": "shared_experts", + "unshared_experts": "unshared_experts", + "composite_expert_mlp": "composite_expert_mlp", + "composite_gated_expert_mlp": "composite_gated_expert_mlp", + "composite_shared_expert_mlp": "composite_shared_expert_mlp", + "composite_gated_shared_expert_mlp": "composite_gated_shared_expert_mlp", + } + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseTransformerDimNames._kwargs_attributes.items(): + setattr(cls, value, f"{cls._prefix}_{value}") + + +class TransformerDimNames(BaseTransformerDimNames, prefix=""): + pass + + +class VisionTransformerDimNames(BaseTransformerDimNames, prefix="image_encoder"): + pass + + +class BaseTransformerKwargs: + _kwargs_attributes = { + "rotary_freq_q": "rotary_freq_q", + "rotary_freq_k": "rotary_freq_k", + "attention_mask": "attention_mask", + "attention_mask_value": "attention_mask_value", + "sequence_lengths": "sequence_lengths", + "cu_seqlens_q": "cu_seqlens_q", + "cu_seqlens_k": "cu_seqlens_k", + "max_seqlen_q": "max_seqlen_q", + "max_seqlen_k": "max_seqlen_k", + "presents": "presents", + "past_key_values": "past_key_values", + "sequence_first": "sequence_first", + "hidden_dims": "hidden_dims", + "sequence_q_dim": "sequence_q_dim", + "sequence_k_dim": "sequence_k_dim", + "sequence_length": "sequence_length", + "micro_batch_size": "micro_batch_size", + "grad_output": "grad_output", + } + + _prefix = "" + + def __init_subclass__(cls, prefix="", **kwargs): + super().__init_subclass__(**kwargs) + cls._prefix = prefix + for attr, value in BaseTransformerKwargs._kwargs_attributes.items(): + setattr(cls, value, f"{cls._prefix}_{value}") + + +class TransformerKwargs(BaseTransformerKwargs, prefix=""): + pass + + +class VisionTransformerKwargs(BaseTransformerKwargs, prefix="image_encoder"): + patch_position_ids = "patch_position_ids" + + +# class TransformerKwargs: +# rotary_freq_q = "rotary_freq_q" +# rotary_freq_k = "rotary_freq_k" +# attention_mask = "attention_mask" +# attention_mask_value = "attention_mask_value" +# sequence_lengths = "sequence_lengths" +# cu_seqlens_q = "cu_seqlens_q" +# cu_seqlens_k = "cu_seqlens_k" +# max_seqlen_q = "max_seqlen_q" +# max_seqlen_k = "max_seqlen_k" +# # TODO: Review these +# presents = "presents" +# past_key_values = "past_key_values" +# sequence_first = "sequence_first" +# hidden_dims = "hidden_dims" +# sequence_q_dim = "sequence_q_dim" +# sequence_k_dim = "sequence_k_dim" +# sequence_length = "sequence_length" +# micro_batch_size = "micro_batch_size" +# # TODO: Move +# grad_output = "grad_output" class TransformerLossNames: @@ -98,6 +147,11 @@ class RotaryEmbeddingType(str, enum.Enum): pixtral = "pixtral" +class TransformerType(str, enum.Enum): + lm_decoder = "lm_decoder" + image_encoder = "image_encoder" + + @config_class() class RotaryConfig(BaseModelConfig): _abstract = False @@ -160,6 +214,14 @@ def _validate(self) -> None: if self.triton and not TritonConfig.TRITON_ENABLED: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") + @property + def _transformer_dim_names(self) -> TransformerDimNames: + return TransformerDimNames + + @property + def _transformer_kwargs(self) -> TransformerKwargs: + return TransformerKwargs + @config_class() class VisionRotaryConfig(RotaryConfig): @@ -169,6 +231,14 @@ class VisionRotaryConfig(RotaryConfig): hint=FieldHint.feature, ) + @property + def _transformer_dim_names(self) -> VisionTransformerDimNames: + return VisionTransformerDimNames + + @property + def _transformer_kwargs(self) -> VisionTransformerKwargs: + return VisionTransformerKwargs + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" @@ -259,6 +329,11 @@ def _validate(self) -> None: @config_class() class TransformerConfig(BaseModelConfig): _abstract = False + transformer_type: TransformerType = Field( + default=TransformerType.lm_decoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + hint=FieldHint.architecture, + ) normalization: NormalizationConfig = Field( default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", @@ -658,72 +733,71 @@ def _from_dict( cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) return super()._from_dict(default, strict, flat) - def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) -> None: - if type == "vision": - # TODO Soham: better way to get around circular imports? Maybe add a type class variable to TransformerConfig? - from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames - - transformer_dim_names = VisionTransformerDimNames - else: - transformer_dim_names = TransformerDimNames + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + self.transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - transformer_dim_names.group_heads, + self.transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(transformer_dim_names.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(transformer_dim_names.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(self.transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim(kv_channels := TensorDim(self.transformer_dim_names.kv_channels, self.kv_channels)) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(self.transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self.transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim( + self.transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels) + ) ) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self.transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim(mlp := TensorDim(self.transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) tensor_space.add_tensor_dim( - gate_and_up := TensorDim(transformer_dim_names.gate_and_up, 2 if self.gated else 1) + gate_and_up := TensorDim(self.transformer_dim_names.gate_and_up, 2 if self.gated else 1) ) - tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(transformer_dim_names.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(transformer_dim_names.composite_expert_mlp, (experts, mlp))) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + CompositeTensorDim(self.transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(transformer_dim_names.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(experts := TensorDim(self.transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim( + CompositeTensorDim(self.transformer_dim_names.composite_expert_mlp, (experts, mlp)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(self.transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(transformer_dim_names.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(self.transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(self.transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + self.transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) @@ -739,6 +813,14 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return use_flash_attention + @property + def _transformer_kwargs(self) -> TransformerKwargs: + return TransformerKwargs + + @property + def _transformer_dim_names(self) -> TransformerDimNames: + return TransformerDimNames + @config_class() class VisionRotaryConfig(RotaryConfig): @@ -755,6 +837,11 @@ class VisionTransformerConfig(TransformerConfig): Configuration for the Vision Transformer (ViT) model. """ + transformer_type: TransformerType = FieldUpdate( + default=TransformerType.image_encoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + hint=FieldHint.architecture, + ) causal: bool = FieldUpdate( default=False, desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", @@ -765,3 +852,11 @@ class VisionTransformerConfig(TransformerConfig): desc="Configuration for the rotary positional embeddings.", hint=FieldHint.feature, ) + + @property + def _transformer_kwargs(self) -> VisionTransformerKwargs: + return VisionTransformerKwargs + + @property + def _transformer_dim_names(self) -> VisionTransformerDimNames: + return VisionTransformerDimNames diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index dcea463a..42393a41 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,14 +8,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerSubLayerName, - VisionTransformerConfig, -) -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -25,12 +18,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s super().__init__() self._name = name - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs init_method_1 = init_normal_( std=config.init_method_std_mlp_1, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 870463df..97c6c0f3 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -7,19 +7,8 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.rotary import convert_rotary_complex_to_real -from fast_llm.layers.transformer.config import ( - RotaryConfig, - RotaryEmbeddingType, - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - VisionTransformerConfig, -) -from fast_llm.layers.vision_encoder.config import ( - VisionEncoderKwargs, - VisionTransformerDimNames, - VisionTransformerKwargs, -) +from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType, TransformerConfig, TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) @@ -178,19 +167,8 @@ def __init__( config: RotaryConfig, tensor_space: TensorSpace, ): - # if isinstance(config, TransformerConfig): - # self._transformer_dim_names = TransformerDimNames - # self._transformer_kwargs = TransformerKwargs - # elif isinstance(config, VisionTransformerConfig): - # self._transformer_dim_names = VisionTransformerDimNames - # self._transformer_kwargs = VisionTransformerKwargs - # TODO Soham: better way to do this? - if config.type == RotaryEmbeddingType.pixtral: - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - else: - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config assert self._config.enabled self._tensor_space = tensor_space @@ -273,12 +251,14 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + # if isinstance(config, VisionTransformerConfig): + # self._transformer_dim_names = VisionTransformerDimNames + # self._transformer_kwargs = VisionTransformerKwargs + # elif isinstance(config, TransformerConfig): + # self._transformer_dim_names = TransformerDimNames + # self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config @@ -348,12 +328,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + # if isinstance(config, VisionTransformerConfig): + # self._transformer_dim_names = VisionTransformerDimNames + # self._transformer_kwargs = VisionTransformerKwargs + # elif isinstance(config, TransformerConfig): + # self._transformer_dim_names = TransformerDimNames + # self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: """ diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 5590be32..8bd1394e 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -9,15 +9,9 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - VisionTransformerConfig, -) +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -35,12 +29,8 @@ def __init__( self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False ): super().__init__() - if isinstance(config, VisionTransformerConfig): - self._transformer_dim_names = VisionTransformerDimNames - self._transformer_kwargs = VisionTransformerKwargs - elif isinstance(config, TransformerConfig): - self._transformer_dim_names = TransformerDimNames - self._transformer_kwargs = TransformerKwargs + self._transformer_dim_names = config._transformer_dim_names + self._transformer_kwargs = config._transformer_kwargs self._config: TransformerConfig = config self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout @@ -80,6 +70,14 @@ def _bias_dropout_add( def name(self) -> str: return f"{self._name} {self._layer_index}" + @property + def _transformer_kwargs(self) -> TransformerKwargs: + return TransformerKwargs + + @property + def _transformer_dim_names(self) -> TransformerDimNames: + return TransformerDimNames + def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[self._transformer_kwargs.hidden_dims] if self._return_input: diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index 72bd95dd..7f39f9cf 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -2,14 +2,20 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.tensor import TensorMeta class VisionTransformerLayer(TransformerLayer): + _name: str = "Vision transformer layer" + + @property + def _transformer_kwargs(self) -> VisionTransformerKwargs: + return VisionTransformerKwargs + @property - def name(self) -> str: - return f"Vision transformer layer {self._layer_index}" + def _transformer_dim_names(self) -> VisionTransformerDimNames: + return VisionTransformerDimNames def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[VisionTransformerKwargs.hidden_dims] diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index fdbe2726..70504901 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -16,39 +16,6 @@ class VisionEncoderDimNames: kv_channels = "vision_kv_channels" -class VisionTransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "vit_batch" - # TODO: Distinguish micro-sequence? - sequence_q = "vit_sequence_q" - sequence_q_tp = "vit_sequence_q_tp" - sequence_k = "vit_sequence_k" - hidden = "vit_hidden" - # Self-attention dimensions - head_groups = "vit_head_groups" - group_heads = "vit_group_heads" - key_and_value = "vit_key_value" - kv_channels = "vit_kv_channels" - composite_heads = "vit_composite_heads" - composite_query = "vit_composite_query" - composite_key_value = "vit_composite_key_value" - composite_dense = "vit_composite_dense" - # MLP dimensions - mlp = "vit_mlp" - gate_and_up = "vit_gate_and_up" - composite_gated_mlp = "vit_composite_gated_mlp" - experts = "vit_experts" - top_experts = "vit_top_experts" - shared_experts = "vit_shared_experts" - unshared_experts = "vit_unshared_experts" - composite_expert_mlp = "vit_composite_expert_mlp" - composite_gated_expert_mlp = "vit_composite_gated_expert_mlp" - composite_shared_expert_mlp = "vit_composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "vit_composite_gated_shared_expert_mlp" - - class VisionEncoderKwargs: patch_size = "patch_size" images = "images" @@ -69,31 +36,6 @@ class VisionEncoderKwargs: out_channels = "vit_out_channels" -# TODO Soham: do we need all of them? -class VisionTransformerKwargs: - rotary_freq_q = "vit_rotary_freq_q" - rotary_freq_k = "vit_rotary_freq_k" - attention_mask = "vit_attention_mask" - attention_mask_value = "vit_attention_mask_value" - sequence_lengths = "vit_sequence_lengths" - cu_seqlens_q = "vit_cu_seqlens_q" - cu_seqlens_k = "vit_cu_seqlens_k" - max_seqlen_q = "vit_max_seqlen_q" - max_seqlen_k = "vit_max_seqlen_k" - # TODO: Review these - presents = "vit_presents" - past_key_values = "vit_past_key_values" - sequence_first = "vit_sequence_first" - hidden_dims = "vit_hidden_dims" - sequence_q_dim = "vit_sequence_q_dim" - sequence_k_dim = "vit_sequence_k_dim" - sequence_length = "vit_sequence_length" - micro_batch_size = "vit_micro_batch_size" - # TODO: Move - grad_output = "vit_grad_output" - patch_position_ids = "patch_position_ids" - - @config_class() class ImageNormalizationConfig(Config): mean_r: float = Field( @@ -194,7 +136,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace): VisionEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads ) ) - self.transformer.setup_tensor_space(tensor_space, type="vision") + self.transformer.setup_tensor_space(tensor_space) @property def enabled(self) -> bool: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9fff50bc..c1d9df90 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -433,17 +433,18 @@ def preprocess( @property def embedding(self) -> LanguageModelEmbedding: - if self._config.vision_encoder.enabled: - return self.layers[self._config.vision_encoder.transformer.num_layers + 2] - else: - return self.layers[0] + return self.layers[self.embedding_layer_index] @property def transformer_layers(self) -> list[TransformerLayer]: + return self.layers[self.embedding_layer_index + 1 : -1] + + @property + def embedding_layer_index(self) -> int: if self._config.vision_encoder.enabled: - return self.layers[self._config.vision_encoder.transformer.num_layers + 3 : -1] + return self._config.vision_encoder.transformer.num_layers + 2 else: - return self.layers[1:-1] + return 0 @property def model_head(self) -> LanguageModelHead: @@ -458,11 +459,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: return { WORD_EMBEDDINGS_WEIGHT: ( self.embedding.word_embeddings_weight, - # TODO Soham: make embedding layer index a property - ( - self._config.vision_encoder.enabled * (self._config.vision_encoder.transformer.num_layers + 2), - *self.model_head_indices, - ), + (self.embedding_layer_index, *self.model_head_indices), ) } elif self._config.prediction_heads > 1: diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 51e0eee5..c5b7f07a 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -336,3 +336,10 @@ def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple def check_equal_nested(config_a, config_b): if errors := compare_nested(config_a, config_b): raise ValueError("\n".join(errors)) + + +def prefix_class_vars(cls, prefix: str, base_cls: type): + for attr, value in vars(base_cls).items(): + if not attr.startswith("__") and isinstance(value, str) and not hasattr(cls, attr): + setattr(cls, attr, prefix + value) + return cls From 1a20913589ebd48f6d50ad324a7eacd46868a0e4 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 20 May 2025 23:13:53 +0000 Subject: [PATCH 46/82] layer changes --- fast_llm/data/data/gpt/data.py | 2 +- fast_llm/data/dataset/gpt/memmap.py | 10 +- fast_llm/data/dataset/gpt/sampled.py | 16 +- fast_llm/data/dataset/monitor.py | 22 +- .../data/preparator/gpt_memmap/prepare.py | 6 +- fast_llm/layers/audio_encoder/adapter.py | 54 ++++ fast_llm/layers/audio_encoder/config.py | 143 +++++++++++ fast_llm/layers/audio_encoder/encoder.py | 61 +++++ .../layers/audio_encoder/preprocessing.py | 47 ++++ fast_llm/layers/language_model/config.py | 4 +- .../layers/transformer/audio_transformer.py | 41 +++ fast_llm/layers/transformer/config.py | 20 +- fast_llm/models/gpt/config.py | 8 + fast_llm/models/gpt/conversion.py | 239 ++++++++++++++++++ fast_llm/models/gpt/model.py | 35 ++- 15 files changed, 680 insertions(+), 28 deletions(-) create mode 100644 fast_llm/layers/audio_encoder/adapter.py create mode 100644 fast_llm/layers/audio_encoder/config.py create mode 100644 fast_llm/layers/audio_encoder/encoder.py create mode 100644 fast_llm/layers/audio_encoder/preprocessing.py create mode 100644 fast_llm/layers/transformer/audio_transformer.py diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 0e43ec2b..028f008a 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -65,7 +65,7 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling batch_audio = [] for sample in batch: if sample.audio is not None and len(sample.audio_positions) > 0: - batch_audio.append([torch.from_numpy(image) for image in sample.audio]) + batch_audio.append([torch.from_numpy(audio) for audio in sample.audio]) has_audio = True else: batch_audio.append(None) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 51fd4cc2..9ea39707 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -261,11 +261,17 @@ def get( images.append(pixels[start : start + n_pixels].reshape(3, image_length[0], image_length[1])) start += n_pixels - audio = None + audio = [] audio_positions = None if self._has_audio: audio_positions = self._audio_positions[idx] - offset = self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize + # increment offset by documents and images + offset = ( + self._pointers[idx] + + offset * np.dtype(self._dtype).itemsize + + self._document_sizes[idx] * np.dtype(self._dtype).itemsize + ) + if self._has_images and len(self._image_lengths) > 0: offset += self._image_lengths[idx].prod(initial=3) * np.dtype(np.uint8).itemsize all_audio = np.frombuffer( diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 45ddeb86..b1cdf826 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -554,13 +554,19 @@ def __getitem__(self, index: int) -> typing.Any: # add tokens and multi modal padding placeholders multimodal_positions = np.concatenate( - [sample.image_positions.astype(np.int32), sample.audio_positions.astype(np.int32)] - ) + [ + arr.astype(np.int32) + for arr in (sample.image_positions, sample.audio_positions) + if arr is not None + ] + ) or np.array([], dtype=np.int32) multimodal_positions.sort() for idx, mm_position in enumerate(multimodal_positions): - if mm_position in sample.image_positions: # TODO Toby: use enum + if ( + sample.image_positions is not None and mm_position in sample.image_positions + ): # TODO Toby: use enum mm_type = "image" - elif mm_position in sample.audio_positions: + elif sample.audio_positions is not None and mm_position in sample.audio_positions: mm_type = "audio" else: assert False @@ -572,8 +578,8 @@ def __getitem__(self, index: int) -> typing.Any: image_positions.append(mm_position + len(token_ids) + mm_tokens_added) mm_tokens_added += image_tokens elif mm_type == "audio": + audio_positions.append(sum(t.size for t in token_ids)) token_ids.append(np.full((audio_token_size_arr[idx],), -100, dtype=np.int64)) - audio_positions.append(len(token_ids)) mm_tokens_added += audio_tokens start_pos = mm_position token_ids.append(sample.token_ids[start_pos:]) diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 86bc080f..53df3add 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -35,18 +35,16 @@ def __len__(self) -> int: def __getitem__(self, idx) -> typing.Any: start_time = time.perf_counter() - try: - sample = self._dataset[idx] - sample_time = (time.perf_counter() - start_time) * 1000 - if sample_time > self._data_sample_warn_time_ms: - logger.warning( - f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" - ) - return sample - - except Exception: - logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") - raise + # try: + sample = self._dataset[idx] + sample_time = (time.perf_counter() - start_time) * 1000 + if sample_time > self._data_sample_warn_time_ms: + logger.warning(f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load") + return sample + + # except Exception as e: + # logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") + # raise @property def name(self) -> str: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index c697d54d..bffa6c83 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -150,7 +150,11 @@ def _document_generator(): # [np.array(im) for im in item["images"]] if self._config.dataset.images else None, item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, - np.array(item[self._config.dataset.audio]) if self._config.dataset.audio else None, + ( + np.array(item[self._config.dataset.audio], dtype=np.float32) + if self._config.dataset.audio + else None + ), item[self._config.dataset.audio_positions] if self._config.dataset.audio_positions else None, ) # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: diff --git a/fast_llm/layers/audio_encoder/adapter.py b/fast_llm/layers/audio_encoder/adapter.py new file mode 100644 index 00000000..4f77971e --- /dev/null +++ b/fast_llm/layers/audio_encoder/adapter.py @@ -0,0 +1,54 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.functional.triton.mlp import torch_mlp_activation +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames +from fast_llm.tensor import TensorMeta, init_normal_ + + +class AudioAdapter(Layer): + """ + Vision adapter layer for the LLM. + """ + + def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + super().__init__() + input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + self._activation_type = config.adapter_activation_type + # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? + self.layer_1 = Linear( + input_dim, + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + self.layer_2 = Linear( + tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), + tensor_space.get_tensor_dim(TransformerDimNames.hidden), + bias=True, + weight_init_method=init_normal_(), + bias_init_method=init_normal_(), + ) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + kwargs[TransformerKwargs.hidden_dims], + tensor_name="Vision adapter output", + dtype=input_.dtype, + ) + return self.layer_2( + torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + ) diff --git a/fast_llm/layers/audio_encoder/config.py b/fast_llm/layers/audio_encoder/config.py new file mode 100644 index 00000000..52a8673e --- /dev/null +++ b/fast_llm/layers/audio_encoder/config.py @@ -0,0 +1,143 @@ +import enum + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.functional.config import ActivationType +from fast_llm.layers.transformer.config import AudioTransformerConfig + + +class AudioEncoderDimNames: + in_channels = "audio_in_channels" + out_channels = "audio_out_channels" + kernel_size = "audio_kernel_size" + adapter_size = "audio_adapter_size" + audio_channels = "audio_kv_channels" + + +class AudioTransformerDimNames: + # A set of common tensor dim names packed into a namespace. + # Input dimensions (variable) + # TODO: Does batch belong here? + batch = "audio_batch" + # TODO: Distinguish micro-sequence? + sequence_q = "audio_sequence_q" + sequence_q_tp = "audio_sequence_q_tp" + sequence_k = "audio_sequence_k" + hidden = "audio_hidden" + # Self-attention dimensions + head_groups = "audio_head_groups" + group_heads = "audio_group_heads" + key_and_value = "audio_key_value" + kv_channels = "audio_kv_channels" + composite_heads = "audio_composite_heads" + composite_query = "audio_composite_query" + composite_key_value = "audio_composite_key_value" + composite_dense = "audio_composite_dense" + # MLP dimensions + mlp = "audio_mlp" + gate_and_up = "audio_gate_and_up" + composite_gated_mlp = "audio_composite_gated_mlp" + experts = "audio_experts" + top_experts = "audio_top_experts" + shared_experts = "audio_shared_experts" + unshared_experts = "audio_unshared_experts" + composite_expert_mlp = "audio_composite_expert_mlp" + composite_gated_expert_mlp = "audio_composite_gated_expert_mlp" + composite_shared_expert_mlp = "audio_composite_shared_expert_mlp" + composite_gated_shared_expert_mlp = "audio_composite_gated_shared_expert_mlp" + + +class AudioEncoderKwargs: + audio = "audio" + audio_mel = "audio_mel" + audio_positions = "audio_positions" + kv_channels = "audio_kv_channels" + hidden_dims = "audio_hidden_dims" + + +class AudioEncoderType(str, enum.Enum): + none = "none" + whisper = "whisper" + + +# # TODO Toby: do we need all of them? +class AudioTransformerKwargs: + rotary_freq_q = "audio_rotary_freq_q" + rotary_freq_k = "audio_rotary_freq_k" + attention_mask = "audio_attention_mask" + attention_mask_value = "audio_attention_mask_value" + sequence_lengths = "audio_sequence_lengths" + cu_seqlens_q = "audio_cu_seqlens_q" + cu_seqlens_k = "audio_cu_seqlens_k" + max_seqlen_q = "audio_max_seqlen_q" + max_seqlen_k = "audio_max_seqlen_k" + # TODO: Review these + presents = "audio_presents" + past_key_values = "audio_past_key_values" + sequence_first = "audio_sequence_first" + hidden_dims = "audio_hidden_dims" + sequence_q_dim = "audio_sequence_q_dim" + sequence_k_dim = "audio_sequence_k_dim" + sequence_length = "audio_sequence_length" + micro_batch_size = "audio_micro_batch_size" + # TODO: Move + grad_output = "audio_grad_output" + patch_position_ids = "patch_position_ids" + + +@config_class() +class AudioEncoderConfig(BaseModelConfig): + _abstract = False + + transformer: AudioTransformerConfig = Field( + default_factory=AudioTransformerConfig, + desc="Configuration for the audio transformer architecture.", + hint=FieldHint.core, + ) + type: AudioEncoderType = Field( + default=AudioEncoderType.none, + desc="Type of the audio encoder. Choices: none, whisper.", + hint=FieldHint.architecture, + ) + conv_bias: bool = Field( + default=False, + desc="Whether to use bias in the convolutional layer.", + hint=FieldHint.optional, + ) + adapter_size: int = Field( + default=5120, + desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + hint=FieldHint.core, + ) + adapter_activation_type: ActivationType = Field( + default=ActivationType.gelu, + desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + hint=FieldHint.core, + ) + aud_downsampling_k: int = Field( + default=5, + desc="Audio downsampling k parameter.", + hint=FieldHint.feature, + ) + aud_sampling_rate: int = Field( + default=16000, + desc="Audio sampling rate to use.", + hint=FieldHint.feature, + ) + + def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.adapter_size, self.adapter_size)) + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.in_channels)) + # TODO Soham: add a check for presence of kv channels parameter (head_dim) + tensor_space.add_tensor_dim( + TensorDim( + AudioEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads + ) + ) + self.transformer.setup_tensor_space(tensor_space, type="audio") + + @property + def enabled(self) -> bool: + return self.type != AudioEncoderType.none diff --git a/fast_llm/layers/audio_encoder/encoder.py b/fast_llm/layers/audio_encoder/encoder.py new file mode 100644 index 00000000..8cd071bd --- /dev/null +++ b/fast_llm/layers/audio_encoder/encoder.py @@ -0,0 +1,61 @@ +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderDimNames, AudioEncoderKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ + + +class AudioConv(Layer): + def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): + super().__init__() + self._tensor_space = tensor_space + # TODO Toby: lr_scale + self.conv1_weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), + ), + init_method=init_normal_(), + ) + self.conv1_stride = 1 + + self.conv2_weight = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.in_channels), # in/out channels are the same + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), + ), + init_method=init_normal_(), + ) + self.conv2_stride = 2 + + if config.conv_bias: + self.bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),) + ) + else: + self.bias = None + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + hidden_dims = kwargs[AudioEncoderKwargs.hidden_dims] + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims(hidden_dims, tensor_name="audio conv output", dtype=input_.dtype) + input_ = torch.nn.functional.conv1d(input_, self.conv1_weight, self.bias, stride=self.conv1_stride) + input_ = torch.nn.functional.gelu(input_) + input_ = torch.nn.functional.conv1d(input_, self.conv2_weight, self.bias, stride=self.conv2_stride) + input_ = torch.nn.functional.gelu(input_) + + # TODO Toby: add learned position embeddings and dropout + audio_embeddings = audio_embeddings.reshape(*(x.size for x in hidden_dims)) + + return audio_embeddings diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py new file mode 100644 index 00000000..54bfeef6 --- /dev/null +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -0,0 +1,47 @@ +import typing + +import torch +from torchaudio.transforms import MelSpectrogram + +from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderKwargs + +# from transformers import WhisperFeatureExtractor + + +class AudioPreprocessor(Preprocessor): + def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + + # self.feature_extractor = WhisperFeatureExtractor(sampling_rate=self._config.aud_sampling_rate) + + self.mel_transform = MelSpectrogram( + sample_rate=self._config.aud_sampling_rate, + n_fft=400, + win_length=400, + hop_length=160, + n_mels=80, + f_min=0.0, + f_max=8000.0, + mel_scale="slaney", + norm="slaney", + center=True, + power=2.0, + ) + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + pass + + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: + audio_raw = kwargs[AudioEncoderKwargs.audio] + # audio_inputs = self.feature_extractor(audio_raw, sampling_rate=16000, return_tensors="pt") + self.mel_transform.to(self._tensor_space.distributed.device) + + audio_mel = [] + for batch in audio_raw: + batch_stacked = torch.stack(batch).unsqueeze(1) + audio_mel.append(self.mel_transform(batch_stacked)) + kwargs[AudioEncoderKwargs.audio_mel] = torch.cat(audio_mel) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index ea72de5c..625e5da6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -48,12 +48,12 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.architecture, ) # TODO Soham: make this None by default. Need to figure out how to handle this in the config (see ) - vision_encoder: VisionEncoderConfig | None = Field( + vision_encoder: VisionEncoderConfig = Field( default_factory=VisionEncoderConfig, desc="Configuration for the vision encoder that transforms images into embeddings.", hint=FieldHint.optional, ) - audio_encoder: AudioEncoderConfig | None = Field( + audio_encoder: AudioEncoderConfig = Field( default_factory=AudioEncoderConfig, desc="Configuration for the audio encoder that transforms audio into embeddings.", hint=FieldHint.optional, diff --git a/fast_llm/layers/transformer/audio_transformer.py b/fast_llm/layers/transformer/audio_transformer.py new file mode 100644 index 00000000..43ee3f46 --- /dev/null +++ b/fast_llm/layers/transformer/audio_transformer.py @@ -0,0 +1,41 @@ +import torch + +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.audio_encoder.config import AudioTransformerDimNames, AudioTransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.tensor import TensorMeta + + +class AudioTransformerLayer(TransformerLayer): + """ + A vision transformer layer to encode image patches + """ + + def __init__( + self, + config: TransformerConfig, + tensor_space: TensorSpace, + layer_index: int, + return_input: bool = False, + ): + super().__init__(config, tensor_space, layer_index, return_input) + + hidden_dim = self._tensor_space.get_tensor_dim(AudioTransformerDimNames.hidden) + + # use regular layernorm (not rms norm) + self.norm_1 = self._config.normalization.get_layer(hidden_dim) + self.norm_2 = self._config.normalization.get_layer(hidden_dim) + + self.norm_1 = self._config.peft.apply_other(self.norm_1) + self.norm_2 = self._config.peft.apply_other(self.norm_2) + + @property + def name(self) -> str: + return f"Audio transformer layer {self._layer_index}" + + def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): + dims = kwargs[AudioTransformerKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 40a29959..3847fd31 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -170,6 +170,15 @@ class VisionRotaryConfig(RotaryConfig): ) +# @config_class() +# class AudioRotaryConfig(RotaryConfig): +# type: RotaryEmbeddingType = Field( +# default=RotaryEmbeddingType.none, +# desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", +# hint=FieldHint.feature, +# ) + + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -664,6 +673,10 @@ def setup_tensor_space(self, tensor_space: TensorSpace, type: str | None = None) from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames transformer_dim_names = VisionTransformerDimNames + elif type == "audio": + from fast_llm.layers.audio_encoder.config import AudioTransformerDimNames + + transformer_dim_names = AudioTransformerDimNames else: transformer_dim_names = TransformerDimNames tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -775,6 +788,11 @@ class AudioTransformerConfig(TransformerConfig): causal: bool = FieldUpdate( default=False, - desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Audio Transformer.", hint=FieldHint.feature, ) + # rotary: AudioRotaryConfig = FieldUpdate( + # default_factory=AudioRotaryConfig, + # desc="Configuration for the rotary positional embeddings.", + # hint=FieldHint.feature, + # ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 16201576..f82051ab 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -51,13 +51,20 @@ class MistralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" + class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mtp_llama" trust_remote_code: typing.ClassVar[bool] = True + class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "llava" + +class WhisperGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "whisper" + + @config_class() class GPTBatchConfig(BatchConfig): sequence_length: int = Field( @@ -140,6 +147,7 @@ class GPTModelConfig(FastLLMModelConfig): MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, LlavaGPTHuggingfaceCheckpointFormat, + WhisperGPTHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index ad4df737..ec765e55 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -25,6 +25,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex +from fast_llm.layers.audio_encoder.config import AudioEncoderType from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig from fast_llm.layers.vision_encoder.config import VisionEncoderType @@ -38,6 +39,7 @@ MTPLlamaGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, + WhisperGPTHuggingfaceCheckpointFormat, ) from fast_llm.models.gpt.external.mtp_llama.configuration_mtp_llama import MTPLlamaConfig from fast_llm.models.gpt.model import GPTModel @@ -555,6 +557,242 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] +class WhisperHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = WhisperGPTHuggingfaceCheckpointFormat + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + # lm_converters = super()._create_config_converters() + lm_converters = super()._create_config_converters() + for idx, converter in enumerate(lm_converters): + if converter.export_names == (("model_type",),): + continue + elif converter.export_names == (("architectures",),): + ignore_index = idx + if converter.export_names: + converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) + + return ( + lm_converters[:ignore_index] + + lm_converters[ignore_index + 1 :] + + [ + ConstantImportParamConverter( + fast_llm_names=(("audio_encoder", "type"),), fast_llm_value=AudioEncoderType.whisper + ), + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["WhisperForConditionalGeneration"] + ), + # Audio Adapter + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "adapter_size"),), + # export_names=(("text_config", "hidden_size"),), + # ), + # ConstantImportParamConverter( + # fast_llm_names=(("vision_encoder", "patch_norm", "type"),), + # fast_llm_value=NormalizationType.rms_norm, + # ), + # ConstantImportParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), + # fast_llm_value=NormalizationType.rms_norm, + # ), + # Audio Transformer + RenameParamConverter( + fast_llm_names=(("audio_encoder", "transformer", "num_layers"),), + export_names=(("encoder_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("audio_encoder", "transformer", "hidden_size"),), + export_names=(("d_model",),), + ), + RenameParamConverter( + fast_llm_names=(("audio_encoder", "transformer", "num_attention_heads"),), + export_names=(("encoder_attention_heads",),), + ), + # RenameParamConverter( + # fast_llm_names=(("audio_encoder", "transformer", "head_groups"),), + # export_names=( + # ( + # "encoder_attention_heads", + # ), + # ), + # ), + RenameParamConverter( + fast_llm_names=(("audio_encoder", "transformer", "ffn_hidden_size"),), + export_names=(("encoder_ffn_dim",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("audio_encoder", "transformer", "activation_type"),), + export_names=(("activation_function",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + # ConstantImportParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True + # ), + # MappedConfigParamConverter( + # fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + # export_names=(("projector_hidden_act",),), + # fast_llm_value=ActivationType.from_hf_name, + # export_value=lambda activation_type: activation_type.hf_name, + # ), + # ConstantImportParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False + # ), + # RenameParamConverter( + # fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), + # export_names=(("vision_config", "rope_theta"),), + # ), + ] + ) + + def _create_vision_transformer_layer_converters( + self, + i: int, + ignore_export: bool = False, + hf_base_prefix: str = "", + fast_llm_offset: int = 1, + type: str | None = None, + ) -> list[WeightConverter]: + if type is not None: + if type == "vision": + transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer + else: + transformer_config: TransformerConfig = self._model.config.base_model.transformer + norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + converters = [] + names_bias_cls = [ + # Self-attn + ( + f"layers.{i+fast_llm_offset}.self_attn.query", + f"vision_tower.transformer.layers.{i}.attention.q_proj", + transformer_config.add_attn_qkv_bias, + QueryWeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.self_attn.key_value", + ( + f"vision_tower.transformer.layers.{i}.attention.k_proj", + f"vision_tower.transformer.layers.{i}.attention.v_proj", + ), + transformer_config.add_attn_qkv_bias, + KeyValueWeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.self_attn.dense", + f"vision_tower.transformer.layers.{i}.attention.o_proj", + transformer_config.add_attn_dense_bias, + WeightConverter, + ), + # Norm + ( + f"layers.{i+fast_llm_offset}.norm_1", + f"vision_tower.transformer.layers.{i}.attention_norm", + norm_bias, + WeightConverter, + ), + ( + f"layers.{i+fast_llm_offset}.norm_2", + f"vision_tower.transformer.layers.{i}.ffn_norm", + norm_bias, + WeightConverter, + ), + ] + for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: + converters += self._get_weight_and_bias_converters( + fast_llm_prefix, + () if ignore_export else hf_prefix, + use_bias, + cls=IgnoreExportWeightConverter if ignore_export else cls, + ) + + # MLP + if ignore_export: + converters += self._get_weight_and_bias_converters( + f"layers.{i+fast_llm_offset}.mlp.layer_1", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+fast_llm_offset}.mlp.layer_2", + (), + transformer_config.add_mlp_bias, + cls=IgnoreExportWeightConverter, + ) + converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] + else: + converters += self._get_vision_transformer_mlp_converters( + f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" + ) + return converters + + def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.feed_forward.down_proj.weight", + self._model.config.base_model, + ), + ] + + def _create_vision_transformer_converters(self) -> list[WeightConverter]: + num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers + vision_transformer_converters = [] + for layer in range(num_layers): + # TODO Soham: check if args are correct + vision_transformer_converters.extend( + self._create_vision_transformer_layer_converters( + layer, + ignore_export=False, + hf_base_prefix="vision_tower.transformer.layers.", + fast_llm_offset=1, + type="vision", + ) + ) + + return vision_transformer_converters + + def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: + patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] + if self._model.config.base_model.vision_encoder.conv_bias: + patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) + layernorm_converters = [ + WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), + ] + if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) + + vision_transformer_converters = self._create_vision_transformer_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 + adapter_converters = [ + WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), + WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), + # TODO Soham: add bias based on config + WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), + WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), + ] + + return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters + + def _create_weight_converters(self) -> list[WeightConverter]: + vision_encoder_converter = self._create_vision_encoder_weight_converters() + offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 + # Embeddings + lm_converters = [ + WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") + ] + for i in range(self._model.config.base_model.transformer.num_layers): + lm_converters += self._create_transformer_layer_converters( + fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" + ) + lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) + return vision_encoder_converter + lm_converters + + class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat @@ -950,4 +1188,5 @@ class AutoGPTHuggingfaceCheckpointHandler( MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, + WhisperGPTHuggingfaceCheckpointFormat.name: WhisperHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 1b91f3e6..9b450c73 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,13 +10,16 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.audio_encoder.adapter import AudioAdapter from fast_llm.layers.audio_encoder.config import AudioEncoderKwargs +from fast_llm.layers.audio_encoder.encoder import AudioConv from fast_llm.layers.audio_encoder.preprocessing import AudioPreprocessor from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding +from fast_llm.layers.transformer.audio_transformer import AudioTransformerLayer from fast_llm.layers.transformer.config import ( RoutingType, TransformerDimNames, @@ -84,7 +87,7 @@ def __init__( self._preprocessors.append( RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) ) - if self._config.audio_encoder: + if self._config.audio_encoder.enabled: self._preprocessors.append(AudioPreprocessor(self._config.audio_encoder, self._tensor_space)) def get_output_layers(self) -> list[Layer]: @@ -124,12 +127,33 @@ def get_vision_layers(self) -> list[Layer]: MultiModalEmbedding(self._config, self._tensor_space), ] + def get_audio_layers(self) -> list[Layer]: + audio_conv = AudioConv(self._config.audio_encoder, self._tensor_space) + audio_layers = [ + AudioTransformerLayer(self._config.audio_encoder.transformer, self._tensor_space, layer_index=idx + 1) + for idx in range(self._config.audio_encoder.transformer.num_layers) + ] + return [ + audio_conv, + *audio_layers, + AudioAdapter(self._config.audio_encoder, self._tensor_space), + MultiModalEmbedding(self._config, self._tensor_space), + ] + + def get_multimodal_layers(self) -> list[Layer]: + if self._config.vision_encoder.enabled: + return self.get_vision_layers() + elif self._config.audio_encoder.enabled: + return self.get_audio_layers() + else: + assert False + def get_layers(self) -> list[Layer]: return [ *( [LanguageModelEmbedding(self._config, self._tensor_space)] - if not self._config.vision_encoder.enabled - else self.get_vision_layers() + if not self._config.vision_encoder.enabled and not self._config.audio_encoder.enabled + else self.get_multimodal_layers() ), *[ TransformerLayer( @@ -423,7 +447,7 @@ def preprocess( if batch.audio is not None: kwargs[AudioEncoderKwargs.audio] = [ [ - aud.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + aud.to(device=self._tensor_space.distributed.device, dtype=torch.float32, non_blocking=True) for aud in audio ] for audio in batch.audio @@ -434,8 +458,11 @@ def preprocess( for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + audio_mel = kwargs.get(AudioEncoderKwargs.audio_mel, None) if image_patches is not None: preprocessed.append((image_patches, kwargs)) + elif audio_mel is not None: + preprocessed.append((audio_mel, kwargs)) else: preprocessed.append((tokens, kwargs)) From ca33ee83b22bea5c45a946a13209572b6aa73680 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 20:59:14 +0000 Subject: [PATCH 47/82] cleaner, extensible multimodal config --- fast_llm/layers/transformer/config.py | 44 +- fast_llm/layers/transformer/preprocessing.py | 12 - fast_llm/layers/transformer/transformer.py | 18 +- .../layers/transformer/vision_transformer.py | 14 +- fast_llm/layers/vision_encoder/config.py | 5 + .../layers/vision_encoder/preprocessing.py | 12 +- fast_llm/models/gpt/config.py | 30 + fast_llm/models/gpt/conversion.py | 984 ++++++++++++------ fast_llm/models/gpt/model.py | 3 +- 9 files changed, 740 insertions(+), 382 deletions(-) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 9a6bec07..a634bc3c 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -60,7 +60,7 @@ def __init_subclass__(cls, prefix="", **kwargs): super().__init_subclass__(**kwargs) cls._prefix = prefix for attr, value in BaseTransformerDimNames._kwargs_attributes.items(): - setattr(cls, value, f"{cls._prefix}_{value}") + setattr(cls, attr, f"{cls._prefix}_{value}") class TransformerDimNames(BaseTransformerDimNames, prefix=""): @@ -737,67 +737,69 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.hidden, self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.hidden, self.hidden_size)) # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - self.transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + self._transformer_dim_names.head_groups, self.head_groups, tensor if self.head_groups > 1 else None ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - self.transformer_dim_names.group_heads, + self._transformer_dim_names.group_heads, div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(self.transformer_dim_names.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(self.transformer_dim_names.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(self._transformer_dim_names.key_and_value, 2)) + tensor_space.add_tensor_dim( + kv_channels := TensorDim(self._transformer_dim_names.kv_channels, self.kv_channels) + ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_heads, (head_groups, group_heads)) + CompositeTensorDim(self._transformer_dim_names.composite_heads, (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self._transformer_dim_names.composite_query, (head_groups, group_heads, kv_channels)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - self.transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels) + self._transformer_dim_names.composite_key_value, (key_and_value, head_groups, kv_channels) ) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(self._transformer_dim_names.composite_dense, (head_groups, group_heads, kv_channels)) ) # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(self.transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim(mlp := TensorDim(self._transformer_dim_names.mlp, self.ffn_hidden_size, tensor)) tensor_space.add_tensor_dim( - gate_and_up := TensorDim(self.transformer_dim_names.gate_and_up, 2 if self.gated else 1) + gate_and_up := TensorDim(self._transformer_dim_names.gate_and_up, 2 if self.gated else 1) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_gated_mlp, (gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(experts := TensorDim(self.transformer_dim_names.experts, self.num_experts)) + tensor_space.add_tensor_dim(experts := TensorDim(self._transformer_dim_names.experts, self.num_experts)) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_expert_mlp, (experts, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_expert_mlp, (experts, mlp)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) ) - tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(self.transformer_dim_names.unshared_experts, self.num_unshared_experts)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.top_experts, self.num_experts_per_token)) + tensor_space.add_tensor_dim(TensorDim(self._transformer_dim_names.unshared_experts, self.num_unshared_experts)) # shared_experts if self.num_shared_experts: tensor_space.add_tensor_dim( - shared_experts := TensorDim(self.transformer_dim_names.shared_experts, self.num_shared_experts) + shared_experts := TensorDim(self._transformer_dim_names.shared_experts, self.num_shared_experts) ) tensor_space.add_tensor_dim( - CompositeTensorDim(self.transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) + CompositeTensorDim(self._transformer_dim_names.composite_shared_expert_mlp, (shared_experts, mlp)) ) tensor_space.add_tensor_dim( CompositeTensorDim( - self.transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) + self._transformer_dim_names.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) ) ) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 97c6c0f3..af1a53f6 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -251,12 +251,6 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, ): - # if isinstance(config, VisionTransformerConfig): - # self._transformer_dim_names = VisionTransformerDimNames - # self._transformer_kwargs = VisionTransformerKwargs - # elif isinstance(config, TransformerConfig): - # self._transformer_dim_names = TransformerDimNames - # self._transformer_kwargs = TransformerKwargs self._transformer_dim_names = config._transformer_dim_names self._transformer_kwargs = config._transformer_kwargs self._config = config @@ -328,12 +322,6 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - # if isinstance(config, VisionTransformerConfig): - # self._transformer_dim_names = VisionTransformerDimNames - # self._transformer_kwargs = VisionTransformerKwargs - # elif isinstance(config, TransformerConfig): - # self._transformer_dim_names = TransformerDimNames - # self._transformer_kwargs = TransformerKwargs self._transformer_dim_names = config._transformer_dim_names self._transformer_kwargs = config._transformer_kwargs diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 8bd1394e..2c79883b 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -9,7 +9,7 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage @@ -70,14 +70,6 @@ def _bias_dropout_add( def name(self) -> str: return f"{self._name} {self._layer_index}" - @property - def _transformer_kwargs(self) -> TransformerKwargs: - return TransformerKwargs - - @property - def _transformer_dim_names(self) -> TransformerDimNames: - return TransformerDimNames - def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[self._transformer_kwargs.hidden_dims] if self._return_input: @@ -157,3 +149,11 @@ def __init__( def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + + # @property + # def _transformer_kwargs(self) -> TransformerKwargs: + # return TransformerKwargs + + # @property + # def _transformer_dim_names(self) -> TransformerDimNames: + # return TransformerDimNames diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index 7f39f9cf..c2cfe9f2 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -1,21 +1,21 @@ import torch from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.layers.transformer.config import VisionTransformerKwargs from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.layers.vision_encoder.config import VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.tensor import TensorMeta class VisionTransformerLayer(TransformerLayer): _name: str = "Vision transformer layer" - @property - def _transformer_kwargs(self) -> VisionTransformerKwargs: - return VisionTransformerKwargs + # @property + # def _transformer_kwargs(self) -> VisionTransformerKwargs: + # return VisionTransformerKwargs - @property - def _transformer_dim_names(self) -> VisionTransformerDimNames: - return VisionTransformerDimNames + # @property + # def _transformer_dim_names(self) -> VisionTransformerDimNames: + # return VisionTransformerDimNames def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[VisionTransformerKwargs.hidden_dims] diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 70504901..6932c8fc 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -119,6 +119,11 @@ class VisionEncoderConfig(BaseModelConfig): desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", hint=FieldHint.core, ) + adapter_bias: bool = Field( + default=True, + desc="Whether to use bias in the adapter linear layer.", + hint=FieldHint.optional, + ) image_normalization: ImageNormalizationConfig = Field( default_factory=ImageNormalizationConfig, desc="Configuration for the normalization layers applied to the image patches.", diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 7ebfb522..5009123f 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -6,14 +6,8 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.layers.vision_encoder.config import ( - VisionEncoderConfig, - VisionEncoderDimNames, - VisionEncoderKwargs, - VisionTransformerDimNames, - VisionTransformerKwargs, -) +from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerDimNames, VisionTransformerKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -152,7 +146,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_position_ids = [] cu_seqlens = [0] max_seqlen = -1 - sequence_first = kwargs.get(TransformerKwargs.sequence_first) + kwargs.get(TransformerKwargs.sequence_first) for imgs, sizes in zip(images, image_sizes): # sum( # get_num_patches(*size, patch_size) for size in sizes diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 16201576..d7d32221 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -51,12 +51,22 @@ class MistralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" + class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mtp_llama" trust_remote_code: typing.ClassVar[bool] = True + class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "llava" + # Using default values for vision and text models. Can be overridden in the config + vision_name: typing.ClassVar[str] = "pixtral" + text_name: typing.ClassVar[str] = "mistral" + + +class PixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "pixtral" + @config_class() class GPTBatchConfig(BatchConfig): @@ -140,6 +150,7 @@ class GPTModelConfig(FastLLMModelConfig): MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, LlavaGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, ) @classmethod @@ -154,6 +165,25 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: return HuggingfaceGPTModelForCausalLM + @classmethod + def get_checkpoint_format(cls, format: type[CheckpointFormat]) -> type[CheckpointFormat]: + if isinstance(format, type) and issubclass(format, CheckpointFormat): + format_ = cls.get_checkpoint_format(format.name) + Assert.is_(format, format_) + return format_ + elif isinstance(format, dict): + for format_ in cls.checkpoint_formats: + if format_.name == format["name"]: + if (vision_name := format.get("vision_name")) is not None: + format_.vision_name = vision_name + if (text_name := format.get("text_name")) is not None: + format_.text_name = text_name + return format_ + for format_ in cls.checkpoint_formats: + if format_.name == format: + return format_ + raise ValueError(f"Checkpoint format {format} not supported for model {cls.model_name}") + @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index ad4df737..0b0796ed 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -6,8 +6,10 @@ import torch from transformers.configuration_utils import PretrainedConfig -from fast_llm.config import DEFAULT, MISSING -from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm import __version__ +from fast_llm.config import DEFAULT, MISSING, get_nested_dict_value, set_nested_dict_value +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.external import ( AutoStateDictCheckpointHandler, ConstantExportParamConverter, @@ -22,7 +24,7 @@ WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.common.config import NormalizationType @@ -36,6 +38,7 @@ MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, + PixtralGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) @@ -112,73 +115,70 @@ def import_weight( return (merged_weight.t().contiguous(),) -class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): - _model: GPTModel - _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig - """ - Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) - """ - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("num_hidden_layers",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("hidden_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_attention_heads"),), - export_names=(("num_attention_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "head_groups"),), - export_names=(("num_key_value_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=(("intermediate_size",),), - ), - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_word_embeddings",),), - ), - ] - - @abc.abstractmethod - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - pass +class TransformerWeightConverterMixin: - def _create_weight_converters( + def _get_weight_and_bias_converters( self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, ) -> list[WeightConverter]: - converters = [] - num_layers = self._model.config.base_model.transformer.num_layers + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters - # Embeddings - converters.append(WeightConverter(f"layers.0.word_embeddings_weight", f"model.embed_tokens.weight")) + def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: + num_layers = self._model.config.base_model.transformer.num_layers + prediction_heads = self._model.config.base_model.prediction_heads + norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + converters = [] - converters += self._create_lm_head_converters() + # Next-token prediction head + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + offset + 1}.final_norm", f"{hf_base_prefix}model.norm", norm_bias + ) + # Output weights + if self._model.config.base_model.tie_word_embeddings: + converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) + else: + converters.append( + WeightConverter(f"layers.{num_layers + offset + 1}.output_weights", f"{hf_base_prefix}lm_head.weight") + ) - for i in range(num_layers): - converters += self._create_transformer_layer_converters(f"layers.{i+1}", f"model.layers.{i}") + # MTP-heads > 0 are thrown away + # TODO Soham: handle offset with MTP + for i in range(1, prediction_heads): + logger.warning( + f"The model weights for the multi-token prediction head {i} are discarded during conversion." + ) + mtp_transformer_layer_index = num_layers - 1 + 2 * i + # MTP transformer layer + converters += self._create_transformer_layer_converters( + f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True + ) + # MTP output norm + converters += self._get_weight_and_bias_converters( + f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter + ) return converters @@ -249,71 +249,81 @@ def _create_transformer_layer_converters( converters += self._get_mlp_converters(f"{fast_llm_layer_name}", f"{hf_layer_name}") return converters - def _create_lm_head_converters(self, hf_base_prefix: str = "", fast_llm_offset: int = 1) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers - prediction_heads = self._model.config.base_model.prediction_heads - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - converters = [] - # Next-token prediction head - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + fast_llm_offset}.final_norm", f"{hf_base_prefix}model.norm", norm_bias - ) - # Output weights - if self._model.config.base_model.tie_word_embeddings: - converters.append(IgnoreImportWeightConverter((), f"{hf_base_prefix}lm_head.weight")) - else: - converters.append( - WeightConverter( - f"layers.{num_layers + fast_llm_offset}.output_weights", f"{hf_base_prefix}lm_head.weight" - ) - ) +class CommonHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): + _model: GPTModel + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + """ + Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) + """ - # MTP-heads > 0 are thrown away - # TODO Soham: handle offset with MTP - for i in range(1, prediction_heads): - logger.warning( - f"The model weights for the multi-token prediction head {i} are discarded during conversion." - ) - mtp_transformer_layer_index = num_layers - 1 + 2 * i - # MTP transformer layer - converters += self._create_transformer_layer_converters( - f"layers.{mtp_transformer_layer_index + 1}", "", ignore_export=True - ) - # MTP output norm - converters += self._get_weight_and_bias_converters( - f"layers.{mtp_transformer_layer_index + 2}.final_norm", (), norm_bias, IgnoreExportWeightConverter - ) + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_attention_heads"),), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "head_groups"),), + export_names=(("num_key_value_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=(("intermediate_size",),), + ), + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_word_embeddings",),), + ), + ] - return converters + @abc.abstractmethod + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + pass - def _get_weight_and_bias_converters( + def _create_weight_converters( self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, + hf_base_prefix: str = "", + offset: int = 0, ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) - ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) + converters = [] + num_layers = self._model.config.base_model.transformer.num_layers + + # Embeddings + converters.append( + WeightConverter(f"layers.{offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight") + ) + + converters += self._create_lm_head_converters(hf_base_prefix, offset=offset) + + for i in range(num_layers): + converters += self._create_transformer_layer_converters( + f"layers.{i+offset+1}", f"{hf_base_prefix}model.layers.{i}" ) + return converters @@ -555,266 +565,592 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] -class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): - format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat +class PixtralHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - # lm_converters = super()._create_config_converters() - lm_converters = super()._create_config_converters() - for idx, converter in enumerate(lm_converters): - if converter.export_names == (("model_type",),): - continue - elif converter.export_names == (("architectures",),): - ignore_index = idx - if converter.export_names: - converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) - - return ( - lm_converters[:ignore_index] - + lm_converters[ignore_index + 1 :] - + [ - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "type"),), fast_llm_value=VisionEncoderType.pixtral - ), - ConstantExportParamConverter( - export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] - ), - # Vision Adapter - RenameParamConverter( - fast_llm_names=(("vision_encoder", "adapter_size"),), - export_names=(("text_config", "hidden_size"),), - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "patch_norm", "type"),), - fast_llm_value=NormalizationType.rms_norm, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), - fast_llm_value=NormalizationType.rms_norm, - ), - # Vision Transformer - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), - export_names=( - ( - "vision_config", - "num_hidden_layers", - ), + return super()._create_config_converters() + [ + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=VisionEncoderType.pixtral), + ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_layers", ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), - export_names=( - ( - "vision_config", - "hidden_size", - ), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "hidden_size", ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), - export_names=( - ( - "vision_config", - "num_attention_heads", - ), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_attention_heads", ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), - export_names=( - ( - "vision_config", - "num_key_value_heads", - ), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "head_groups", ), ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), - export_names=( - ( - "vision_config", - "intermediate_size", - ), + export_names=(("num_key_value_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "ffn_hidden_size", ), ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), - export_names=( - ( - "vision_config", - "hidden_act", - ), + export_names=(("intermediate_size",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "kv_channels", ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True - ), - MappedConfigParamConverter( - fast_llm_names=(("vision_encoder", "adapter_activation_type"),), - export_names=(("projector_hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, ), - ConstantImportParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False - ), - RenameParamConverter( - fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), - export_names=(("vision_config", "rope_theta"),), + export_names=(("head_dim",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "rotary", + "theta", + ), ), - ] - ) + export_names=(("rope_theta",),), + ), + RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), + ] + + def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.feed_forward.down_proj.weight", + self._model.config.base_model, + ), + ] def _create_vision_transformer_layer_converters( - self, - i: int, - ignore_export: bool = False, - hf_base_prefix: str = "", - fast_llm_offset: int = 1, - type: str | None = None, + self, transformer_layer_index: int, fast_llm_offset: int = 1, hf_base_prefix: str = "" ) -> list[WeightConverter]: - if type is not None: - if type == "vision": - transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer - else: - transformer_config: TransformerConfig = self._model.config.base_model.transformer - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - converters = [] - names_bias_cls = [ + # Vision transformer layer + transformer_config = self._model.config.base_model.vision_encoder.transformer + norm_bias: bool = transformer_config.normalization.type == NormalizationType.layer_norm + name_bias_cls = [ # Self-attn ( - f"layers.{i+fast_llm_offset}.self_attn.query", - f"vision_tower.transformer.layers.{i}.attention.q_proj", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.query", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.q_proj", transformer_config.add_attn_qkv_bias, QueryWeightConverter, ), ( - f"layers.{i+fast_llm_offset}.self_attn.key_value", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.key_value", ( - f"vision_tower.transformer.layers.{i}.attention.k_proj", - f"vision_tower.transformer.layers.{i}.attention.v_proj", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.k_proj", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.v_proj", ), transformer_config.add_attn_qkv_bias, KeyValueWeightConverter, ), ( - f"layers.{i+fast_llm_offset}.self_attn.dense", - f"vision_tower.transformer.layers.{i}.attention.o_proj", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.dense", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention.o_proj", transformer_config.add_attn_dense_bias, WeightConverter, ), # Norm ( - f"layers.{i+fast_llm_offset}.norm_1", - f"vision_tower.transformer.layers.{i}.attention_norm", + f"layers.{fast_llm_offset + transformer_layer_index}.norm_1", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.attention_norm", norm_bias, WeightConverter, ), ( - f"layers.{i+fast_llm_offset}.norm_2", - f"vision_tower.transformer.layers.{i}.ffn_norm", + f"layers.{fast_llm_offset + transformer_layer_index}.norm_2", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}.ffn_norm", norm_bias, WeightConverter, ), ] - for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: + converters = [] + for fast_llm_prefix, hf_prefix, use_bias, cls in name_bias_cls: converters += self._get_weight_and_bias_converters( fast_llm_prefix, - () if ignore_export else hf_prefix, + hf_prefix, use_bias, - cls=IgnoreExportWeightConverter if ignore_export else cls, + cls, ) - # MLP - if ignore_export: - converters += self._get_weight_and_bias_converters( - f"layers.{i+fast_llm_offset}.mlp.layer_1", - (), - transformer_config.add_mlp_bias, - cls=IgnoreExportWeightConverter, - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+fast_llm_offset}.mlp.layer_2", - (), - transformer_config.add_mlp_bias, - cls=IgnoreExportWeightConverter, - ) - converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] - else: - converters += self._get_vision_transformer_mlp_converters( - f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" - ) + converters += self._get_transformer_mlp_converters( + f"layers.{fast_llm_offset + transformer_layer_index}", + f"{hf_base_prefix}transformer.layers.{transformer_layer_index}", + ) return converters - def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - return [ - SplitWeightConverter( - f"{fast_llm_prefix}.mlp.layer_1.weight", - (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), - ), - MLPLayer2Converter( - f"{fast_llm_prefix}.mlp.layer_2.weight", - f"{hf_prefix}.feed_forward.down_proj.weight", - self._model.config.base_model, - ), - ] + def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: + converters = [] + converters.append(WeightConverter(f"layers.{offset}.weight", f"{hf_base_prefix}patch_conv.weight")) + if self._model.config.base_model.vision_encoder.conv_bias: + converters.append(WeightConverter(f"layers.{offset}.bias", f"{hf_base_prefix}patch_conv.bias")) + converters.append(WeightConverter(f"layers.{offset}.norm.weight", f"{hf_base_prefix}ln_pre.weight")) + if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: + converters.append(WeightConverter(f"layers.{offset}.norm.bias", f"{hf_base_prefix}ln_pre.bias")) - def _create_vision_transformer_converters(self) -> list[WeightConverter]: num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers - vision_transformer_converters = [] - for layer in range(num_layers): - # TODO Soham: check if args are correct - vision_transformer_converters.extend( - self._create_vision_transformer_layer_converters( - layer, - ignore_export=False, - hf_base_prefix="vision_tower.transformer.layers.", - fast_llm_offset=1, - type="vision", - ) + for i in range(num_layers): + converters += self._create_vision_transformer_layer_converters(i, offset + 1, hf_base_prefix) + + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.weight", "multi_modal_projector.linear_1.weight" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.weight", "multi_modal_projector.linear_2.weight" + ), + ] + ) + if self._model.config.base_model.vision_encoder.adapter_bias: + converters.extend( + [ + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_1.bias", "multi_modal_projector.linear_1.bias" + ), + WeightConverter( + f"layers.{offset + num_layers + 1}.layer_2.bias", "multi_modal_projector.linear_2.bias" + ), + ] ) - return vision_transformer_converters + return converters - def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] - if self._model.config.base_model.vision_encoder.conv_bias: - patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) - layernorm_converters = [ - WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), - ] - if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: - layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) - - vision_transformer_converters = self._create_vision_transformer_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 - adapter_converters = [ - WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), - WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), - # TODO Soham: add bias based on config - WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), - WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), - ] + @property + def num_layers(self) -> int: + # +2 for projector and conv layers + return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 - return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters - def _create_weight_converters(self) -> list[WeightConverter]: - vision_encoder_converter = self._create_vision_encoder_weight_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 - # Embeddings - lm_converters = [ - WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") - ] - for i in range(self._model.config.base_model.transformer.num_layers): - lm_converters += self._create_transformer_layer_converters( - fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" +class LlavaHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + cfg_dict = cls._load_config(config.path) + kwargs = {} + if "text_config" in cfg_dict: + text_kwargs = cls._import_config(cfg_dict["text_config"]) + kwargs.update(text_kwargs) + if "vision_config" in cfg_dict: + vision_kwargs = cls._import_config(cfg_dict["vision_config"]) + vision_kwargs = {tuple(["vision_encoder"] + list(key)): value for key, value in vision_kwargs.items()} + kwargs.update(vision_kwargs) + kwargs.update( + cls._import_config( + {key: value for key, value in cfg_dict.items() if key not in ("text_config", "vision_config")} ) - lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) - return vision_encoder_converter + lm_converters + ) + imported_model_config = cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + return CheckpointMetadata( + fast_llm_version=__version__, + model=cls._model_class, + format=config.format, + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + shards=["weights"], + ) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] + ), + MappedConfigParamConverter( + fast_llm_names=(("vision_encoder", "adapter_activation_type"),), + export_names=(("projector_hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ] + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: + handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + kwargs = {} + for converter in handler_cls._create_config_converters(): + try: + values = () + for export_name in converter.export_names: + try: + value = get_nested_dict_value(config, export_name) + except KeyError: + value = MISSING + values = values + (value,) + values = converter.import_params(values) + for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): + if value is MISSING: + raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") + if fast_llm_name in kwargs: + raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") + kwargs[fast_llm_name] = value + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return kwargs + + @classmethod + def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: + exported_config = {} + vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + for converter in vision_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, ("vision_encoder",) + fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("vision_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in text_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("text_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return exported_config + + def _create_weight_converters(self): + vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.vision_name) + vision_handler = vision_handler_cls(self._model) + converters = vision_handler._create_weight_converters(hf_base_prefix="vision_tower.", offset=0) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.text_name) + text_handler = text_handler_cls(self._model) + converters.extend( + text_handler._create_weight_converters(hf_base_prefix="language_model.", offset=vision_handler.num_layers) + ) + return converters + + +# class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): +# format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + +# @classmethod +# def _create_config_converters(cls) -> list[ParamConverter]: +# # lm_converters = super()._create_config_converters() +# lm_converters = super()._create_config_converters() +# for idx, converter in enumerate(lm_converters): +# if converter.export_names == (("model_type",),): +# continue +# elif converter.export_names == (("architectures",),): +# ignore_index = idx +# if converter.export_names: +# converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) + +# return ( +# lm_converters[:ignore_index] +# + lm_converters[ignore_index + 1 :] +# + [ +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "type"),), fast_llm_value=VisionEncoderType.pixtral +# ), +# ConstantExportParamConverter( +# export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] +# ), +# # Vision Adapter +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "adapter_size"),), +# export_names=(("text_config", "hidden_size"),), +# ), +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "patch_norm", "type"),), +# fast_llm_value=NormalizationType.rms_norm, +# ), +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), +# fast_llm_value=NormalizationType.rms_norm, +# ), +# # Vision Transformer +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), +# export_names=( +# ( +# "vision_config", +# "num_hidden_layers", +# ), +# ), +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), +# export_names=( +# ( +# "vision_config", +# "hidden_size", +# ), +# ), +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), +# export_names=( +# ( +# "vision_config", +# "num_attention_heads", +# ), +# ), +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), +# export_names=( +# ( +# "vision_config", +# "num_key_value_heads", +# ), +# ), +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), +# export_names=( +# ( +# "vision_config", +# "intermediate_size", +# ), +# ), +# ), +# MappedConfigParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), +# export_names=( +# ( +# "vision_config", +# "hidden_act", +# ), +# ), +# fast_llm_value=ActivationType.from_hf_name, +# export_value=lambda activation_type: activation_type.hf_name, +# ), +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True +# ), +# MappedConfigParamConverter( +# fast_llm_names=(("vision_encoder", "adapter_activation_type"),), +# export_names=(("projector_hidden_act",),), +# fast_llm_value=ActivationType.from_hf_name, +# export_value=lambda activation_type: activation_type.hf_name, +# ), +# ConstantImportParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False +# ), +# RenameParamConverter( +# fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), +# export_names=(("vision_config", "rope_theta"),), +# ), +# ] +# ) + +# def _create_vision_transformer_layer_converters( +# self, +# i: int, +# ignore_export: bool = False, +# hf_base_prefix: str = "", +# fast_llm_offset: int = 1, +# type: str | None = None, +# ) -> list[WeightConverter]: +# if type is not None: +# if type == "vision": +# transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer +# else: +# transformer_config: TransformerConfig = self._model.config.base_model.transformer +# norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm +# converters = [] +# names_bias_cls = [ +# # Self-attn +# ( +# f"layers.{i+fast_llm_offset}.self_attn.query", +# f"vision_tower.transformer.layers.{i}.attention.q_proj", +# transformer_config.add_attn_qkv_bias, +# QueryWeightConverter, +# ), +# ( +# f"layers.{i+fast_llm_offset}.self_attn.key_value", +# ( +# f"vision_tower.transformer.layers.{i}.attention.k_proj", +# f"vision_tower.transformer.layers.{i}.attention.v_proj", +# ), +# transformer_config.add_attn_qkv_bias, +# KeyValueWeightConverter, +# ), +# ( +# f"layers.{i+fast_llm_offset}.self_attn.dense", +# f"vision_tower.transformer.layers.{i}.attention.o_proj", +# transformer_config.add_attn_dense_bias, +# WeightConverter, +# ), +# # Norm +# ( +# f"layers.{i+fast_llm_offset}.norm_1", +# f"vision_tower.transformer.layers.{i}.attention_norm", +# norm_bias, +# WeightConverter, +# ), +# ( +# f"layers.{i+fast_llm_offset}.norm_2", +# f"vision_tower.transformer.layers.{i}.ffn_norm", +# norm_bias, +# WeightConverter, +# ), +# ] +# for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: +# converters += self._get_weight_and_bias_converters( +# fast_llm_prefix, +# () if ignore_export else hf_prefix, +# use_bias, +# cls=IgnoreExportWeightConverter if ignore_export else cls, +# ) + +# # MLP +# if ignore_export: +# converters += self._get_weight_and_bias_converters( +# f"layers.{i+fast_llm_offset}.mlp.layer_1", +# (), +# transformer_config.add_mlp_bias, +# cls=IgnoreExportWeightConverter, +# ) +# converters += self._get_weight_and_bias_converters( +# f"layers.{i+fast_llm_offset}.mlp.layer_2", +# (), +# transformer_config.add_mlp_bias, +# cls=IgnoreExportWeightConverter, +# ) +# converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] +# else: +# converters += self._get_vision_transformer_mlp_converters( +# f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" +# ) +# return converters + +# def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: +# return [ +# SplitWeightConverter( +# f"{fast_llm_prefix}.mlp.layer_1.weight", +# (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), +# ), +# MLPLayer2Converter( +# f"{fast_llm_prefix}.mlp.layer_2.weight", +# f"{hf_prefix}.feed_forward.down_proj.weight", +# self._model.config.base_model, +# ), +# ] + +# def _create_vision_transformer_converters(self) -> list[WeightConverter]: +# num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers +# vision_transformer_converters = [] +# for layer in range(num_layers): +# # TODO Soham: check if args are correct +# vision_transformer_converters.extend( +# self._create_vision_transformer_layer_converters( +# layer, +# ignore_export=False, +# hf_base_prefix="vision_tower.transformer.layers.", +# fast_llm_offset=1, +# type="vision", +# ) +# ) + +# return vision_transformer_converters + +# def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: +# patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] +# if self._model.config.base_model.vision_encoder.conv_bias: +# patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) +# layernorm_converters = [ +# WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), +# ] +# if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: +# layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) + +# vision_transformer_converters = self._create_vision_transformer_converters() +# offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 +# adapter_converters = [ +# WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), +# WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), +# # TODO Soham: add bias based on config +# WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), +# WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), +# ] + +# return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters + +# def _create_weight_converters(self) -> list[WeightConverter]: +# vision_encoder_converter = self._create_vision_encoder_weight_converters() +# offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 +# # Embeddings +# lm_converters = [ +# WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") +# ] +# for i in range(self._model.config.base_model.transformer.num_layers): +# lm_converters += self._create_transformer_layer_converters( +# fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" +# ) +# lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) +# return vision_encoder_converter + lm_converters class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): @@ -950,4 +1286,6 @@ class AutoGPTHuggingfaceCheckpointHandler( MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, MTPLlamaGPTHuggingfaceCheckpointFormat.name: MTPLlamaHuggingfaceCheckpointHandler, LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, + PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, + # MultiModalGPTHuggingfaceCheckpointFormat.name: MultiModalHuggingfaceCheckpointHandler } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c1d9df90..72ff1b88 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -20,6 +20,7 @@ TransformerDimNames, TransformerKwargs, TransformerLossNames, + VisionTransformerDimNames, ) from fast_llm.layers.transformer.preprocessing import ( BackupAttentionPreprocessor, @@ -29,7 +30,7 @@ from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs, VisionTransformerDimNames +from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs from fast_llm.layers.vision_encoder.encoder import PatchConv from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig From f3a4a74a086f5cb81da86195a00d6549cf66844b Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 21:00:11 +0000 Subject: [PATCH 48/82] cleanup --- fast_llm/models/gpt/conversion.py | 262 ------------------------------ 1 file changed, 262 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 0b0796ed..35652547 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -891,268 +891,6 @@ def _create_weight_converters(self): return converters -# class LlavaHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): -# format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat - -# @classmethod -# def _create_config_converters(cls) -> list[ParamConverter]: -# # lm_converters = super()._create_config_converters() -# lm_converters = super()._create_config_converters() -# for idx, converter in enumerate(lm_converters): -# if converter.export_names == (("model_type",),): -# continue -# elif converter.export_names == (("architectures",),): -# ignore_index = idx -# if converter.export_names: -# converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) - -# return ( -# lm_converters[:ignore_index] -# + lm_converters[ignore_index + 1 :] -# + [ -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "type"),), fast_llm_value=VisionEncoderType.pixtral -# ), -# ConstantExportParamConverter( -# export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] -# ), -# # Vision Adapter -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "adapter_size"),), -# export_names=(("text_config", "hidden_size"),), -# ), -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "patch_norm", "type"),), -# fast_llm_value=NormalizationType.rms_norm, -# ), -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), -# fast_llm_value=NormalizationType.rms_norm, -# ), -# # Vision Transformer -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "num_layers"),), -# export_names=( -# ( -# "vision_config", -# "num_hidden_layers", -# ), -# ), -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "hidden_size"),), -# export_names=( -# ( -# "vision_config", -# "hidden_size", -# ), -# ), -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "num_attention_heads"),), -# export_names=( -# ( -# "vision_config", -# "num_attention_heads", -# ), -# ), -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "head_groups"),), -# export_names=( -# ( -# "vision_config", -# "num_key_value_heads", -# ), -# ), -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "ffn_hidden_size"),), -# export_names=( -# ( -# "vision_config", -# "intermediate_size", -# ), -# ), -# ), -# MappedConfigParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "activation_type"),), -# export_names=( -# ( -# "vision_config", -# "hidden_act", -# ), -# ), -# fast_llm_value=ActivationType.from_hf_name, -# export_value=lambda activation_type: activation_type.hf_name, -# ), -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True -# ), -# MappedConfigParamConverter( -# fast_llm_names=(("vision_encoder", "adapter_activation_type"),), -# export_names=(("projector_hidden_act",),), -# fast_llm_value=ActivationType.from_hf_name, -# export_value=lambda activation_type: activation_type.hf_name, -# ), -# ConstantImportParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False -# ), -# RenameParamConverter( -# fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), -# export_names=(("vision_config", "rope_theta"),), -# ), -# ] -# ) - -# def _create_vision_transformer_layer_converters( -# self, -# i: int, -# ignore_export: bool = False, -# hf_base_prefix: str = "", -# fast_llm_offset: int = 1, -# type: str | None = None, -# ) -> list[WeightConverter]: -# if type is not None: -# if type == "vision": -# transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer -# else: -# transformer_config: TransformerConfig = self._model.config.base_model.transformer -# norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm -# converters = [] -# names_bias_cls = [ -# # Self-attn -# ( -# f"layers.{i+fast_llm_offset}.self_attn.query", -# f"vision_tower.transformer.layers.{i}.attention.q_proj", -# transformer_config.add_attn_qkv_bias, -# QueryWeightConverter, -# ), -# ( -# f"layers.{i+fast_llm_offset}.self_attn.key_value", -# ( -# f"vision_tower.transformer.layers.{i}.attention.k_proj", -# f"vision_tower.transformer.layers.{i}.attention.v_proj", -# ), -# transformer_config.add_attn_qkv_bias, -# KeyValueWeightConverter, -# ), -# ( -# f"layers.{i+fast_llm_offset}.self_attn.dense", -# f"vision_tower.transformer.layers.{i}.attention.o_proj", -# transformer_config.add_attn_dense_bias, -# WeightConverter, -# ), -# # Norm -# ( -# f"layers.{i+fast_llm_offset}.norm_1", -# f"vision_tower.transformer.layers.{i}.attention_norm", -# norm_bias, -# WeightConverter, -# ), -# ( -# f"layers.{i+fast_llm_offset}.norm_2", -# f"vision_tower.transformer.layers.{i}.ffn_norm", -# norm_bias, -# WeightConverter, -# ), -# ] -# for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: -# converters += self._get_weight_and_bias_converters( -# fast_llm_prefix, -# () if ignore_export else hf_prefix, -# use_bias, -# cls=IgnoreExportWeightConverter if ignore_export else cls, -# ) - -# # MLP -# if ignore_export: -# converters += self._get_weight_and_bias_converters( -# f"layers.{i+fast_llm_offset}.mlp.layer_1", -# (), -# transformer_config.add_mlp_bias, -# cls=IgnoreExportWeightConverter, -# ) -# converters += self._get_weight_and_bias_converters( -# f"layers.{i+fast_llm_offset}.mlp.layer_2", -# (), -# transformer_config.add_mlp_bias, -# cls=IgnoreExportWeightConverter, -# ) -# converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] -# else: -# converters += self._get_vision_transformer_mlp_converters( -# f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" -# ) -# return converters - -# def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: -# return [ -# SplitWeightConverter( -# f"{fast_llm_prefix}.mlp.layer_1.weight", -# (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), -# ), -# MLPLayer2Converter( -# f"{fast_llm_prefix}.mlp.layer_2.weight", -# f"{hf_prefix}.feed_forward.down_proj.weight", -# self._model.config.base_model, -# ), -# ] - -# def _create_vision_transformer_converters(self) -> list[WeightConverter]: -# num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers -# vision_transformer_converters = [] -# for layer in range(num_layers): -# # TODO Soham: check if args are correct -# vision_transformer_converters.extend( -# self._create_vision_transformer_layer_converters( -# layer, -# ignore_export=False, -# hf_base_prefix="vision_tower.transformer.layers.", -# fast_llm_offset=1, -# type="vision", -# ) -# ) - -# return vision_transformer_converters - -# def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: -# patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] -# if self._model.config.base_model.vision_encoder.conv_bias: -# patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) -# layernorm_converters = [ -# WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), -# ] -# if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: -# layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) - -# vision_transformer_converters = self._create_vision_transformer_converters() -# offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 -# adapter_converters = [ -# WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), -# WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), -# # TODO Soham: add bias based on config -# WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), -# WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), -# ] - -# return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters - -# def _create_weight_converters(self) -> list[WeightConverter]: -# vision_encoder_converter = self._create_vision_encoder_weight_converters() -# offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 -# # Embeddings -# lm_converters = [ -# WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") -# ] -# for i in range(self._model.config.base_model.transformer.num_layers): -# lm_converters += self._create_transformer_layer_converters( -# fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" -# ) -# lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) -# return vision_encoder_converter + lm_converters - - class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat From 3b955b1600ba09c5b7844113b6fc55ee3916f261 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 22:23:07 +0000 Subject: [PATCH 49/82] fixes for pixtral --- fast_llm/models/gpt/conversion.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 35652547..b7f9f773 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -572,6 +572,12 @@ class PixtralHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, Huggi @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantImportParamConverter( + fast_llm_names=(("patch_norm", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), ConstantExportParamConverter(export_names=(("architectures",),), export_value=["PixtralVisionModel"]), ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=VisionEncoderType.pixtral), ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), @@ -646,6 +652,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: export_names=(("rope_theta",),), ), RenameParamConverter(fast_llm_names=(("patch_size",),), export_names=(("patch_size",),)), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), ] def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: @@ -803,6 +811,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_value=ActivationType.from_hf_name, export_value=lambda activation_type: activation_type.hf_name, ), + RenameParamConverter( + fast_llm_names=(("vision_encoder", "adapter_size"),), + export_names=(("projector_intermediate_size",),), + ), ] @classmethod From 49daf581600175c884265c00df4aaf04a9dc0f74 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 23:27:52 +0000 Subject: [PATCH 50/82] model fixes --- fast_llm/layers/multi_modal/embedding.py | 2 -- fast_llm/layers/transformer/config.py | 11 +---------- fast_llm/layers/vision_encoder/encoder.py | 4 ++-- fast_llm/models/gpt/model.py | 5 +++-- 4 files changed, 6 insertions(+), 16 deletions(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 9a035d8f..c67f82b4 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -55,7 +55,6 @@ def _forward( embeddings = reduce_forward(embeddings, group) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) - # TODO Soham: avoid cloning? embeddings = embeddings.clone() input_ = gather(input_, group, dim=0) for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): @@ -82,7 +81,6 @@ def _forward( # for positions in image_positions: # if positions > self._distributed_config.tensor_rank embeddings = torch.embedding(self.word_embeddings_weight, tokens) - # TODO Soham: avoid cloning? embeddings = embeddings.clone() for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index a634bc3c..49babb06 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -99,7 +99,7 @@ def __init_subclass__(cls, prefix="", **kwargs): super().__init_subclass__(**kwargs) cls._prefix = prefix for attr, value in BaseTransformerKwargs._kwargs_attributes.items(): - setattr(cls, value, f"{cls._prefix}_{value}") + setattr(cls, value, f"{cls._prefix}_{value}" if cls._prefix else value) class TransformerKwargs(BaseTransformerKwargs, prefix=""): @@ -824,15 +824,6 @@ def _transformer_dim_names(self) -> TransformerDimNames: return TransformerDimNames -@config_class() -class VisionRotaryConfig(RotaryConfig): - type: RotaryEmbeddingType = Field( - default=RotaryEmbeddingType.pixtral, - desc="The type of rotary embedding to use. Choices: none, default, llama3, yarn, pixtral.", - hint=FieldHint.feature, - ) - - @config_class() class VisionTransformerConfig(TransformerConfig): """ diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/encoder.py index 1df7f889..20749af4 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/encoder.py @@ -5,7 +5,7 @@ from fast_llm.core.ops import split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -69,7 +69,7 @@ def forward( losses: dict[str, typing.Any] | None = None, metrics: dict | None = None, ) -> torch.Tensor: - hidden_dims = kwargs[VisionEncoderKwargs.hidden_dims] + hidden_dims = kwargs[VisionTransformerKwargs.hidden_dims] if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="patch conv output", dtype=input_.dtype) micro_batch_size = kwargs[TransformerKwargs.micro_batch_size] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 72ff1b88..cbce66f2 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,6 +21,7 @@ TransformerKwargs, TransformerLossNames, VisionTransformerDimNames, + VisionTransformerKwargs, ) from fast_llm.layers.transformer.preprocessing import ( BackupAttentionPreprocessor, @@ -30,7 +31,7 @@ from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter -from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs +from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.layers.vision_encoder.encoder import PatchConv from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig @@ -244,7 +245,7 @@ def preprocess_meta( ) vision_kwargs.update( { - VisionEncoderKwargs.hidden_dims: vision_hidden_dims, + VisionTransformerKwargs.hidden_dims: vision_hidden_dims, } ) From b5ed9f4f6fdd6205225f730a136edb2f211c9f95 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 22 May 2025 19:07:00 +0000 Subject: [PATCH 51/82] more cleanup --- fast_llm/data/data/gpt/data.py | 4 +- fast_llm/data/preparator/gpt_memmap/config.py | 8 +- .../data/preparator/gpt_memmap/prepare.py | 13 +- fast_llm/data/tokenizer.py | 2 +- fast_llm/engine/schedule/config.py | 5 - fast_llm/functional/config.py | 8 +- fast_llm/layers/multi_modal/embedding.py | 1 - fast_llm/layers/transformer/attention.py | 18 +- fast_llm/layers/transformer/config.py | 23 --- fast_llm/layers/transformer/transformer.py | 8 - .../layers/transformer/vision_transformer.py | 8 - fast_llm/layers/vision_encoder/adapter.py | 1 - fast_llm/layers/vision_encoder/config.py | 21 ++- .../{encoder.py => patch_conv.py} | 6 +- .../layers/vision_encoder/preprocessing.py | 5 - fast_llm/models/gpt/conversion.py | 161 +++++++++--------- fast_llm/models/gpt/model.py | 11 +- fast_llm/models/gpt/trainer.py | 9 +- fast_llm/tools/cli.py | 1 + fast_llm/utils.py | 7 - 20 files changed, 129 insertions(+), 191 deletions(-) rename fast_llm/layers/vision_encoder/{encoder.py => patch_conv.py} (95%) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 4fcd42ae..31a19e14 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -51,13 +51,13 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling batch_images.append([torch.from_numpy(image) for image in sample.images]) has_images = True else: - batch_images.append(None) + batch_images.append([]) batch_image_positions = [] for sample in batch: if sample.image_positions is not None: batch_image_positions.append(torch.from_numpy(sample.image_positions)) else: - batch_image_positions.append(None) + batch_image_positions.append([]) return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 53f8e468..2e924380 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -151,12 +151,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, valid=check_field(Assert.geq, 1), ) - tokenize_batch_size: int = Field( - default=1000, - desc="Batch size for tokenization.", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 1), - ) saving_workers: int = Field( default=1, desc="Number of processes for saving the data.", @@ -170,7 +164,7 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): ) tokenizer: TokenizerConfig = Field( default_factory=TokenizerConfig, - desc="Tokenizer configuration.", + desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) image_patch_size: int = Field( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index c5a1b339..fa46ee92 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -138,20 +138,9 @@ def _document_generator(): if self._config.dataset.loss_masking_spans else None ), - # [np.array(Image.open(pathlib.Path(self._config.dataset.path) / path)) for path in item["image_paths"]] if self._config.dataset.image_paths else None, - # [np.array(im) for im in item["images"]] if self._config.dataset.images else None, item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, ) - # if "token_spans" in shard_dataset.column_names and self._config.dataset.loss_masking_spans is not None: - # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - # yield GPTSample( - # np.array(item["input_ids"], dtype=self._data_type.numpy), - # np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - # ) - # else: - # for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - # yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -279,7 +268,7 @@ def run(self) -> None: if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") tokenize_fn = self._tokenize_batch - # Avoid decoding bytes to images unless asked + # decoding bytes to images is slow and should be done only when needed if self._config.dataset.images is not None: dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 0acb65e4..1cbc1ec5 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -44,7 +44,7 @@ def _tokenize(self, text: str, begin=True, end=True) -> list[int]: def tokenize(self, text: str, char_spans=None, image_positions=None) -> tuple[list[int], list[tuple[int, int]]]: """ - Tokenize the input text and return the tokenized input_ids along with token spans. + Tokenize the input text and return the tokenized input_ids and if provided, token spans and image positions. """ if not image_positions: image_positions = [] diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 48daf0e6..204abdf1 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -50,11 +50,6 @@ class BatchConfig(Config): hint=FieldHint.setup, ) # Image inputs - patch_size: int | None = Field( - default=None, - desc="Patch size for each image token", - hint=FieldHint.optional, - ) image_size: int | None = Field( default=None, desc="Maximum image height and width", diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 233ea339..480fa067 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -40,6 +40,7 @@ class ActivationType(enum.StrEnum): """ gelu = "gelu" + gelu_pytorch_tanh = "gelu_pytorch_tanh" silu = "silu" relu = "relu" squared_relu = "squared_relu" @@ -67,7 +68,8 @@ def _set_activation_fn_map() -> None: global _ACTIVATION_FN_MAP _ACTIVATION_FN_MAP = { - ActivationType.gelu: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.gelu: torch.nn.functional.gelu, + ActivationType.gelu_pytorch_tanh: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), ActivationType.silu: torch.nn.functional.silu, ActivationType.relu: torch.nn.functional.relu, ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2), @@ -78,14 +80,14 @@ def _set_activation_fn_map() -> None: _ACTIVATION_FN_MAP: dict[ActivationType, typing.Callable[["torch.Tensor"], "torch.Tensor"]] = {} _ACTIVATION_HF_NAMES = { - ActivationType.gelu: "gelu_pytorch_tanh", + ActivationType.gelu: "gelu", + ActivationType.gelu_pytorch_tanh: "gelu_pytorch_tanh", ActivationType.silu: "silu", ActivationType.relu: "relu", ActivationType.squared_relu: "relu2", ActivationType.identity: "identity", } _ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()} -_ACTIVATION_HF_NAMES_INV["gelu"] = ActivationType.gelu MAX_DROPLESS_BLOCK_SIZE_ROW = 128 diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index c67f82b4..8c541e98 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -114,7 +114,6 @@ def forward( tensor_name="Embedding output", dtype=self._residual_dtype, ) - # image_embeddings = kwargs.pop(VisionEncoderKwargs.patch_embeddings) position_ids = kwargs.get(LanguageModelKwargs.position_ids) image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) image_positions = kwargs.get(VisionEncoderKwargs.image_positions) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3180b6cb..e88f64a3 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -191,7 +191,7 @@ def _get_meta( ) @property - def query_dims(self): + def _query_dims(self): return ( self._transformer_dim_names.batch, self._transformer_dim_names.sequence_q, @@ -200,7 +200,7 @@ def query_dims(self): ) @property - def kv_dims(self): + def _kv_dims(self): return ( self._transformer_dim_names.batch, self._transformer_dim_names.sequence_q, @@ -209,7 +209,7 @@ def kv_dims(self): ) @property - def context_dims(self): + def _context_dims(self): return ( self._transformer_dim_names.batch, self._transformer_dim_names.sequence_q, @@ -346,11 +346,11 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ if self._config.rotary.enabled: if self._debug_transformer: - self._debug_log(query, "query_rotary_input", self.query_dims, kwargs) + self._debug_log(query, "query_rotary_input", self._query_dims, kwargs) self._debug_log( key, "key_rotary_input", - self.kv_dims, + self._kv_dims, kwargs, ) rotary_fn = triton_rotary_autograd_ if self._config.rotary.triton else apply_rotary_embeddings @@ -402,20 +402,20 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ) if self._debug_transformer: - self._debug_log(query, "query", self.query_dims, kwargs) + self._debug_log(query, "query", self._query_dims, kwargs) self._debug_log( key, "key", - self.kv_dims, + self._kv_dims, kwargs, ) self._debug_log( value, "value", - self.kv_dims, + self._kv_dims, kwargs, ) - self._debug_log(input_, "context", self.context_dims, kwargs) + self._debug_log(input_, "context", self._context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 49babb06..b8d15367 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -110,29 +110,6 @@ class VisionTransformerKwargs(BaseTransformerKwargs, prefix="image_encoder"): patch_position_ids = "patch_position_ids" -# class TransformerKwargs: -# rotary_freq_q = "rotary_freq_q" -# rotary_freq_k = "rotary_freq_k" -# attention_mask = "attention_mask" -# attention_mask_value = "attention_mask_value" -# sequence_lengths = "sequence_lengths" -# cu_seqlens_q = "cu_seqlens_q" -# cu_seqlens_k = "cu_seqlens_k" -# max_seqlen_q = "max_seqlen_q" -# max_seqlen_k = "max_seqlen_k" -# # TODO: Review these -# presents = "presents" -# past_key_values = "past_key_values" -# sequence_first = "sequence_first" -# hidden_dims = "hidden_dims" -# sequence_q_dim = "sequence_q_dim" -# sequence_k_dim = "sequence_k_dim" -# sequence_length = "sequence_length" -# micro_batch_size = "micro_batch_size" -# # TODO: Move -# grad_output = "grad_output" - - class TransformerLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 2c79883b..392ebb88 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -149,11 +149,3 @@ def __init__( def _create_mixer(self): self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) - - # @property - # def _transformer_kwargs(self) -> TransformerKwargs: - # return TransformerKwargs - - # @property - # def _transformer_dim_names(self) -> TransformerDimNames: - # return TransformerDimNames diff --git a/fast_llm/layers/transformer/vision_transformer.py b/fast_llm/layers/transformer/vision_transformer.py index c2cfe9f2..7c1be0d1 100644 --- a/fast_llm/layers/transformer/vision_transformer.py +++ b/fast_llm/layers/transformer/vision_transformer.py @@ -9,14 +9,6 @@ class VisionTransformerLayer(TransformerLayer): _name: str = "Vision transformer layer" - # @property - # def _transformer_kwargs(self) -> VisionTransformerKwargs: - # return VisionTransformerKwargs - - # @property - # def _transformer_dim_names(self) -> VisionTransformerDimNames: - # return VisionTransformerDimNames - def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[VisionTransformerKwargs.hidden_dims] if self._return_input: diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index bf5f3f1a..41ea065d 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -20,7 +20,6 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): super().__init__() input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) self._activation_type = config.adapter_activation_type - # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? self.layer_1 = Linear( input_dim, tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 6932c8fc..f788b514 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -1,11 +1,12 @@ import enum -from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationConfig from fast_llm.layers.transformer.config import VisionTransformerConfig +from fast_llm.utils import Assert class VisionEncoderDimNames: @@ -129,18 +130,24 @@ class VisionEncoderConfig(BaseModelConfig): desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) + adapter_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the adapter weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + conv_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the convolutional layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.adapter_size, self.adapter_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.patch_size, self.patch_size)) tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.in_channels, 3)) - # TODO Soham: add a check for presence of kv channels parameter (head_dim) - tensor_space.add_tensor_dim( - TensorDim( - VisionEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads - ) - ) self.transformer.setup_tensor_space(tensor_space) @property diff --git a/fast_llm/layers/vision_encoder/encoder.py b/fast_llm/layers/vision_encoder/patch_conv.py similarity index 95% rename from fast_llm/layers/vision_encoder/encoder.py rename to fast_llm/layers/vision_encoder/patch_conv.py index 20749af4..68f22200 100644 --- a/fast_llm/layers/vision_encoder/encoder.py +++ b/fast_llm/layers/vision_encoder/patch_conv.py @@ -43,6 +43,7 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = tensor_space.distributed_config self._sequence_parallel = self._distributed_config.sequence_tensor_parallel + self._lr_scale = config.adapter_lr_scale # TODO Soham: lr_scale self.weight = ParameterMeta.from_dims( ( @@ -52,10 +53,13 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._tensor_space.get_tensor_dim(VisionEncoderDimNames.patch_size), ), init_method=init_normal_(), + lr_scale=self._lr_scale, ) if config.conv_bias: self.bias = ParameterMeta.from_dims( - (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),) + (self._tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_sclae=self._lr_scale, ) else: self.bias = None diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 5009123f..d85442a3 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -103,7 +103,6 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): self._distributed_config = self._tensor_space.distributed_config def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - # kwargs[VisionEncoderDimNames] kwargs[VisionEncoderKwargs.image_patches_meta] = TensorMeta.from_dims( ( TensorDim( @@ -141,16 +140,12 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] for imgs in images ] - # position_ids = position_ids_in_meshgrid(image_sizes, im_height, patch_size) patches = [] patch_position_ids = [] cu_seqlens = [0] max_seqlen = -1 kwargs.get(TransformerKwargs.sequence_first) for imgs, sizes in zip(images, image_sizes): - # sum( - # get_num_patches(*size, patch_size) for size in sizes - # ) seq_patches = [] sample_cu_seqlen = 0 for image, size in zip(imgs, sizes): diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index b7f9f773..95bbebde 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -115,8 +115,7 @@ def import_weight( return (merged_weight.t().contiguous(),) -class TransformerWeightConverterMixin: - +class WeightAndBiasConverterMixin: def _get_weight_and_bias_converters( self, fast_llm_prefix: str | tuple[str, ...], @@ -145,6 +144,83 @@ def _get_weight_and_bias_converters( ) return converters + +class CommonHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): + _model: GPTModel + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + """ + Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) + """ + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_attention_heads"),), + export_names=(("num_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "head_groups"),), + export_names=(("num_key_value_heads",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=(("intermediate_size",),), + ), + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_word_embeddings",),), + ), + ] + + @abc.abstractmethod + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + pass + + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = [] + num_layers = self._model.config.base_model.transformer.num_layers + + # Embeddings + converters.append( + WeightConverter(f"layers.{offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight") + ) + + converters += self._create_lm_head_converters(hf_base_prefix, offset=offset) + + for i in range(num_layers): + converters += self._create_transformer_layer_converters( + f"layers.{i+offset+1}", f"{hf_base_prefix}model.layers.{i}" + ) + + return converters + def _create_lm_head_converters(self, hf_base_prefix: str = "", offset: int = 0) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads @@ -250,83 +326,6 @@ def _create_transformer_layer_converters( return converters -class CommonHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): - _model: GPTModel - _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig - """ - Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) - """ - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),) - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=(("hidden_act",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("num_hidden_layers",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("hidden_size",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "num_attention_heads"),), - export_names=(("num_attention_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "head_groups"),), - export_names=(("num_key_value_heads",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=(("intermediate_size",),), - ), - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_word_embeddings",),), - ), - ] - - @abc.abstractmethod - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - pass - - def _create_weight_converters( - self, - hf_base_prefix: str = "", - offset: int = 0, - ) -> list[WeightConverter]: - converters = [] - num_layers = self._model.config.base_model.transformer.num_layers - - # Embeddings - converters.append( - WeightConverter(f"layers.{offset}.word_embeddings_weight", f"{hf_base_prefix}model.embed_tokens.weight") - ) - - converters += self._create_lm_head_converters(hf_base_prefix, offset=offset) - - for i in range(num_layers): - converters += self._create_transformer_layer_converters( - f"layers.{i+offset+1}", f"{hf_base_prefix}model.layers.{i}" - ) - - return converters - - class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Starcoder2GPTHuggingfaceCheckpointFormat @@ -565,7 +564,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] -class PixtralHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): +class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @@ -770,7 +769,7 @@ def num_layers(self) -> int: return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 -class LlavaHuggingfaceCheckpointHandler(TransformerWeightConverterMixin, HuggingfaceStateDictCheckpointHandler): +class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index cbce66f2..586b511b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -32,7 +32,7 @@ from fast_llm.layers.transformer.vision_transformer import VisionTransformerLayer from fast_llm.layers.vision_encoder.adapter import VisionAdapter from fast_llm.layers.vision_encoder.config import VisionEncoderDimNames, VisionEncoderKwargs -from fast_llm.layers.vision_encoder.encoder import PatchConv +from fast_llm.layers.vision_encoder.patch_conv import PatchConv from fast_llm.layers.vision_encoder.preprocessing import VisionPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -84,11 +84,6 @@ def __init__( self._preprocessors.append( RotaryEmbeddingPreprocessor(self._config.vision_encoder.transformer.rotary, self._tensor_space) ) - # self._vision_preprocessor = VisionPreprocessor(self._config.vision_encoder, self._tensor_space) - # if self._config.vision_encoder.transformer.rotary.enabled: - # self._vision_rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( - # self._config.vision_encoder.transformer.rotary, self._tensor_space - # ) def get_output_layers(self) -> list[Layer]: layers = [] @@ -178,14 +173,14 @@ def preprocess_meta( ] image_rescale_factor = self._config.vision_encoder.image_normalization.rescale_factor vision_kwargs = { - VisionEncoderKwargs.patch_size: batch_meta.patch_size, + VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, VisionEncoderKwargs.image_size: image_size, VisionEncoderKwargs.image_mean: image_mean, VisionEncoderKwargs.image_std: image_std, VisionEncoderKwargs.image_rescale_factor: image_rescale_factor, VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, VisionEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( - VisionEncoderDimNames.kv_channels + VisionTransformerDimNames.kv_channels ).size, VisionEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( VisionEncoderDimNames.out_channels diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 482fea02..840b8092 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -30,10 +30,15 @@ def _get_sampling_parameters( "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, "cross_document_attention": self._config.batch.cross_document_attention, "extra_tokens": self._config.model.base_model.prediction_heads, - "patch_size": self._config.batch.patch_size, - "image_size": self._config.batch.image_size, } ) + if self._config.model.base_model.vision_encoder.enabled: + parameters.update( + { + "patch_size": self._config.model.base_model.vision_encoder.patch_size, + "image_size": self._config.batch.image_size, + } + ) return parameters if _return_dict else GPTSamplingParameters(**parameters) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index 4d218c3f..0cc02f42 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -36,6 +36,7 @@ def fast_llm(args=None): if sys.gettrace(): raise logger.critical(traceback.format_exc()) + sys.exit(1) if __name__ == "__main__": diff --git a/fast_llm/utils.py b/fast_llm/utils.py index c5b7f07a..51e0eee5 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -336,10 +336,3 @@ def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple def check_equal_nested(config_a, config_b): if errors := compare_nested(config_a, config_b): raise ValueError("\n".join(errors)) - - -def prefix_class_vars(cls, prefix: str, base_cls: type): - for attr, value in vars(base_cls).items(): - if not attr.startswith("__") and isinstance(value, str) and not hasattr(cls, attr): - setattr(cls, attr, prefix + value) - return cls From dc888c8fc6596b0ba7483b4eaf184ba7015e2063 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 22 May 2025 23:05:57 +0000 Subject: [PATCH 52/82] image break token in sampling --- fast_llm/data/dataset/gpt/config.py | 1 + fast_llm/data/dataset/gpt/sampled.py | 45 +++++++++++++++++-- fast_llm/layers/vision_encoder/config.py | 5 +++ .../layers/vision_encoder/preprocessing.py | 10 +++++ fast_llm/models/gpt/trainer.py | 1 + 5 files changed, 58 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 44d1f4cc..004a062c 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -76,6 +76,7 @@ class GPTSamplingParameters(SamplingParameters): cross_document_attention: bool = True patch_size: int | None = None image_size: int | None = None + image_break_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 780b1887..de8e1d75 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -14,7 +14,7 @@ from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.layers.vision_encoder.preprocessing import get_num_patches, get_resize_dims +from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert try: @@ -138,7 +138,7 @@ def _sample(self) -> None: for i, sizes in enumerate(image_sizes): image_token_sizes.append( sum( - get_num_patches( + get_num_image_tokens( *get_resize_dims( *size, self._parameters.image_size, @@ -146,6 +146,7 @@ def _sample(self) -> None: self._parameters.patch_size, ), self._parameters.patch_size, + break_token=self._parameters.image_break_token is not None, ) for size in sizes ) @@ -211,6 +212,7 @@ def _sample(self) -> None: "sequence_length": self._parameters.sequence_length, "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, + "image_break_token": self._parameters.image_break_token, "config": self._config.to_dict(), } if self._truncate_documents: @@ -423,7 +425,7 @@ def __getitem__(self, index: int) -> typing.Any: text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) image_sizes = [ - get_num_patches( + get_num_image_tokens( *get_resize_dims( *image_length, self._parameters.image_size, @@ -431,6 +433,7 @@ def __getitem__(self, index: int) -> typing.Any: self._parameters.patch_size, ), self._parameters.patch_size, + break_token=self._parameters.image_break_token is not None, ) for image_length in image_lengths ] @@ -473,7 +476,41 @@ def __getitem__(self, index: int) -> typing.Any: # Add placeholders for image tokens token_ids.append(sample.token_ids[start_pos:im_position]) text_tokens_added += len(token_ids[-1]) - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + # token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + if self._parameters.image_break_token is not None: + # Calculate patch dimensions for the image + width, height = get_resize_dims( + image_lengths[idx][0], + image_lengths[idx][1], + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ) + num_patches_w = math.ceil(width / self._parameters.patch_size) + num_patches_h = math.ceil(height / self._parameters.patch_size) + + # Calculate the token count considering break tokens + tokens_per_row = num_patches_w + total_tokens = num_patches_h * tokens_per_row + ( + num_patches_h - 1 + ) # Add break tokens after each row except last + + # Create image token placeholder array + image_token_array = np.full((total_tokens,), -100, dtype=np.int64) + + # Add break tokens after each row except the last row + for row in range(num_patches_h - 1): + position = (row + 1) * tokens_per_row + row + image_token_array[position] = self._parameters.image_break_token + + token_ids.append(image_token_array) + + # Update image_tokens_added to reflect actual number of tokens added + image_tokens_added += total_tokens + else: + # Just add placeholders for all image tokens without break tokens + token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) + image_tokens_added += image_sizes[idx] image_positions.append(im_position + len(token_ids) + image_tokens_added) image_tokens_added += image_tokens start_pos = im_position diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index f788b514..5b972f12 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -130,6 +130,11 @@ class VisionEncoderConfig(BaseModelConfig): desc="Configuration for the normalization layers applied to the image patches.", hint=FieldHint.optional, ) + image_break_token: int | None = Field( + default=None, + desc="Token id to separate image rows. If None, no token id is applied is applied.", + hint=FieldHint.optional, + ) adapter_lr_scale: float | None = Field( default=None, desc="Custom learning rate scale for the adapter weights.", diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index d85442a3..5cffbff5 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -19,6 +19,16 @@ def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int] return div(height, patch_size) * div(width, patch_size) +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool) -> int: + """ + Calculate the number of image tokens. + If image_break is True, we consider 1 additional token after every row of patches. + """ + height_patches = div(height, patch_size) + width_patches = div(width, patch_size) + return height_patches * (width_patches + image_break) + + def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: """ Calculate the new dimensions for resizing an image while maintaining the aspect ratio. diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 840b8092..d1b6d19e 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -37,6 +37,7 @@ def _get_sampling_parameters( { "patch_size": self._config.model.base_model.vision_encoder.patch_size, "image_size": self._config.batch.image_size, + "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) From af3e2dbcb19bec618d88dbf1bfb913fe8940caf7 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 23 May 2025 22:47:04 +0000 Subject: [PATCH 53/82] minor fixes --- fast_llm/data/dataset/gpt/memmap.py | 6 +- fast_llm/data/dataset/gpt/sampled.py | 13 ++-- fast_llm/layers/multi_modal/embedding.py | 64 ++++++++++++++----- .../layers/vision_encoder/preprocessing.py | 2 +- 4 files changed, 60 insertions(+), 25 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 1efc312e..a202d2e1 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -10,7 +10,7 @@ from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.layers.vision_encoder.preprocessing import get_num_patches, get_resize_dims +from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert, div @@ -201,6 +201,7 @@ def get( use_loss_masking_spans: bool = False, patch_size: int | None = None, image_size: int | None = None, + image_break: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -239,9 +240,10 @@ def get( additional_tokens = 0 image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") while image_position >= span[0] and image_position <= span[1]: - image_tokens = get_num_patches( + image_tokens = get_num_image_tokens( get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), patch_size, + image_break=image_break, ) additional_tokens += image_tokens image_idx += 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index de8e1d75..f441d9b9 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -146,7 +146,7 @@ def _sample(self) -> None: self._parameters.patch_size, ), self._parameters.patch_size, - break_token=self._parameters.image_break_token is not None, + image_break=self._parameters.image_break_token is not None, ) for size in sizes ) @@ -433,7 +433,7 @@ def __getitem__(self, index: int) -> typing.Any: self._parameters.patch_size, ), self._parameters.patch_size, - break_token=self._parameters.image_break_token is not None, + image_break=self._parameters.image_break_token is not None, ) for image_length in image_lengths ] @@ -476,6 +476,7 @@ def __getitem__(self, index: int) -> typing.Any: # Add placeholders for image tokens token_ids.append(sample.token_ids[start_pos:im_position]) text_tokens_added += len(token_ids[-1]) + image_positions.append(text_tokens_added + im_position + image_tokens_added) # token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) if self._parameters.image_break_token is not None: # Calculate patch dimensions for the image @@ -491,12 +492,12 @@ def __getitem__(self, index: int) -> typing.Any: # Calculate the token count considering break tokens tokens_per_row = num_patches_w - total_tokens = num_patches_h * tokens_per_row + ( + resized_image_tokens = num_patches_h * tokens_per_row + ( num_patches_h - 1 ) # Add break tokens after each row except last # Create image token placeholder array - image_token_array = np.full((total_tokens,), -100, dtype=np.int64) + image_token_array = np.full((resized_image_tokens,), -100, dtype=np.int64) # Add break tokens after each row except the last row for row in range(num_patches_h - 1): @@ -506,13 +507,11 @@ def __getitem__(self, index: int) -> typing.Any: token_ids.append(image_token_array) # Update image_tokens_added to reflect actual number of tokens added - image_tokens_added += total_tokens + image_tokens_added += resized_image_tokens else: # Just add placeholders for all image tokens without break tokens token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) image_tokens_added += image_sizes[idx] - image_positions.append(im_position + len(token_ids) + image_tokens_added) - image_tokens_added += image_tokens start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) text_tokens_added += len(token_ids[-1]) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 8c541e98..12b58a76 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -9,9 +9,9 @@ from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs -from fast_llm.layers.vision_encoder.preprocessing import get_num_patches +from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div class MultiModalEmbedding(LanguageModelEmbedding): @@ -60,15 +60,32 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) - if self._sequence_parallel: - embeddings[position : position + num_image_tokens, sample_idx] = input_[ - image_embedding_offset : image_embedding_offset + num_image_tokens, sample_idx - ] - else: - embeddings[sample_idx, position : position + num_image_tokens] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens - ] + num_image_tokens = get_num_image_tokens(*size, self._config.vision_encoder.patch_size) + # Calculate the patch dimensions + patch_width = div(size[0], self._config.vision_encoder.patch_size) + patch_height = div(size[1], self._config.vision_encoder.patch_size) + + # Process row by row for both sequence parallel and non-parallel cases + for row in range(patch_height): + # Calculate source and destination starting positions + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + # Always use full patch_width + tokens_in_row = patch_width + + if self._sequence_parallel: + # Copy with dimensions swapped for sequence parallel case + embeddings[row_start_dst : row_start_dst + tokens_in_row, sample_idx] = input_[ + row_start_src : row_start_src + tokens_in_row, sample_idx + ] + else: + # Copy with normal dimension ordering + embeddings[sample_idx, row_start_dst : row_start_dst + tokens_in_row] = input_[ + sample_idx, row_start_src : row_start_src + tokens_in_row + ] + + # Move to the next image in the input tensor image_embedding_offset += num_image_tokens if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) @@ -85,10 +102,27 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_patches(*size, self._config.vision_encoder.patch_size) - embeddings[sample_idx, position : position + num_image_tokens] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_image_tokens - ] + num_image_tokens = get_num_image_tokens( + *size, + self._config.vision_encoder.patch_size, + image_break=self._config.vision_encoder.image_break_token is not None, + ) + # Calculate the patch dimensions + patch_width = div(size[0], self._config.vision_encoder.patch_size) + patch_height = div(size[1], self._config.vision_encoder.patch_size) + + # Process row by row + for row in range(patch_height): + # Calculate source and destination starting positions + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + # Copy row by row + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + + # Move to the next image in the input tensor image_embedding_offset += num_image_tokens if self._use_absolute_position_embeddings: diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 5cffbff5..c5c14a26 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -26,7 +26,7 @@ def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: """ height_patches = div(height, patch_size) width_patches = div(width, patch_size) - return height_patches * (width_patches + image_break) + return height_patches * (width_patches + image_break) - 1 def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: From 6d56be085309a4e0f74c24c5bad4aa8aea442708 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 24 May 2025 19:43:34 +0000 Subject: [PATCH 54/82] fix img break --- fast_llm/data/dataset/gpt/sampled.py | 6 +++--- fast_llm/layers/vision_encoder/preprocessing.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f441d9b9..2c068742 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -476,19 +476,19 @@ def __getitem__(self, index: int) -> typing.Any: # Add placeholders for image tokens token_ids.append(sample.token_ids[start_pos:im_position]) text_tokens_added += len(token_ids[-1]) - image_positions.append(text_tokens_added + im_position + image_tokens_added) + image_positions.append(text_tokens_added + image_tokens_added) # token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) if self._parameters.image_break_token is not None: # Calculate patch dimensions for the image - width, height = get_resize_dims( + height, width = get_resize_dims( image_lengths[idx][0], image_lengths[idx][1], self._parameters.image_size, self._parameters.image_size, self._parameters.patch_size, ) - num_patches_w = math.ceil(width / self._parameters.patch_size) num_patches_h = math.ceil(height / self._parameters.patch_size) + num_patches_w = math.ceil(width / self._parameters.patch_size) # Calculate the token count considering break tokens tokens_per_row = num_patches_w diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index c5c14a26..8404adae 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -26,7 +26,7 @@ def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: """ height_patches = div(height, patch_size) width_patches = div(width, patch_size) - return height_patches * (width_patches + image_break) - 1 + return height_patches * width_patches + (height_patches - 1 if image_break else 0) def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: From ce9164647d3a582b8a13fd3646a66f3a019c8966 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 27 May 2025 23:34:57 +0000 Subject: [PATCH 55/82] fixes --- fast_llm/layers/language_model/embedding.py | 5 ++++- fast_llm/layers/multi_modal/embedding.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d9406ed..f51f40df 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -99,7 +99,10 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> t input_ = split(input_, group=group, dim=0) if self._use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) - embeddings = torch.embedding(self.word_embeddings_weight, input_) + # mask padded tokens + input_mask = input_ >= 0 + masked_input = input_ * input_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) with set_generator( diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 12b58a76..f40df3f0 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -60,7 +60,11 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_image_tokens(*size, self._config.vision_encoder.patch_size) + num_image_tokens = get_num_image_tokens( + *size, + self._config.vision_encoder.patch_size, + image_break=self._config.vision_encoder.image_break_token is not None, + ) # Calculate the patch dimensions patch_width = div(size[0], self._config.vision_encoder.patch_size) patch_height = div(size[1], self._config.vision_encoder.patch_size) @@ -97,7 +101,10 @@ def _forward( # TODO Soham: get image positions for current split. Maybe in preprocessing? # for positions in image_positions: # if positions > self._distributed_config.tensor_rank - embeddings = torch.embedding(self.word_embeddings_weight, tokens) + # mask padded tokens + token_mask = tokens >= 0 + masked_tokens = tokens * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) embeddings = embeddings.clone() for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 From 7eea79bdfe1cd0275fd4310d75487a2b78c7a998 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 28 May 2025 00:21:31 +0000 Subject: [PATCH 56/82] update audio encoder --- fast_llm/data/data/gpt/data.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 7 +- fast_llm/layers/audio_encoder/adapter.py | 46 +++++-- fast_llm/layers/audio_encoder/config.py | 123 ++++++++---------- fast_llm/layers/audio_encoder/encoder.py | 52 ++++++-- .../layers/audio_encoder/preprocessing.py | 31 ++++- 6 files changed, 158 insertions(+), 103 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 577ebc79..a4b183ca 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -83,7 +83,7 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling images=batch_images if has_images else None, image_positions=batch_image_positions if has_images else None, audio=batch_audio if has_audio else None, - audio_positions=batch_image_positions if has_audio else None, + audio_positions=batch_audio_positions if has_audio else None, ) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 36f8a0e7..c21e0825 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -622,11 +622,10 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) - # images = [im for img_list in images for im in img_list] if images else None - # image_positions = np.array(image_positions) if image_positions else None - images = None + images = [im for img_list in images for im in img_list] if images else None + image_positions = np.array(image_positions) if image_positions else None - audio = [aud for aud_list in audio for aud in aud_list] if audio else None + audio = [aud for aud_list in audio for aud in aud_list] if audio else None # flatten audio_positions = np.array(audio_positions) if audio_positions else None # Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) diff --git a/fast_llm/layers/audio_encoder/adapter.py b/fast_llm/layers/audio_encoder/adapter.py index 4f77971e..8c0c7175 100644 --- a/fast_llm/layers/audio_encoder/adapter.py +++ b/fast_llm/layers/audio_encoder/adapter.py @@ -5,9 +5,9 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import torch_mlp_activation +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderDimNames from fast_llm.layers.common.linear import Linear from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs -from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames from fast_llm.tensor import TensorMeta, init_normal_ @@ -16,26 +16,36 @@ class AudioAdapter(Layer): Vision adapter layer for the LLM. """ - def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): + def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): super().__init__() - input_dim = tensor_space.get_tensor_dim(VisionEncoderDimNames.out_channels) + audio_hidden_dim = tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels) + input_dim = tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_input) self._activation_type = config.adapter_activation_type + self._use_adapter_bias = config.adapter_bias + + self.norm_1 = config.transformer.normalization.get_layer(audio_hidden_dim) + self.norm_2 = config.transformer.normalization.get_layer( + tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size) + ) + # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? self.layer_1 = Linear( input_dim, - tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), - bias=True, + tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size), + bias=self._use_adapter_bias, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) self.layer_2 = Linear( - tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), + tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size), tensor_space.get_tensor_dim(TransformerDimNames.hidden), - bias=True, + bias=self._use_adapter_bias, weight_init_method=init_normal_(), bias_init_method=init_normal_(), ) + self.aud_downsampling_k = config.aud_downsampling_k + def forward( self, input_: torch.Tensor, @@ -46,9 +56,25 @@ def forward( if isinstance(input_, TensorMeta): return TensorMeta.from_dims( kwargs[TransformerKwargs.hidden_dims], - tensor_name="Vision adapter output", + tensor_name="Audio adapter output", dtype=input_.dtype, ) - return self.layer_2( - torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + batch_size, seq_len, dim = input_.size() + + # Check if sequence length is divisible by downsampling rate. + if seq_len % self.aud_downsampling_k != 0: + # If not divisible, trim the end of the sequence. + trimmed_seq_len = seq_len - (seq_len % self.aud_downsampling_k) + input_ = input_[:, :trimmed_seq_len, :] + seq_len = trimmed_seq_len + + # Reshape: group every k frames together (concatenate along feature dimension). + new_seq_len = seq_len // self.aud_downsampling_k + input_ = input_.contiguous().view(batch_size, new_seq_len, dim * self.aud_downsampling_k) + + res = self.layer_2( + self.norm_2( + torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) + ) ) + return res diff --git a/fast_llm/layers/audio_encoder/config.py b/fast_llm/layers/audio_encoder/config.py index 52a8673e..3e09b39f 100644 --- a/fast_llm/layers/audio_encoder/config.py +++ b/fast_llm/layers/audio_encoder/config.py @@ -11,100 +11,65 @@ class AudioEncoderDimNames: in_channels = "audio_in_channels" out_channels = "audio_out_channels" kernel_size = "audio_kernel_size" + adapter_input = "audio_adapter_input" adapter_size = "audio_adapter_size" audio_channels = "audio_kv_channels" - - -class AudioTransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "audio_batch" - # TODO: Distinguish micro-sequence? - sequence_q = "audio_sequence_q" - sequence_q_tp = "audio_sequence_q_tp" - sequence_k = "audio_sequence_k" - hidden = "audio_hidden" - # Self-attention dimensions - head_groups = "audio_head_groups" - group_heads = "audio_group_heads" - key_and_value = "audio_key_value" - kv_channels = "audio_kv_channels" - composite_heads = "audio_composite_heads" - composite_query = "audio_composite_query" - composite_key_value = "audio_composite_key_value" - composite_dense = "audio_composite_dense" - # MLP dimensions - mlp = "audio_mlp" - gate_and_up = "audio_gate_and_up" - composite_gated_mlp = "audio_composite_gated_mlp" - experts = "audio_experts" - top_experts = "audio_top_experts" - shared_experts = "audio_shared_experts" - unshared_experts = "audio_unshared_experts" - composite_expert_mlp = "audio_composite_expert_mlp" - composite_gated_expert_mlp = "audio_composite_gated_expert_mlp" - composite_shared_expert_mlp = "audio_composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "audio_composite_gated_shared_expert_mlp" + max_source_positions = "audio_max_source_positions" class AudioEncoderKwargs: audio = "audio" audio_mel = "audio_mel" audio_positions = "audio_positions" - kv_channels = "audio_kv_channels" + + kv_channels = "audio_kv_channels" # TODO: check this + out_channels = "audio_out_channels" hidden_dims = "audio_hidden_dims" + # TODO: used for backup attention + sequence_length = "audio_sequence_length" + sequence_k_dim = "audio_sequence_k_dim" + sequence_q_dim = "audio_sequence_q_dim" + class AudioEncoderType(str, enum.Enum): none = "none" whisper = "whisper" -# # TODO Toby: do we need all of them? -class AudioTransformerKwargs: - rotary_freq_q = "audio_rotary_freq_q" - rotary_freq_k = "audio_rotary_freq_k" - attention_mask = "audio_attention_mask" - attention_mask_value = "audio_attention_mask_value" - sequence_lengths = "audio_sequence_lengths" - cu_seqlens_q = "audio_cu_seqlens_q" - cu_seqlens_k = "audio_cu_seqlens_k" - max_seqlen_q = "audio_max_seqlen_q" - max_seqlen_k = "audio_max_seqlen_k" - # TODO: Review these - presents = "audio_presents" - past_key_values = "audio_past_key_values" - sequence_first = "audio_sequence_first" - hidden_dims = "audio_hidden_dims" - sequence_q_dim = "audio_sequence_q_dim" - sequence_k_dim = "audio_sequence_k_dim" - sequence_length = "audio_sequence_length" - micro_batch_size = "audio_micro_batch_size" - # TODO: Move - grad_output = "audio_grad_output" - patch_position_ids = "patch_position_ids" - - @config_class() class AudioEncoderConfig(BaseModelConfig): _abstract = False - transformer: AudioTransformerConfig = Field( - default_factory=AudioTransformerConfig, - desc="Configuration for the audio transformer architecture.", - hint=FieldHint.core, - ) type: AudioEncoderType = Field( default=AudioEncoderType.none, desc="Type of the audio encoder. Choices: none, whisper.", hint=FieldHint.architecture, ) + transformer: AudioTransformerConfig = Field( + default_factory=AudioTransformerConfig, + desc="Configuration for the audio transformer architecture.", + hint=FieldHint.core, + ) + + # encoder configs conv_bias: bool = Field( - default=False, + default=True, desc="Whether to use bias in the convolutional layer.", hint=FieldHint.optional, ) + encoder_dropout: float = Field( + default=0.0, + desc="Dropout for encoder.", + hint=FieldHint.core, + ) + kernel_size: int = Field( + default=3, + desc="Encoder convolution layer kernel size.", + hint=FieldHint.core, + ) + + # adapter configs adapter_size: int = Field( default=5120, desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", @@ -115,6 +80,18 @@ class AudioEncoderConfig(BaseModelConfig): desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", hint=FieldHint.core, ) + adapter_bias: bool = Field( + default=True, + desc="Whether to use bias in the adapter layer.", + hint=FieldHint.optional, + ) + + # audio configs + num_mel_bins: int = Field( + default=80, + desc="Number of bins for mel spectogram.", + hint=FieldHint.core, + ) aud_downsampling_k: int = Field( default=5, desc="Audio downsampling k parameter.", @@ -127,16 +104,24 @@ class AudioEncoderConfig(BaseModelConfig): ) def setup_tensor_space(self, tensor_space: TensorSpace): + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.in_channels, self.num_mel_bins)) tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.out_channels, self.transformer.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.kernel_size, self.kernel_size)) + tensor_space.add_tensor_dim( + TensorDim(AudioEncoderDimNames.adapter_input, self.transformer.hidden_size * self.aud_downsampling_k) + ) tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.adapter_size, self.adapter_size)) - tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.in_channels)) - # TODO Soham: add a check for presence of kv channels parameter (head_dim) + tensor_space.add_tensor_dim( + TensorDim(AudioEncoderDimNames.max_source_positions, 1500) + ) # TODO: configure later + tensor_space.add_tensor_dim( TensorDim( - AudioEncoderDimNames.kv_channels, self.transformer.hidden_size // self.transformer.num_attention_heads + AudioEncoderDimNames.audio_channels, + self.transformer.hidden_size // self.transformer.num_attention_heads, ) ) - self.transformer.setup_tensor_space(tensor_space, type="audio") + self.transformer.setup_tensor_space(tensor_space) @property def enabled(self) -> bool: diff --git a/fast_llm/layers/audio_encoder/encoder.py b/fast_llm/layers/audio_encoder/encoder.py index 8cd071bd..20c7d507 100644 --- a/fast_llm/layers/audio_encoder/encoder.py +++ b/fast_llm/layers/audio_encoder/encoder.py @@ -4,7 +4,8 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderDimNames, AudioEncoderKwargs +from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderDimNames +from fast_llm.layers.transformer.config import AudioTransformerKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ @@ -12,6 +13,8 @@ class AudioConv(Layer): def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): super().__init__() self._tensor_space = tensor_space + self.dropout_p = config.encoder_dropout + # TODO Toby: lr_scale self.conv1_weight = ParameterMeta.from_dims( ( @@ -21,24 +24,36 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): ), init_method=init_normal_(), ) - self.conv1_stride = 1 + self.conv1_stride = 1 # TODO: parameterize? self.conv2_weight = ParameterMeta.from_dims( ( - self._tensor_space.get_tensor_dim(AudioEncoderDimNames.in_channels), # in/out channels are the same - self._tensor_space.get_tensor_dim(AudioEncoderDimNames.in_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), ), init_method=init_normal_(), ) - self.conv2_stride = 2 + self.conv2_stride = 2 # TODO: parameterize? if config.conv_bias: - self.bias = ParameterMeta.from_dims( - (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),) + self.conv1_bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), init_method=init_normal_() + ) + self.conv2_bias = ParameterMeta.from_dims( + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), init_method=init_normal_() ) else: - self.bias = None + self.conv1_bias = None + self.conv2_bias = None + + self.positional_embeddings = ParameterMeta.from_dims( + ( + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.max_source_positions), + self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), + ), + init_method=init_normal_(), + ) def forward( self, @@ -47,15 +62,24 @@ def forward( losses: dict[str, typing.Any] | None = None, metrics: dict | None = None, ) -> torch.Tensor: - hidden_dims = kwargs[AudioEncoderKwargs.hidden_dims] + hidden_dims = kwargs[AudioTransformerKwargs.hidden_dims] # TODO: check seq q if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="audio conv output", dtype=input_.dtype) - input_ = torch.nn.functional.conv1d(input_, self.conv1_weight, self.bias, stride=self.conv1_stride) + + # TODO: check how to best cast dtype + input_ = input_.to(self.conv1_weight.dtype) + + input_ = torch.nn.functional.conv1d( + input_, self.conv1_weight, self.conv1_bias, stride=self.conv1_stride, padding=1 + ) input_ = torch.nn.functional.gelu(input_) - input_ = torch.nn.functional.conv1d(input_, self.conv2_weight, self.bias, stride=self.conv2_stride) + input_ = torch.nn.functional.conv1d( + input_, self.conv2_weight, self.conv2_bias, stride=self.conv2_stride, padding=1 + ) input_ = torch.nn.functional.gelu(input_) - # TODO Toby: add learned position embeddings and dropout - audio_embeddings = audio_embeddings.reshape(*(x.size for x in hidden_dims)) + audio_embeddings = input_.permute(0, 2, 1) + audio_embeddings = audio_embeddings + self.positional_embeddings + audio_embeddings = torch.nn.functional.dropout(audio_embeddings, p=self.dropout_p, training=self.training) - return audio_embeddings + return audio_embeddings.contiguous() diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index 54bfeef6..9d0db1b4 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -33,15 +33,36 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): ) def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + # kwargs[AudioEncoderKwargs.audio_mel_meta] = TensorMeta.from_dims( + # ( + # TensorDim( + # VisionTransformerDimNames.batch, + # kwargs[TransformerKwargs.micro_batch_size] * kwargs[TransformerKwargs.sequence_q_dim].size, + # ), + # TensorDim(VisionEncoderDimNames.in_channels, 3), + # TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + # TensorDim(VisionEncoderDimNames.patch_size, kwargs[VisionEncoderKwargs.patch_size]), + # ), + # dtype=self._distributed_config.training_dtype.torch, + # ) pass def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: audio_raw = kwargs[AudioEncoderKwargs.audio] + flattened_audio = [audio_arr for sequence in audio_raw for audio_arr in sequence] + flattened_audio_tensor = torch.stack(flattened_audio, dim=0) # audio_inputs = self.feature_extractor(audio_raw, sampling_rate=16000, return_tensors="pt") self.mel_transform.to(self._tensor_space.distributed.device) - audio_mel = [] - for batch in audio_raw: - batch_stacked = torch.stack(batch).unsqueeze(1) - audio_mel.append(self.mel_transform(batch_stacked)) - kwargs[AudioEncoderKwargs.audio_mel] = torch.cat(audio_mel) + audio_mel = self.mel_transform(flattened_audio_tensor) + audio_mel = audio_mel[:, :, :-1] # TODO Toby: check this! + + # # set attention mask # TODO Toby: fix backup attention + # sequence_k = kwargs[self._transformer_kwargs.sequence_k_dim].size + # sequence_q = kwargs[self._transformer_kwargs.sequence_q_dim].size + # kwargs[self._transformer_kwargs.attention_mask] = self._mask[ + # None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k + # ] + # kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value + + kwargs[AudioEncoderKwargs.audio_mel] = audio_mel From daf98b3c9eac1c49f12a0f34e5d056b9c9d7a351 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 28 May 2025 00:23:20 +0000 Subject: [PATCH 57/82] audio transformer updates --- fast_llm/layers/language_model/config.py | 2 ++ fast_llm/layers/multi_modal/embedding.py | 18 +++++++++-- .../layers/transformer/audio_transformer.py | 5 ++- fast_llm/layers/transformer/config.py | 31 +++++++++++++++++-- fast_llm/layers/transformer/preprocessing.py | 6 ++-- 5 files changed, 51 insertions(+), 11 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4a425da8..8ba066cb 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -182,6 +182,8 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) if self.vision_encoder.enabled: self.vision_encoder.setup_tensor_space(tensor_space) + if self.audio_encoder.enabled: + self.audio_encoder.setup_tensor_space(tensor_space) @property def num_absolute_position_embeddings(self) -> int: diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 8c541e98..3bce539d 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -5,6 +5,7 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather, reduce_forward, split from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.audio_encoder.config import AudioEncoderKwargs from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.config import TransformerKwargs @@ -34,6 +35,7 @@ def _forward( position_ids: torch.Tensor | None, image_positions: list[torch.Tensor] | None, image_sizes: list[list[tuple[int, int]]] | None, + audio_positions: list[torch.Tensor] | None, ) -> torch.Tensor: """ Forward pass for the multi-modal embedding layer. @@ -57,6 +59,7 @@ def _forward( embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) embeddings = embeddings.clone() input_ = gather(input_, group, dim=0) + # TODO: Toby implement audio for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): @@ -91,6 +94,13 @@ def _forward( ] image_embedding_offset += num_image_tokens + audio_position_idx = 0 + for sample_idx, positions in enumerate(audio_positions): + for position in positions: + num_audio_tokens = input_.shape[1] # TODO: Toby better way to get this? + embeddings[sample_idx, position : position + num_audio_tokens] = input_[audio_position_idx] + audio_position_idx += 1 + if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) with set_generator( @@ -114,9 +124,11 @@ def forward( tensor_name="Embedding output", dtype=self._residual_dtype, ) + # TODO: How do we support both Audio and Vision? position_ids = kwargs.get(LanguageModelKwargs.position_ids) - image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes) - image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + image_sizes = kwargs.get(VisionEncoderKwargs.image_sizes, []) + image_positions = kwargs.get(VisionEncoderKwargs.image_positions, []) + audio_positions = kwargs.get(AudioEncoderKwargs.audio_positions, []) tokens = kwargs.get(LanguageModelKwargs.tokens) - return self._forward(input_, tokens, position_ids, image_positions, image_sizes) + return self._forward(input_, tokens, position_ids, image_positions, image_sizes, audio_positions) diff --git a/fast_llm/layers/transformer/audio_transformer.py b/fast_llm/layers/transformer/audio_transformer.py index 43ee3f46..f0fb6d17 100644 --- a/fast_llm/layers/transformer/audio_transformer.py +++ b/fast_llm/layers/transformer/audio_transformer.py @@ -1,15 +1,14 @@ import torch from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.audio_encoder.config import AudioTransformerDimNames, AudioTransformerKwargs -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.config import AudioTransformerDimNames, AudioTransformerKwargs, TransformerConfig from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.tensor import TensorMeta class AudioTransformerLayer(TransformerLayer): """ - A vision transformer layer to encode image patches + A audio transformer layer to encode image patches """ def __init__( diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index c9d379ab..45d911a6 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -71,6 +71,10 @@ class VisionTransformerDimNames(BaseTransformerDimNames, prefix="image_encoder") pass +class AudioTransformerDimNames(BaseTransformerDimNames, prefix="audio_encoder"): + pass + + class BaseTransformerKwargs: _kwargs_attributes = { "rotary_freq_q": "rotary_freq_q", @@ -110,6 +114,10 @@ class VisionTransformerKwargs(BaseTransformerKwargs, prefix="image_encoder"): patch_position_ids = "patch_position_ids" +class AudioTransformerKwargs(BaseTransformerKwargs, prefix="audio_encoder"): + pass + + class TransformerLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" @@ -127,6 +135,7 @@ class RotaryEmbeddingType(str, enum.Enum): class TransformerType(str, enum.Enum): lm_decoder = "lm_decoder" image_encoder = "image_encoder" + audio_encoder = "audio_encoder" @config_class() @@ -317,7 +326,7 @@ class TransformerConfig(BaseModelConfig): _abstract = False transformer_type: TransformerType = Field( default=TransformerType.lm_decoder, - desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + desc="Type of the transformer. Choices: lm_decoder, image_encoder, audio_encoder.", hint=FieldHint.architecture, ) normalization: NormalizationConfig = Field( @@ -828,7 +837,7 @@ class VisionTransformerConfig(TransformerConfig): transformer_type: TransformerType = FieldUpdate( default=TransformerType.image_encoder, - desc="Type of the transformer. Choices: lm_decoder, image_encoder.", + desc="Type of the transformer. Choices: lm_decoder, image_encoder, audio_encoder.", hint=FieldHint.architecture, ) causal: bool = FieldUpdate( @@ -857,13 +866,31 @@ class AudioTransformerConfig(TransformerConfig): Configuration for the Audio Transformer model. """ + transformer_type: TransformerType = FieldUpdate( + default=TransformerType.audio_encoder, + desc="Type of the transformer. Choices: lm_decoder, image_encoder, audio_encoder.", + hint=FieldHint.architecture, + ) causal: bool = FieldUpdate( default=False, desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Audio Transformer.", hint=FieldHint.feature, ) + gated: bool = FieldUpdate( + default=False, + desc="MLP gating.", + hint=FieldHint.feature, + ) # rotary: AudioRotaryConfig = FieldUpdate( # default_factory=AudioRotaryConfig, # desc="Configuration for the rotary positional embeddings.", # hint=FieldHint.feature, # ) + + @property + def _transformer_kwargs(self) -> AudioTransformerKwargs: + return AudioTransformerKwargs + + @property + def _transformer_dim_names(self) -> AudioTransformerDimNames: + return AudioTransformerDimNames diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index af1a53f6..1b436eba 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -280,9 +280,9 @@ def _create_tensors(self, sequence_length: int) -> None: ) def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + self._create_tensors(kwargs[self._transformer_kwargs.sequence_length]) + sequence_k = kwargs[self._transformer_kwargs.sequence_k_dim].size + sequence_q = kwargs[self._transformer_kwargs.sequence_q_dim].size kwargs[self._transformer_kwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] From cd167fc81c77f0d78d8b9f77bf168c7416718773 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 28 May 2025 00:24:47 +0000 Subject: [PATCH 58/82] audio conversion --- fast_llm/models/gpt/config.py | 7 + fast_llm/models/gpt/conversion.py | 478 ++++++++++++++++++------------ fast_llm/models/gpt/model.py | 33 +++ fast_llm/models/gpt/trainer.py | 11 +- 4 files changed, 340 insertions(+), 189 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 2bfdf892..ae6fc6ad 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -72,6 +72,12 @@ class WhisperGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "whisper" +class AyraAudioModelGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "ayra_audio" + audio_name: typing.ClassVar[str] = "whisper" + text_name: typing.ClassVar[str] = "llama" + + @config_class() class GPTBatchConfig(BatchConfig): sequence_length: int = Field( @@ -156,6 +162,7 @@ class GPTModelConfig(FastLLMModelConfig): LlavaGPTHuggingfaceCheckpointFormat, WhisperGPTHuggingfaceCheckpointFormat, PixtralGPTHuggingfaceCheckpointFormat, + AyraAudioModelGPTHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index da9e897b..a1d91b2a 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -32,6 +32,7 @@ from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig from fast_llm.layers.vision_encoder.config import VisionEncoderType from fast_llm.models.gpt.config import ( + AyraAudioModelGPTHuggingfaceCheckpointFormat, GPTBaseModelConfig, GPTModelConfig, LlamaGPTHuggingfaceCheckpointFormat, @@ -566,240 +567,212 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ] -class WhisperHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): +class WhisperHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = WhisperGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - # lm_converters = super()._create_config_converters() - lm_converters = super()._create_config_converters() - for idx, converter in enumerate(lm_converters): - if converter.export_names == (("model_type",),): - continue - elif converter.export_names == (("architectures",),): - ignore_index = idx - if converter.export_names: - converter.export_names = (("text_config", *converter.export_names[0]), *converter.export_names[1:]) - - return ( - lm_converters[:ignore_index] - + lm_converters[ignore_index + 1 :] - + [ - ConstantImportParamConverter( - fast_llm_names=(("audio_encoder", "type"),), fast_llm_value=AudioEncoderType.whisper - ), - ConstantExportParamConverter( - export_names=(("architectures",),), export_value=["WhisperForConditionalGeneration"] - ), - # Audio Adapter - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "adapter_size"),), - # export_names=(("text_config", "hidden_size"),), - # ), - # ConstantImportParamConverter( - # fast_llm_names=(("vision_encoder", "patch_norm", "type"),), - # fast_llm_value=NormalizationType.rms_norm, - # ), - # ConstantImportParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "normalization", "type"),), - # fast_llm_value=NormalizationType.rms_norm, - # ), - # Audio Transformer - RenameParamConverter( - fast_llm_names=(("audio_encoder", "transformer", "num_layers"),), - export_names=(("encoder_layers",),), + return super()._create_config_converters() + [ + # set default layernorm + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.layer_norm + ), + ConstantExportParamConverter( + export_names=(("architectures",),), export_value=["WhisperForConditionalGeneration"] + ), + ConstantImportParamConverter(fast_llm_names=(("type",),), fast_llm_value=AudioEncoderType.whisper), + # make transformer noncasual + ConstantImportParamConverter(fast_llm_names=(("transformer", "causal"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_layers", + ), ), - RenameParamConverter( - fast_llm_names=(("audio_encoder", "transformer", "hidden_size"),), - export_names=(("d_model",),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "hidden_size", + ), ), - RenameParamConverter( - fast_llm_names=(("audio_encoder", "transformer", "num_attention_heads"),), - export_names=(("encoder_attention_heads",),), + export_names=(("d_model",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "num_attention_heads", + ), ), - # RenameParamConverter( - # fast_llm_names=(("audio_encoder", "transformer", "head_groups"),), - # export_names=( - # ( - # "encoder_attention_heads", - # ), - # ), - # ), - RenameParamConverter( - fast_llm_names=(("audio_encoder", "transformer", "ffn_hidden_size"),), - export_names=(("encoder_ffn_dim",),), + export_names=(("encoder_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "head_groups", + ), ), - MappedConfigParamConverter( - fast_llm_names=(("audio_encoder", "transformer", "activation_type"),), - export_names=(("activation_function",),), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, + export_names=(("encoder_attention_heads",),), + ), + RenameParamConverter( + fast_llm_names=( + ( + "transformer", + "ffn_hidden_size", + ), ), - # ConstantImportParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "gated"),), fast_llm_value=True - # ), - # MappedConfigParamConverter( - # fast_llm_names=(("vision_encoder", "adapter_activation_type"),), - # export_names=(("projector_hidden_act",),), - # fast_llm_value=ActivationType.from_hf_name, - # export_value=lambda activation_type: activation_type.hf_name, - # ), - # ConstantImportParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "add_linear_biases"),), fast_llm_value=False - # ), - # RenameParamConverter( - # fast_llm_names=(("vision_encoder", "transformer", "rotary", "theta"),), - # export_names=(("vision_config", "rope_theta"),), - # ), - ] - ) + export_names=(("encoder_ffn_dim",),), + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("activation_function",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.none + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=False), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=True), + RenameParamConverter( + fast_llm_names=(("num_mel_bins",),), + export_names=(("num_mel_bins",),), + ), + RenameParamConverter( + fast_llm_names=(("aud_downsampling_k",),), + export_names=(("encoder_projector_ds_rate",),), + ), + ] - def _create_vision_transformer_layer_converters( - self, - i: int, - ignore_export: bool = False, - hf_base_prefix: str = "", - fast_llm_offset: int = 1, - type: str | None = None, + def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + return [ + WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.weight", f"{hf_prefix}fc1.weight"), + WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.bias", f"{hf_prefix}fc1.bias"), + WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.weight", f"{hf_prefix}fc2.weight"), + WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.bias", f"{hf_prefix}fc2.bias"), + ] + + def _create_audio_transformer_layer_converters( + self, transformer_layer_index: int, fast_llm_offset: int = 1, hf_base_prefix: str = "" ) -> list[WeightConverter]: - if type is not None: - if type == "vision": - transformer_config: TransformerConfig = self._model.config.base_model.vision_encoder.transformer - else: - transformer_config: TransformerConfig = self._model.config.base_model.transformer - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm - converters = [] - names_bias_cls = [ + # Vision transformer layer + transformer_config = self._model.config.base_model.audio_encoder.transformer + norm_bias: bool = transformer_config.normalization.type == NormalizationType.layer_norm + name_bias_cls = [ # Self-attn ( - f"layers.{i+fast_llm_offset}.self_attn.query", - f"vision_tower.transformer.layers.{i}.attention.q_proj", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.query", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.q_proj", transformer_config.add_attn_qkv_bias, QueryWeightConverter, ), ( - f"layers.{i+fast_llm_offset}.self_attn.key_value", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.key_value", ( - f"vision_tower.transformer.layers.{i}.attention.k_proj", - f"vision_tower.transformer.layers.{i}.attention.v_proj", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.k_proj", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.v_proj", ), - transformer_config.add_attn_qkv_bias, + transformer_config.add_attn_qkv_bias, # TODO Toby: add permanent fix for key bias KeyValueWeightConverter, ), ( - f"layers.{i+fast_llm_offset}.self_attn.dense", - f"vision_tower.transformer.layers.{i}.attention.o_proj", + f"layers.{fast_llm_offset + transformer_layer_index}.self_attn.dense", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn.out_proj", transformer_config.add_attn_dense_bias, WeightConverter, ), # Norm ( - f"layers.{i+fast_llm_offset}.norm_1", - f"vision_tower.transformer.layers.{i}.attention_norm", + f"layers.{fast_llm_offset + transformer_layer_index}.norm_1", + f"{hf_base_prefix}layers.{transformer_layer_index}.self_attn_layer_norm", norm_bias, WeightConverter, ), ( - f"layers.{i+fast_llm_offset}.norm_2", - f"vision_tower.transformer.layers.{i}.ffn_norm", + f"layers.{fast_llm_offset + transformer_layer_index}.norm_2", + f"{hf_base_prefix}layers.{transformer_layer_index}.final_layer_norm", norm_bias, WeightConverter, ), ] - for fast_llm_prefix, hf_prefix, use_bias, cls in names_bias_cls: + converters = [] + for fast_llm_prefix, hf_prefix, use_bias, cls in name_bias_cls: converters += self._get_weight_and_bias_converters( fast_llm_prefix, - () if ignore_export else hf_prefix, + hf_prefix, use_bias, - cls=IgnoreExportWeightConverter if ignore_export else cls, + cls, ) - # MLP - if ignore_export: - converters += self._get_weight_and_bias_converters( - f"layers.{i+fast_llm_offset}.mlp.layer_1", - (), - transformer_config.add_mlp_bias, - cls=IgnoreExportWeightConverter, - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+fast_llm_offset}.mlp.layer_2", - (), - transformer_config.add_mlp_bias, - cls=IgnoreExportWeightConverter, - ) - converters += [IgnoreExportWeightConverter(f"layers.{i+fast_llm_offset}.mlp.router.weight", ())] - else: - converters += self._get_vision_transformer_mlp_converters( - f"layers.{i+fast_llm_offset}", f"vision_tower.transformer.layers.{i}" - ) + converters += self._get_transformer_mlp_converters( + f"layers.{fast_llm_offset + transformer_layer_index}", + f"{hf_base_prefix}layers.{transformer_layer_index}.", + ) return converters - def _get_vision_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - return [ - SplitWeightConverter( - f"{fast_llm_prefix}.mlp.layer_1.weight", - (f"{hf_prefix}.feed_forward.gate_proj.weight", f"{hf_prefix}.feed_forward.up_proj.weight"), - ), - MLPLayer2Converter( - f"{fast_llm_prefix}.mlp.layer_2.weight", - f"{hf_prefix}.feed_forward.down_proj.weight", - self._model.config.base_model, - ), + def _create_weight_converters(self, offset: int = 0, hf_base_prefix: str = "") -> list[WeightConverter]: + converters = [] + + # audio encoder conv + converters += [ + WeightConverter(f"layers.{offset}.conv1_weight", f"{hf_base_prefix}conv1.weight"), + WeightConverter(f"layers.{offset}.conv2_weight", f"{hf_base_prefix}conv2.weight"), ] - def _create_vision_transformer_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.vision_encoder.transformer.num_layers - vision_transformer_converters = [] - for layer in range(num_layers): - # TODO Soham: check if args are correct - vision_transformer_converters.extend( - self._create_vision_transformer_layer_converters( - layer, - ignore_export=False, - hf_base_prefix="vision_tower.transformer.layers.", - fast_llm_offset=1, - type="vision", - ) - ) + if self._model.config.base_model.audio_encoder.conv_bias: + converters += [ + WeightConverter(f"layers.{offset}.conv1_bias", f"{hf_base_prefix}conv1.bias"), + WeightConverter(f"layers.{offset}.conv2_bias", f"{hf_base_prefix}conv2.bias"), + ] - return vision_transformer_converters + # position embedding + converters.append( + WeightConverter(f"layers.{offset}.positional_embeddings", f"{hf_base_prefix}embed_positions.weight") + ) - def _create_vision_encoder_weight_converters(self) -> list[WeightConverter]: - patch_conv_converters = [WeightConverter("layers.0.weight", "vision_tower.patch_conv.weight")] - if self._model.config.base_model.vision_encoder.conv_bias: - patch_conv_converters.append(WeightConverter("layers.0.bias", "vision_tower.patch_conv.bias")) - layernorm_converters = [ - WeightConverter("layers.0.norm.weight", "vision_tower.ln_pre.weight"), - ] - if self._model.config.base_model.vision_encoder.patch_norm.type == NormalizationType.layer_norm: - layernorm_converters.append(WeightConverter("layers.0.norm.bias", "vision_tower.ln_pre.bias")) - - vision_transformer_converters = self._create_vision_transformer_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 1 - adapter_converters = [ - WeightConverter(f"layers.{offset}.layer_1.weight", "multi_modal_projector.linear_1.weight"), - WeightConverter(f"layers.{offset}.layer_1.bias", "multi_modal_projector.linear_1.bias"), - # TODO Soham: add bias based on config - WeightConverter(f"layers.{offset}.layer_2.weight", "multi_modal_projector.linear_2.weight"), - WeightConverter(f"layers.{offset}.layer_2.bias", "multi_modal_projector.linear_2.bias"), - ] + # transformer encoder layers + num_layers = self._model.config.base_model.audio_encoder.transformer.num_layers + for i in range(num_layers): + converters += self._create_audio_transformer_layer_converters(i, offset + 1, hf_base_prefix) - return patch_conv_converters + layernorm_converters + vision_transformer_converters + adapter_converters + offset = offset + num_layers + 1 - def _create_weight_converters(self) -> list[WeightConverter]: - vision_encoder_converter = self._create_vision_encoder_weight_converters() - offset = self._model.config.base_model.vision_encoder.transformer.num_layers + 3 - # Embeddings - lm_converters = [ - WeightConverter(f"layers.{offset - 1}.word_embeddings_weight", f"language_model.model.embed_tokens.weight") - ] - for i in range(self._model.config.base_model.transformer.num_layers): - lm_converters += self._create_transformer_layer_converters( - fast_llm_layer_name=f"layers.{i + offset}", hf_layer_name=f"language_model.model.layers.{i}" + # add final layernorm + if self._model.config.base_model.audio_encoder.transformer.normalization.type == NormalizationType.layer_norm: + converters += [ + WeightConverter(f"layers.{offset}.norm_1.weight", f"{hf_base_prefix}layer_norm.weight"), + WeightConverter(f"layers.{offset}.norm_2.weight", "encoder_projector.layer_norm.weight"), + WeightConverter(f"layers.{offset}.norm_1.bias", f"{hf_base_prefix}layer_norm.bias"), + WeightConverter(f"layers.{offset}.norm_2.bias", "encoder_projector.layer_norm.bias"), + ] + + # multimodal projector + converters.extend( + [ + WeightConverter(f"layers.{offset}.layer_1.weight", "encoder_projector.linear1.weight"), + WeightConverter(f"layers.{offset}.layer_2.weight", "encoder_projector.linear2.weight"), + ] + ) + if self._model.config.base_model.audio_encoder.adapter_bias: + converters.extend( + [ + WeightConverter(f"layers.{offset}.layer_1.bias", "encoder_projector.linear1.bias"), + WeightConverter(f"layers.{offset}.layer_2.bias", "encoder_projector.linear2.bias"), + ] ) - lm_converters += self._create_lm_head_converters(hf_base_prefix="language_model.", fast_llm_offset=offset) - return vision_encoder_converter + lm_converters + + return converters + + @property + def num_layers(self) -> int: + # +2 for projector and conv layers + return self._model.config.base_model.audio_encoder.transformer.num_layers + 2 class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): @@ -1007,6 +980,139 @@ def num_layers(self) -> int: return self._model.config.base_model.vision_encoder.transformer.num_layers + 2 +class AyraAudioModelHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = AyraAudioModelGPTHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + + @classmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + cfg_dict = cls._load_config(config.path) + kwargs = {} + if "text_config" in cfg_dict: + text_kwargs = cls._import_config(cfg_dict["text_config"]) + kwargs.update(text_kwargs) + if "audio_config" in cfg_dict: + audio_kwargs = cls._import_config(cfg_dict["audio_config"]) + audio_kwargs = {tuple(["audio_encoder"] + list(key)): value for key, value in audio_kwargs.items()} + kwargs.update(audio_kwargs) + kwargs.update( + cls._import_config( + {key: value for key, value in cfg_dict.items() if key not in ("text_config", "audio_config")} + ) + ) + imported_model_config = cls._model_class.get_base_model_config_class().from_dict({}, kwargs) + return CheckpointMetadata( + fast_llm_version=__version__, + model=cls._model_class, + format=config.format, + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), + shards=["weights"], + ) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter(export_names=(("architectures",),), export_value=["AyraAudioModel"]), + # projector + MappedConfigParamConverter( + fast_llm_names=(("audio_encoder", "adapter_activation_type"),), + export_names=(("activation_function",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("audio_encoder", "adapter_size"),), + export_names=(("adapter_size",),), + ), + ] + + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: + handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + kwargs = {} + for converter in handler_cls._create_config_converters(): + try: + values = () + for export_name in converter.export_names: + try: + value = get_nested_dict_value(config, export_name) + except KeyError: + value = MISSING + values = values + (value,) + values = converter.import_params(values) + for fast_llm_name, value in zip(converter.fast_llm_names, values, strict=True): + if value is MISSING: + raise ValueError(f"Missing converted value for fast-llm parameter {fast_llm_name}") + if fast_llm_name in kwargs: + raise ValueError(f"Duplicate converted value for fast-llm parameter {fast_llm_name}") + kwargs[fast_llm_name] = value + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return kwargs + + @classmethod + def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: + # TODO Toby: implement for audio + exported_config = {} + vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + for converter in vision_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, ("vision_encoder",) + fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("vision_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in text_handler_cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, ("text_config",) + export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + for converter in cls._create_config_converters(): + try: + values = converter.export_params( + tuple( + cls._get_fast_llm_attribute(config, fast_llm_name) + for fast_llm_name in converter.fast_llm_names + ) + ) + for export_name, value in zip(converter.export_names, values, strict=True): + if value is not MISSING: + set_nested_dict_value(exported_config, export_name, value) + except Exception as e: + raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + + return exported_config + + def _create_weight_converters(self): + audio_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.audio_name) + audio_handler = audio_handler_cls(self._model) # TODO Toby: are we calling this twice? + converters = audio_handler._create_weight_converters(hf_base_prefix="encoder.", offset=0) + text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.text_name) + text_handler = text_handler_cls(self._model) + converters.extend( + text_handler._create_weight_converters(hf_base_prefix="llm.", offset=audio_handler.num_layers) + ) + return converters + + class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @@ -1275,5 +1381,5 @@ class AutoGPTHuggingfaceCheckpointHandler( LlavaGPTHuggingfaceCheckpointFormat.name: LlavaHuggingfaceCheckpointHandler, WhisperGPTHuggingfaceCheckpointFormat.name: WhisperHuggingfaceCheckpointHandler, PixtralGPTHuggingfaceCheckpointFormat.name: PixtralHuggingfaceCheckpointHandler, - # MultiModalGPTHuggingfaceCheckpointFormat.name: MultiModalHuggingfaceCheckpointHandler + AyraAudioModelGPTHuggingfaceCheckpointFormat.name: AyraAudioModelHuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 01cf0ec3..330bd832 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,6 +21,8 @@ from fast_llm.layers.multi_modal.embedding import MultiModalEmbedding from fast_llm.layers.transformer.audio_transformer import AudioTransformerLayer from fast_llm.layers.transformer.config import ( + AudioTransformerDimNames, + AudioTransformerKwargs, RoutingType, TransformerDimNames, TransformerKwargs, @@ -217,6 +219,18 @@ def preprocess_meta( else: vision_kwargs = {} + if self._config.audio_encoder.enabled: + audio_kwargs = { + AudioEncoderKwargs.kv_channels: self._tensor_space.get_tensor_dim( + AudioTransformerDimNames.kv_channels + ).size, + AudioEncoderKwargs.out_channels: self._tensor_space.get_tensor_dim( + AudioEncoderKwargs.out_channels + ).size, + } + else: + audio_kwargs = {} + batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) @@ -272,6 +286,22 @@ def preprocess_meta( } ) + if self._config.audio_encoder.enabled: + audio_hidden_dim = self._tensor_space.get_tensor_dim(AudioTransformerDimNames.hidden) + audio_hidden_dims = ( + (hidden_sequence_q_dim, batch_dim, audio_hidden_dim) + if sequence_first + else (batch_dim, hidden_sequence_q_dim, audio_hidden_dim) + ) + audio_kwargs.update( + { + AudioTransformerKwargs.hidden_dims: audio_hidden_dims, + AudioTransformerKwargs.sequence_length: 1500, # TODO: Toby Parameterize + AudioTransformerKwargs.sequence_k_dim: 1500, + AudioTransformerKwargs.sequence_q_dim: 1500, + } + ) + common_kwargs = { LanguageModelKwargs.phase: phase, TransformerKwargs.sequence_first: sequence_first, @@ -281,6 +311,7 @@ def preprocess_meta( TransformerKwargs.micro_batch_size: micro_batch_size, } common_kwargs.update(vision_kwargs) + common_kwargs.update(audio_kwargs) sequence_k_pasts = range( sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, @@ -482,6 +513,8 @@ def transformer_layers(self) -> list[TransformerLayer]: def embedding_layer_index(self) -> int: if self._config.vision_encoder.enabled: return self._config.vision_encoder.transformer.num_layers + 2 + elif self._config.audio_encoder.enabled: + return self._config.audio_encoder.transformer.num_layers + 2 else: return 0 diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 38264d4a..3000e9be 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -30,9 +30,6 @@ def _get_sampling_parameters( "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, "cross_document_attention": self._config.batch.cross_document_attention, "extra_tokens": self._config.model.base_model.prediction_heads, - "aud_downsampling_k": self._config.model.base_model.audio_encoder.aud_downsampling_k, - "aud_padding_duration": self._config.batch.aud_padding_duration, - "aud_sampling_rate": self._config.model.base_model.audio_encoder.aud_sampling_rate, } ) if self._config.model.base_model.vision_encoder.enabled: @@ -42,6 +39,14 @@ def _get_sampling_parameters( "image_size": self._config.batch.image_size, } ) + if self._config.model.base_model.audio_encoder.enabled: + parameters.update( + { + "aud_downsampling_k": self._config.model.base_model.audio_encoder.aud_downsampling_k, + "aud_padding_duration": self._config.batch.aud_padding_duration, + "aud_sampling_rate": self._config.model.base_model.audio_encoder.aud_sampling_rate, + } + ) return parameters if _return_dict else GPTSamplingParameters(**parameters) def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: From 204b3e9f27e6d12168f72a4ae045fc7ab9dbe475 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 28 May 2025 06:04:47 +0000 Subject: [PATCH 59/82] fix image embeddings offset --- fast_llm/data/dataset/gpt/config.py | 1 + fast_llm/data/dataset/gpt/memmap.py | 2 + fast_llm/data/dataset/gpt/sampled.py | 68 +++++++------- fast_llm/layers/multi_modal/embedding.py | 89 ++++++++----------- fast_llm/layers/vision_encoder/config.py | 7 +- .../layers/vision_encoder/preprocessing.py | 31 ++++++- fast_llm/models/gpt/trainer.py | 1 + 7 files changed, 109 insertions(+), 90 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 004a062c..bb3ff717 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -77,6 +77,7 @@ class GPTSamplingParameters(SamplingParameters): patch_size: int | None = None image_size: int | None = None image_break_token: int | None = None + image_end_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index a202d2e1..d83064b1 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -202,6 +202,7 @@ def get( patch_size: int | None = None, image_size: int | None = None, image_break: bool = False, + image_end: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -244,6 +245,7 @@ def get( get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), patch_size, image_break=image_break, + image_end=image_end, ) additional_tokens += image_tokens image_idx += 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2c068742..6c8e9fe7 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -15,7 +15,7 @@ from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims -from fast_llm.utils import Assert +from fast_llm.utils import Assert, div try: from fast_llm.csrc.data import build_padded_token_cumsum, build_sample_idx # noqa @@ -147,6 +147,7 @@ def _sample(self) -> None: ), self._parameters.patch_size, image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, ) for size in sizes ) @@ -213,6 +214,7 @@ def _sample(self) -> None: "patch_size": self._parameters.patch_size, "truncate_documents": self._truncate_documents, "image_break_token": self._parameters.image_break_token, + "image_end_token": self._parameters.image_end_token, "config": self._config.to_dict(), } if self._truncate_documents: @@ -424,18 +426,23 @@ def __getitem__(self, index: int) -> typing.Any: text_size, image_lengths = self._indexed_dataset.get_document_size(document_index) + resized_image_lengths = [ + get_resize_dims( + *image_length, + self._parameters.image_size, + self._parameters.image_size, + self._parameters.patch_size, + ) + for image_length in image_lengths + ] image_sizes = [ get_num_image_tokens( - *get_resize_dims( - *image_length, - self._parameters.image_size, - self._parameters.image_size, - self._parameters.patch_size, - ), + *image_length, self._parameters.patch_size, image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, ) - for image_length in image_lengths + for image_length in resized_image_lengths ] image_tokens = sum(image_sizes) document_size = text_size + image_tokens @@ -468,6 +475,8 @@ def __getitem__(self, index: int) -> typing.Any: offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, + # image_break=self._parameters.image_break_token is not None, + # image_end=self._parameters.image_end_token is not None, ) start_pos = 0 if sample.image_positions: @@ -477,41 +486,30 @@ def __getitem__(self, index: int) -> typing.Any: token_ids.append(sample.token_ids[start_pos:im_position]) text_tokens_added += len(token_ids[-1]) image_positions.append(text_tokens_added + image_tokens_added) - # token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) if self._parameters.image_break_token is not None: - # Calculate patch dimensions for the image - height, width = get_resize_dims( - image_lengths[idx][0], - image_lengths[idx][1], - self._parameters.image_size, - self._parameters.image_size, - self._parameters.patch_size, - ) - num_patches_h = math.ceil(height / self._parameters.patch_size) - num_patches_w = math.ceil(width / self._parameters.patch_size) - - # Calculate the token count considering break tokens - tokens_per_row = num_patches_w - resized_image_tokens = num_patches_h * tokens_per_row + ( - num_patches_h - 1 - ) # Add break tokens after each row except last + height, width = resized_image_lengths[idx] + num_patches_h = div(height, self._parameters.patch_size) + num_patches_w = div(width, self._parameters.patch_size) # Create image token placeholder array - image_token_array = np.full((resized_image_tokens,), -100, dtype=np.int64) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) # Add break tokens after each row except the last row for row in range(num_patches_h - 1): - position = (row + 1) * tokens_per_row + row + position = (row + 1) * num_patches_w + row image_token_array[position] = self._parameters.image_break_token - - token_ids.append(image_token_array) - - # Update image_tokens_added to reflect actual number of tokens added - image_tokens_added += resized_image_tokens + # add end token if specified, else break token + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if self._parameters.image_end_token is not None: + image_token_array[last_row_position] = self._parameters.image_end_token + else: + image_token_array[last_row_position] = self._parameters.image_break_token else: - # Just add placeholders for all image tokens without break tokens - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) - image_tokens_added += image_sizes[idx] + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if self._parameters.image_end_token is not None: + image_token_array[-1] = self._parameters.image_end_token + token_ids.append(image_token_array) + image_tokens_added += image_sizes[idx] start_pos = im_position token_ids.append(sample.token_ids[start_pos:]) text_tokens_added += len(token_ids[-1]) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index f40df3f0..4dd4a46e 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -9,7 +9,7 @@ from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderKwargs -from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens +from fast_llm.layers.vision_encoder.preprocessing import get_num_patches from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div @@ -60,37 +60,30 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_image_tokens( - *size, - self._config.vision_encoder.patch_size, - image_break=self._config.vision_encoder.image_break_token is not None, - ) - # Calculate the patch dimensions - patch_width = div(size[0], self._config.vision_encoder.patch_size) - patch_height = div(size[1], self._config.vision_encoder.patch_size) - - # Process row by row for both sequence parallel and non-parallel cases - for row in range(patch_height): - # Calculate source and destination starting positions - row_start_src = image_embedding_offset + row * patch_width - row_start_dst = position + row * (patch_width + 1) - - # Always use full patch_width - tokens_in_row = patch_width - - if self._sequence_parallel: - # Copy with dimensions swapped for sequence parallel case - embeddings[row_start_dst : row_start_dst + tokens_in_row, sample_idx] = input_[ - row_start_src : row_start_src + tokens_in_row, sample_idx - ] - else: - # Copy with normal dimension ordering - embeddings[sample_idx, row_start_dst : row_start_dst + tokens_in_row] = input_[ - sample_idx, row_start_src : row_start_src + tokens_in_row - ] - - # Move to the next image in the input tensor - image_embedding_offset += num_image_tokens + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if self._config.vision_encoder.image_break_token is not None: + patch_width = div(size[0], self._config.vision_encoder.patch_size) + patch_height = div(size[1], self._config.vision_encoder.patch_size) + + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) + + if self._sequence_parallel: + # Copy with dimensions swapped for sequence parallel case + embeddings[row_start_dst : row_start_dst + patch_width, sample_idx] = input_[ + row_start_src : row_start_src + patch_width, sample_idx + ] + else: + # Copy with normal dimension ordering + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + else: + embeddings[sample_idx, position : position + num_patches] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_patches + ] + image_embedding_offset += num_patches if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: @@ -109,28 +102,24 @@ def _forward( for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): image_embedding_offset = 0 for position, size in zip(positions, sizes): - num_image_tokens = get_num_image_tokens( - *size, - self._config.vision_encoder.patch_size, - image_break=self._config.vision_encoder.image_break_token is not None, - ) - # Calculate the patch dimensions - patch_width = div(size[0], self._config.vision_encoder.patch_size) - patch_height = div(size[1], self._config.vision_encoder.patch_size) + num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if self._config.vision_encoder.image_break_token is not None: + patch_width = div(size[0], self._config.vision_encoder.patch_size) + patch_height = div(size[1], self._config.vision_encoder.patch_size) - # Process row by row - for row in range(patch_height): - # Calculate source and destination starting positions - row_start_src = image_embedding_offset + row * patch_width - row_start_dst = position + row * (patch_width + 1) + for row in range(patch_height): + row_start_src = image_embedding_offset + row * patch_width + row_start_dst = position + row * (patch_width + 1) - # Copy row by row - embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ - sample_idx, row_start_src : row_start_src + patch_width + embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ + sample_idx, row_start_src : row_start_src + patch_width + ] + else: + embeddings[sample_idx, position : position + num_patches] = input_[ + sample_idx, image_embedding_offset : image_embedding_offset + num_patches ] - # Move to the next image in the input tensor - image_embedding_offset += num_image_tokens + image_embedding_offset += num_patches if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 5b972f12..26794174 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -132,7 +132,12 @@ class VisionEncoderConfig(BaseModelConfig): ) image_break_token: int | None = Field( default=None, - desc="Token id to separate image rows. If None, no token id is applied is applied.", + desc="Token id to separate image rows. If None, no token id is applied.", + hint=FieldHint.optional, + ) + image_end_token: int | None = Field( + default=None, + desc="Token id to indicate the end of an image. If None, no token id is applied.", hint=FieldHint.optional, ) adapter_lr_scale: float | None = Field( diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 8404adae..41da4fb6 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -6,6 +6,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.transformer.config import TransformerKwargs, VisionTransformerDimNames, VisionTransformerKwargs from fast_llm.layers.vision_encoder.config import VisionEncoderConfig, VisionEncoderDimNames, VisionEncoderKwargs from fast_llm.tensor import TensorMeta @@ -19,14 +20,19 @@ def get_num_patches(height: int, width: int, patch_size: int) -> tuple[int, int] return div(height, patch_size) * div(width, patch_size) -def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool) -> int: +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool, image_end: bool) -> int: """ Calculate the number of image tokens. If image_break is True, we consider 1 additional token after every row of patches. """ height_patches = div(height, patch_size) width_patches = div(width, patch_size) - return height_patches * width_patches + (height_patches - 1 if image_break else 0) + num_tokens = height_patches * width_patches + if image_break: + num_tokens += height_patches + elif image_end: + num_tokens += 1 + return num_tokens def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: @@ -150,16 +156,32 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] for imgs in images ] + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) + + labels = kwargs[LanguageModelKwargs.labels] + if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): + # If image break or end token is present, we need to replace image token ids to -100 in labels + # TODO: avoid double cloning labels in case of loss masking spans? + labels = labels.clone() + patches = [] patch_position_ids = [] cu_seqlens = [0] max_seqlen = -1 kwargs.get(TransformerKwargs.sequence_first) - for imgs, sizes in zip(images, image_sizes): + for idx, (imgs, sizes, positions) in enumerate(zip(images, image_sizes, image_positions)): seq_patches = [] sample_cu_seqlen = 0 - for image, size in zip(imgs, sizes): + for image, size, position in zip(imgs, sizes, positions): seqlen = get_num_patches(*size, patch_size) + num_tokens = get_num_image_tokens( + *size, + patch_size=patch_size, + image_break=self._config.image_break_token is not None, + image_end=self._config.image_end_token is not None, + ) + # set labels for image patches to -100 + labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 if seqlen > max_seqlen: max_seqlen = seqlen cu_seqlens.append(cu_seqlens[-1] + seqlen) @@ -204,6 +226,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # TODO Soham: remove assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] patches = torch.cat(patches) + kwargs[LanguageModelKwargs.labels] = labels patch_position_ids = torch.cat(patch_position_ids) kwargs[VisionEncoderKwargs.image_patches] = patches kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index d1b6d19e..a4f0b0b4 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -38,6 +38,7 @@ def _get_sampling_parameters( "patch_size": self._config.model.base_model.vision_encoder.patch_size, "image_size": self._config.batch.image_size, "image_break_token": self._config.model.base_model.vision_encoder.image_break_token, + "image_end_token": self._config.model.base_model.vision_encoder.image_end_token, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) From fd08eac092f508b50219d4314f22a54af8efe768 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 29 May 2025 00:10:38 +0000 Subject: [PATCH 60/82] heterogeneous data fixes --- fast_llm/engine/multi_stage/stage.py | 2 +- fast_llm/functional/cross_entropy.py | 2 +- fast_llm/functional/triton/mlp.py | 4 ++-- .../layers/vision_encoder/preprocessing.py | 21 +++++++++++++++---- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 675e878b..b1c7df81 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -121,7 +121,7 @@ def backward( assert self._mode.support_backward input_, output = grad_context output.backward(output_grad) - return input_.grad + return input_.grad if input_.grad is not None else torch.zeros_like(input_) def restore_parameters(self) -> None: assert self._is_setup diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 513510ec..53b5979e 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -145,7 +145,7 @@ def _fused_cross_entropy_forward_backward( per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask + per_sample_loss = per_sample_loss[loss_mask] loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index ee3ba304..0fb71bd5 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -50,7 +50,7 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - if activation_type == _TritonActivationType.gelu: + if activation_type == _TritonActivationType.gelu_pytorch_tanh: tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) @@ -100,7 +100,7 @@ def triton_mlp_activation_backward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) output_grad = tl.load(grad_output_ptr + output_offsets, mask=mask).to(tl.float32) - if activation_type == _TritonActivationType.gelu: + if activation_type == _TritonActivationType.gelu_pytorch_tanh: tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) grad = 0.5 * input_ * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * input_ * input_)) + 0.5 * (1 + tanh) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 41da4fb6..8fad3572 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -170,7 +170,13 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: max_seqlen = -1 kwargs.get(TransformerKwargs.sequence_first) for idx, (imgs, sizes, positions) in enumerate(zip(images, image_sizes, image_positions)): - seq_patches = [] + # add an empty tensor for clean concatenation in case of no images + seq_patches = [ + torch.tensor([]).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) + ] sample_cu_seqlen = 0 for image, size, position in zip(imgs, sizes, positions): seqlen = get_num_patches(*size, patch_size) @@ -211,9 +217,16 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] ) ) - position_ids = torch.cat( - [position_ids_in_meshgrid(*size, im_height // patch_size, patch_size) for size in sizes] - ).to(device=self._tensor_space.distributed.device) + if sizes: + position_ids = torch.cat( + [position_ids_in_meshgrid(*size, im_height // patch_size, patch_size) for size in sizes] + ).to(device=self._tensor_space.distributed.device) + else: + position_ids = torch.tensor( + [], + dtype=torch.int64, + device=self._tensor_space.distributed.device, + ) # We pad at the end instead of padding at the position in meshgrid because flash attention does not support custom attention masks patch_position_ids.append( torch.cat( From 1e3652aeae78f930fdd1c58d09b45681adec2047 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 29 May 2025 15:25:49 +0000 Subject: [PATCH 61/82] convert to rgb --- fast_llm/data/dataset/gpt/memmap.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index d83064b1..70380941 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -325,10 +325,11 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP for image in document.images: # assume 3 channels (RGB) for all images with PIL.Image.open(io.BytesIO(image["bytes"])) as img: - if img.mode == "L": - # Convert grayscale to RGB + if img.mode != "RGB": + # Convert all images to RGB img = img.convert("RGB") pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." image_lengths.append(np.array(pixels.shape[1:])) bin_stream.write(pixels.tobytes(order="C")) total_im_size += pixels.size From e0f7dfd7cd23f4974aa1e19c8bdb39e0b5a8f290 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 29 May 2025 16:31:51 +0000 Subject: [PATCH 62/82] mm loss masking spans --- fast_llm/data/dataset/gpt/memmap.py | 61 +++++++----- fast_llm/data/dataset/gpt/sampled.py | 99 ++++++++++++++----- .../data/preparator/gpt_memmap/prepare.py | 76 +++++++------- 3 files changed, 148 insertions(+), 88 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 855026cf..eb0a14c8 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -10,7 +10,6 @@ from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert, div @@ -108,7 +107,7 @@ def _init( offset += ( self._num_spans.nbytes + self._num_spans.sum() * 2 * np.dtype(np.int32).itemsize - + sum([x.nbytes for x in self._spans]) + # + sum([x.nbytes for x in self._spans]) ) self._num_pixels = 0 self._image_lengths = [] @@ -141,8 +140,8 @@ def _init( ) images_seen += n_images offset = offset + self._n_images.nbytes + 3 * self._n_images.sum() * np.dtype(np.int32).itemsize - self._audio_lengths = [] - self._audio_positions = [] + self._audio_lengths = [] # list of arrays + self._audio_positions = [] # list of arrays if self._has_audio and self._version >= 5: self._n_audio = np.frombuffer( self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset @@ -267,19 +266,19 @@ def get( if self._has_audio: audio_positions = self._audio_positions[idx] # increment offset by documents and images - offset = ( + aud_offset = ( self._pointers[idx] + offset * np.dtype(self._dtype).itemsize + self._document_sizes[idx] * np.dtype(self._dtype).itemsize ) if self._has_images and len(self._image_lengths) > 0: - offset += self._image_lengths[idx].prod(initial=3) * np.dtype(np.uint8).itemsize + aud_offset += self._image_lengths[idx].prod(initial=3) * np.dtype(np.uint8).itemsize all_audio = np.frombuffer( self._bin_buffer, dtype=np.dtype(np.float32), count=self._audio_lengths[idx].sum(), - offset=offset, + offset=aud_offset, ) start = 0 for audio_length in self._audio_lengths[idx]: @@ -295,23 +294,37 @@ def get( ] sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - if images: - image_idx = 0 - for span in sample_spans: - additional_tokens = 0 - image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - while image_position >= span[0] and image_position <= span[1]: - image_tokens = get_num_image_tokens( - get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), - patch_size, - image_break=image_break, - ) - additional_tokens += image_tokens - image_idx += 1 - image_position = ( - image_positions[image_idx] if image_idx < len(image_positions) else float("inf") - ) - span[1] += additional_tokens + # if images: + # image_idx = 0 + # for span in sample_spans: + # additional_tokens = 0 + # image_position = image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + # while image_position >= span[0] and image_position <= span[1]: + # image_tokens = get_num_image_tokens( + # get_resize_dims(*self._image_lengths[idx][image_idx], image_size, image_size, patch_size), + # patch_size, + # image_break=image_break, + # ) + # additional_tokens += image_tokens + # image_idx += 1 + # image_position = ( + # image_positions[image_idx] if image_idx < len(image_positions) else float("inf") + # ) + # span[1] += additional_tokens + # if audio: + # audio_idx = 0 + # for span in sample_spans: + # additional_tokens = 0 + # audio_position = audio_positions[audio_idx] if audio_idx < len(audio_positions) else float("inf") + # while audio_position >= span[0] and audio_position <= span[1]: + # audio_tokens = ... + # additional_tokens += audio_tokens + # audio_idx += 1 + # audio_position = ( + # audio_positions[audio_idx] if audio_idx < len(audio_positions) else float("inf") + # ) + # span[1] += additional_tokens + return GPTSample( token_ids=token_ids, images=images, diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 87cbebf3..c6a7e2bc 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -144,7 +144,7 @@ def _compute_audio_token_size(self, sizes): sizes.fill(raw_audio_seq_length) # set all audio sizes to padded amount # account for mel spectogram, convolution, downsampling k - audio_token_size_arr = sizes // 160 # default hop length TODO: check divisible? + audio_token_size_arr = sizes // 160 # default hop length TODO Toby: check divisible? audio_token_size_arr = audio_token_size_arr // ( 2 * self._parameters.aud_downsampling_k ) # convolution (2) * downsampling @@ -557,24 +557,27 @@ def __getitem__(self, index: int) -> typing.Any: start_pos = 0 # add tokens and multi modal padding placeholders - multimodal_positions = np.concatenate( - [ - arr.astype(np.int32) - for arr in (sample.image_positions, sample.audio_positions) - if arr is not None - ] - ) or np.array([], dtype=np.int32) - multimodal_positions.sort() - for idx, mm_position in enumerate(multimodal_positions): - if ( - sample.image_positions is not None and mm_position in sample.image_positions - ): # TODO Toby: use enum - mm_type = "image" - elif sample.audio_positions is not None and mm_position in sample.audio_positions: - mm_type = "audio" - else: - assert False - # image_positions.append(im_positions + len(token_ids) + image_tokens_added) + # multimodal_positions = np.concatenate( + # [ + # arr.astype(np.int32) + # for arr in (sample.image_positions, sample.audio_positions) + # if arr is not None + # ] + # ) or np.array([], dtype=np.int32) + # multimodal_positions.sort() + + multimodal_positions = [] + if sample.image_positions is not None: + multimodal_positions.extend( + [(pos, "image", idx) for idx, pos in enumerate(sample.image_positions)] + ) + if sample.audio_positions is not None: + multimodal_positions.extend( + [(pos, "audio", idx) for idx, pos in enumerate(sample.audio_positions)] + ) + + multimodal_positions.sort(key=lambda x: x[0]) + for global_idx, (mm_position, mm_type, source_idx) in enumerate(multimodal_positions): # Add placeholders for image and audio tokens tokens token_ids.append(sample.token_ids[start_pos:mm_position]) if mm_type == "image": @@ -584,8 +587,8 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.image_break_token is not None: # Calculate patch dimensions for the image height, width = get_resize_dims( - image_lengths[idx][0], - image_lengths[idx][1], + image_lengths[source_idx][0], + image_lengths[source_idx][1], self._parameters.image_size, self._parameters.image_size, self._parameters.patch_size, @@ -613,11 +616,14 @@ def __getitem__(self, index: int) -> typing.Any: mm_tokens_added += resized_image_tokens else: # Just add placeholders for all image tokens without break tokens - token_ids.append(np.full((image_sizes[idx],), -100, dtype=np.int64)) - mm_tokens_added += image_sizes[idx] + token_ids.append(np.full((image_sizes[source_idx],), -100, dtype=np.int64)) + mm_tokens_added += image_sizes[source_idx] elif mm_type == "audio": - audio_positions.append(sum(t.size for t in token_ids)) - token_ids.append(np.full((audio_token_size_arr[idx],), -100, dtype=np.int64)) + audio_pos = sum(t.size for t in token_ids) # includes mm tokens added already + audio_positions.append(audio_pos) + token_ids.append( + np.full((audio_token_size_arr[source_idx],), -100, dtype=np.int64) + ) # TODO Toby: index doesnt work here mm_tokens_added += audio_tokens start_pos = mm_position token_ids.append(sample.token_ids[start_pos:]) @@ -634,7 +640,47 @@ def __getitem__(self, index: int) -> typing.Any: audio.append([]) if self._parameters.use_loss_masking_spans: - for loss_masking_span in sample.loss_masking_spans: + mm_idx = 0 + total_mm_tokens = 0 + for loss_masking_span in sample.loss_masking_spans: # TODO: check these must be sorted + mm_tokens_in_span = 0 + mm_position, mm_type, source_idx = ( + multimodal_positions[mm_idx] + if mm_idx < len(multimodal_positions) + else (float("inf"), _, _) + ) + + # increment mm_idx until span is reached + while mm_position < loss_masking_span[0]: + if mm_type == "image": + num_mm_tokens = image_sizes[source_idx] + elif mm_type == "audio": + num_mm_tokens = audio_token_size_arr[source_idx] + total_mm_tokens += num_mm_tokens + mm_idx += 1 + mm_position, mm_type, source_idx = ( + multimodal_positions[mm_idx] + if mm_idx < len(multimodal_positions) + else (float("inf"), _, _) + ) + + # get all multimodal positions within span + while mm_position >= loss_masking_span[0] and mm_position <= loss_masking_span[1]: + if mm_type == "image": + num_mm_tokens = image_sizes[source_idx] + elif mm_type == "audio": + num_mm_tokens = audio_token_size_arr[source_idx] + mm_tokens_in_span += num_mm_tokens + mm_idx += 1 + mm_position, mm_type, source_idx = ( + multimodal_positions[mm_idx] + if mm_idx < len(multimodal_positions) + else (float("inf"), _, _) + ) + loss_masking_span[0] += total_mm_tokens # increment by all mm tokens before span + loss_masking_span[1] += total_mm_tokens + mm_tokens_in_span + total_mm_tokens += mm_tokens_in_span + span = np.clip( loss_masking_span + token_count - token_start, 0, @@ -658,6 +704,7 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans else None ) + images = [im for img_list in images for im in img_list] if images else None image_positions = np.array(image_positions) if image_positions else None diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index aa2481f0..283a6bf8 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -93,44 +93,44 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ "num_pixels": num_pixels, } - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans, images, image_token_positions = map( - list, - zip( - *[ - ( - np.array(input_ids, dtype=self._data_type.numpy), - np.array(token_spans, dtype=np.int32).reshape(-1, 2), - np.array(images, dtype=np.uint8), - np.array(image_token_positions, dtype=np.int32), - ) - for input_ids, token_spans, images, image_token_positions in [ - self._tokenizer.tokenize_with_spans(text, char_spans) - for text, char_spans in zip( - batch[self._config.dataset.field], - batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), - batch.get(self._config.dataset.images, itertools.repeat(None)), - batch.get(self._config.dataset.image_positions, itertools.repeat(None)), - ) - ] - ] - ), - ) - num_tokens = [len(x) for x in input_ids] - num_pixels = [0] * len(input_ids) - for idx, images in enumerate(images): - for bytes_im in images: - with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: - width, height = im.size - num_pixels[idx] += width * height * 3 - return { - "input_ids": input_ids, - "token_spans": token_spans, - "images": images, - "image_positions": image_token_positions, - "num_tokens": num_tokens, - "num_pixels": num_pixels, - } + # def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + # input_ids, token_spans, images, image_token_positions = map( + # list, + # zip( + # *[ + # ( + # np.array(input_ids, dtype=self._data_type.numpy), + # np.array(token_spans, dtype=np.int32).reshape(-1, 2), + # np.array(images, dtype=np.uint8), + # np.array(image_token_positions, dtype=np.int32), + # ) + # for input_ids, token_spans, images, image_token_positions in [ + # self._tokenizer.tokenize_with_spans(text, char_spans) + # for text, char_spans in zip( + # batch[self._config.dataset.field], + # batch.get(self._config.dataset.loss_masking_spans, itertools.repeat(None)), + # batch.get(self._config.dataset.images, itertools.repeat(None)), + # batch.get(self._config.dataset.image_positions, itertools.repeat(None)), + # ) + # ] + # ] + # ), + # ) + # num_tokens = [len(x) for x in input_ids] + # num_pixels = [0] * len(input_ids) + # for idx, images in enumerate(images): + # for bytes_im in images: + # with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + # width, height = im.size + # num_pixels[idx] += width * height * 3 + # return { + # "input_ids": input_ids, + # "token_spans": token_spans, + # "images": images, + # "image_positions": image_token_positions, + # "num_tokens": num_tokens, + # "num_pixels": num_pixels, + # } def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: shard_idx, shard_dataset = args From 0ae74d191c99d2d6e31232bfb3e1a67015248109 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 29 May 2025 16:33:00 +0000 Subject: [PATCH 63/82] add lr scale --- fast_llm/layers/audio_encoder/adapter.py | 6 ++++++ fast_llm/layers/audio_encoder/config.py | 21 ++++++++++++++++++++- fast_llm/layers/audio_encoder/encoder.py | 20 ++++++++++++++------ fast_llm/models/gpt/conversion.py | 13 +++++++++---- 4 files changed, 49 insertions(+), 11 deletions(-) diff --git a/fast_llm/layers/audio_encoder/adapter.py b/fast_llm/layers/audio_encoder/adapter.py index 8c0c7175..b02b4e77 100644 --- a/fast_llm/layers/audio_encoder/adapter.py +++ b/fast_llm/layers/audio_encoder/adapter.py @@ -22,11 +22,14 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): input_dim = tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_input) self._activation_type = config.adapter_activation_type self._use_adapter_bias = config.adapter_bias + self.lr_scale = config.adapter_lr_scale self.norm_1 = config.transformer.normalization.get_layer(audio_hidden_dim) + self.norm_1.lr_scale = self.lr_scale self.norm_2 = config.transformer.normalization.get_layer( tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size) ) + self.norm_2.lr_scale = self.lr_scale # TODO Soham: Make them OutputParallelLinear instead? How would this work with parallelism? self.layer_1 = Linear( @@ -35,6 +38,7 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): bias=self._use_adapter_bias, weight_init_method=init_normal_(), bias_init_method=init_normal_(), + lr_scale=self.lr_scale, ) self.layer_2 = Linear( tensor_space.get_tensor_dim(AudioEncoderDimNames.adapter_size), @@ -42,6 +46,7 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): bias=self._use_adapter_bias, weight_init_method=init_normal_(), bias_init_method=init_normal_(), + lr_scale=self.lr_scale, ) self.aud_downsampling_k = config.aud_downsampling_k @@ -59,6 +64,7 @@ def forward( tensor_name="Audio adapter output", dtype=input_.dtype, ) + input_ = self.norm_1(input_) batch_size, seq_len, dim = input_.size() # Check if sequence length is divisible by downsampling rate. diff --git a/fast_llm/layers/audio_encoder/config.py b/fast_llm/layers/audio_encoder/config.py index 3e09b39f..9503c60c 100644 --- a/fast_llm/layers/audio_encoder/config.py +++ b/fast_llm/layers/audio_encoder/config.py @@ -1,10 +1,11 @@ import enum -from fast_llm.config import Field, FieldHint, config_class +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.transformer.config import AudioTransformerConfig +from fast_llm.utils import Assert class AudioEncoderDimNames: @@ -68,6 +69,18 @@ class AudioEncoderConfig(BaseModelConfig): desc="Encoder convolution layer kernel size.", hint=FieldHint.core, ) + conv_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the convolutional layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + pos_emb_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the position embedding layer weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) # adapter configs adapter_size: int = Field( @@ -85,6 +98,12 @@ class AudioEncoderConfig(BaseModelConfig): desc="Whether to use bias in the adapter layer.", hint=FieldHint.optional, ) + adapter_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the adapter weights.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) # audio configs num_mel_bins: int = Field( diff --git a/fast_llm/layers/audio_encoder/encoder.py b/fast_llm/layers/audio_encoder/encoder.py index 20c7d507..b35cc174 100644 --- a/fast_llm/layers/audio_encoder/encoder.py +++ b/fast_llm/layers/audio_encoder/encoder.py @@ -14,8 +14,9 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): super().__init__() self._tensor_space = tensor_space self.dropout_p = config.encoder_dropout + self._conv_lr_scale = config.conv_lr_scale + self._pos_emb_lr_scale = config.pos_emb_lr_scale - # TODO Toby: lr_scale self.conv1_weight = ParameterMeta.from_dims( ( self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), @@ -23,8 +24,9 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), ), init_method=init_normal_(), + lr_scale=self._conv_lr_scale, ) - self.conv1_stride = 1 # TODO: parameterize? + self.conv1_stride = 1 # TODO Toby: parameterize? self.conv2_weight = ParameterMeta.from_dims( ( @@ -33,15 +35,20 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): self._tensor_space.get_tensor_dim(AudioEncoderDimNames.kernel_size), ), init_method=init_normal_(), + lr_scale=self._conv_lr_scale, ) - self.conv2_stride = 2 # TODO: parameterize? + self.conv2_stride = 2 # TODO Toby: parameterize? if config.conv_bias: self.conv1_bias = ParameterMeta.from_dims( - (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), init_method=init_normal_() + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_scale=self._conv_lr_scale, ) self.conv2_bias = ParameterMeta.from_dims( - (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), init_method=init_normal_() + (self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels),), + init_method=init_normal_(), + lr_scale=self._conv_lr_scale, ) else: self.conv1_bias = None @@ -53,6 +60,7 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): self._tensor_space.get_tensor_dim(AudioEncoderDimNames.out_channels), ), init_method=init_normal_(), + lr_scale=self._pos_emb_lr_scale, ) def forward( @@ -66,7 +74,7 @@ def forward( if isinstance(input_, TensorMeta): return TensorMeta.from_dims(hidden_dims, tensor_name="audio conv output", dtype=input_.dtype) - # TODO: check how to best cast dtype + # TODO Toby: check how to best cast dtype input_ = input_.to(self.conv1_weight.dtype) input_ = torch.nn.functional.conv1d( diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index a1d91b2a..6438ce0f 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -644,10 +644,6 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("num_mel_bins",),), export_names=(("num_mel_bins",),), ), - RenameParamConverter( - fast_llm_names=(("aud_downsampling_k",),), - export_names=(("encoder_projector_ds_rate",),), - ), ] def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: @@ -1024,6 +1020,15 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("audio_encoder", "adapter_size"),), export_names=(("adapter_size",),), ), + RenameParamConverter( + fast_llm_names=( + ( + "audio_encoder", + "aud_downsampling_k", + ), + ), + export_names=(("encoder_projector_ds_rate",),), + ), ] @classmethod From 438ba80062077b9c0c69653229fd1c04a9a88c91 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 30 May 2025 16:42:54 +0000 Subject: [PATCH 64/82] mel spec changes --- fast_llm/data/dataset/gpt/sampled.py | 3 +- .../layers/audio_encoder/preprocessing.py | 61 ++++++++++++------- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 66ae6a88..b2d30379 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -636,7 +636,8 @@ def __getitem__(self, index: int) -> typing.Any: else: images.append([]) if sample.audio: - audio.append(self.apply_audio_padding(sample.audio)) + # audio.append(self.apply_audio_padding(sample.audio)) + audio.append(sample.audio) else: audio.append([]) diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index 9d0db1b4..506e026e 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -1,14 +1,12 @@ import typing import torch -from torchaudio.transforms import MelSpectrogram +from transformers import WhisperFeatureExtractor from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderKwargs -# from transformers import WhisperFeatureExtractor - class AudioPreprocessor(Preprocessor): def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): @@ -16,21 +14,21 @@ def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - # self.feature_extractor = WhisperFeatureExtractor(sampling_rate=self._config.aud_sampling_rate) + self.feature_extractor = WhisperFeatureExtractor(sampling_rate=self._config.aud_sampling_rate) - self.mel_transform = MelSpectrogram( - sample_rate=self._config.aud_sampling_rate, - n_fft=400, - win_length=400, - hop_length=160, - n_mels=80, - f_min=0.0, - f_max=8000.0, - mel_scale="slaney", - norm="slaney", - center=True, - power=2.0, - ) + # self.mel_transform = MelSpectrogram( + # sample_rate=self._config.aud_sampling_rate, + # n_fft=400, + # win_length=400, + # hop_length=160, + # n_mels=80, + # f_min=0.0, + # f_max=8000.0, + # mel_scale="slaney", + # norm="slaney", + # center=True, + # power=2.0, + # ) def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: # kwargs[AudioEncoderKwargs.audio_mel_meta] = TensorMeta.from_dims( @@ -49,13 +47,31 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: audio_raw = kwargs[AudioEncoderKwargs.audio] - flattened_audio = [audio_arr for sequence in audio_raw for audio_arr in sequence] - flattened_audio_tensor = torch.stack(flattened_audio, dim=0) + flattened_audio = [ + audio_arr for sequence in audio_raw for audio_arr in sequence + ] # flatten in the batch dimension + # flattened_audio_tensor = torch.stack(flattened_audio, dim=0) # audio_inputs = self.feature_extractor(audio_raw, sampling_rate=16000, return_tensors="pt") - self.mel_transform.to(self._tensor_space.distributed.device) + # self.mel_transform.to(self._tensor_space.distributed.device) + + # audio_mel = self.mel_transform(flattened_audio_tensor) + # flattened_audio_tensor = np.stack(flattened_audio, axis=0) + # audio_inputs = self.feature_extractor(flattened_audio_tensor, sampling_rate=16000, return_tensors="pt") + # audio_mel = audio_inputs['input_features'] + + audio_mel = [] + for audio in flattened_audio: + audio_mel.append( + self.feature_extractor( + audio, + sampling_rate=self._config.aud_sampling_rate, + return_tensors="pt", + max_length=30 * self._config.aud_sampling_rate, + )["input_features"] + ) + audio_mel = torch.stack(audio_mel, dim=0).squeeze(1) - audio_mel = self.mel_transform(flattened_audio_tensor) - audio_mel = audio_mel[:, :, :-1] # TODO Toby: check this! + # audio_mel = audio_mel[:, :, :-1] # TODO Toby: check this! # # set attention mask # TODO Toby: fix backup attention # sequence_k = kwargs[self._transformer_kwargs.sequence_k_dim].size @@ -65,4 +81,5 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # ] # kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value + audio_mel = audio_mel.to(self._tensor_space.distributed.device) kwargs[AudioEncoderKwargs.audio_mel] = audio_mel From 525543a74bf4d4362f441766f41c07172234d064 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 30 May 2025 16:44:10 +0000 Subject: [PATCH 65/82] updates --- fast_llm/layers/audio_encoder/adapter.py | 13 +++++++------ fast_llm/models/gpt/conversion.py | 23 +++++++++++++++++++---- fast_llm/models/gpt/model.py | 5 +---- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/fast_llm/layers/audio_encoder/adapter.py b/fast_llm/layers/audio_encoder/adapter.py index b02b4e77..bc4f8f00 100644 --- a/fast_llm/layers/audio_encoder/adapter.py +++ b/fast_llm/layers/audio_encoder/adapter.py @@ -77,10 +77,11 @@ def forward( # Reshape: group every k frames together (concatenate along feature dimension). new_seq_len = seq_len // self.aud_downsampling_k input_ = input_.contiguous().view(batch_size, new_seq_len, dim * self.aud_downsampling_k) - - res = self.layer_2( - self.norm_2( - torch_mlp_activation(input_=self.layer_1(input_), gated=False, activation_type=self._activation_type) - ) + layer1_res = torch_mlp_activation( + input_=self.layer_1(input_), gated=False, activation_type=self._activation_type ) - return res + torch.manual_seed(0) # TODO Toby: remove after debugging + layer1_res_dropout = torch.nn.functional.dropout(layer1_res, 0.1) + layer1_res_norm = self.norm_2(layer1_res_dropout) + layer2_res = self.layer_2(layer1_res_norm) + return layer2_res diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 6438ce0f..ad348ce9 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -647,11 +647,26 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_transformer_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + # return [ + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.weight", f"{hf_prefix}fc1.weight"), + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.bias", f"{hf_prefix}fc1.bias"), + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.weight", f"{hf_prefix}fc2.weight"), + # WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.bias", f"{hf_prefix}fc2.bias"), + # ] + transformer_config = self._model.config.base_model.audio_encoder.transformer return [ - WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.weight", f"{hf_prefix}fc1.weight"), - WeightConverter(f"{fast_llm_prefix}.mlp.layer_1.bias", f"{hf_prefix}fc1.bias"), - WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.weight", f"{hf_prefix}fc2.weight"), - WeightConverter(f"{fast_llm_prefix}.mlp.layer_2.bias", f"{hf_prefix}fc2.bias"), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + f"{hf_prefix}fc1", + transformer_config.add_mlp_bias, + WeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}fc2", + transformer_config.add_mlp_bias, + MLPLayer2Converter, + ), ] def _create_audio_transformer_layer_converters( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 330bd832..57ec951b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -479,10 +479,7 @@ def preprocess( if batch.audio is not None: kwargs[AudioEncoderKwargs.audio] = [ - [ - aud.to(device=self._tensor_space.distributed.device, dtype=torch.float32, non_blocking=True) - for aud in audio - ] + [aud.to(device="cpu", dtype=torch.float32, non_blocking=True) for aud in audio] for audio in batch.audio ] kwargs[AudioEncoderKwargs.audio_positions] = batch.audio_positions From 2aabf353752eeb9290f470cd76e44da8482c0456 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 30 May 2025 20:48:27 +0000 Subject: [PATCH 66/82] fix sequence parallel image patches --- fast_llm/layers/multi_modal/embedding.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 4dd4a46e..9e11df3f 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -48,11 +48,17 @@ def _forward( """ Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) group = self._tensor_space.distributed.tensor_group + if self._sequence_parallel: + micro_seqlen = input_.size(0) + patch_start_offset = self._distributed_config.tensor_rank * micro_seqlen + patch_end_offset = (self._distributed_config.tensor_rank + 1) * micro_seqlen + else: + patch_start_offset = 0 + patch_end_offset = input_.size(0) if self._parallel_embeddings: token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) masked_tokens = (tokens - self._vocab_start_index) * token_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa - embeddings = reduce_forward(embeddings, group) if self._use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) embeddings = embeddings.clone() @@ -61,13 +67,18 @@ def _forward( image_embedding_offset = 0 for position, size in zip(positions, sizes): num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) + if image_embedding_offset + num_patches < patch_start_offset: + continue if self._config.vision_encoder.image_break_token is not None: patch_width = div(size[0], self._config.vision_encoder.patch_size) patch_height = div(size[1], self._config.vision_encoder.patch_size) - for row in range(patch_height): row_start_src = image_embedding_offset + row * patch_width row_start_dst = position + row * (patch_width + 1) + if row_start_src > patch_end_offset: + break + if row_start_dst < patch_start_offset: + continue if self._sequence_parallel: # Copy with dimensions swapped for sequence parallel case @@ -84,6 +95,9 @@ def _forward( sample_idx, image_embedding_offset : image_embedding_offset + num_patches ] image_embedding_offset += num_patches + if image_embedding_offset > patch_end_offset: + break + embeddings = reduce_forward(embeddings, group) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: From b6d48589ad500034efdecb3727a5d163702f60e2 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 31 May 2025 01:50:12 +0000 Subject: [PATCH 67/82] fixes --- fast_llm/layers/multi_modal/embedding.py | 46 +++++++++++++------ .../layers/vision_encoder/preprocessing.py | 2 +- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 9e11df3f..76060a00 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -59,8 +59,6 @@ def _forward( token_mask = (tokens >= self._vocab_start_index) * (tokens < self._vocab_end_index) masked_tokens = (tokens - self._vocab_start_index) * token_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_tokens) * token_mask.unsqueeze(2) # noqa - if self._use_absolute_position_embeddings: - embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) embeddings = embeddings.clone() input_ = gather(input_, group, dim=0) for sample_idx, (positions, sizes) in enumerate(zip(image_positions, image_sizes)): @@ -70,34 +68,56 @@ def _forward( if image_embedding_offset + num_patches < patch_start_offset: continue if self._config.vision_encoder.image_break_token is not None: - patch_width = div(size[0], self._config.vision_encoder.patch_size) - patch_height = div(size[1], self._config.vision_encoder.patch_size) + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) for row in range(patch_height): row_start_src = image_embedding_offset + row * patch_width row_start_dst = position + row * (patch_width + 1) if row_start_src > patch_end_offset: break - if row_start_dst < patch_start_offset: + if row_start_src + patch_width <= patch_start_offset: continue + input_start_index = max(row_start_src, patch_start_offset) - patch_start_offset + input_end_index = min(row_start_src + patch_width, patch_end_offset) - patch_start_offset + embeddings_start_index = row_start_dst - max(patch_start_offset - row_start_src, 0) + embeddings_end_index = ( + row_start_dst + patch_width - max(row_start_src + patch_width - patch_end_offset, 0) + ) + # row_end_src = min(row_start_src + patch_width, patch_end_offset) if self._sequence_parallel: # Copy with dimensions swapped for sequence parallel case - embeddings[row_start_dst : row_start_dst + patch_width, sample_idx] = input_[ - row_start_src : row_start_src + patch_width, sample_idx + embeddings[embeddings_start_index:embeddings_end_index, sample_idx] = input_[ + input_start_index:input_end_index, sample_idx ] + tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 else: # Copy with normal dimension ordering - embeddings[sample_idx, row_start_dst : row_start_dst + patch_width] = input_[ - sample_idx, row_start_src : row_start_src + patch_width + embeddings[sample_idx, embeddings_start_index:embeddings_end_index] = input_[ + sample_idx, input_start_index:input_end_index ] + tokens[embeddings_start_index:embeddings_end_index, sample_idx] = 10 else: - embeddings[sample_idx, position : position + num_patches] = input_[ - sample_idx, image_embedding_offset : image_embedding_offset + num_patches + input_start_index = max(image_embedding_offset, patch_start_offset) - patch_start_offset + input_end_index = ( + min(image_embedding_offset + num_patches, patch_end_offset) - patch_start_offset + ) + embedding_start_index = position - max(patch_start_offset - image_embedding_offset, 0) + embedding_end_index = ( + position + num_patches - max(image_embedding_offset + num_patches - patch_end_offset, 0) + ) + embeddings[sample_idx, embedding_start_index:embedding_end_index] = input_[ + input_start_index:input_end_index, sample_idx ] + # embeddings[sample_idx, position : position + num_patches] = input_[ + # sample_idx, image_embedding_offset : image_embedding_offset + num_patches + # ] image_embedding_offset += num_patches if image_embedding_offset > patch_end_offset: break embeddings = reduce_forward(embeddings, group) + if self._use_absolute_position_embeddings: + embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: @@ -118,8 +138,8 @@ def _forward( for position, size in zip(positions, sizes): num_patches = get_num_patches(*size, self._config.vision_encoder.patch_size) if self._config.vision_encoder.image_break_token is not None: - patch_width = div(size[0], self._config.vision_encoder.patch_size) - patch_height = div(size[1], self._config.vision_encoder.patch_size) + patch_height = div(size[0], self._config.vision_encoder.patch_size) + patch_width = div(size[1], self._config.vision_encoder.patch_size) for row in range(patch_height): row_start_src = image_embedding_offset + row * patch_width diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 8fad3572..ab0d2378 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -205,7 +205,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: padding_size = kwargs[TransformerKwargs.sequence_length] - sample_cu_seqlen if padding_size > max_seqlen: max_seqlen = padding_size - cu_seqlens.append(kwargs[TransformerKwargs.sequence_length]) + cu_seqlens.append(kwargs[TransformerKwargs.sequence_length] * (idx + 1)) patches.append( torch.cat( [ From 25a650bf588e8a20b02a4b6f6b991aa42993808b Mon Sep 17 00:00:00 2001 From: root Date: Sat, 31 May 2025 17:10:16 +0000 Subject: [PATCH 68/82] no compile for embeddings --- fast_llm/layers/multi_modal/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/multi_modal/embedding.py b/fast_llm/layers/multi_modal/embedding.py index 76060a00..7f09347b 100644 --- a/fast_llm/layers/multi_modal/embedding.py +++ b/fast_llm/layers/multi_modal/embedding.py @@ -26,7 +26,7 @@ def __init__( ): super().__init__(config, tensor_space) - @torch.compile + # @torch.compile def _forward( self, input_: torch.Tensor, From c904da5def23c6db1abb775971f6790a4bec8272 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 1 Jun 2025 17:48:59 +0000 Subject: [PATCH 69/82] fix sampling --- fast_llm/data/dataset/gpt/sampled.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 6c8e9fe7..8d216b3d 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -453,7 +453,7 @@ def __getitem__(self, index: int) -> typing.Any: document_sampling_index += 1 continue tokens_in_sample = token_count % (self._parameters.sequence_length + 1) - if document_size + tokens_in_sample > self._parameters.sequence_length + 1: + if document_size + tokens_in_sample >= self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: @@ -464,6 +464,7 @@ def __getitem__(self, index: int) -> typing.Any: else: # Move on to the next sample. token_count += padding_size + continue # Determine if the document belongs to the requested sample. if token_count + document_size >= token_start: From 7a4701c522431eb94a873f59a220e13691c007b9 Mon Sep 17 00:00:00 2001 From: sohamparikh Date: Mon, 2 Jun 2025 00:15:54 -0700 Subject: [PATCH 70/82] sampling and preprocessing bugs --- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/layers/vision_encoder/preprocessing.py | 6 +++--- fast_llm/models/gpt/model.py | 13 ++++++++++--- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 8d216b3d..f58b009a 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -456,7 +456,7 @@ def __getitem__(self, index: int) -> typing.Any: if document_size + tokens_in_sample >= self._parameters.sequence_length + 1: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample - if token_count > token_start: + if token_count >= token_start: # Add padding tokens to current sample token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) Assert.eq(token_count + padding_size, token_end) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index ab0d2378..76b0aa28 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -137,6 +137,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: im_height = kwargs.get(VisionEncoderKwargs.image_size) im_width = kwargs.get(VisionEncoderKwargs.image_size) patch_size = kwargs[VisionEncoderKwargs.patch_size] + image_positions = kwargs.get(VisionEncoderKwargs.image_positions) image_sizes = [ [get_resize_dims(im.size(1), im.size(2), im_height, im_width, patch_size=patch_size) for im in ims] for ims in images @@ -156,7 +157,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ] for imgs in images ] - image_positions = kwargs.get(VisionEncoderKwargs.image_positions) labels = kwargs[LanguageModelKwargs.labels] if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): @@ -239,9 +239,9 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # TODO Soham: remove assert patches[-1].size(0) == kwargs[TransformerKwargs.sequence_length] patches = torch.cat(patches) - kwargs[LanguageModelKwargs.labels] = labels patch_position_ids = torch.cat(patch_position_ids) kwargs[VisionEncoderKwargs.image_patches] = patches + kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids kwargs[VisionEncoderKwargs.rotary_inv_freq] = create_inv_freqs( kwargs[VisionEncoderKwargs.rope_theta], kwargs[VisionEncoderKwargs.kv_channels], @@ -249,7 +249,6 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: patch_size, ).to(device=self._tensor_space.distributed.device) kwargs[VisionEncoderKwargs.max_image_tokens] = div(im_height * im_width, patch_size**2) - kwargs[VisionTransformerKwargs.patch_position_ids] = patch_position_ids # TODO Soham: handle sequence data parallel kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( cu_seqlens, device=self._tensor_space.distributed.device, dtype=torch.int32 @@ -259,3 +258,4 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ) kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen + kwargs[LanguageModelKwargs.labels] = labels diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 586b511b..45cf4a4f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -407,15 +407,22 @@ def preprocess( kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) - if batch.images is not None: + if self._config.vision_encoder.enabled: + batch_images = ( + batch.images if batch.images is not None else [[]] * kwargs[TransformerKwargs.micro_batch_size] + ) kwargs[VisionEncoderKwargs.images] = [ [ img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) for img in images ] - for images in batch.images + for images in batch_images ] - kwargs[VisionEncoderKwargs.image_positions] = batch.image_positions + kwargs[VisionEncoderKwargs.image_positions] = ( + batch.image_positions + if batch.image_positions is not None + else [[]] * kwargs[TransformerKwargs.micro_batch_size] + ) kwargs[LanguageModelKwargs.tokens] = tokens for preprocessor in self._preprocessors: From 067f901bc8bc0b51148f2531d1f929f74b90081a Mon Sep 17 00:00:00 2001 From: root Date: Mon, 2 Jun 2025 18:35:24 +0000 Subject: [PATCH 71/82] speed up sampling --- fast_llm/data/dataset/gpt/sampled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f58b009a..2972632c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -166,7 +166,7 @@ def _sample(self) -> None: " Please make sure Fast-LLM is installed correctly." ) long_docs_filter = document_sizes + image_token_sizes > self._parameters.sequence_length + 1 - ignored_documents = sum(long_docs_filter) + ignored_documents = long_docs_filter.sum() if ignored_documents: log_main_rank( f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", From 95526a3d674b89e89c8559757f286d1b48ab89a6 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 2 Jun 2025 21:03:45 +0000 Subject: [PATCH 72/82] adding audio start and end tokens --- fast_llm/data/dataset/gpt/config.py | 2 + fast_llm/data/dataset/gpt/sampled.py | 142 +++++++++++------- fast_llm/layers/audio_encoder/config.py | 12 ++ .../layers/audio_encoder/preprocessing.py | 42 ++++++ fast_llm/models/gpt/trainer.py | 2 + 5 files changed, 144 insertions(+), 56 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 0b72402f..9819f4e8 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -81,6 +81,8 @@ class GPTSamplingParameters(SamplingParameters): aud_sampling_rate: int | None = None image_break_token: int | None = None image_end_token: int | None = None + audio_start_token: int | None = None + audio_end_token: int | None = None # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index b2d30379..63337e67 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -14,6 +14,7 @@ from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.layers.audio_encoder.preprocessing import get_num_audio_tokens from fast_llm.layers.vision_encoder.preprocessing import get_num_image_tokens, get_resize_dims from fast_llm.utils import Assert, div @@ -132,38 +133,6 @@ def __init__( # No barrier yet to allow running in parallel. # There needs to be one before calling `__getitem__`, normally handled through `GPTData`. - def _compute_audio_token_size(self, sizes): - if len(sizes) == 0: # sample has no audio - return sizes, False - to_filter = False - # account for padding - if self._parameters.aud_padding_duration > 0: - raw_audio_seq_length = self._parameters.aud_padding_duration * self._parameters.aud_sampling_rate - sizes = sizes.copy() # original is read-only - to_filter = bool(np.any(sizes > raw_audio_seq_length)) # filter sample where any audio is too long - sizes.fill(raw_audio_seq_length) # set all audio sizes to padded amount - - # account for mel spectogram, convolution, downsampling k - audio_token_size_arr = sizes // 160 # default hop length TODO Toby: check divisible? - audio_token_size_arr = audio_token_size_arr // ( - 2 * self._parameters.aud_downsampling_k - ) # convolution (2) * downsampling - return audio_token_size_arr, to_filter - - def apply_audio_padding(self, audio): - if len(audio) == 0: - return audio - # TODO Toby: check 2d - padded_audio = [] - if self._parameters.aud_padding_duration > 0: - raw_audio_seq_length = self._parameters.aud_padding_duration * self._parameters.aud_sampling_rate - for aud in audio: - padded = np.pad(aud, (0, raw_audio_seq_length - len(aud)), mode="constant", constant_values=0) - padded_audio.append(padded) - return padded_audio - else: - return audio - def _sample(self) -> None: """ Create a `GPTSampledDataset` with the requested parameters. @@ -198,7 +167,14 @@ def _sample(self) -> None: audio_token_sizes = torch.zeros_like(document_sizes).to(self._device) long_audio_filter = torch.zeros_like(document_sizes, dtype=torch.bool) # longer than audio padding for i, sizes in enumerate(audio_sizes): - audio_token_size_arr, to_filter = self._compute_audio_token_size(sizes) + audio_token_size_arr, to_filter = get_num_audio_tokens( + sizes, + self._parameters.aud_padding_duration, + self._parameters.aud_sampling_rate, + self._parameters.aud_downsampling_k, + self._parameters.audio_start_token, + self._parameters.audio_end_token, + ) audio_token_sizes[i] = audio_token_size_arr.sum() long_audio_filter[i] = to_filter @@ -371,7 +347,7 @@ def _sample(self) -> None: unshuffled_tokens = 0 if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens.item() + yaml_data["unshuffled_tokens"] = unshuffled_tokens.item() * unshuffled_epochs self._load_yaml_data(yaml_data) if self._yaml_path is not None: self._yaml_path.parent.mkdir(parents=True, exist_ok=True) @@ -520,7 +496,14 @@ def __getitem__(self, index: int) -> typing.Any: ] image_tokens = sum(image_sizes) - audio_token_size_arr, _ = self._compute_audio_token_size(audio_lengths) + audio_token_size_arr, _ = get_num_audio_tokens( + audio_lengths, + self._parameters.aud_padding_duration, + self._parameters.aud_sampling_rate, + self._parameters.aud_downsampling_k, + self._parameters.audio_start_token, + self._parameters.audio_end_token, + ) audio_tokens = audio_token_size_arr.sum() document_size = text_size + image_tokens + audio_tokens @@ -585,14 +568,16 @@ def __getitem__(self, index: int) -> typing.Any: [(pos, "audio", idx) for idx, pos in enumerate(sample.audio_positions)] ) + token_ids_per_sample = [] + special_mm_tok_loss_masking_spans = np.empty((0, 2), dtype=np.int32) multimodal_positions.sort(key=lambda x: x[0]) for global_idx, (mm_position, mm_type, source_idx) in enumerate(multimodal_positions): # Add placeholders for image and audio tokens tokens - token_ids.append(sample.token_ids[start_pos:mm_position]) + token_ids_per_sample.append(sample.token_ids[start_pos:mm_position]) + text_tokens_added += len(token_ids_per_sample[-1]) if mm_type == "image": # image_positions.append(im_positions + len(token_ids) + image_tokens_added) # Add placeholders for image tokens - text_tokens_added += len(token_ids[-1]) image_positions.append(text_tokens_added + mm_tokens_added) if self._parameters.image_break_token is not None: height, width = resized_image_lengths[source_idx] @@ -616,21 +601,55 @@ def __getitem__(self, index: int) -> typing.Any: image_token_array = np.full((image_sizes[source_idx],), -100, dtype=np.int64) if self._parameters.image_end_token is not None: image_token_array[-1] = self._parameters.image_end_token - token_ids.append(image_token_array) + token_ids_per_sample.append(image_token_array) mm_tokens_added += image_sizes[source_idx] elif mm_type == "audio": - audio_pos = sum(t.size for t in token_ids) # includes mm tokens added already + # audio_pos = sum(t.size for t in token_ids) # includes mm tokens added already + # compute audio position + start_token_offset = int(self._parameters.audio_start_token is not None) + audio_pos = text_tokens_added + mm_tokens_added + start_token_offset audio_positions.append(audio_pos) - token_ids.append( - np.full((audio_token_size_arr[source_idx],), -100, dtype=np.int64) - ) # TODO Toby: index doesnt work here - mm_tokens_added += audio_tokens + + # compute number of special tokens + num_audio_special_tokens = int(self._parameters.audio_start_token is not None) + int( + self._parameters.audio_end_token is not None + ) + + # add start tokens + if self._parameters.audio_start_token is not None: + token_ids_per_sample.append(np.array([self._parameters.audio_start_token])) + # add to loss masking spans + special_mm_tok_loss_masking_spans = np.append( + special_mm_tok_loss_masking_spans, [[audio_pos - 1, audio_pos - 1]], axis=0 + ) + # sample.loss_masking_spans = np.append(sample.loss_masking_spans, [[audio_pos-1, audio_pos-1]], axis=0) + + # add audio pad tokens + num_audio_pad_tokens = audio_token_size_arr[source_idx] + num_audio_pad_tokens -= num_audio_special_tokens # ignore start/end tokens for padding + audio_padding_tokens = np.full((num_audio_pad_tokens,), -100, dtype=np.int64) + token_ids_per_sample.append(audio_padding_tokens) + + # add end token + if self._parameters.audio_end_token is not None: + token_ids_per_sample.append(np.array([self._parameters.audio_end_token])) + # add to loss masking spans + special_mm_tok_loss_masking_spans = np.append( + special_mm_tok_loss_masking_spans, + [[audio_pos + num_audio_pad_tokens, audio_pos + num_audio_pad_tokens]], + axis=0, + ) + # sample.loss_masking_spans = np.append(sample.loss_masking_spans, [[audio_pos + num_audio_pad_tokens, audio_pos + num_audio_pad_tokens]], axis=0) + + # update mm tokens added + mm_tokens_added += num_audio_special_tokens + num_audio_pad_tokens start_pos = mm_position - token_ids.append(sample.token_ids[start_pos:]) + # add remaining text tokens + token_ids_per_sample.append(sample.token_ids[start_pos:]) + text_tokens_added += len(token_ids_per_sample[-1]) - # TODO Soham: add offsets for loss masking spans - text_tokens_added += len(token_ids[-1]) + token_ids.append(np.concatenate(token_ids_per_sample)) if sample.images: images.append(sample.images) else: @@ -643,22 +662,25 @@ def __getitem__(self, index: int) -> typing.Any: if self._parameters.use_loss_masking_spans: mm_idx = 0 - total_mm_tokens = 0 - for loss_masking_span in sample.loss_masking_spans: # TODO: check these must be sorted - mm_tokens_in_span = 0 + mm_tokens_before_span = 0 + + # sort by start of span + sample.loss_masking_spans = sample.loss_masking_spans[sample.loss_masking_spans[:, 0].argsort()] + for loss_masking_span in sample.loss_masking_spans: + mm_tokens_within_span = 0 mm_position, mm_type, source_idx = ( multimodal_positions[mm_idx] if mm_idx < len(multimodal_positions) else (float("inf"), _, _) ) - # increment mm_idx until span is reached + # increment mm_idx until span is reached, track mm tokens before span while mm_position < loss_masking_span[0]: if mm_type == "image": num_mm_tokens = image_sizes[source_idx] elif mm_type == "audio": num_mm_tokens = audio_token_size_arr[source_idx] - total_mm_tokens += num_mm_tokens + mm_tokens_before_span += num_mm_tokens mm_idx += 1 mm_position, mm_type, source_idx = ( multimodal_positions[mm_idx] @@ -672,25 +694,33 @@ def __getitem__(self, index: int) -> typing.Any: num_mm_tokens = image_sizes[source_idx] elif mm_type == "audio": num_mm_tokens = audio_token_size_arr[source_idx] - mm_tokens_in_span += num_mm_tokens + mm_tokens_within_span += num_mm_tokens mm_idx += 1 mm_position, mm_type, source_idx = ( multimodal_positions[mm_idx] if mm_idx < len(multimodal_positions) else (float("inf"), _, _) ) - loss_masking_span[0] += total_mm_tokens # increment by all mm tokens before span - loss_masking_span[1] += total_mm_tokens + mm_tokens_in_span - total_mm_tokens += mm_tokens_in_span + loss_masking_span[0] += mm_tokens_before_span # increment by all mm tokens before span + loss_masking_span[1] += mm_tokens_before_span + mm_tokens_within_span + mm_tokens_before_span += mm_tokens_within_span span = np.clip( loss_masking_span + token_count - token_start, 0, self._parameters.sequence_length + self._parameters.extra_tokens, ) - if span[1] > span[0]: + if span[1] >= span[0]: loss_masking_spans.append(span) + for span in special_mm_tok_loss_masking_spans: + # span = np.clip( + # loss_masking_span + token_count - token_start, + # 0, + # self._parameters.sequence_length + self._parameters.extra_tokens, + # ) + if span[1] >= span[0]: + loss_masking_spans.append(span) # Go to the next document. document_sampling_index += 1 token_count += document_size diff --git a/fast_llm/layers/audio_encoder/config.py b/fast_llm/layers/audio_encoder/config.py index 9503c60c..95665901 100644 --- a/fast_llm/layers/audio_encoder/config.py +++ b/fast_llm/layers/audio_encoder/config.py @@ -122,6 +122,18 @@ class AudioEncoderConfig(BaseModelConfig): hint=FieldHint.feature, ) + # audio start/end tokens + audio_start_token: int | None = Field( + default=None, + desc="Token id for audio start.", + hint=FieldHint.optional, + ) + audio_end_token: int | None = Field( + default=None, + desc="Token id for audio end.", + hint=FieldHint.optional, + ) + def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.in_channels, self.num_mel_bins)) tensor_space.add_tensor_dim(TensorDim(AudioEncoderDimNames.out_channels, self.transformer.hidden_size)) diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index 506e026e..8959837d 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -1,5 +1,6 @@ import typing +import numpy as np import torch from transformers import WhisperFeatureExtractor @@ -8,6 +9,47 @@ from fast_llm.layers.audio_encoder.config import AudioEncoderConfig, AudioEncoderKwargs +def get_num_audio_tokens( + sizes, aud_padding_duration, aud_sampling_rate, aud_downsampling_k, audio_start_token, audio_end_token +): + if len(sizes) == 0: # sample has no audio + return sizes, False + to_filter = False + # account for padding + if aud_padding_duration > 0: + raw_audio_seq_length = aud_padding_duration * aud_sampling_rate + sizes = sizes.copy() # original is read-only + to_filter = bool(np.any(sizes > raw_audio_seq_length)) # filter sample where any audio is too long + sizes.fill(raw_audio_seq_length) # set all audio sizes to padded amount + + # account for mel spectogram, convolution, downsampling k + audio_token_size_arr = sizes // 160 # default hop length TODO Toby: check divisible? + audio_token_size_arr = audio_token_size_arr // ( + 2 * aud_downsampling_k + ) # convolution (2 stride) * downsampling TODO Toby: make configurable convolution + + if audio_start_token is not None: + audio_token_size_arr += 1 + if audio_end_token is not None: + audio_token_size_arr += 1 + return audio_token_size_arr, to_filter + + +def apply_audio_padding(audio, aud_padding_duration, aud_sampling_rate): + if len(audio) == 0: + return audio + # TODO Toby: check 2d + padded_audio = [] + if aud_padding_duration > 0: + raw_audio_seq_length = aud_padding_duration * aud_sampling_rate + for aud in audio: + padded = np.pad(aud, (0, raw_audio_seq_length - len(aud)), mode="constant", constant_values=0) + padded_audio.append(padded) + return padded_audio + else: + return audio + + class AudioPreprocessor(Preprocessor): def __init__(self, config: AudioEncoderConfig, tensor_space: TensorSpace): self._config = config diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index c9025f7f..b4a3036f 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -47,6 +47,8 @@ def _get_sampling_parameters( "aud_downsampling_k": self._config.model.base_model.audio_encoder.aud_downsampling_k, "aud_padding_duration": self._config.batch.aud_padding_duration, "aud_sampling_rate": self._config.model.base_model.audio_encoder.aud_sampling_rate, + "audio_start_token": self._config.model.base_model.audio_encoder.audio_start_token, + "audio_end_token": self._config.model.base_model.audio_encoder.audio_end_token, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) From fb23ef8cceb15d43612c042c8bcf43070a68d9ed Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 3 Jun 2025 17:04:29 +0000 Subject: [PATCH 73/82] conversion changes --- fast_llm/models/gpt/conversion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index ad348ce9..568c7808 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -1075,19 +1075,19 @@ def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: # TODO Toby: implement for audio exported_config = {} - vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + audio_handler_class = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.audio_name) text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) - for converter in vision_handler_cls._create_config_converters(): + for converter in audio_handler_class._create_config_converters(): try: values = converter.export_params( tuple( - cls._get_fast_llm_attribute(config, ("vision_encoder",) + fast_llm_name) + cls._get_fast_llm_attribute(config, ("audio_encoder",) + fast_llm_name) for fast_llm_name in converter.fast_llm_names ) ) for export_name, value in zip(converter.export_names, values, strict=True): if value is not MISSING: - set_nested_dict_value(exported_config, ("vision_config",) + export_name, value) + set_nested_dict_value(exported_config, ("audio_config",) + export_name, value) except Exception as e: raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) From d7d11352bfa828352f2d9d338390d33b95693683 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 3 Jun 2025 22:43:32 +0000 Subject: [PATCH 74/82] adding data prep sharding --- fast_llm/data/preparator/gpt_memmap/prepare.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 283a6bf8..cb34cc91 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -84,6 +84,11 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ width, height = im.size num_pixels[idx] += width * height * 3 + num_audio = [0] * len(input_ids) + for idx, audio_lst in enumerate(batch.get(self._config.dataset.audio, [])): + for audio in audio_lst: + num_audio[idx] += len(audio) + return { "input_ids": input_ids, "image_positions": image_token_positions, @@ -91,6 +96,7 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ "token_spans": token_spans, "num_tokens": num_tokens, "num_pixels": num_pixels, + "num_audio": num_audio, } # def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -296,6 +302,7 @@ def run(self) -> None: batched=True, num_proc=self._config.tokenize_workers, desc="Tokenizing batches", + # load_from_cache_file=False # TODO Toby: remove ) # Calculate total number of tokens @@ -305,7 +312,13 @@ def run(self) -> None: if self._config.dataset.images else 0 ) + total_audio = ( + sum(tqdm.tqdm(tokenized_dataset["num_audio"], desc="Counting audio", unit="audio")) + if self._config.dataset.audio + else 0 + ) total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize + total_tokens += total_audio * np.float32().itemsize // np.dtype(self._data_type.numpy).itemsize # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) From 012a6364dfa98c54e7424ab748856568a0c21eae Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 6 Jun 2025 02:24:33 +0000 Subject: [PATCH 75/82] faster mel sepc --- fast_llm/layers/audio_encoder/preprocessing.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index 8959837d..916c97bb 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -109,6 +109,7 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: sampling_rate=self._config.aud_sampling_rate, return_tensors="pt", max_length=30 * self._config.aud_sampling_rate, + device=self._tensor_space.distributed.device, )["input_features"] ) audio_mel = torch.stack(audio_mel, dim=0).squeeze(1) @@ -124,4 +125,11 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value audio_mel = audio_mel.to(self._tensor_space.distributed.device) + + # PAD_TO = 100 + # padding_size = PAD_TO - audio_mel.size(0) + # padding = torch.zeros(padding_size, audio_mel.size(1), audio_mel.size(2), dtype=audio_mel.dtype, device=audio_mel.device) + + # audio_mel = torch.cat((audio_mel, padding), dim=0) + kwargs[AudioEncoderKwargs.audio_mel] = audio_mel From c664444bf323740a3d2488956bb9e5a933c2aa9b Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 12 Jun 2025 21:04:01 +0000 Subject: [PATCH 76/82] adding num audio to config --- fast_llm/data/preparator/gpt_memmap/prepare.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index cb34cc91..832af202 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -171,6 +171,7 @@ def _document_generator(): "num_documents": len(shard_dataset), # Use the length of the shard dataset directly "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), "num_pixels": sum(doc["num_pixels"] for doc in shard_dataset), + "num_audio": sum(doc["num_audio"] for doc in shard_dataset), } ) From ba7393970f0d9afbcd4fd241beaf7a5ec1ee1834 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 12 Jun 2025 22:41:47 +0000 Subject: [PATCH 77/82] audio encoder padding updates --- fast_llm/data/data/gpt/data.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 10 ++- .../layers/audio_encoder/preprocessing.py | 76 ++++++++++--------- fast_llm/models/gpt/model.py | 19 ++--- 4 files changed, 56 insertions(+), 51 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index a4b183ca..8f700978 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -64,7 +64,7 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling has_audio = False batch_audio = [] for sample in batch: - if sample.audio is not None and len(sample.audio_positions) > 0: + if sample.audio is not None and sample.audio_positions is not None: batch_audio.append([torch.from_numpy(audio) for audio in sample.audio]) has_audio = True else: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 628de35f..2c64c47e 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -243,7 +243,7 @@ def _sample(self) -> None: shuffled_documents = documents_per_epoch * shuffled_epochs unshuffled_epochs = num_epochs - shuffled_epochs - yaml_data = { + yaml_data = { # TODO Toby: add audio "dataset": { "name": self._indexed_dataset.name, "documents_per_epoch": documents_per_epoch, @@ -504,7 +504,7 @@ def __getitem__(self, index: int) -> typing.Any: self._parameters.audio_start_token, self._parameters.audio_end_token, ) - audio_tokens = audio_token_size_arr.sum() + audio_tokens = int(audio_token_size_arr.sum()) document_size = text_size + image_tokens + audio_tokens @@ -705,7 +705,7 @@ def __getitem__(self, index: int) -> typing.Any: mm_tokens_before_span += mm_tokens_within_span span = np.clip( - loss_masking_span + token_count - token_start, + loss_masking_span + int(token_count) - int(token_start), 0, self._parameters.sequence_length + self._parameters.extra_tokens, ) @@ -743,13 +743,15 @@ def __getitem__(self, index: int) -> typing.Any: audio_positions = np.array(audio_positions) if audio_positions else None # Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) + # # TODO: Toby remove/comment after testing (for testing only first sample) + # loss_masking_spans = np.append(loss_masking_spans, [[sequence_lengths[0], sequence_lengths[:-1].sum()]], axis=0) return GPTSample( token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths, images=images, image_positions=image_positions, - audio=audio, + audio=audio if len(audio) > 0 else None, audio_positions=audio_positions, ) diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index 916c97bb..f6a696d7 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -1,3 +1,4 @@ +import math import typing import numpy as np @@ -13,7 +14,7 @@ def get_num_audio_tokens( sizes, aud_padding_duration, aud_sampling_rate, aud_downsampling_k, audio_start_token, audio_end_token ): if len(sizes) == 0: # sample has no audio - return sizes, False + return np.array(sizes), False to_filter = False # account for padding if aud_padding_duration > 0: @@ -88,33 +89,43 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: pass def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: - audio_raw = kwargs[AudioEncoderKwargs.audio] - flattened_audio = [ - audio_arr for sequence in audio_raw for audio_arr in sequence - ] # flatten in the batch dimension - # flattened_audio_tensor = torch.stack(flattened_audio, dim=0) - # audio_inputs = self.feature_extractor(audio_raw, sampling_rate=16000, return_tensors="pt") - # self.mel_transform.to(self._tensor_space.distributed.device) - - # audio_mel = self.mel_transform(flattened_audio_tensor) - # flattened_audio_tensor = np.stack(flattened_audio, axis=0) - # audio_inputs = self.feature_extractor(flattened_audio_tensor, sampling_rate=16000, return_tensors="pt") - # audio_mel = audio_inputs['input_features'] - + # check if audio is in batch audio_mel = [] - for audio in flattened_audio: - audio_mel.append( - self.feature_extractor( - audio, - sampling_rate=self._config.aud_sampling_rate, - return_tensors="pt", - max_length=30 * self._config.aud_sampling_rate, - device=self._tensor_space.distributed.device, - )["input_features"] - ) - audio_mel = torch.stack(audio_mel, dim=0).squeeze(1) - - # audio_mel = audio_mel[:, :, :-1] # TODO Toby: check this! + if AudioEncoderKwargs.audio in kwargs: + audio_raw = kwargs[AudioEncoderKwargs.audio] + flattened_audio = [ + audio_arr for sequence in audio_raw for audio_arr in sequence + ] # flatten in the batch dimension + + for audio in flattened_audio: + audio_mel.append( + self.feature_extractor( + audio, + sampling_rate=self._config.aud_sampling_rate, + return_tensors="pt", + max_length=30 * self._config.aud_sampling_rate, + device=self._tensor_space.distributed.device, + )["input_features"] + ) + audio_mel = torch.stack(audio_mel, dim=0).squeeze(1) + curr_size = audio_mel.size(0) + else: + audio_mel = torch.tensor(audio_mel, dtype=torch.float32) + curr_size = 0 + + max_pad = math.ceil(kwargs["sequence_length"] / (kwargs["audio_encoder_sequence_length"] // 5)) + padding_size = max_pad - curr_size + padding = torch.zeros( + padding_size, + self.feature_extractor.feature_size, + self.feature_extractor.nb_max_frames, + dtype=audio_mel.dtype, + device=audio_mel.device, + ) + audio_mel = torch.cat((audio_mel, padding), dim=0) + audio_mel = audio_mel.to(self._tensor_space.distributed.device) + + kwargs[AudioEncoderKwargs.audio_mel] = audio_mel # # set attention mask # TODO Toby: fix backup attention # sequence_k = kwargs[self._transformer_kwargs.sequence_k_dim].size @@ -123,13 +134,4 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k # ] # kwargs[self._transformer_kwargs.attention_mask_value] = self._mask_value - - audio_mel = audio_mel.to(self._tensor_space.distributed.device) - - # PAD_TO = 100 - # padding_size = PAD_TO - audio_mel.size(0) - # padding = torch.zeros(padding_size, audio_mel.size(1), audio_mel.size(2), dtype=audio_mel.dtype, device=audio_mel.device) - - # audio_mel = torch.cat((audio_mel, padding), dim=0) - - kwargs[AudioEncoderKwargs.audio_mel] = audio_mel + # audio_mel = torch.rand(len(flattened_audio), 80, 3000) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bc04fc6e..05b15e4d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -484,22 +484,23 @@ def preprocess( ) kwargs[LanguageModelKwargs.tokens] = tokens - if batch.audio is not None: - kwargs[AudioEncoderKwargs.audio] = [ - [aud.to(device="cpu", dtype=torch.float32, non_blocking=True) for aud in audio] - for audio in batch.audio - ] - kwargs[AudioEncoderKwargs.audio_positions] = batch.audio_positions + if self._config.audio_encoder.enabled: + if batch.audio is not None: + kwargs[AudioEncoderKwargs.audio] = [ + [aud.to(device="cpu", dtype=torch.float32, non_blocking=True) for aud in audio] + for audio in batch.audio + ] + kwargs[AudioEncoderKwargs.audio_positions] = batch.audio_positions kwargs[LanguageModelKwargs.tokens] = tokens for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) audio_mel = kwargs.get(AudioEncoderKwargs.audio_mel, None) - if image_patches is not None: - preprocessed.append((image_patches, kwargs)) - elif audio_mel is not None: + if audio_mel is not None: preprocessed.append((audio_mel, kwargs)) + elif image_patches is not None: + preprocessed.append((image_patches, kwargs)) else: preprocessed.append((tokens, kwargs)) From 5667a0a3d1a8cdcf7f36253976c40e0cbc8bf7c6 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 12 Jun 2025 23:02:37 +0000 Subject: [PATCH 78/82] configurable max pad --- fast_llm/layers/audio_encoder/preprocessing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index f6a696d7..a326dc60 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -113,7 +113,9 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: audio_mel = torch.tensor(audio_mel, dtype=torch.float32) curr_size = 0 - max_pad = math.ceil(kwargs["sequence_length"] / (kwargs["audio_encoder_sequence_length"] // 5)) + max_pad = math.ceil( + kwargs["sequence_length"] / (kwargs["audio_encoder_sequence_length"] // self._config.aud_downsampling_k) + ) padding_size = max_pad - curr_size padding = torch.zeros( padding_size, From 9f68a5e47361f10bf8429c529d8d0d61b008c142 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 16 Jun 2025 20:18:47 +0000 Subject: [PATCH 79/82] small fix --- fast_llm/data/dataset/gpt/config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 9819f4e8..357623b1 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -209,6 +209,11 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="Expected number of pixels in the dataset.", hint=FieldHint.optional, ) + num_audio: int | None = Field( + default=None, + desc="Expected number of audio in the dataset.", + hint=FieldHint.optional, + ) def build(self) -> "GPTMemmapDataset": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset From c286f8d649965128ec284a8324255bb487415fe9 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 18 Jun 2025 23:34:07 +0000 Subject: [PATCH 80/82] debugging updates --- fast_llm/data/dataset/gpt/memmap.py | 23 +++++++++---- .../data/preparator/gpt_memmap/prepare.py | 12 +++---- .../layers/audio_encoder/preprocessing.py | 32 +++++++++++++------ 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 74a7b420..c0353a42 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -5,6 +5,8 @@ import numpy as np import PIL.Image +import torchaudio +import soundfile as sf from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample @@ -285,6 +287,10 @@ def get( for audio_length in self._audio_lengths[idx]: audio.append(all_audio[start : start + audio_length]) start += audio_length + + print("Memmap audio length: ", self._audio_lengths[idx]) + print("Memmap audio pos: ", self._audio_positions[idx]) + print("Memmap get audio: ", audio) # TODO Soham: return loss_masking_spans sample_spans = None @@ -427,13 +433,18 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP total_im_size += pixels.size im_positions.append(document.image_positions) if document.audio is not None: - n_audio.append(len(document.audio)) - total_audio += len(document.audio) + num_audio = 0 for audio in document.audio: - audio_lengths.append(len(audio)) - bin_stream.write(audio.tobytes(order="C")) - total_aud_size += audio.size - if len(document.audio) > 0: + # audio_arr, _ = torchaudio.load(io.BytesIO(audio["bytes"])) + audio_arr, _ = sf.read(io.BytesIO(audio["bytes"])) + if len(audio_arr) > 0: + num_audio += 1 + audio_lengths.append(len(audio_arr)) + bin_stream.write(audio_arr.tobytes(order="C")) + total_aud_size += audio_arr.size + n_audio.append(num_audio) + total_audio += num_audio + if num_audio > 0: aud_positions += document.audio_positions # Update metadata diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 832af202..888e1b63 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -154,11 +154,7 @@ def _document_generator(): ), item["images"] if self._config.dataset.images else None, item["image_positions"] if self._config.dataset.image_positions else None, - ( - np.array(item[self._config.dataset.audio], dtype=np.float32) - if self._config.dataset.audio - else None - ), + item[self._config.dataset.audio] if self._config.dataset.audio else None, item[self._config.dataset.audio_positions] if self._config.dataset.audio_positions else None, ) @@ -296,6 +292,8 @@ def run(self) -> None: # decoding bytes to images is slow and should be done only when needed if self._config.dataset.images is not None: dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) + if self._config.dataset.audio is not None: + dataset = dataset.cast_column("audio", datasets.Sequence(datasets.Audio(decode=False))) # Tokenize the dataset in parallel tokenized_dataset = dataset.map( @@ -303,7 +301,7 @@ def run(self) -> None: batched=True, num_proc=self._config.tokenize_workers, desc="Tokenizing batches", - # load_from_cache_file=False # TODO Toby: remove + load_from_cache_file=False # TODO Toby: remove ) # Calculate total number of tokens @@ -321,6 +319,8 @@ def run(self) -> None: total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize total_tokens += total_audio * np.float32().itemsize // np.dtype(self._data_type.numpy).itemsize + tokenized_dataset = tokenized_dataset.shuffle(seed=42) + # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) shards = [ diff --git a/fast_llm/layers/audio_encoder/preprocessing.py b/fast_llm/layers/audio_encoder/preprocessing.py index a326dc60..21262fe9 100644 --- a/fast_llm/layers/audio_encoder/preprocessing.py +++ b/fast_llm/layers/audio_encoder/preprocessing.py @@ -92,10 +92,12 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: # check if audio is in batch audio_mel = [] if AudioEncoderKwargs.audio in kwargs: + print("Preprocessing Contains Audio") audio_raw = kwargs[AudioEncoderKwargs.audio] flattened_audio = [ audio_arr for sequence in audio_raw for audio_arr in sequence ] # flatten in the batch dimension + print("Preprocessing Flattened Audio: ", flattened_audio) for audio in flattened_audio: audio_mel.append( @@ -110,23 +112,35 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: audio_mel = torch.stack(audio_mel, dim=0).squeeze(1) curr_size = audio_mel.size(0) else: + print("Preprocessing No Audio") audio_mel = torch.tensor(audio_mel, dtype=torch.float32) curr_size = 0 + + print("Preprocessing Audio Mel Raw: ", audio_mel) + # compute max pad max_pad = math.ceil( kwargs["sequence_length"] / (kwargs["audio_encoder_sequence_length"] // self._config.aud_downsampling_k) ) + max_pad = 1 + max_pad = max(max_pad, curr_size) + + # add padding padding_size = max_pad - curr_size - padding = torch.zeros( - padding_size, - self.feature_extractor.feature_size, - self.feature_extractor.nb_max_frames, - dtype=audio_mel.dtype, - device=audio_mel.device, - ) - audio_mel = torch.cat((audio_mel, padding), dim=0) + if padding_size > 0: + padding = torch.zeros( + padding_size, + self.feature_extractor.feature_size, + self.feature_extractor.nb_max_frames, + dtype=audio_mel.dtype, + device=audio_mel.device, + ) + audio_mel = torch.cat((audio_mel, padding), dim=0) + + print("Preprocessing Audio Mel Final: ", audio_mel) + + # move to device audio_mel = audio_mel.to(self._tensor_space.distributed.device) - kwargs[AudioEncoderKwargs.audio_mel] = audio_mel # # set attention mask # TODO Toby: fix backup attention From eb39e7e9a9c4ca78c25a0d072ae513ff35a89cc1 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 23 Jun 2025 20:09:52 +0000 Subject: [PATCH 81/82] working 5b changes --- fast_llm/data/dataset/gpt/memmap.py | 5 +---- fast_llm/data/dataset/gpt/sampled.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c0353a42..c47d3cf6 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -287,10 +287,6 @@ def get( for audio_length in self._audio_lengths[idx]: audio.append(all_audio[start : start + audio_length]) start += audio_length - - print("Memmap audio length: ", self._audio_lengths[idx]) - print("Memmap audio pos: ", self._audio_positions[idx]) - print("Memmap get audio: ", audio) # TODO Soham: return loss_masking_spans sample_spans = None @@ -437,6 +433,7 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP for audio in document.audio: # audio_arr, _ = torchaudio.load(io.BytesIO(audio["bytes"])) audio_arr, _ = sf.read(io.BytesIO(audio["bytes"])) + audio_arr = audio_arr.astype(np.float32) if len(audio_arr) > 0: num_audio += 1 audio_lengths.append(len(audio_arr)) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2c64c47e..ea8eed40 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -347,7 +347,7 @@ def _sample(self) -> None: unshuffled_tokens = 0 if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens.item() * unshuffled_epochs + yaml_data["unshuffled_tokens"] = unshuffled_tokens * unshuffled_epochs self._load_yaml_data(yaml_data) if self._yaml_path is not None: self._yaml_path.parent.mkdir(parents=True, exist_ok=True) From a53c89a454decdec70356ffac601dc3e4fe6dc09 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 23 Jun 2025 23:23:10 +0000 Subject: [PATCH 82/82] small fixes --- fast_llm/data/dataset/gpt/sampled.py | 2 +- fast_llm/models/gpt/model.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index ea8eed40..42cc0729 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -751,7 +751,7 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths=sequence_lengths, images=images, image_positions=image_positions, - audio=audio if len(audio) > 0 else None, + audio=audio if audio is not None and len(audio) > 0 else None, audio_positions=audio_positions, ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 05b15e4d..48f5760b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -444,7 +444,7 @@ def preprocess( if batch.loss_masking_spans is not None: # avoid changing input tokens labels = labels.clone() - for i, spans in enumerate(batch.loss_masking_spans): + for idx, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue valid_spans = spans[ @@ -457,9 +457,9 @@ def preprocess( loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - loss_mask[start : end + 1, i] = False + loss_mask[start : end + 1, idx] = False else: - loss_mask[i, start : end + 1] = False + loss_mask[idx, start : end + 1] = False if self._config.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100)