Skip to content

Commit 5af3d95

Browse files
fix: gliner model (#715)
1 parent 80650cb commit 5af3d95

File tree

5 files changed

+68
-94
lines changed

5 files changed

+68
-94
lines changed

label_studio_ml/examples/gliner/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,16 @@ The following common parameters are available:
8585
- `THREADS` - Specify the number of threads for the model server.
8686
- `LABEL_STUDIO_URL` - Specify the URL of your Label Studio instance. Note that this might need to be `http://host.docker.internal:8080` if you are running Label Studio on another Docker container.
8787
- `LABEL_STUDIO_API_KEY`- Specify the API key for authenticating your Label Studio instance. You can find this by logging into Label Studio and and [going to the **Account & Settings** page](https://labelstud.io/guide/user_account#Access-token).
88+
89+
## A Note on Model Training
90+
91+
If you plan to use a webhook to train this model on "Start Training", note that you do
92+
not need to configure a separate webhook. Instead, go to the three dots next to your model
93+
on the Model tab in your project settings and click "start training".
94+
95+
Additionally, note that this container has been set for a **VERY SMALL** demo set, with only 1
96+
non-eval sample (we expect the first 10 data samples to be for evaluation.)
97+
98+
If you're working with a larger dataset, be sure to:
99+
1. update num_steps and batch size to the number of training steps you want and the batch size that works for your dataset.
100+
2. change the uploaded model after training (line 239 of `model.py`) to the highest checkpoint that you have.

label_studio_ml/examples/gliner/docker-compose.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ services:
2929
# Determine the actual IP using 'ifconfig' (Linux/Mac) or 'ipconfig' (Windows).
3030
- LABEL_STUDIO_URL=http://host.docker.internal:8080
3131
- LABEL_STUDIO_API_KEY=
32-
3332
ports:
3433
- "9090:9090"
3534
volumes:

label_studio_ml/examples/gliner/model.py

Lines changed: 51 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import logging
22
import os
3-
from types import SimpleNamespace
3+
from math import floor
44
from typing import List, Dict, Optional
55

66
import label_studio_sdk
7-
import torch
87
from gliner import GLiNER
8+
from gliner.data_processing.collator import DataCollator
9+
from gliner.training import Trainer, TrainingArguments
10+
from label_studio_sdk.label_interface.objects import PredictionValue
11+
912
from label_studio_ml.model import LabelStudioMLBase
1013
from label_studio_ml.response import ModelResponse
11-
from label_studio_sdk.label_interface.objects import PredictionValue
12-
from tqdm import tqdm
13-
from transformers import get_cosine_schedule_with_warmup
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -109,7 +109,7 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
109109

110110
# make predictions with currently set model
111111
from_name, to_name, value = self.label_interface.get_first_tag_occurence('Labels', 'Text')
112-
112+
113113
# get labels from the labeling configuration
114114
labels = sorted(self.label_interface.get_tag(from_name).labels)
115115

@@ -141,7 +141,7 @@ def process_training_data(self, task):
141141
ner.append([start_token, end_token, label])
142142
return tokens, ner
143143

144-
def train(self, model, config, train_data, eval_data=None):
144+
def train(self, model, training_args, train_data, eval_data=None):
145145
"""
146146
retrain the GLiNER model. Code adapted from the GLiNER finetuning notebook.
147147
:param model: the model to train
@@ -150,69 +150,23 @@ def train(self, model, config, train_data, eval_data=None):
150150
:param eval_data: the eval data
151151
"""
152152
logger.info("Training Model")
153-
model = model.to(config.device)
154-
155-
# Set sampling parameters from config
156-
model.set_sampling_params(
157-
max_types=config.max_types,
158-
shuffle_types=config.shuffle_types,
159-
random_drop=config.random_drop,
160-
max_neg_type_ratio=config.max_neg_type_ratio,
161-
max_len=config.max_len
162-
)
153+
if training_args.use_cpu == True:
154+
model = model.to('cpu')
155+
else:
156+
model = model.to("cuda")
163157

164-
model.train()
165-
train_loader = model.create_dataloader(train_data, batch_size=config.train_batch_size, shuffle=True)
166-
optimizer = model.get_optimizer(config.lr_encoder, config.lr_others, config.freeze_token_rep)
167-
pbar = tqdm(range(config.num_steps))
168-
num_warmup_steps = int(config.num_steps * config.warmup_ratio) if config.warmup_ratio < 1 else int(
169-
config.warmup_ratio)
170-
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, config.num_steps)
171-
iter_train_loader = iter(train_loader)
172-
173-
for step in pbar:
174-
try:
175-
x = next(iter_train_loader)
176-
except StopIteration:
177-
iter_train_loader = iter(train_loader)
178-
x = next(iter_train_loader)
179-
180-
for k, v in x.items():
181-
if isinstance(v, torch.Tensor):
182-
x[k] = v.to(config.device)
183-
184-
try:
185-
loss = model(x) # Forward pass
186-
except RuntimeError as e:
187-
print(f"Error during forward pass at step {step}: {e}")
188-
print(f"x: {x}")
189-
continue
190-
191-
if torch.isnan(loss):
192-
print("Loss is NaN, skipping...")
193-
continue
194-
195-
loss.backward() # Compute gradients
196-
optimizer.step() # Update parameters
197-
scheduler.step() # Update learning rate schedule
198-
optimizer.zero_grad() # Reset gradients
199-
200-
description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}"
201-
pbar.set_description(description)
202-
203-
if (step + 1) % config.eval_every == 0:
204-
model.eval()
205-
if eval_data:
206-
results, f1 = model.evaluate(eval_data["samples"], flat_ner=True, threshold=0.5, batch_size=12,
207-
entity_types=eval_data["entity_types"])
208-
print(f"Step={step}\n{results}")
209-
210-
if not os.path.exists(config.save_directory):
211-
os.makedirs(config.save_directory)
212-
213-
model.save_pretrained(f"{config.save_directory}/finetuned_{step}")
214-
model.train()
158+
data_collator = DataCollator(model.config, data_processor=model.data_processor, prepare_labels=True)
159+
160+
trainer = Trainer(
161+
model=model,
162+
args=training_args,
163+
train_dataset=train_data,
164+
eval_dataset=eval_data,
165+
tokenizer=model.data_processor.transformer_tokenizer,
166+
data_collator=data_collator,
167+
)
215168

