Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support LLaVA #196

Merged
merged 218 commits into from
Dec 26, 2023
Merged
Changes from 1 commit
Commits
Show all changes
218 commits
Select commit Hold shift + click to select a range
6de3469
v1
LZHgrla Nov 1, 2023
70971d9
add load_image
LZHgrla Nov 1, 2023
9405944
update cfg image url
LZHgrla Nov 1, 2023
c946bd5
del fig
LZHgrla Nov 1, 2023
5b76a56
Merge branch 'main' into lzh/llava
LZHgrla Nov 1, 2023
551bb74
update
LZHgrla Nov 1, 2023
70fa7d9
temp
LZHgrla Nov 1, 2023
5013d3d
update convert
LZHgrla Nov 2, 2023
39f2fb3
update chat_mm
LZHgrla Nov 2, 2023
a65a1ae
add exclude_frozen_parameters for deepspeed
LZHgrla Nov 2, 2023
5dd244a
update chat
LZHgrla Nov 2, 2023
669f282
update xtuner help msg
LZHgrla Nov 2, 2023
b0f9ad0
fix bugs
LZHgrla Nov 2, 2023
dea64fb
revert bf16 deepspeed
LZHgrla Nov 2, 2023
c31ab61
Merge branch 'InternLM:main' into lzh/llava
LZHgrla Nov 2, 2023
1f4a97b
fix bugs
LZHgrla Nov 2, 2023
6ceeaa8
add visual_select_layer for chat
LZHgrla Nov 2, 2023
7502793
improve pth_to_hf
LZHgrla Nov 2, 2023
6f31402
Merge branch 'main' into lzh/llava
LZHgrla Nov 3, 2023
9268dbc
rename projecter_pth to pretrained_pth
LZHgrla Nov 6, 2023
5282b3c
temp
LZHgrla Nov 6, 2023
3e4b425
update requirements
LZHgrla Nov 7, 2023
fe30549
add cfgs
LZHgrla Nov 7, 2023
413111f
update
LZHgrla Nov 7, 2023
fac6cf8
fix pre-commit
LZHgrla Nov 7, 2023
da3f268
optim chat
LZHgrla Nov 7, 2023
b25913b
optim chat
LZHgrla Nov 7, 2023
0a9e480
Delete xtuner/model/unused.py
LZHgrla Nov 7, 2023
c0c4b8b
move dispatch to a deeper folder
LZHgrla Nov 8, 2023
74159ac
add projector
LZHgrla Nov 8, 2023
b92f075
update
LZHgrla Nov 8, 2023
04fca91
del model/projector
LZHgrla Nov 8, 2023
fba83bc
fix bugs
LZHgrla Nov 8, 2023
93f6616
add docs
LZHgrla Nov 8, 2023
d13f71e
update
LZHgrla Nov 8, 2023
e62132f
update
LZHgrla Nov 8, 2023
37b017e
update
LZHgrla Nov 8, 2023
8d67ed1
update
LZHgrla Nov 8, 2023
6b99303
Merge branch 'main' into lzh/llava
LZHgrla Nov 8, 2023
00c5d0c
enhance resume for map_fn
LZHgrla Nov 8, 2023
5496952
update import
LZHgrla Nov 8, 2023
a07a833
add llava_internlm_chat_7b_clip_vit_large_p14
LZHgrla Nov 9, 2023
21e649b
update dispatch
LZHgrla Nov 9, 2023
e0e0275
update dispatch
LZHgrla Nov 9, 2023
1a12477
add link
LZHgrla Nov 9, 2023
fbaf22f
update max_length
LZHgrla Nov 10, 2023
f31366f
update max_length
LZHgrla Nov 10, 2023
9786d3b
update hyp
LZHgrla Nov 10, 2023
db15bfa
align
LZHgrla Nov 11, 2023
1af5c4a
Merge branch 'main' into lzh/llava
LZHgrla Nov 13, 2023
c426c42
move yi flash attn
LZHgrla Nov 13, 2023
001bf8d
fix pre-commit
LZHgrla Nov 13, 2023
7b7b690
update deepspeed requirements
LZHgrla Nov 14, 2023
374f997
add mmbench script
LZHgrla Nov 14, 2023
bb9dcc3
install openpyxl
LZHgrla Nov 14, 2023
f4114ed
add entry_point for mmbench
LZHgrla Nov 14, 2023
041e96f
save args
LZHgrla Nov 14, 2023
c5c8437
update mmbench
LZHgrla Nov 14, 2023
0509cdd
Merge branch 'main' into lzh/llava
LZHgrla Nov 15, 2023
c3e3cf2
update max_length
LZHgrla Nov 15, 2023
20b5d74
add llama2 qlora
LZHgrla Nov 15, 2023
092f7e6
update mmbench
LZHgrla Nov 15, 2023
6ffffef
fix mmbench bugs
LZHgrla Nov 15, 2023
efc178b
Merge branch 'main' into lzh/llava
LZHgrla Nov 16, 2023
d30ba29
use osp instead of os.path
LZHgrla Nov 16, 2023
a3f2435
refactor pth_to_hf
LZHgrla Nov 16, 2023
aafe80c
update chat and mmbench to support --llava
LZHgrla Nov 16, 2023
8836b0b
align to chat
LZHgrla Nov 16, 2023
8393559
update entry_point
LZHgrla Nov 16, 2023
e0274f6
add vicuna template
LZHgrla Nov 16, 2023
927e207
add vicuna_7b_v15
LZHgrla Nov 16, 2023
33d6c76
Merge branch 'main' into lzh/llava
LZHgrla Nov 17, 2023
903e074
fix pre-commit
LZHgrla Nov 17, 2023
b66419e
add vicuna_7b_v1.5 qlora
LZHgrla Nov 20, 2023
6edf769
Merge branch 'main' into lzh/llava
LZHgrla Nov 20, 2023
5fe1e54
skip_special_tokens for decode text
LZHgrla Nov 21, 2023
376abb6
remove do_sample
LZHgrla Nov 22, 2023
82d8df6
Merge branch 'main' into lzh/llava
LZHgrla Nov 22, 2023
4dc6379
add warmup
LZHgrla Nov 22, 2023
67ffd70
fix pre-commit
LZHgrla Nov 22, 2023
76ad9e9
Update dataset_prepare.md
LZHgrla Nov 22, 2023
aeded33
Update dataset_prepare.md
LZHgrla Nov 22, 2023
3c44f94
Add KEEP_STSTEM for template
LZHgrla Nov 22, 2023
b537e9b
remove
LZHgrla Nov 22, 2023
c36c1ff
fix vicuna template
LZHgrla Nov 22, 2023
fb3f7da
clean cfgs
LZHgrla Nov 22, 2023
b0b4f1d
add cfgs
LZHgrla Nov 22, 2023
6c2dec5
Merge branch 'main' into lzh/llava
LZHgrla Nov 22, 2023
1543df7
fix pre-commit
LZHgrla Nov 22, 2023
db434f7
add --language for mmbench
LZHgrla Nov 23, 2023
9f3e44e
Merge branch 'main' into lzh/llava
LZHgrla Nov 23, 2023
eb2ad0d
fix bugs
LZHgrla Nov 23, 2023
349e37c
fix pretrain bug
LZHgrla Nov 23, 2023
bbbc62b
support visual_encoder lora
LZHgrla Nov 23, 2023
0357a93
fix bugs
LZHgrla Nov 23, 2023
f2295d2
add paramwise_cfg
LZHgrla Nov 23, 2023
72a986e
remove print_peft_model_trainable_parameters
LZHgrla Nov 23, 2023
e0583cb
Merge branch 'main' into lzh/llava
LZHgrla Nov 24, 2023
4416c56
fix bugs
LZHgrla Nov 24, 2023
0894e60
add paramwise_cfg for DeepSpeedOptimWrapper
LZHgrla Nov 24, 2023
ff4f15e
fix engine deepspeed paramwise_cfg bug
LZHgrla Nov 24, 2023
aa1dbf1
fix encode_fn bug
LZHgrla Nov 25, 2023
a046e0e
fix
LZHgrla Nov 25, 2023
9080be3
fix pad_image_to_square bugs
LZHgrla Nov 26, 2023
12c212a
Add space for system to avoid mismatch of 'USER' token
LZHgrla Nov 26, 2023
19bde6f
revert to adding bos_token at each conv
LZHgrla Nov 29, 2023
7c01831
revert for paramwise_cfg
LZHgrla Nov 29, 2023
ba9de6d
better cfgs?
LZHgrla Nov 29, 2023
baa1727
fix import bug
LZHgrla Nov 29, 2023
c5e61a5
fix import bug
LZHgrla Nov 29, 2023
fece023
pretrain align
LZHgrla Nov 30, 2023
273b24d
update prepare_inputs_labels_for_multimodal
LZHgrla Nov 30, 2023
b37dd8f
1792
LZHgrla Nov 30, 2023
e307624
support length_grouped_samplers
LZHgrla Dec 1, 2023
e25280b
1792
LZHgrla Dec 1, 2023
0a15676
remove KEEP_SYSTEM
LZHgrla Dec 1, 2023
e3b936a
remove system in cfg
LZHgrla Dec 1, 2023
580136b
update 336 cfg
LZHgrla Dec 1, 2023
683385e
Merge branch 'main' into lzh/llava
LZHgrla Dec 1, 2023
053eb84
add torch_dtype for mmbench and chat
LZHgrla Dec 2, 2023
f362a9f
group 50
LZHgrla Dec 2, 2023
12d7a1e
quant for pretrain
LZHgrla Dec 2, 2023
c4fd8db
update cfgs
LZHgrla Dec 4, 2023
245af61
refactor cfgs
LZHgrla Dec 4, 2023
4087168
add length for concat dataset
LZHgrla Dec 4, 2023
013a930
update requirements
LZHgrla Dec 4, 2023
8721427
Merge branch 'lzh/llava' of github.com:LZHgrla/xtuner into lzh/llava
LZHgrla Dec 4, 2023
491be19
fix typo
LZHgrla Dec 4, 2023
bf5d2da
Merge branch 'main' into lzh/llava
LZHgrla Dec 4, 2023
0e21e51
add template for internlm pretrain
LZHgrla Dec 6, 2023
d429961
no zh
LZHgrla Dec 6, 2023
8ce84c3
remove 20b cfgs
LZHgrla Dec 6, 2023
41a8794
fix pre-commit
LZHgrla Dec 6, 2023
ac80d1a
revert invalid input
LZHgrla Dec 7, 2023
2e94a52
rename
LZHgrla Dec 7, 2023
c5b9e75
Update README.md
LZHgrla Dec 7, 2023
2110f19
Update README_zh-CN.md
LZHgrla Dec 7, 2023
036fd72
fix pre-commit
LZHgrla Dec 7, 2023
a8eecbf
remove llava_zh from docs
LZHgrla Dec 8, 2023
0cc9bf8
qlora 512
LZHgrla Dec 9, 2023
bcaffd4
rename llava map_fn
LZHgrla Dec 10, 2023
0f0250c
update cfgs
LZHgrla Dec 10, 2023
a050926
update model urls
LZHgrla Dec 11, 2023
951b15a
add docs link
LZHgrla Dec 11, 2023
8dc0746
add llava docs
LZHgrla Dec 11, 2023
2e5c77a
Merge branch 'main' into lzh/llava
LZHgrla Dec 11, 2023
3aef652
update docs
LZHgrla Dec 11, 2023
24996d6
update urls
LZHgrla Dec 11, 2023
8787baa
Merge branch 'main' into lzh/llava
LZHgrla Dec 11, 2023
3eac2df
add citation
LZHgrla Dec 11, 2023
f65bf9a
fix README
LZHgrla Dec 14, 2023
aa6d525
move
LZHgrla Dec 14, 2023
ba8facd
update
LZHgrla Dec 14, 2023
b717f8e
vicuna pretrain with prompt
LZHgrla Dec 14, 2023
6dd1e63
rename
LZHgrla Dec 15, 2023
321b351
add results
LZHgrla Dec 15, 2023
c44f71b
fix pre-commit
LZHgrla Dec 15, 2023
cde826f
update
LZHgrla Dec 15, 2023
f2b6e3b
update
LZHgrla Dec 15, 2023
3b6c07a
update
LZHgrla Dec 15, 2023
b383667
update
LZHgrla Dec 15, 2023
b463465
update
LZHgrla Dec 15, 2023
261fc43
update
LZHgrla Dec 15, 2023
da3ed07
update
LZHgrla Dec 15, 2023
8045143
update
LZHgrla Dec 15, 2023
f7d14f8
update
LZHgrla Dec 15, 2023
ee9c026
update
LZHgrla Dec 15, 2023
adb4ba5
update
LZHgrla Dec 15, 2023
fa1ce76
update
LZHgrla Dec 15, 2023
95faa59
Update README.md
LZHgrla Dec 15, 2023
8054690
Update README_zh-CN.md
LZHgrla Dec 15, 2023
56e9507
Update README_zh.md
LZHgrla Dec 15, 2023
99c4e91
Update README_zh.md
LZHgrla Dec 15, 2023
479a5fd
Update README.md
LZHgrla Dec 15, 2023
47f4927
Update README_zh.md
LZHgrla Dec 15, 2023
367225b
Update README.md
LZHgrla Dec 15, 2023
c4007f7
Update README.md
LZHgrla Dec 15, 2023
3d4dee8
fix typo
LZHgrla Dec 15, 2023
eec012c
fix
LZHgrla Dec 15, 2023
b027cb8
Update README.md
LZHgrla Dec 15, 2023
6276d33
Update README_zh-CN.md
LZHgrla Dec 15, 2023
ad65fc8
rename
LZHgrla Dec 16, 2023
77d9809
auto cn_string
LZHgrla Dec 16, 2023
a318133
fix pre-commit
LZHgrla Dec 16, 2023
1dadc4b
rename
LZHgrla Dec 16, 2023
72ca5ee
remove language
LZHgrla Dec 16, 2023
197c292
add VLMEvalKit
LZHgrla Dec 16, 2023
11cbfdc
rename VLLM to VLM
LZHgrla Dec 21, 2023
63ed932
add the download links of MMBench
LZHgrla Dec 21, 2023
99a2b8e
update
LZHgrla Dec 21, 2023
8080a06
update readme
LZHgrla Dec 21, 2023
360b816
update
LZHgrla Dec 21, 2023
4ade82d
update
LZHgrla Dec 21, 2023
885c832
update merge
LZHgrla Dec 21, 2023
990d689
fix cfg bug
LZHgrla Dec 21, 2023
0e5d692
Update README.md
LZHgrla Dec 21, 2023
8225f9f
Update README_zh.md
LZHgrla Dec 21, 2023
648111d
update
LZHgrla Dec 21, 2023
6f06498
fix
LZHgrla Dec 21, 2023
76d1313
Merge branch 'main' into lzh/llava
LZHgrla Dec 21, 2023
b5124b1
update requirements
LZHgrla Dec 22, 2023
5973d6c
Merge branch 'main' into lzh/llava
LZHgrla Dec 22, 2023
311f9d0
Update runtime.txt
LZHgrla Dec 22, 2023
cbb7924
Update runtime.txt
LZHgrla Dec 22, 2023
d9a96af
Update runtime.txt
LZHgrla Dec 22, 2023
8332c2c
Update README.md
LZHgrla Dec 25, 2023
b9efc8a
Update README.md
LZHgrla Dec 25, 2023
7b29f81
Update README_zh.md
LZHgrla Dec 25, 2023
6c80ec7
fix pre-commit
LZHgrla Dec 25, 2023
034b4cb
fix
LZHgrla Dec 25, 2023
f7ec4da
update mmbench prompt
LZHgrla Dec 25, 2023
7231865
fix bugs
LZHgrla Dec 26, 2023
bf384de
fix bugs
LZHgrla Dec 26, 2023
80d7c11
update docs
LZHgrla Dec 26, 2023
7e68e1d
update
LZHgrla Dec 26, 2023
327a122
update
LZHgrla Dec 26, 2023
15c927a
Merge branch 'main' into lzh/llava
LZHgrla Dec 26, 2023
761a7ea
Update README.md
LZHgrla Dec 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
v1
LZHgrla committed Nov 1, 2023
commit 6de3469fd975619df1551edea94936d8cdedf878
2 changes: 1 addition & 1 deletion xtuner/configs/deepspeed/deepspeed_zero2.json
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
"overlap_comm": true,
"reduce_scatter": true
},
"fp16": {
"bf16": {
"enabled": true,
"initial_scale_power": 16
}
Binary file added xtuner/configs/llava/cloud.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
176 changes: 176 additions & 0 deletions xtuner/configs/llava/llava_llama2_7b_chat_clip_vit_large_p14_e1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
CLIPImageProcessor, CLIPVisionModel)

