Skip to content

Commit

Permalink
Released training code for U-Net and SyncNet
Browse files Browse the repository at this point in the history
  • Loading branch information
chunyu-li committed Dec 19, 2024
1 parent b142c50 commit 992ed80
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 8 deletions.
28 changes: 24 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ We present LatentSync, an end-to-end lip sync framework based on audio condition

- [x] Inference code and checkpoints
- [x] Data processing pipeline
- [ ] Training code
- [x] Training code

## 🔧 Setting up the Environment

Expand Down Expand Up @@ -114,11 +114,11 @@ The complete data processing pipeline includes the following steps:

1. Remove the broken video files.
2. Resample the video FPS to 25, and resample the audio to 16000 Hz.
3. Scene detect.
3. Scene detect via [PySceneDetect](https://github.com/Breakthrough/PySceneDetect).
4. Split each video into 5-10 second segments.
5. Remove videos where the face is smaller than 256 $\times$ 256, as well as videos with more than one face.
6. Affine transform the faces according to landmarks, then resize to 256 $\times$ 256.
7. Remove videos with sync conf lower than 3, and adjust the audio-visual offset to 0.
6. Affine transform the faces according to the landmarks detected by [face-alignment](https://github.com/1adrianb/face-alignment), then resize to 256 $\times$ 256.
7. Remove videos with [sync confidence score](https://www.robots.ox.ac.uk/~vgg/publications/2016/Chung16a/chung16a.pdf) lower than 3, and adjust the audio-visual offset to 0.
8. Calculate [hyperIQA](https://openaccess.thecvf.com/content_CVPR_2020/papers/Su_Blindly_Assess_Image_Quality_in_the_Wild_Guided_by_a_CVPR_2020_paper.pdf) score, and remove videos with scores lower than 40.

Run the script to execute the data processing pipeline:
Expand All @@ -128,3 +128,23 @@ Run the script to execute the data processing pipeline:
```

You can change the parameter `input_dir` in the script to specify the data directory to be processed. The processed data will be saved in the same directory. Each step will generate a new directory to prevent the need to redo the entire pipeline in case the process is interrupted by an unexpected error.

## 🏋️‍♂️ Training U-Net

Before training, you must process the data as described above and download all the checkpoints. We released a pretrained SyncNet with 94% accuracy on the VoxCeleb2 dataset for the supervision of U-Net training. Note that this SyncNet is trained on affine transformed videos, so when using or evaluating this SyncNet, you need to perform affine transformation on the video first (the code of affine transformation is included in the data processing pipeline).

If all the preparations are complete, you can train the U-Net with the following script:

```bash
./train_unet.sh
```

You should change the parameters in U-Net config file to specify the data directory, checkpoint save path, and other training hyperparameters.

## 🏋️‍♂️ Training SyncNet

In case you want to train SyncNet on your own datasets, you can run the following script:

```bash
./train_syncnet.sh
```
4 changes: 2 additions & 2 deletions configs/syncnet/syncnet_16_pixel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ model:
ckpt:
resume_ckpt_path: ""
inference_ckpt_path: checkpoints/latentsync_syncnet.pt
save_ckpt_steps: 20
save_ckpt_steps: 2500

data:
train_output_dir: debug/syncnet
Expand All @@ -40,6 +40,6 @@ optimizer:

run:
max_train_steps: 10000000
validation_steps: 20
validation_steps: 2500
mixed_precision_training: true
seed: 42
102 changes: 102 additions & 0 deletions configs/unet/first_stage.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
data:
syncnet_config_path: configs/syncnet/syncnet_16_pixel.yaml
train_output_dir: debug/unet
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/all_data_v6.txt
train_data_dir: ""
audio_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/whisper_new

val_video_path: assets/demo1_video.mp4
val_audio_path: assets/demo1_audio.wav
batch_size: 8 # 8
num_workers: 11 # 11
num_frames: 16
resolution: 256
mask: fix_mask
audio_sample_rate: 16000
video_fps: 25

ckpt:
resume_ckpt_path: checkpoints/latentsync_unet.pt
save_ckpt_steps: 5000

run:
pixel_space_supervise: false
use_syncnet: true
sync_loss_weight: 0.05 # 1/283
perceptual_loss_weight: 0.1 # 0.1
recon_loss_weight: 1 # 1
guidance_scale: 1.0 # 1.5 or 1.0
trepa_loss_weight: 10
inference_steps: 20
seed: 1247
use_mixed_noise: true
mixed_noise_alpha: 1 # 1
mixed_precision_training: true
enable_gradient_checkpointing: false
enable_xformers_memory_efficient_attention: true
max_train_steps: 10000000
max_train_epochs: -1

optimizer:
lr: 1e-5
scale_lr: false
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_steps: 0

model:
act_fn: silu
add_audio_layer: true
custom_audio_layer: false
audio_condition_method: cross_attn # Choose between [cross_attn, group_norm]
attention_head_dim: 8
block_out_channels: [320, 640, 1280, 1280]
center_input_sample: false
cross_attention_dim: 384
down_block_types:
[
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
]
mid_block_type: UNetMidBlock3DCrossAttn
up_block_types:
[
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
]
downsample_padding: 1
flip_sin_to_cos: true
freq_shift: 0
in_channels: 13 # 49
layers_per_block: 2
mid_block_scale_factor: 1
norm_eps: 1e-5
norm_num_groups: 32
out_channels: 4 # 16
sample_size: 64
resnet_time_scale_shift: default # Choose between [default, scale_shift]
unet_use_cross_frame_attention: false
unet_use_temporal_attention: false

# Actually we don't use the motion module in the final version of LatentSync
# When we started the project, we used the codebase of AnimateDiff and tried motion module
# But the results are poor, and we decied to leave the code here for possible future usage
use_motion_module: false
motion_module_resolutions: [1, 2, 4, 8]
motion_module_mid_block: false
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 16
temporal_attention_dim_div: 1
zero_initialize: true
File renamed without changes.
2 changes: 1 addition & 1 deletion inference.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

python -m scripts.inference \
--unet_config_path "configs/unet/unet_latent_16_diffusion.yaml" \
--unet_config_path "configs/unet/second_stage.yaml" \
--inference_ckpt_path "checkpoints/latentsync_unet.pt" \
--video_path "assets/demo1_video.mp4" \
--audio_path "assets/demo1_audio.wav" \
Expand Down
2 changes: 1 addition & 1 deletion train_unet.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash

torchrun --nnodes=1 --nproc_per_node=1 --master_port=25678 -m scripts.train_unet \
--unet_config_path "configs/unet/unet_latent_16_diffusion.yaml"
--unet_config_path "configs/unet/first_stage.yaml"

0 comments on commit 992ed80

Please sign in to comment.