diff --git a/src/ydata/sdk/common/client/client.py b/src/ydata/sdk/common/client/client.py index 11574899..a3a3a258 100644 --- a/src/ydata/sdk/common/client/client.py +++ b/src/ydata/sdk/common/client/client.py @@ -56,13 +56,12 @@ def __init__(self, credentials: Optional[Union[str, Dict]] = None, project: Opti self._handshake() - self._project = project if project is not None else self._get_default_project( - credentials) - self.project = project + self._default_project = project or self._get_default_project(credentials) if set_as_global: self.__set_global() - def post(self, endpoint: str, data: Optional[Dict] = None, json: Optional[Dict] = None, files: Optional[Dict] = None, raise_for_status: bool = True) -> Response: + def post(self, endpoint: str, data: Optional[Dict] = None, json: Optional[Dict] = None, + project: Project | None = None, files: Optional[Dict] = None, raise_for_status: bool = True) -> Response: """POST request to the backend. Args: @@ -75,7 +74,8 @@ def post(self, endpoint: str, data: Optional[Dict] = None, json: Optional[Dict] Returns: Response object """ - url_data = self.__build_url(endpoint, data=data, json=json, files=files) + url_data = self.__build_url( + endpoint, data=data, json=json, files=files, project=project) response = self._http_client.post(**url_data) if response.status_code != Client.codes.OK and raise_for_status: @@ -83,7 +83,8 @@ def post(self, endpoint: str, data: Optional[Dict] = None, json: Optional[Dict] return response - def get(self, endpoint: str, params: Optional[Dict] = None, cookies: Optional[Dict] = None, raise_for_status: bool = True) -> Response: + def get(self, endpoint: str, params: Optional[Dict] = None, + project: Project | None = None, cookies: Optional[Dict] = None, raise_for_status: bool = True) -> Response: """GET request to the backend. Args: @@ -94,7 +95,8 @@ def get(self, endpoint: str, params: Optional[Dict] = None, cookies: Optional[Di Returns: Response object """ - url_data = self.__build_url(endpoint, params=params, cookies=cookies) + url_data = self.__build_url(endpoint, params=params, + cookies=cookies, project=project) response = self._http_client.get(**url_data) if response.status_code != Client.codes.OK and raise_for_status: @@ -102,7 +104,7 @@ def get(self, endpoint: str, params: Optional[Dict] = None, cookies: Optional[Di return response - def get_static_file(self, endpoint: str, raise_for_status: bool = True) -> Response: + def get_static_file(self, endpoint: str, project: Project | None = None, raise_for_status: bool = True) -> Response: """Retrieve a static file from the backend. Args: @@ -112,7 +114,7 @@ def get_static_file(self, endpoint: str, raise_for_status: bool = True) -> Respo Returns: Response object """ - url_data = self.__build_url(endpoint) + url_data = self.__build_url(endpoint, project=project) url_data['url'] = f'{self._scheme}://{self._base_url}/static-content{endpoint}' response = self._http_client.get(**url_data) @@ -138,7 +140,9 @@ def _get_default_project(self, token: str): data: Dict = response.json() return data['myWorkspace'] - def __build_url(self, endpoint: str, params: Optional[Dict] = None, data: Optional[Dict] = None, json: Optional[Dict] = None, files: Optional[Dict] = None, cookies: Optional[Dict] = None) -> Dict: + def __build_url(self, endpoint: str, params: Optional[Dict] = None, data: Optional[Dict] = None, + json: Optional[Dict] = None, project: Project | None = None, files: Optional[Dict] = None, + cookies: Optional[Dict] = None) -> Dict: """Build a request for the backend. Args: @@ -153,7 +157,7 @@ def __build_url(self, endpoint: str, params: Optional[Dict] = None, data: Option dictionary containing the information to perform a request """ _params = params if params is not None else { - 'ns': self._project + 'ns': project or self._default_project } url_data = { diff --git a/src/ydata/sdk/synthesizers/synthesizer.py b/src/ydata/sdk/synthesizers/synthesizer.py index 737d1557..14ce5e75 100644 --- a/src/ydata/sdk/synthesizers/synthesizer.py +++ b/src/ydata/sdk/synthesizers/synthesizer.py @@ -14,9 +14,9 @@ from ydata.sdk.common.client.utils import init_client from ydata.sdk.common.config import BACKOFF, LOG_LEVEL from ydata.sdk.common.exceptions import (AlreadyFittedError, DataSourceAttrsError, DataSourceNotAvailableError, - DataTypeMissingError, EmptyDataError, FittingError) + DataTypeMissingError, EmptyDataError, FittingError, InputError) from ydata.sdk.common.logger import create_logger -from ydata.sdk.common.types import UID +from ydata.sdk.common.types import UID, Project from ydata.sdk.common.warnings import DataSourceTypeWarning from ydata.sdk.datasources import DataSource, LocalDataSource from ydata.sdk.datasources._models.attributes import DataSourceAttrs @@ -51,9 +51,11 @@ class BaseSynthesizer(ABC, ModelFactoryMixin): client (Client): (optional) Client to connect to the backend """ - def __init__(self, name: str | None = None, client: Client | None = None): + def __init__(self, uid: UID | None = None, name: str | None = None, project: Project | None = None, client: Client | None = None): self._init_common(client=client) - self._model = mSynthesizer(name=name or str(uuid4())) + self._model = mSynthesizer(uid=uid, name=name or str( + uuid4())) if uid or project else None + self.__project = project @init_client def _init_common(self, client: Optional[Client] = None): @@ -239,7 +241,8 @@ def _fit_from_datasource( if condition_on is not None: payload["extraData"]["condition_on"] = condition_on - response = self._client.post('/synthesizer/', json=payload) + response = self._client.post( + '/synthesizer/', json=payload, project=self.__project) data: list = response.json() self._model, _ = self._model_from_api(X.datatype, data) while self.status not in [Status.READY, Status.FAILED]: @@ -271,21 +274,22 @@ def _sample(self, payload: Dict) -> pdDataFrame: pandas `DataFrame` """ response = self._client.post( - f"/synthesizer/{self.uid}/sample", json=payload) + f"/synthesizer/{self.uid}/sample", json=payload, project=self.__project) data: Dict = response.json() sample_uid = data.get('uid') sample_status = None while sample_status not in ['finished', 'failed']: self._logger.info('Sampling from the synthesizer...') - response = self._client.get(f'/synthesizer/{self.uid}/history') + response = self._client.get( + f'/synthesizer/{self.uid}/history', project=self.__project) history: Dict = response.json() sample_data = next((s for s in history if s.get('uid') == sample_uid), None) sample_status = sample_data.get('status', {}).get('state') sleep(BACKOFF) response = self._client.get_static_file( - f'/synthesizer/{self.uid}/sample/{sample_uid}/sample.csv') + f'/synthesizer/{self.uid}/sample/{sample_uid}/sample.csv', project=self.__project) data = StringIO(response.content.decode()) return read_csv(data) @@ -317,23 +321,15 @@ def status(self) -> Status: except Exception: # noqa: PIE786 return Status.UNKNOWN - @staticmethod - @init_client - def get(uid: str, client: Optional[Client] = None) -> "BaseSynthesizer": - """List the synthesizer instances. + def get(self): + assert self._is_initialized() and self._model.uid, InputError( + "Please provide the synthesizer `uid`") - Arguments: - uid (str): synthesizer instance uid - client (Client): (optional) Client to connect to the backend + response = self._client.get(f'/synthesizer/{self.uid}', project=self.__project) + data = filter_dict(mSynthesizer, response.json()) + self._model = mSynthesizer(**data) - Returns: - Synthesizer instance - """ - response = client.get(f'/synthesizer/{uid}') - data: list = response.json() - model, synth_cls = BaseSynthesizer._model_from_api( - data['dataSource']['dataType'], data) - return ModelFactoryMixin._init_from_model_data(synth_cls, model) + return self @staticmethod @init_client