ResViTM-Net is a novel hybrid deep learning model designed for automatic Tuberculosis (TB) detection. It skillfully combines local features from CNNs, global context from ViTs, and patient clinical prior information. On three public datasets, the model achieves an overall accuracy of 96.0%, significantly outperforming various baseline models while also realizing faster training speed.
This project is the code implementation for ResViTM-Net.
- OS: Ubuntu 22.04 (WSL2)
- Hardware Recommendation: GPU with at least 16GB of VRAM
- Python Version: 3.12.9
It is recommended to use Anaconda to create an isolated Python environment.
conda create -n ResViTM python=3.12.9
conda activate ResViTMInstall the main dependencies:
pip install -r requirements.txtThe three datasets used in the paper have been consolidated and can be downloaded from the project's Releases page.
Please download the data and organize it in the data directory according to the following structure:
data
│
├── ChinaSet_AllFiles
│ ├── ClinicalReadings
│ └── CXR_png
│
├── MontgomeryCXR
│ └── MontgomerySet
│ ├── ClinicalReadings
│ └── CXR_png
│
└── TB_Chest_Radiography_Database
├── Normal
└── Tuberculosis
The scripts related to data preprocessing are located in the data_process directory.
These scripts are used to train and apply a CNN-based inversion model.
python data_process/train_INV_CNN.py
python data_process/use_INV_CNN.pyThese scripts are for training a U-Net-based lung lobe segmentation model and reconstructing the dataset based on the segmentation results.
python data_process/train_unet.py
python data_process/Rebuild_Dataset.py Run the following command to start training the main model:
python ResViTM-Net/train_ResViTM.pyThe training report will be saved as a text file in the report directory. Model weights and Loss/Accuracy curves will be saved in the model_output directory. An image of the output curve:
Generate and save the ROC curve (PDF) for a trained ResViTM model:
python ResViTM-Net/ROC/ROC.pySample output:
Run data augmentation strategy experiments and save the trained models and reports:
python ResViTM-Net/Augmentation_analysis/Augmentation_analysis.pyTraining scripts for ResViTM models using different loss functions are located in the ResViTM-Net/Adjust_Loss directory. For example:
python ResViTM-Net/Adjust_Loss/BCELoss.pyTo adjust the beta hyperparameter of the SmoothL1Loss function used in ResViTM, use the MoreBeta script:
python ResViTM-Net/MoreBeta/MoreBeta.py -b 0.5The beta parameter must be in the range [0.0, 1.0] and should be specified to one decimal place (e.g., 0.0, 0.5, 1.0).
Code related to ablation studies and GradCAM heatmap visualization is located in the ResViTM-Net/Ablation directory. For example:
python ResViTM-Net/Ablation/woRes.py
python ResViTM-Net/Ablation/woRes_GradCAM.pyScripts for comparing different models are located in the model_compare directory. For example:
python Model_Compare/Resnet18.pyScripts for T-SNE visualization are located in the T-SNE directory. For example:
python T-SNE/Resnet18.pyThe T-SNE visualization results will be saved in the t-SNE_output directory.

