The config is located in backend/src/config.py
and is used to set all kinds of parameters for the entire project (e.g. logging, paths, training parameters, api settings, etc.)
Make sure to check it out before running the pipeline. Important parameters should be annotated.
Once we have chosen a baseline model, we will hyperparameter tune it with the dicts specified in the config. (TODO: move the dicts to separate files for each model)
During the training pipeline, we populate the data/intermediate
directory and model artifacts in the models
directory. Once the splits are created, you can disable the creation of new splits by setting CREATE_STATIC_SPLITS
to False
in the config. This will use the existing splits in the data/intermediate
directory.
The intermediate data entails a static train/val/test split where only the train set is augmented. This is to prevent data leakage during evaluation and testing.
To reduce class imbalance, we augment underrepresented classes using a dynamic scaling approach. For each class i, we calculate the target number of samples after augmentation using:
This approach ensures two key constraints:
- The augmented sample count for any class never exceeds the size of the largest original class
- Each original sample is augmented at most MAX_AUGMENT times to maintain data quality
This helps prevent over-augmentation while flattening the class distribution, as shown in the plots below:
The training pipeline can be run from the command line:
python backend/run_train_pipeline.py
Note: The Jupyter notebook version is outdated and should not be used.
During training, all model artifacts are saved in the models
directory with the following structure:
models/
└── datetime_timestamp/
├── cnn/
│ ├── best_model.pth
│ ├── metrics.json
│ └── metrics.png
├── vgg19/
│ ├── best_model.pth
│ ├── metrics.json
│ └── metrics.png
└── vit/
├── best_model.pth
├── metrics.json
└── metrics.png
For inference, we select the best model from these runs and move it to the models/final_models
directory, from where we then load the model type and weights.
Even with the baseline ViT model trained for 2 epochs, we can achieve a very solid performance across all relevant metrics:
{
"val_f1": 82.82044466135217,
"val_loss": 0.4800497365387034,
"val_top3": 100.0,
"val_top5": 100.0,
"train_f1": 86.72970789630963,
"train_loss": 0.4314231148026336
}
For our use-case, top-5 accuracy is the most relevant metric, as we display the top-5 predictions to the user who then makes the final decision.
Very few classes (species) have slightly higher misclassification rates, most likely due to visual similarity:
The Flask API can be started by running:
python backend/run_flask.py
By default, the API is accessible at http://localhost:5000
. The host address and port can be configured in the config file.
For the scope of this project, we use the Flask development server, which is not recommended for production.
The docstrings throughout the project were mostly AI generated but were manually validated.