Skip to content

ihatesea69/MSCAF-TransUNet

Repository files navigation

MSCAF-TransUNet for Synapse Segmentation

PyTorch research repo for extending TransUNet on the Synapse multi-organ segmentation benchmark, with the current focus on MSCAF-TransUNet: Multi-Scale CNN Attention Fusion inside the hybrid R50-ViT encoder.

This cleaned version keeps only the research codepath, lightweight utilities, and reproducibility notebooks. AWS/SageMaker and CloudFormation deployment assets were intentionally removed so the repository is easier to read, reproduce, and push to GitHub.

Research focus

This repo currently centers on MSCAF-TransUNet and its related ablations on top of the hybrid ResNet-50 + ViT-B/16 encoder:

  • pre_hidden: refine selected CNN scales and fuse them into the hidden feature before patch projection
  • cnn_fusion: refine selected CNN skip features and fuse multiple CNN scales back into the hidden feature

Current result snapshot

Latest evaluated MSCAF-TransUNet run:

  • Method: MSCAF-TransUNet
  • Implementation mode: cnn_fusion
  • Scales: 1/8,1/4,1/2
  • Mean Dice: 76.61%
  • Mean HD95: 28.80
  • Better than the original TransUNet paper on HD95, Liver, Pancreas, Spleen, and Stomach

Reference baseline from the earlier cleaned reproduction:

  • Mean Dice: 77.29%
  • Mean HD95: 30.71

Comparison with the original TransUNet paper

Framework Encoder Decoder Average DSC ↑ HD ↓ Pancreas Liver Spleen Stomach Aorta Gallbladder Kidney (L) Kidney (R)
MSCAF-TransUNet (Ours) R50-ViT CUP 76.61 28.80 57.36 94.40 86.54 76.20 86.72 57.15 79.33 75.14
TransUNet (paper) R50-ViT CUP 77.48 31.69 55.86 94.08 85.08 75.62 87.23 63.13 81.87 77.02

Bold values indicate the better score between MSCAF-TransUNet and the original paper row. The current method first stands out on the metrics it improves: HD95 is lower (28.80 vs 31.69), and organ-wise Dice is higher on Pancreas (57.36 vs 55.86), Liver (94.40 vs 94.08), Spleen (86.54 vs 85.08), and Stomach (76.20 vs 75.62). It still trails the original paper on mean Dice and several other organs.

Relevant implementation files:

Repository layout

datasets/          dataset package and Synapse loader
splits/            explicit train/test split metadata
networks/          TransUNet model + hybrid encoder attention modules
notebooks/         Colab notebooks for Drive bootstrap and end-to-end experiments
train.py           training entrypoint
test.py            evaluation entrypoint
trainer.py         training loop with epoch-level resume checkpointing

Environment

Recommended:

  • Python 3.10 to 3.12
  • CUDA-enabled PyTorch
  • pip install -r requirements.txt

Main Python dependencies are tracked in requirements.txt. PyTorch and torchvision should match your CUDA runtime.

Data

The repo expects preprocessed Synapse data in:

data/
  Synapse/
    train_npz/
    test_vol_h5/

Recommended workflow:

Pretrained weights

The hybrid encoder expects the R50-ViT-B/16 ImageNet-21k checkpoint under:

model/vit_checkpoint/imagenet21k/
  R50+ViT-B_16.npz
  R50-ViT-B_16.npz

Recommended workflow:

Both filename aliases are supported because different codepaths and notebooks reference both forms.

Training

Example: run MSCAF-TransUNet (cnn_fusion on 1/8,1/4,1/2)

python train.py ^
  --dataset Synapse ^
  --vit_name R50-ViT-B_16 ^
  --attention_mode cnn_fusion ^
  --attention_scales 1/8,1/4,1/2

Alternative attention experiment:

python train.py ^
  --dataset Synapse ^
  --vit_name R50-ViT-B_16 ^
  --attention_mode pre_hidden ^
  --attention_scales 1/8

Baseline ablation:

python train.py --dataset Synapse --vit_name R50-ViT-B_16 --attention_mode none

Evaluation

python test.py ^
  --dataset Synapse ^
  --vit_name R50-ViT-B_16 ^
  --attention_mode cnn_fusion ^
  --attention_scales 1/8,1/4,1/2

Save NIfTI predictions:

python test.py --dataset Synapse --vit_name R50-ViT-B_16 --is_savenii

Colab notebooks

For reproducibility on Google Colab:

Notes

  • trainer.py saves latest_checkpoint.pth every epoch and can resume automatically.
  • Package markers were added to datasets/ and networks/ so Colab does not confuse them with third-party packages.
  • The repo intentionally no longer contains AWS deployment code, CloudFormation templates, or SageMaker helpers.

Citation

If you use this repo, cite the original TransUNet work and document the attention extension separately in your report or paper.

@article{chen2024transunet,
  title={TransUNet: Rethinking the U-Net architecture design for medical image segmentation through the lens of transformers},
  author={Chen, Jieneng and Mei, Jieru and Li, Xianhang and Lu, Yongyi and Yu, Qihang and Wei, Qingyue and Luo, Xiangde and Xie, Yutong and Adeli, Ehsan and Wang, Yan and others},
  journal={Medical Image Analysis},
  pages={103280},
  year={2024}
}

License

Apache License 2.0. See LICENSE.