Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapting Unsloth's implementation for classification tasks #1461

Open
Gladiator07 opened this issue Dec 21, 2024 · 1 comment
Open

Adapting Unsloth's implementation for classification tasks #1461

Gladiator07 opened this issue Dec 21, 2024 · 1 comment

Comments

@Gladiator07
Copy link

Hi @danielhanchen ,

I am experimenting with training causal models for classification tasks. So far HuggingFace's implementation projects the last hidden states to the final classification layer (instead of a normal language modelling head as in causal lm architecture). I am trying to see if we can stay in unsloth and still finetune for classification.

So far I've tried this notebook: https://github.com/timothelaborie/text_classification_scripts/blob/main/unsloth_classification.ipynb
and it works great. The way shown in the notebook effectively matches the HuggingFace's implementation of AutoModelForSequenceClassification for causal models. But when we want to save the model and actually do the inference, unsloth throws an error of size mis-match as we are trying to copy the weights of last lm_head which we modified in the notebook. There's a monkey-patch to avoid this error and get it working, but seems very ugly. Is there a cleaner way to save the model (lora adapters) and load it for inference with the modified lm_head. Also a native support for finetuning classification models would be great too if it's in the plans.

Again thanks for all the work you do!

@shimmyshimmer
Copy link
Collaborator

Interesting we'll take a closer look!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants