Skip to content

Commit f668332

Browse files
M 5629918191 model sort (#100)
* init work * fix spelling * Model sorting methods * Fix on functional test --------- Co-authored-by: MAlyafeai18 <[email protected]>
1 parent 6d4946a commit f668332

File tree

5 files changed

+94
-2
lines changed

5 files changed

+94
-2
lines changed

aixplain/enums/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@
1111
from .privacy import Privacy
1212
from .storage_type import StorageType
1313
from .supplier import Supplier
14+
from .sort_by import SortBy
15+
from .sort_order import SortOrder

aixplain/enums/sort_by.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
__author__ = "aiXplain"
2+
3+
"""
4+
Copyright 2023 The aiXplain SDK authors
5+
6+
Licensed under the Apache License, Version 2.0 (the "License");
7+
you may not use this file except in compliance with the License.
8+
You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing, software
13+
distributed under the License is distributed on an "AS IS" BASIS,
14+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
See the License for the specific language governing permissions and
16+
limitations under the License.
17+
18+
Author: aiXplain team
19+
Date: March 20th 2023
20+
Description:
21+
Sort By Enum
22+
"""
23+
24+
from enum import Enum
25+
26+
27+
class SortBy(Enum):
28+
CREATION_DATE = "createdAt"
29+
PRICE = "normalizedPrice"
30+
POPULARITY = "totalSubscribed"

aixplain/enums/sort_order.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
__author__ = "aiXplain"
2+
3+
"""
4+
Copyright 2023 The aiXplain SDK authors
5+
6+
Licensed under the Apache License, Version 2.0 (the "License");
7+
you may not use this file except in compliance with the License.
8+
You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing, software
13+
distributed under the License is distributed on an "AS IS" BASIS,
14+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
See the License for the specific language governing permissions and
16+
limitations under the License.
17+
18+
Author: aiXplain team
19+
Date: March 20th 2023
20+
Description:
21+
Sort By Enum
22+
"""
23+
24+
from enum import Enum
25+
26+
27+
class SortOrder(Enum):
28+
ASCENDING = 1
29+
DESCENDING = -1

aixplain/factories/model_factory.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import json
2525
import logging
2626
from aixplain.modules.model import Model
27-
from aixplain.enums import Function, Language, OwnershipType, Supplier
27+
from aixplain.enums import Function, Language, OwnershipType, Supplier, SortBy, SortOrder
2828
from aixplain.utils import config
2929
from aixplain.utils.file_utils import _request_with_retry
3030
from urllib.parse import urljoin
@@ -130,6 +130,8 @@ def _get_assets_from_page(
130130
target_languages: Union[Language, List[Language]],
131131
is_finetunable: bool = None,
132132
ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None,
133+
sort_by: Optional[SortBy] = None,
134+
sort_order: SortOrder = SortOrder.ASCENDING,
133135
) -> List[Model]:
134136
try:
135137
url = urljoin(cls.backend_url, f"sdk/models/paginate")
@@ -146,6 +148,7 @@ def _get_assets_from_page(
146148
if isinstance(ownership, OwnershipType) is True:
147149
ownership = [ownership]
148150
filter_params["ownership"] = [ownership_.value for ownership_ in ownership]
151+
149152
lang_filter_params = []
150153
if source_languages is not None:
151154
if isinstance(source_languages, Language):
@@ -162,6 +165,8 @@ def _get_assets_from_page(
162165
if function == Function.TRANSLATION:
163166
code = "targetlanguage"
164167
lang_filter_params.append({"code": code, "value": target_languages[0].value["language"]})
168+
if sort_by is not None:
169+
filter_params["sort"] = [{"dir": sort_order.value, "field": sort_by.value}]
165170
if len(lang_filter_params) != 0:
166171
filter_params["ioFilter"] = lang_filter_params
167172
if cls.aixplain_key != "":
@@ -191,6 +196,8 @@ def list(
191196
target_languages: Optional[Union[Language, List[Language]]] = None,
192197
is_finetunable: Optional[bool] = None,
193198
ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None,
199+
sort_by: Optional[SortBy] = None,
200+
sort_order: SortOrder = SortOrder.ASCENDING,
194201
page_number: int = 0,
195202
page_size: int = 20,
196203
) -> List[Model]:
@@ -202,6 +209,7 @@ def list(
202209
target_languages (Optional[Union[Language, List[Language]]], optional): language filter of output data. Defaults to None.
203210
is_finetunable (Optional[bool], optional): can be finetuned or not. Defaults to None.
204211
ownership (Optional[Tuple[OwnershipType, List[OwnershipType]]], optional): Ownership filters (e.g. SUBSCRIBED, OWNER). Defaults to None.
212+
sort_by (Optional[SortBy], optional): sort the retrived models by a specific attribute,
205213
page_number (int, optional): page number. Defaults to 0.
206214
page_size (int, optional): page size. Defaults to 20.
207215
@@ -219,6 +227,8 @@ def list(
219227
target_languages,
220228
is_finetunable,
221229
ownership,
230+
sort_by,
231+
sort_order,
222232
)
223233
return {
224234
"results": models,

tests/functional/general_assets/asset_functional_test.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
load_dotenv()
55
from aixplain.factories import ModelFactory, DatasetFactory, MetricFactory, PipelineFactory
66
from pathlib import Path
7-
from aixplain.enums import Function, OwnershipType, Supplier
7+
from aixplain.enums import Function, Language, OwnershipType, Supplier, SortBy, SortOrder
88

99
import pytest
1010

@@ -63,6 +63,27 @@ def test_model_supplier():
6363
assert model.supplier.value in [desired_supplier.value for desired_supplier in desired_suppliers]
6464

6565

66+
def test_model_sort():
67+
function = Function.TRANSLATION
68+
src_language = Language.Portuguese
69+
trg_language = Language.English
70+
71+
models = ModelFactory.list(
72+
function=function,
73+
source_languages=src_language,
74+
target_languages=trg_language,
75+
sort_by=SortBy.PRICE,
76+
sort_order=SortOrder.DESCENDING,
77+
)["results"]
78+
for idx in range(1, len(models)):
79+
prev_model = models[idx - 1]
80+
model = models[idx]
81+
82+
prev_model_price = prev_model.additional_info["pricing"]["price"]
83+
model_price = model.additional_info["pricing"]["price"]
84+
assert prev_model_price >= model_price
85+
86+
6687
def test_model_ownership():
6788
models = ModelFactory.list(ownership=OwnershipType.SUBSCRIBED)["results"]
6889
for model in models:

0 commit comments

Comments
 (0)