This repository contains the source code for the paper: HAMIL-QA: Hierarchical Approach to Multiple Instance Learning for Atrial LGE MRI Quality Assessment (MICCAI 2024) to be presented in MICCAI 2024.
This project implements a two-tier attention mechanism for medical image quality assessment:
- Tier 1: Patch-level feature extraction using ResNet encoder with attention
- Tier 2: Bag-level aggregation with attention for final quality prediction
The model uses pseudo-bags created from extracted patches to perform quality assessment in a multiple instance learning framework.
HAMIL-QA/
├── main.py # Main training script
├── config.py # Configuration settings
├── model/
│ └── qc_model.py # Model architectures
├── scripts/
│ ├── qc_trainer.py # Training logic
│ └── qc_dataset.py # Dataset handling
├── util/
│ ├── data_transforms_config.py
│ ├── transformation_utils.py
│ └── early_stopping.py
└── model/saved_models/ # Saved model checkpoints
Your dataset should be organized as follows:
/path/to/dataset/
├── UTXYZ/
│ ├── data.nrrd # Medical image (LGE MRI)
│ └── shrinkwrap.nrrd # Left atrium segmentation mask
├── UTYZX/
│ ├── data.nrrd
│ └── shrinkwrap.nrrd
├── UTZXY/
│ ├── data.nrrd
│ └── shrinkwrap.nrrd
└── ...
Requirements:
- Each subject folder contains:
data.nrrd: The medical image file (NRRD format)shrinkwrap.nrrd: The corresponding left atrium segmentation mask
- Folder names should match the subject IDs in the quality labels JSON
Create a JSON file containing quality assessment labels for each subject:
{
"UTXYZ": {
"label": {
"quality_for_fibrosis_assessment": 5.0
}
},
"UTYZX": {
"label": {
"quality_for_fibrosis_assessment": 3.0
}
},
"UTZXY": {
"label": {
"quality_for_fibrosis_assessment": 1.0
}
}
}Label Definition:
quality_for_fibrosis_assessment: Quality score ranging from 1 to 5- 1-2: Poor quality (converted to class 0)
- 3-5: Good quality (converted to class 1)
The model automatically converts these scores to binary labels:
- Scores ≤ 2.0 → Class 0 (Poor)
- Scores > 2.0 → Class 1 (Good)
Update config.py with your dataset paths:
'data_path': '/path/to/your/dataset',
'qc_label_dict': '/path/to/your/quality_labels.json',pip install torch torchvision
pip install monai
pip install timm
pip install pynrrd
pip install Pillow
pip install torchmetrics
pip install scikit-learn
pip install comet-ml # Optional, for experiment tracking
pip install numpy pandas tqdm matplotlib
pip install python-dotenv # For environment variables- Create a
.envfile in the project root: (Optional, for Comet.ml integration)
# .env
COMET_API_KEY=your_comet_api_key_here
PROJECT_NAME=your_project_name
COMET_WORKSPACE=your_workspace_name- Load environment variables (already configured in
main.py):
import os
from dotenv import load_dotenv
load_dotenv() # Load from .env fileEdit config.py to customize:
- Model Parameters:
n_patches: Number of patches to extractno_of_pseudo_bags: Number of pseudo bagspatch_size_2d: Size of each patchoverlap_percentage: Patch overlapenlarge_xy: Enlargement of bounding box
python main.pyThe training script always starts fresh and trains a new model for each fold.
To choose the encoder initialization mode, set encoder_name in config.py:
- From scratch (random initialization):
# in config.py
'encoder_name': 'resnet'You can also use:
'encoder_name': 'resnet_scratch'Run:
python main.py- Pretrained backbone (ImageNet-pretrained TIMM
resnet10t, with most layers frozen):
# in config.py
'encoder_name': 'resnet_pretrained'You can also use:
'encoder_name': 'timm_resnet_pretrained'Run:
python main.pypython main.py \
--learning_rate 1e-4 \
--batch_size 4 \
--epochs 500 \
--no_of_pseudo_bags 4 \
--n_patches 20To run without Comet.ml logging:
# In config.py
disable_comet = TrueIf you use this code, please cite:
@inproceedings{sultan2024hamil,
title={HAMIL-QA: Hierarchical Approach to Multiple Instance Learning for Atrial LGE MRI Quality Assessment},
author={Sultan, KM Arefeen and Hisham, Md Hasibul Husain and Orkild, Benjamin and Morris, Alan and Kholmovski, Eugene and Bieging, Erik and Kwan, Eugene and Ranjan, Ravi and DiBella, Ed and Elhabian, Shireen},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={275--284},
year={2024},
organization={Springer}
}