From 93a78a91049a7a0b37728fcab21b7d5c96232296 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Portela=20Afonso?= Date: Thu, 11 Jan 2024 20:09:41 +0000 Subject: [PATCH 01/11] feat(synthesizer): Support for MultiTable --- Makefile | 7 +- .../sdk/datasources/_models/datasource.py | 19 ++-- src/ydata/sdk/datasources/_models/status.py | 5 +- src/ydata/sdk/datasources/datasource.py | 22 ++--- src/ydata/sdk/synthesizers/__init__.py | 3 +- src/ydata/sdk/synthesizers/_models/status.py | 64 ++++++++----- .../sdk/synthesizers/_models/synthesizer.py | 15 ++- src/ydata/sdk/synthesizers/multitable.py | 72 ++++++++++++++ src/ydata/sdk/synthesizers/synthesizer.py | 94 +++++++++++-------- 9 files changed, 203 insertions(+), 98 deletions(-) create mode 100644 src/ydata/sdk/synthesizers/multitable.py diff --git a/Makefile b/Makefile index de4c632a..07ce061f 100644 --- a/Makefile +++ b/Makefile @@ -68,13 +68,16 @@ package: ### Builds the package in wheel format echo "$(version)" > src/ydata/sdk/VERSION stubgen src/ydata/sdk -o src --export-less $(PYTHON) -m build --wheel - twine check dist/* + $(PYTHON) -m twine check dist/* wheel: ### Compiles the wheel test -d wheels || mkdir -p wheels cp dist/ydata_sdk-$(version)-py3-none-any.whl wheels/ydata_sdk-$(version)-py$(PYV)-none-any.whl $(PYTHON) -m pyc_wheel wheels/ydata_sdk-$(version)-py$(PYV)-none-any.whl - twine check wheels/* + $(PYTHON) -m twine check wheels/* + +upload: + $(PYTHON) -m twine upload -r ydata wheels/ydata_sdk-$(version)-py310-none-any.whl publish-docs: ### Publishes the documentation mike deploy --push --update-aliases $(version) latest diff --git a/src/ydata/sdk/datasources/_models/datasource.py b/src/ydata/sdk/datasources/_models/datasource.py index d505f597..9f82d261 100644 --- a/src/ydata/sdk/datasources/_models/datasource.py +++ b/src/ydata/sdk/datasources/_models/datasource.py @@ -16,22 +16,21 @@ class DataSource: datatype: Optional[DataSourceType] = None metadata: Optional[Metadata] = None status: Optional[Status] = None - state: Optional[State] = None def __post_init__(self): if self.metadata is not None: self.metadata = Metadata(**self.metadata) - if self.state is not None: - data = { - 'validation': self.state.get('validation', {}).get('state', 'unknown'), - 'metadata': self.state.get('metadata', {}).get('state', 'unknown'), - 'profiling': self.state.get('profiling', {}).get('state', 'unknown') - } - self.state = State.parse_obj(data) + # if self.state is not None: + # data = { + # 'validation': self.state.get('validation', {}).get('state', 'unknown'), + # 'metadata': self.state.get('metadata', {}).get('state', 'unknown'), + # 'profiling': self.state.get('profiling', {}).get('state', 'unknown') + # } + # self.state = State.parse_obj(data) - if self.status is not None: - self.status = Status(self.status) + # if self.status is not None: + # self.status = Status(self.status) def to_payload(self): return {} diff --git a/src/ydata/sdk/datasources/_models/status.py b/src/ydata/sdk/datasources/_models/status.py index 59ceb004..07739780 100644 --- a/src/ydata/sdk/datasources/_models/status.py +++ b/src/ydata/sdk/datasources/_models/status.py @@ -27,7 +27,7 @@ class ProfilingState(StringEnum): AVAILABLE = 'available' -class Status(StringEnum): +class State(StringEnum): """Represent the status of a [`DataSource`][ydata.sdk.datasources.datasource.DataSource].""" AVAILABLE = 'available' @@ -59,7 +59,8 @@ class Status(StringEnum): """ -class State(BaseModel): +class Status(BaseModel): + state: State validation: ValidationState metadata: MetadataState profiling: ProfilingState diff --git a/src/ydata/sdk/datasources/datasource.py b/src/ydata/sdk/datasources/datasource.py index db2a4a63..f28cc869 100644 --- a/src/ydata/sdk/datasources/datasource.py +++ b/src/ydata/sdk/datasources/datasource.py @@ -174,20 +174,20 @@ def _wait_for_metadata(datasource): sleep(BACKOFF) return datasource - @staticmethod - def _resolve_api_status(api_status: Dict) -> Status: - status = Status(api_status.get('state', Status.UNKNOWN.name)) - validation = ValidationState(api_status.get('validation', {}).get( - 'state', ValidationState.UNKNOWN.name)) - if validation == ValidationState.FAILED: - status = Status.FAILED - return status + # @staticmethod + # def _resolve_api_status(api_status: Dict) -> Status: + # status = Status(api_status.get('state', Status.UNKNOWN.name)) + # validation = ValidationState(api_status.get('validation', {}).get( + # 'state', ValidationState.UNKNOWN.name)) + # if validation == ValidationState.FAILED: + # status = .FAILED + # return status @staticmethod def _model_from_api(data: Dict, datasource_type: Type[mDataSource]) -> mDataSource: - data['datatype'] = data.pop('dataType') - data['state'] = data['status'] - data['status'] = DataSource._resolve_api_status(data['status']) + data['datatype'] = data.pop('dataType', None) + # data['state'] = data['status'] + # data['status'] = DataSource._resolve_api_status(data['status']) data = filter_dict(datasource_type, data) model = datasource_type(**data) return model diff --git a/src/ydata/sdk/synthesizers/__init__.py b/src/ydata/sdk/synthesizers/__init__.py index ac5dcd26..b64c122e 100644 --- a/src/ydata/sdk/synthesizers/__init__.py +++ b/src/ydata/sdk/synthesizers/__init__.py @@ -1,8 +1,9 @@ from ydata.datascience.common import PrivacyLevel from ydata.sdk.synthesizers._models.synthesizers_list import SynthesizersList +from ydata.sdk.synthesizers.multitable import MultiTableSynthesizer from ydata.sdk.synthesizers.regular import RegularSynthesizer from ydata.sdk.synthesizers.synthesizer import BaseSynthesizer as Synthesizer from ydata.sdk.synthesizers.timeseries import TimeSeriesSynthesizer __all__ = ["RegularSynthesizer", "TimeSeriesSynthesizer", - "Synthesizer", "SynthesizersList", "PrivacyLevel"] + "Synthesizer", "SynthesizersList", "PrivacyLevel", "MultiTableSynthesizer"] diff --git a/src/ydata/sdk/synthesizers/_models/status.py b/src/ydata/sdk/synthesizers/_models/status.py index 12d888c7..3f2fddc2 100644 --- a/src/ydata/sdk/synthesizers/_models/status.py +++ b/src/ydata/sdk/synthesizers/_models/status.py @@ -1,37 +1,37 @@ from typing import Generic, TypeVar -from pydantic import BaseModel - +from pydantic import BaseModel, Field from ydata.core.enum import StringEnum T = TypeVar("T") class GenericStateErrorStatus(BaseModel, Generic[T]): - state: T + state: T | None = Field(None) + + class Config: + use_enum_values = True class PrepareState(StringEnum): - PREPARING = 'preparing' - DISCOVERING = 'discovering' - FINISHED = 'finished' - FAILED = 'failed' - UNKNOWN = 'unknown' + PREPARING = "preparing" + DISCOVERING = "discovering" + FINISHED = "finished" + FAILED = "failed" class TrainingState(StringEnum): - PREPARING = 'preparing' - RUNNING = 'running' - FINISHED = 'finished' - FAILED = 'failed' - UNKNOWN = 'unknown' + PREPARING = "preparing" + RUNNING = "running" + FINISHED = "finished" + FAILED = "failed" class ReportState(StringEnum): - UNKNOWN = 'unknown' - DISCOVERING = 'discovering' - FINISHED = 'finished' - FAILED = 'failed' + PREPARING = "preparing" + GENERATING = "generating" + AVAILABLE = "available" + FAILED = "failed" PrepareStatus = GenericStateErrorStatus[PrepareState] @@ -39,11 +39,25 @@ class ReportState(StringEnum): ReportStatus = GenericStateErrorStatus[ReportState] -class Status(StringEnum): - NOT_INITIALIZED = 'not initialized' - FAILED = 'failed' - PREPARE = 'prepare' - TRAIN = 'train' - REPORT = 'report' # Should not be here for SDK - READY = 'ready' - UNKNOWN = 'unknown' +class Status(BaseModel): + class State(StringEnum): + NOT_INITIALIZED = 'not initialized' + UNKNOWN = 'unknown' + + PREPARE = "prepare" + TRAIN = "train" + REPORT = "report" + READY = "ready" + + state: State | None = Field(None) + prepare: PrepareStatus | None = Field(None) + training: TrainingStatus | None = Field(None) + report: ReportStatus | None = Field(None) + + @staticmethod + def not_initialized() -> "Status": + return Status(state=Status.State.NOT_INITIALIZED) + + @staticmethod + def unknown() -> "Status": + return Status(state=Status.State.UNKNOWN) diff --git a/src/ydata/sdk/synthesizers/_models/synthesizer.py b/src/ydata/sdk/synthesizers/_models/synthesizer.py index 4f1aad42..79242ece 100644 --- a/src/ydata/sdk/synthesizers/_models/synthesizer.py +++ b/src/ydata/sdk/synthesizers/_models/synthesizer.py @@ -1,11 +1,10 @@ -from dataclasses import dataclass, field -from typing import Dict, Optional +from pydantic import BaseModel, Field +from .status import Status -@dataclass -class Synthesizer: - uid: Optional[str] = None - author: Optional[str] = None - name: Optional[str] = None - status: Optional[Dict] = field(default_factory=dict) +class Synthesizer(BaseModel): + uid: str | None = None + author: str | None = None + name: str | None = None + status: Status | None = Field(None) diff --git a/src/ydata/sdk/synthesizers/multitable.py b/src/ydata/sdk/synthesizers/multitable.py new file mode 100644 index 00000000..4e323261 --- /dev/null +++ b/src/ydata/sdk/synthesizers/multitable.py @@ -0,0 +1,72 @@ +from time import sleep + +from ydata.sdk.common.client import Client +from ydata.sdk.common.config import BACKOFF +from ydata.sdk.common.types import UID, Project +from ydata.sdk.common.exceptions import InputError +from ydata.sdk.datasources import DataSource +from ydata.sdk.synthesizers.synthesizer import BaseSynthesizer + +class MultiTableSynthesizer(BaseSynthesizer): + + def __init__( + self, write_connector: UID, uid: UID | None = None, name: str | None = None, + project: Project | None = None, client: Client | None = None): + + self.__write_connector = write_connector + + super().__init__(uid, name, project, client) + + def fit(self, X: DataSource) -> None: + """Fit the synthesizer. + + The synthesizer accepts as training dataset a YData [`DataSource`][ydata.sdk.datasources.DataSource]. + + Arguments: + X (DataSource): DataSource to Train + """ + + self._fit_from_datasource(X) + + def sample(self, frac: int | float = 1, write_connector: UID | None = None) -> None: + """Sample from a [`MultiTableSynthesizer`][ydata.sdk.synthesizers.MultiTableSynthesizer] + instance. + The sample is saved in the connector that was provided in the synthesizer initialization + or in the + + Arguments: + frac (int | float): fraction of the sample to be returned + """ + + assert frac >= 0.1, InputError("It is not possible to generate an empty synthetic data schema. Please validate the input provided. ") + assert frac <= 5, InputError("It is not possible to generate a database that is 5x bigger than the original dataset. Please validate the input provided.") + + payload = { + 'fraction': frac, + } + + if write_connector is not None: + payload['writeConnector'] = write_connector + + response = self._client.post( + f"/synthesizer/{self.uid}/sample", json=payload, project=self._project) + + data = response.json() + sample_uid = data.get('uid') + sample_status = None + while sample_status not in ['finished', 'failed']: + self._logger.info('Sampling from the synthesizer...') + response = self._client.get( + f'/synthesizer/{self.uid}/history', project=self._project) + history = response.json() + sample_data = next((s for s in history if s.get('uid') == sample_uid), None) + sample_status = sample_data.get('status', {}).get('state') + sleep(BACKOFF) + + print(f"Sample created and saved into connector with ID {self.__write_connector or write_connector}") + + def _create_payload(self) -> dict: + payload = super()._create_payload() + payload['writeConnector'] = self.__write_connector + + return payload diff --git a/src/ydata/sdk/synthesizers/synthesizer.py b/src/ydata/sdk/synthesizers/synthesizer.py index 14ce5e75..5a9ec7f8 100644 --- a/src/ydata/sdk/synthesizers/synthesizer.py +++ b/src/ydata/sdk/synthesizers/synthesizer.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from io import StringIO from time import sleep -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Union from uuid import uuid4 from warnings import warn @@ -53,9 +53,8 @@ class BaseSynthesizer(ABC, ModelFactoryMixin): def __init__(self, uid: UID | None = None, name: str | None = None, project: Project | None = None, client: Client | None = None): self._init_common(client=client) - self._model = mSynthesizer(uid=uid, name=name or str( - uuid4())) if uid or project else None - self.__project = project + self._model = mSynthesizer(uid=uid, name=name or str(uuid4())) + self._project = project @init_client def _init_common(self, client: Optional[Client] = None): @@ -221,44 +220,61 @@ def _metadata_to_payload( def _fit_from_datasource( self, X: DataSource, - privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY, - dataset_attrs: Optional[DataSourceAttrs] = None, - target: Optional[str] = None, - anonymize: Optional[dict] = None, - condition_on: Optional[List[str]] = None + privacy_level: PrivacyLevel | None = None, + dataset_attrs: DataSourceAttrs | None = None, + target: str | None = None, + anonymize: dict | None = None, + condition_on: list[str] | None = None ) -> None: - metadata = self._metadata_to_payload( - DataSourceType(X.datatype), X.metadata, dataset_attrs, target) - payload = { - 'name': self._model.name, - 'dataSourceUID': X.uid, - 'metadata': metadata, - 'extraData': {}, - 'privacyLevel': privacy_level.value - } + payload = self._create_payload() + + payload['dataSourceUID'] = X.uid + + if privacy_level: + payload['privacy_level'] = privacy_level.value + + if X.metadata is not None and X.datatype is not None: + payload['metadata'] = self._metadata_to_payload( + DataSourceType(X.datatype), X.metadata, dataset_attrs, target) + if anonymize is not None: payload["extraData"]["anonymize"] = anonymize if condition_on is not None: payload["extraData"]["condition_on"] = condition_on response = self._client.post( - '/synthesizer/', json=payload, project=self.__project) - data: list = response.json() - self._model, _ = self._model_from_api(X.datatype, data) - while self.status not in [Status.READY, Status.FAILED]: + '/synthesizer/', json=payload, project=self._project) + data = response.json() + self._model = mSynthesizer(**data) + while self._check_fitting_not_finished(self.status): self._logger.info('Training the synthesizer...') sleep(BACKOFF) - if self.status == Status.FAILED: + def _create_payload(self) -> dict: + payload = { + 'extraData': {} + } + + if self._model and self._model.name: + payload['name'] = self._model.name + + return payload + + def _check_fitting_not_finished(self, status: Status) -> bool: + self._logger.debug(f'checking status {status}') + + if status.state in [Status.State.READY, Status.State.REPORT]: + return False + + self._logger.debug(f'status not ready yet {status.state}') + + if status.prepare and PrepareState(status.prepare.state) == PrepareState.FAILED: raise FittingError('Could not train the synthesizer') - @staticmethod - def _model_from_api(datatype: str, data: Dict) -> Tuple[mSynthesizer, Type["BaseSynthesizer"]]: - from ydata.sdk.synthesizers._models.synthesizer_map import TYPE_TO_CLASS - synth_cls = TYPE_TO_CLASS.get(SynthesizerType(datatype).value) - data['status'] = synth_cls._resolve_api_status(data['status']) - data = filter_dict(mSynthesizer, data) - return mSynthesizer(**data), synth_cls + if status.training and TrainingState(status.training.state) == TrainingState.FAILED: + raise FittingError('Could not train the synthesizer') + + return True @abstractmethod def sample(self) -> pdDataFrame: @@ -274,7 +290,7 @@ def _sample(self, payload: Dict) -> pdDataFrame: pandas `DataFrame` """ response = self._client.post( - f"/synthesizer/{self.uid}/sample", json=payload, project=self.__project) + f"/synthesizer/{self.uid}/sample", json=payload, project=self._project) data: Dict = response.json() sample_uid = data.get('uid') @@ -282,14 +298,14 @@ def _sample(self, payload: Dict) -> pdDataFrame: while sample_status not in ['finished', 'failed']: self._logger.info('Sampling from the synthesizer...') response = self._client.get( - f'/synthesizer/{self.uid}/history', project=self.__project) + f'/synthesizer/{self.uid}/history', project=self._project) history: Dict = response.json() sample_data = next((s for s in history if s.get('uid') == sample_uid), None) sample_status = sample_data.get('status', {}).get('state') sleep(BACKOFF) response = self._client.get_static_file( - f'/synthesizer/{self.uid}/sample/{sample_uid}/sample.csv', project=self.__project) + f'/synthesizer/{self.uid}/sample/{sample_uid}/sample.csv', project=self._project) data = StringIO(response.content.decode()) return read_csv(data) @@ -301,7 +317,7 @@ def uid(self) -> UID: Synthesizer status """ if not self._is_initialized(): - return Status.NOT_INITIALIZED + return Status.State.NOT_INITIALIZED return self._model.uid @@ -313,20 +329,20 @@ def status(self) -> Status: Synthesizer status """ if not self._is_initialized(): - return Status.NOT_INITIALIZED + return Status.not_initialized() try: - self = self.get(self._model.uid, self._client) + self = self.get() return self._model.status except Exception: # noqa: PIE786 - return Status.UNKNOWN + return Status.unknown() def get(self): assert self._is_initialized() and self._model.uid, InputError( "Please provide the synthesizer `uid`") - response = self._client.get(f'/synthesizer/{self.uid}', project=self.__project) - data = filter_dict(mSynthesizer, response.json()) + response = self._client.get(f'/synthesizer/{self.uid}', project=self._project) + data = response.json() self._model = mSynthesizer(**data) return self From a35b47e4e21f404c569d7398cd93b3a9aa153e84 Mon Sep 17 00:00:00 2001 From: Azory YData Bot Date: Thu, 11 Jan 2024 20:11:37 +0000 Subject: [PATCH 02/11] fix(linting): code formatting --- src/ydata/sdk/datasources/_models/datasource.py | 2 +- src/ydata/sdk/datasources/datasource.py | 2 +- src/ydata/sdk/synthesizers/_models/status.py | 1 + src/ydata/sdk/synthesizers/multitable.py | 12 ++++++++---- src/ydata/sdk/synthesizers/synthesizer.py | 4 +--- 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/ydata/sdk/datasources/_models/datasource.py b/src/ydata/sdk/datasources/_models/datasource.py index 9f82d261..68d8c33a 100644 --- a/src/ydata/sdk/datasources/_models/datasource.py +++ b/src/ydata/sdk/datasources/_models/datasource.py @@ -4,7 +4,7 @@ from ydata.sdk.common.types import UID from ydata.sdk.datasources._models.datatype import DataSourceType from ydata.sdk.datasources._models.metadata.metadata import Metadata -from ydata.sdk.datasources._models.status import State, Status +from ydata.sdk.datasources._models.status import Status @dataclass diff --git a/src/ydata/sdk/datasources/datasource.py b/src/ydata/sdk/datasources/datasource.py index f28cc869..2710737b 100644 --- a/src/ydata/sdk/datasources/datasource.py +++ b/src/ydata/sdk/datasources/datasource.py @@ -13,7 +13,7 @@ from ydata.sdk.datasources._models.datasource_list import DataSourceList from ydata.sdk.datasources._models.datatype import DataSourceType from ydata.sdk.datasources._models.metadata.metadata import Metadata -from ydata.sdk.datasources._models.status import Status, ValidationState +from ydata.sdk.datasources._models.status import Status from ydata.sdk.utils.model_mixin import ModelFactoryMixin from ydata.sdk.utils.model_utils import filter_dict diff --git a/src/ydata/sdk/synthesizers/_models/status.py b/src/ydata/sdk/synthesizers/_models/status.py index 3f2fddc2..3655fcc6 100644 --- a/src/ydata/sdk/synthesizers/_models/status.py +++ b/src/ydata/sdk/synthesizers/_models/status.py @@ -1,6 +1,7 @@ from typing import Generic, TypeVar from pydantic import BaseModel, Field + from ydata.core.enum import StringEnum T = TypeVar("T") diff --git a/src/ydata/sdk/synthesizers/multitable.py b/src/ydata/sdk/synthesizers/multitable.py index 4e323261..dc4c277a 100644 --- a/src/ydata/sdk/synthesizers/multitable.py +++ b/src/ydata/sdk/synthesizers/multitable.py @@ -2,11 +2,12 @@ from ydata.sdk.common.client import Client from ydata.sdk.common.config import BACKOFF -from ydata.sdk.common.types import UID, Project from ydata.sdk.common.exceptions import InputError +from ydata.sdk.common.types import UID, Project from ydata.sdk.datasources import DataSource from ydata.sdk.synthesizers.synthesizer import BaseSynthesizer + class MultiTableSynthesizer(BaseSynthesizer): def __init__( @@ -38,8 +39,10 @@ def sample(self, frac: int | float = 1, write_connector: UID | None = None) -> N frac (int | float): fraction of the sample to be returned """ - assert frac >= 0.1, InputError("It is not possible to generate an empty synthetic data schema. Please validate the input provided. ") - assert frac <= 5, InputError("It is not possible to generate a database that is 5x bigger than the original dataset. Please validate the input provided.") + assert frac >= 0.1, InputError( + "It is not possible to generate an empty synthetic data schema. Please validate the input provided. ") + assert frac <= 5, InputError( + "It is not possible to generate a database that is 5x bigger than the original dataset. Please validate the input provided.") payload = { 'fraction': frac, @@ -63,7 +66,8 @@ def sample(self, frac: int | float = 1, write_connector: UID | None = None) -> N sample_status = sample_data.get('status', {}).get('state') sleep(BACKOFF) - print(f"Sample created and saved into connector with ID {self.__write_connector or write_connector}") + print( + f"Sample created and saved into connector with ID {self.__write_connector or write_connector}") def _create_payload(self) -> dict: payload = super()._create_payload() diff --git a/src/ydata/sdk/synthesizers/synthesizer.py b/src/ydata/sdk/synthesizers/synthesizer.py index 5a9ec7f8..70bfef39 100644 --- a/src/ydata/sdk/synthesizers/synthesizer.py +++ b/src/ydata/sdk/synthesizers/synthesizer.py @@ -26,10 +26,8 @@ from ydata.sdk.datasources._models.status import Status as dsStatus from ydata.sdk.synthesizers._models.status import PrepareState, Status, TrainingState from ydata.sdk.synthesizers._models.synthesizer import Synthesizer as mSynthesizer -from ydata.sdk.synthesizers._models.synthesizer_type import SynthesizerType from ydata.sdk.synthesizers._models.synthesizers_list import SynthesizersList from ydata.sdk.utils.model_mixin import ModelFactoryMixin -from ydata.sdk.utils.model_utils import filter_dict @typechecked @@ -228,7 +226,7 @@ def _fit_from_datasource( ) -> None: payload = self._create_payload() - payload['dataSourceUID'] = X.uid + payload['dataSourceUID'] = X.uid if privacy_level: payload['privacy_level'] = privacy_level.value From a15f3afc7f40121aba19f0801deaaf577483b92f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Portela=20Afonso?= Date: Fri, 12 Jan 2024 12:27:17 +0000 Subject: [PATCH 03/11] class comments --- src/ydata/sdk/synthesizers/multitable.py | 16 ++++++++++++++++ src/ydata/sdk/synthesizers/synthesizer.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/ydata/sdk/synthesizers/multitable.py b/src/ydata/sdk/synthesizers/multitable.py index dc4c277a..b9f71b23 100644 --- a/src/ydata/sdk/synthesizers/multitable.py +++ b/src/ydata/sdk/synthesizers/multitable.py @@ -9,6 +9,22 @@ class MultiTableSynthesizer(BaseSynthesizer): + """MultiTable synthesizer class. + + Methods + ------- + - `fit`: train a synthesizer instance. + - `sample`: request synthetic data. + - `status`: current status of the synthesizer instance. + + Note: + The synthesizer instance is created in the backend only when the `fit` method is called. + + Arguments: + write_connector (UID): Connector of type RDBMS to be used to write the samples + name (str): (optional) Name to be used when creating the synthesizer. Calculated internally if not provided + client (Client): (optional) Client to connect to the backend + """ def __init__( self, write_connector: UID, uid: UID | None = None, name: str | None = None, diff --git a/src/ydata/sdk/synthesizers/synthesizer.py b/src/ydata/sdk/synthesizers/synthesizer.py index 70bfef39..6a3e91f3 100644 --- a/src/ydata/sdk/synthesizers/synthesizer.py +++ b/src/ydata/sdk/synthesizers/synthesizer.py @@ -34,7 +34,7 @@ class BaseSynthesizer(ABC, ModelFactoryMixin): """Main synthesizer class. - This class cannot be directly instanciated because of the specificities between [`RegularSynthesizer`][ydata.sdk.synthesizers.RegularSynthesizer] and [`TimeSeriesSynthesizer`][ydata.sdk.synthesizers.TimeSeriesSynthesizer] `sample` methods. + This class cannot be directly instanciated because of the specificities between [`RegularSynthesizer`][ydata.sdk.synthesizers.RegularSynthesizer], [`TimeSeriesSynthesizer`][ydata.sdk.synthesizers.TimeSeriesSynthesizer] or [`MultiTableSynthesizer`][ydata.sdk.synthesizers.MultiTableSynthesizer] `sample` methods. Methods ------- From a1ad5322c826111e10986c16837bb3cb473af0c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Portela=20Afonso?= Date: Fri, 12 Jan 2024 15:15:26 +0000 Subject: [PATCH 04/11] add arguments to fit method --- src/ydata/sdk/synthesizers/multitable.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/ydata/sdk/synthesizers/multitable.py b/src/ydata/sdk/synthesizers/multitable.py index b9f71b23..492de04e 100644 --- a/src/ydata/sdk/synthesizers/multitable.py +++ b/src/ydata/sdk/synthesizers/multitable.py @@ -1,10 +1,13 @@ from time import sleep +from ydata.datascience.common import PrivacyLevel from ydata.sdk.common.client import Client from ydata.sdk.common.config import BACKOFF from ydata.sdk.common.exceptions import InputError from ydata.sdk.common.types import UID, Project from ydata.sdk.datasources import DataSource +from ydata.sdk.datasources._models.datatype import DataSourceType +from ydata.sdk.datasources._models.metadata.data_types import DataType from ydata.sdk.synthesizers.synthesizer import BaseSynthesizer @@ -34,10 +37,21 @@ def __init__( super().__init__(uid, name, project, client) - def fit(self, X: DataSource) -> None: + def fit(self, X: DataSource, + privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY, + datatype: DataSourceType | str | None = None, + sortbykey: str | list[str] | None = None, + entities: str | list[str] | None = None, + generate_cols: list[str] | None = None, + exclude_cols: list[str] | None = None, + dtypes: dict[str, str | DataType] | None = None, + target: str | None = None, + anonymize: dict | None = None, + condition_on: list[str] | None = None) -> None: """Fit the synthesizer. The synthesizer accepts as training dataset a YData [`DataSource`][ydata.sdk.datasources.DataSource]. + Except X, all the other arguments are for now ignored until they are supported. Arguments: X (DataSource): DataSource to Train From daeab6071c31223fd3987a16c408b3dbfdfc3977 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Portela=20Afonso?= Date: Fri, 12 Jan 2024 15:20:02 +0000 Subject: [PATCH 05/11] remove commented code --- .../sdk/datasources/_models/datasource.py | 28 ++++--------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/src/ydata/sdk/datasources/_models/datasource.py b/src/ydata/sdk/datasources/_models/datasource.py index 68d8c33a..cd33c162 100644 --- a/src/ydata/sdk/datasources/_models/datasource.py +++ b/src/ydata/sdk/datasources/_models/datasource.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional from ydata.sdk.common.types import UID from ydata.sdk.datasources._models.datatype import DataSourceType @@ -10,27 +9,12 @@ @dataclass class DataSource: - uid: Optional[UID] = None - author: Optional[str] = None - name: Optional[str] = None - datatype: Optional[DataSourceType] = None - metadata: Optional[Metadata] = None - status: Optional[Status] = None - - def __post_init__(self): - if self.metadata is not None: - self.metadata = Metadata(**self.metadata) - - # if self.state is not None: - # data = { - # 'validation': self.state.get('validation', {}).get('state', 'unknown'), - # 'metadata': self.state.get('metadata', {}).get('state', 'unknown'), - # 'profiling': self.state.get('profiling', {}).get('state', 'unknown') - # } - # self.state = State.parse_obj(data) - - # if self.status is not None: - # self.status = Status(self.status) + uid: UID | None = None + author: str | None = None + name: str | None = None + datatype: DataSourceType | None = None + metadata: Metadata | None = None + status: Status | None = None def to_payload(self): return {} From 41d5969d4f1a55a478a73932e599ad2708238f99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Portela=20Afonso?= Date: Fri, 12 Jan 2024 15:27:17 +0000 Subject: [PATCH 06/11] remove commented code --- src/ydata/sdk/datasources/datasource.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/ydata/sdk/datasources/datasource.py b/src/ydata/sdk/datasources/datasource.py index 2710737b..ec15a47a 100644 --- a/src/ydata/sdk/datasources/datasource.py +++ b/src/ydata/sdk/datasources/datasource.py @@ -174,20 +174,9 @@ def _wait_for_metadata(datasource): sleep(BACKOFF) return datasource - # @staticmethod - # def _resolve_api_status(api_status: Dict) -> Status: - # status = Status(api_status.get('state', Status.UNKNOWN.name)) - # validation = ValidationState(api_status.get('validation', {}).get( - # 'state', ValidationState.UNKNOWN.name)) - # if validation == ValidationState.FAILED: - # status = .FAILED - # return status - @staticmethod def _model_from_api(data: Dict, datasource_type: Type[mDataSource]) -> mDataSource: data['datatype'] = data.pop('dataType', None) - # data['state'] = data['status'] - # data['status'] = DataSource._resolve_api_status(data['status']) data = filter_dict(datasource_type, data) model = datasource_type(**data) return model From bb339b97b78f50f2e3a1c1afc8d35f819f4a7488 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Portela=20Afonso?= Date: Fri, 12 Jan 2024 18:01:51 +0000 Subject: [PATCH 07/11] add ability to use Connector as argument --- src/ydata/sdk/connectors/connector.py | 2 +- src/ydata/sdk/synthesizers/multitable.py | 28 +++++++++++++++++++----- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/ydata/sdk/connectors/connector.py b/src/ydata/sdk/connectors/connector.py index 0608edda..bfa6de2b 100644 --- a/src/ydata/sdk/connectors/connector.py +++ b/src/ydata/sdk/connectors/connector.py @@ -47,7 +47,7 @@ def uid(self) -> UID: return self._model.uid @property - def type(self) -> str: + def type(self) -> ConnectorType: return self._model.type @staticmethod diff --git a/src/ydata/sdk/synthesizers/multitable.py b/src/ydata/sdk/synthesizers/multitable.py index 492de04e..960edb1a 100644 --- a/src/ydata/sdk/synthesizers/multitable.py +++ b/src/ydata/sdk/synthesizers/multitable.py @@ -3,8 +3,9 @@ from ydata.datascience.common import PrivacyLevel from ydata.sdk.common.client import Client from ydata.sdk.common.config import BACKOFF -from ydata.sdk.common.exceptions import InputError +from ydata.sdk.common.exceptions import InputError, ConnectorError from ydata.sdk.common.types import UID, Project +from ydata.sdk.connectors.connector import Connector, ConnectorType from ydata.sdk.datasources import DataSource from ydata.sdk.datasources._models.datatype import DataSourceType from ydata.sdk.datasources._models.metadata.data_types import DataType @@ -30,10 +31,11 @@ class MultiTableSynthesizer(BaseSynthesizer): """ def __init__( - self, write_connector: UID, uid: UID | None = None, name: str | None = None, + self, write_connector: Connector | UID, uid: UID | None = None, name: str | None = None, project: Project | None = None, client: Client | None = None): - self.__write_connector = write_connector + connector = self._check_or_fetch_connector(write_connector) + self.__write_connector = connector.uid super().__init__(uid, name, project, client) @@ -59,7 +61,7 @@ def fit(self, X: DataSource, self._fit_from_datasource(X) - def sample(self, frac: int | float = 1, write_connector: UID | None = None) -> None: + def sample(self, frac: int | float = 1, write_connector: Connector | UID | None = None) -> None: """Sample from a [`MultiTableSynthesizer`][ydata.sdk.synthesizers.MultiTableSynthesizer] instance. The sample is saved in the connector that was provided in the synthesizer initialization @@ -79,7 +81,8 @@ def sample(self, frac: int | float = 1, write_connector: UID | None = None) -> N } if write_connector is not None: - payload['writeConnector'] = write_connector + connector = self._check_or_fetch_connector(write_connector) + payload['writeConnector'] = connector.uid response = self._client.post( f"/synthesizer/{self.uid}/sample", json=payload, project=self._project) @@ -104,3 +107,18 @@ def _create_payload(self) -> dict: payload['writeConnector'] = self.__write_connector return payload + + def _check_or_fetch_connector(self, write_connector: Connector | UID) -> Connector: + self._logger.debug(f'Write connector is {write_connector}') + if isinstance(write_connector, str): + self._logger.debug(f'Write connector is of type `UID` {write_connector}') + write_connector = Connector.get(write_connector) + self._logger.debug(f'Using fetched connector {write_connector}') + + if write_connector.uid is None: + raise InputError("Invalid connector provided as input for write") + + if write_connector.type not in [ConnectorType.AZURE_SQL, ConnectorType.MYSQL, ConnectorType.SNOWFLAKE]: + raise ConnectorError(f"Invalid type `{write_connector.type}` for the provided connector") + + return write_connector From 4de8544e12d85f4f0818f5c4080715c9d0492d69 Mon Sep 17 00:00:00 2001 From: Azory YData Bot Date: Fri, 12 Jan 2024 18:03:08 +0000 Subject: [PATCH 08/11] fix(linting): code formatting --- src/ydata/sdk/synthesizers/multitable.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ydata/sdk/synthesizers/multitable.py b/src/ydata/sdk/synthesizers/multitable.py index 960edb1a..161071a1 100644 --- a/src/ydata/sdk/synthesizers/multitable.py +++ b/src/ydata/sdk/synthesizers/multitable.py @@ -3,7 +3,7 @@ from ydata.datascience.common import PrivacyLevel from ydata.sdk.common.client import Client from ydata.sdk.common.config import BACKOFF -from ydata.sdk.common.exceptions import InputError, ConnectorError +from ydata.sdk.common.exceptions import ConnectorError, InputError from ydata.sdk.common.types import UID, Project from ydata.sdk.connectors.connector import Connector, ConnectorType from ydata.sdk.datasources import DataSource @@ -119,6 +119,7 @@ def _check_or_fetch_connector(self, write_connector: Connector | UID) -> Connect raise InputError("Invalid connector provided as input for write") if write_connector.type not in [ConnectorType.AZURE_SQL, ConnectorType.MYSQL, ConnectorType.SNOWFLAKE]: - raise ConnectorError(f"Invalid type `{write_connector.type}` for the provided connector") + raise ConnectorError( + f"Invalid type `{write_connector.type}` for the provided connector") return write_connector From 4e782390c330e54276070a639d056ef3e8bb6c0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Portela=20Afonso?= Date: Fri, 12 Jan 2024 18:12:09 +0000 Subject: [PATCH 09/11] sample write connector as Connector --- src/ydata/sdk/synthesizers/multitable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ydata/sdk/synthesizers/multitable.py b/src/ydata/sdk/synthesizers/multitable.py index 161071a1..8dd2f851 100644 --- a/src/ydata/sdk/synthesizers/multitable.py +++ b/src/ydata/sdk/synthesizers/multitable.py @@ -34,11 +34,11 @@ def __init__( self, write_connector: Connector | UID, uid: UID | None = None, name: str | None = None, project: Project | None = None, client: Client | None = None): + super().__init__(uid, name, project, client) + connector = self._check_or_fetch_connector(write_connector) self.__write_connector = connector.uid - super().__init__(uid, name, project, client) - def fit(self, X: DataSource, privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY, datatype: DataSourceType | str | None = None, From 0945d35784b52150a48a8e72842e68bc7f483975 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Portela=20Afonso?= Date: Fri, 12 Jan 2024 20:30:21 +0000 Subject: [PATCH 10/11] examples and documentation --- docs/examples/synthesize_timeseries_data.md | 4 +-- docs/examples/synthesizer_multitable.md | 17 ++++++++++ .../reference/api/synthesizers/multitable.md | 1 + .../synthesizers/multi_table_quickstart.py | 25 +++++++++++++++ .../multi_table_sample_write_override.py | 32 +++++++++++++++++++ 5 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 docs/examples/synthesizer_multitable.md create mode 100644 docs/sdk/reference/api/synthesizers/multitable.md create mode 100644 examples/synthesizers/multi_table_quickstart.py create mode 100644 examples/synthesizers/multi_table_sample_write_override.py diff --git a/docs/examples/synthesize_timeseries_data.md b/docs/examples/synthesize_timeseries_data.md index a224a530..5bfd3234 100644 --- a/docs/examples/synthesize_timeseries_data.md +++ b/docs/examples/synthesize_timeseries_data.md @@ -2,9 +2,9 @@ **Use YData's *TimeSeriesSynthesizer* to generate time-series synthetic data** -Tabular data is the most common type of data we encounter in data problems. +Timeseries is the most common type of data we encounter in data problems. -When thinking about tabular data, we assume independence between different records, but this does not happen in reality. Suppose we check events from our day-to-day life, such as room temperature changes, bank account transactions, stock price fluctuations, and air quality measurements in our neighborhood. In that case, we might end up with datasets where measures and records evolve and are related through time. This type of data is known to be sequential or time-series data. +When thinking about timeseries data, we assume independence between different records, but this does not happen in reality. Suppose we check events from our day-to-day life, such as room temperature changes, bank account transactions, stock price fluctuations, and air quality measurements in our neighborhood. In that case, we might end up with datasets where measures and records evolve and are related through time. This type of data is known to be sequential or time-series data. Thus, sequential or time-series data refers to any data containing elements ordered into sequences in a structured format. Dissecting any time-series dataset, we see differences in variables' behavior that need to be understood for an effective generation of synthetic data. Typically any time-series dataset is composed of the following: diff --git a/docs/examples/synthesizer_multitable.md b/docs/examples/synthesizer_multitable.md new file mode 100644 index 00000000..e6a9f52b --- /dev/null +++ b/docs/examples/synthesizer_multitable.md @@ -0,0 +1,17 @@ +# Synthesize Multi Table + +**Use YData's *MultiTableSynthesizer* to generate multi table synthetic data from multiple RDBMS tables** + +Multi table is the way to synthesize data from multiple tables from a database, with a relational in mind... + +Quickstart example: + +```python +--8<-- "examples/synthesizers/multi_table_quickstart.py" +``` + +Sample write connector overriding example: + +```python +--8<-- "examples/synthesizers/multi_table_sample_write_override.py" +``` diff --git a/docs/sdk/reference/api/synthesizers/multitable.md b/docs/sdk/reference/api/synthesizers/multitable.md new file mode 100644 index 00000000..ffba37d1 --- /dev/null +++ b/docs/sdk/reference/api/synthesizers/multitable.md @@ -0,0 +1 @@ +::: ydata.sdk.synthesizers.multitable.MultiTableSynthesizer diff --git a/examples/synthesizers/multi_table_quickstart.py b/examples/synthesizers/multi_table_quickstart.py new file mode 100644 index 00000000..c47c8334 --- /dev/null +++ b/examples/synthesizers/multi_table_quickstart.py @@ -0,0 +1,25 @@ +import os + +from ydata.sdk.datasources import DataSource +from ydata.sdk.synthesizers import MultiTableSynthesizer + +# Do not forget to add your token as env variables +os.environ["YDATA_TOKEN"] = '' # Remove if already defined + +# In this example, we demonstrate how to train a synthesizer from an existing multi table RDBMS datasource. +# After training a Multi Table Synthesizer, we request a sample. +# In this case, we don't return the Dataset for the sample, it will be saved in the database +# that the connector refers to. + +X = DataSource.get('') + +# Initialize a multi table synthesizer with the connector to write to +# As long as the synthesizer does not call `fit`, it exists only locally +# write_connector can be an UID or a Connector instance +synth = MultiTableSynthesizer(write_connector='') + +# For demonstration purposes, we will use a connector instance, but you can just send the UID + +write_connector = Connector.get('') + +# Initialize a multi table synthesizer with the connector to write to +# As long as the synthesizer does not call `fit`, it exists only locally +# write_connector can be an UID or a Connector instance +synth = MultiTableSynthesizer(write_connector=write_connector) + +# The synthesizer training is requested +synth.fit(X) + +# We request a synthetic dataset with a fracion of 1.5 +# In this case we use a Connector instance. +# You can just use the you don't need to get the connector upfront. +synth.sample(frac=1.5, write_connector=write_connector) From 0fd00cfaee36c9661b660155e11bdb7e86a96f75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Portela=20Afonso?= Date: Mon, 15 Jan 2024 11:51:06 +0000 Subject: [PATCH 11/11] support for python 3.8 --- Makefile | 2 +- docs/examples/synthesize_timeseries_data.md | 4 +-- src/ydata/sdk/common/client/client.py | 18 ++++++++---- .../sdk/datasources/_models/datasource.py | 14 ++++----- src/ydata/sdk/synthesizers/_models/status.py | 12 ++++---- .../sdk/synthesizers/_models/synthesizer.py | 10 ++++--- src/ydata/sdk/synthesizers/multitable.py | 29 ++++++++++--------- src/ydata/sdk/synthesizers/synthesizer.py | 16 +++++----- 8 files changed, 59 insertions(+), 46 deletions(-) diff --git a/Makefile b/Makefile index 07ce061f..e847e663 100644 --- a/Makefile +++ b/Makefile @@ -77,7 +77,7 @@ wheel: ### Compiles the wheel $(PYTHON) -m twine check wheels/* upload: - $(PYTHON) -m twine upload -r ydata wheels/ydata_sdk-$(version)-py310-none-any.whl + $(PYTHON) -m twine upload -r ydata wheels/ydata_sdk-$(version)-py$(PYV)-none-any.whl publish-docs: ### Publishes the documentation mike deploy --push --update-aliases $(version) latest diff --git a/docs/examples/synthesize_timeseries_data.md b/docs/examples/synthesize_timeseries_data.md index 5bfd3234..a224a530 100644 --- a/docs/examples/synthesize_timeseries_data.md +++ b/docs/examples/synthesize_timeseries_data.md @@ -2,9 +2,9 @@ **Use YData's *TimeSeriesSynthesizer* to generate time-series synthetic data** -Timeseries is the most common type of data we encounter in data problems. +Tabular data is the most common type of data we encounter in data problems. -When thinking about timeseries data, we assume independence between different records, but this does not happen in reality. Suppose we check events from our day-to-day life, such as room temperature changes, bank account transactions, stock price fluctuations, and air quality measurements in our neighborhood. In that case, we might end up with datasets where measures and records evolve and are related through time. This type of data is known to be sequential or time-series data. +When thinking about tabular data, we assume independence between different records, but this does not happen in reality. Suppose we check events from our day-to-day life, such as room temperature changes, bank account transactions, stock price fluctuations, and air quality measurements in our neighborhood. In that case, we might end up with datasets where measures and records evolve and are related through time. This type of data is known to be sequential or time-series data. Thus, sequential or time-series data refers to any data containing elements ordered into sequences in a structured format. Dissecting any time-series dataset, we see differences in variables' behavior that need to be understood for an effective generation of synthetic data. Typically any time-series dataset is composed of the following: diff --git a/src/ydata/sdk/common/client/client.py b/src/ydata/sdk/common/client/client.py index a3a3a258..3040c0c3 100644 --- a/src/ydata/sdk/common/client/client.py +++ b/src/ydata/sdk/common/client/client.py @@ -60,8 +60,10 @@ def __init__(self, credentials: Optional[Union[str, Dict]] = None, project: Opti if set_as_global: self.__set_global() - def post(self, endpoint: str, data: Optional[Dict] = None, json: Optional[Dict] = None, - project: Project | None = None, files: Optional[Dict] = None, raise_for_status: bool = True) -> Response: + def post( + self, endpoint: str, data: Optional[Dict] = None, json: Optional[Dict] = None, + project: Optional[Project] = None, files: Optional[Dict] = None, raise_for_status: bool = True + ) -> Response: """POST request to the backend. Args: @@ -83,8 +85,10 @@ def post(self, endpoint: str, data: Optional[Dict] = None, json: Optional[Dict] return response - def get(self, endpoint: str, params: Optional[Dict] = None, - project: Project | None = None, cookies: Optional[Dict] = None, raise_for_status: bool = True) -> Response: + def get( + self, endpoint: str, params: Optional[Dict] = None, project: Optional[Project] = None, + cookies: Optional[Dict] = None, raise_for_status: bool = True + ) -> Response: """GET request to the backend. Args: @@ -104,7 +108,9 @@ def get(self, endpoint: str, params: Optional[Dict] = None, return response - def get_static_file(self, endpoint: str, project: Project | None = None, raise_for_status: bool = True) -> Response: + def get_static_file( + self, endpoint: str, project: Optional[Project] = None, raise_for_status: bool = True + ) -> Response: """Retrieve a static file from the backend. Args: @@ -141,7 +147,7 @@ def _get_default_project(self, token: str): return data['myWorkspace'] def __build_url(self, endpoint: str, params: Optional[Dict] = None, data: Optional[Dict] = None, - json: Optional[Dict] = None, project: Project | None = None, files: Optional[Dict] = None, + json: Optional[Dict] = None, project: Optional[Project] = None, files: Optional[Dict] = None, cookies: Optional[Dict] = None) -> Dict: """Build a request for the backend. diff --git a/src/ydata/sdk/datasources/_models/datasource.py b/src/ydata/sdk/datasources/_models/datasource.py index cd33c162..0b3fffca 100644 --- a/src/ydata/sdk/datasources/_models/datasource.py +++ b/src/ydata/sdk/datasources/_models/datasource.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional from ydata.sdk.common.types import UID from ydata.sdk.datasources._models.datatype import DataSourceType @@ -8,13 +9,12 @@ @dataclass class DataSource: - - uid: UID | None = None - author: str | None = None - name: str | None = None - datatype: DataSourceType | None = None - metadata: Metadata | None = None - status: Status | None = None + uid: Optional[UID] = None + author: Optional[str] = None + name: Optional[str] = None + datatype: Optional[DataSourceType] = None + metadata: Optional[Metadata] = None + status: Optional[Status] = None def to_payload(self): return {} diff --git a/src/ydata/sdk/synthesizers/_models/status.py b/src/ydata/sdk/synthesizers/_models/status.py index 3655fcc6..dd892bd5 100644 --- a/src/ydata/sdk/synthesizers/_models/status.py +++ b/src/ydata/sdk/synthesizers/_models/status.py @@ -1,4 +1,4 @@ -from typing import Generic, TypeVar +from typing import Generic, Optional, TypeVar from pydantic import BaseModel, Field @@ -8,7 +8,7 @@ class GenericStateErrorStatus(BaseModel, Generic[T]): - state: T | None = Field(None) + state: Optional[T] = Field(None) class Config: use_enum_values = True @@ -50,10 +50,10 @@ class State(StringEnum): REPORT = "report" READY = "ready" - state: State | None = Field(None) - prepare: PrepareStatus | None = Field(None) - training: TrainingStatus | None = Field(None) - report: ReportStatus | None = Field(None) + state: Optional[State] = Field(None) + prepare: Optional[PrepareStatus] = Field(None) + training: Optional[TrainingStatus] = Field(None) + report: Optional[ReportStatus] = Field(None) @staticmethod def not_initialized() -> "Status": diff --git a/src/ydata/sdk/synthesizers/_models/synthesizer.py b/src/ydata/sdk/synthesizers/_models/synthesizer.py index 79242ece..7928c9a2 100644 --- a/src/ydata/sdk/synthesizers/_models/synthesizer.py +++ b/src/ydata/sdk/synthesizers/_models/synthesizer.py @@ -1,10 +1,12 @@ +from typing import Optional + from pydantic import BaseModel, Field from .status import Status class Synthesizer(BaseModel): - uid: str | None = None - author: str | None = None - name: str | None = None - status: Status | None = Field(None) + uid: Optional[str] = Field(None) + author: Optional[str] = Field(None) + name: Optional[str] = Field(None) + status: Optional[Status] = Field(None) diff --git a/src/ydata/sdk/synthesizers/multitable.py b/src/ydata/sdk/synthesizers/multitable.py index 8dd2f851..faff0a75 100644 --- a/src/ydata/sdk/synthesizers/multitable.py +++ b/src/ydata/sdk/synthesizers/multitable.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from time import sleep +from typing import Dict, List, Optional, Union from ydata.datascience.common import PrivacyLevel from ydata.sdk.common.client import Client @@ -31,8 +34,8 @@ class MultiTableSynthesizer(BaseSynthesizer): """ def __init__( - self, write_connector: Connector | UID, uid: UID | None = None, name: str | None = None, - project: Project | None = None, client: Client | None = None): + self, write_connector: Union[Connector, UID], uid: Optional[UID] = None, name: Optional[str] = None, + project: Optional[Project] = None, client: Optional[Client] = None): super().__init__(uid, name, project, client) @@ -41,15 +44,15 @@ def __init__( def fit(self, X: DataSource, privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY, - datatype: DataSourceType | str | None = None, - sortbykey: str | list[str] | None = None, - entities: str | list[str] | None = None, - generate_cols: list[str] | None = None, - exclude_cols: list[str] | None = None, - dtypes: dict[str, str | DataType] | None = None, - target: str | None = None, - anonymize: dict | None = None, - condition_on: list[str] | None = None) -> None: + datatype: Optional[Union[DataSourceType, str]] = None, + sortbykey: Optional[Union[str, List[str]]] = None, + entities: Optional[Union[str, List[str]]] = None, + generate_cols: Optional[List[str]] = None, + exclude_cols: Optional[List[str]] = None, + dtypes: Optional[Dict[str, Union[str, DataType]]] = None, + target: Optional[str] = None, + anonymize: Optional[dict] = None, + condition_on: Optional[List[str]] = None) -> None: """Fit the synthesizer. The synthesizer accepts as training dataset a YData [`DataSource`][ydata.sdk.datasources.DataSource]. @@ -61,7 +64,7 @@ def fit(self, X: DataSource, self._fit_from_datasource(X) - def sample(self, frac: int | float = 1, write_connector: Connector | UID | None = None) -> None: + def sample(self, frac: Union[int, float] = 1, write_connector: Optional[Union[Connector, UID]] = None) -> None: """Sample from a [`MultiTableSynthesizer`][ydata.sdk.synthesizers.MultiTableSynthesizer] instance. The sample is saved in the connector that was provided in the synthesizer initialization @@ -108,7 +111,7 @@ def _create_payload(self) -> dict: return payload - def _check_or_fetch_connector(self, write_connector: Connector | UID) -> Connector: + def _check_or_fetch_connector(self, write_connector: Union[Connector, UID]) -> Connector: self._logger.debug(f'Write connector is {write_connector}') if isinstance(write_connector, str): self._logger.debug(f'Write connector is of type `UID` {write_connector}') diff --git a/src/ydata/sdk/synthesizers/synthesizer.py b/src/ydata/sdk/synthesizers/synthesizer.py index 6a3e91f3..604c3211 100644 --- a/src/ydata/sdk/synthesizers/synthesizer.py +++ b/src/ydata/sdk/synthesizers/synthesizer.py @@ -49,7 +49,9 @@ class BaseSynthesizer(ABC, ModelFactoryMixin): client (Client): (optional) Client to connect to the backend """ - def __init__(self, uid: UID | None = None, name: str | None = None, project: Project | None = None, client: Client | None = None): + def __init__( + self, uid: Optional[UID] = None, name: Optional[str] = None, + project: Optional[Project] = None, client: Optional[Client] = None): self._init_common(client=client) self._model = mSynthesizer(uid=uid, name=name or str(uuid4())) self._project = project @@ -179,7 +181,7 @@ def _validate_datasource_attributes(X: Union[DataSource, pdDataFrame], dataset_a @staticmethod def _metadata_to_payload( datatype: DataSourceType, ds_metadata: Metadata, - dataset_attrs: Optional[DataSourceAttrs] = None, target: str | None = None + dataset_attrs: Optional[DataSourceAttrs] = None, target: Optional[str] = None ) -> dict: """Transform a the metadata and dataset attributes into a valid payload. @@ -218,11 +220,11 @@ def _metadata_to_payload( def _fit_from_datasource( self, X: DataSource, - privacy_level: PrivacyLevel | None = None, - dataset_attrs: DataSourceAttrs | None = None, - target: str | None = None, - anonymize: dict | None = None, - condition_on: list[str] | None = None + privacy_level: Optional[PrivacyLevel] = None, + dataset_attrs: Optional[DataSourceAttrs] = None, + target: Optional[str] = None, + anonymize: Optional[dict] = None, + condition_on: Optional[List[str]] = None ) -> None: payload = self._create_payload()