diff --git a/check_logits_hidden_layers.ipynb b/check_logits_hidden_layers.ipynb new file mode 100644 index 00000000..ee4e6f5f --- /dev/null +++ b/check_logits_hidden_layers.ipynb @@ -0,0 +1,387 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset\n", + "from pathlib import Path\n", + "import numpy as np\n", + "from transformers import AutoTokenizer\n", + "import torch\n", + "import pickle\n", + "\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "files_root = Path(\"/mnt/datasets/tests/denis/tensors_f32/\")\n", + "#files_root = Path(\"/mnt/datasets/tests/denis/tensors/\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fm_files = {int(file.stem.split(\"tensor\")[1]): file for file in (files_root / \"fast_llm/logits/\").glob(\"tensor*.pt\")}\n", + "hf_files = {int(file.stem.split(\"tensor\")[1]): file for file in (files_root / \"hf/logits\").glob(\"tensor*.pt\")}\n", + "assert len(fm_files) == len(hf_files)\n", + "len(fm_files)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hf_tokens = []\n", + "fm_tokens = []\n", + "max_adiff = []\n", + "mean_adiff = []\n", + "sum_adiff = []\n", + "for i in range(len(fm_files)):\n", + " fm_data = torch.load(fm_files[i])\n", + " hf_data = torch.load(hf_files[i])\n", + " \n", + " hf_tokens.append(hf_data[0, -1, :].argmax().item())\n", + " fm_tokens.append(fm_data[0, -1, :].argmax().item())\n", + "\n", + " adiff = torch.abs(hf_data[0, -1, :] - fm_data[0, -1, :])\n", + " max_adiff.append(adiff.max().item())\n", + " mean_adiff.append(adiff.mean().item())\n", + " sum_adiff.append(adiff.sum().item())\n", + " \n", + "all(a == b for a, b in zip(hf_tokens, fm_tokens))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "min(len(hf_tokens)+1 if ab[0] == ab[1] else i for i, ab in enumerate(zip(hf_tokens, fm_tokens)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True)\n", + "\n", + "# Left plot: max and mean absolute differences\n", + "axes[0].plot(max_adiff, label='max')\n", + "axes[0].plot(mean_adiff, label='mean')\n", + "axes[0].set_title('Max and Mean Absolute Difference')\n", + "axes[0].set_xlabel('Token Position Index')\n", + "axes[0].set_ylabel('Absolute Difference')\n", + "axes[0].legend()\n", + "axes[0].grid(True)\n", + "\n", + "# Right plot: sum absolute difference\n", + "axes[1].plot(sum_adiff, label='sum', color='tab:orange')\n", + "axes[1].set_title('Sum Absolute Difference')\n", + "axes[1].set_xlabel('Token Position Index')\n", + "axes[1].set_ylabel('Absolute Difference')\n", + "axes[1].legend()\n", + "axes[1].grid(True)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fm_hidden_files = {int(file.stem.split(\"data\")[1]): file for file in (files_root / \"fast_llm/hidden_states/\").glob(\"data*.pickle\")}\n", + "hf_hidden_files = {int(file.stem.split(\"data\")[1]): file for file in (files_root / \"hf/hidden_states\").glob(\"data*.pickle\")}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def mad(new_token_index, fm_hidden_files, hf_hidden_files):\n", + " with fm_hidden_files[new_token_index].open(\"rb\") as f:\n", + " fm_data = pickle.load(f)\n", + " with hf_hidden_files[new_token_index].open(\"rb\") as f:\n", + " hf_data = pickle.load(f)\n", + " max_adiffs_hidden_layers = []\n", + " for i in range(len(hf_data)):\n", + " max_adiff = torch.abs(hf_data[i][0,-1,:]-fm_data[i]['tensor'][0,-1,:]).max().item()\n", + " max_adiffs_hidden_layers.append(max_adiff)\n", + " return max_adiffs_hidden_layers\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_token_index = 107\n", + "new_token_index1 = 108\n", + "max_adiffs_hidden_layers = mad(0, fm_hidden_files, hf_hidden_files)\n", + "max_adiffs_hidden_layers2 = mad(new_token_index, fm_hidden_files, hf_hidden_files)\n", + "max_adiffs_hidden_layers3 = mad(new_token_index1, fm_hidden_files, hf_hidden_files)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True)\n", + "\n", + "axes[0].plot(max_adiffs_hidden_layers, label='new_token_0', color='blue')\n", + "axes[0].plot(max_adiffs_hidden_layers2, label=f'new_token_{new_token_index}', color='green')\n", + "axes[0].set_title('Max and Mean Absolute Difference')\n", + "axes[0].set_xlabel('Hidden Layer Index')\n", + "axes[0].set_ylabel('Max Absolute Difference')\n", + "axes[0].legend()\n", + "axes[0].grid(True)\n", + "\n", + "axes[1].plot(max_adiffs_hidden_layers, label='new_token_0', color='blue')\n", + "axes[1].plot(max_adiffs_hidden_layers3, label=f'new_token_{new_token_index1}', color='green')\n", + "axes[1].set_title('Max and Mean Absolute Difference')\n", + "axes[1].set_xlabel('Hidden Layer Index')\n", + "axes[1].set_ylabel('Max Absolute Difference')\n", + "axes[1].legend()\n", + "axes[1].grid(True)\n", + "\n", + "\n", + "\n", + "plt.title('Per-layer Max Absolute Differences')\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(hf_tokens_bf16[106:120])\n", + "print(fm_tokens_b16[106:120])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(hf_tokens[106:120])\n", + "print(fm_tokens[106:120])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hf_tokens_bf16 = hf_tokens\n", + "fm_tokens_b16 = fm_tokens" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "min(len(hf_tokens)+1 if ab[0] == ab[1] else i for i, ab in enumerate(zip(hf_tokens, fm_tokens)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "min(len(hf_tokens)+1 if ab[0] == ab[1] else i for i, ab in enumerate(zip(hf_tokens, hf_tokens_bf16)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "min(len(hf_tokens)+1 if ab[0] == ab[1] else i for i, ab in enumerate(zip(fm_tokens, fm_tokens_b16)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "min(len(hf_tokens)+1 if ab[0] == ab[1] else i for i, ab in enumerate(zip(hf_tokens, fm_tokens)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import safetensors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# this is just to show possibility\n", + "# assumes no converiosn of key names or tensors or aggregation of tensors is needed\n", + "def load(path, model):\n", + " with safetensors.safe_open(path, 'pt', device=model.distributed.device) as f:\n", + " key = 'model.embed_tokens.weight'\n", + " # this would load only part of the tensor for this tensor parallel, etc rank\n", + " # get_local_slice_ranges would return a multidimensional range object \n", + " tensor = f.get_slice(key)[model.get_local_slice_ranges(key)]\n", + " model.import_tensor(key, tensor)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fast_llm.engine.distributed.config import DistributedConfig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"| rank | local_rank | tensor_rank | pipeline_rank | data_rank | sequence_data_rank | batch_data_rank | | | | | | |\")\n", + "print(\"| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\")\n", + "for rank in range(16):\n", + " cfg = DistributedConfig(rank=rank, world_size=16, local_world_size=8, tensor_parallel=2, pipeline_parallel=2, sequence_data_parallel=2, pipeline_first=True)\n", + " res = f\"| {cfg.rank} | {cfg.local_rank} | {cfg.tensor_rank} | {cfg.pipeline_rank} | {cfg.data_rank} | {cfg.sequence_data_rank} | {cfg.batch_data_rank} |\"\n", + " for name, dm in cfg.distributed_dims.items():\n", + " if name == 'world':\n", + " continue\n", + " res += f\"{name}_{dm.id} |\"\n", + " print(res)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res = '|'\n", + "for name, dm in cfg.distributed_dims.items():\n", + " if name == 'world':\n", + " continue\n", + " res += f\"{name}_{dm.id} |\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"/mnt/checkpoints/test/denis/smol_eval_experiment_test/lm_eval/batch_0.pkl\", 'rb') as f:\n", + " data = pickle.load(f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data[1:]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/qwen_evaluate.yaml b/examples/qwen_evaluate.yaml new file mode 100644 index 00000000..eb6d8752 --- /dev/null +++ b/examples/qwen_evaluate.yaml @@ -0,0 +1,87 @@ +training: + train_iters: 100_000 + logs: + interval: 10 + evaluations: + gsm8k: + run_interval: + interval: 10 + evaluator: + type: lm_eval + cli_args: + - --tasks + - gsm8k + - --output_path + - /mnt/checkpoints/test/denis/smol_eval_experiment/lm_eval + stack_3b: + run_interval: + interval: 10 + evaluator: + type: loss + iterations: 10 + dataset_name: stack_3b + fineweb: + run_interval: + interval: 10 + evaluator: + iterations: 10 + dataset_name: fineweb + checkpoint: + interval: 1000 + keep: 5 + test_iters: 0 + export: # (1)! + format: llama + interval: 20_000 +batch: + micro_batch_size: 16 + sequence_length: 4096 + batch_size: 32 +data: + tokenizer: + path: /mnt/checkpoints/pretrained_models/Qwen2-1.5B-Instruct + bos_token: "<|endoftext|>" + datasets: + # Bad dataset they are tokenized with different tokenizer, then llama + training: + type: file + path: /mnt/datasets/test/denis/fineweb_the_stack_3b.yaml + stack_3b: + type: memmap + path: /mnt/datasets/data_collections/the_stack_3b/tokens/stack_3b/default/train/99 + fineweb: + type: memmap + path: /mnt/datasets/data_collections/standalone_datasets/tokens/HuggingFaceFW/fineweb/default/train/9_1000 +optimizer: + weight_decay: 0.1 + beta_1: 0.9 + beta_2: 0.95 + learning_rate: + base: 1.0e-04 # (3)! + minimum: 1.0e-05 + decay_style: cosine + decay_iterations: 100_000 + warmup_iterations: 2000 +pretrained: # (4)! + format: qwen2 + path: /mnt/checkpoints/pretrained_models/Qwen2-1.5B-Instruct + model_weights: yes # (5)! +model: + base_model: + transformer: + use_flash_attention: yes + cross_entropy_impl: fused + multi_stage: + zero_stage: 2 + distributed: + training_dtype: bf16 + +run: + experiment_dir: "/mnt/checkpoints/test/denis/qwen_eval_experiment" + +# training: +# logs: +# interval: 10 +# wandb: +# project_name: ${job.project_name} +# group_name: ${job.project_version} \ No newline at end of file diff --git a/examples/smol_evaluate.yaml b/examples/smol_evaluate.yaml new file mode 100644 index 00000000..14d992fe --- /dev/null +++ b/examples/smol_evaluate.yaml @@ -0,0 +1,86 @@ +training: + train_iters: 100_000 + logs: + interval: 10 + evaluations: + gsm8k: + run_interval: + interval: 10 + evaluator: + type: lm_eval + cli_args: + - --tasks + - gsm8k + - --output_path + - /mnt/checkpoints/test/denis/smol_eval_experiment/lm_eval + stack_3b: + run_interval: + interval: 10 + evaluator: + type: loss + iterations: 10 + dataset_name: stack_3b + fineweb: + run_interval: + interval: 10 + evaluator: + iterations: 10 + dataset_name: fineweb + checkpoint: + interval: 1000 + keep: 5 + test_iters: 0 + export: # (1)! + format: llama + interval: 20_000 +batch: + micro_batch_size: 16 + sequence_length: 4096 + batch_size: 32 +data: + tokenizer: + path: /mnt/checkpoints/pretrained_models/SmolLM2-135M-Instruct + datasets: + # Bad dataset they are tokenized with different tokenizer, then llama + training: + type: file + path: /mnt/datasets/test/denis/fineweb_the_stack_3b.yaml + stack_3b: + type: memmap + path: /mnt/datasets/data_collections/the_stack_3b/tokens/stack_3b/default/train/99 + fineweb: + type: memmap + path: /mnt/datasets/data_collections/standalone_datasets/tokens/HuggingFaceFW/fineweb/default/train/9_1000 +optimizer: + weight_decay: 0.1 + beta_1: 0.9 + beta_2: 0.95 + learning_rate: + base: 1.0e-04 # (3)! + minimum: 1.0e-05 + decay_style: cosine + decay_iterations: 100_000 + warmup_iterations: 2000 +pretrained: # (4)! + format: llama + path: /mnt/checkpoints/pretrained_models/SmolLM2-135M-Instruct/ + model_weights: yes # (5)! +model: + base_model: + transformer: + use_flash_attention: yes + cross_entropy_impl: fused + multi_stage: + zero_stage: 2 + distributed: + training_dtype: bf16 + +run: + experiment_dir: "/mnt/checkpoints/test/denis/smol_eval_experiment" + +# training: +# logs: +# interval: 10 +# wandb: +# project_name: ${job.project_name} +# group_name: ${job.project_version} \ No newline at end of file diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index e82e0801..09daa438 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -6,12 +6,17 @@ Todo: Move all core methods elsewhere (functional?). """ +import collections import contextlib import datetime +import io +import itertools import logging +import pickle import typing import torch +import torch.monitor from torch._C._distributed_c10d import Work from torch.distributed import ( # noqa ProcessGroup, @@ -26,6 +31,117 @@ logger = logging.getLogger(__name__) +def _as_iterable(obj) -> collections.abc.Iterable: + return obj if isinstance(obj, list) else (obj,) + + +def _check_single_tensor(param, param_name) -> None: + """Check that the parameter ``param_name`` is a single tensor.""" + if not isinstance(param, torch.Tensor): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type torch.Tensor + but got {type(param)} instead.""" + ) + + +def _check_tensor_list(param, param_name) -> None: + """Check that the parameter ``param_name`` is a list of tensors.""" + if not isinstance(param, list): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] + but got {type(param)} instead.""" + ) + elif not all(isinstance(p, torch.Tensor) for p in param): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] + but got {type(param)} with elements of type {[type(p) for p in param]}.""" + ) + + +def _ensure_all_tensors_same_dtype(*tensors) -> None: + last_dtype = None + for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): + tensor_dtype = tensor.dtype + # Mixing complex and its element type is allowed + if tensor_dtype.is_complex: + tensor_dtype = torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 + + if last_dtype is None: + last_dtype = tensor_dtype + else: + if last_dtype != tensor_dtype: + raise ValueError( + "Invalid usage of tensors with different dtypes" f"Found {last_dtype} and {tensor.dtype}" + ) + + +def _rank_not_in_group(group: typing.Optional[ProcessGroup]) -> bool: + """Check if the current process's rank is not in a given group.""" + if group is None: + return False + return group == torch.distributed.GroupMember.NON_GROUP_MEMBER + + +def _warn_not_in_group(op_name) -> None: + # TODO: get global rank + global_rank = -1 + logger.warning(f"Running {op_name} on global rank {global_rank} which does not " "belong to the given group.") + + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +def _object_to_tensor(obj, device, group): + with torch.monitor._WaitCounter("pytorch.wait_counter.c10d._object_to_tensor").guard(): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage).to(device) + + # TODO: do we need to log this level of details? + # if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): + # backend = get_backend(group) + # if backend == Backend.NCCL: + # hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) + # logger.warning( + # "_object_to_tensor size: %s hash value: %s", + # byte_tensor.numel(), + # hash, + # ) + + local_size = torch.LongTensor([byte_tensor.numel()]).to(device) + return byte_tensor, local_size + + +def _tensor_to_object(tensor, tensor_size, group): + with torch.monitor._WaitCounter("pytorch.wait_counter.c10d._tensor_to_object").guard(): + + # TODO: do we need to log this level of details? + # if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): + # backend = get_backend(group) + # if backend == Backend.NCCL: + # hash = torch._C._distributed_c10d._hash_tensors([tensor]) + # logger.warning( + # "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash + # ) + + tensor = tensor.cpu() + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +def _validate_output_list_for_rank(my_rank, dst, gather_list): + if dst == my_rank: + if not gather_list: + raise ValueError("Argument ``gather_list`` must be specified on destination rank.") + elif gather_list: + raise ValueError("Argument ``gather_list`` must NOT be specified on non-destination ranks.") + + def add_ephemeral_timeout(group: ProcessGroup, timeout: float | None = None) -> None: if group is not None and timeout is not None: # TODO: Only works for nccl? @@ -133,3 +249,406 @@ def set_generator(generator: torch.Generator) -> typing.Generator[None, None, No finally: generator.set_state(default_generator.get_state()) default_generator.set_state(old_state) + + +def gather( + tensor: torch.Tensor, + gather_list: typing.Optional[list[torch.Tensor]] = None, + group: typing.Optional[ProcessGroup] = None, + async_op: bool = False, + group_dst: typing.Optional[int] = None, +): + """ + Gathers a list of tensors in a single process. + + This function requires all tensors to be the same size on each process. + + Args: + tensor (Tensor): Input tensor. + gather_list (list[Tensor], optional): List of appropriately, + same-sized tensors to use for gathered data + (default is None, must be specified on the destination rank) + group (ProcessGroup, optional): The process group to work on. + async_op (bool, optional): Whether this op should be an async op + group_dst (int, optional): Destination rank on ``group``. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: Note that all Tensors in gather_list must have the same size. + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> # We have 2 process groups, 2 ranks. + >>> tensor_size = 2 + >>> device = torch.device(f'cuda:{rank}') + >>> tensor = torch.ones(tensor_size, device=device) + rank + >>> if dist.get_rank() == 0: + >>> gather_list = [torch.zeros_like(tensor, device=device) for i in range(2)] + >>> else: + >>> gather_list = None + >>> dist.gather(tensor, gather_list, dst=0) + >>> # Rank 0 gets gathered data. + >>> gather_list + [tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0 + None # Rank 1 + + """ + _check_single_tensor(tensor, "tensor") + + # Parameter ``gather_list`` may be left unspecified on non-dst ranks. + if gather_list: + _check_tensor_list(gather_list, "gather_list") + else: + gather_list = [] + _ensure_all_tensors_same_dtype(tensor, gather_list) + assert group is not None + if _rank_not_in_group(group): + _warn_not_in_group("gather") + return + if group_dst is None: + group_dst = 0 + my_group_rank = group.rank() + _validate_output_list_for_rank(my_group_rank, group_dst, gather_list) + output_tensors = [gather_list] if group_dst == my_group_rank else [] + input_tensors = [tensor] + + opts = torch.distributed.GatherOptions() + opts.rootRank = group_dst + # Absent in ver 2.6 + #opts.asyncOp = async_op + work = group.gather(output_tensors, input_tensors, opts) + + if async_op: + return work + elif work is not None: # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +def scatter( + tensor: torch.Tensor, + scatter_list: typing.Optional[list[torch.Tensor]] = None, + group: typing.Optional[ProcessGroup] = None, + async_op: bool = False, + group_src: typing.Optional[int] = None, +): + """ + Scatters a list of tensors to all processes in a group. + + Each process will receive exactly one tensor and store its data in the + ``tensor`` argument. + + Complex tensors are supported. + + Args: + tensor (Tensor): Output tensor. + scatter_list (list[Tensor]): List of tensors to scatter (default is + None, must be specified on the source rank) + group (ProcessGroup, optional): The process group to work on. + async_op (bool, optional): Whether this op should be an async op + group_src (int, optional): Source rank on ``group``. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: Note that all Tensors in scatter_list must have the same size. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> tensor_size = 2 + >>> device = torch.device(f'cuda:{rank}') + >>> output_tensor = torch.zeros(tensor_size, device=device) + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> # Only tensors, all of which must be the same size. + >>> t_ones = torch.ones(tensor_size, device=device) + >>> t_fives = torch.ones(tensor_size, device=device) * 5 + >>> scatter_list = [t_ones, t_fives] + >>> else: + >>> scatter_list = None + >>> dist.scatter(output_tensor, scatter_list, src=0) + >>> # Rank i gets scatter_list[i]. + >>> output_tensor + tensor([1., 1.], device='cuda:0') # Rank 0 + tensor([5., 5.], device='cuda:1') # Rank 1 + + """ + _check_single_tensor(tensor, "tensor") + # Parameter ``scatter_list`` may be left unspecified on non-src ranks. + if scatter_list: + _check_tensor_list(scatter_list, "scatter_list") + else: + scatter_list = [] + _ensure_all_tensors_same_dtype(tensor, scatter_list) + assert group is not None + if group_src is None: + group_src = 0 + if _rank_not_in_group(group): + _warn_not_in_group("scatter") + return + scatter_list = [t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + + my_group_rank = group.rank() + if group_src == my_group_rank: + if not scatter_list: + raise ValueError("Argument ``scatter_list`` must be specified on source rank.") + input_tensors = [scatter_list] + output_tensors = [tensor] + else: + if scatter_list: + raise ValueError("Argument ``scatter_list`` must NOT be specified on non-source ranks.") + input_tensors = [] + output_tensors = [tensor] + + opts = torch.distributed.ScatterOptions() + opts.rootRank = group_src + opts.asyncOp = async_op + work = group.scatter(output_tensors, input_tensors, opts) + + if async_op: + return work + elif work is not None: # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +def gather_object( + current_device: torch.device | str, + obj: typing.Any, + object_gather_list: typing.Optional[list[typing.Any]] = None, + group: typing.Optional[ProcessGroup] = None, + group_dst: typing.Optional[int] = None, +): + """ + Gathers picklable objects from the whole group in a single process. + + Similar to :func:`gather`, but Python objects can be passed in. Note that the + object must be picklable in order to be gathered. + + Args: + current_device: (torch.device | str): device to use for object serialization to + tensor, must be this process assigned gpu for nccl backend. + obj (Any): Input object. Must be picklable. + object_gather_list (list[Any]): Output list. On the ``dst`` rank, it + should be correctly sized as the size of the group for this + collective and will contain the output. Must be ``None`` on non-dst + ranks. (default is ``None``) + dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). + (If both ``dst`` and ``group_dst`` are None, default is global rank 0) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` + + Returns: + None. On the ``dst`` rank, ``object_gather_list`` will contain the + output of the collective. + + .. note:: Note that this API differs slightly from the gather collective + since it does not provide an async_op handle and thus will be a blocking + call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`gather_object` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`gather` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.gather_object( + ... gather_objects[dist.get_rank()], + ... output if dist.get_rank() == 0 else None, + ... dst=0 + ... ) + >>> # On rank 0 + >>> output + ['foo', 12, {1: 2}] + """ + assert group is not None + if group_dst is None: + group_dst = 0 + if _rank_not_in_group(group): + _warn_not_in_group("gather_object") + return + + # Ensure object_gather_list is specified appropriately. + my_group_rank = group.rank() + _validate_output_list_for_rank(my_group_rank, group_dst, object_gather_list) + input_tensor, local_size = _object_to_tensor(obj, current_device, group) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = group.size() + object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device) + object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)] + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. + all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + # Avoid populating output tensors if the result won't be gathered on this rank. + if my_group_rank == group_dst: + coalesced_output_tensor = torch.empty(max_object_size * group_size, dtype=torch.uint8, device=current_device) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] for i in range(group_size) + ] + # All ranks call gather with equal-sized tensors. + gather( + input_tensor, + gather_list=output_tensors if my_group_rank == group_dst else None, # type: ignore[possibly-undefined] + group_dst=group_dst, + group=group, + ) + if my_group_rank != group_dst: + return + + assert object_gather_list is not None, "Must provide object_gather_list on dst rank" + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group) + + +def scatter_object_list( + pg_device: torch.device | str, + scatter_object_output_list: list[typing.Any], + scatter_object_input_list: typing.Optional[list[typing.Any]] = None, + group: typing.Optional[ProcessGroup] = None, + group_src: typing.Optional[int] = None, +): + """ + Scatters picklable objects in ``scatter_object_input_list`` to the whole group. + + Similar to :func:`scatter`, but Python objects can be passed in. On + each rank, the scattered object will be stored as the first element of + ``scatter_object_output_list``. Note that all objects in + ``scatter_object_input_list`` must be picklable in order to be scattered. + + Args: + pg_device: (torch.device | str): device to use for object serialization to + tensor, must be this process assigned gpu for nccl backend. + scatter_object_output_list (List[Any]): Non-empty list whose first + element will store the object scattered to this rank. + scatter_object_input_list (List[Any], optional): List of input objects to scatter. + Each object must be picklable. Only objects on the ``src`` rank will + be scattered, and the argument can be ``None`` for non-src ranks. + group: (ProcessGroup, optional): The process group to work on. + group_src (int, optional): Source rank on ``group``. + + Returns: + ``None``. If rank is part of the group, ``scatter_object_output_list`` + will have its first element set to the scattered object for this rank. + + .. note:: Note that this API differs slightly from the scatter collective + since it does not provide an ``async_op`` handle and thus will be a + blocking call. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`scatter_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`scatter_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`scatter` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> # Can be any list on non-src ranks, elements are not used. + >>> objects = [None, None, None] + >>> output_list = [None] + >>> dist.scatter_object_list(output_list, objects, src=0) + >>> # Rank i gets objects[i]. For example, on rank 2: + >>> output_list + [{1: 2}] + """ + assert group is not None + if group_src is None: + group_src = 0 + if _rank_not_in_group(group): + _warn_not_in_group("scatter_object_list") + return + + if not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1: + raise ValueError("Expected argument scatter_object_output_list to be a list of size at least 1.") + + my_group_rank = group.rank() + if my_group_rank == group_src: + if scatter_object_input_list is None: + raise ValueError("source rank must provide non-None scatter_object_input_list") + tensor_list, tensor_sizes = zip( + *[_object_to_tensor(obj, pg_device, group) for obj in scatter_object_input_list] + ) + tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) + + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. + max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined] + for tensor in tensor_list: # type: ignore[possibly-undefined] + tensor.resize_(max_tensor_size) + else: + max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) + broadcast(max_tensor_size, src=group_src, group=group) + + # Scatter actual serialized objects + output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device) + scatter( + output_tensor, + scatter_list=None if my_group_rank != group_src else tensor_list, # type: ignore[possibly-undefined] + group_src=group_src, + group=group, + ) + + # Scatter per-object sizes to trim tensors when deserializing back to object + obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) + scatter( + obj_tensor_size, + scatter_list=None if my_group_rank != group_src else tensor_sizes, # type: ignore[possibly-undefined] + group_src=group_src, + group=group, + ) + + # Deserialize back to object + scatter_object_output_list[0] = _tensor_to_object(output_tensor, obj_tensor_size, group) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 1586d370..4c041945 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -34,3 +34,8 @@ class TokenizerConfig(Config): desc="Path to the tokenizer file.", hint=FieldHint.core, ) + bos_token: str | None = Field( + default=None, + desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.", + hint=FieldHint.core, + ) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 02c1b6c0..2b0be0a0 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -106,8 +106,10 @@ def setup( self._datasets = {} for dataset_name, sampling_parameters in self._sampling_parameters.items(): if self._tokenizer is not None: - # TODO: Too constraining? - Assert.eq(self._tokenizer.vocab_size, sampling_parameters.vocab_size) + # NOTE: Some models like Qwen2-1.5B-Instruct + # have vocab_size bigger in model config than in tokenizer + # TODO: Still, is it too constraining? + Assert.geq(sampling_parameters.vocab_size, self._tokenizer.vocab_size) if sampling_parameters.num_samples > 0: sampling = GPTSamplingData( config=self._config.sampling, diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 28e105ee..bc801ed0 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -1,6 +1,6 @@ import numpy as np import torch -from transformers import PreTrainedTokenizerFast +from transformers import PreTrainedTokenizerFast, AutoTokenizer from fast_llm.data.config import TokenizerConfig from fast_llm.engine.config_utils.run import log_main_rank @@ -13,9 +13,18 @@ class Tokenizer: def __init__(self, config: TokenizerConfig): log_main_rank(f"> loading tokenizer from {config.path} ...") - self.tokenizer: PreTrainedTokenizerFast = PreTrainedTokenizerFast.from_pretrained( - pretrained_model_name_or_path=config.path, errors="replace", max_len=None + # self.tokenizer: PreTrainedTokenizerFast = PreTrainedTokenizerFast.from_pretrained( + # pretrained_model_name_or_path=config.path, errors="replace", max_len=None + # ) + self.tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=config.path, + errors="replace", + max_len=None, + trust_remote_code=True, + use_fast=True, # This is the flag you're asking about ) + if config.bos_token is not None: + self.tokenizer.bos_token = config.bos_token if self.tokenizer.eos_token_id is None: raise ValueError("Tokenizer does not have an EOS token.") if self.tokenizer.bos_token_id is None: diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py new file mode 100644 index 00000000..3adf5ad3 --- /dev/null +++ b/fast_llm/engine/evaluation/config.py @@ -0,0 +1,163 @@ +import abc +import typing + +from fast_llm.config import ( + Config, + Field, + FieldHint, + check_field, + config_class, + skip_valid_if_none, +) +from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.utils import Assert, Registry + +if typing.TYPE_CHECKING: + from fast_llm.engine.evaluation.evaluator import Evaluator, EvaluatorLoss, EvaluatorLmEval, TrainingEvaluator + +@config_class() +class EvaluatorConfigBase(Config): + @abc.abstractmethod + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "Evaluator": + pass + + +@config_class() +class EvaluatorConfig(EvaluatorConfigBase): + _abstract: typing.ClassVar[bool] = True + # TODO: Generalize dynamic types? + _registry: typing.ClassVar[Registry[str, type["EvaluatorConfig"]]] = Registry[str, type["EvaluationConfig"]]( + "evaluation_class", {} + ) + type_: typing.ClassVar[str | None] = None + type: str | None = Field( + default=None, + desc="The type of evaluation.", + hint=FieldHint.core, + ) + + def _validate(self) -> None: + if self.type is None: + self.type = self.type_ + # Should be handled in `from_dict`, but can fail if instantiating directly. + Assert.eq(self.type, self.__class__.type_) + super()._validate() + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + type_ = default.get("type") + if type_ is None: + # TODO: Remove in version 0.* — this is for backward compatibility. + # If 'type' is not provided, it falls back to 'loss'. + type_ = "loss" + default["type"] = type_ + actual_cls = EvaluatorLossConfig + # actual_cls = cls + else: + if type_ not in cls._registry: + raise ValueError( + f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}" + ) + actual_cls = cls._registry[type_] + Assert.custom(issubclass, actual_cls, cls) + if actual_cls == cls: + return super()._from_dict(default, strict=strict, flat=flat) + else: + return actual_cls._from_dict(default, strict=strict, flat=flat) + + def __init_subclass__(cls) -> None: + if cls._abstract and cls.type_ is not None: + # Abstract classes should not have a `type_` + raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.") + if cls.type_ is not None: + if cls.type_ in cls._registry: + raise ValueError( + f"Registry {cls._registry.name} already contains type {cls.type_}." + f" Make sure all classes either have a unique or `None` type." + ) + EvaluatorConfig._registry[cls.type_] = cls + super().__init_subclass__() + + + +@config_class() +class EvaluatorLossConfig(EvaluatorConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "loss" + + iterations: int | None = Field( + default=None, + desc="Number of iterations for each evaluation phase. Setting to None will disable.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + + dataset_name: str | None = Field(default=None) + + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "EvaluatorLoss": + from fast_llm.engine.evaluation.evaluator import EvaluatorLoss + + return EvaluatorLoss(name, self, batch_config, data_load_num_proc, train_iters) + + +@config_class() +class EvaluatorLmEvalConfig(EvaluatorConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "lm_eval" + + cli_args: list[str] = Field( + default_factory=lambda: [], + desc="lm_eval CLI arguments, excluding those related to model, wandb, batch sizes, and device.", + ) + + truncation: bool = Field( + default=False, + desc="Whether to use truncation during tokenization (useful when inputs exceed model's max length);" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + logits_cache: bool = Field( + default=True, + desc="Whether to enable logits caching for speedup and avoiding recomputation during repeated evaluations;" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + add_bos_token: bool = Field( + default=False, + desc="Whether to prepend a beginning-of-sequence (BOS) token, required for some models like LLaMA;" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + prefix_token_id: int | None = Field( + default=None, + desc="Token ID to use as a prefix to the input (e.g., for control codes or prompts);" + " passed to the Fast-LLM lm_eval model wrapper.", + ) + + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "EvaluatorLmEval": + from fast_llm.engine.evaluation.evaluator import EvaluatorLmEval + + return EvaluatorLmEval(name, self, batch_config, data_load_num_proc, train_iters) diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py new file mode 100644 index 00000000..bc7c9335 --- /dev/null +++ b/fast_llm/engine/evaluation/evaluator.py @@ -0,0 +1,408 @@ +import abc +import dataclasses +import logging +import time +import typing + + +from fast_llm.config import Configurable +from fast_llm.core.distributed import safe_barrier +from fast_llm.data.data.abstract import Data +from fast_llm.engine.config_utils.run import Run, log_main_rank +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel + +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.engine.evaluation.config import ( + EvaluatorConfigBase, + EvaluatorConfig, + EvaluatorLossConfig, + EvaluatorLmEvalConfig, + +) +from fast_llm.engine.training.config import WandbConfig +from fast_llm.engine.training.wandb import Wandb +from fast_llm.logging import format_metrics, get_memory_usage_mib +from fast_llm.engine.evaluation.lm_eval.fast_llm_wrapper import FastLLMLmEvalWrapper +from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results +from fast_llm.engine.schedule.config import BatchConfig + +# from fast_llm.engine.training.lm_eval.evaluator import simple_evaluate as lm_eval_simple_evaluate +from lm_eval.evaluator import simple_evaluate as lm_eval_simple_evaluate + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class TrainingProgress: + done: bool + completed_steps: int + consumed_samples: int + consumed_tokens: int + + +@dataclasses.dataclass +class EvaluationMetrics: + metrics: dict[str, any] = dataclasses.field(default_factory=dict) + formatted_metrics: str | None = None + + +@dataclasses.dataclass +class EvaluatorSamplingParameters: + dataset_name: str + num_samples: int + + +class Evaluator[ConfigType: EvaluatorConfig](Configurable[ConfigType], abc.ABC): + config_class: typing.ClassVar[type[EvaluatorConfig]] = EvaluatorConfig + + _is_setup: bool = False + + def __init__( + self, + name: str, + eval_config: EvaluatorLossConfig, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ): + super().__init__(eval_config) + self._name = name + self._batch_config = batch_config + self._data_load_num_proc = data_load_num_proc + self._train_iters = train_iters + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + # TODO: check if objects passed are actually set up themselves, if appropriate + self._distributed = distributed + self._run = run + self._runner = runner + self._multi_stage = multi_stage + self._data = data + self._phase = phase + + @abc.abstractmethod + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: ... + + @abc.abstractmethod + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + """ + Returns the name and number of required samples in a dataset, + or None if the evaluation does not rely on Fast-LLM data or + if the evaluation is skipped for this run. + """ + + +class EvaluatorLoss[ConfigType: EvaluatorLossConfig](Evaluator[ConfigType]): + config_class: typing.ClassVar[type[EvaluatorLossConfig]] = EvaluatorLossConfig + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + super().setup(distributed, run, multi_stage, runner, data, phase) + + # Setup the schedule + self._schedule = Schedule( + multi_stage=self._multi_stage, + batch_config=self._batch_config, + schedule_config=runner.config, + distributed_config=distributed.config, + phase=PhaseType.inference if self._phase == PhaseType.inference else PhaseType.validation, + ) + + self._loss_defs = self._multi_stage.base_model.loss_defs + self._evaluation_iterator = None + self._is_setup = True + + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + return EvaluatorSamplingParameters( + (self._name if self._config.dataset_name is None else self._config.dataset_name), + self._config.iterations * self._batch_config.batch_size, + ) + + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: + assert self._is_setup + if run_index is None: + run_index = 0 + + metrics = {} + formatted_metrics = None + + if self._evaluation_iterator is None: + self._evaluation_iterator = self._get_data_iterator(self._get_completed_evaluation_steps(run_index)) + # TODO: formatting metric category as Validation.evaluation_dataset_name + # maybe format each metric with evaluation_dataset_name prefix instead? + # TODO: setting performance metrics per evaluation dataset + # maybe to set aggregate performance metrics for all evaluations datasets? + phase = PhaseType.inference if self._phase == PhaseType.inference else PhaseType.validation + metric_key = f"{phase.value}.{self._name}" + metrics[metric_key] = self._evaluate_loss( + data_iterator=self._evaluation_iterator, + phase=phase, + num_iters=self._config.iterations, + begin_iter=self._get_completed_evaluation_steps(run_index), + completed_steps=None if training_progress is None else training_progress.completed_steps, + ) + + if self._train_iters is not None: + metrics[metric_key]["train_iters"] = self._train_iters + + if training_progress is not None: + metrics[metric_key]["iteration"] = training_progress.completed_steps + metrics[metric_key]["consumed_samples"] = training_progress.consumed_samples + metrics[metric_key]["consumed_tokens"] = training_progress.consumed_tokens + + formatted_metrics = format_metrics( + metrics[metric_key], + self._loss_defs, + phase, + dataset_name=self._name, + ) + + return EvaluationMetrics(metrics, formatted_metrics) + + def _evaluate_loss( + self, + *, + data_iterator: typing.Iterator, + phase: PhaseType, + num_iters: int, + completed_steps: int | None, + begin_iter: int = 0, + ) -> dict[str, float | int]: + full_phase_name = f"{phase.value}_{self._name}" + safe_barrier(self._distributed.world_group, f"{full_phase_name} begin") + begin_time = time.perf_counter() + total_losses = {loss_def.name: 0.0 for loss_def in self._loss_defs} + for iter_ in range(num_iters): + iter_losses, _, _ = self._runner.run_step(data_iterator, self._schedule, iteration=begin_iter + iter_) + for name, value in iter_losses.items(): + total_losses[name] += value + + tensor_save_name = ( + f"{full_phase_name}_{iter_}" + if completed_steps is None + else f"{full_phase_name}_{completed_steps}_{iter_}" + ) + self._run.save_logged_tensors(tensor_save_name) + + safe_barrier( + self._distributed.world_group, + f"{full_phase_name} end", + ) + end_time = time.perf_counter() + time_per_iteration = (end_time - begin_time) / num_iters + model_tflops, hardware_tflops = self._multi_stage.get_tflops( + phase, + time_per_iteration, + self._batch_config.batch_size, + self._batch_config.sequence_length, + ) + # TODO add other relevant eval metrics + metrics = { + "batch_size": self._batch_config.batch_size, + **{name: (value / num_iters) for name, value in total_losses.items()}, + "step_time_ms": time_per_iteration * 1000, + "model_tflops": model_tflops, + "hardware_tflops": hardware_tflops, + "tokens_per_sec_per_gpu": ( + (self._batch_config.sequence_length * self._batch_config.batch_size) + / self._schedule._distributed.world_size + / time_per_iteration + ), + **get_memory_usage_mib(), + } + return metrics + + def _get_completed_evaluation_steps(self, run_index: int) -> int: + # Number of evaluations steps performed before the current step + return max(0, run_index - 1) * self.config.iterations + + def _get_data_iterator( + self, completed_steps: int = 0, prefetch_factor: int | None = None + ) -> typing.Iterator[typing.Any]: + return self._data.get_iterator( + self._batch_config, + self._name, + consumed_samples=completed_steps * self._batch_config.batch_size, + num_workers=self._data_load_num_proc, + prefetch_factor=prefetch_factor, + ) + + +class EvaluatorLmEval[ConfigType: EvaluatorLmEvalConfig](Evaluator[ConfigType]): + config_class: typing.ClassVar[type[EvaluatorLmEvalConfig]] = EvaluatorLmEvalConfig + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + super().setup(distributed, run, multi_stage, runner, data, phase) + + # TODO: pass mini and batch size of the same length for lm_eval not to crash during training + # or implement min batch sequential awareness in fas_llm_wrapper for lm_eval + self._hf_model = self._multi_stage.config_class.get_huggingface_model_for_causal_lm_class().from_model( + self._multi_stage, self._batch_config.micro_batch_size, self._runner + ) + + # For reporting purposes, just to indicate it is from Fast-LLM + # as lm_eval.simple_evaluate will take it for results['config']['model'] + self._hf_model.config.name_or_path = type(self._hf_model).__name__ + + self._flm_wrapper = FastLLMLmEvalWrapper( + model=self._hf_model, + tokenizer=self._data.tokenizer.tokenizer, + truncation=self._config.truncation, + logits_cache=self._config.logits_cache, + add_bos_token=self._config.add_bos_token, + prefix_token_id=self._config.prefix_token_id, + ) + self._is_setup = True + + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: + assert self._is_setup + + # TODO: use run_index instead? + # completed_steps is added to output_path like output_path/runs/run_index/completed_steps/ + completed_steps = 0 if training_progress is None else training_progress.completed_steps + + if self._run.is_main_rank: + args, simple_eval_kwargs = prepare_lm_eval_simple_eval_params( + self._config.cli_args, completed_steps, self._run.index + ) + simple_eval_kwargs["model"] = self._flm_wrapper + + # Needed for reporting as batch_size is set from args not lm for reporting in evaluate + simple_eval_kwargs["batch_size"] = self._flm_wrapper.batch_size + simple_eval_kwargs["max_batch_size"] = self._flm_wrapper.max_batch_size + + # As of lm_eval commit 758c5ed891b1ca48acd8d3a0d309a827215796b7 + # Expected to be a string even if empty and not None in simple_evaluate + simple_eval_kwargs["model_args"] = "" + + results = lm_eval_simple_evaluate(**simple_eval_kwargs) + self._flm_wrapper.stop_workers() + + # Evaluation_tracker save expects model to be either string, but if model is passed + # LM wrapper needs to be deep copyable and json serializable + simple_eval_kwargs["evaluation_tracker"].general_config_tracker.model_source = ( + self._hf_model.config.name_or_path + ) + + if results is not None: + process_lm_eval_results( + args, + results, + simple_eval_kwargs["evaluation_tracker"], + completed_steps, + ) + else: + self._flm_wrapper.worker_model_invoke() + + # TODO: do we need it here as self._flm_wrapper.stop_workers() and self._flm_wrapper.worker_model_invoke() + # already have barrier + safe_barrier(self._distributed.world_group, f"Evaluation Harness Run end") + + # lm_eval logs to disc, wandb and prints to screen itself + return EvaluationMetrics() + + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + return None + + +# NOTE: This is not a standalone runnable; it's a submodule of Trainer used for code encapsulation. +class EvaluatorRunner: + _is_setup: bool = False + + def __init__( + self, + evaluator_configs: dict[str, EvaluatorConfigBase], + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + wandb_config: WandbConfig | None = None, + ): + self._wandb_config = wandb_config + self._evaluations = [ + eval_config.get_evaluator(name, batch_config, data_load_num_proc, train_iters) + for name, eval_config in evaluator_configs.items() + ] + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + wandb: Wandb, + phase: PhaseType, + ) -> None: + self._wandb = wandb + for evaluation in self._evaluations: + evaluation.setup(distributed, run, multi_stage, runner, data, phase) + self._is_setup = True + + def get_sampling_parameters(self) -> list[EvaluatorSamplingParameters]: + return [ + sampling_params + for sampling_params in (evaluation.get_sampling_parameters() for evaluation in self._evaluations) + if sampling_params is not None + ] + + def run( + self, + metrics: dict[str:any], + training_progress: TrainingProgress | None = None, + ): + assert self._is_setup + formatted_metrics = [] + for evaluation in self._evaluations: + evaluation_metrics = evaluation.run(training_progress) + if len(evaluation_metrics.metrics) == 0: + continue + for k, v in evaluation_metrics.metrics.items(): + metrics[k] = v + if evaluation_metrics.formatted_metrics is not None: + formatted_metrics.append(evaluation_metrics.formatted_metrics) + + if len(formatted_metrics) > 0: + formatted_metrics = "\n".join(formatted_metrics) + log_main_rank(formatted_metrics) + if self._wandb_config is not None and self._wandb_config.alert.enabled( + 0 if training_progress is None else training_progress.completed_steps + ): + self._wandb.alert("Validation results", formatted_metrics, "INFO") diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py new file mode 100644 index 00000000..0c138eda --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -0,0 +1,940 @@ +import logging +import copy +import jinja2 + + +import transformers +from tqdm.auto import tqdm +import torch +import torch.nn.functional as F + + +# make lazy +import lm_eval.api.instance +import lm_eval.models.utils +import lm_eval.api.model +import lm_eval.utils + + +from fast_llm.core.distributed import safe_barrier +from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.distributed.config import DistributedConfig + +from fast_llm.core.distributed import scatter_object_list, gather_object + + +eval_logger = logging.getLogger(__name__) + + +class FastLLMLmEvalWrapper(lm_eval.api.model.TemplateLM): + _DEFAULT_MAX_LENGTH = 2048 + + def __init__( + self, + model: HuggingfaceBaseModelForCausalLM, + tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast, + truncation: bool | None = False, + logits_cache: bool = True, + add_bos_token: bool | None = False, + prefix_token_id: int | None = None, + ): + super().__init__() + # This is for lm_eval sake, we always run lm_eval on one main rank + self._rank = 0 + self._world_size = 1 + + self._distributed: Distributed = model._inference_runner._fast_llm_model.distributed + dist_config: DistributedConfig = self._distributed.config + # get batch_data_parallel group leaders + if dist_config.sequence_data_rank == 0 and dist_config.pipeline_rank == 0 and dist_config.tensor_rank == 0: + self.group = self._distributed.batch_data_group + else: + self.group = torch.distributed.GroupMember.NON_GROUP_MEMBER + + # TODO: clean code which does not used parts from HFLM + backend = "causal" + revision = "main" + gguf_file = None + delta = None + peft = None + + # set some inputs which are expected in HFLM but are set by our model config + self.backend = backend + + # set tokenizer object + assert isinstance(tokenizer, transformers.PreTrainedTokenizer) or isinstance( + tokenizer, transformers.PreTrainedTokenizerFast + ) + self.tokenizer = tokenizer + + # initialize model fields + self._model = model + self._device = self._model.device + self._config = self._model.config + + # access self._model through self.model property outside this method + if isinstance(self.model, torch.nn.Module): + self.model.eval() + self.model.tie_weights() + + self.truncation = truncation + self.logits_cache = logits_cache + self.vocab_size = self.tokenizer.vocab_size + # select (or create) a pad token to use + self.tokenizer = lm_eval.models.utils.configure_pad_token(self.tokenizer, model_config=self.config) + + self.add_bos_token = add_bos_token + # TODO: do we support gemma models? + if "gemma" in getattr(self.config, "model_type", ""): + self.add_bos_token = True + eval_logger.info( + f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS" + " token will be used as Gemma underperforms without it." + ) + + self._max_length = model._inference_runner._batch_config.sequence_length + self.pretrained = model + self.delta = delta + self.peft = peft + self.revision = revision + + self.batch_schedule = 1 + self.batch_sizes = {} + self.batch_size_per_gpu = 16 # model._inference_runner._batch_config.micro_batch_size + self.batch_size = self.batch_size_per_gpu * dist_config.batch_data_parallel + self.max_batch_size = self.batch_size + + self.custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + eval_logger.info(f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}") + + def _model_invoke( + self, + input_ids, + attention_mask, + labels, + max_length, + stop, + generate: bool, + continue_generate: bool, + **generation_kwargs, + ): + if self.group is None or (world_size := self.group.size()) == 1: + # Must not be called with continue_generate false on one process + assert continue_generate + return self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + + rank = self.group.rank() + assert rank == 0 + + if continue_generate: + assert input_ids is not None + if generate: + assert max_length is not None and stop is not None + + # always divide by batch_size, if not full batch, some ranks will get less work or not at all + step = self.batch_size // world_size + + input_ids = [input_ids[i * step : (i + 1) * step] for i in range(world_size)] + attention_mask = [ + attention_mask[i * step : (i + 1) * step] if attention_mask is not None else None + for i in range(world_size) + ] + labels = [labels[i * step : (i + 1) * step] if labels is not None else None for i in range(world_size)] + + scatter_list = [ + [ + input_ids[i], + attention_mask[i], + labels[i], + max_length, + stop, + generate, + continue_generate, + generation_kwargs, + ] + for i in range(world_size) + ] + else: + scatter_list = [[None, None, None, None, None, None, False, None] for _ in range(world_size)] + + obj_list = [None] + scatter_object_list( + self._distributed.device, + obj_list, + scatter_list, + group_src=0, + group=self.group, + ) + input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = tuple( + obj_list[0] + ) + + if continue_generate == False: + return + + assert len(input_ids) > 0 + + res = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + + gather_list = [None] * world_size + gather_object( + self._distributed.device, + res, + gather_list, + group_dst=0, + group=self.group, + ) + + # If it was model generate tensors could be of different length + # so we aggregate results to list instead of a tensor + if generate: + res = sum((el.tolist() for el in gather_list), []) + else: + res = torch.cat(gather_list, dim=0) + + return res + + def worker_model_invoke(self): + assert self.group is not None + # if isinstance(self.group, dist.ProcessGroup): + if not isinstance(self.group, int): + assert self.group.size() > 1 and self.group.rank() != 0 + # on worker ranks the function need to wait to be called multiple times + while True: + scatter_list = None + obj_list = [None] + scatter_object_list( + self._distributed.device, + obj_list, + scatter_list, + group_src=0, + group=self.group, + ) + input_ids, attention_mask, labels, max_length, stop, generate, continue_generate, generation_kwargs = ( + tuple(obj_list[0]) + ) + + if continue_generate == False: + break + + # if some data was received, work, otherwise return empty tensor + if len(input_ids) > 0: + res = self._model_invoke_inner( + input_ids, attention_mask, labels, max_length, stop, generate, **generation_kwargs + ) + else: + res = input_ids + + gather_list = None + gather_object( + self._distributed.device, + res, + gather_list, + group_dst=0, + group=self.group, + ) + else: + # TODO: implement distributed model support + assert self.group == torch.distributed.GroupMember.NON_GROUP_MEMBER + safe_barrier(self._distributed.world_group, "lm_eval_end") + + def stop_workers(self): + if self.group is None or (world_size := self.group.size()) == 1: + return + self._model_invoke(None, None, None, None, None, None, continue_generate=False) + safe_barrier(self._distributed.world_group, "lm_eval_end") + + def _model_invoke_inner( + self, input_ids, attention_mask, labels, max_length, stop, generate: bool, **generation_kwargs + ): + if generate: + return self._model_generate_inner(input_ids, attention_mask, max_length, stop, **generation_kwargs) + else: + return self._model_call_inner(input_ids, attention_mask, labels) + + def _model_call(self, input_ids, attention_mask=None, labels=None): + return self._model_invoke( + input_ids, attention_mask, labels, None, None, generate=False, continue_generate=True + ) + + def _model_generate(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): + return self._model_invoke( + input_ids, + attention_mask, + None, + max_length, + stop, + generate=True, + continue_generate=True, + **generation_kwargs, + ) + + def _model_call_inner(self, input_ids, attention_mask=None, labels=None): + """ + :param input_ids: torch.Tensor + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape + [batch, sequence_ctx]. the size of sequence may vary from call to call + :param attention_mask: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :param labels: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :return + A torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model's decoder + """ + # TODO: do we need no_grad for our model? + with torch.no_grad(): + if attention_mask is not None or labels is not None: + assert attention_mask is not None and labels is not None + return self.model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ).logits + else: + return self.model( + input_ids=input_ids, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ).logits + + def _model_generate_inner(self, input_ids, attention_mask, max_length, stop, **generation_kwargs): + # temperature = 0.0 if not set + # if do_sample is false and temp==0.0: + # remove temperature, as do_sample=False takes care of this + # and we don't want a warning from HF + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + # build stopping criteria + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self.tokenizer, stop, input_ids.shape[1], input_ids.shape[0] + ) + if attention_mask is None: + return self.model.generate( + input_ids=input_ids, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=False, + **generation_kwargs, + ) + else: + return self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=False, + **generation_kwargs, + ) + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + if self.custom_prefix_token_id is not None: + return self.custom_prefix_token_id + if self.tokenizer.bos_token_id is not None: + return self.tokenizer.bos_token_id + return self.tokenizer.eos_token_id + + @property + def max_length(self): + if self._max_length: # if max length manually set, return it + return self._max_length + seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") + for attr in seqlen_config_attrs: + if hasattr(self.model.config, attr): + return getattr(self.model.config, attr) + if hasattr(self.tokenizer, "model_max_length"): + if self.tokenizer.model_max_length == 1000000000000000019884624838656: + return self._DEFAULT_MAX_LENGTH + return self.tokenizer.model_max_length + return self._DEFAULT_MAX_LENGTH + + @property + def max_gen_toks(self) -> int: + return 256 + + # TODO: check removing this does not affect lm_eval + # @property + # def batch_size(self): + # return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> list[int]: + """ """ + # default for None - empty dict, use predefined tokenizer param + # used for all models except for CausalLM or predefined value + special_tokens_kwargs = {} + + # by default for CausalLM - false or self.add_bos_token is set + if add_special_tokens is None: + if self.backend == "causal": + special_tokens_kwargs = {"add_special_tokens": False or self.add_bos_token} + # otherwise the method explicitly defines the value + else: + special_tokens_kwargs = {"add_special_tokens": add_special_tokens} + + encoding = self.tokenizer.encode(string, **special_tokens_kwargs) + + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + + return encoding + + def tok_batch_encode( + self, + strings: list[str], + padding_side: str = "left", + left_truncate_len: int = None, + truncation: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. + old_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = padding_side + + add_special_tokens = {} + if self.backend == "causal": + add_special_tokens = {"add_special_tokens": False or self.add_bos_token} + + encoding = self.tokenizer( + strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + **add_special_tokens, + ) + if left_truncate_len: + original_lengths = encoding["input_ids"].size(1) + if original_lengths > left_truncate_len: + eval_logger.warn( + f"Left truncation applied. Original sequence length was {original_lengths}, " + f"truncating to last {left_truncate_len} tokens. Some content will be lost.", + ) + encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] + encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:] + self.tokenizer.padding_side = old_padding_side + + return encoding["input_ids"], encoding["attention_mask"] + + def tok_decode(self, tokens, skip_special_tokens=True): + return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def _select_cont_toks(self, logits: torch.Tensor, contlen: int = None, inplen: int = None) -> torch.Tensor: + if self.backend == "causal": + assert contlen and inplen, "Must pass input len and cont. len to select scored logits for causal LM" + # discard right-padding. + # also discard the input/context tokens. we'll only score continuations. + logits = logits[inplen - contlen : inplen] + elif self.backend == "seq2seq": + assert contlen and not inplen, "Selecting scored logits for Seq2SeqLM requires only cont. len" + # only discard right-padding. + # the logits input to this fn only contain decoder-side tokens. + logits = logits[:contlen] + + return logits + + def loglikelihood_rolling( + self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False + ) -> list[float]: + adaptive_batch_size = None + if self.batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + + # First, collect all windows from all requests + all_windows = [] # List of (request_idx, window) tuples + request_window_counts = [] # Track number of windows per request + + for req_idx, (string,) in enumerate( + tqdm( + [req.args for req in requests], + disable=(disable_tqdm or (self.rank != 0)), + ) + ): + rolling_token_windows: list[tuple[list[int], list[int]]] = list( + map( + lm_eval.utils.make_disjoint_window, + lm_eval.utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.prefix_token_id, + max_seq_len=self.max_length, + context_len=1, + ), + ) + ) + + # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case + windows = [(None,) + x for x in rolling_token_windows] + + # Store windows with their request index + all_windows.extend((req_idx, window) for window in windows) + request_window_counts.append(len(windows)) + + # Handle distributed case padding + pad_amnt = 0 + if self.world_size > 1: + mytensor = torch.tensor(len(all_windows), device=self.device) + gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() + pad_amnt = max(gathered) - gathered[self.rank] + if pad_amnt > 0: + all_windows += pad_amnt * [all_windows[0]] + + all_nlls = [] + batch_size = adaptive_batch_size or self.batch_size + for i in range(0, len(all_windows), batch_size): + batch = all_windows[i : i + batch_size] + # Extract just the windows for processing, keeping track of request indices + batch_indices, batch_windows = zip(*batch) + + batch_nlls = self._loglikelihood_tokens( + requests=batch_windows, + disable_tqdm=False, + override_bs=len(batch_windows), + ) + # Store results with their request indices + all_nlls.extend(zip(batch_indices, batch_nlls)) + + # Remove padding if necessary + if (self.world_size > 1) and (pad_amnt > 0): + all_nlls = all_nlls[:-pad_amnt] + + # Reconstruct per-request loglikelihoods + loglikelihoods = [] + current_idx = 0 + for window_count in request_window_counts: + # Get all nlls for this request + request_nlls = all_nlls[current_idx : current_idx + window_count] + # Sum up the nlls for this request (discarding is_greedy) + request_total = sum(nll[0] for _, nll in request_nlls) + loglikelihoods.append(request_total) + current_idx += window_count + + string = requests[len(loglikelihoods) - 1].args[0] + self.cache_hook.add_partial("loglikelihood_rolling", (string,), request_total) + + return loglikelihoods + + def _batch_scheduler(self, pos, n_reordered_requests): + sched = pos // int(len(n_reordered_requests) / self.batch_schedule) + if sched in self.batch_sizes: + return self.batch_sizes[sched] + if (len(self.batch_sizes) > 1) and (self.batch_sizes[sched - 1] == self.max_batch_size): + # if previous batch size is already maximal, skip recomputation + self.batch_sizes[sched] = self.max_batch_size + return self.batch_sizes[sched] + print(f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size") + self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) + print(f"Determined largest batch size: {self.batch_sizes[sched]}") + return self.batch_sizes[sched] + + def _loglikelihood_tokens( + self, + requests: list[tuple[tuple[str, str], list[int], list[int]]], + disable_tqdm: bool = False, + override_bs: int = None, + ) -> list[tuple[float, bool]]: + # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context + res = [] + + def _collate(req: tuple[tuple[str, str], list[int], list[int]]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + + toks = req[1] + req[2] + return -len(toks), tuple(toks) + + def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]): + """Defines the key to group and lookup one-token continuations""" + # Use with group_by="contexts" (optional)" + # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. + # speeds up some multiple-choice tasks proportionally to the number of choices. + # groups requests by context+continuation[:-1] and infer on one request/group. + return req[-2] + req[-1][:-1] + + re_ord = lm_eval.models.utils.Collator( + requests, + sort_fn=_collate, + group_by="contexts" if self.backend == "causal" and self.logits_cache else None, + group_fn=_lookup_one_token_cont, + ) + + # automatic (variable) batch size detection for vectorization + # pull longest context sample from request + n_reordered_requests = len(re_ord) + batch_size = self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0 + batch_fn = ( + self._batch_scheduler + if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs + else None + ) + + chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running loglikelihood requests", + ) + for chunk in chunks: + inps = [] + cont_toks_list = [] + inplens = [] + + conts = [] + encoder_attns = [] + + padding_len_inp = None + padding_len_cont = None + # because vectorizing is annoying, we first convert each (context, continuation) pair to padded + # tensors, then we pack them together into a batch, call the model, and then pick it all apart + # again because vectorizing is annoying + + for _, context_enc, continuation_enc in chunk: + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + # how this all works (illustrated on a causal decoder-only setup): + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # model \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + # when too long to fit in context, truncate from the left + if self.backend == "causal": + total_length = len(context_enc) + len(continuation_enc) + if total_length > self.max_length + 1: + eval_logger.warn( + f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) " + f"exceeds model's maximum length ({self.max_length}). " + f"Truncating {total_length - self.max_length + 1} tokens from the left." + ) + inp = torch.tensor( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], + dtype=torch.long, + device=self.device, + ) + (inplen,) = inp.shape + elif self.backend == "seq2seq": + inp = torch.tensor( + (context_enc)[-self.max_length :], + dtype=torch.long, + device=self.device, + ) + (inplen,) = inp.shape + + # build encoder attn masks + encoder_attns.append(torch.ones_like(inp)) + + cont = torch.tensor( + (continuation_enc)[-self.max_length :], + # TODO: left-shift these? + # TODO: our code assumes we never end up truncating conts for either model type + dtype=torch.long, + device=self.device, + ) + (contlen,) = cont.shape + + conts.append(cont) + + padding_len_cont = max(padding_len_cont, contlen) if padding_len_cont is not None else contlen + + padding_len_inp = max(padding_len_inp, inplen) if padding_len_inp is not None else inplen + + inps.append(inp) # [1, inp_length] + cont_toks_list.append(continuation_enc) + inplens.append(inplen) + + # create encoder attn mask and batched conts, if seq2seq + call_kwargs = {} + if self.backend == "causal": + batched_inps = lm_eval.models.utils.pad_and_concat( + padding_len_inp, inps, padding_side="right" + ) # [batch, padding_len_inp] + elif self.backend == "seq2seq": + # TODO: left-pad encoder inps and mask? + batched_inps = lm_eval.models.utils.pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp] + batched_conts = lm_eval.models.utils.pad_and_concat( + padding_len_cont, conts + ) # [batch, padding_len_cont] + batched_encoder_mask = lm_eval.models.utils.pad_and_concat( + padding_len_inp, encoder_attns + ) # [batch, padding_len_inp] + call_kwargs = { + "attention_mask": batched_encoder_mask, + "labels": batched_conts, + } + + multi_logits = F.log_softmax( + self._model_call(batched_inps, **call_kwargs), dim=-1 + ) # [batch, padding_length (inp or cont), vocab] + + for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( + chunk, multi_logits, inplens, cont_toks_list + ): + # Slice to original seq length + contlen = len(cont_toks) + # take only logits in the continuation + # (discard context toks if decoder-only ; discard right-padding) + # also discards + checks for "virtual tokens" in the causal LM's input window + # from prompt/prefix tuning tokens, if applicable + ctx_len = inplen + (logits.shape[0] - padding_len_inp) if self.backend == "causal" else None + logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) + logits = logits.unsqueeze(0) # [1, seq, vocab] + + # Check if per-token argmax is exactly equal to continuation + greedy_tokens = logits.argmax(dim=-1) + + # check for one-token continuation cache hits. + # noop in case group_by != "contexts" or no cache hit and returns the + # original args. Otherwise, expands the logits batch dimension and yields each + # batch along with matching continuation tokens and prompt strings. + # logits -> [1, seq, vocab] + for request_str, cont_toks, logits in re_ord.get_cache( + req_str=request_str, + cxt_toks=ctx_tokens, + cont_toks=cont_toks, + logits=logits, + ): + cont_toks = torch.tensor(cont_toks, dtype=torch.long, device=self.device).unsqueeze(0) # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() + + # Obtain log-probs at the corresponding continuation token indices + # last_token_slice = logits[:, -1, :].squeeze(0).tolist() + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] + + # Answer: (log prob, is-exact-match) + answer = (float(logits.sum()), bool(max_equal)) + + res.append(answer) + + if request_str is not None: + # special case: loglikelihood_rolling produces a number of loglikelihood requests + # all with cache key None. instead do add_partial on the per-example level + # in the loglikelihood_rolling() function for those. + self.cache_hook.add_partial("loglikelihood", request_str, answer) + pbar.update(1) + + pbar.close() + + return re_ord.get_original(res) + + def generate_until(self, requests: list[lm_eval.api.instance.Instance], disable_tqdm: bool = False) -> list[str]: + res = [] + + def _collate(req: tuple[str, dict]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(req[0]) + return -len(toks), req[0] + + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running generate_until requests", + ) + adaptive_batch_size = None + if self.batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + # for each different set of kwargs, we execute all requests, by batch. + batch_size = ( + self.batch_size + if self.batch_size != "auto" + else adaptive_batch_size if adaptive_batch_size is not None else 0 + ) + batch_fn = self._batch_scheduler if self.batch_size == "auto" and not adaptive_batch_size else None + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) + re_ords = lm_eval.models.utils.Collator( + [reg.args for reg in requests], + sort_fn=_collate, + group_by="gen_kwargs", + group_fn=lambda x: x[1], + ) + chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) + eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) + + for chunk in chunks: + contexts, all_gen_kwargs = zip(*chunk) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + # unpack our keyword arguments. + if isinstance(gen_kwargs, dict): + kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 + # add EOS token to stop sequences + until = lm_eval.models.utils.handle_stop_sequences(kwargs.pop("until", None), eos=eos) + else: + raise ValueError(f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}") + if "max_gen_toks" in kwargs.keys(): + max_gen_toks = kwargs.pop("max_gen_toks") + else: + max_gen_toks = self.max_gen_toks + + # set the max length in tokens of inputs ("context_enc") + if self.backend == "causal": + # max len for inputs = max length, minus room to generate the max new tokens + max_ctx_len = self.max_length - max_gen_toks + assert ( + max_ctx_len > 0 + ), f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})." + elif self.backend == "seq2seq": + # max len for inputs = encoder's whole max_length + max_ctx_len = self.max_length + + # encode, pad, and truncate contexts for this batch + input_ids, attention_mask = self.tok_batch_encode( + contexts, + left_truncate_len=max_ctx_len, + truncation=self.truncation, + ) + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + if "max_length" not in kwargs: + kwargs["max_length"] = input_ids.shape[1] + max_gen_toks + + # perform batched generation + cont = self._model_generate( + input_ids=input_ids, + attention_mask=attention_mask, + stop=until, + **kwargs, + ) + + # cont_toks_list = cont.tolist() + cont_toks_list = cont + + for cont_toks, context in zip(cont_toks_list, contexts): + # discard context + left-padding toks if using causal decoder-only LM + if self.backend == "causal": + cont_toks = cont_toks[input_ids.shape[1] :] + + s = self.tok_decode(cont_toks) + + # use secondary stop seqs to cut off should-have-been-stopped content post-hoc + for term in until: + if len(term) > 0: + # ignore '' separator, + # for seq2seq case where self.tok_decode(self.eot_token_id) = '' + s = s.split(term)[0] + + res.append(s) + + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + + return res + + def apply_chat_template(self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True) -> str: + """ + Method to apply a chat template to a list of chat history between user and model. + """ + try: + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + except jinja2.exceptions.TemplateError: + eval_logger.warning("Failed to apply chat template. removing the system role in chat history.") + chat_history = [msg for msg in chat_history if msg["role"] != "system"] + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + + return chat_templated diff --git a/fast_llm/engine/evaluation/lm_eval/utils.py b/fast_llm/engine/evaluation/lm_eval/utils.py new file mode 100644 index 00000000..396e73c1 --- /dev/null +++ b/fast_llm/engine/evaluation/lm_eval/utils.py @@ -0,0 +1,248 @@ +import argparse +import json +import logging +import os +import pathlib +import sys +from pathlib import Path + + +import lm_eval.evaluator +import lm_eval.loggers +import lm_eval.__main__ +import lm_eval.tasks +import lm_eval.utils + + +eval_logger = logging.getLogger(__name__) + + +def parse_eval_args(parser: argparse.ArgumentParser, args: list[str]) -> argparse.Namespace: + lm_eval.__main__.check_argument_types(parser) + return parser.parse_args(args) + + +def prepare_lm_eval_simple_eval_params( + cli_args: list[str], + completed_steps: int, + run_index: int, +) -> tuple[argparse.Namespace, dict[str, any]]: + """ + Parses CLI arguments for an LM evaluation run and prepares keyword arguments + for the `evaluate` function. + + This function wraps argument parsing, environment configuration, task resolution, + and metadata setup needed for evaluation with Fast-LLM's `lm_eval` wrapper. It also + handles special cases like hub token injection, dynamic sample loading, and task + listing commands. + + Args: + cli_args (list[str]): Command-line arguments, excluding the program name. + completed_steps (int): Current number of completed training steps, used to + uniquely tag evaluation output paths. + run_index (int): index of the current run of Fast-LLM experiment + + Returns: + tuple: + - argparse.Namespace: Parsed CLI arguments. + - dict: Keyword arguments to pass into `simple_evaluate`, including task list, + tracker, cache settings, random seeds, and generation parameters. + + Raises: + ValueError: If required fields like `--tasks` or `--output_path` are missing + when needed, or if misconfigured combinations are detected. + SystemExit: If special task listing flags are used. + """ + parser = lm_eval.__main__.setup_parser() + args = parse_eval_args(parser, cli_args) + + # NOTE: all this args are set by fast_llm on the model directly or not used here + assert not args.wandb_args # default empty string + assert not args.wandb_config_args # default empty string + assert args.model == "hf" # default value of 'hf' + assert not args.model_args # default empty string + assert args.batch_size == 1 # default value of 1 + assert args.max_batch_size is None + assert args.device is None + # if args.wandb_args: + # wandb_args_dict = simple_parse_args_string(args.wandb_args) + # wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args) + # wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict) + + # TODO: change logging levels from fast_llm to lm_eval and then back? + # utils.setup_logging(args.verbosity) + # eval_logger = logging.getLogger(__name__) + + # update the evaluation tracker args with the output path and the HF token + evaluation_tracker_args = "" + if args.output_path: + args.output_path = str(pathlib.Path(args.output_path) / f"runs/{run_index}/{completed_steps}") + evaluation_tracker_args += f",output_path={args.output_path}" + + evaluation_tracker_args = lm_eval.utils.simple_parse_args_string(evaluation_tracker_args) + evaluation_tracker = lm_eval.loggers.EvaluationTracker(**evaluation_tracker_args) + + if args.predict_only: + args.log_samples = True + if (args.log_samples or args.predict_only) and not args.output_path: + raise ValueError("Specify --output_path if providing --log_samples or --predict_only") + + if args.fewshot_as_multiturn and args.apply_chat_template is False: + raise ValueError( + "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)." + ) + + if args.include_path is not None: + eval_logger.info(f"Including path: {args.include_path}") + metadata = ( + lm_eval.utils.simple_parse_args_string(args.model_args) + if isinstance(args.model_args, str) + else args.model_args if isinstance(args.model_args, dict) else {} + ) | (args.metadata if isinstance(args.metadata, dict) else lm_eval.utils.simple_parse_args_string(args.metadata)) + + task_manager = lm_eval.tasks.TaskManager(include_path=args.include_path, metadata=metadata) + + if args.limit: + eval_logger.warning( + " --limit SHOULD ONLY BE USED FOR TESTING." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." + ) + if args.samples: + assert args.limit is None, "If --samples is not None, then --limit must be None." + if (samples := Path(args.samples)).is_file(): + args.samples = json.loads(samples.read_text()) + else: + args.samples = json.loads(args.samples) + + if args.tasks is None: + eval_logger.error("Need to specify task to evaluate.") + sys.exit() + elif args.tasks == "list": + print(task_manager.list_all_tasks()) + sys.exit() + elif args.tasks == "list_groups": + print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False)) + sys.exit() + elif args.tasks == "list_tags": + print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False)) + sys.exit() + elif args.tasks == "list_subtasks": + print(task_manager.list_all_tasks(list_groups=False, list_tags=False)) + sys.exit() + else: + if os.path.isdir(args.tasks): + import glob + + task_names = [] + yaml_path = os.path.join(args.tasks, "*.yaml") + for yaml_file in glob.glob(yaml_path): + config = lm_eval.utils.load_yaml_config(yaml_file) + task_names.append(config) + else: + task_list = args.tasks.split(",") + task_names = task_manager.match_tasks(task_list) + for task in [task for task in task_list if task not in task_names]: + if os.path.isfile(task): + config = lm_eval.utils.load_yaml_config(task) + task_names.append(config) + task_missing = [ + task for task in task_list if task not in task_names and "*" not in task + ] # we don't want errors if a wildcard ("*") task name was used + + if task_missing: + missing = ", ".join(task_missing) + eval_logger.error( + f"Tasks were not found: {missing}\n" + f"{lm_eval.utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", + ) + raise ValueError( + f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all" + " available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG'" + " to troubleshoot task registration issues." + ) + + ( + eval_logger.info(f"Selected Tasks: {task_names}") + if eval_logger.getEffectiveLevel() >= logging.INFO + else print(f"Selected Tasks: {task_names}") + ) + + request_caching_args = lm_eval.evaluator.request_caching_arg_to_dict(cache_requests=args.cache_requests) + + eval_kwargs = dict( + tasks=task_names, + num_fewshot=args.num_fewshot, + # batch_size=args.batch_size, + # max_batch_size=args.max_batch_size, + # device=args.device, + use_cache=args.use_cache, + limit=args.limit, + samples=args.samples, + check_integrity=args.check_integrity, + write_out=args.write_out, + log_samples=args.log_samples, + evaluation_tracker=evaluation_tracker, + system_instruction=args.system_instruction, + apply_chat_template=args.apply_chat_template, + fewshot_as_multiturn=args.fewshot_as_multiturn, + gen_kwargs=args.gen_kwargs, + task_manager=task_manager, + predict_only=args.predict_only, + random_seed=args.seed[0], + numpy_random_seed=args.seed[1], + torch_random_seed=args.seed[2], + fewshot_random_seed=args.seed[3], + confirm_run_unsafe_code=args.confirm_run_unsafe_code, + metadata=metadata, + **request_caching_args, + ) + + return args, eval_kwargs + + +def process_lm_eval_results( + args: argparse.Namespace, + results: dict[str, any], + evaluation_tracker: lm_eval.loggers.EvaluationTracker, + completed_steps: int | None, +) -> None: + if results is not None: + completed_steps = 0 if completed_steps is None else completed_steps + import wandb + + if args.log_samples: + samples = results.pop("samples") + dumped = json.dumps(results, indent=2, default=lm_eval.utils.handle_non_serializable, ensure_ascii=False) + if args.show_config: + print(dumped) + + batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) + + # Add W&B logging if we have the run to log to + # we expect the rest of the fast_llm code will finish the run. + if wandb.run is not None: + try: + wandb_logger = lm_eval.loggers.WandbLogger(init_args={"step": completed_steps}) + wandb_logger.post_init(results) + wandb_logger.log_eval_result() + if args.log_samples: + wandb_logger.log_eval_samples(samples) + except Exception as e: + eval_logger.info(f"Logging to Weights and Biases failed due to {e}") + + evaluation_tracker.save_results_aggregated(results=results, samples=samples if args.log_samples else None) + + if args.log_samples: + for task_name, config in results["configs"].items(): + evaluation_tracker.save_results_samples(task_name=task_name, samples=samples[task_name]) + + if evaluation_tracker.push_results_to_hub or evaluation_tracker.push_samples_to_hub: + evaluation_tracker.recreate_metadata_card() + + # TODO: convert to logging entries instead? + print( + f"{results["config"]["model"]}, gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " + f"batch_size: {results["config"]["batch_size"]}{f' ({batch_sizes})' if batch_sizes else ''}" + ) + print(lm_eval.utils.make_table(results)) + if "groups" in results: + print(lm_eval.utils.make_table(results, "groups")) diff --git a/fast_llm/engine/inference/config.py b/fast_llm/engine/inference/config.py index d4b46bcc..c18daa48 100644 --- a/fast_llm/engine/inference/config.py +++ b/fast_llm/engine/inference/config.py @@ -91,7 +91,8 @@ def __eq__(self, other) -> bool: def to_dict(self) -> dict[str, typing.Any]: out = super().to_dict() - out["fast_llm_config"] = self.fast_llm_config.to_dict(verbose=FieldVerboseLevel.everything) + if self.fast_llm_config is not None: + out["fast_llm_config"] = self.fast_llm_config.to_dict(verbose=FieldVerboseLevel.everything) return out def to_diff_dict(self) -> dict[str, typing.Any]: diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 196310b4..3656354e 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -2,16 +2,22 @@ import pathlib import typing +import torch import transformers.modeling_outputs +import transformers.generation.utils from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.engine.training.config import TrainerConfig -class HuggingfacePreTrainedModel(transformers.PreTrainedModel): +class HuggingfaceBaseModel(transformers.PreTrainedModel): config_class: typing.ClassVar[type[HuggingfaceModelConfig]] = HuggingfaceModelConfig runner_class: typing.ClassVar[type[InferenceRunner]] = InferenceRunner config: HuggingfaceModelConfig @@ -20,31 +26,86 @@ class HuggingfacePreTrainedModel(transformers.PreTrainedModel): # _supports_cache_class = False # _tied_weights_keys = [] - def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel, **kwargs): + def __init__( + self, + config: HuggingfaceModelConfig, + fast_llm_model: FastLLMModel, + micro_batch_size: int | None = None, + runner: ScheduleRunner | None = None, + **kwargs, + ): assert self.runner_class.model_class.config_class is config.model_config_class assert config.fast_llm_config is fast_llm_model.config assert isinstance(config, self.config_class) + # The HF constructor performs a deep copy of the config, + # but config.fast_llm_config may contain non-picklable items like process groups. + # Temporarily remove it before the call and restore it afterward. + fast_llm_config = config.fast_llm_config + config.fast_llm_config = None super().__init__(config, **kwargs) + config.fast_llm_config = fast_llm_config + + self._inference_runner = self.runner_class(fast_llm_model, micro_batch_size, runner) - self._inference_runner = self.runner_class(fast_llm_model) - if not fast_llm_model.is_setup: - fast_llm_model.setup(mode=StageMode.inference) + # A model can be created from pretrained which setup it in the current HF wrapper api + # or set from training loop and also is setup, so, do not accept not setup model + assert fast_llm_model.is_setup + # if not fast_llm_model.is_setup: + # fast_llm_model.setup(distributed=distributed, mode=StageMode.inference) self._inference_runner.setup() + # Transformers needs to be able to inspect the base model. self.fast_llm_base_model = fast_llm_model.base_model - # TODO: Support distributed models? - assert fast_llm_model.config.distributed.world_size == 1 + # # TODO: Support distributed models? + # assert fast_llm_model.config.distributed.world_size == 1 with transformers.modeling_utils.no_init_weights(): self.post_init() + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values=None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: + # Meant to be overridden in derived classes + raise NotImplementedError() + + @classmethod + def from_model( + cls, + fast_llm_model: FastLLMModel, + micro_batch_size: int | None = None, + runner: ScheduleRunner | None = None, + **kwargs, + ): + config = cls.config_class(fast_llm_model.config) + return cls( + config, + fast_llm_model, + micro_batch_size=micro_batch_size, + runner=runner, + **kwargs, + ) + @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str | os.PathLike | CheckpointLoadConfig, - *, - mode: StageMode = StageMode.inference, + *updates: dict[str | tuple[str, ...], typing.Any], + optimizer_state_names: tuple[str, ...] | None = None, + # setup: bool = True, + mode: StageMode = StageMode.training, + use_cpu: bool = False, + stage_filter: set | None = None, **kwargs, ) -> typing.Self: # Pretrained config. @@ -54,18 +115,24 @@ def from_pretrained( format=FastLLMCheckpointFormat, ) - updates = {} - torch_dtype = kwargs.pop("torch_dtype", None) - if torch_dtype is not None: - updates[("distributed", "training_dtype")] = torch_dtype - # Create the model + # always set up model and crate distributed instance internally for now fast_llm_model = cls.runner_class.model_class.from_pretrained( - pretrained_model_name_or_path, updates, mode=mode + pretrained_model_name_or_path, + *updates, + optimizer_state_names=optimizer_state_names, + # setup=setup, + mode=mode, + use_cpu=use_cpu, + stage_filter=stage_filter, ) - config = cls.config_class(fast_llm_model.config) + config = cls.config_class(fast_llm_model.config) return cls(config, fast_llm_model, **kwargs) def _init_weights(self, module) -> None: raise NotImplementedError(module) + + +class HuggingfaceBaseModelForCausalLM(HuggingfaceBaseModel, transformers.generation.utils.GenerationMixin): + pass diff --git a/fast_llm/engine/inference/runner.py b/fast_llm/engine/inference/runner.py index 30f836b7..52eff82b 100644 --- a/fast_llm/engine/inference/runner.py +++ b/fast_llm/engine/inference/runner.py @@ -7,27 +7,43 @@ from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.engine.training.config import TrainerConfig class InferenceRunner(abc.ABC): model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel batch_config_class: typing.ClassVar[type[BatchConfig]] = BatchConfig - def __init__(self, fast_llm_model: FastLLMModel): + def __init__( + self, + fast_llm_model: FastLLMModel, + micro_batch_size: int | None = None, + runner: ScheduleRunner | None = None, + ): assert isinstance(fast_llm_model, self.model_class) self._fast_llm_model = fast_llm_model - # We only need a basic schedule and don't care about dimensions. - self._schedule_config = ScheduleConfig() - # TODO: Sort things out. + with NoAutoValidate(): - self._batch_config = self.batch_config_class() + self._batch_config = self.batch_config_class(micro_batch_size=micro_batch_size) self._batch_config.setup(self._fast_llm_model.config.distributed) self._batch_config.validate() - self._runner = ScheduleRunner( - config=self._schedule_config, - multi_stage=self._fast_llm_model, - distributed_config=self._fast_llm_model.config.distributed, - ) + + if runner is None: + # We only need a basic schedule and don't care about dimensions. + self._schedule_config = ScheduleConfig() + # TODO: Sort things out. + + self._runner = ScheduleRunner( + config=self._schedule_config, + multi_stage=self._fast_llm_model, + distributed_config=self._fast_llm_model.config.distributed, + ) + else: + self._schedule_config = runner.config + self._runner = runner + # External runner from training loop must be already setup + assert runner._is_setup + # TODO: Random state? (Distributed.set_step) self._schedule = Schedule( multi_stage=self._fast_llm_model, @@ -42,7 +58,8 @@ def fast_llm_model(self) -> FastLLMModel: return self._fast_llm_model def setup(self): - self._runner.setup(self._fast_llm_model.distributed) + if not self._runner._is_setup: + self._runner.setup(self._fast_llm_model.distributed) def forward( self, input_, kwargs: dict, *, iteration: int = 1, return_metrics: bool = False diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index e2d04f80..ae3abc70 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -30,7 +30,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel + from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel logger = logging.getLogger(__name__) @@ -247,7 +247,7 @@ def get_model_class(cls) -> type["FastLLMModel"]: raise NotImplementedError @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfacePreTrainedModel"]: + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceBaseModelForCausalLM"]: raise NotImplementedError @classmethod diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 21d0fe55..a4e0ce73 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -1,3 +1,4 @@ +import abc import dataclasses import logging import typing @@ -12,7 +13,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.run import log_main_rank, log_model_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim -from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.fsdp import FSDP @@ -257,6 +258,11 @@ def setup(self, distributed: Distributed | None = None, mode: StageMode = StageM self.train(self._mode.support_backward) + @abc.abstractmethod + def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, sequence_length) -> tuple[int, int]: + # TODO: Do in model, automate/generalize, get other stats + pass + def _allocate_buffers( self, buffer_meta: TensorMeta, sizes: list[int], name: str ) -> tuple[tuple[torch.Tensor, ...], int]: diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 675e878b..a3ac98e8 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -13,6 +13,9 @@ from fast_llm.tensor import ParameterMeta, TensorMeta, accumulate_gradient from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.core.distributed import ProcessGroup + logger = logging.getLogger(__name__) @@ -111,6 +114,21 @@ def forward( metrics, ) self._log_layer_forward(output, kwargs, i) + + # TODO: very slow and memory consuming, only use for debugging for now + # TODO: decide if and how we want to return + # HF transformer style details from forward properly + if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]: + # Last layer does not provide output + if output is not None: + meta = self._meta_outputs[i] + output_global, _ = meta.local_to_global(output.detach(), distributed=self._distributed) + else: + output_global = None + kwargs["hidden_states"][self._layer_range[i]] = { + "layer_type": type(layer).__name__, + "tensor": output_global, + } return None if output is None else output.detach(), (input_, output) def backward( diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 8eca4559..b61607e9 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -396,8 +396,9 @@ def _recv(self, context: BatchContext, step: Step) -> None: self._record_event(context, EventType.compute_wait_pipe, step) def _forward(self, context: BatchContext, step: Step) -> None: + input = self._get_forward_input(context, step) output, grad_context = self._stages[step.stage].forward( - self._get_forward_input(context, step), + input, context.batch[step.data_index], losses=context.losses, metrics=context.metrics, diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 1e990e9c..73c4be54 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -26,12 +26,13 @@ from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig +from fast_llm.engine.evaluation.config import EvaluatorConfig, EvaluatorConfigBase from fast_llm.profile import ProfilingConfig -from fast_llm.utils import Assert +from fast_llm.utils import Assert, Registry if typing.TYPE_CHECKING: from fast_llm.engine.inference.runner import InferenceRunner - from fast_llm.engine.training.trainer import Trainer + from fast_llm.engine.training.trainer import Trainer, TrainingEvaluator @config_class() @@ -152,22 +153,37 @@ class WandbConfig(Config): @config_class() -class EvaluationConfig(IntervalConfig): - interval = FieldUpdate( - desc="The number of training iterations between each evaluation phase." - " Setting to None will disable evaluation." - ) - offset = FieldUpdate(desc="Offset for the first evaluation phase.") - iterations: int | None = Field( - default=None, - desc="Number of iterations for each evaluation phase. Setting to None will disable.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), - ) - - def get_iteration_count(self, training_iterations: int, extra_evaluations: int = 0): - # Number of completed validation iterations - return (self.get_count(training_iterations) + extra_evaluations) * self.iterations if self.enabled() else 0 +class TrainingEvaluatorConfig(EvaluatorConfigBase): + run_interval: IntervalConfig = Field(default_factory=IntervalConfig) + evaluator: EvaluatorConfig = Field(default_factory=EvaluatorConfig) + + def get_run_count(self, training_iterations: int, extra_evaluations: int = 0): + # Number of completed evaluation runs + return (self.run_interval.get_count(training_iterations) + extra_evaluations) if self.run_interval.enabled() else 0 + + def get_evaluator( + self, + name: str, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ) -> "TrainingEvaluator": + from fast_llm.engine.training.trainer import TrainingEvaluator + + return TrainingEvaluator(name, self, batch_config, data_load_num_proc, train_iters) + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + # TODO v0.x: Remove backward compatibility. + cls._handle_renamed_field(default, "interval", ("run_interval", "interval")) + cls._handle_renamed_field(default, "offset", ("run_interval", "offset")) + cls._handle_renamed_field(default, "iterations", ("evaluator", "iterations")) + return super()._from_dict(default, strict, flat) @config_class() @@ -279,7 +295,7 @@ class ShutdownConfig(IntervalConfig): @config_class() class TrainingConfig(Config): - evaluations: dict[str, EvaluationConfig] = Field( + evaluations: dict[str, TrainingEvaluatorConfig] = Field( default_factory=dict, desc="A dictionary of evaluation dataset names and their configurations for the validation phase.", hint=FieldHint.core, diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index abd8f9dc..717ccfe2 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -15,12 +15,26 @@ from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.evaluation.evaluator import ( + EvaluationMetrics, + Evaluator, + EvaluatorRunner, + EvaluatorSamplingParameters, + TrainingProgress, +) from fast_llm.engine.multi_stage.config import StageMode +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.engine.optimizer.optimizer import Optimizer +from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.engine.training.config import TrainerConfig, TrainingCheckpointBaseConfig, TrainingCheckpointConfig +from fast_llm.engine.training.config import ( + TrainerConfig, + TrainingCheckpointBaseConfig, + TrainingCheckpointConfig, + TrainingEvaluatorConfig, +) from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics, get_memory_usage_mib, log_memory_usage from fast_llm.utils import Assert @@ -28,6 +42,77 @@ logger = logging.getLogger(__name__) +class TrainingEvaluator[ConfigType: TrainingEvaluatorConfig](Evaluator[ConfigType]): + config_class: typing.ClassVar[type[TrainingEvaluatorConfig]] = TrainingEvaluatorConfig + + evaluator: Evaluator + + def __init__( + self, + name: str, + eval_config: TrainingEvaluatorConfig, + batch_config: BatchConfig, + data_load_num_proc: int, + train_iters: int | None = None, + ): + super().__init__(name, eval_config, batch_config, data_load_num_proc, train_iters) + + self._train_iters = 0 if self._train_iters is None else self._train_iters + + self.evaluator = eval_config.evaluator.get_evaluator(name, batch_config, data_load_num_proc, train_iters) + + def setup( + self, + distributed: Distributed, + run: Run, + multi_stage: FastLLMModel, + runner: ScheduleRunner, + data: Data, + phase: PhaseType, + ) -> None: + self.evaluator.setup( + distributed, + run, + multi_stage, + runner, + data, + phase, + ) + + def run( + self, + training_progress: TrainingProgress | None = None, + run_index: int | None = None, + ) -> EvaluationMetrics: + # Run index must be None because it is defined here to be passed to actual evaluator + assert run_index is None + + # Training progress can be None as it can be run in a training + # run without training, just evaluation + if training_progress is None: + done = True + completed_steps = 0 + else: + done = training_progress.done + completed_steps = training_progress.completed_steps + + if done or self.config.run_interval.enabled(completed_steps): + return self.evaluator.run(training_progress, run_index=self._config.get_run_count(completed_steps - 1)) + else: + return EvaluationMetrics() + + def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: + name_samples = self.evaluator.get_sampling_parameters() + if name_samples is None: + return None + run_count = self._config.get_run_count( + self._train_iters, + # There may be an extra evaluation after the last training step.s + not self._config.run_interval.enabled(self._train_iters), + ) + return EvaluatorSamplingParameters(name_samples.dataset_name, name_samples.num_samples * run_count) + + class Trainer[ConfigType: TrainerConfig](Configurable[ConfigType], abc.ABC): config_class: typing.ClassVar[type[TrainerConfig]] = TrainerConfig # TODO: Generalize data, schedule, logging, etc. @@ -39,13 +124,20 @@ class Trainer[ConfigType: TrainerConfig](Configurable[ConfigType], abc.ABC): _completed_steps: int + _is_evaluation_only: bool + + _evaluator_runner: EvaluatorRunner + def __init__(self, config: TrainerConfig): super().__init__(config) + + self._is_evaluation_only = config.training.train_iters == 0 + self._data = self._get_data() log_main_rank("Creating model...") self._multi_stage = self._config.model.get_model_class()( self._config.model, - optimizer_state_names=self._config.optimizer.state_names(), + optimizer_state_names=self._config.optimizer.state_names() if not self._is_evaluation_only else (), ) self._reference_models = {} for name, reference_config in self._config.reference_models.items(): @@ -55,51 +147,54 @@ def __init__(self, config: TrainerConfig): ) self._multi_stage.base_model.add_reference_model(name, self._reference_models[name]) - phase: PhaseType self._runner = ScheduleRunner( config=self._config.schedule, multi_stage=self._multi_stage, distributed_config=self._config.model.distributed, ) - steps_per_split = { - PhaseType.training: {PhaseType.training.value.lower(): self._config.training.train_iters}, - PhaseType.validation: { - dataset_name: self._config.training.evaluations[dataset_name].get_iteration_count( - self._config.training.train_iters, - # There may be an extra evaluation after the last training step. - not self._config.training.evaluations[dataset_name].enabled(self._config.training.train_iters), - ) - for dataset_name in self._config.training.evaluations.keys() - }, - PhaseType.test: {PhaseType.test.value.lower(): self._config.training.test_iters}, - } - self._samples_per_split = { - phase: { - dataset_name: self._config.batch.batch_size * steps - for dataset_name, steps in datasets.items() - if steps > 0 - } - for phase, datasets in steps_per_split.items() - } - # Prune empty phases. - self._samples_per_split = {k: v for k, v in self._samples_per_split.items() if len(v) > 0} - self._loss_defs = self._multi_stage.base_model.loss_defs - # Setup the schedules - self._schedule = { - phase: { - dataset_name: Schedule( - multi_stage=self._multi_stage, - batch_config=self._config.batch, - schedule_config=self._config.schedule, - distributed_config=self._config.model.distributed, - phase=phase, - ) - for dataset_name in datasets + if not self._is_evaluation_only: + steps_per_split = { + PhaseType.training: {PhaseType.training.value.lower(): self._config.training.train_iters}, + PhaseType.test: {PhaseType.test.value.lower(): self._config.training.test_iters}, } - for phase, datasets in self._samples_per_split.items() - } + + self._samples_per_split = { + phase: { + dataset_name: self._config.batch.batch_size * steps + for dataset_name, steps in datasets.items() + if steps > 0 + } + for phase, datasets in steps_per_split.items() + } + # Prune empty phases. + self._samples_per_split = {k: v for k, v in self._samples_per_split.items() if len(v) > 0} + + # Setup the schedules + self._schedule = { + phase: { + dataset_name: Schedule( + multi_stage=self._multi_stage, + batch_config=self._config.batch, + schedule_config=self._config.schedule, + distributed_config=self._config.model.distributed, + phase=phase, + ) + for dataset_name in datasets + } + for phase, datasets in self._samples_per_split.items() + } + else: + self._samples_per_split = {} + + self._evaluator_runner = EvaluatorRunner( + evaluator_configs=self._config.training.evaluations, + batch_config=self._config.batch, + data_load_num_proc=self._config.training.num_workers, + train_iters=self._config.training.train_iters, + wandb_config=self._config.training.wandb, + ) def setup(self, distributed: Distributed, run: Run) -> None: assert distributed.config is self._config.model.distributed @@ -117,18 +212,23 @@ def setup(self, distributed: Distributed, run: Run) -> None: reference_model.fast_llm_model.setup(distributed, StageMode.inference) reference_model.setup() + # TODO: Check with Joel if this will be enought not to allocate grad buffers. # Setup the optimizer. - param_groups, grads_for_norm = self._multi_stage.get_param_groups(ParamGroup) - self._optimizer = self._config.optimizer.optimizer_cls( - self._config.optimizer, - param_groups=param_groups, - grads_for_norm=grads_for_norm, - distributed=self._distributed, - ) + if self._is_evaluation_only: + self._optimizer = None + else: + param_groups, grads_for_norm = self._multi_stage.get_param_groups(ParamGroup) + self._optimizer = self._config.optimizer.optimizer_cls( + self._config.optimizer, + param_groups=param_groups, + grads_for_norm=grads_for_norm, + distributed=self._distributed, + ) # Setup the schedules. with torch.no_grad(): self._runner.setup(distributed, self._optimizer) + # Setup the datasets. log_main_rank("Preparing datasets...") self._data.setup( @@ -137,10 +237,28 @@ def setup(self, distributed: Distributed, run: Run) -> None: dataset_name: self._get_sampling_parameters({"num_samples": samples}) for datasets in self._samples_per_split.values() for dataset_name, samples in datasets.items() + } + | { + eval_sampling_params.dataset_name: self._get_sampling_parameters( + {"num_samples": eval_sampling_params.num_samples} + ) + for eval_sampling_params in self._evaluator_runner.get_sampling_parameters() }, None if run.experiment_directory is None else run.experiment_directory / "dataset_cache", timeout=self._config.training.timeout, ) + + # Must be called with all arguments set up + self._evaluator_runner.setup( + distributed=self._distributed, + run=self._run, + multi_stage=self._multi_stage, + runner=self._runner, + data=self._data, + wandb=self._wandb, + phase=PhaseType.inference if self._is_evaluation_only else PhaseType.validation, + ) + self._is_setup = True @abc.abstractmethod @@ -162,21 +280,21 @@ def _consumed_tokens(self) -> int: assert self._is_setup return self._consumed_samples * self._config.batch.sequence_length - def _get_completed_evaluation_steps(self, dataset_name) -> int: - # Number of evaluations steps performed before the current step - return self._config.training.evaluations[dataset_name].get_iteration_count(self._completed_steps - 1) - def run(self) -> None: assert self._is_setup with self._wandb: self._run_training() def _run_training(self) -> None: - self._prepare_training_state() + self._prepare_model_state() + log_main_rank("done with setup ...") log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"After initial setup", str)) self._run.save_logged_tensors("init") + if self._is_evaluation_only: + assert len(self._samples_per_split) == 0 + if PhaseType.training in self._samples_per_split: done = self._completed_steps >= self._config.training.train_iters if done: @@ -185,13 +303,15 @@ def _run_training(self) -> None: else: done, metrics = self._train() else: - done, metrics = True, {} + metrics = {} + done = True + self._evaluator_runner.run(metrics=metrics) if done and PhaseType.test in self._samples_per_split: log_main_rank(lambda: f"Running test phase ...") test_iterator = self._get_data_iterator(PhaseType.test.value.lower()) metrics_key = PhaseType.test.value - metrics[metrics_key] = self._evaluate( + metrics[metrics_key] = self._evaluate_loss( data_iterator=test_iterator, phase=PhaseType.test, num_iters=self._config.training.test_iters, @@ -219,7 +339,6 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: self._completed_steps, self._config.training.prefetch_factor, ) - evaluation_iterators = {name: None for name in self._config.training.evaluations.keys()} log_main_rank("Training ...") @@ -271,7 +390,12 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: remaining_time = average_time_per_iteration * ( self._config.training.train_iters - self._completed_steps ) - model_tflops, hardware_tflops = self.get_tflops(PhaseType.training, time_per_iteration) + model_tflops, hardware_tflops = self._multi_stage.get_tflops( + PhaseType.training, + time_per_iteration, + self._config.batch.batch_size, + self._config.batch.sequence_length, + ) metrics_key = PhaseType.training.value metrics[metrics_key] = { "train_iters": self._config.training.train_iters, @@ -319,49 +443,18 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: done = self._completed_steps >= self._config.training.train_iters # TODO: Signal-based stop. stop = done or self._config.training.shutdown.enabled(self._completed_steps) + # Evaluation # TODO: Adjust valid iterator length. - if PhaseType.validation in self._samples_per_split and ( - done - or any( - evaluation_conf.enabled(self._completed_steps) - for evaluation_conf in self._config.training.evaluations.values() - ) - ): - formatted_metrics = [] - for dataset_name, evaluation_conf in self._config.training.evaluations.items(): - if not evaluation_conf.enabled(self._completed_steps): - continue - if evaluation_iterators[dataset_name] is None: - evaluation_iterators[dataset_name] = self._get_data_iterator( - dataset_name, self._get_completed_evaluation_steps(dataset_name) - ) - # TODO: formatting metric category as Validation.evaluation_dataset_name - # maybe format each metric with evaluation_dataset_name prefix instead? - # TODO: setting performance metrics per evaluation dataset - # maybe to set aggregate performance metrics for all evaluations datasets? - metric_key = f"{PhaseType.validation.value}.{dataset_name}" - metrics[metric_key] = self._evaluate( - data_iterator=evaluation_iterators[dataset_name], - phase=PhaseType.validation, - num_iters=evaluation_conf.iterations, - begin_iter=self._get_completed_evaluation_steps(dataset_name), - dataset_name=dataset_name, - ) - formatted_metrics.append( - format_metrics( - metrics[metric_key], - self._loss_defs, - PhaseType.validation, - dataset_name=dataset_name, - ) - ) - - if len(formatted_metrics) > 0: - formatted_metrics = "\n".join(formatted_metrics) - log_main_rank(formatted_metrics) - if self._config.training.wandb.alert.enabled(self._completed_steps): - self._wandb.alert("Validation results", formatted_metrics, "INFO") + self._evaluator_runner.run( + metrics=metrics, + training_progress=TrainingProgress( + done=done, + completed_steps=self._completed_steps, + consumed_samples=self._consumed_samples, + consumed_tokens=self._consumed_tokens, + ), + ) if is_main_rank() and metrics: self._wandb.log_metrics(self._completed_steps, metrics) @@ -375,55 +468,6 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: profiler.step() return done, metrics - def _evaluate( - self, - *, - data_iterator: typing.Iterator, - phase: PhaseType, - num_iters: int, - begin_iter: int = 0, - dataset_name: str | None = None, - ) -> dict[str, float | int]: - full_phase_name = phase.value if dataset_name is None else f"{phase.value}_{dataset_name}" - safe_barrier(self._distributed.world_group, f"{full_phase_name} begin") - begin_time = time.perf_counter() - total_losses = {loss_def.name: 0.0 for loss_def in self._loss_defs} - for iter_ in range(num_iters): - iter_losses, _, _ = self._runner.run_step( - data_iterator, self._schedule[phase][dataset_name], iteration=begin_iter + iter_ - ) - for name, value in iter_losses.items(): - total_losses[name] += value - self._run.save_logged_tensors(f"{full_phase_name}_{self._completed_steps}_{iter_}") - - safe_barrier( - self._distributed.world_group, - f"{full_phase_name} end", - ) - end_time = time.perf_counter() - time_per_iteration = (end_time - begin_time) / num_iters - model_tflops, hardware_tflops = self.get_tflops(phase, time_per_iteration) - # TODO add other relevant eval metrics - metrics = { - "train_iters": self._config.training.train_iters, - "batch_size": self._config.batch.batch_size, - "iteration": self._completed_steps, - **{name: (value / num_iters) for name, value in total_losses.items()}, - "consumed_samples": self._consumed_samples, - "consumed_tokens": self._consumed_tokens, - "step_time_ms": time_per_iteration * 1000, - "model_tflops": model_tflops, - "hardware_tflops": hardware_tflops, - "tokens_per_sec_per_gpu": ( - (self._config.batch.sequence_length * self._config.batch.batch_size) - / self._config.model.distributed.world_size - / time_per_iteration - ), - **get_memory_usage_mib(), - } - - return metrics - def _get_data_iterator( self, dataset_name, completed_steps: int = 0, prefetch_factor: int | None = None ) -> typing.Iterator[typing.Any]: @@ -436,7 +480,7 @@ def _get_data_iterator( timeout=self._config.training.timeout, ) - def _prepare_training_state(self) -> None: + def _prepare_model_state(self) -> None: # Setup the training state. if (last_iteration := self._get_last_checkpoint()) is None: if (path := self._config.pretrained.path) is not None and self._config.pretrained.model_weights: @@ -447,9 +491,15 @@ def _prepare_training_state(self) -> None: ) self._multi_stage.load_checkpoint(self._config.pretrained) else: + if self._is_evaluation_only: + raise ValueError( + "Evaluation mode, model need to be trained first or pretrained checkpoint is provided for loading" + ) log_main_rank(f"Initializing training state from scratch...") self._multi_stage.initialize_weights() - self._optimizer.reset_state() + + if not self._is_evaluation_only: + self._optimizer.reset_state() self._completed_steps = 0 else: log_main_rank(lambda: f"Loading checkpoint from iteration {last_iteration}...") @@ -526,7 +576,8 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) -> config.get_load_config(checkpoint_directory, timeout=self._config.training.timeout) ) assert metadata is not None - self._optimizer.load(metadata["optimizer"]) + if not self._is_evaluation_only: + self._optimizer.load(metadata["optimizer"]) if "schedules" in metadata: # Backward compatibility. self._completed_steps = metadata["schedules"][PhaseType.training.value]["completed_steps"] @@ -553,8 +604,3 @@ def _get_last_checkpoint(self) -> int | None: iteration = -1 iteration = self._run.broadcast_int(iteration) return iteration if iteration >= 0 else None - - @abc.abstractmethod - def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: - # TODO: Do in model, automate/generalize, get other stats - pass diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 813dcc07..79c9f61b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -5,7 +5,7 @@ from torch.distributed import all_reduce from fast_llm.config import Configurable -from fast_llm.core.ops import split_op +from fast_llm.core.ops import gather_op, split_op from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames @@ -175,6 +175,14 @@ def _forward_backward( with torch.enable_grad(): ln_output = self.final_norm(input_) + if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]: + # The last hidden layer output is returned normalized in the HF Transformers-style output, at least for LLama style models. + # So, if needed, we gather the data after normalization and set it as the output of the previous layer. + group = self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None + sequence_parallel = self._sequence_parallel and self._parallel_embeddings + hidden_state = gather_op(ln_output.detach(), group, dim=0) if sequence_parallel else ln_output.detach() + kwargs["hidden_states"][len(kwargs["hidden_states"]) - 1]["tensor"] = hidden_state + grad_output = kwargs[TransformerKwargs.grad_output] / ( self._group_size if self._sequence_parallel_logits else 1 ) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 2415a2f9..1e07c1c1 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -239,7 +239,7 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ] if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( - [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] + [torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) for sample_lens in sequence_lengths] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) kwargs[TransformerKwargs.attention_mask] = ( diff --git a/fast_llm/logging.py b/fast_llm/logging.py index ffeb56f6..e1414f00 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -92,11 +92,13 @@ _METRIC_FORMATS_KEYS = { PhaseType.training: _TRAINING_METRIC_FORMAT_KEYS, PhaseType.validation: _VALIDATION_METRIC_FORMAT_KEYS, + PhaseType.inference: _VALIDATION_METRIC_FORMAT_KEYS, PhaseType.test: _VALIDATION_METRIC_FORMAT_KEYS, } _METRIC_FORMATS = { PhaseType.training: _TRAINING_METRIC_FORMATS, PhaseType.validation: _VALIDATION_METRIC_FORMATS, + PhaseType.inference: _VALIDATION_METRIC_FORMATS, PhaseType.test: _VALIDATION_METRIC_FORMATS, } diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 8be45e1c..f9805e64 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -35,7 +35,7 @@ def get_model_class(cls) -> type["CustomModel"]: return CustomModel @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfaceCustomModelForCausalLM"]: + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceCustomModelForCausalLM"]: from fast_llm.models.custom.huggingface import HuggingfaceCustomModelForCausalLM return HuggingfaceCustomModelForCausalLM diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 418f948e..3852d83f 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -147,7 +147,7 @@ def get_model_class(cls) -> type["GPTModel"]: return GPTModel @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM return HuggingfaceGPTModelForCausalLM diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 0da4acbb..7e668e73 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -5,10 +5,11 @@ import torch import transformers.modeling_outputs + from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig -from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel +from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner @@ -22,7 +23,7 @@ class HuggingfaceGPTModelConfig(HuggingfaceModelConfig): fast_llm_config: GPTModelConfig -class HuggingfaceGPTModelForCausalLM(HuggingfacePreTrainedModel): +class HuggingfaceGPTModelForCausalLM(HuggingfaceBaseModelForCausalLM): config_class = HuggingfaceGPTModelConfig config: HuggingfaceGPTModelConfig runner_class: typing.ClassVar[type[GPTInferenceRunner]] = GPTInferenceRunner @@ -55,21 +56,33 @@ def forward( if output_attentions: raise NotImplementedError() - if output_hidden_states: - raise NotImplementedError() - if attention_mask is not None: - raise NotImplementedError() - if position_ids is not None: - raise NotImplementedError() if inputs_embeds is not None: raise NotImplementedError() if labels is not None: raise NotImplementedError() + # NOTE: We are ignoring position_ids as we reconstruct them from attention_mask via sequence_lenghts. + if attention_mask is not None: + + # First non zero indexes or zero index if the row is all zeros (invalid row) + first_non_zero_indexes = attention_mask.argmax(dim=1) + + # Check if the sequence is left-padded and if the remaining ones are continuous 1-ns + assert (attention_mask.sum(axis=1) == (attention_mask.shape[1] - first_non_zero_indexes)).all() + + sequence_lenghts = [ + torch.tensor( + [attention_mask.shape[1]] if el == 0 else [el, attention_mask.shape[1] - el], dtype=torch.int64 + ) + for el in first_non_zero_indexes.tolist() + ] + else: + sequence_lenghts = None + # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) batch = self.fast_llm_base_model.preprocess( - GPTBatch(input_ids), phase=PhaseType.inference, iteration=iteration + GPTBatch(input_ids, sequence_lengths=sequence_lenghts), phase=PhaseType.inference, iteration=iteration ) ((input_, kwargs),) = batch @@ -82,23 +95,40 @@ def forward( # The transformers will save the present keys and values to this list. kwargs[TransformerKwargs.presents] = [] + if output_hidden_states: + kwargs["output_hidden_states"] = True + kwargs["hidden_states"] = {} + else: + kwargs["output_hidden_states"] = False + self._inference_runner.forward(input_, kwargs, iteration=iteration) # TODO: Make a proper way of returning the model output. logits = kwargs["logits"] + # TODO: convert hidden state form dict to list to be the same as with HFs + hidden_states = None + if output_hidden_states: + hidden_states = kwargs["hidden_states"] + if not return_dict: - outputs = (logits,) + # TODO: check hidden state go before past in the tuple + if output_hidden_states: + outputs = (logits, hidden_states) + else: + outputs = (logits,) + if use_cache: outputs += (kwargs[TransformerKwargs.presents],) return outputs return transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, + hidden_states=hidden_states, past_key_values=kwargs[TransformerKwargs.presents], ) - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - raise NotImplementedError() + # def prepare_inputs_for_generation( + # self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + # ): + # raise NotImplementedError() diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index d177a41d..c9879dfe 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -407,6 +407,64 @@ class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel + def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, sequence_length) -> tuple[int, int]: + # TODO: Do in model, automate/generalize, get other stats + """Get tflop/s/GPU from global-batch-size and elapsed-time""" + checkpoint_activations_factor = 3 if phase == PhaseType.training else 1 + transformer_config = self._config.base_model.transformer + + consumed_tokens_per_iteration = sequence_length * batch_size + + num_transformer_layers = transformer_config.num_layers + self._config.base_model.prediction_heads - 1 + transformer_flops_base = ( + 2 * checkpoint_activations_factor * consumed_tokens_per_iteration * num_transformer_layers + ) + dense_flops_base = transformer_flops_base * transformer_config.hidden_size + # Query, key, value, dense. + flops_per_iteration = ( + 2 + * (transformer_config.num_attention_heads + transformer_config.head_groups) + * transformer_config.kv_channels + * dense_flops_base + ) + # MLP + flops_per_iteration += ( + (2 + transformer_config.gated) + * transformer_config.ffn_hidden_size + * dense_flops_base + * transformer_config.num_experts_per_token + ) + + # LM-head + flops_per_iteration += ( + 6 + * consumed_tokens_per_iteration + * transformer_config.hidden_size + * self._config.base_model.vocab_size + * self._config.base_model.prediction_heads + ) + + # Attention-matrix computation + attn_flops_base = transformer_flops_base * transformer_config.projection_size + if transformer_config.window_size is None: + # Ignore masked values (s**2/2) + attn_flops = attn_flops_base * sequence_length + model_tflops = flops_per_iteration + attn_flops + else: + # s*w - w**2/2 + attn_flops = ( + 2 + * attn_flops_base + * transformer_config.window_size + * (1 - transformer_config.window_size / 2 / sequence_length) + ) + model_tflops = flops_per_iteration + attn_flops + + # Partial recomputation (normal is 2 ops * ckpt_factor = 6, adding 1 for recomputing Q x K) + hardware_flops = flops_per_iteration + 7 / 6 * attn_flops + ratio = elapsed_time_per_iteration * self._config.distributed.world_size * 1e12 + return model_tflops / ratio, hardware_flops / ratio + class GPTInferenceRunner(InferenceRunner): model_class: typing.ClassVar[type[GPTModel]] = GPTModel diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 3bdb05c3..4e3b24a0 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -33,59 +33,3 @@ def _get_sampling_parameters( } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) - - def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: - # TODO: Do in model, automate/generalize, get other stats - """Get tflop/s/GPU from global-batch-size and elapsed-time""" - checkpoint_activations_factor = 3 if phase == PhaseType.training else 1 - transformer_config = self._config.model.base_model.transformer - sequence_length = self._config.batch.sequence_length - - tokens = self._config.batch.batch_size * sequence_length - num_transformer_layers = transformer_config.num_layers + self._config.model.base_model.prediction_heads - 1 - transformer_flops_base = 2 * checkpoint_activations_factor * tokens * num_transformer_layers - dense_flops_base = transformer_flops_base * transformer_config.hidden_size - # Query, key, value, dense. - flops_per_iteration = ( - 2 - * (transformer_config.num_attention_heads + transformer_config.head_groups) - * transformer_config.kv_channels - * dense_flops_base - ) - # MLP - flops_per_iteration += ( - (2 + transformer_config.gated) - * transformer_config.ffn_hidden_size - * dense_flops_base - * transformer_config.num_experts_per_token - ) - - # LM-head - flops_per_iteration += ( - 6 - * tokens - * transformer_config.hidden_size - * self._config.model.base_model.vocab_size - * self._config.model.base_model.prediction_heads - ) - - # Attention-matrix computation - attn_flops_base = transformer_flops_base * transformer_config.projection_size - if transformer_config.window_size is None: - # Ignore masked values (s**2/2) - attn_flops = attn_flops_base * sequence_length - model_tflops = flops_per_iteration + attn_flops - else: - # s*w - w**2/2 - attn_flops = ( - 2 - * attn_flops_base - * transformer_config.window_size - * (1 - transformer_config.window_size / 2 / sequence_length) - ) - model_tflops = flops_per_iteration + attn_flops - - # Partial recomputation (normal is 2 ops * ckpt_factor = 6, adding 1 for recomputing Q x K) - hardware_flops = flops_per_iteration + 7 / 6 * attn_flops - ratio = elapsed_time_per_iteration * self._config.model.distributed.world_size * 1e12 - return model_tflops / ratio, hardware_flops / ratio diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index 0cc02f42..b36a294d 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -15,11 +15,13 @@ def fast_llm(args=None): # (Pre-)configure logging configure_logging() parser = argparse.ArgumentParser(add_help=False) - parser.add_argument("subcommand", choices=["train", "convert", "prepare"]) + parser.add_argument("subcommand", choices=["train", "evaluate", "convert", "prepare"]) parsed, unparsed = parser.parse_known_args(args) try: if parsed.subcommand == "train": from fast_llm.tools.train import CliTrainingConfig as Runnable + elif parsed.subcommand == "evaluate": + from fast_llm.tools.evaluate import CliEvaluationConfig as Runnable elif parsed.subcommand == "convert": from fast_llm.tools.convert import ConversionConfig as Runnable elif parsed.subcommand == "prepare": diff --git a/fast_llm/tools/evaluate.py b/fast_llm/tools/evaluate.py new file mode 100644 index 00000000..26a9aa9c --- /dev/null +++ b/fast_llm/tools/evaluate.py @@ -0,0 +1,25 @@ +import argparse + +from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.models.auto import trainer_registry + + +class CliEvaluationConfig(RunnableConfig): + @classmethod + def _get_parser(cls): + parser = super()._get_parser() + parser.add_argument( + "model_type", + choices=trainer_registry.keys(), + help="The Fast-LLM model type to use. Must be defined in the trainer registry in `fast_llm.models.auto`.", + ) + return parser + + @classmethod + def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): + unparsed += ['training.train_iters=0'] + return trainer_registry[parsed.model_type]._from_parsed_args(parsed, unparsed) + + +if __name__ == "__main__": + CliEvaluationConfig.parse_and_run() diff --git a/test.py b/test.py new file mode 100644 index 00000000..e02fb32d --- /dev/null +++ b/test.py @@ -0,0 +1,216 @@ +import torch + +from pathlib import Path +import shutil +import cloudpickle + +from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers.modeling_outputs import CausalLMOutputWithPast +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig +from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat + +import torch + + +def generate(model, input_ids, attention_mask, max_new_tokens, tensors_save_path: Path | None = None): + + if tensors_save_path is not None: + if tensors_save_path.is_dir(): + shutil.rmtree(tensors_save_path, ignore_errors=True) + logits_save_path = tensors_save_path / "logits" + hs_save_path = tensors_save_path / "hidden_states" + logits_save_path.mkdir(exist_ok=True, parents=True) + hs_save_path.mkdir(exist_ok=True, parents=True) + + # assume attention mask is left padded with zeroes if any + mask_step = torch.ones((attention_mask.shape[0], 1), dtype=torch.int64).to(attention_mask.device) + for i in range(max_new_tokens): + output: CausalLMOutputWithPast = model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=False, + output_attentions=False, + output_hidden_states=True, + return_dict=True, + ) + current_ids = output.logits[:, -1, :].argmax(dim=1, keepdim=True) + input_ids = torch.cat([input_ids, current_ids], dim=1) + attention_mask = torch.cat([attention_mask, mask_step], dim=1) + + if tensors_save_path is not None: + logits_file = logits_save_path / f"tensor{i}.pt" + torch.save(output.logits, logits_file) + + hidden_states_file = hs_save_path / f"data{i}.pickle" + with hidden_states_file.open("wb") as f: + cloudpickle.dump(output.hidden_states, f) + + return input_ids + + +def diff_flm_hf(tokenizer, flm_tokens, hf_tokens): + print("+++++++++++++++fast_llm:+++++++++++++++++++++++++++++++++++++++++++++++++") + fllm_str = tokenizer.decode(flm_tokens) + print(fllm_str) + print("---------------hugging_face:---------------------------------------------") + hf_str = tokenizer.decode(hf_tokens) + print(hf_str) + print( + f"==============================({"Same" if fllm_str==hf_str else "Different"})=====================================" + ) + + +def run_test_fast_llm( + attn_implementation, + torch_dtype, + is_batch_size2, + reverse_samples, + tensors_save_path, + num_new_tokens, +): + checkpoint = "/mnt/checkpoints/pretrained_models/SmolLM2-135M-Instruct" + + device = "cuda" # for GPU usage or "cpu" for CPU usage + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + messages = [ + # {"role": "user", "content": "What is gravity?"}, + {"role": "user", "content": "What is gravity?"}, + {"role": "user", "content": "Who is the president of EU?"}, + # {"role": "user", "content": "Who is the president of EU?"}, + ] + if reverse_samples: + messages = list(reversed(messages)) + if not is_batch_size2: + messages = messages[0:1] + + input_text = [tokenizer.apply_chat_template([el], tokenize=False) for el in messages] + + tokenizer.padding_side = "left" + inputs = tokenizer(input_text, padding="longest", return_tensors="pt").to(device) + + fm_kwards = {} + if attn_implementation is not None and attn_implementation == "flash_attention_2": + fm_kwards["attn_implementation"] = "flash_attention_2" + else: + fm_kwards["attn_implementation"] = "fuse" + if torch_dtype is not None and torch_dtype == torch.bfloat16: + fm_kwards["torch_dtype"] = "bf16" + + print("fm_kwards", fm_kwards) + + model_fm = HuggingfaceGPTModelForCausalLM.from_pretrained( + CheckpointLoadConfig( + path=checkpoint, + format=LlamaGPTHuggingfaceCheckpointFormat, + ), + **fm_kwards, + ) + + # outputs_fm = model_fm.generate(**inputs, max_new_tokens=50, use_cache=False) + outputs_fm = generate( + model_fm, **inputs, max_new_tokens=num_new_tokens, tensors_save_path=tensors_save_path / "fast_llm" + ) + + print(tokenizer.decode(outputs_fm[0][inputs["input_ids"].shape[1] :])) + if len(outputs_fm) > 1: + print("--------------------------------------------------------------") + print(tokenizer.decode(outputs_fm[1][inputs["input_ids"].shape[1] :])) + + +def run_test( + attn_implementation, + torch_dtype, + is_batch_size2, + reverse_samples, + tensors_save_path, + num_new_tokens, +): + checkpoint = "/mnt/checkpoints/pretrained_models/SmolLM2-135M-Instruct" + + device = "cuda" # for GPU usage or "cpu" for CPU usage + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + # for multiple GPUs install accelerate and do `model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto")` + hf_kwards = {} + if attn_implementation is not None and attn_implementation == "flash_attention_2": + hf_kwards["attn_implementation"] = "flash_attention_2" + if torch_dtype is not None: + hf_kwards["torch_dtype"] = torch_dtype + + print("hf_kwards", hf_kwards) + model_hf = AutoModelForCausalLM.from_pretrained(checkpoint, **hf_kwards).to(device) + + messages = [ + # {"role": "user", "content": "What is gravity?"}, + {"role": "user", "content": "Who is the president of EU?"}, + {"role": "user", "content": "Who is the president of EU?"}, + ] + if reverse_samples: + messages = list(reversed(messages)) + if not is_batch_size2: + messages = messages[0:1] + + input_text = [tokenizer.apply_chat_template([el], tokenize=False) for el in messages] + + tokenizer.padding_side = "left" + inputs = tokenizer(input_text, padding="longest", return_tensors="pt").to(device) + + # outputs_hf = model_hf.generate(**inputs, max_new_tokens=50, use_cache=False) + outputs_hf = generate( + model_hf, **inputs, max_new_tokens=num_new_tokens, tensors_save_path=tensors_save_path / "hf" + ) + # print(tokenizer.decode(outputs_hf[0])) + + fm_kwards = {} + if attn_implementation is not None and attn_implementation == "flash_attention_2": + fm_kwards["attn_implementation"] = "flash_attention_2" + else: + fm_kwards["attn_implementation"] = "fuse" + if torch_dtype is not None and torch_dtype == torch.bfloat16: + fm_kwards["torch_dtype"] = "bf16" + + print("fm_kwards", fm_kwards) + + model_fm = HuggingfaceGPTModelForCausalLM.from_pretrained( + CheckpointLoadConfig( + path=checkpoint, + format=LlamaGPTHuggingfaceCheckpointFormat, + ), + **fm_kwards, + ) + + # outputs_fm = model_fm.generate(**inputs, max_new_tokens=50, use_cache=False) + outputs_fm = generate( + model_fm, **inputs, max_new_tokens=num_new_tokens, tensors_save_path=tensors_save_path / "fast_llm" + ) + + diff_flm_hf( + tokenizer, outputs_fm[0][inputs["input_ids"].shape[1] :], outputs_hf[0][inputs["input_ids"].shape[1] :] + ) + if len(outputs_fm) > 1: + diff_flm_hf( + tokenizer, outputs_fm[1][inputs["input_ids"].shape[1] :], outputs_hf[1][inputs["input_ids"].shape[1] :] + ) + + +def main(): + run_test_fast_llm( + # run_test( + attn_implementation="flash_attention_2", + # attn_implementation=None, + torch_dtype=torch.bfloat16, + # torch_dtype=None, + is_batch_size2=True, + reverse_samples=False, + # tensors_save_path=Path("/mnt/datasets/tests/denis/tensors_bf16_flash_attention_2_batch_size2/"), + tensors_save_path=Path("/mnt/datasets/tests/denis/tmp/"), + num_new_tokens=100, + ) + + +if __name__ == "__main__": + main() diff --git a/test_distributed.py b/test_distributed.py new file mode 100644 index 00000000..1a4f5a5f --- /dev/null +++ b/test_distributed.py @@ -0,0 +1,133 @@ +# distributed_example.py +import os +import torch +import torch.distributed as dist + +from dataclasses import dataclass +import functools +import time + +from transformers import AutoTokenizer +from transformers.modeling_outputs import CausalLMOutputWithPast + +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM + +from fast_llm.core.distributed import scatter_object_list, gather_object + + +def run( + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + checkpoint="/mnt/checkpoints/pretrained_models/Qwen2-1.5B-Instruct/", +): + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + + updates = { + ("base_model", "transformer", "use_flash_attention"): attn_implementation is not None + and attn_implementation == "flash_attention_2", + ("distributed", "tensor_parallel"): 1, + ("distributed", "pipeline_parallel"): 1, + ("distributed", "sequence_data_parallel"): 1, + # ("distributed", "sequence_tensor_parallel"): True, + } + + if torch_dtype is not None and torch_dtype == torch.bfloat16: + updates[("distributed", "training_dtype")] = "bf16" + + print("aupdatesgs", updates) + + model_fm = HuggingfaceGPTModelForCausalLM.from_pretrained( + CheckpointLoadConfig( + path=checkpoint, + format=Qwen2GPTHuggingfaceCheckpointFormat, + model_weights=True, + ), + updates, + ) + + device = model_fm._inference_runner._fast_llm_model.distributed.device + rank = model_fm._inference_runner._fast_llm_model.distributed._config.rank + word_size = model_fm._inference_runner._fast_llm_model.distributed._config.world_size + + if rank == 0: + batch_size = 32 + length = 20 + + num_batches = 10 + t0 = time.time() + for i in range(num_batches): + input_ids = torch.randint( + 1, + tokenizer.vocab_size, + (batch_size, length), + dtype=torch.int64, + generator=torch.Generator().manual_seed(42+i), + ).to(device) + + step = batch_size // word_size + scatter_list = [(input_ids[i * step: (i + 1) * step], True) for i in range(word_size)] + + params = [None] + scatter_object_list(device, params, scatter_list, model_fm._inference_runner._fast_llm_model.distributed.world_group, 0) + input_ids = params[0][0] + + #res = model_fm.generate(input_ids, max_new_tokens=50, use_cache=False) + + res = input_ids + res = res.to("cpu") + + + global_res = [None] * word_size + gather_object(device, res, global_res, model_fm._inference_runner._fast_llm_model.distributed.world_group, 0) + + res = torch.cat(global_res, dim=0) + print(res.shape, res.sum().item()) + + scatter_list = [(None, False) for i in range(word_size)] + params = [None] + scatter_object_list(device, params, scatter_list, model_fm._inference_runner._fast_llm_model.distributed.world_group, 0) + + print(time.time() - t0) + + else: + while True: + scatter_list = None + + params = [None] + scatter_object_list(device, params, scatter_list, model_fm._inference_runner._fast_llm_model.distributed.world_group, 0) + input_ids, continue_generate = params[0] + if not continue_generate: + break + + #res = model_fm.generate(input_ids, max_new_tokens=50, use_cache=False) + + res = input_ids + res = res.to("cpu") + + if rank == 0: + global_res = [None] * word_size + else: + global_res = None + gather_object(device, res, global_res, model_fm._inference_runner._fast_llm_model.distributed.world_group, 0) + + + # res = model_fm.forward(input_ids, use_cache=False) + # if res.logits is not None: + # print(res.logits.shape, res.logits.sum().item()) + # print(res.logits.argmax(dim=2, keepdim=False)) + + # else: + # print("None") + + +def main(): + run() + + +if __name__ == "__main__": + main() + print("exiting") diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 257947e9..91d62942 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -32,7 +32,7 @@ from tests.compare_tensor_logs import CompareConfig, compare_logged_tensor TEST_MODEL_CONFIG_CLS = model_registry[TEST_MODEL_TYPE] -TEST_MODEL_HF_CLS = TEST_MODEL_CONFIG_CLS.get_huggingface_model_class() +TEST_MODEL_HF_CLS = TEST_MODEL_CONFIG_CLS.get_huggingface_model_for_causal_lm_class() TEST_MODEL_CLS = TEST_MODEL_CONFIG_CLS.get_model_class() TEST_BASE_MODEL_CONFIG_CLS = TEST_MODEL_CONFIG_CLS.get_base_model_config_class()