169+
trainer.train()
216170

217171
def fit(self, event, data, **kwargs):
218172
"""
@@ -250,31 +204,38 @@ def fit(self, event, data, **kwargs):
250204

251205
# Define the hyperparameters in a config variable
252206
# This comes from the pretraining example in the GLiNER repo
253-
config = SimpleNamespace(
254-
num_steps=10000, # number of training iteration
255-
train_batch_size=2,
256-
eval_every=1000, # evaluation/saving steps
257-
save_directory="logs", # where to save checkpoints
258-
warmup_ratio=0.1, # warmup steps
259-
device='cpu',
260-
lr_encoder=1e-5, # learning rate for the backbone
261-
lr_others=5e-5, # learning rate for other parameters
262-
freeze_token_rep=False, # freeze of not the backbone
263-
264-
# Parameters for set_sampling_params
265-
max_types=25, # maximum number of entity types during training
266-
shuffle_types=True, # if shuffle or not entity types
267-
random_drop=True, # randomly drop entity types
268-
max_neg_type_ratio=1,
269-
# ratio of positive/negative types, 1 mean 50%/50%, 2 mean 33%/66%, 3 mean 25%/75% ...
270-
max_len=384 # maximum sentence length
207+
num_steps = 10
208+
batch_size = 1
209+
data_size = len(training_data)
210+
num_batches = floor(data_size / batch_size)
211+
num_epochs = max(1, floor(num_steps / num_batches))
212+
213+
training_args = TrainingArguments(
214+
output_dir="models",
215+
learning_rate=5e-6,
216+
weight_decay=0.01,
217+
others_lr=1e-5,
218+
others_weight_decay=0.01,
219+
lr_scheduler_type="linear", # cosine
220+
warmup_ratio=0.1,
221+
per_device_train_batch_size=batch_size,
222+
per_device_eval_batch_size=batch_size,
223+
focal_loss_alpha=0.75,
224+
focal_loss_gamma=2,
225+
num_train_epochs=num_epochs,
226+
evaluation_strategy="steps",
227+
save_steps=100,
228+
save_total_limit=10,
229+
dataloader_num_workers=0,
230+
use_cpu=True,
231+
report_to="none",
271232
)
272233

273-
self.train(self.model, config, training_data, eval_data)
234+
self.train(self.model, training_args, training_data, eval_data)
274235

275236
logger.info("Saving new fine-tuned model as the default model")
276-
self.model = GLiNERModel.from_pretrained("finetuned", local_files_only=True)
277-
model_version = self.model_version[-1] + 1
237+
self.model = GLiNER.from_pretrained(f"models/checkpoint-10", local_files_only=True)
238+
model_version = int(self.model_version[-1]) + 1
278239
self.set("model_version", f'{self.__class__.__name__}-v{model_version}')
279240
else:
280-
logger.info("Model training not triggered")
241+
logger.info("Model training not triggered")
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
gliner==0.1.12
2-
torch==2.2.0
1+
gliner==0.2.16
2+
torch==2.2.0
3+
accelerate>=0.26.0

label_studio_ml/examples/gliner/test_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_predict(client):
5656
}
5757

5858
expected_response = {"results": [{"model_version": "GLiNERModel-v0.0.1", "result": [
59-
{"from_name": "label", "score": 0.9220, "to_name": "text", "type": "labels",
59+
{"from_name": "label", "score": 0.922, "to_name": "text", "type": "labels",
6060
"value": {"end": 11, "labels": ["Medication/Vaccine"], "start": 0, "text": "atomoxetine"}},
6161
{"from_name": "label", "score": 0.7053, "to_name": "text", "type": "labels",
6262
"value": {"end": 65, "labels": ["Medication/Vaccine"], "start": 32,

0 commit comments

Comments
 (0)