from xtuner.dataset import LLaVADataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
from xtuner.model import LLaVAModel
from xtuner.utils import PROMPT_TEMPLATE

#######################################################################
# PART 1 Settings #
#######################################################################
# Model
llm_name_or_path = 'meta-llama/Llama-2-7b-hf'
visual_encoder_name_or_path = 'openai/clip-vit-large-patch14'

# Data
data_path = './data/llava_data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json'
image_folder = './data/llava_data/LLaVA-Pretrain/images'
prompt_template = PROMPT_TEMPLATE.llama2_chat
max_length = 2048

# Scheduler & Optimizer
batch_size = 32 # per_device
accumulative_counts = 1
dataloader_num_workers = 4
max_epochs = 1
optim_type = AdamW
lr = 1e-3
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1 # grad clip

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = ''
evaluation_images = ',/xtuner/xtuner/configs/llava/cloud.png'
evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']

#######################################################################
# PART 2 Model & Tokenizer & Processor #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=llm_name_or_path,
trust_remote_code=True,
padding_side='right')

processor = dict(
type=CLIPImageProcessor.from_pretrained,
pretrained_model_name_or_path=visual_encoder_name_or_path,
trust_remote_code=True)

model = dict(
type=LLaVAModel,
freeze_llm=True,
freeze_visual_encoder=True,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=llm_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float32),
visual_encoder=dict(
type=CLIPVisionModel.from_pretrained,
pretrained_model_name_or_path=visual_encoder_name_or_path))

