Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

13B On 24GB go OOM #225

Closed
nivibilla opened this issue Jul 22, 2023 · 16 comments
Closed

13B On 24GB go OOM #225

nivibilla opened this issue Jul 22, 2023 · 16 comments

Comments

@nivibilla
Copy link

Hey, how can I stop the training going OOM when training a 13B in a 24GB GPU?

@nivibilla
Copy link
Author

Related question, if I had to drop from modules to adapt.

Out of q,k,v,o up, down, gate. Which should I drop first?

@BugReporterZ
Copy link

BugReporterZ commented Jul 22, 2023

You should probably first preferably limit your batch size to 1 if you haven't already done so. Decreasing the Lora R to 8 may also help to a limited extent without affecting model performance.

--per_device_train_batch_size 1
--gradient_accumulation_steps 1
--lora_r 8

If you're still getting OOM after this then perhaps your training data has excessively long sequence lengths. With my 24GB RTX3090 I could go up to about 3750 tokens with these settings.

@nivibilla
Copy link
Author

I've basically stripped everything down. I'm only training q,v. I'm using batch size 1 with grad accumulation 1. And LoRA r 4. But I still get OOM

@artidoro
Copy link
Owner

Do you mind sharing your code/script? 24GB should be sufficient for 13B

@nivibilla
Copy link
Author

nivibilla commented Jul 22, 2023

Sure here is the converted notebook

# %%
import os
os.environ["LD_LIBRARY_PATH"] = "" # hack to force bitsandbytes to use databricks cuda11.7 lib
os.environ["TORCH_DISTRIBUTED_DEBUG"] = 'DETAIL'
# would like to solve this issue for good at some point..

# %%
!mlflow gc

# %%
# check gpus
!nvidia-smi

# %%
!pip install torch==2.0.1

# %%
!pip install -q tqdm
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U git+https://github.com/huggingface/datasets.git
!pip install -q -U sentencepiece
# !pip install -q -U einops
# !pip install -q -U trl

# %%
# check bitsandbytes is working correctly
!python -m bitsandbytes

# %%
models = {
    '/dbfs/mnt/llm_model_dump/Llama-2-13b-hf/' : {
        'size' : '7b',
        'folder_name' : 'llama_2_13b_qlora_classifier_7k'
    }
}

list_of_models = list(models.keys())
model_name = '/dbfs/mnt/llm_model_dump/Llama-2-13b-hf/'

# %%
from pyspark.sql.types import StringType, ArrayType, IntegerType
from pyspark.sql.functions import col, lit

