This is an official implementation of "Learning the degradation distribution for medical image superresolution via sparse swin transformer".
Clear medical images are significant for auxiliary diagnoses, but the images generated by various medical devices inevitably contain considerable noise. Although various models have been proposed for denoising, these methods ignore that different types of medical images have different noise levels, which leads to unsatisfactory test results. In addition, collecting a large number of medical images for training denoising models consumes many material resources. To address these issues, we formulated a progressive denoising architecture that contains preliminary and profound denoising. First, we construct a noise level estimation network to estimate the noise level via self-supervised learning and perform preliminary denoising with a dilated blind-spot network. Second, with the learned noise distribution, we synthesize noisy natural images to construct clean-noisy natural image pairs. Finally, we design a novel medical image denoising model for profound denoising by training these pairs. The proposed three-stage learning scheme and progressive denoising architecture not only solve the problem that the denoising model only adapts to a single noise level but also alleviate the lack of medical image pairs. Moreover, we integrate dense attention and sparse attention to constitute the retractable transformer module into the profound denoising model, which reconciles a wider receptive field and enhances the representation ability of the transformer. This allows the denoising model to obtain retractable attention on the input feature and can capture more local and global receptive fields simultaneously. The results of qualitative and quantitative experiments demonstrate that our method outperforms other denoising methods in terms of both qualitative and quantitative aspects.
- Python 3.8
- PyTorch >= 1.8
- cudatoolkit 10.1
Download DIV2K training data (800 training + 100 validtion images) from DIV2K Dataset
python main.py --dir_data ./.. --n_GPUs 0 --rgb_range 1 --chunk_size 128 --n_hashes 4 --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model SSFormer --scale 2 --patch_size 96 --save ./ --data_train DIV2K/cell
python main.py --dir_data ./.. --n_GPUs 0 --rgb_range 1 --chunk_size 128 --n_hashes 4 --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model SSFormer --scale 4 --patch_size 96 --save ./ --data_train DIV2K/cell
1. Download benchmark datasets from https://cv.snu.ac.kr/research/EDSR/benchmark.tar
python main.py --dir_data ../ --model SSFormer --chunk_size 144 --data_test Set5 --n_hashes 4 --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train ./model_x2.pt
This code is built on HarukiYqM/Non-Local-Sparse-Attention and JingyunLiang/SwinIR