Skip to content

Commit b0a2c77

Browse files
authored
adding load_model for pickled models (#38)
* adding `load_model` * linter
1 parent a70865a commit b0a2c77

File tree

5 files changed

+51
-11
lines changed

5 files changed

+51
-11
lines changed

README.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,13 @@ upload_model(model=model, name=MY_MODEL_NAME)
7272
```python
7373
import os
7474
import joblib
75-
from litmodels import download_model
75+
from litmodels import load_model
7676

7777
# Unique model identifier: <organization>/<teamspace>/<model-name>
7878
MY_MODEL_NAME = "your_org/your_team/sklearn-svm-model"
7979

80-
# Download the model file from cloud storage
81-
model_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
82-
83-
# Load the model for inference using joblib
84-
model = joblib.load(os.path.join("my_models", model_path[0]))
80+
# Download and load the model file from cloud storage
81+
model = load_model(name=MY_MODEL_NAME, download_dir="my_models")
8582

8683
# Example: run inference with the loaded model
8784
sample_input = [[5.1, 3.5, 1.4, 0.2]]

src/litmodels/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
_PACKAGE_ROOT = os.path.dirname(__file__)
88
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
99

10-
from litmodels.io import download_model, upload_model
10+
from litmodels.io import download_model, load_model, upload_model
1111

12-
__all__ = ["download_model", "upload_model"]
12+
__all__ = ["download_model", "upload_model", "load_model"]

src/litmodels/io/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Root package for Input/output."""
22

33
from litmodels.io.cloud import download_model_files, upload_model_files
4-
from litmodels.io.gateway import download_model, upload_model
4+
from litmodels.io.gateway import download_model, load_model, upload_model
55

6-
__all__ = ["download_model", "upload_model", "download_model_files", "upload_model_files"]
6+
__all__ = ["download_model", "upload_model", "download_model_files", "upload_model_files", "load_model"]

src/litmodels/io/gateway.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,26 @@ def download_model(
8787
download_dir=download_dir,
8888
progress_bar=progress_bar,
8989
)
90+
91+
92+
def load_model(name: str, download_dir: str = ".") -> Any:
93+
"""Download a model from the model store and load it into memory.
94+
95+
Args:
96+
name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
97+
where entity is either your username or the name of an organization you are part of.
98+
download_dir: A path to directory where the model should be downloaded. Defaults
99+
to the current working directory.
100+
101+
Returns:
102+
The loaded model.
103+
"""
104+
download_paths = download_model(name=name, download_dir=download_dir)
105+
# filter out all Markdown, TXT and RST files
106+
download_paths = [p for p in download_paths if Path(p).suffix.lower() not in {".md", ".txt", ".rst"}]
107+
if len(download_paths) > 1:
108+
raise NotImplementedError("Downloaded model with multiple files is not supported yet.")
109+
model_path = Path(os.path.join(download_dir, download_paths[0]))
110+
if model_path.suffix.lower() == ".pkl":
111+
return joblib.load(model_path)
112+
raise NotImplementedError(f"Loading model from {model_path.suffix} is not supported yet.")

tests/test_io_cloud.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
22
from unittest import mock
33

4+
import joblib
45
import pytest
5-
from litmodels import download_model, upload_model
6+
from litmodels import download_model, load_model, upload_model
67
from litmodels.io import upload_model_files
78
from sklearn import svm
89
from torch.nn import Module
@@ -56,3 +57,22 @@ def test_download_model(mock_download_model):
5657
mock_download_model.assert_called_once_with(
5758
name="org-name/teamspace/model-name", download_dir="where/to/download", progress_bar=True
5859
)
60+
61+
62+
@mock.patch("litmodels.io.cloud.sdk_download_model")
63+
def test_load_model(mock_download_model, tmp_path):
64+
# create a dummy model file
65+
model_file = tmp_path / "dummy_model.pkl"
66+
test_data = svm.SVC()
67+
joblib.dump(test_data, model_file)
68+
mock_download_model.return_value = [str(model_file.name)]
69+
70+
# The lit-logger function is just a wrapper around the SDK function
71+
model = load_model(
72+
name="org-name/teamspace/model-name",
73+
download_dir=str(tmp_path),
74+
)
75+
mock_download_model.assert_called_once_with(
76+
name="org-name/teamspace/model-name", download_dir=str(tmp_path), progress_bar=True
77+
)
78+
assert isinstance(model, svm.SVC)

0 commit comments

Comments
 (0)