Skip to content

Commit 31fe758

Browse files
committed
try to install flash-attn and lower the model size for the demo mode
1 parent 7ca7332 commit 31fe758

File tree

2 files changed

+60
-9
lines changed

2 files changed

+60
-9
lines changed

src/MEDS_DEV/models/cehrxgpt/mimiciv/model.yaml

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,28 @@ commands:
1515
mkdir -p "{output_dir}/cehrgpt_pretrained/dataset_prepared"
1616
meds_reader_convert "{dataset_dir}" "{output_dir}/cehrgpt_pretrained/meds_reader" --num_threads 8
1717
18+
echo "Attempting to install flash-attn (optional)..."
19+
pip install flash-attn || echo "Warning: flash-attn installation failed. Continuing without it."
20+
21+
# Set model configuration based on demo mode
22+
if [ "{demo}" = "true" ]; then
23+
export HIDDEN_SIZE=256
24+
export NUM_LAYERS=4
25+
export MAX_POS_EMB=128
26+
export MAX_TOKENS=512
27+
export NUM_EPOCHS=5
28+
export DATALOADER_WORKERS=2
29+
export PREFETCH_FACTOR=2
30+
else
31+
export HIDDEN_SIZE=768
32+
export NUM_LAYERS=14
33+
export MAX_POS_EMB=8192
34+
export MAX_TOKENS=16384
35+
export NUM_EPOCHS=50
36+
export DATALOADER_WORKERS=8
37+
export PREFETCH_FACTOR=8
38+
fi
39+
1840
export CUDA_VISIBLE_DEVICES="0"; python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
1941
--model_name_or_path "{output_dir}/cehrgpt_pretrained" \
2042
--tokenizer_name_or_path "{output_dir}/cehrgpt_pretrained" \
@@ -23,12 +45,16 @@ commands:
2345
--tokenized_dataset_name "full_tokenized_dataset" \
2446
--dataset_prepared_path {output_dir}/cehrgpt_pretrained/dataset_prepared \
2547
--do_train true --seed 42 \
26-
--dataloader_num_workers 8 --dataloader_prefetch_factor 8 \
27-
--hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
48+
--dataloader_num_workers $DATALOADER_WORKERS \
49+
--dataloader_prefetch_factor $PREFETCH_FACTOR \
50+
--hidden_size $HIDDEN_SIZE \
51+
--num_hidden_layers $NUM_LAYERS \
52+
--max_position_embeddings $MAX_POS_EMB \
2853
--evaluation_strategy epoch --save_strategy epoch \
29-
--sample_packing --max_tokens_per_batch 16384 \
54+
--sample_packing --max_tokens_per_batch $MAX_TOKENS \
3055
--warmup_ratio 0.01 --weight_decay 0.01 \
31-
--num_train_epochs 50 --learning_rate 0.0001 \
56+
--num_train_epochs $NUM_EPOCHS \
57+
--learning_rate 0.0001 \
3258
--use_early_stopping --early_stopping_threshold 0.001 \
3359
--load_best_model_at_end \
3460
--is_data_in_meds --inpatient_att_function_type day \

src/MEDS_DEV/models/cehrxgpt/omop/model.yaml

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,28 @@ commands:
1515
mkdir -p "{output_dir}/cehrgpt_pretrained/dataset_prepared"
1616
meds_reader_convert "{dataset_dir}" "{output_dir}/cehrgpt_pretrained/meds_reader" --num_threads 8
1717
18+
echo "Attempting to install flash-attn (optional)..."
19+
pip install flash-attn || echo "Warning: flash-attn installation failed. Continuing without it."
20+
21+
# Set model configuration based on demo mode
22+
if [ "{demo}" = "true" ]; then
23+
export HIDDEN_SIZE=256
24+
export NUM_LAYERS=4
25+
export MAX_POS_EMB=128
26+
export MAX_TOKENS=512
27+
export NUM_EPOCHS=5
28+
export DATALOADER_WORKERS=2
29+
export PREFETCH_FACTOR=2
30+
else
31+
export HIDDEN_SIZE=768
32+
export NUM_LAYERS=14
33+
export MAX_POS_EMB=8192
34+
export MAX_TOKENS=16384
35+
export NUM_EPOCHS=50
36+
export DATALOADER_WORKERS=8
37+
export PREFETCH_FACTOR=8
38+
fi
39+
1840
export CUDA_VISIBLE_DEVICES="0"; python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
1941
--model_name_or_path "{output_dir}/cehrgpt_pretrained" \
2042
--tokenizer_name_or_path "{output_dir}/cehrgpt_pretrained" \
@@ -23,18 +45,21 @@ commands:
2345
--tokenized_dataset_name "full_tokenized_dataset" \
2446
--dataset_prepared_path {output_dir}/cehrgpt_pretrained/dataset_prepared \
2547
--do_train true --seed 42 \
26-
--dataloader_num_workers 8 --dataloader_prefetch_factor 8 \
27-
--hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
48+
--dataloader_num_workers $DATALOADER_WORKERS \
49+
--dataloader_prefetch_factor $PREFETCH_FACTOR \
50+
--hidden_size $HIDDEN_SIZE \
51+
--num_hidden_layers $NUM_LAYERS \
52+
--max_position_embeddings $MAX_POS_EMB \
2853
--evaluation_strategy epoch --save_strategy epoch \
29-
--sample_packing --max_tokens_per_batch 16384 \
54+
--sample_packing --max_tokens_per_batch $MAX_TOKENS \
3055
--warmup_ratio 0.01 --weight_decay 0.01 \
31-
--num_train_epochs 50 --learning_rate 0.0001 \
56+
--num_train_epochs $NUM_EPOCHS \
57+
--learning_rate 0.0001 \
3258
--use_early_stopping --early_stopping_threshold 0.001 \
3359
--load_best_model_at_end \
3460
--is_data_in_meds --inpatient_att_function_type day \
3561
--att_function_type day --include_inpatient_hour_token \
3662
--include_auxiliary_token --include_demographic_prompt \
37-
--disconnect_problem_list_events \
3863
--meds_to_cehrbert_conversion_type MedsToCehrbertOMOP \
3964
--report_to "none"
4065

0 commit comments

Comments
 (0)