#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
llava_dataset = dict(
type=LLaVADataset,
data_path=data_path,
image_folder=image_folder,
tokenizer=tokenizer,
processor=processor,
dataset_map_fn=llava_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
image_aspect_ratio='pad')

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=llava_dataset,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate_fn))

#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = dict(
type=CosineAnnealingLR,
eta_min=lr * 0.1,
by_epoch=True,
T_max=max_epochs,
convert_to_iter_based=True)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)

#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
processor=processor,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
evaluation_images=evaluation_images,
system=SYSTEM,
prompt_template=prompt_template)
]

# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 100 iterations.
logger=dict(type=LoggerHook, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per epoch.
checkpoint=dict(type=CheckpointHook, interval=1),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
178 changes: 178 additions & 0 deletions xtuner/configs/llava/llava_llama2_7b_chat_clip_vit_large_p14_e1_ft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
CLIPImageProcessor, CLIPVisionModel)

from xtuner.dataset import LLaVADataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
from xtuner.model import LLaVAModel
from xtuner.utils import PROMPT_TEMPLATE

#######################################################################
# PART 1 Settings #
#######################################################################
# Model
llm_name_or_path = 'meta-llama/Llama-2-7b-hf'
visual_encoder_name_or_path = 'openai/clip-vit-large-patch14'
projector_pth = './work_dirs/llava_llama2_7b_chat_clip_vit_large_p14_e1/epoch_1.pth' # noqa: E501

