Skip to content

Commit ecbcad4

Browse files
ikxplainlucas-aixplainthiago-aixplain
authored
Merge dev to test (#107) (#108)
* Create bounds for FineTune hyperparameters (#103) * Test bound to hyperparameters * Update finetune llm hyperparameters * Remove option to use PEFT, always on use now * Fixing pipeline general asset test (#106) --------- Co-authored-by: Lucas Pavanelli <[email protected]> Co-authored-by: Thiago Castro Ferreira <[email protected]>
1 parent 92d1870 commit ecbcad4

File tree

8 files changed

+48
-42
lines changed

8 files changed

+48
-42
lines changed

aixplain/factories/finetune_factory/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from aixplain.modules.finetune import Finetune
2929
from aixplain.modules.finetune.cost import FinetuneCost
3030
from aixplain.modules.finetune.hyperparameters import Hyperparameters
31-
from aixplain.modules.finetune.peft import Peft
3231
from aixplain.modules.dataset import Dataset
3332
from aixplain.modules.model import Model
3433
from aixplain.utils import config
@@ -66,7 +65,6 @@ def create(
6665
model: Model,
6766
prompt_template: Optional[Text] = None,
6867
hyperparameters: Optional[Hyperparameters] = None,
69-
peft: Optional[Peft] = None,
7068
train_percentage: Optional[float] = 100,
7169
dev_percentage: Optional[float] = 0,
7270
) -> Finetune:
@@ -78,7 +76,6 @@ def create(
7876
model (Model): Model to be fine-tuned.
7977
prompt_template (Text, optional): Fine-tuning prompt_template. Should reference columns in the dataset using format <<COLUMN_NAME>>. Defaults to None.
8078
hyperparameters (Hyperparameters, optional): Hyperparameters for fine-tuning. Defaults to None.
81-
peft (Peft, optional): PEFT (Parameter-Efficient Fine-Tuning) configuration. Defaults to None.
8279
train_percentage (float, optional): Percentage of training samples. Defaults to 100.
8380
dev_percentage (float, optional): Percentage of development samples. Defaults to 0.
8481
Returns:
@@ -106,8 +103,6 @@ def create(
106103
parameters["prompt"] = prompt_template
107104
if hyperparameters is not None:
108105
parameters["hyperparameters"] = hyperparameters.to_dict()
109-
if peft is not None:
110-
parameters["peft"] = peft.to_dict()
111106
payload["parameters"] = parameters
112107
logging.info(f"Start service for POST Create FineTune - {url} - {headers} - {json.dumps(payload)}")
113108
r = _request_with_retry("post", url, headers=headers, json=payload)
@@ -123,7 +118,6 @@ def create(
123118
dev_percentage=dev_percentage,
124119
prompt_template=prompt_template,
125120
hyperparameters=hyperparameters,
126-
peft=peft,
127121
)
128122
except Exception:
129123
error_message = f"Create FineTune: Error with payload {json.dumps(payload)}"

aixplain/modules/finetune/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from urllib.parse import urljoin
2727
from aixplain.modules.finetune.cost import FinetuneCost
2828
from aixplain.modules.finetune.hyperparameters import Hyperparameters
29-
from aixplain.modules.finetune.peft import Peft
3029
from aixplain.factories.model_factory import ModelFactory
3130
from aixplain.modules.asset import Asset
3231
from aixplain.modules.dataset import Dataset
@@ -52,7 +51,6 @@ class Finetune(Asset):
5251
dev_percentage (float): Percentage of development samples.
5352
prompt_template (Text): Fine-tuning prompt_template.
5453
hyperparameters (Hyperparameters): Hyperparameters for fine-tuning.
55-
peft (Peft): PEFT (Parameter-Efficient Fine-Tuning) configuration.
5654
additional_info (dict): Additional information to be saved with the FineTune.
5755
backend_url (str): URL of the backend.
5856
api_key (str): The TEAM API key used for authentication.
@@ -72,7 +70,6 @@ def __init__(
7270
dev_percentage: Optional[float] = 0,
7371
prompt_template: Optional[Text] = None,
7472
hyperparameters: Optional[Hyperparameters] = None,
75-
peft: Optional[Peft] = None,
7673
**additional_info,
7774
) -> None:
7875
"""Create a FineTune with the necessary information.
@@ -90,7 +87,6 @@ def __init__(
9087
dev_percentage (float, optional): Percentage of development samples. Defaults to 0.
9188
prompt_template (Text, optional): Fine-tuning prompt_template. Should reference columns in the dataset using format <<COLUMN_NAME>>. Defaults to None.
9289
hyperparameters (Hyperparameters, optional): Hyperparameters for fine-tuning. Defaults to None.
93-
peft (Peft, optional): PEFT (Parameter-Efficient Fine-Tuning) configuration. Defaults to None.
9490
**additional_info: Additional information to be saved with the FineTune.
9591
"""
9692
super().__init__(id, name, description, supplier, version)
@@ -101,7 +97,6 @@ def __init__(
10197
self.dev_percentage = dev_percentage
10298
self.prompt_template = prompt_template
10399
self.hyperparameters = hyperparameters
104-
self.peft = peft
105100
self.additional_info = additional_info
106101
self.backend_url = config.BACKEND_URL
107102
self.api_key = config.TEAM_API_KEY
@@ -134,8 +129,6 @@ def start(self) -> Model:
134129
parameters["prompt"] = self.prompt_template
135130
if self.hyperparameters is not None:
136131
parameters["hyperparameters"] = self.hyperparameters.to_dict()
137-
if self.peft is not None:
138-
parameters["peft"] = self.peft.to_dict()
139132
payload["parameters"] = parameters
140133
logging.info(f"Start service for POST Start FineTune - {url} - {headers} - {json.dumps(payload)}")
141134
r = _request_with_retry("post", url, headers=headers, json=payload)
Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from dataclasses import dataclass
22
from dataclasses_json import dataclass_json
3+
from enum import Enum
4+
from typing import Text
35

46

5-
class SchedulerType:
7+
class SchedulerType(Text, Enum):
68
LINEAR = "linear"
79
COSINE = "cosine"
810
COSINE_WITH_RESTARTS = "cosine_with_restarts"
@@ -13,19 +15,49 @@ class SchedulerType:
1315
REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
1416

1517

18+
EPOCHS_MAX_VALUE = 4
19+
MAX_SEQ_LENGTH_MAX_VALUE = 4096
20+
GENERATION_MAX_LENGTH_MAX_VALUE = 225
21+
22+
1623
@dataclass_json
1724
@dataclass
1825
class Hyperparameters(object):
19-
epochs: int = 4
20-
train_batch_size: int = 4
21-
eval_batch_size: int = 4
22-
learning_rate: float = 2e-5
26+
epochs: int = 1
27+
learning_rate: float = 1e-5
2328
generation_max_length: int = 225
24-
tokenizer_batch_size: int = 256
25-
gradient_checkpointing: bool = False
26-
gradient_accumulation_steps: int = 1
2729
max_seq_length: int = 4096
2830
warmup_ratio: float = 0.0
2931
warmup_steps: int = 0
30-
early_stopping_patience: int = 1
3132
lr_scheduler_type: SchedulerType = SchedulerType.LINEAR
33+
34+
def __post_init__(self):
35+
if not isinstance(self.epochs, int):
36+
raise TypeError("epochs should be of type int")
37+
38+
if not isinstance(self.learning_rate, float):
39+
raise TypeError("learning_rate should be of type float")
40+
41+
if not isinstance(self.generation_max_length, int):
42+
raise TypeError("generation_max_length should be of type int")
43+
44+
if not isinstance(self.max_seq_length, int):
45+
raise TypeError("max_seq_length should be of type int")
46+
47+
if not isinstance(self.warmup_ratio, float):
48+
raise TypeError("warmup_ratio should be of type float")
49+
50+
if not isinstance(self.warmup_steps, int):
51+
raise TypeError("warmup_steps should be of type int")
52+
53+
if not isinstance(self.lr_scheduler_type, SchedulerType):
54+
raise TypeError("lr_scheduler_type should be of type SchedulerType")
55+
56+
if self.epochs > EPOCHS_MAX_VALUE:
57+
raise ValueError(f"epochs must be one less than {EPOCHS_MAX_VALUE}")
58+
59+
if self.max_seq_length > MAX_SEQ_LENGTH_MAX_VALUE:
60+
raise ValueError(f"max_seq_length must be less than {MAX_SEQ_LENGTH_MAX_VALUE}")
61+
62+
if self.generation_max_length > GENERATION_MAX_LENGTH_MAX_VALUE:
63+
raise ValueError(f"generation_max_length must be less than {GENERATION_MAX_LENGTH_MAX_VALUE}")

aixplain/modules/finetune/peft.py

Lines changed: 0 additions & 10 deletions
This file was deleted.
Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
[
22
{"model_name": "gpt2", "model_id": "64e615671567f848804985e1", "dataset_name": "Test text generation dataset"},
3-
{"model_name": "falcon 7b instruct", "model_id": "65519d57bf42e6037ab109d5", "dataset_name": "Test text generation dataset"},
4-
{"model_name": "bloomz 7b", "model_id": "6551ab17bf42e6037ab109e0", "dataset_name": "Test text generation dataset"},
5-
{"model_name": "MPT 7B", "model_id": "6551a72bbf42e6037ab109d9", "dataset_name": "Test text generation dataset"},
6-
{"model_name": "falcon 7b", "model_id": "6551bff9bf42e6037ab109e1", "dataset_name": "Test text generation dataset"},
7-
{"model_name": "mistral 7b", "model_id": "6551a9e7bf42e6037ab109de", "dataset_name": "Test text generation dataset"},
8-
{"model_name": "MPT 7B Storywriter", "model_id": "6551a870bf42e6037ab109db", "dataset_name": "Test text generation dataset"},
93
{"model_name": "llama 2 7b", "model_id": "6543cb991f695e72028e9428", "dataset_name": "Test text generation dataset"},
104
{"model_name": "Llama 2 7B Chat", "model_id": "65519ee7bf42e6037ab109d8", "dataset_name": "Test text generation dataset"}
115
]

tests/functional/finetune/data/finetune_test_end2end.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[
22
{
3-
"model_name": "gpt2",
4-
"model_id": "64e615671567f848804985e1",
3+
"model_name": "llama2 7b",
4+
"model_id": "6543cb991f695e72028e9428",
55
"dataset_name": "Test text generation dataset",
66
"inference_data": "Hello!",
77
"required_dev": true

tests/functional/general_assets/asset_functional_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def test_list(asset_name):
4040
def test_run(inputs, asset_name):
4141
asset_details = inputs[asset_name]
4242
AssetFactory = __get_asset_factory(asset_name)
43-
asset = AssetFactory.get(asset_details["id"])
43+
if asset_name == "pipeline":
44+
asset = AssetFactory.list(query=asset_details["name"])["results"][0]
45+
else:
46+
asset = AssetFactory.get(asset_details["id"])
4447
payload = asset_details["data"]
4548
if type(payload) is dict:
4649
output = asset.run(**payload)

tests/functional/general_assets/data/asset_run_test_data.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"data": "This is a test sentence."
55
},
66
"pipeline": {
7-
"id" : "64da138fa27cffd5e0c3c30d",
7+
"name": "SingleNodePipeline",
88
"data": "This is a test sentence."
99
},
1010
"metric": {

0 commit comments

Comments
 (0)