@@ -15,6 +15,28 @@ commands:
1515 mkdir -p "{output_dir}/cehrgpt_pretrained/dataset_prepared"
1616 meds_reader_convert "{dataset_dir}" "{output_dir}/cehrgpt_pretrained/meds_reader" --num_threads 8
1717
18+ echo "Attempting to install flash-attn (optional)..."
19+ pip install flash-attn || echo "Warning: flash-attn installation failed. Continuing without it."
20+
21+ # Set model configuration based on demo mode
22+ if [ "{demo}" = "true" ]; then
23+ export HIDDEN_SIZE=256
24+ export NUM_LAYERS=4
25+ export MAX_POS_EMB=128
26+ export MAX_TOKENS=512
27+ export NUM_EPOCHS=5
28+ export DATALOADER_WORKERS=2
29+ export PREFETCH_FACTOR=2
30+ else
31+ export HIDDEN_SIZE=768
32+ export NUM_LAYERS=14
33+ export MAX_POS_EMB=8192
34+ export MAX_TOKENS=16384
35+ export NUM_EPOCHS=50
36+ export DATALOADER_WORKERS=8
37+ export PREFETCH_FACTOR=8
38+ fi
39+
1840 export CUDA_VISIBLE_DEVICES="0"; python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner \
1941 --model_name_or_path "{output_dir}/cehrgpt_pretrained" \
2042 --tokenizer_name_or_path "{output_dir}/cehrgpt_pretrained" \
@@ -23,18 +45,21 @@ commands:
2345 --tokenized_dataset_name "full_tokenized_dataset" \
2446 --dataset_prepared_path {output_dir}/cehrgpt_pretrained/dataset_prepared \
2547 --do_train true --seed 42 \
26- --dataloader_num_workers 8 --dataloader_prefetch_factor 8 \
27- --hidden_size 768 --num_hidden_layers 14 --max_position_embeddings 8192 \
48+ --dataloader_num_workers $DATALOADER_WORKERS \
49+ --dataloader_prefetch_factor $PREFETCH_FACTOR \
50+ --hidden_size $HIDDEN_SIZE \
51+ --num_hidden_layers $NUM_LAYERS \
52+ --max_position_embeddings $MAX_POS_EMB \
2853 --evaluation_strategy epoch --save_strategy epoch \
29- --sample_packing --max_tokens_per_batch 16384 \
54+ --sample_packing --max_tokens_per_batch $MAX_TOKENS \
3055 --warmup_ratio 0.01 --weight_decay 0.01 \
31- --num_train_epochs 50 --learning_rate 0.0001 \
56+ --num_train_epochs $NUM_EPOCHS \
57+ --learning_rate 0.0001 \
3258 --use_early_stopping --early_stopping_threshold 0.001 \
3359 --load_best_model_at_end \
3460 --is_data_in_meds --inpatient_att_function_type day \
3561 --att_function_type day --include_inpatient_hour_token \
3662 --include_auxiliary_token --include_demographic_prompt \
37- --disconnect_problem_list_events \
3863 --meds_to_cehrbert_conversion_type MedsToCehrbertOMOP \
3964 --report_to "none"
4065
0 commit comments