# Data
data_path = './data/llava_data/LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
image_folder = './data/llava_data/llava_images'
prompt_template = PROMPT_TEMPLATE.llama2_chat
max_length = 2048

# Scheduler & Optimizer
batch_size = 16 # per_device
accumulative_counts = 1
dataloader_num_workers = 4
max_epochs = 1
optim_type = AdamW
lr = 2e-5
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1 # grad clip

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = ''
evaluation_images = './xtuner/configs/llava/cloud.png'
evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']

#######################################################################
# PART 2 Model & Tokenizer & Processor #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=llm_name_or_path,
trust_remote_code=True,
padding_side='right')

processor = dict(
type=CLIPImageProcessor.from_pretrained,
pretrained_model_name_or_path=visual_encoder_name_or_path,
trust_remote_code=True)

model = dict(
type=LLaVAModel,
freeze_llm=False,
freeze_visual_encoder=True,
projector_pth=projector_pth,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=llm_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float32),
visual_encoder=dict(
type=CLIPVisionModel.from_pretrained,
pretrained_model_name_or_path=visual_encoder_name_or_path))

#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
llava_dataset = dict(
type=LLaVADataset,
data_path=data_path,
image_folder=image_folder,
tokenizer=tokenizer,
processor=processor,
dataset_map_fn=llava_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=max_length,
image_aspect_ratio='pad')

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=llava_dataset,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate_fn))

#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = dict(
type=CosineAnnealingLR,
eta_min=lr * 0.1,
by_epoch=True,
T_max=max_epochs,
convert_to_iter_based=True)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)

#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
processor=processor,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
evaluation_images=evaluation_images,
system=SYSTEM,
prompt_template=prompt_template)
]

# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 100 iterations.
logger=dict(type=LoggerHook, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per epoch.
checkpoint=dict(type=CheckpointHook, interval=1),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
4 changes: 3 additions & 1 deletion xtuner/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .concat_dataset import ConcatDataset
from .huggingface import process_hf_dataset
from .llava import LLaVADataset
from .modelscope import process_ms_dataset
from .moss_sft import MOSSSFTDataset
from .utils import expand2square

__all__ = [
'process_hf_dataset', 'ConcatDataset', 'MOSSSFTDataset',
'process_ms_dataset'
'process_ms_dataset', 'LLaVADataset', 'expand2square'
]
Loading