Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added configuration management using pydantic #986

Draft
wants to merge 71 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
da9014f
976 add pydantic configuration v1 (#34)
benmalef Feb 12, 2025
79173fd
fix codacy errors
benmalef Feb 12, 2025
c5882f8
fix codacy error eval -> ast.literal_eval()
benmalef Feb 12, 2025
83bbd06
Merge branch 'mlcommons:master' into 976-add-pydantic-configuration
benmalef Feb 25, 2025
be917cb
update the model final_layer parameter with literals
benmalef Feb 25, 2025
a70c88b
change the model amp parameter description
benmalef Feb 25, 2025
71ade3f
add None in the final_layers_options
benmalef Feb 25, 2025
8837b7c
define the pydantic version in the setup
benmalef Feb 25, 2025
6a9dcb4
add grid_aggregator_overlap literals in the default_parameters model
benmalef Feb 26, 2025
343b3c2
set proportional parameter in nested_training model to False
benmalef Feb 26, 2025
382c9a4
add literals for type parameter in the patch_sampler model
benmalef Feb 26, 2025
a1f110a
add literals for type parameter in the patch_sampler model
benmalef Feb 26, 2025
baf25f7
add literals for type parameter in the model_parameters
benmalef Feb 26, 2025
3c85107
changed the version description
benmalef Feb 26, 2025
16eae59
changed the differential_privacy from Any to Union[bool, dict]
benmalef Feb 26, 2025
638b012
blacked .
benmalef Feb 26, 2025
988ba8d
blacked . and remove unnecessary imports
benmalef Feb 26, 2025
f6558d1
fix the patch_sample error
benmalef Feb 26, 2025
1a607d1
refactor: Add dimension literal to check the input dimensions
benmalef Mar 13, 2025
cd98f60
refactor: blacked .
benmalef Mar 13, 2025
e8b4786
Merge branch 'master' into 976-add-pydantic-configuration
benmalef Mar 16, 2025
c74e2ca
feat: Add test for model configuration
benmalef Mar 17, 2025
87226f2
blacked .
benmalef Mar 17, 2025
243b2fb
Merge branch 'master' into 976-add-pydantic-configuration
benmalef Mar 17, 2025
afc0c21
fix: fix model configuration
benmalef Mar 17, 2025
7569112
Merge remote-tracking branch 'origin/976-add-pydantic-configuration' …
benmalef Mar 17, 2025
c00f29d
feat: Update parameters configuration error handling
benmalef Mar 17, 2025
c010381
black .
benmalef Mar 17, 2025
063fabe
black .
benmalef Mar 17, 2025
00c8658
feat: Update tests
benmalef Mar 17, 2025
5724814
feat: Update tests
benmalef Mar 17, 2025
b267f3d
feat: Update tests
benmalef Mar 17, 2025
a902835
Merge remote-tracking branch 'origin/976-add-pydantic-configuration' …
benmalef Mar 17, 2025
c00d979
refactor: Update the base_model names with "config"
benmalef Mar 18, 2025
a1d97e8
black .
benmalef Mar 18, 2025
0b4b0d7
fix: Update the test_model_fail_generic_config
benmalef Mar 18, 2025
36dd469
fix: Update the test_model_fail_generic_config
benmalef Mar 18, 2025
7856768
fix: Update the test_model_fail_generic_config
benmalef Mar 18, 2025
c39e875
blacked .
benmalef Mar 18, 2025
8ccd2c9
fix: Remove unused imports and rename the Model to ModelConfig
benmalef Mar 18, 2025
c1314bd
fix: Remove unused imports
benmalef Mar 18, 2025
24d26c2
feat: Update the tests
benmalef Mar 18, 2025
b51dbf9
feat: Add documentation
benmalef Mar 19, 2025
e5edd80
feat: Update the documentation
benmalef Mar 19, 2025
c149d01
feat: Update the documentation
benmalef Mar 19, 2025
cb7b95a
refactor: change file names in the configuration dir
benmalef Mar 20, 2025
8555096
refactor: update differential_privacy parameter
benmalef Mar 20, 2025
a846fc8
refactor: black .
benmalef Mar 20, 2025
ce3e978
refactor: black .
benmalef Mar 20, 2025
2a11d15
test: Update test
benmalef Mar 20, 2025
73fa67c
test: Update normtype_segmentation_rad_3d test
benmalef Mar 20, 2025
af06473
refactor: Made some minor changes
benmalef Mar 20, 2025
38da826
refactor: blacked .
benmalef Mar 20, 2025
fb736b4
refactor: rename and move the files in parent dir
benmalef Mar 21, 2025
15b3f57
feat: Add configuration for each scheduler
benmalef Mar 22, 2025
dd47c21
refactor: Fix a typo
benmalef Mar 22, 2025
7c1e0b3
refactor: update the scheduler_config
benmalef Mar 22, 2025
047f616
refactor: remove unnecessary comments from the test
benmalef Mar 22, 2025
91c1938
feat: add optimizer config classes
benmalef Mar 22, 2025
79257aa
refactor: Remove unnecessary imports
benmalef Mar 22, 2025
77192cc
refactor: change the tuple
benmalef Mar 22, 2025
52d3caf
test: Update test to solve the !!python/tuple error
benmalef Mar 22, 2025
57d431a
test: Update test to solve the !!python/tuple error
benmalef Mar 22, 2025
9ee0331
test: Update test to solve the !!python/tuple error
benmalef Mar 22, 2025
e01843d
test: Update test to solve the !!python/tuple error
benmalef Mar 22, 2025
dc6563e
feat: update data_postprocessing config
benmalef Mar 23, 2025
e80724b
feat: update data_postprocessing config
benmalef Mar 23, 2025
2defeb8
feat: update the nested_training testing parameter to be less >10
benmalef Mar 24, 2025
427ad94
refactor: Remove unnecessary imports
benmalef Mar 24, 2025
4876d32
refactor: update model configuration architecture
benmalef Mar 24, 2025
f7c481b
refactor: update model configuration
benmalef Mar 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
742 changes: 26 additions & 716 deletions GANDLF/config_manager.py

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions GANDLF/configuration/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
### Parameters Configuration
We use the Pydantic library for parameter configuration. Parameters are organized by context within the base model classes described below.

#### Basic Classes
- **DefaultParameters**: Contains parameters initialized directly from the application.
- **UserDefinedParameters**: Contains parameters that the user must define.
##### Other Subclasses
- **ModelConfig**: Contains parameters specific to the model.
- **OptimizerConfig**: Contains parameters for the optimizer.
- **SchedulerConfig**: Contains parameters for the scheduler.
- **NestedTrainingConfig**: Contains parameters for nested training.
- **PatchSampleConfig**: Contains parameters for the patch sampler.

#### How to Define New Parameters
To define new parameters, add new parameters directly in the classes.
Also, create a new BaseModel class and add it to one of the basic classes (UserDefinedParameters or DefaultParameters).

If validation is required, you can define it in the validators file.
For more details, refer to the [Pydantic documentation](https://docs.pydantic.dev/latest/).

Empty file.
73 changes: 73 additions & 0 deletions GANDLF/configuration/default_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from pydantic import BaseModel, Field, AfterValidator
from typing import Dict
from typing_extensions import Literal, Optional, Annotated

from GANDLF.configuration.validators import validate_postprocessing

GRID_AGGREGATOR_OVERLAP_OPTIONS = Literal["crop", "average", "hann"]

Check warning on line 7 in GANDLF/configuration/default_config.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`hann` is not a recognized word. (unrecognized-spelling)


class DefaultParameters(BaseModel):
weighted_loss: bool = Field(
default=False, description="Whether weighted loss is to be used or not."
)
verbose: bool = Field(default=False, description="General application verbosity.")
q_verbose: bool = Field(default=False, description="Queue construction verbosity.")
medcam_enabled: bool = Field(
default=False, description="Enable interpretability via medcam."
)
save_training: bool = Field(
default=False, description="Save outputs during training."
)
save_output: bool = Field(
default=False, description="Save outputs during validation/testing."
)
in_memory: bool = Field(default=False, description="Pin data to CPU memory.")
pin_memory_dataloader: bool = Field(
default=False, description="Pin data to GPU memory."
)
scaling_factor: int = Field(
default=1, description="Scaling factor for regression problems."
)
q_max_length: int = Field(default=100, description="The max length of the queue.")
q_samples_per_volume: int = Field(
default=10, description="Number of samples per volume."
)
q_num_workers: int = Field(
default=4, description="Number of worker threads to use."
)
num_epochs: int = Field(default=100, description="Total number of epochs to train.")
patience: int = Field(
default=100, description="Number of epochs to wait for performance improvement."
)
batch_size: int = Field(default=1, description="Default batch size for training.")
learning_rate: float = Field(default=0.001, description="Default learning rate.")
clip_grad: Optional[float] = Field(
default=None, description="Gradient clipping value."
)
track_memory_usage: bool = Field(
default=False, description="Enable memory usage tracking."
)
memory_save_mode: bool = Field(
default=False,
description="Enable memory-saving mode. If enabled, resize/resample will save files to disk.",
)
print_rgb_label_warning: bool = Field(
default=True, description="Print a warning for RGB labels."
)
data_postprocessing: Annotated[
dict,
Field(description="Default data postprocessing configuration.", default={}),
AfterValidator(validate_postprocessing),
]

grid_aggregator_overlap: GRID_AGGREGATOR_OVERLAP_OPTIONS = Field(
default="crop", description="Default grid aggregator overlap strategy."
)
determinism: bool = Field(
default=False, description="Enable deterministic computation."
)
previous_parameters: Optional[Dict] = Field(
default=None,
description="Previous parameters to be used for resuming training and performing sanity checks.",
)
16 changes: 16 additions & 0 deletions GANDLF/configuration/differential_privacy_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing_extensions import Literal

from pydantic import BaseModel, Field, ConfigDict

ACCOUNTANT_OPTIONS = Literal["rdp", "gdp", "prv"]


class DifferentialPrivacyConfig(BaseModel):
model_config = ConfigDict(extra="allow")
noise_multiplier: float = Field(default=10.0)
max_grad_norm: float = Field(default=1.0)
accountant: ACCOUNTANT_OPTIONS = Field(default="rdp")
secure_mode: bool = Field(default=False)
allow_opacus_model_fix: bool = Field(default=True)
delta: float = Field(default=1e-5)
physical_batch_size: int = Field(validate_default=True)
1 change: 1 addition & 0 deletions GANDLF/configuration/exclude_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
exclude_parameters = {"differential_privacy"}
82 changes: 82 additions & 0 deletions GANDLF/configuration/model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from pydantic import BaseModel, model_validator, Field, AliasChoices, ConfigDict
from typing_extensions import Self, Literal, Optional
from typing import Union
from GANDLF.configuration.validators import validate_class_list, validate_norm_type
from GANDLF.models import global_models_dict

# Define model architecture options
ARCHITECTURE_OPTIONS = Literal[tuple(global_models_dict.keys())]
# Define model norm_type options
NORM_TYPE_OPTIONS = Literal["batch", "instance", "none"]
# Define model final_layer options
FINAL_LAYER_OPTIONS = Literal[
"sigmoid",
"softmax",
"logsoftmax",
"tanh",
"identity",
"logits",
"regression",
"None",
"none",
]
TYPE_OPTIONS = Literal["torch", "openvino"]
DIMENSIONS_OPTIONS = Literal[2, 3]


# You can define new parameters for model here. Please read the pydantic documentation.
# It allows extra fields in model dict.
class ModelConfig(BaseModel):
model_config = ConfigDict(
extra="allow"
) # it allows extra fields in the model dict
dimension: Optional[DIMENSIONS_OPTIONS] = Field(
description="model input dimension (2D or 3D)."
)
architecture: ARCHITECTURE_OPTIONS = Field(description="Architecture.")
final_layer: FINAL_LAYER_OPTIONS = Field(description="Final layer.")
norm_type: Optional[NORM_TYPE_OPTIONS] = Field(
description="Normalization type.", default="batch"
) # TODO: check it again
base_filters: Optional[int] = Field(
description="Base filters.", default=None, validate_default=True
) # default is 32
class_list: Union[list, str] = Field(default=[], description="Class list.")
num_channels: Optional[int] = Field(
description="Number of channels.",
validation_alias=AliasChoices(
"num_channels", "n_channels", "channels", "model_channels"
),
default=3,
) # TODO: check it
type: TYPE_OPTIONS = Field(description="Type of model.", default="torch")
data_type: str = Field(description="Data type.", default="FP32")
save_at_every_epoch: bool = Field(default=False, description="Save at every epoch.")
amp: bool = Field(default=False, description="Automatic mixed precision")
ignore_label_validation: Union[int, None] = Field(
default=None, description="Ignore label validation."
) # TODO: To check it
print_summary: bool = Field(default=True, description="Print summary.")

@model_validator(mode="after")
def model_validate(self) -> Self:
# TODO: Change the print to logging.warnings
self.class_list = validate_class_list(
self.class_list
) # init and validate the class_list parameter
self.norm_type = validate_norm_type(
self.norm_type, self.architecture
) # init and validate the norm type
if self.amp is False:
print("NOT using Mixed Precision Training")

if self.save_at_every_epoch:
print(
"WARNING: 'save_at_every_epoch' will result in TREMENDOUS storage usage; use at your own risk."
) # TODO: It is better to use logging.warning

if self.base_filters is None:
self.base_filters = 32
print("Using default 'base_filters' in 'model': ", self.base_filters)

return self
25 changes: 25 additions & 0 deletions GANDLF/configuration/nested_training_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self, Optional


class NestedTraining(BaseModel):
stratified: bool = Field(
default=False,
description="this will perform stratified k-fold cross-validation but only with offline data splitting",
)
testing: int = Field(
default=-5,
description="this controls the number of testing data folds for final model evaluation; [NOT recommended] to disable this, use '1'",
le=10,
)
validation: int = Field(
default=-5,
description="this controls the number of validation data folds to be used for model *selection* during training (not used for back-propagation)",
)
proportional: Optional[bool] = Field(default=False)

@model_validator(mode="after")
def validate_nested_training(self) -> Self:
if self.proportional is not None:
self.stratified = self.proportional
return self
121 changes: 121 additions & 0 deletions GANDLF/configuration/optimizer_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import Tuple

from pydantic import BaseModel, Field, ConfigDict
from typing_extensions import Literal

from GANDLF.optimizers import global_optimizer_dict

# takes the keys from global optimizer
OPTIMIZER_OPTIONS = Literal[tuple(global_optimizer_dict.keys())]


class sgd_config(BaseModel):
momentum: float = Field(default=0.99)
weight_decay: float = Field(default=3e-05)
dampening: float = Field(default=0)
nesterov: bool = Field(default=True)


class asgd_config(BaseModel):
alpha: float = Field(default=0.75)
t0: float = Field(default=1e6)
lambd: float = Field(default=1e-4)
weight_decay: float = Field(default=3e-05)


class adam_config(BaseModel):
betas: Tuple[float, float] = Field(default=(0.9, 0.999))
weight_decay: float = Field(default=0.00005)
eps: float = Field(default=1e-8)
amsgrad: bool = Field(default=False)


class adamax_config(BaseModel):
betas: Tuple[float, float] = Field(default=(0.9, 0.999))
weight_decay: float = Field(default=0.00005)
eps: float = Field(default=1e-8)


class rprop_config(BaseModel):
etas: Tuple[float, float] = Field(default=(0.5, 1.2))
step_sizes: Tuple[float, float] = Field(default=(1e-6, 50))


class adadelta_config(BaseModel):
rho: float = Field(default=0.9)
eps: float = Field(default=1e-6)
weight_decay: float = Field(default=3e-05)


class adagrad_config(BaseModel):
lr_decay: float = Field(default=0)
eps: float = Field(default=1e-6)
weight_decay: float = Field(default=3e-05)


class rmsprop_config(BaseModel):
alpha: float = Field(default=0.99)
eps: float = Field(default=1e-8)
centered: bool = Field(default=False)
momentum: float = Field(default=0)
weight_decay: float = Field(default=3e-05)


class radam_config(BaseModel):
betas: Tuple[float, float] = Field(default=(0.9, 0.999))
eps: float = Field(default=1e-8)
weight_decay: float = Field(default=3e-05)
foreach: bool = Field(default=None)


class nadam_config(BaseModel):
betas: Tuple[float, float] = Field(default=(0.9, 0.999))
eps: float = Field(default=1e-8)
weight_decay: float = Field(default=3e-05)
foreach: bool = Field(default=None)


class novograd_config(BaseModel):
betas: Tuple[float, float] = Field(default=(0.9, 0.999))
eps: float = Field(default=1e-8)
weight_decay: float = Field(default=3e-05)
amsgrad: bool = Field(default=False)


class ademamix_config(BaseModel):
pass # TODO: Check it because the default parameters are not in the optimizer dict


class lion_config(BaseModel):
betas: Tuple[float, float] = Field(default=(0.9, 0.999))
weight_decay: float = Field(default=0.0)
decoupled_weight_decay: bool = Field(default=False)


class adopt_config(BaseModel):
pass # TODO: Check it because the default parameters are not in the optimizer dict


class OptimizerConfig(BaseModel):
model_config = ConfigDict(extra="allow")
type: OPTIMIZER_OPTIONS = Field(description="Type of optimizer to use")


optimizer_dict_config = {
"sgd": sgd_config,
"asgd": asgd_config,
"adam": adam_config,
"adamw": adam_config,
"adamax": adamax_config,
# "sparseadam": sparseadam,
"rprop": rprop_config,
"adadelta": adadelta_config,
"adagrad": adagrad_config,
"rmsprop": rmsprop_config,
"radam": radam_config,
"novograd": novograd_config,
"nadam": nadam_config,
"ademamix": ademamix_config,
"lion": lion_config,
"adopt": adopt_config,
}
10 changes: 10 additions & 0 deletions GANDLF/configuration/parameters_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pydantic import BaseModel, ConfigDict
from GANDLF.configuration.user_defined_config import UserDefinedParameters


class ParametersConfiguration(BaseModel):
model_config = ConfigDict(extra="allow")


class Parameters(ParametersConfiguration, UserDefinedParameters):
pass
11 changes: 11 additions & 0 deletions GANDLF/configuration/patch_sampler_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from pydantic import BaseModel, Field
from typing_extensions import Literal

TYPE_OPTIONS = Literal["uniform", "label"]


class PatchSamplerConfig(BaseModel):
type: TYPE_OPTIONS = Field(default="uniform")
enable_padding: bool = Field(default=False)
padding_mode: str = Field(default="symmetric")
biased_sampling: bool = Field(default=False)
10 changes: 10 additions & 0 deletions GANDLF/configuration/post_processing_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Any


class PostProcessingConfig(BaseModel):
model_config = ConfigDict(extra="forbid", exclude_none=True)
fill_holes: Any = Field(default=None)
mapping: dict = Field(default=None)
morphology: Any = Field(default=None)
cca: Any = Field(default=None)
Loading
Loading