PyTorch research repo for extending TransUNet on the Synapse multi-organ segmentation benchmark, with the current focus on MSCAF-TransUNet: Multi-Scale CNN Attention Fusion inside the hybrid R50-ViT encoder.
This cleaned version keeps only the research codepath, lightweight utilities, and reproducibility notebooks. AWS/SageMaker and CloudFormation deployment assets were intentionally removed so the repository is easier to read, reproduce, and push to GitHub.
This repo currently centers on MSCAF-TransUNet and its related ablations on top of the hybrid ResNet-50 + ViT-B/16 encoder:
pre_hidden: refine selected CNN scales and fuse them into the hidden feature before patch projectioncnn_fusion: refine selected CNN skip features and fuse multiple CNN scales back into the hidden feature
Latest evaluated MSCAF-TransUNet run:
- Method:
MSCAF-TransUNet - Implementation mode:
cnn_fusion - Scales:
1/8,1/4,1/2 - Mean Dice:
76.61% - Mean HD95:
28.80 - Better than the original TransUNet paper on
HD95,Liver,Pancreas,Spleen, andStomach
Reference baseline from the earlier cleaned reproduction:
- Mean Dice:
77.29% - Mean HD95:
30.71
| Framework | Encoder | Decoder | Average DSC ↑ | HD ↓ | Pancreas | Liver | Spleen | Stomach | Aorta | Gallbladder | Kidney (L) | Kidney (R) |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| MSCAF-TransUNet (Ours) | R50-ViT | CUP | 76.61 | 28.80 | 57.36 | 94.40 | 86.54 | 76.20 | 86.72 | 57.15 | 79.33 | 75.14 |
| TransUNet (paper) | R50-ViT | CUP | 77.48 | 31.69 | 55.86 | 94.08 | 85.08 | 75.62 | 87.23 | 63.13 | 81.87 | 77.02 |
Bold values indicate the better score between MSCAF-TransUNet and the original paper row. The current method first stands out on the metrics it improves: HD95 is lower (28.80 vs 31.69), and organ-wise Dice is higher on Pancreas (57.36 vs 55.86), Liver (94.40 vs 94.08), Spleen (86.54 vs 85.08), and Stomach (76.20 vs 75.62). It still trails the original paper on mean Dice and several other organs.
Relevant implementation files:
- networks/vit_seg_modeling.py
- networks/vit_seg_modeling_resnet_skip.py
- experiment_utils.py
- train.py
- test.py
datasets/ dataset package and Synapse loader
splits/ explicit train/test split metadata
networks/ TransUNet model + hybrid encoder attention modules
notebooks/ Colab notebooks for Drive bootstrap and end-to-end experiments
train.py training entrypoint
test.py evaluation entrypoint
trainer.py training loop with epoch-level resume checkpointing
Recommended:
- Python 3.10 to 3.12
- CUDA-enabled PyTorch
pip install -r requirements.txt
Main Python dependencies are tracked in requirements.txt. PyTorch and torchvision should match your CUDA runtime.
The repo expects preprocessed Synapse data in:
data/
Synapse/
train_npz/
test_vol_h5/
Recommended workflow:
- use notebooks/transunet-drive-data-setup.ipynb to cache the dataset to Google Drive for Colab
- or prepare the Synapse layout manually under
data/Synapse
The hybrid encoder expects the R50-ViT-B/16 ImageNet-21k checkpoint under:
model/vit_checkpoint/imagenet21k/
R50+ViT-B_16.npz
R50-ViT-B_16.npz
Recommended workflow:
- use notebooks/transunet-drive-data-setup.ipynb to cache the pretrained weight to Google Drive for Colab
- or place the checkpoint manually under
model/vit_checkpoint/imagenet21k/
Both filename aliases are supported because different codepaths and notebooks reference both forms.
Example: run MSCAF-TransUNet (cnn_fusion on 1/8,1/4,1/2)
python train.py ^
--dataset Synapse ^
--vit_name R50-ViT-B_16 ^
--attention_mode cnn_fusion ^
--attention_scales 1/8,1/4,1/2Alternative attention experiment:
python train.py ^
--dataset Synapse ^
--vit_name R50-ViT-B_16 ^
--attention_mode pre_hidden ^
--attention_scales 1/8Baseline ablation:
python train.py --dataset Synapse --vit_name R50-ViT-B_16 --attention_mode nonepython test.py ^
--dataset Synapse ^
--vit_name R50-ViT-B_16 ^
--attention_mode cnn_fusion ^
--attention_scales 1/8,1/4,1/2Save NIfTI predictions:
python test.py --dataset Synapse --vit_name R50-ViT-B_16 --is_saveniiFor reproducibility on Google Colab:
- notebooks/transunet-drive-data-setup.ipynb: prepare the Synapse dataset and pretrained TransUNet weight on Google Drive
- notebooks/transunet-cnn-attention-research-colab.ipynb: run the MSCAF-TransUNet experiment end-to-end on Colab with live logs and checkpoint resume
trainer.pysaveslatest_checkpoint.pthevery epoch and can resume automatically.- Package markers were added to
datasets/andnetworks/so Colab does not confuse them with third-party packages. - The repo intentionally no longer contains AWS deployment code, CloudFormation templates, or SageMaker helpers.
If you use this repo, cite the original TransUNet work and document the attention extension separately in your report or paper.
@article{chen2024transunet,
title={TransUNet: Rethinking the U-Net architecture design for medical image segmentation through the lens of transformers},
author={Chen, Jieneng and Mei, Jieru and Li, Xianhang and Lu, Yongyi and Yu, Qihang and Wei, Qingyue and Luo, Xiangde and Xie, Yutong and Adeli, Ehsan and Wang, Yan and others},
journal={Medical Image Analysis},
pages={103280},
year={2024}
}Apache License 2.0. See LICENSE.