Skip to content

Commit

Permalink
feat(synthesizer): Support sampling on existing Synthesizers (#80)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Azory YData Bot <[email protected]>
  • Loading branch information
portellaa and azory-ydata authored Dec 14, 2023
1 parent 4e970da commit a30d8d7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 34 deletions.
26 changes: 15 additions & 11 deletions src/ydata/sdk/common/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,12 @@ def __init__(self, credentials: Optional[Union[str, Dict]] = None, project: Opti

self._handshake()

self._project = project if project is not None else self._get_default_project(
credentials)
self.project = project
self._default_project = project or self._get_default_project(credentials)
if set_as_global:
self.__set_global()

def post(self, endpoint: str, data: Optional[Dict] = None, json: Optional[Dict] = None, files: Optional[Dict] = None, raise_for_status: bool = True) -> Response:
def post(self, endpoint: str, data: Optional[Dict] = None, json: Optional[Dict] = None,
project: Project | None = None, files: Optional[Dict] = None, raise_for_status: bool = True) -> Response:
"""POST request to the backend.
Args:
Expand All @@ -75,15 +74,17 @@ def post(self, endpoint: str, data: Optional[Dict] = None, json: Optional[Dict]
Returns:
Response object
"""
url_data = self.__build_url(endpoint, data=data, json=json, files=files)
url_data = self.__build_url(
endpoint, data=data, json=json, files=files, project=project)
response = self._http_client.post(**url_data)

if response.status_code != Client.codes.OK and raise_for_status:
self.__raise_for_status(response)

return response

def get(self, endpoint: str, params: Optional[Dict] = None, cookies: Optional[Dict] = None, raise_for_status: bool = True) -> Response:
def get(self, endpoint: str, params: Optional[Dict] = None,
project: Project | None = None, cookies: Optional[Dict] = None, raise_for_status: bool = True) -> Response:
"""GET request to the backend.
Args:
Expand All @@ -94,15 +95,16 @@ def get(self, endpoint: str, params: Optional[Dict] = None, cookies: Optional[Di
Returns:
Response object
"""
url_data = self.__build_url(endpoint, params=params, cookies=cookies)
url_data = self.__build_url(endpoint, params=params,
cookies=cookies, project=project)
response = self._http_client.get(**url_data)

if response.status_code != Client.codes.OK and raise_for_status:
self.__raise_for_status(response)

return response

def get_static_file(self, endpoint: str, raise_for_status: bool = True) -> Response:
def get_static_file(self, endpoint: str, project: Project | None = None, raise_for_status: bool = True) -> Response:
"""Retrieve a static file from the backend.
Args:
Expand All @@ -112,7 +114,7 @@ def get_static_file(self, endpoint: str, raise_for_status: bool = True) -> Respo
Returns:
Response object
"""
url_data = self.__build_url(endpoint)
url_data = self.__build_url(endpoint, project=project)
url_data['url'] = f'{self._scheme}://{self._base_url}/static-content{endpoint}'
response = self._http_client.get(**url_data)

Expand All @@ -138,7 +140,9 @@ def _get_default_project(self, token: str):
data: Dict = response.json()
return data['myWorkspace']

def __build_url(self, endpoint: str, params: Optional[Dict] = None, data: Optional[Dict] = None, json: Optional[Dict] = None, files: Optional[Dict] = None, cookies: Optional[Dict] = None) -> Dict:
def __build_url(self, endpoint: str, params: Optional[Dict] = None, data: Optional[Dict] = None,
json: Optional[Dict] = None, project: Project | None = None, files: Optional[Dict] = None,
cookies: Optional[Dict] = None) -> Dict:
"""Build a request for the backend.
Args:
Expand All @@ -153,7 +157,7 @@ def __build_url(self, endpoint: str, params: Optional[Dict] = None, data: Option
dictionary containing the information to perform a request
"""
_params = params if params is not None else {
'ns': self._project
'ns': project or self._default_project
}

url_data = {
Expand Down
42 changes: 19 additions & 23 deletions src/ydata/sdk/synthesizers/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from ydata.sdk.common.client.utils import init_client
from ydata.sdk.common.config import BACKOFF, LOG_LEVEL
from ydata.sdk.common.exceptions import (AlreadyFittedError, DataSourceAttrsError, DataSourceNotAvailableError,
DataTypeMissingError, EmptyDataError, FittingError)
DataTypeMissingError, EmptyDataError, FittingError, InputError)
from ydata.sdk.common.logger import create_logger
from ydata.sdk.common.types import UID
from ydata.sdk.common.types import UID, Project
from ydata.sdk.common.warnings import DataSourceTypeWarning
from ydata.sdk.datasources import DataSource, LocalDataSource
from ydata.sdk.datasources._models.attributes import DataSourceAttrs
Expand Down Expand Up @@ -51,9 +51,11 @@ class BaseSynthesizer(ABC, ModelFactoryMixin):
client (Client): (optional) Client to connect to the backend
"""

def __init__(self, name: str | None = None, client: Client | None = None):
def __init__(self, uid: UID | None = None, name: str | None = None, project: Project | None = None, client: Client | None = None):
self._init_common(client=client)
self._model = mSynthesizer(name=name or str(uuid4()))
self._model = mSynthesizer(uid=uid, name=name or str(
uuid4())) if uid or project else None
self.__project = project

@init_client
def _init_common(self, client: Optional[Client] = None):
Expand Down Expand Up @@ -239,7 +241,8 @@ def _fit_from_datasource(
if condition_on is not None:
payload["extraData"]["condition_on"] = condition_on

response = self._client.post('/synthesizer/', json=payload)
response = self._client.post(
'/synthesizer/', json=payload, project=self.__project)
data: list = response.json()
self._model, _ = self._model_from_api(X.datatype, data)
while self.status not in [Status.READY, Status.FAILED]:
Expand Down Expand Up @@ -271,21 +274,22 @@ def _sample(self, payload: Dict) -> pdDataFrame:
pandas `DataFrame`
"""
response = self._client.post(
f"/synthesizer/{self.uid}/sample", json=payload)
f"/synthesizer/{self.uid}/sample", json=payload, project=self.__project)

data: Dict = response.json()
sample_uid = data.get('uid')
sample_status = None
while sample_status not in ['finished', 'failed']:
self._logger.info('Sampling from the synthesizer...')
response = self._client.get(f'/synthesizer/{self.uid}/history')
response = self._client.get(
f'/synthesizer/{self.uid}/history', project=self.__project)
history: Dict = response.json()
sample_data = next((s for s in history if s.get('uid') == sample_uid), None)
sample_status = sample_data.get('status', {}).get('state')
sleep(BACKOFF)

response = self._client.get_static_file(
f'/synthesizer/{self.uid}/sample/{sample_uid}/sample.csv')
f'/synthesizer/{self.uid}/sample/{sample_uid}/sample.csv', project=self.__project)
data = StringIO(response.content.decode())
return read_csv(data)

Expand Down Expand Up @@ -317,23 +321,15 @@ def status(self) -> Status:
except Exception: # noqa: PIE786
return Status.UNKNOWN

@staticmethod
@init_client
def get(uid: str, client: Optional[Client] = None) -> "BaseSynthesizer":
"""List the synthesizer instances.
def get(self):
assert self._is_initialized() and self._model.uid, InputError(
"Please provide the synthesizer `uid`")

Arguments:
uid (str): synthesizer instance uid
client (Client): (optional) Client to connect to the backend
response = self._client.get(f'/synthesizer/{self.uid}', project=self.__project)
data = filter_dict(mSynthesizer, response.json())
self._model = mSynthesizer(**data)

Returns:
Synthesizer instance
"""
response = client.get(f'/synthesizer/{uid}')
data: list = response.json()
model, synth_cls = BaseSynthesizer._model_from_api(
data['dataSource']['dataType'], data)
return ModelFactoryMixin._init_from_model_data(synth_cls, model)
return self

@staticmethod
@init_client
Expand Down

0 comments on commit a30d8d7

Please sign in to comment.