Skip to content

Commit 3503694

Browse files
authored
Merge pull request #420 from roboflow/feature/epochs-configurable
2 parents 4310e56 + 2f8cd11 commit 3503694

File tree

3 files changed

+26
-13
lines changed

3 files changed

+26
-13
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from roboflow.models import CLIPModel, GazeModel # noqa: F401
1616
from roboflow.util.general import write_line
1717

18-
__version__ = "1.2.10"
18+
__version__ = "1.2.11"
1919

2020

2121
def check_key(api_key, model, notebook, num_retries=0):

roboflow/adapters/rfapi.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import os
33
import urllib
4-
from typing import List, Optional
4+
from typing import Dict, List, Optional, Union
55

66
import requests
77
from requests.exceptions import RequestException
@@ -58,6 +58,7 @@ def start_version_training(
5858
speed: Optional[str] = None,
5959
checkpoint: Optional[str] = None,
6060
model_type: Optional[str] = None,
61+
epochs: Optional[int] = None,
6162
):
6263
"""
6364
Start a training job for a specific version.
@@ -66,14 +67,16 @@ def start_version_training(
6667
"""
6768
url = f"{API_URL}/{workspace_url}/{project_url}/{version}/train?api_key={api_key}&nocache=true"
6869

69-
data = {}
70+
data: Dict[str, Union[str, int]] = {}
7071
if speed is not None:
7172
data["speed"] = speed
7273
if checkpoint is not None:
7374
data["checkpoint"] = checkpoint
7475
if model_type is not None:
7576
# API expects camelCase
7677
data["modelType"] = model_type
78+
if epochs is not None:
79+
data["epochs"] = epochs
7780

7881
response = requests.post(url, json=data)
7982
if not response.ok:

roboflow/core/version.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -296,15 +296,18 @@ def export(self, model_format=None) -> bool | None:
296296
else:
297297
raise RuntimeError(f"Unexpected export {export_info}")
298298

299-
def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
299+
def train(
300+
self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False, epochs=None
301+
) -> InferenceModel:
300302
"""
301303
Ask the Roboflow API to train a previously exported version's dataset.
302304
303305
Args:
304306
speed: Whether to train quickly or accurately. Note: accurate training is a paid feature. Default speed is `fast`.
305307
model_type: The type of model to train. Default depends on kind of project. It takes precedence over speed. You can check the list of model ids by sending an invalid parameter in this argument.
306308
checkpoint: A string representing the checkpoint to use while training
307-
plot: Whether to plot the training results. Default is `False`.
309+
epochs: Number of epochs to train the model
310+
plot_in_notebook: Whether to plot the training results. Default is `False`.
308311
309312
Returns:
310313
An instance of the trained model class
@@ -336,6 +339,7 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
336339
speed=payload_speed,
337340
checkpoint=payload_checkpoint,
338341
model_type=payload_model_type,
342+
epochs=epochs,
339343
)
340344

341345
status = "training"
@@ -385,15 +389,15 @@ def live_plot(epochs, mAP, loss, title=""):
385389
write_line(line="Training failed")
386390
break
387391

388-
epochs: Union[np.ndarray, list]
392+
epoch_ids: Union[np.ndarray, list]
389393
mAP: Union[np.ndarray, list]
390394
loss: Union[np.ndarray, list]
391395

392396
if "roboflow-train" in models.keys():
393397
import numpy as np
394398

395399
# training has started
396-
epochs = np.array([int(epoch["epoch"]) for epoch in models["roboflow-train"]["epochs"]])
400+
epoch_ids = np.array([int(epoch["epoch"]) for epoch in models["roboflow-train"]["epochs"]])
397401
mAP = np.array([float(epoch["mAP"]) for epoch in models["roboflow-train"]["epochs"]])
398402
loss = np.array(
399403
[
@@ -410,23 +414,29 @@ def live_plot(epochs, mAP, loss, title=""):
410414
num_machine_spin_dots = ["."]
411415
title = "Training Machine Spinning Up" + "".join(num_machine_spin_dots)
412416

413-
epochs = []
417+
epoch_ids = []
414418
mAP = []
415419
loss = []
416420

417-
if (len(epochs) > len(previous_epochs)) or (len(epochs) == 0):
421+
if (len(epoch_ids) > len(previous_epochs)) or (len(epoch_ids) == 0):
418422
if plot_in_notebook:
419-
live_plot(epochs, mAP, loss, title)
423+
live_plot(epoch_ids, mAP, loss, title)
420424
else:
421-
if len(epochs) > 0:
425+
if len(epoch_ids) > 0:
422426
title = (
423-
title + ": Epoch: " + str(epochs[-1]) + " mAP: " + str(mAP[-1]) + " loss: " + str(loss[-1])
427+
title
428+
+ ": Epoch: "
429+
+ str(epoch_ids[-1])
430+
+ " mAP: "
431+
+ str(mAP[-1])
432+
+ " loss: "
433+
+ str(loss[-1])
424434
)
425435
if not first_graph_write:
426436
write_line(title)
427437
first_graph_write = True
428438

429-
previous_epochs = copy.deepcopy(epochs)
439+
previous_epochs = copy.deepcopy(epoch_ids)
430440

431441
time.sleep(5)
432442

0 commit comments

Comments
 (0)