Skip to content

Commit 873cb6a

Browse files
thiago-aixplainThiago Castro FerreiraThiago Castro FerreiraThiago Castro Ferreira
authored
M 6107719447 check finetuner status (#133)
* Finetune status object * New finetune status checker * Covering some None cases * Finetune status checker updates and new unit tests --------- Co-authored-by: Thiago Castro Ferreira <[email protected]> Co-authored-by: Thiago Castro Ferreira <[email protected]> Co-authored-by: Thiago Castro Ferreira <[email protected]>
1 parent b4e5b67 commit 873cb6a

File tree

8 files changed

+240
-13
lines changed

8 files changed

+240
-13
lines changed

aixplain/enums/asset_status.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
__author__ = "thiagocastroferreira"
2+
3+
"""
4+
Copyright 2024 The aiXplain SDK authors
5+
6+
Licensed under the Apache License, Version 2.0 (the "License");
7+
you may not use this file except in compliance with the License.
8+
You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing, software
13+
distributed under the License is distributed on an "AS IS" BASIS,
14+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
See the License for the specific language governing permissions and
16+
limitations under the License.
17+
18+
Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli
19+
Date: February 21st 2024
20+
Description:
21+
Asset Enum
22+
"""
23+
24+
from enum import Enum
25+
from typing import Text
26+
27+
class AssetStatus(Text, Enum):
28+
HIDDEN = 'hidden'
29+
SCHEDULED = 'scheduled'
30+
ONBOARDING = 'onboarding'
31+
ONBOARDED = 'onboarded'
32+
PENDING = 'pending'
33+
FAILED = 'failed'
34+
TRAINING = 'training'
35+
REJECTED = 'rejected'
36+
ENABLING = 'enabling'
37+
DELETING = 'deleting'
38+
DISABLED = 'disabled'
39+
DELETED = 'deleted'
40+
IN_PROGRESS = 'in_progress'
41+
COMPLETED = 'completed'
42+
CANCELING = 'canceling'
43+
CANCELED = 'canceled'

aixplain/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@
2929
from .model import Model
3030
from .pipeline import Pipeline
3131
from .finetune import Finetune, FinetuneCost
32+
from .finetune.status import FinetuneStatus
3233
from .benchmark import Benchmark
3334
from .benchmark_job import BenchmarkJob
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
__author__ = "thiagocastroferreira"
2+
3+
"""
4+
Copyright 2024 The aiXplain SDK authors
5+
6+
Licensed under the Apache License, Version 2.0 (the "License");
7+
you may not use this file except in compliance with the License.
8+
You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing, software
13+
distributed under the License is distributed on an "AS IS" BASIS,
14+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
See the License for the specific language governing permissions and
16+
limitations under the License.
17+
18+
Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli
19+
Date: February 21st 2024
20+
Description:
21+
FinetuneCost Class
22+
"""
23+
24+
from aixplain.enums.asset_status import AssetStatus
25+
from dataclasses import dataclass
26+
from dataclasses_json import dataclass_json
27+
from typing import Optional, Text
28+
29+
@dataclass_json
30+
@dataclass
31+
class FinetuneStatus(object):
32+
status: "AssetStatus"
33+
model_status: "AssetStatus"
34+
epoch: Optional[float] = None
35+
training_loss: Optional[float] = None
36+
validation_loss: Optional[float] = None

aixplain/modules/metric.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from typing import Optional, Text, List, Union
2525
from aixplain.modules.asset import Asset
26+
2627
from aixplain.utils.file_utils import _request_with_retry
2728
# from aixplain.factories.model_factory import ModelFactory
2829

@@ -92,6 +93,7 @@ def run(self, hypothesis: Optional[Union[str, List[str]]]=None, source: Optional
9293
source (Optional[Union[str, List[str]]], optional): Can give a single source or a list of sources for metric calculation. Defaults to None.
9394
reference (Optional[Union[str, List[str]]], optional): Can give a single reference or a list of references for metric calculation. Defaults to None.
9495
"""
96+
from aixplain.factories.model_factory import ModelFactory
9597
model = ModelFactory.get(self.id)
9698
payload = {
9799
"function": self.function,

aixplain/modules/model.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
Description:
2121
Model Class
2222
"""
23-
2423
import time
2524
import json
2625
import logging
@@ -251,23 +250,65 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param
251250
response["error"] = msg
252251
return response
253252

254-
def check_finetune_status(self):
253+
def check_finetune_status(self, after_epoch: Optional[int] = None):
255254
"""Check the status of the FineTune model.
256255
256+
Args:
257+
after_epoch (Optional[int], optional): status after a given epoch. Defaults to None.
258+
257259
Raises:
258260
Exception: If the 'TEAM_API_KEY' is not provided.
259261
260262
Returns:
261-
str: The status of the FineTune model.
263+
FinetuneStatus: The status of the FineTune model.
262264
"""
265+
from aixplain.enums.asset_status import AssetStatus
266+
from aixplain.modules.finetune.status import FinetuneStatus
263267
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
268+
resp = None
264269
try:
265-
url = urljoin(self.backend_url, f"sdk/models/{self.id}")
270+
url = urljoin(self.backend_url, f"sdk/finetune/{self.id}/ml-logs")
266271
logging.info(f"Start service for GET Check FineTune status Model - {url} - {headers}")
267272
r = _request_with_retry("get", url, headers=headers)
268273
resp = r.json()
269-
status = resp["status"]
270-
logging.info(f"Response for GET Check FineTune status Model - Id {self.id} / Status {status}.")
274+
finetune_status = AssetStatus(resp["finetuneStatus"])
275+
model_status = AssetStatus(resp["modelStatus"])
276+
logs = sorted(resp["logs"], key=lambda x: float(x["epoch"]))
277+
278+
target_epoch = None
279+
if after_epoch is not None:
280+
logs = [log for log in logs if float(log["epoch"]) > after_epoch]
281+
if len(logs) > 0:
282+
target_epoch = float(logs[0]["epoch"])
283+
elif len(logs) > 0:
284+
target_epoch = float(logs[-1]["epoch"])
285+
286+
if target_epoch is not None:
287+
log = None
288+
for log_ in logs:
289+
if int(log_["epoch"]) == target_epoch:
290+
if log is None:
291+
log = log_
292+
else:
293+
if log_["trainLoss"] is not None:
294+
log["trainLoss"] = log_["trainLoss"]
295+
if log_["evalLoss"] is not None:
296+
log["evalLoss"] = log_["evalLoss"]
297+
298+
status = FinetuneStatus(
299+
status=finetune_status,
300+
model_status=model_status,
301+
epoch=float(log["epoch"]) if "epoch" in log and log["epoch"] is not None else None,
302+
training_loss=float(log["trainLoss"]) if "trainLoss" in log and log["trainLoss"] is not None else None,
303+
validation_loss=float(log["evalLoss"]) if "evalLoss" in log and log["evalLoss"] is not None else None,
304+
)
305+
else:
306+
status = FinetuneStatus(
307+
status=finetune_status,
308+
model_status=model_status,
309+
)
310+
311+
logging.info(f"Response for GET Check FineTune status Model - Id {self.id} / Status {status.status.value}.")
271312
return status
272313
except Exception as e:
273314
message = ""

tests/unit/finetune_test.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from aixplain.modules import Model, Finetune
3030
from aixplain.modules.finetune import Hyperparameters
3131
from aixplain.enums import Function
32+
from urllib.parse import urljoin
3233

3334
import pytest
3435

@@ -37,6 +38,8 @@
3738
COST_ESTIMATION_FILE = "tests/unit/mock_responses/cost_estimation_response.json"
3839
FINETUNE_URL = f"{config.BACKEND_URL}/sdk/finetune"
3940
FINETUNE_FILE = "tests/unit/mock_responses/finetune_response.json"
41+
FINETUNE_STATUS_FILE = "tests/unit/mock_responses/finetune_status_response.json"
42+
FINETUNE_STATUS_FILE_2 = "tests/unit/mock_responses/finetune_status_response_2.json"
4043
PERCENTAGE_EXCEPTION_FILE = "tests/unit/data/create_finetune_percentage_exception.json"
4144
MODEL_FILE = "tests/unit/mock_responses/model_response.json"
4245
MODEL_URL = f"{config.BACKEND_URL}/sdk/models"
@@ -106,16 +109,27 @@ def test_start():
106109
assert fine_tuned_model is not None
107110
assert fine_tuned_model.id == model_map["id"]
108111

109-
110-
def test_check_finetuner_status():
111-
model_map = read_data(MODEL_FILE)
112+
@pytest.mark.parametrize(
113+
"input_path,after_epoch,training_loss,validation_loss",
114+
[
115+
(FINETUNE_STATUS_FILE, None, 0.4, 0.0217),
116+
(FINETUNE_STATUS_FILE, 1, 0.2, 0.0482),
117+
(FINETUNE_STATUS_FILE_2, None, 2.657801408034, 2.596168756485),
118+
(FINETUNE_STATUS_FILE_2, 0, None, 2.684150457382)
119+
]
120+
)
121+
def test_check_finetuner_status(input_path, after_epoch, training_loss, validation_loss):
122+
model_map = read_data(input_path)
112123
asset_id = "test_id"
113124
with requests_mock.Mocker() as mock:
114125
test_model = Model(asset_id, "")
115-
url = f"{MODEL_URL}/{asset_id}"
126+
url = urljoin(config.BACKEND_URL, f"sdk/finetune/{asset_id}/ml-logs")
116127
mock.get(url, headers=FIXED_HEADER, json=model_map)
117-
status = test_model.check_finetune_status()
118-
assert status == model_map["status"]
128+
status = test_model.check_finetune_status(after_epoch=after_epoch)
129+
assert status.status.value == model_map["finetuneStatus"]
130+
assert status.model_status.value == model_map["modelStatus"]
131+
assert status.training_loss == training_loss
132+
assert status.validation_loss == validation_loss
119133

120134

121135
@pytest.mark.parametrize("is_finetunable", [True, False])
@@ -132,4 +146,4 @@ def test_list_finetunable_models(is_finetunable):
132146
model_list = result["results"]
133147
assert len(model_list) > 0
134148
for model_index in range(len(model_list)):
135-
assert model_list[model_index].id == list_map["items"][model_index]["id"]
149+
assert model_list[model_index].id == list_map["items"][model_index]["id"]
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{
2+
"finetuneStatus": "onboarding",
3+
"modelStatus": "onboarded",
4+
"logs": [
5+
{
6+
"epoch": 1,
7+
"learningRate": 9.938725490196079e-05,
8+
"trainLoss": 0.1,
9+
"evalLoss": 0.1106,
10+
"step": 10
11+
},
12+
{
13+
"epoch": 2,
14+
"learningRate": 9.877450980392157e-05,
15+
"trainLoss": 0.2,
16+
"evalLoss": 0.0482,
17+
"step": 20
18+
},
19+
{
20+
"epoch": 3,
21+
"learningRate": 9.816176470588235e-05,
22+
"trainLoss": 0.3,
23+
"evalLoss": 0.0251,
24+
"step": 30
25+
},
26+
{
27+
"epoch": 4,
28+
"learningRate": 9.754901960784314e-05,
29+
"trainLoss": 0.9,
30+
"evalLoss": 0.0228,
31+
"step": 40
32+
},
33+
{
34+
"epoch": 5,
35+
"learningRate": 9.693627450980392e-05,
36+
"trainLoss": 0.4,
37+
"evalLoss": 0.0217,
38+
"step": 50
39+
}
40+
]
41+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
{
2+
"id": "65fb26268fe9153a6c9c29c4",
3+
"finetuneStatus": "in_progress",
4+
"modelStatus": "training",
5+
"logs": [
6+
{
7+
"epoch": 1,
8+
"learningRate": null,
9+
"trainLoss": null,
10+
"validationLoss": null,
11+
"step": null,
12+
"evalLoss": 2.684150457382,
13+
"totalFlos": null,
14+
"evalRuntime": 12.4129,
15+
"trainRuntime": null,
16+
"evalStepsPerSecond": 0.322,
17+
"trainStepsPerSecond": null,
18+
"evalSamplesPerSecond": 16.112
19+
},
20+
{
21+
"epoch": 2,
22+
"learningRate": null,
23+
"trainLoss": null,
24+
"validationLoss": null,
25+
"step": null,
26+
"evalLoss": 2.596168756485,
27+
"totalFlos": null,
28+
"evalRuntime": 11.8249,
29+
"trainRuntime": null,
30+
"evalStepsPerSecond": 0.338,
31+
"trainStepsPerSecond": null,
32+
"evalSamplesPerSecond": 16.913
33+
},
34+
{
35+
"epoch": 2,
36+
"learningRate": null,
37+
"trainLoss": 2.657801408034,
38+
"validationLoss": null,
39+
"step": null,
40+
"evalLoss": null,
41+
"totalFlos": 11893948284928,
42+
"evalRuntime": null,
43+
"trainRuntime": 221.7946,
44+
"evalStepsPerSecond": null,
45+
"trainStepsPerSecond": 0.117,
46+
"evalSamplesPerSecond": null
47+
}
48+
]
49+
}

0 commit comments

Comments
 (0)