|
| 1 | +import os |
| 2 | +from contextlib import redirect_stdout |
| 3 | +from io import StringIO |
| 4 | + |
| 5 | +import pytest |
| 6 | +from lightning_sdk.lightning_cloud.rest_client import GridRestClient |
| 7 | +from lightning_sdk.utils.resolve import _resolve_teamspace |
| 8 | +from litmodels import download_model, upload_model |
| 9 | + |
| 10 | +LIT_ORG = "lightning-ai" |
| 11 | +LIT_TEAMSPACE = "LitModels" |
| 12 | + |
| 13 | + |
| 14 | +@pytest.mark.cloud() |
| 15 | +def test_upload_download_model(tmp_path): |
| 16 | + """Verify that the model is uploaded to the teamspace""" |
| 17 | + # create a dummy file |
| 18 | + file_path = tmp_path / "dummy.txt" |
| 19 | + with open(file_path, "w") as f: |
| 20 | + f.write("dummy") |
| 21 | + |
| 22 | + # model name with random hash |
| 23 | + model_name = f"litmodels_test_integrations+{os.urandom(8).hex()}" |
| 24 | + teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None) |
| 25 | + org_team = f"{teamspace.owner.name}/{teamspace.name}" |
| 26 | + |
| 27 | + out = StringIO() |
| 28 | + with redirect_stdout(out): |
| 29 | + upload_model(name=f"{org_team}/{model_name}", model=file_path) |
| 30 | + |
| 31 | + # validate the output |
| 32 | + assert ( |
| 33 | + f"Model uploaded successfully. Link to the model: 'https://lightning.ai/{org_team}/models/{model_name}'" |
| 34 | + ) in out.getvalue() |
| 35 | + |
| 36 | + os.remove(file_path) |
| 37 | + assert not os.path.isfile(file_path) |
| 38 | + |
| 39 | + model_files = download_model(name=f"{org_team}/{model_name}", download_dir=tmp_path) |
| 40 | + assert model_files == ["dummy.txt"] |
| 41 | + for file in model_files: |
| 42 | + assert os.path.isfile(os.path.join(tmp_path, file)) |
| 43 | + |
| 44 | + client = GridRestClient() |
| 45 | + # cleaning created models with todo: also consider how to delete just this version of the model |
| 46 | + model = client.models_store_get_model_by_name( |
| 47 | + project_owner_name=teamspace.owner.name, |
| 48 | + project_name=teamspace.name, |
| 49 | + model_name=model_name, |
| 50 | + ) |
| 51 | + client.models_store_delete_model(project_id=teamspace.id, model_id=model.id) |
0 commit comments