@udf(returnType=StringType())
def make_prompt(inputs, group, end_token):
  instruction = f"""Out of these classes: ['good', 'bad', 'neutral'] which class does this text belong to?\n<start_of_text>{inputs}<end_of_text>"""
  
  response = f""" {group}"""

  prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\nOut of the classes you provided the text belongs to{response}.{end_token}"""
      
  return prompt

from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)

bc_tokenizer = spark.sparkContext.broadcast(tokenizer)

@udf(returnType=ArrayType(IntegerType()))
def dist_tokenize_input_ids(seq):
  return bc_tokenizer.value(seq)['input_ids']

@udf(returnType=ArrayType(IntegerType()))
def dist_tokenize_attention_mask(seq):
  return bc_tokenizer.value(seq)['attention_mask']

# %%
train_df = (
    spark.read.parquet(
        "/mnt/processed/train_500"
    ).limit(500)
    .withColumn(
        "prompt", make_prompt(col("inputs"), col("group"), lit(tokenizer.eos_token))
    )
    .withColumn("input_ids", dist_tokenize_input_ids(col("prompt")))
    .withColumn("attention_mask", dist_tokenize_attention_mask(col("prompt")))
    # .withColumn('num_tokens',  
)

# test_df = (
#     spark.read.parquet(
#         "/mnt/processed/test_1_percent"
#     )
#     .withColumn(
#         "prompt", make_prompt(col("inputs"), col("group"), lit(tokenizer.eos_token))
#     )
#     .withColumn("input_ids", dist_tokenize(col("prompt")))
# )


# %%
from datasets import Dataset, DatasetDict

train_dataset = Dataset.from_spark(train_df)
# test_dataset = Dataset.from_spark(test_df)

train_dataset = train_dataset.add_column("group_value", train_dataset['group'])
train_dataset = train_dataset.class_encode_column('group').train_test_split(test_size=0.1, stratify_by_column="group", seed=1, shuffle=True)

train_dataset = train_dataset.remove_columns('group').rename_column('group_value', 'group')

dataset = DatasetDict({
    'train' : train_dataset['train'],
    'valid' : train_dataset['test'],
    # 'test' : test_dataset
    })

# %%
dataset

# %%
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto")


# %%
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

# %%
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model, use_gradient_checkpointing = True)

config = LoraConfig(
    r=4, 
    lora_alpha=8, 
    target_modules=['q_proj', 'v_proj' ], 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM"
)


model = get_peft_model(model, config)
print_trainable_parameters(model)

# %%
model

# %%
# check gpus
!nvidia-smi

# %%
dataset['train']

# %%
torch.cuda.empty_cache()

# %%
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling, TrainerCallback

class PeftSavingCallback(TrainerCallback):
  def on_save(self, args, state, control, **kwargs):
    checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
    kwargs["model"].save_pretrained(checkpoint_path)

    if "pytorch_model.bin" in os.listdir(checkpoint_path):
      os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))

trainer = Trainer(
    model = model,
    train_dataset = dataset["train"],
    args = TrainingArguments(
        save_steps = 0.1,
        ddp_find_unused_parameters=False,
        gradient_checkpointing=True,
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 1,
        num_train_epochs = 10,
        learning_rate = 2e-5,
        fp16 = True,
        max_grad_norm=0.03,
        logging_steps = 1,
        output_dir = models[model_name]['folder_name'],
        optim = "paged_adamw_8bit"
    ),
    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False),
    callbacks = [PeftSavingCallback]
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!

trainer.train()

# %%
!nvidia-smi

@nivibilla
Copy link
Author

or here as a notebook in drive

@artidoro
Copy link
Owner

The code looks reasonable. One thing I noticed is that you are not controlling the number of tokens in your data. I think in our experiments we keep the maximum sequence length to 256 or 512 depending on the dataset and truncate the rest. This could explain your code running fine for a few steps and then getting into OOM with a longer sequence.

If you do want to train on longer sequences that don't fit in memory, maybe try splitting your model across two GPUs. This should be handled pretty easily with accelerate or by defining a custom device map that puts layers on different GPUs.

@nivibilla
Copy link
Author

Ah okay I didn't know that. I have some samples that are close to 4k.

Does finetuning on 512 length still perform well when you test on longer sequences?

@artidoro
Copy link
Owner

I am not sure, to be honest. Maybe if you are using LLaMA2, its pretraining on 8k context is sufficient even if finetuning is only on 512 tokens. Especially if you use relatively few examples similar to LIMA or like we did with Guanaco. But this is still an open question as far as I know.

@nivibilla
Copy link
Author

Ah okay, cool. I will let you know how it goes. I plan to train on 50k data rows max

@nivibilla
Copy link
Author

I will confirm that with limiting to 512 fixes this and then close the issue

@artidoro
Copy link
Owner

You might want to look into flash attention for long sequences. There was some discussion on this issue #221. If you have an interest in long sequence modeling we would appreciate some help getting an example with flash attention in QLoRA.

@nivibilla
Copy link
Author

Thanks! Will have a look

@nivibilla
Copy link
Author

@artidoro I have checked with limiting the seq length to 512. It works fine. Might be worth mentioning this somewhere in case others have the same issue. Closing this issue as my problem is fixed.

@nivibilla
Copy link
Author

@artidoro for reference I was able to go to seq Len 1024 with an effective batch size of 32(actual batch 4 and gradient accumulation 8 with gradient checkpointing enabled) all on a single 24GB gpu

@AaronZLT
Copy link

could you please share that where to adjust the seq len?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants