diff --git a/docs/quick-start.md b/docs/quick-start.md index ee228fb9..e5873712 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -492,9 +492,10 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: train_iters: 100 # (1)! logs: interval: 10 - validation: - iterations: 25 - interval: 100 + evaluations: + validation: + iterations: 25 + interval: 100 export: # (2)! format: llama interval: 100 @@ -508,10 +509,10 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: batch_size: 480 # (5)! data: datasets: - Training: + training: type: file path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (6)! - Validation: + validation: type: file path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)! optimizer: @@ -549,9 +550,10 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: train_iters: 100_000 # (1)! logs: interval: 10 - validation: - iterations: 25 - interval: 1000 + evaluations: + validation: + iterations: 25 + interval: 1000 checkpoint: interval: 1000 keep: 5 @@ -569,10 +571,10 @@ Save the following as `fast-llm-tutorial/train-config.yaml`: batch_size: 512 # (5)! data: datasets: - Training: + training: type: file path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (6)! - Validation: + validation: type: file path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)! optimizer: # (7)! diff --git a/docs/recipes/continue-training.md b/docs/recipes/continue-training.md index 8ea36ebd..a19d965d 100644 --- a/docs/recipes/continue-training.md +++ b/docs/recipes/continue-training.md @@ -33,9 +33,10 @@ This is not much different from a pretraining config. We will: train_iters: 100_000 logs: interval: 10 - validation: - iterations: 25 - interval: 1000 + evaluations: + validation: + iterations: 25 + interval: 1000 checkpoint: interval: 1000 keep: 5 @@ -48,9 +49,13 @@ This is not much different from a pretraining config. We will: sequence_length: 4096 batch_size: 256 data: - format: file - path: fast-llm-tutorial/dataset.json # (2)! - split: [99, 1, 0] + datasets: + training: + type: file + path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (2)! + validation: + type: file + path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (2)! optimizer: weight_decay: 0.1 beta_1: 0.9 @@ -84,8 +89,9 @@ This is not much different from a pretraining config. We will: logs: interval: 10 validation: - iterations: 25 - interval: 1000 + Validation: + iterations: 25 + interval: 1000 checkpoint: interval: 1000 keep: 5 @@ -98,9 +104,13 @@ This is not much different from a pretraining config. We will: sequence_length: 8192 batch_size: 256 data: - format: file - path: fast-llm-tutorial/dataset.json # (2)! - split: [99, 1, 0] + datasets: + training: + type: file + path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml # (6)! + validation: + type: file + path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml # (6)! optimizer: weight_decay: 0.1 beta_1: 0.9 @@ -129,7 +139,7 @@ This is not much different from a pretraining config. We will: ``` 1. A the model will be saved in Hugging Face format to `~/results` directory every 20,000 iterations. -2. Location of the dataset metadata file generated in Step 4. +2. Location of the dataset metadata file generated in Step 4 of quick start guide. 3. The learning-rate can be used to trade-off between learning and forgetting. A higher learning-rate will learn quickly on our new dataset but will cause forgetting. A lower learning-rate will instead retain more of the pretrained model's knowledge, but will slow down adapting to the new domain. 4. Config of the pretrained model. We load the model downloaded from the repository earlier. 5. This tells Fast-LLM to load the weights of the pretrained model. If we wanted to use the model's configuration, but train from scratch, we could use the same config but set this to `no`. diff --git a/docs/recipes/data-configuration.md b/docs/recipes/data-configuration.md index 8bea3280..23da8cc4 100644 --- a/docs/recipes/data-configuration.md +++ b/docs/recipes/data-configuration.md @@ -13,10 +13,10 @@ We already saw an example dataset configuration in the [quick-start guide](../qu ```yaml data: datasets: - Training: + training: type: file path: fast-llm-tutorial/dataset/fast_llm_config_training.yaml - Validation: + validation: type: file path: fast-llm-tutorial/dataset/fast_llm_config_validation.yaml ``` @@ -25,6 +25,16 @@ We already saw an example dataset configuration in the [quick-start guide](../qu In this section we are interested in generalizing step 3. For more details on steps 1 and 2, please refer to the quick-start guide or [this example](data-configuration.md). +The section `data.datasets` holds descriptions of datasets used in training, validation, and testing. + +The Training and Testing phases must have predetermined dataset names: `training` and `testing`, respectively. Each of these phases can have only one dataset. + +For validation datasets, the rules are different. There can be as many validation datasets as needed, and their names are arbitrary. In the example above, the dataset name `validation` is chosen for simplicity. The datasets names used for validation and their application details are specified in the training config `evaluations` sections. + +Adding multiple validation datasets increases flexibility in tracking the accuracy of your trained model. One possible scenario is using a separate validation dataset for each blended training dataset, allowing you to track training progress on each subset separately and observe how the model performs in real time on different subsets of your training data. + +Below are examples of how to configure various aspects of training and validation datasets. + ## Example 1: Blending multiple datasets In this example, we have three datasets and want to sample from each of them during training with probabilities 0.70, 0.25 and 0.05. For this, we use the `blended` type which takes other datasets as arguments: @@ -32,7 +42,7 @@ In this example, we have three datasets and want to sample from each of them dur ```yaml data: datasets: - Training: + training: type: blended datasets: - type: file @@ -54,7 +64,7 @@ In this example, we have a large dataset that comes pre-shuffled, so shuffling i ```yaml data: datasets: - Training: + training: type: file path: path/to/dataset.yaml sampling: @@ -68,10 +78,10 @@ In this example, we want to disable shuffling entirely, but only for the validat ```yaml data: datasets: - Training: + training: type: file path: path/to/training_dataset.yaml - Validation: + validation: type: sampled dataset: type: file @@ -91,7 +101,7 @@ In this example, we have a blend of datasets as in example 1, but we wish to set ```yaml data: datasets: - Training: + training: type: blended datasets: - type: sampled @@ -118,7 +128,34 @@ data: !!! note "Default seed" In the absence of explicit seed, Fast-LLM uses a default seed (`data.sampling`'s default) instead, and uses seed shifts to ensure different seeds for each phase and for the various blended datasets. -## Example 5: Advanced scenario + +## Example 5: Specifying Multiple Validation Datasets + +In this example, we show how to specify multiple validation datasets and configure how often they are applied, along with their usage attributes in the `training.evaluations` section. + +Please note that the same dataset names must be used in the `training.evaluations` section. If a validation dataset is specified in the `datasets` section but not in `training.evaluations`, it will not be used for validation. + +```yaml +training: + evaluations: + the_stack: + iterations: 25 + interval: 50 + fineweb: + iterations: 25 + interval: 100 +data: + datasets: + the_stack: + type: file + path: path/to/validation_the_stack_dataset.yaml + fineweb: + type: file + path: path/to/validation_fineweb_dataset.yaml + +``` + +## Example 6: Advanced scenario In this example, we combine everything we learned so far to create a complex scenario, where: @@ -129,7 +166,7 @@ In this example, we combine everything we learned so far to create a complex sce ```yaml data: datasets: - Training: + training: type: blended datasets: - type: sampled @@ -156,7 +193,7 @@ data: # Seed = default + train_shift + 2 * blend_shift, shuffle = skip_first_epoch path: path/to/dataset_3.yaml weights: [0.70, 0.25, 0.05] - Validation: + validation: type: sampled dataset: type: file @@ -174,10 +211,10 @@ data: ```yaml data: datasets: - Training: + training: type: file path: path/to/training_dataset_config.yaml - Validation: + validation: type: file path: path/to/validation_dataset_config.yaml sampling: diff --git a/docs/recipes/instruction-finetuning.md b/docs/recipes/instruction-finetuning.md index 4ffc983c..15a45426 100644 --- a/docs/recipes/instruction-finetuning.md +++ b/docs/recipes/instruction-finetuning.md @@ -114,9 +114,10 @@ training: train_iters: 5_000 logs: interval: 1 - validation: - iterations: 25 - interval: 1000 + evaluations: + validation: + iterations: 25 + interval: 1000 checkpoint: interval: 1000 keep: 5 @@ -131,10 +132,10 @@ batch: cross_document_attention: no # (1)! data: datasets: - Training: + training: type: file path: ./sft-tutorial/tokenized/Llama-3.1-8B/fast_llm_config_training.yaml - Validation: + validation: type: file path: ./sft-tutorial/tokenized/Llama-3.1-8B/fast_llm_config_validation.yaml truncate_documents: no # (2)! diff --git a/docs/recipes/train.md b/docs/recipes/train.md index 4b59ab6c..9c5f92e5 100644 --- a/docs/recipes/train.md +++ b/docs/recipes/train.md @@ -19,9 +19,10 @@ Let's start from the following training configuration: train_iters: 100_000 logs: interval: 10 - validation: - iterations: 25 - interval: 1000 + evaluations: + validation: + iterations: 25 + interval: 1000 checkpoint: interval: 1000 keep: 5 @@ -34,9 +35,13 @@ Let's start from the following training configuration: sequence_length: 4096 batch_size: 256 data: - format: file - path: fast-llm-tutorial/dataset/fast_llm_dataset.json - split: [99, 1, 0] + datasets: + training: + type: file + path: path/to/training_dataset_config.yaml + validation: + type: file + path: path/to/validation_dataset_config.yaml optimizer: weight_decay: 0.1 beta_1: 0.9 @@ -63,9 +68,10 @@ Let's start from the following training configuration: train_iters: 100_000 logs: interval: 10 - validation: - iterations: 25 - interval: 1000 + evaluations: + validation: + iterations: 25 + interval: 1000 checkpoint: interval: 1000 keep: 5 @@ -78,9 +84,13 @@ Let's start from the following training configuration: sequence_length: 8192 batch_size: 256 data: - format: file - path: fast-llm-tutorial/dataset/fast_llm_dataset.json - split: [99, 1, 0] + datasets: + training: + type: file + path: path/to/training_dataset_config.yaml + validation: + type: file + path: path/to/validation_dataset_config.yaml optimizer: weight_decay: 0.1 beta_1: 0.9 diff --git a/examples/mistral.yaml b/examples/mistral.yaml index d60a7802..f1fa8279 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -3,8 +3,9 @@ training: num_workers: 8 logs: interval: 10 - validation: - iterations: null + evaluations: + validation: + iterations: null test_iters: 0 batch: sequence_length: 4096 @@ -12,7 +13,7 @@ batch: batch_size: 64 data: datasets: - Training: + training: type: random optimizer: learning_rate: diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index 8e691b31..c42301ea 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -13,7 +13,7 @@ class Data[ConfigType: DataConfig](Configurable[ConfigType], abc.ABC): _distributed: "Distributed" - _samples_per_phase: dict[PhaseType, int] + _samples_per_dataset: dict[str, int] _cache_directory: pathlib.Path | None def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> None: @@ -24,12 +24,12 @@ def __init__(self, config: DataConfig, distributed_config: DistributedConfig) -> def setup( self, distributed: "Distributed", - samples_per_phase: dict[PhaseType, int], + samples_per_dataset: dict[str, int], cache_directory: pathlib.Path, timeout: float | None = None, ) -> None: self._distributed = distributed - self._samples_per_phase = samples_per_phase + self._samples_per_dataset = samples_per_dataset self._cache_directory = cache_directory @property @@ -40,7 +40,7 @@ def distributed(self): def get_iterator( self, batch_config: BatchConfig, - phase: PhaseType, + dataset_name: str, *, consumed_samples: int, num_workers: int, diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index cbbfa036..18b1eaac 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,4 +1,5 @@ import logging +import typing from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.config import MultiprocessingContext, TokenizerConfig @@ -39,7 +40,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): hint=FieldHint.feature, ) # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[PhaseType, GPTSampledDatasetConfig] = Field( + datasets: dict[str, GPTSampledDatasetConfig] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, @@ -63,7 +64,26 @@ def _validate(self) -> None: "Using the legacy dataset definition format." " Specify it through `data.datasets` instead." ) self.datasets = { - phase: GPTLegacyDatasetConfig.from_dict(self, strict=False) + phase.value.lower(): GPTLegacyDatasetConfig.from_dict(self, strict=False) for phase in (PhaseType.training, PhaseType.validation, PhaseType.test) } super()._validate() + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + # TODO v0.x: Remove backward compatibility. + if "datasets" in default: + for phase in PhaseType: + if phase.value in default["datasets"]: + rename = phase.value.lower() + logger.warning(f"Renaming dataset {phase.value} to {rename}") + assert rename not in default["datasets"] + default["datasets"][rename] = default["datasets"].pop(phase.value) + + cls._handle_renamed_field(default, "validation", ("evaluations", "validation")) + return super()._from_dict(default, strict, flat) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 37f9950d..8fc33376 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -19,7 +19,7 @@ from fast_llm.data.iterator import SampledDatasetIterator from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.schedule.config import BatchConfig from fast_llm.utils import Assert @@ -56,7 +56,7 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): TODO: Separate generic and GPT classes. """ - _datasets: dict[PhaseType, SampledDataset] + _datasets: dict[str, SampledDataset] _tokenizer: Tokenizer | None _is_setup: bool = False @@ -80,7 +80,7 @@ def __init__( def setup( self, distributed: "Distributed", - samples_per_phase: dict[PhaseType, int], + samples_per_dataset: dict[str, int], cache_directory: pathlib.Path, timeout: float | None = None, ) -> None: @@ -88,7 +88,20 @@ def setup( Load the datasets, and prepare or load the samplings. This may take a while and a significant amount of cpu memory. """ - super().setup(distributed, samples_per_phase, cache_directory) + # Check and raise an error if a used dataset is not defined. + for dataset_name in samples_per_dataset.keys(): + if dataset_name not in self._config.datasets: + raise ValueError(f"Dataset {dataset_name} not found.") + + # Check and warn if there are defined datasets that are not used. + unused_datasets = self._config.datasets.keys() - samples_per_dataset.keys() + if unused_datasets: + warnings.warn( + f"The following datasets are defined but not used: {', '.join(unused_datasets)}. " + "Ensure this is intentional, or update the configuration accordingly." + ) + + super().setup(distributed, samples_per_dataset, cache_directory) log_main_rank(f"Preparing dataset. This may take several minutes.") self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer) @@ -97,23 +110,21 @@ def setup( warnings.warn(f"Using the dataset directory for the index cache.") self._datasets = {} - for phase, num_samples in samples_per_phase.items(): + for dataset_name, num_samples in samples_per_dataset.items(): if num_samples > 0: - # TODO: Do the check earlier. - assert phase in self._config.datasets sampling = GPTSamplingData( - num_samples=samples_per_phase[phase], + num_samples=num_samples, config=self._config.sampling, cache_directory=self._cache_directory, distributed=distributed, - phase=phase, + dataset_name=dataset_name, sequence_length=self._max_sequence_length, vocab_size=self._vocab_size, tokenizer=self._tokenizer, cross_document_attention=self._cross_document_attention, ) - dataset = self._config.datasets[phase].build_and_sample(sampling) - self._datasets[phase] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) + dataset = self._config.datasets[dataset_name].build_and_sample(sampling) + self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) safe_barrier(self._distributed.world_group, "data_preparation", timeout) self._is_setup = True @@ -126,21 +137,26 @@ def tokenizer(self) -> Tokenizer: def get_iterator( self, batch_config: BatchConfig, - phase: PhaseType, + dataset_name: str, *, consumed_samples: int, num_workers: int, prefetch_factor: int | None = None, ) -> typing.Iterator[typing.Any]: assert self._is_setup - Assert.incl(phase, self._datasets) + + # Some dataset names may come from phases and are capitalized, + # so we need to normalize them before use. + dataset_name = dataset_name.lower() + + Assert.incl(dataset_name, self._datasets) Assert.in_range_incl(batch_config.sequence_length, 1, self._max_sequence_length) - log_main_rank(f"Initializing {phase} data iterator from sample {consumed_samples}...") + log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") return iter( torch.utils.data.DataLoader( - self._datasets[phase], # noqa + self._datasets[dataset_name], # noqa batch_sampler=SampledDatasetIterator( - total_samples=len(self._datasets[phase]), + total_samples=len(self._datasets[dataset_name]), begin_index=consumed_samples, micro_batch_size=batch_config.micro_batch_size, data_rank=self._distributed.config.batch_data_rank, diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 431a28a0..ccc51b4a 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -7,7 +7,6 @@ from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset -from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: @@ -40,7 +39,7 @@ class SamplingData: cache_directory: pathlib.Path | None # TODO: This prevents the sampling config from being pickled in multiprocessing. distributed: "Distributed" - phase: PhaseType + dataset_name: str # Using a mutable rather than an int so it's shared with all copies made with `update`. _rank_counter: typing.Iterator[int] = itertools.count diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 74d8a0c3..f91c537e 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -472,11 +472,12 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: raise NotImplementedError(self.format) phase_splits = padded_cumsum(normalize_probabilities(self.split)) + phase_index = { - PhaseType.training: 0, - PhaseType.validation: 1, - PhaseType.test: 2, - }[sampling.phase] + PhaseType.training.value.lower(): 0, + PhaseType.validation.value.lower(): 1, + PhaseType.test.value.lower(): 2, + }[sampling.dataset_name] dataset_configs = [ { diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 30add2f4..dac4e553 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -140,22 +140,22 @@ class WandbConfig(Config): @config_class() -class ValidationConfig(IntervalConfig): +class EvaluationConfig(IntervalConfig): interval = FieldUpdate( - desc="The number of training iterations between each validation phase." - " Setting to None will disable validation." + desc="The number of training iterations between each evaluation phase." + " Setting to None will disable evaluation." ) - offset = FieldUpdate(desc="Offset for the first validation phase.") + offset = FieldUpdate(desc="Offset for the first evaluation phase.") iterations: int | None = Field( default=None, - desc="Number of iterations for each validation phase. Setting to None will disable.", + desc="Number of iterations for each evaluation phase. Setting to None will disable.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - def get_iteration_count(self, training_iterations: int, extra_validations: int = 0): + def get_iteration_count(self, training_iterations: int, extra_evaluations: int = 0): # Number of completed validation iterations - return (self.get_count(training_iterations) + extra_validations) * self.iterations if self.enabled() else 0 + return (self.get_count(training_iterations) + extra_evaluations) * self.iterations if self.enabled() else 0 @config_class() @@ -267,9 +267,9 @@ class ShutdownConfig(IntervalConfig): @config_class() class TrainingConfig(Config): - validation: ValidationConfig = Field( - default_factory=ValidationConfig, - desc="Configuration for the validation phase", + evaluations: dict[str, EvaluationConfig] = Field( + default_factory=dict, + desc="A dictionary of evaluation dataset names and their configurations for the validation phase.", hint=FieldHint.core, ) logs: MetricsLogsConfig = Field( @@ -315,6 +315,17 @@ class TrainingConfig(Config): valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + # TODO v0.x: Remove backward compatibility. + cls._handle_renamed_field(default, "validation", ("evaluations", "validation")) + return super()._from_dict(default, strict, flat) + def _validate(self) -> None: super()._validate() self.shutdown.assert_sub_interval(self.checkpoint) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index d43abe56..f2ed4a38 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -54,29 +54,43 @@ def __init__(self, config: TrainerConfig): distributed_config=self._config.model.distributed, ) steps_per_split = { - PhaseType.training: self._config.training.train_iters, - PhaseType.validation: self._config.training.validation.get_iteration_count( - self._config.training.train_iters, - # There may be an extra validation after the last training step. - not self._config.training.validation.enabled(self._config.training.train_iters), - ), - PhaseType.test: self._config.training.test_iters, + PhaseType.training: {PhaseType.training.value.lower(): self._config.training.train_iters}, + PhaseType.validation: { + dataset_name: self._config.training.evaluations[dataset_name].get_iteration_count( + self._config.training.train_iters, + # There may be an extra evaluation after the last training step. + not self._config.training.evaluations[dataset_name].enabled(self._config.training.train_iters), + ) + for dataset_name in self._config.training.evaluations.keys() + }, + PhaseType.test: {PhaseType.test.value.lower(): self._config.training.test_iters}, } self._samples_per_split = { - phase: self._config.batch.batch_size * steps for phase, steps in steps_per_split.items() if steps > 0 + phase: { + dataset_name: self._config.batch.batch_size * steps + for dataset_name, steps in datasets.items() + if steps > 0 + } + for phase, datasets in steps_per_split.items() } + # Prune empty phases. + self._samples_per_split = {k: v for k, v in self._samples_per_split.items() if len(v) > 0} + self._loss_defs = self._multi_stage.base_model.loss_defs # Setup the schedules self._schedule = { - phase: Schedule( - multi_stage=self._multi_stage, - batch_config=self._config.batch, - schedule_config=self._config.schedule, - distributed_config=self._config.model.distributed, - phase=phase, - ) - for phase in self._samples_per_split + phase: { + dataset_name: Schedule( + multi_stage=self._multi_stage, + batch_config=self._config.batch, + schedule_config=self._config.schedule, + distributed_config=self._config.model.distributed, + phase=phase, + ) + for dataset_name in datasets + } + for phase, datasets in self._samples_per_split.items() } def setup(self, distributed: Distributed, run: Run) -> None: @@ -107,7 +121,11 @@ def setup(self, distributed: Distributed, run: Run) -> None: log_main_rank("Preparing datasets...") self._data.setup( distributed, - self._samples_per_split, + { + dataset_name: steps + for datasets in self._samples_per_split.values() + for dataset_name, steps in datasets.items() + }, None if run.experiment_directory is None else run.experiment_directory / "dataset_cache", timeout=self._config.training.timeout, ) @@ -127,10 +145,9 @@ def _consumed_tokens(self) -> int: assert self._is_setup return self._consumed_samples * self._config.batch.sequence_length - @property - def _completed_validation_steps(self) -> int: - # Number of validation steps performed before the current step - return self._config.training.validation.get_iteration_count(self._completed_steps - 1) + def _get_completed_evaluation_steps(self, dataset_name) -> int: + # Number of evaluations steps performed before the current step + return self._config.training.evaluations[dataset_name].get_iteration_count(self._completed_steps - 1) def run(self) -> None: assert self._is_setup @@ -155,13 +172,14 @@ def _run_training(self) -> None: if done and PhaseType.test in self._samples_per_split: log_main_rank(lambda: f"Running test phase ...") - test_iterator = self._get_data_iterator(PhaseType.test) - metrics[PhaseType.test] = self._evaluate( + test_iterator = self._get_data_iterator(PhaseType.test.value.lower()) + metrics_key = PhaseType.test.value + metrics[metrics_key] = self._evaluate( data_iterator=test_iterator, phase=PhaseType.test, num_iters=self._config.training.test_iters, ) - formatted_metrics = format_metrics(metrics[PhaseType.test], self._loss_defs, PhaseType.test) + formatted_metrics = format_metrics(metrics[metrics_key], self._loss_defs, PhaseType.test) log_main_rank(formatted_metrics) self._wandb.alert("Testing results", formatted_metrics, "WARN") # TODO: This may erase some metrics. @@ -180,11 +198,11 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: ) train_iterator = self._get_data_iterator( - PhaseType.training, + PhaseType.training.value, self._completed_steps, self._config.training.prefetch_factor, ) - valid_iterator = None + evaluation_iterators = {name: None for name in self._config.training.evaluations.keys()} log_main_rank("Training ...") @@ -206,7 +224,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # (Also preprocessing adds overhead) reduced_losses, update_successful, train_metrics = self._runner.run_step( train_iterator, - self._schedule[PhaseType.training], + self._schedule[PhaseType.training][PhaseType.training.value.lower()], iteration=self._completed_steps, return_metrics=is_logging, ) @@ -237,7 +255,8 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: self._config.training.train_iters - self._completed_steps ) model_tflops, hardware_tflops = self.get_tflops(PhaseType.training, time_per_iteration) - metrics[PhaseType.training] = { + metrics_key = PhaseType.training.value + metrics[metrics_key] = { "train_iters": self._config.training.train_iters, "batch_size": self._config.batch.batch_size, "iteration": self._completed_steps, @@ -266,9 +285,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: **get_memory_usage_mib(), } - formatted_metrics = format_metrics( - metrics[PhaseType.training], self._loss_defs, PhaseType.training - ) + formatted_metrics = format_metrics(metrics[metrics_key], self._loss_defs, PhaseType.training) logger.info(formatted_metrics) if self._config.training.wandb.alert.enabled(self._completed_steps): self._wandb.alert("Training results", formatted_metrics, "INFO") @@ -288,24 +305,46 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # Evaluation # TODO: Adjust valid iterator length. if PhaseType.validation in self._samples_per_split and ( - done or self._config.training.validation.enabled(self._completed_steps) + done + or any( + evaluation_conf.enabled(self._completed_steps) + for evaluation_conf in self._config.training.evaluations.values() + ) ): - if valid_iterator is None: - valid_iterator = self._get_data_iterator( - PhaseType.validation, self._completed_validation_steps + formatted_metrics = [] + for dataset_name, evaluation_conf in self._config.training.evaluations.items(): + if not evaluation_conf.enabled(self._completed_steps): + continue + if evaluation_iterators[dataset_name] is None: + evaluation_iterators[dataset_name] = self._get_data_iterator( + dataset_name, self._get_completed_evaluation_steps(dataset_name) + ) + # TODO: formatting metric category as Validation.evaluation_dataset_name + # maybe format each metric with evaluation_dataset_name prefix instead? + # TODO: setting performance metrics per evaluation dataset + # maybe to set aggregate performance metrics for all evaluations datasets? + metric_key = f"{PhaseType.validation.value}.{dataset_name}" + metrics[metric_key] = self._evaluate( + data_iterator=evaluation_iterators[dataset_name], + phase=PhaseType.validation, + num_iters=evaluation_conf.iterations, + begin_iter=self._get_completed_evaluation_steps(dataset_name), + dataset_name=dataset_name, ) - metrics[PhaseType.validation] = self._evaluate( - data_iterator=valid_iterator, - phase=PhaseType.validation, - num_iters=self._config.training.validation.iterations, - begin_iter=self._completed_validation_steps, - ) - formatted_metrics = format_metrics( - metrics[PhaseType.validation], self._loss_defs, PhaseType.validation - ) - log_main_rank(formatted_metrics) - if self._config.training.wandb.alert.enabled(self._completed_steps): - self._wandb.alert("Validation results", formatted_metrics, "INFO") + formatted_metrics.append( + format_metrics( + metrics[metric_key], + self._loss_defs, + PhaseType.validation, + dataset_name=dataset_name, + ) + ) + + if len(formatted_metrics) > 0: + formatted_metrics = "\n".join(formatted_metrics) + log_main_rank(formatted_metrics) + if self._config.training.wandb.alert.enabled(self._completed_steps): + self._wandb.alert("Validation results", formatted_metrics, "INFO") if is_main_rank() and metrics: self._wandb.log_metrics(self._completed_steps, metrics) @@ -326,21 +365,23 @@ def _evaluate( phase: PhaseType, num_iters: int, begin_iter: int = 0, + dataset_name: str | None = None, ) -> dict[str, float | int]: - safe_barrier(self._distributed.world_group, f"{phase.value} begin") + full_phase_name = phase.value if dataset_name is None else f"{phase.value}_{dataset_name}" + safe_barrier(self._distributed.world_group, f"{full_phase_name} begin") begin_time = time.perf_counter() total_losses = {loss_def.name: 0.0 for loss_def in self._loss_defs} for iter_ in range(num_iters): iter_losses, _, _ = self._runner.run_step( - data_iterator, self._schedule[phase], iteration=begin_iter + iter_ + data_iterator, self._schedule[phase][dataset_name], iteration=begin_iter + iter_ ) for name, value in iter_losses.items(): total_losses[name] += value - self._run.save_logged_tensors(f"{phase}_{self._completed_steps}_{iter_}") + self._run.save_logged_tensors(f"{full_phase_name}_{self._completed_steps}_{iter_}") safe_barrier( self._distributed.world_group, - f"{phase.value} end", + f"{full_phase_name} end", ) end_time = time.perf_counter() time_per_iteration = (end_time - begin_time) / num_iters @@ -367,11 +408,11 @@ def _evaluate( return metrics def _get_data_iterator( - self, phase, completed_steps: int = 0, prefetch_factor: int | None = None + self, dataset_name, completed_steps: int = 0, prefetch_factor: int | None = None ) -> typing.Iterator[typing.Any]: return self._data.get_iterator( self._config.batch, - phase, + dataset_name, consumed_samples=completed_steps * self._config.batch.batch_size, num_workers=self._config.training.num_workers, prefetch_factor=prefetch_factor, @@ -400,7 +441,7 @@ def _prepare_training_state(self) -> None: assert self._multi_stage._is_loaded # noqa def _save_checkpoint( - self, config: TrainingCheckpointBaseConfig, metrics: dict[PhaseType, dict[str, float | int]] | None + self, config: TrainingCheckpointBaseConfig, metrics: dict[str, dict[str, float | int]] | None ) -> None: # TODO v0.3: Move barrier, ok file to FastLLMModel checkpoint_base_directory = config.get_save_directory(self._run.experiment_directory) @@ -418,7 +459,7 @@ def _save_checkpoint( "completed_steps": self._completed_steps, } if metrics is not None: - metadata["metrics"] = {key.value: value for key, value in metrics.items()} + metadata["metrics"] = metrics self._multi_stage.save_checkpoint( config.get_save_config(checkpoint_directory, timeout=self._config.training.timeout), metadata ) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 997b426b..ffeb56f6 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -48,7 +48,7 @@ } _VALIDATION_METRIC_FORMATS = ( - "{phase} @ iteration {iteration:6.0f}/{train_iters:6.0f}" + "{phase}{dataset_name} @ iteration {iteration:6.0f}/{train_iters:6.0f}" " | consumed samples: {consumed_samples:12,.0f}" " | consumed tokens: {consumed_tokens:16,.0f}" " | batch size: {batch_size:3.0f}" @@ -101,13 +101,17 @@ } -def format_metrics(metrics: dict[str, float | int], loss_defs: list[LossDef], phase: PhaseType) -> str: +def format_metrics( + metrics: dict[str, float | int], loss_defs: list[LossDef], phase: PhaseType, dataset_name: str | None = None +) -> str: # TODO: Improve, add flexibility. metrics = {key: _FORMAT_MAP[key](value) if key in _FORMAT_MAP else value for key, value in metrics.items()} outputs = [ _METRIC_FORMATS[phase].format( - phase=phase, **{key: metrics.pop(key, _NAN) for key in _METRIC_FORMATS_KEYS[phase]} + phase=phase, + dataset_name="" if dataset_name is None else f"/{dataset_name}", + **{key: metrics.pop(key, _NAN) for key in _METRIC_FORMATS_KEYS[phase]}, ) ] outputs.extend([f"{loss_def.formatted_name}: {metrics.pop(loss_def.name, _NAN):.5f}" for loss_def in loss_defs]) diff --git a/tests/common.py b/tests/common.py index 8b8e57c3..4dd30971 100644 --- a/tests/common.py +++ b/tests/common.py @@ -68,20 +68,20 @@ "training.timeout=30", "batch.batch_size=8", "batch.sequence_length=512", - "data.datasets.Training.type=slice", - "data.datasets.Training.end=0.969", - "data.datasets.Training.dataset.type=memmap", - f"data.datasets.Training.dataset.path={DATASET_PREFIX}", - "data.datasets.Validation.type=slice", - "data.datasets.Validation.begin=0.969", - "data.datasets.Validation.end=0.999", - "data.datasets.Validation.dataset.type=memmap", - f"data.datasets.Validation.dataset.path={DATASET_PREFIX}", - "data.datasets.Test.type=slice", - "data.datasets.Test.begin=0.999", - "data.datasets.Test.end=1", - "data.datasets.Test.dataset.type=memmap", - f"data.datasets.Test.dataset.path={DATASET_PREFIX}", + "data.datasets.training.type=slice", + "data.datasets.training.end=0.969", + "data.datasets.training.dataset.type=memmap", + f"data.datasets.training.dataset.path={DATASET_PREFIX}", + "data.datasets.validation.type=slice", + "data.datasets.validation.begin=0.969", + "data.datasets.validation.end=0.999", + "data.datasets.validation.dataset.type=memmap", + f"data.datasets.validation.dataset.path={DATASET_PREFIX}", + "data.datasets.test.type=slice", + "data.datasets.test.begin=0.999", + "data.datasets.test.end=1", + "data.datasets.test.dataset.type=memmap", + f"data.datasets.test.dataset.path={DATASET_PREFIX}", "optimizer.learning_rate.base=0.0001", ] CONFIG_BASE_MEGATRON = [ diff --git a/tests/data/common.py b/tests/data/common.py index 917b4914..5177b1f1 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -47,7 +47,7 @@ def get_sampling_data( num_samples=num_samples, cache_directory=cache_directory, distributed=distributed, - phase=phase, + dataset_name=phase.value, sequence_length=sequence_length, vocab_size=vocab_size, tokenizer=tokenizer, @@ -62,7 +62,7 @@ def get_dataset_config[T: GPTSampledDatasetConfig](config: dict[str, typing.Any] def get_test_data_and_compare_samples( config: dict, - samples_per_phase: dict[PhaseType, int], + samples_per_dataset: dict[str, int] | int, *, seed: int = 54983, gpu: bool = False, @@ -70,11 +70,16 @@ def get_test_data_and_compare_samples( cache_directory: pathlib.Path | None = None, sequence_length: int = 512, vocab_size=TEST_VOCAB_SIZE, - expected_samples: dict[PhaseType, list[list[int]]], + expected_samples: dict[str, list[list[int]]] | list[list[int]], legacy: bool = False, ) -> GPTData: distributed_config = DistributedConfig(seed=seed if legacy else 87522) distributed = Distributed(distributed_config, use_cpu=True) + if isinstance(samples_per_dataset, int): + samples_per_dataset = {PhaseType.training.value.lower(): samples_per_dataset} + if isinstance(expected_samples, list): + expected_samples = {PhaseType.training.value.lower(): expected_samples} + assert "sampling" not in config config["sampling"] = GPTSamplingDefaultConfig( seed=87522 if legacy else seed, @@ -82,7 +87,7 @@ def get_test_data_and_compare_samples( shuffle=shuffle, ) data = GPTData(GPTDataConfig.from_dict(config), distributed_config, vocab_size, sequence_length) - data.setup(distributed, samples_per_phase, cache_directory) + data.setup(distributed, samples_per_dataset, cache_directory) with NoAutoValidate(): batch_config = BatchConfig(batch_size=1, sequence_length=sequence_length) batch_config.setup(distributed_config) @@ -91,10 +96,9 @@ def get_test_data_and_compare_samples( phase: torch.stack( [batch.token_ids[0] for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)] ) - for phase, samples in samples_per_phase.items() + for phase, samples in samples_per_dataset.items() } for phase, expected_samples_ in expected_samples.items(): - print("jerbn", phase, tokens[phase].tolist()) Assert.all_equal(tokens[phase], expected_samples_) return data diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index fa1bc2a9..de97eaa2 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -4,7 +4,6 @@ import pytest from fast_llm.data.dataset.gpt.config import GPTBlendedDatasetConfig -from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, normalize_probabilities from tests.common import DATASET_CACHE, DATASET_PREFIX, get_test_dataset from tests.data.common import ( @@ -154,9 +153,9 @@ def test_gpt_blended_data(): } } }, - {PhaseType.training: 8}, + 8, sequence_length=5, - expected_samples={PhaseType.training: GPT_BLENDED_SAMPLES}, + expected_samples=GPT_BLENDED_SAMPLES, ) @@ -169,9 +168,9 @@ def test_gpt_blended_data_legacy(): "path": ["0.75", str(DATASET_PREFIX), "0.25", str(_DATASET_PREFIX_MIX_1)], "split": [1, 0, 0], }, - {PhaseType.training: 8}, + 8, sequence_length=5, - expected_samples={PhaseType.training: GPT_BLENDED_LEGACY_SAMPLES}, + expected_samples=GPT_BLENDED_LEGACY_SAMPLES, legacy=True, ) @@ -204,7 +203,7 @@ def test_gpt_blended_mixed_data(): } } }, - {PhaseType.training: 8}, + 8, sequence_length=5, - expected_samples={PhaseType.training: GPT_BLENDED_MIXED_SAMPLES}, + expected_samples=GPT_BLENDED_MIXED_SAMPLES, ) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index c25baa68..1142d536 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,5 +1,4 @@ from fast_llm.data.dataset.gpt.config import GPTConcatenatedDatasetConfig -from fast_llm.engine.distributed.config import PhaseType from tests.common import DATASET_PREFIX, get_test_dataset from tests.data.common import ( compare_indexed_dataset, @@ -49,7 +48,7 @@ def test_gpt_concatenate_data(): } } }, - {PhaseType.training: 8}, + 8, sequence_length=5, - expected_samples={PhaseType.training: GPT_CONCATENATED_SAMPLES}, + expected_samples=GPT_CONCATENATED_SAMPLES, ) diff --git a/tests/data/test_concatenated_memmap.py b/tests/data/test_concatenated_memmap.py index e8f7c149..09929040 100644 --- a/tests/data/test_concatenated_memmap.py +++ b/tests/data/test_concatenated_memmap.py @@ -1,5 +1,4 @@ from fast_llm.data.dataset.gpt.config import GPTConcatenatedMemmapConfig -from fast_llm.engine.distributed.config import PhaseType from tests.common import DATASET_CACHE, get_test_concatenated_memmap_dataset from tests.data.common import ( compare_indexed_dataset, @@ -68,7 +67,7 @@ def test_gpt_concatenated_memmap_data(): } } }, - {PhaseType.training: 8}, + 8, sequence_length=5, - expected_samples={PhaseType.training: CONCATENATED_MEMMAP_SAMPLES}, + expected_samples=CONCATENATED_MEMMAP_SAMPLES, ) diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 5e4c54c5..65fbf369 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -1,7 +1,6 @@ from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.gpt.config import GPTFimSampledDatasetConfig from fast_llm.data.tokenizer import Tokenizer -from fast_llm.engine.distributed.config import PhaseType from tests.common import DATASET_PREFIX, TOKENIZER_PATH, get_test_dataset from tests.data.common import ( compare_sampled_dataset, @@ -71,9 +70,9 @@ def test_gpt_fim_data(): }, "tokenizer": {"path": TOKENIZER_PATH}, }, - {PhaseType.training: 8}, + 8, sequence_length=5, - expected_samples={PhaseType.training: GPT_FIM_SAMPLES}, + expected_samples=GPT_FIM_SAMPLES, ) @@ -86,8 +85,8 @@ def test_gpt_fim_data_legacy(): "tokenizer": {"path": TOKENIZER_PATH}, "split": [1, 0, 0], }, - {PhaseType.training: 8}, + 8, sequence_length=5, - expected_samples={PhaseType.training: GPT_FIM_SAMPLES_LEGACY}, + expected_samples=GPT_FIM_SAMPLES_LEGACY, legacy=True, ) diff --git a/tests/data/test_random.py b/tests/data/test_random.py index 2145eff9..72a6080a 100644 --- a/tests/data/test_random.py +++ b/tests/data/test_random.py @@ -1,5 +1,4 @@ from fast_llm.data.dataset.gpt.config import GPTRandomDatasetConfig -from fast_llm.engine.distributed.config import PhaseType from tests.data.common import ( compare_sampled_dataset, get_dataset_config, @@ -32,17 +31,17 @@ def test_gpt_random_data(): } } }, - {PhaseType.training: 4}, + 4, sequence_length=7, - expected_samples={PhaseType.training: RANDOM_DATASET_EXPECTED_SAMPLES}, + expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, ) def test_gpt_random_data_legacy(): get_test_data_and_compare_samples( {"format": "random"}, - {PhaseType.training: 4}, + 4, sequence_length=7, - expected_samples={PhaseType.training: RANDOM_DATASET_EXPECTED_SAMPLES}, + expected_samples=RANDOM_DATASET_EXPECTED_SAMPLES, legacy=True, ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index da06901b..e622d118 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -6,7 +6,6 @@ from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, ShufflingType from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert from tests.common import DATASET_PREFIX, get_test_dataset from tests.data.common import ( @@ -58,18 +57,18 @@ def test_gpt_sampled_data(): } } }, - {PhaseType.training: 8}, + 8, sequence_length=5, - expected_samples={PhaseType.training: GPT_MEMMAP_SAMPLES}, + expected_samples=GPT_MEMMAP_SAMPLES, ) def test_gpt_sampled_data_legacy(): get_test_data_and_compare_samples( {"format": "list", "path": [str(DATASET_PREFIX)], "split": [1, 0, 0]}, - {PhaseType.training: 8}, + 8, sequence_length=5, - expected_samples={PhaseType.training: GPT_MEMMAP_SAMPLES_LEGACY}, + expected_samples=GPT_MEMMAP_SAMPLES_LEGACY, legacy=True, ) diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 70c1f453..299e2054 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,5 +1,4 @@ from fast_llm.data.dataset.gpt.config import GPTDatasetSliceConfig -from fast_llm.engine.distributed.config import PhaseType from tests.common import DATASET_PREFIX, get_test_dataset from tests.data.common import ( compare_indexed_dataset, @@ -62,19 +61,19 @@ def test_gpt_slice_data(): get_test_data_and_compare_samples( { "datasets": { - "Training": { + "training": { "type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0, "end": 0.0015, }, - "Validation": { + "validation": { "type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0.0015, "end": 0.003, }, - "Test": { + "test": { "type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0.003, @@ -82,11 +81,11 @@ def test_gpt_slice_data(): }, } }, - {PhaseType.training: 4, PhaseType.validation: 8, PhaseType.test: 5}, + {"training": 4, "validation": 8, "test": 5}, sequence_length=5, expected_samples={ - PhaseType.training: GPT_SLICE_TRAINING_SAMPLES, - PhaseType.validation: GPT_SLICE_VALIDATION_SAMPLES, + "training": GPT_SLICE_TRAINING_SAMPLES, + "validation": GPT_SLICE_VALIDATION_SAMPLES, }, ) @@ -95,11 +94,11 @@ def test_gpt_slice_data_legacy(): get_test_dataset() get_test_data_and_compare_samples( {"format": "list", "path": [str(DATASET_PREFIX)], "split": [0.0015, 0.0015, 0.997]}, - {PhaseType.training: 4, PhaseType.validation: 8, PhaseType.test: 5}, + {"training": 4, "validation": 8, "test": 5}, sequence_length=5, expected_samples={ - PhaseType.training: GPT_SLICE_TRAINING_SAMPLES_LEGACY, - PhaseType.validation: GPT_SLICE_VALIDATION_SAMPLES_LEGACY, + "training": GPT_SLICE_TRAINING_SAMPLES_LEGACY, + "validation": GPT_SLICE_VALIDATION_SAMPLES_LEGACY, }, legacy=True, ) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 32f0c3f7..d5685a71 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -47,7 +47,11 @@ def test_checkpoint_and_eval(): run_test_script( f"test_{TEST_MODEL}_checkpoint_and_eval", CONFIG_COMMON - + ["training.checkpoint.interval=1", "training.validation.interval=2", "training.validation.iterations=1"], + + [ + "training.checkpoint.interval=1", + "training.evaluations.validation.interval=2", + "training.evaluations.validation.iterations=1", + ], ) @@ -76,7 +80,11 @@ def test_resume(): run_test_script( f"test_{TEST_MODEL}_resume", CONFIG_COMMON - + ["training.checkpoint.interval=1", "training.validation.interval=2", "training.validation.iterations=1"], + + [ + "training.checkpoint.interval=1", + "training.evaluations.validation.interval=2", + "training.evaluations.validation.iterations=1", + ], compare=f"test_{TEST_MODEL}_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, compare_fn=_compare_resume_fn, diff --git a/tests/test_simple.py b/tests/test_simple.py index 85727bca..b98d9ee9 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -43,9 +43,9 @@ def test_model_dp2_timeout(): # Use a short timeout "model.distributed.timeout=4", # Make a dataset that would timeout under the distributed timeout - 'data.datasets.Training={"type":"test_slow"}', - "data.datasets.Training.type=test_slow", - "data.datasets.Training.sleep=6", + 'data.datasets.training={"type":"test_slow"}', + "data.datasets.training.type=test_slow", + "data.datasets.training.sleep=6", # Use a bigger timeout for the dataset. "training.timeout=10", # Remove testing clutter.