Skip to content

Commit 8df2bd0

Browse files
thiago-aixplainThiago Castro FerreiraThiago Castro Ferreira
authored
Saving/Update Pipeline Services (#154)
* Saving/Update Pipeline Services * Docstrings for new functions * Pipeline creation unit test --------- Co-authored-by: Thiago Castro Ferreira <[email protected]> Co-authored-by: Thiago Castro Ferreira <[email protected]>
1 parent 451f309 commit 8df2bd0

File tree

5 files changed

+273
-1
lines changed

5 files changed

+273
-1
lines changed

aixplain/factories/pipeline_factory.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323
import json
2424
import logging
25+
import os
2526
from typing import Dict, List, Optional, Text, Union
2627
from aixplain.enums.data_type import DataType
2728
from aixplain.enums.function import Function
@@ -207,7 +208,7 @@ def list(
207208
output_data_types = [output_data_types]
208209
payload["inputDataTypes"] = [data_type.value for data_type in output_data_types]
209210

210-
logging.info(f"Start service for POST List Dataset - {url} - {headers} - {json.dumps(payload)}")
211+
logging.info(f"Start service for POST List Pipeline - {url} - {headers} - {json.dumps(payload)}")
211212
r = _request_with_retry("post", url, headers=headers, json=payload)
212213
resp = r.json()
213214

@@ -220,3 +221,40 @@ def list(
220221
for pipeline in results:
221222
pipelines.append(cls.__from_response(pipeline))
222223
return {"results": pipelines, "page_total": page_total, "page_number": page_number, "total": total}
224+
225+
@classmethod
226+
def create(cls, name: Text, pipeline: Union[Text, Dict], status: Text = "draft") -> Pipeline:
227+
"""Pipeline Creation
228+
229+
Args:
230+
name (Text): Pipeline Name
231+
pipeline (Union[Text, Dict]): Pipeline as a Python dictionary or in a JSON file
232+
status (Text, optional): Status of the pipeline. Currently only draft pipelines can be saved. Defaults to "draft".
233+
234+
Raises:
235+
Exception: Currently just the creation of draft pipelines are supported
236+
237+
Returns:
238+
Pipeline: instance of the new pipeline
239+
"""
240+
try:
241+
assert status == "draft", "Pipeline Creation Error: Currently just the creation of draft pipelines are supported."
242+
if isinstance(pipeline, str) is True:
243+
_, ext = os.path.splitext(pipeline)
244+
assert (
245+
os.path.exists(pipeline) and ext == ".json"
246+
), "Pipeline Creation Error: Make sure the pipeline to be save is in a JSON file."
247+
with open(pipeline) as f:
248+
pipeline = json.load(f)
249+
250+
# prepare payload
251+
payload = {"name": name, "status": "draft", "architecture": pipeline}
252+
url = urljoin(cls.backend_url, "sdk/pipelines")
253+
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}
254+
logging.info(f"Start service for POST Create Pipeline - {url} - {headers} - {json.dumps(payload)}")
255+
r = _request_with_retry("post", url, headers=headers, json=payload)
256+
response = r.json()
257+
258+
return Pipeline(response["id"], name, config.TEAM_API_KEY)
259+
except Exception as e:
260+
raise Exception(e)

aixplain/modules/pipeline.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323

2424
import time
2525
import json
26+
import os
2627
import logging
2728
from aixplain.modules.asset import Asset
2829
from aixplain.utils import config
2930
from aixplain.utils.file_utils import _request_with_retry
3031
from typing import Dict, Optional, Text, Union
32+
from urllib.parse import urljoin
3133

3234

3335
class Pipeline(Asset):
@@ -306,3 +308,32 @@ def run_async(
306308
if resp is not None:
307309
response["error"] = resp
308310
return response
311+
312+
def update(self, pipeline: Union[Text, Dict]):
313+
"""Update Pipeline
314+
315+
Args:
316+
pipeline (Union[Text, Dict]): Pipeline as a Python dictionary or in a JSON file
317+
318+
Raises:
319+
Exception: Make sure the pipeline to be save is in a JSON file.
320+
"""
321+
try:
322+
if isinstance(pipeline, str) is True:
323+
_, ext = os.path.splitext(pipeline)
324+
assert (
325+
os.path.exists(pipeline) and ext == ".json"
326+
), "Pipeline Update Error: Make sure the pipeline to be save is in a JSON file."
327+
with open(pipeline) as f:
328+
pipeline = json.load(f)
329+
330+
# prepare payload
331+
payload = {"name": self.name, "status": "draft", "architecture": pipeline}
332+
url = urljoin(config.BACKEND_URL, f"sdk/pipelines/{self.id}")
333+
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}
334+
logging.info(f"Start service for PUT Update Pipeline - {url} - {headers} - {json.dumps(payload)}")
335+
r = _request_with_retry("put", url, headers=headers, json=payload)
336+
response = r.json()
337+
logging.info(f"Pipeline {response['id']} Updated.")
338+
except Exception as e:
339+
raise Exception(e)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
__author__ = "thiagocastroferreira"
2+
3+
"""
4+
Copyright 2022 The aiXplain SDK authors
5+
6+
Licensed under the Apache License, Version 2.0 (the "License");
7+
you may not use this file except in compliance with the License.
8+
You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing, software
13+
distributed under the License is distributed on an "AS IS" BASIS,
14+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
See the License for the specific language governing permissions and
16+
limitations under the License.
17+
"""
18+
19+
import json
20+
import pytest
21+
from aixplain.factories import PipelineFactory
22+
from aixplain.modules import Pipeline
23+
from uuid import uuid4
24+
25+
26+
def test_create_pipeline_from_json():
27+
pipeline_json = "tests/functional/pipelines/data/pipeline.json"
28+
pipeline_name = str(uuid4())
29+
pipeline = PipelineFactory.create(name=pipeline_name, pipeline=pipeline_json)
30+
31+
assert isinstance(pipeline, Pipeline)
32+
assert pipeline.id != ""
33+
34+
35+
def test_create_pipeline_from_string():
36+
pipeline_json = "tests/functional/pipelines/data/pipeline.json"
37+
with open(pipeline_json) as f:
38+
pipeline_dict = json.load(f)
39+
40+
pipeline_name = str(uuid4())
41+
pipeline = PipelineFactory.create(name=pipeline_name, pipeline=pipeline_dict)
42+
43+
assert isinstance(pipeline, Pipeline)
44+
assert pipeline.id != ""
45+
46+
47+
def test_update_pipeline():
48+
pipeline_json = "tests/functional/pipelines/data/pipeline.json"
49+
with open(pipeline_json) as f:
50+
pipeline_dict = json.load(f)
51+
52+
pipeline_name = str(uuid4())
53+
pipeline = PipelineFactory.create(name=pipeline_name, pipeline=pipeline_dict)
54+
55+
pipeline.update(pipeline=pipeline_json)
56+
assert isinstance(pipeline, Pipeline)
57+
assert pipeline.id != ""
58+
59+
60+
def test_create_pipeline_wrong_path():
61+
pipeline_name = str(uuid4())
62+
63+
with pytest.raises(Exception):
64+
pipeline = PipelineFactory.create(name=pipeline_name, pipeline="/")
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
{
2+
"links": [
3+
{
4+
"from": 0,
5+
"to": 1,
6+
"paramMapping": [
7+
{
8+
"from": "input",
9+
"to": "text"
10+
}
11+
]
12+
},
13+
{
14+
"from": 1,
15+
"to": 2,
16+
"paramMapping": [
17+
{
18+
"from": "data",
19+
"to": "text"
20+
}
21+
]
22+
},
23+
{
24+
"from": 2,
25+
"to": 3,
26+
"paramMapping": [
27+
{
28+
"from": "data",
29+
"to": "output"
30+
}
31+
]
32+
}
33+
],
34+
"nodes": [
35+
{
36+
"number": 0,
37+
"type": "INPUT"
38+
},
39+
{
40+
"number": 1,
41+
"type": "ASSET",
42+
"function": "sentiment-analysis",
43+
"inputValues": [
44+
{
45+
"code": "language",
46+
"value": "en"
47+
},
48+
{
49+
"code": "text",
50+
"dataType": "text"
51+
}
52+
],
53+
"assetId": "6172874f720b09325cbcdc33",
54+
"assetType": "MODEL",
55+
"autoSelectOptions": [],
56+
"functionType": "AI",
57+
"status": "Exists",
58+
"outputValues": [
59+
{
60+
"code": "data",
61+
"dataType": "label"
62+
}
63+
]
64+
},
65+
{
66+
"number": 2,
67+
"type": "ASSET",
68+
"function": "translation",
69+
"inputValues": [
70+
{
71+
"code": "sourcelanguage",
72+
"value": "en"
73+
},
74+
{
75+
"code": "targetlanguage",
76+
"value": "es"
77+
},
78+
{
79+
"code": "text",
80+
"dataType": "text"
81+
}
82+
],
83+
"assetId": "61b097551efecf30109d3316",
84+
"assetType": "MODEL",
85+
"autoSelectOptions": [],
86+
"functionType": "AI",
87+
"status": "Exists",
88+
"outputValues": [
89+
{
90+
"code": "data",
91+
"dataType": "text"
92+
}
93+
]
94+
},
95+
{
96+
"number": 3,
97+
"type": "OUTPUT"
98+
}
99+
]
100+
}

tests/unit/pipeline_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
__author__ = "thiagocastroferreira"
2+
3+
"""
4+
Copyright 2022 The aiXplain SDK authors
5+
6+
Licensed under the Apache License, Version 2.0 (the "License");
7+
you may not use this file except in compliance with the License.
8+
You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing, software
13+
distributed under the License is distributed on an "AS IS" BASIS,
14+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
See the License for the specific language governing permissions and
16+
limitations under the License.
17+
"""
18+
19+
from dotenv import load_dotenv
20+
21+
load_dotenv()
22+
import requests_mock
23+
from aixplain.utils import config
24+
from aixplain.factories import PipelineFactory
25+
from aixplain.modules import Pipeline
26+
from urllib.parse import urljoin
27+
import pytest
28+
29+
30+
def test_create_pipeline():
31+
with requests_mock.Mocker() as mock:
32+
url = urljoin(config.BACKEND_URL, "sdk/pipelines")
33+
headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"}
34+
ref_response = {"id": "12345"}
35+
mock.post(url, headers=headers, json=ref_response)
36+
ref_pipeline = Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY)
37+
hyp_pipeline = PipelineFactory.create(pipeline={}, name="Pipeline Test")
38+
assert hyp_pipeline.id == ref_pipeline.id
39+
assert hyp_pipeline.name == ref_pipeline.name

0 commit comments

Comments
 (0)