-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
1,128 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Swin-Unet | ||
The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"(https://arxiv.org/abs/2105.05537). A validation for U-shaped Swin Transformer. | ||
|
||
## 1. Download pre-trained swin transformer model (Swin-T) | ||
* [Get pre-trained model in this link] (https://drive.google.com/drive/folders/1UC3XOoezeum0uck4KBVGa8osahs6rKUY?usp=sharing): Put pretrained Swin-T into folder "pretrained_ckpt/" | ||
|
||
## 2. Prepare data | ||
|
||
- The datasets we used are provided by TransUnet's authors. Please go to ["./datasets/README.md"](datasets/README.md) for details, or please send an Email to jienengchen01 AT gmail.com to request the preprocessed data. If you would like to use the preprocessed data, please use it for research purposes and do not redistribute it (following the TransUnet's License). | ||
|
||
## 3. Environment | ||
|
||
- Please prepare an environment with python=3.7, and then use the command "pip install -r requirements.txt" for the dependencies. | ||
|
||
## 4. Train/Test | ||
|
||
- Run the train script on synapse dataset. The batch size we used is 24. If you do not have enough GPU memory, the bacth size can be reduced to 12 or 6 to save memory. | ||
|
||
- Train | ||
|
||
```bash | ||
sh train.sh or python train.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --root_path your DATA_DIR --max_epochs 150 --output_dir your OUT_DIR --img_size 224 --base_lr 0.05 --batch_size 24 | ||
``` | ||
|
||
- Test | ||
|
||
```bash | ||
sh test.sh or python test.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --is_saveni --volume_path your DATA_DIR --output_dir your OUT_DIR --max_epoch 150 --base_lr 0.05 --img_size 224 --batch_size 24 | ||
``` | ||
|
||
## References | ||
* [TransUnet](https://github.com/Beckschen/TransUNet) | ||
* [SwinTransformer](https://github.com/microsoft/Swin-Transformer) | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@misc{cao2021swinunet, | ||
title={Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation}, | ||
author={Hu Cao and Yueyue Wang and Joy Chen and Dongsheng Jiang and Xiaopeng Zhang and Qi Tian and Manning Wang}, | ||
year={2021}, | ||
eprint={2105.05537}, | ||
archivePrefix={arXiv}, | ||
primaryClass={eess.IV} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
# -------------------------------------------------------- | ||
# Swin Transformer | ||
# Copyright (c) 2021 Microsoft | ||
# Licensed under The MIT License [see LICENSE for details] | ||
# Written by Ze Liu | ||
# --------------------------------------------------------' | ||
|
||
import os | ||
import yaml | ||
from yacs.config import CfgNode as CN | ||
|
||
_C = CN() | ||
|
||
# Base config files | ||
_C.BASE = [''] | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Data settings | ||
# ----------------------------------------------------------------------------- | ||
_C.DATA = CN() | ||
# Batch size for a single GPU, could be overwritten by command line argument | ||
_C.DATA.BATCH_SIZE = 2 | ||
# Path to dataset, could be overwritten by command line argument | ||
_C.DATA.DATA_PATH = '' | ||
# Dataset name | ||
_C.DATA.DATASET = 'imagenet' | ||
# Input image size | ||
_C.DATA.IMG_SIZE = 448 | ||
# Interpolation to resize image (random, bilinear, bicubic) | ||
_C.DATA.INTERPOLATION = 'bicubic' | ||
# Use zipped dataset instead of folder dataset | ||
# could be overwritten by command line argument | ||
_C.DATA.ZIP_MODE = False | ||
# Cache Data in Memory, could be overwritten by command line argument | ||
_C.DATA.CACHE_MODE = 'part' | ||
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. | ||
_C.DATA.PIN_MEMORY = True | ||
# Number of data loading threads | ||
_C.DATA.NUM_WORKERS = 8 | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Model settings | ||
# ----------------------------------------------------------------------------- | ||
_C.MODEL = CN() | ||
# Model type | ||
_C.MODEL.TYPE = 'swin' | ||
# Model name | ||
_C.MODEL.NAME = 'swin_tiny_patch4_window7_224' | ||
# Checkpoint to resume, could be overwritten by command line argument | ||
_C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth' | ||
_C.MODEL.RESUME = '' | ||
# Number of classes, overwritten in data preparation | ||
_C.MODEL.NUM_CLASSES = 1000 | ||
# Dropout rate | ||
_C.MODEL.DROP_RATE = 0.0 | ||
# Drop path rate | ||
_C.MODEL.DROP_PATH_RATE = 0.1 | ||
# Label Smoothing | ||
_C.MODEL.LABEL_SMOOTHING = 0.1 | ||
|
||
# Swin Transformer parameters | ||
_C.MODEL.SWIN = CN() | ||
_C.MODEL.SWIN.PATCH_SIZE = 4 | ||
_C.MODEL.SWIN.IN_CHANS = 3 | ||
_C.MODEL.SWIN.EMBED_DIM = 96 | ||
_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] | ||
_C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] | ||
_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] | ||
_C.MODEL.SWIN.WINDOW_SIZE = 7 | ||
_C.MODEL.SWIN.MLP_RATIO = 4. | ||
_C.MODEL.SWIN.QKV_BIAS = True | ||
_C.MODEL.SWIN.QK_SCALE = None | ||
_C.MODEL.SWIN.APE = False | ||
_C.MODEL.SWIN.PATCH_NORM = True | ||
_C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Training settings | ||
# ----------------------------------------------------------------------------- | ||
_C.TRAIN = CN() | ||
_C.TRAIN.START_EPOCH = 0 | ||
_C.TRAIN.EPOCHS = 300 | ||
_C.TRAIN.WARMUP_EPOCHS = 20 | ||
_C.TRAIN.WEIGHT_DECAY = 0.05 | ||
_C.TRAIN.BASE_LR = 5e-4 | ||
_C.TRAIN.WARMUP_LR = 5e-7 | ||
_C.TRAIN.MIN_LR = 5e-6 | ||
# Clip gradient norm | ||
_C.TRAIN.CLIP_GRAD = 5.0 | ||
# Auto resume from latest checkpoint | ||
_C.TRAIN.AUTO_RESUME = True | ||
# Gradient accumulation steps | ||
# could be overwritten by command line argument | ||
_C.TRAIN.ACCUMULATION_STEPS = 0 | ||
# Whether to use gradient checkpointing to save memory | ||
# could be overwritten by command line argument | ||
_C.TRAIN.USE_CHECKPOINT = False | ||
|
||
# LR scheduler | ||
_C.TRAIN.LR_SCHEDULER = CN() | ||
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine' | ||
# Epoch interval to decay LR, used in StepLRScheduler | ||
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 | ||
# LR decay rate, used in StepLRScheduler | ||
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 | ||
|
||
# Optimizer | ||
_C.TRAIN.OPTIMIZER = CN() | ||
_C.TRAIN.OPTIMIZER.NAME = 'adamw' | ||
# Optimizer Epsilon | ||
_C.TRAIN.OPTIMIZER.EPS = 1e-8 | ||
# Optimizer Betas | ||
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) | ||
# SGD momentum | ||
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Augmentation settings | ||
# ----------------------------------------------------------------------------- | ||
_C.AUG = CN() | ||
# Color jitter factor | ||
_C.AUG.COLOR_JITTER = 0.4 | ||
# Use AutoAugment policy. "v0" or "original" | ||
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' | ||
# Random erase prob | ||
_C.AUG.REPROB = 0.25 | ||
# Random erase mode | ||
_C.AUG.REMODE = 'pixel' | ||
# Random erase count | ||
_C.AUG.RECOUNT = 1 | ||
# Mixup alpha, mixup enabled if > 0 | ||
_C.AUG.MIXUP = 0.8 | ||
# Cutmix alpha, cutmix enabled if > 0 | ||
_C.AUG.CUTMIX = 1.0 | ||
# Cutmix min/max ratio, overrides alpha and enables cutmix if set | ||
_C.AUG.CUTMIX_MINMAX = None | ||
# Probability of performing mixup or cutmix when either/both is enabled | ||
_C.AUG.MIXUP_PROB = 1.0 | ||
# Probability of switching to cutmix when both mixup and cutmix enabled | ||
_C.AUG.MIXUP_SWITCH_PROB = 0.5 | ||
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem" | ||
_C.AUG.MIXUP_MODE = 'batch' | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Testing settings | ||
# ----------------------------------------------------------------------------- | ||
_C.TEST = CN() | ||
# Whether to use center crop when testing | ||
_C.TEST.CROP = True | ||
|
||
# ----------------------------------------------------------------------------- | ||
# Misc | ||
# ----------------------------------------------------------------------------- | ||
# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') | ||
# overwritten by command line argument | ||
_C.AMP_OPT_LEVEL = '' | ||
# Path to output folder, overwritten by command line argument | ||
_C.OUTPUT = '' | ||
# Tag of experiment, overwritten by command line argument | ||
_C.TAG = 'default' | ||
# Frequency to save checkpoint | ||
_C.SAVE_FREQ = 1 | ||
# Frequency to logging info | ||
_C.PRINT_FREQ = 10 | ||
# Fixed random seed | ||
_C.SEED = 0 | ||
# Perform evaluation only, overwritten by command line argument | ||
_C.EVAL_MODE = False | ||
# Test throughput only, overwritten by command line argument | ||
_C.THROUGHPUT_MODE = False | ||
# local rank for DistributedDataParallel, given by command line argument | ||
_C.LOCAL_RANK = 0 | ||
|
||
|
||
def _update_config_from_file(config, cfg_file): | ||
config.defrost() | ||
with open(cfg_file, 'r') as f: | ||
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) | ||
|
||
for cfg in yaml_cfg.setdefault('BASE', ['']): | ||
if cfg: | ||
_update_config_from_file( | ||
config, os.path.join(os.path.dirname(cfg_file), cfg) | ||
) | ||
print('=> merge config from {}'.format(cfg_file)) | ||
config.merge_from_file(cfg_file) | ||
config.freeze() | ||
|
||
|
||
def update_config(config, args): | ||
_update_config_from_file(config, args.cfg) | ||
|
||
config.defrost() | ||
if args.opts: | ||
config.merge_from_list(args.opts) | ||
|
||
# merge from specific arguments | ||
if args.batch_size: | ||
config.DATA.BATCH_SIZE = args.batch_size | ||
if args.zip: | ||
config.DATA.ZIP_MODE = True | ||
if args.cache_mode: | ||
config.DATA.CACHE_MODE = args.cache_mode | ||
if args.resume: | ||
config.MODEL.RESUME = args.resume | ||
if args.accumulation_steps: | ||
config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps | ||
if args.use_checkpoint: | ||
config.TRAIN.USE_CHECKPOINT = True | ||
if args.amp_opt_level: | ||
config.AMP_OPT_LEVEL = args.amp_opt_level | ||
if args.tag: | ||
config.TAG = args.tag | ||
if args.eval: | ||
config.EVAL_MODE = True | ||
if args.throughput: | ||
config.THROUGHPUT_MODE = True | ||
|
||
config.freeze() | ||
|
||
|
||
def get_config(args): | ||
"""Get a yacs CfgNode object with default values.""" | ||
# Return a clone so that the defaults will not be altered | ||
# This is for the "local variable" use pattern | ||
config = _C.clone() | ||
update_config(config, args) | ||
|
||
return config | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from __future__ import print_function, division | ||
import os | ||
from numpy.core.fromnumeric import transpose | ||
from skimage import io,transform ,filters | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import glob | ||
import skimage | ||
import torch | ||
from torch.utils.data import Dataset,DataLoader | ||
import random | ||
from scipy import ndimage | ||
from scipy.ndimage.interpolation import zoom | ||
from torchvision import transforms | ||
|
||
|
||
def random_rot_flip(image, label): | ||
k = np.random.randint(0, 4) | ||
image = np.rot90(image, k) | ||
label = np.rot90(label, k) | ||
axis = np.random.randint(0, 2) | ||
image = np.flip(image, axis=axis).copy() | ||
label = np.flip(label, axis=axis).copy() | ||
return image, label | ||
|
||
|
||
def random_rotate(image, label): | ||
angle = np.random.randint(-20, 20) | ||
image = ndimage.rotate(image, angle, order=0, reshape=False) | ||
label = ndimage.rotate(label, angle, order=0, reshape=False) | ||
return image, label | ||
|
||
|
||
class RandomGenerator(object): | ||
def __init__(self, output_size): | ||
self.output_size = output_size | ||
|
||
def __call__(self, sample): | ||
image, label = sample['img'], sample['mask'] | ||
|
||
if random.random() > 0.5: | ||
image, label = random_rot_flip(image, label) | ||
elif random.random() > 0.5: | ||
image, label = random_rotate(image, label) | ||
x, y = image.shape[:2] | ||
if x != self.output_size[0] or y != self.output_size[1]: | ||
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3? | ||
label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) | ||
image = np.transpose(image,(2,0,1)) | ||
label[label < 127] = 0.0 | ||
label[label > 127] = 1.0 | ||
image = torch.from_numpy(image.astype(np.float32)) / 255.0 | ||
label = torch.from_numpy(label.astype(np.float32)) | ||
sample = {'img': image, 'mask': label.long()} | ||
return sample | ||
|
||
|
||
class CrackSegDataset(Dataset): | ||
|
||
def __init__(self, partition = "train",transform = None): | ||
|
||
self.transform = transform # using transform in torch! | ||
self.base_dir = os.path.join(str(os.getcwd()),"datasets") | ||
self.dataset_dir = os.path.join(self.base_dir,"crack_segmentation_dataset") | ||
self.mask_dir = os.path.join(os.path.join(self.dataset_dir,partition),"masks") | ||
self.img_dir = os.path.join(os.path.join(self.dataset_dir,partition),"images") | ||
self.imgs = os.listdir(self.img_dir) | ||
self.masks = os.listdir(self.mask_dir) | ||
|
||
def __len__(self): | ||
return len(self.imgs) | ||
|
||
def __getitem__(self, idx) : | ||
|
||
img = io.imread(os.path.join(self.img_dir,self.imgs[idx])) | ||
mask = io.imread(os.path.join(self.mask_dir,self.masks[idx])) | ||
sample = {"img" : img ,"mask" : mask} | ||
if self.transform: | ||
sample = self.transform(sample) | ||
|
||
return sample | ||
|
||
|
||
|
||
|
||
if __name__ == "__main__": | ||
db_train = CrackSegDataset(partition = "train", | ||
transform=transforms.Compose( | ||
[RandomGenerator(output_size=[448,448])])) | ||
|
||
trainloader = DataLoader(db_train, batch_size= 4, shuffle=True, num_workers=8, pin_memory=True) | ||
|
||
for sample in trainloader: | ||
print(sample["img"].shape) | ||
print(sample["mask"].shape) | ||
|
||
|
||
break | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
Oops, something went wrong.