2929from aixplain .modules import Model , Finetune
3030from aixplain .modules .finetune import Hyperparameters
3131from aixplain .enums import Function
32+ from urllib .parse import urljoin
3233
3334import pytest
3435
3738COST_ESTIMATION_FILE = "tests/unit/mock_responses/cost_estimation_response.json"
3839FINETUNE_URL = f"{ config .BACKEND_URL } /sdk/finetune"
3940FINETUNE_FILE = "tests/unit/mock_responses/finetune_response.json"
41+ FINETUNE_STATUS_FILE = "tests/unit/mock_responses/finetune_status_response.json"
42+ FINETUNE_STATUS_FILE_2 = "tests/unit/mock_responses/finetune_status_response_2.json"
4043PERCENTAGE_EXCEPTION_FILE = "tests/unit/data/create_finetune_percentage_exception.json"
4144MODEL_FILE = "tests/unit/mock_responses/model_response.json"
4245MODEL_URL = f"{ config .BACKEND_URL } /sdk/models"
@@ -106,16 +109,27 @@ def test_start():
106109 assert fine_tuned_model is not None
107110 assert fine_tuned_model .id == model_map ["id" ]
108111
109-
110- def test_check_finetuner_status ():
111- model_map = read_data (MODEL_FILE )
112+ @pytest .mark .parametrize (
113+ "input_path,after_epoch,training_loss,validation_loss" ,
114+ [
115+ (FINETUNE_STATUS_FILE , None , 0.4 , 0.0217 ),
116+ (FINETUNE_STATUS_FILE , 1 , 0.2 , 0.0482 ),
117+ (FINETUNE_STATUS_FILE_2 , None , 2.657801408034 , 2.596168756485 ),
118+ (FINETUNE_STATUS_FILE_2 , 0 , None , 2.684150457382 )
119+ ]
120+ )
121+ def test_check_finetuner_status (input_path , after_epoch , training_loss , validation_loss ):
122+ model_map = read_data (input_path )
112123 asset_id = "test_id"
113124 with requests_mock .Mocker () as mock :
114125 test_model = Model (asset_id , "" )
115- url = f" { MODEL_URL } / { asset_id } "
126+ url = urljoin ( config . BACKEND_URL , f"sdk/finetune/ { asset_id } /ml-logs" )
116127 mock .get (url , headers = FIXED_HEADER , json = model_map )
117- status = test_model .check_finetune_status ()
118- assert status == model_map ["status" ]
128+ status = test_model .check_finetune_status (after_epoch = after_epoch )
129+ assert status .status .value == model_map ["finetuneStatus" ]
130+ assert status .model_status .value == model_map ["modelStatus" ]
131+ assert status .training_loss == training_loss
132+ assert status .validation_loss == validation_loss
119133
120134
121135@pytest .mark .parametrize ("is_finetunable" , [True , False ])
@@ -132,4 +146,4 @@ def test_list_finetunable_models(is_finetunable):
132146 model_list = result ["results" ]
133147 assert len (model_list ) > 0
134148 for model_index in range (len (model_list )):
135- assert model_list [model_index ].id == list_map ["items" ][model_index ]["id" ]
149+ assert model_list [model_index ].id == list_map ["items" ][model_index ]["id" ]
0 commit comments