diff --git a/examples/synthesizers/privacy_example.py b/examples/synthesizers/privacy_example.py index 3299c068..0c71c551 100644 --- a/examples/synthesizers/privacy_example.py +++ b/examples/synthesizers/privacy_example.py @@ -4,7 +4,7 @@ from ydata.sdk.synthesizers import PrivacyLevel, RegularSynthesizer # Do not forget to add your token as env variables -os.environ["YDATA_TOKEN"] = '' # Remove if already defined +os.environ["YDATA_TOKEN"] = '{insert-your-token}' # Remove if already defined def main(): @@ -16,12 +16,11 @@ def main(): # We initialize a regular synthesizer # As long as the synthesizer does not call `fit`, it exists only locally - synth = RegularSynthesizer() + synth = RegularSynthesizer(name='Titanic Privacy') # We train the synthesizer on our dataset setting the privacy level to high synth.fit( X, - name="titanic_synthesizer", privacy_level=PrivacyLevel.HIGH_PRIVACY ) diff --git a/examples/synthesizers/regular_existing_datasource.py b/examples/synthesizers/regular_existing_datasource.py new file mode 100644 index 00000000..47196e0f --- /dev/null +++ b/examples/synthesizers/regular_existing_datasource.py @@ -0,0 +1,27 @@ +import os + +from ydata.sdk.datasources import DataSource +from ydata.sdk.synthesizers import RegularSynthesizer + +# Authenticate to Fabric to leverage the SDK - https://docs.sdk.ydata.ai/latest/sdk/installation/ +# Make sure to add your token as env variable. +os.environ["YDATA_TOKEN"] = '{insert-token}' # Remove if already defined + + +# In this example, we demonstrate how to train a synthesizer from an existing RDBMS Dataset. +# Make sure to follow the step-by-step gu ide to create a Dataset in Fabric's catalog: https://docs.sdk.ydata.ai/latest/get-started/create_multitable_dataset/ +X = DataSource.get('{insert-datasource-id}') + +# Init a multi-table synthesizer. Provide a connector so that the process of data synthesis write the +# synthetic data into the destination database +# Provide a connector ID as the write_connector argument. See in this tutorial how to get a connector ID +synth = RegularSynthesizer(name='existing_DS') + +# Start the training of your synthetic data generator +synth.fit(X) + +# As soon as the training process is completed you are able to sample a synthetic database +# The input expected is a percentage of the original database size +# In this case it was requested a synthetic database with the same size as the original +# Your synthetic sample was written to the database provided in the write_connector +synth.sample(n_samples=1000) diff --git a/pyproject.toml b/pyproject.toml index a93e58c9..a79ae8e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ "pandas>=1.5.0", "prettytable==3.13.*", "pydantic>=2.0.0", - "typeguard==2.13.3", + "typeguard>=2.13.3, <2.14.0", "ydata-datascience", "requests==2.*", ] diff --git a/src/ydata/sdk/connectors/connector.py b/src/ydata/sdk/connectors/connector.py index 81081557..2ffa6695 100644 --- a/src/ydata/sdk/connectors/connector.py +++ b/src/ydata/sdk/connectors/connector.py @@ -50,8 +50,11 @@ def __init__( self, connector_type: Union[ConnectorType, str, None] = None, credentials: Optional[Dict] = None, name: Optional[str] = None, project: Optional[Project] = None, client: Optional[Client] = None): self._init_common(client=client) - self._model = _connector_type_to_model(ConnectorType._init_connector_type(connector_type))._create_model( - connector_type, credentials, name, client=client) + + self._model = self.create(connector_type=connector_type, + credentials=credentials, + name=name, project=project, + client=client) self._project = project @@ -150,9 +153,13 @@ def create( payload = { "type": connector_type.value, - "credentials": credentials.dict(by_alias=True) + "credentials": credentials if isinstance(credentials, dict) else credentials.dict(by_alias=True) } - model = connector_class._create(payload, name, project, client) + + if client is None: + model = connector_class._create(payload, name, project) + else: + model = connector_class._create(payload, name, project, client) connector = connector_class._init_from_model_data(model) connector._project = project diff --git a/src/ydata/sdk/datasources/_models/datasource.py b/src/ydata/sdk/datasources/_models/datasource.py index 082efd49..08be2fc5 100644 --- a/src/ydata/sdk/datasources/_models/datasource.py +++ b/src/ydata/sdk/datasources/_models/datasource.py @@ -20,10 +20,10 @@ class DataSource: connector_type: Optional[str] = None def __post_init__(self): - if self.metadata is not None: + if self.metadata is not None and not isinstance(self.metadata, Metadata): self.metadata = Metadata(**self.metadata) - if self.status is not None: + if self.status is not None and not isinstance(self.status, Status): self.status = Status(**self.status) def to_payload(self): diff --git a/src/ydata/sdk/datasources/_models/datasources/mysql.py b/src/ydata/sdk/datasources/_models/datasources/mysql.py index b144d4ca..af99c21d 100644 --- a/src/ydata/sdk/datasources/_models/datasources/mysql.py +++ b/src/ydata/sdk/datasources/_models/datasources/mysql.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import asdict, dataclass from ydata.sdk.datasources._models.datasource import DataSource @@ -10,4 +10,4 @@ class MySQLDataSource(DataSource): tables: dict = None def to_payload(self): - self.dict() + return asdict(self) diff --git a/src/ydata/sdk/datasources/_models/status.py b/src/ydata/sdk/datasources/_models/status.py index c830f511..f19674c7 100644 --- a/src/ydata/sdk/datasources/_models/status.py +++ b/src/ydata/sdk/datasources/_models/status.py @@ -21,6 +21,7 @@ class MetadataState(StringEnum): GENERATING = 'generating' FAILED = 'failed' AVAILABLE = 'available' + UNAVAILABLE = 'unavailable' class ProfilingState(StringEnum): @@ -29,6 +30,7 @@ class ProfilingState(StringEnum): GENERATING = 'generating' FAILED = 'failed' AVAILABLE = 'available' + UNAVAILABLE = 'unavailable' class State(StringEnum): @@ -73,6 +75,7 @@ class Status(BaseModel): validation: Optional[ValidationStatus] = Field(None) metadata: Optional[MetadataStatus] = Field(None) profiling: Optional[ProfilingStatus] = Field(None) + dependentSynthesizersNumber: Optional[int] = Field(None) @staticmethod def unknown() -> "Status": diff --git a/src/ydata/sdk/datasources/datasource.py b/src/ydata/sdk/datasources/datasource.py index 968afd2f..1d1b6817 100644 --- a/src/ydata/sdk/datasources/datasource.py +++ b/src/ydata/sdk/datasources/datasource.py @@ -59,6 +59,10 @@ def _init_common(self, client: Optional[Client] = None): self._client = client self._logger = create_logger(__name__, level=LOG_LEVEL) + @property + def client(self): + return self._client + @property def uid(self) -> UID: return self._model.uid @@ -127,7 +131,8 @@ def get(uid: UID, project: Optional[Project] = None, client: Optional[Client] = data: list = response.json() datasource_type = CONNECTOR_TO_DATASOURCE.get( ConnectorType(data['connector']['type'])) - datasource = DataSource._model_from_api(data, datasource_type) + model = DataSource._model_from_api(data, datasource_type) + datasource = DataSource._init_from_model_data(model) datasource._project = project return datasource @@ -152,6 +157,7 @@ def create( DataSource """ datasource_type = CONNECTOR_TO_DATASOURCE.get(connector.type) + return cls._create( connector=connector, datasource_type=datasource_type, datatype=datatype, config=config, name=name, project=project, wait_for_metadata=wait_for_metadata, client=client) @@ -163,8 +169,14 @@ def _create( name: Optional[str] = None, project: Optional[Project] = None, wait_for_metadata: bool = True, client: Optional[Client] = None ) -> "DataSource": - model = DataSource._create_model( - connector, datasource_type, datatype, config, name, project, client) + + if client is None: + model = DataSource._create_model( + connector, datasource_type, datatype, config, name, project) + else: + model = DataSource._create_model( + connector, datasource_type, datatype, config, name, project, client) + datasource = DataSource._init_from_model_data(model) if wait_for_metadata: @@ -201,9 +213,10 @@ def _create_model( @staticmethod def _wait_for_metadata(datasource): logger = create_logger(__name__, level=LOG_LEVEL) + client = datasource._client while State(datasource.status.state) not in [State.AVAILABLE, State.FAILED, State.UNAVAILABLE]: logger.info(f'Calculating metadata [{datasource.status}]') - datasource = DataSource.get(uid=datasource.uid, client=datasource._client) + datasource = DataSource.get(uid=datasource.uid, client=client) sleep(BACKOFF) return datasource diff --git a/src/ydata/sdk/synthesizers/synthesizer.py b/src/ydata/sdk/synthesizers/synthesizer.py index 7327b3b6..c10e98a5 100644 --- a/src/ydata/sdk/synthesizers/synthesizer.py +++ b/src/ydata/sdk/synthesizers/synthesizer.py @@ -172,7 +172,7 @@ def _validate_datasource_attributes(X, dataset_attrs: DataSourceAttrs, datatype: raise DataTypeMissingError( "Argument `datatype` is mandatory for pandas.DataFrame training data") elif datatype == DataSourceType.MULTITABLE: - tables = [t for t in X.tables.keys()] # noqa: F841 + tables = [t for t in X._model.tables.keys()] # noqa: F841 # Does it make sense to add more validations here? else: columns = [c.name for c in X.metadata.columns] @@ -320,14 +320,14 @@ def _sample(self, payload: Dict) -> pdDataFrame: response = self._client.post( f"/synthesizer/{self.uid}/sample", json=payload, project=self._project) - data: Dict = response.json() - sample_uid = data.get('uid') + data = response.json() + sample_uid: str = data.get('uid') sample_status = None while sample_status not in ['finished', 'failed']: self._logger.info('Sampling from the synthesizer...') response = self._client.get( f'/synthesizer/{self.uid}/history', project=self._project) - history: Dict = response.json() + history: list = response.json() sample_data = next((s for s in history if s.get('uid') == sample_uid), None) sample_status = sample_data.get('status', {}).get('state') sleep(BACKOFF) diff --git a/src/ydata/sdk/utils/logger.py b/src/ydata/sdk/utils/logger.py index 5f3b61a9..4a341744 100644 --- a/src/ydata/sdk/utils/logger.py +++ b/src/ydata/sdk/utils/logger.py @@ -31,7 +31,7 @@ def get_datasource_info(dataframe, datatype): nrows, ncols = dataframe.shape[0], dataframe.shape[1] ntables = None # calculate the number of rows and cols else: - connector = dataframe.connector_type + connector = dataframe._model.connector_type if DataSourceType(datatype) != DataSourceType.MULTITABLE: nrows = dataframe.metadata.number_of_rows ncols = len(dataframe.metadata.columns) @@ -39,7 +39,7 @@ def get_datasource_info(dataframe, datatype): else: nrows = 0 ncols = 0 - ntables = len(dataframe.tables.keys()) + ntables = len(dataframe._model.tables.keys()) return connector, nrows, ncols, ntables