[Paper : Disruption Prediction and Analysis Through Multimodal Deep Learning in KSTAR]
- Firstly, we set image sequence data as an input data (B,T,C,W,H) and assumed that the last frame where the image of the plasma in tokamak disapper is a disruption.
- Then, the last second frame of the image sequence can be considered as a current quench.
- Thus, the frame sequences including the last second frame of each experiment data, are labeled as disruptive.
- Under this condition, the neural networks trained by these labeled dataset can predict the disruption prior to a current quench.
Analysis of the models using visualization of hidden vectors
The code was developed using python 3.9 on Ubuntu 18.04
The GPU used : NVIDIA GeForce RTX 3090 24GB x 4
The resources for training networks were provided by PLARE in Seoul National University
-
Environment
conda create env -f environment.yaml conda activate research-env -
Video Dataset Generation : old version, inefficient memory usage and scalability
# generate disruptive video data and normal video data from .avi python3 ./src/generate_video_data_fixed.py --fps 210 --duration 21 --distance 5 --save_path './dataset/' # train and test split with converting video as image sequences python3 ./src/preprocessing.py --test_ratio 0.2 --valid_ratio 0.2 --video_data_path './dataset/dur21_dis0' --save_path './dataset/dur21_dis0' -
Video Dataset Generation : new version, more efficient than old version
# additional KSTAR shot log with frame information of the video data python3 ./src/generate_modified_shot_log.py # generate video dataset from extended KSTAR shot log : you don't need to split the train-test set for every distance python3 ./src/generate_video_data.py --fps 210 --raw_video_path "./dataset/raw_videos/raw_videos/" --df_shot_list_path "./dataset/KSTAR_Disruption_Shot_List_extend.csv" --save_path "./dataset/temp" --width 256 --height 256 --overwrite True -
0D Dataset Generation (Numerical dataset)
# interpolate KSTAR data and convert as tabular dataframe python3 ./src/generate_numerical_data.py
- Test code before model training : check the invalid data or issues from model architecture
# test all process : data + model pytest test # test the data validity pytest test/test_data.py # test the model validity pytest test/test_model.py
-
Models for video data
python3 train_vision_nework.py --batch_size {batch size} --gpu_num {gpu num} --use_LDAM {bool : use LDAM loss} --model_type {model name} --tag {name of experiment / info} --use_DRW {bool : use Deferred re-weighting} --use_RS {bool : use re-sampling} --seq_len {int : input sequence length} --pred_len {int : prediction time} --image_size {int} -
Models for 0D data
python3 train_0D_nework.py --batch_size {batch size} --gpu_num {gpu num} --use_LDAM {bool : use LDAM loss} --model_type {model name} --tag {name of experiment / info} --use_DRW {bool : use Deferred re-weighting} --use_RS {bool : use re-sampling} --seq_len {int : input sequence length} --pred_len {int : prediction time} -
Models for MultiModal(video + 0D data)
python3 train_multi_modal.py --batch_size {batch size} --gpu_num {gpu num} --use_LDAM {bool : use LDAM loss} --use_GB {bool : use Deferred re-weighting} --tag {name of experiment / info} --use_DRW {bool : use Deferred re-weighting} --use_RS {bool : use re-sampling} --seq_len {int : input sequence length} --pred_len {int : prediction time} --tau {int : stride for input sequence}
-
Experiment for each network(vision, 0D, multimodal) with different prediction time
# R1Plus1D sh exp/exp_r1plus1d.sh # Slowfast sh exp/exp_slowfast.sh # ViViT sh exp/exp_vivit.sh # Transformer sh exp/exp_0D_transformer.sh # CnnLSTM sh exp/exp_0D_cnnlstm.sh # MLSTM-FCN sh exp/exp_0D_mlstm.sh # Multimodal model sh exp/exp_multi.sh # Multimodal model with Gradient Blending sh exp/exp_multi_gb.sh -
Experiment with different learning algorithms and models
# case : R2Plus1D sh exp/exp_la_r2plus1d.sh # case : SlowFast sh exp/exp_la_slowfast.sh # case : ViViT sh exp/exp_la_vivit.sh -
Model performance visualization for continuous disruption prediction using gif
python3 make_continuous_prediction.py
-
Video encoder
-
0D data encoder
- Transformer : paper(https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf), application code(https://www.kaggle.com/general/200913)
- Conv1D-LSTM using self-attention : https://pseudo-lab.github.io/Tutorial-Book/chapters/time-series/Ch5-CNN-LSTM.html
- MLSTM_FCN : paper(https://arxiv.org/abs/1801.04503), application code(https://github.com/titu1994/MLSTM-FCN)
-
Multimodal Model
- Multimodal fusion model: video encoder + 0D data encoder
- Tensor Fusion Network
- Other methods (Future work)
- Multimodal deep representation learning for video classification : https://link.springer.com/content/pdf/10.1007/s11280-018-0548-3.pdf?pdf=button
- Truly Multi-modal YouTube-8M Video Classification with Video, Audio, and Text : https://static.googleusercontent.com/media/research.google.com/ko//youtube8m/workshop2017/c06.pdf
-
Solving imbalanced classificatio issue
- Re-Sampling : ImbalancedWeightedSampler, Over-Sampling for minor classes
- Re-Weighting : Define inverse class frequencies as weights to apply with loss function (CE, Focal Loss, LDAM Loss)
- LDAM with DRW : Label-distribution-aware margin loss with deferred re-weighting scheduling
- Multimodal Learning : Gradient Blending for avoiding sub-optimal due to large modalities
- Multimodal Learning : CCA Learning for enhancement
-
Analysis on physical characteristics of disruptive video data
- CAM : proceeding
- Grad CAM : paper(https://arxiv.org/abs/1610.02391), target model(R2Plus1D, SlowFast)
- attention rollout : paper(https://arxiv.org/abs/2005.00928), target model(ViViT)
-
Data augmentation
- Video Mixup Algorithm for Data augmentation(done, not effective)
- Conventional Image Augmentation(Flip, Brightness, Contrast, Blur, shift)
-
Training Process enhancement
- Multigrid training algorithm : Fast training for SlowFast
- Deep CCA : Deep cannonical correlation analysis to train multi-modal representation
-
Generalization and Robustness
- Add noise with image sequence and 0D data for robustness
- Multimodality can also guarantee the robustness from noise of the data
- Gradient Blending for avoiding sub-optimal states from multi-modal learning
- Multi-GPU distributed Learning : done
- Database contruction : Tabular dataset(IKSTAR) + Video dataset, done
- ML Pipeline : Tensorboard, done
- Disruption : disruptive state at t = tipminf (current-quench)
- Borderline : inter-plane region (not used)
- Normal : non-disruptive state
If you use this repository in your research, please cite the following:
Disruption prediction and analysis through multimodal deep learning in KSTAR
Kim, Jinsu, et al. "Disruption prediction and analysis through multimodal deep learning in KSTAR." Fusion Engineering and Design 200 (2024): 114204.
Jinsu Kim (2024). Disruption-Prediciton-based-on-Multimodal-Deep-Learning. GitHub.
https://github.com/ZINZINBIN/Disruption-Prediciton-based-on-Multimodal-Deep-Learning
@software{Kim_Deep_Multimodal_Learning_2024,
author = {Kim, Jinsu},
doi = {https://doi.org/10.1016/j.fusengdes.2024.114204},
license = {MIT},
month = feb,
title = {{Deep Multimodal Learning based KSTAR Disruption Prediction Model}},
url = {https://github.com/ZINZINBIN/Disruption-Prediciton-based-on-Multimodal-Deep-Learning},
version = {1.0.0},
year = {2024}
}
