Skip to content

Commit

Permalink
fix: conn datasource creation (#156)
Browse files Browse the repository at this point in the history
* fix: connector and datasource creation

* chore: add code example for SDK

* fix(linting): code formatting

* fix: connector creation and synthesizer get from api request.

* fix(linting): code formatting

* fix: fix linting

* fix: fix status and metrics to create synths from existing DS.

* fix(linting): code formatting

* fix: fix typeguard version

* fix: remove unused code.

---------

Co-authored-by: Azory YData Bot <[email protected]>
  • Loading branch information
fabclmnt and azory-ydata authored Feb 7, 2025
1 parent 91616da commit ec77c29
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 22 deletions.
5 changes: 2 additions & 3 deletions examples/synthesizers/privacy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ydata.sdk.synthesizers import PrivacyLevel, RegularSynthesizer

# Do not forget to add your token as env variables
os.environ["YDATA_TOKEN"] = '<TOKEN>' # Remove if already defined
os.environ["YDATA_TOKEN"] = '{insert-your-token}' # Remove if already defined


def main():
Expand All @@ -16,12 +16,11 @@ def main():

# We initialize a regular synthesizer
# As long as the synthesizer does not call `fit`, it exists only locally
synth = RegularSynthesizer()
synth = RegularSynthesizer(name='Titanic Privacy')

# We train the synthesizer on our dataset setting the privacy level to high
synth.fit(
X,
name="titanic_synthesizer",
privacy_level=PrivacyLevel.HIGH_PRIVACY
)

Expand Down
27 changes: 27 additions & 0 deletions examples/synthesizers/regular_existing_datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

from ydata.sdk.datasources import DataSource
from ydata.sdk.synthesizers import RegularSynthesizer

# Authenticate to Fabric to leverage the SDK - https://docs.sdk.ydata.ai/latest/sdk/installation/
# Make sure to add your token as env variable.
os.environ["YDATA_TOKEN"] = '{insert-token}' # Remove if already defined


# In this example, we demonstrate how to train a synthesizer from an existing RDBMS Dataset.
# Make sure to follow the step-by-step gu ide to create a Dataset in Fabric's catalog: https://docs.sdk.ydata.ai/latest/get-started/create_multitable_dataset/
X = DataSource.get('{insert-datasource-id}')

# Init a multi-table synthesizer. Provide a connector so that the process of data synthesis write the
# synthetic data into the destination database
# Provide a connector ID as the write_connector argument. See in this tutorial how to get a connector ID
synth = RegularSynthesizer(name='existing_DS')

# Start the training of your synthetic data generator
synth.fit(X)

# As soon as the training process is completed you are able to sample a synthetic database
# The input expected is a percentage of the original database size
# In this case it was requested a synthetic database with the same size as the original
# Your synthetic sample was written to the database provided in the write_connector
synth.sample(n_samples=1000)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ dependencies = [
"pandas>=1.5.0",
"prettytable==3.13.*",
"pydantic>=2.0.0",
"typeguard==2.13.3",
"typeguard>=2.13.3, <2.14.0",
"ydata-datascience",
"requests==2.*",
]
Expand Down
15 changes: 11 additions & 4 deletions src/ydata/sdk/connectors/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def __init__(
self, connector_type: Union[ConnectorType, str, None] = None, credentials: Optional[Dict] = None,
name: Optional[str] = None, project: Optional[Project] = None, client: Optional[Client] = None):
self._init_common(client=client)
self._model = _connector_type_to_model(ConnectorType._init_connector_type(connector_type))._create_model(
connector_type, credentials, name, client=client)

self._model = self.create(connector_type=connector_type,
credentials=credentials,
name=name, project=project,
client=client)

self._project = project

Expand Down Expand Up @@ -150,9 +153,13 @@ def create(

payload = {
"type": connector_type.value,
"credentials": credentials.dict(by_alias=True)
"credentials": credentials if isinstance(credentials, dict) else credentials.dict(by_alias=True)
}
model = connector_class._create(payload, name, project, client)

if client is None:
model = connector_class._create(payload, name, project)
else:
model = connector_class._create(payload, name, project, client)

connector = connector_class._init_from_model_data(model)
connector._project = project
Expand Down
4 changes: 2 additions & 2 deletions src/ydata/sdk/datasources/_models/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ class DataSource:
connector_type: Optional[str] = None

def __post_init__(self):
if self.metadata is not None:
if self.metadata is not None and not isinstance(self.metadata, Metadata):
self.metadata = Metadata(**self.metadata)

if self.status is not None:
if self.status is not None and not isinstance(self.status, Status):
self.status = Status(**self.status)

def to_payload(self):
Expand Down
4 changes: 2 additions & 2 deletions src/ydata/sdk/datasources/_models/datasources/mysql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import asdict, dataclass

from ydata.sdk.datasources._models.datasource import DataSource

Expand All @@ -10,4 +10,4 @@ class MySQLDataSource(DataSource):
tables: dict = None

def to_payload(self):
self.dict()
return asdict(self)
3 changes: 3 additions & 0 deletions src/ydata/sdk/datasources/_models/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class MetadataState(StringEnum):
GENERATING = 'generating'
FAILED = 'failed'
AVAILABLE = 'available'
UNAVAILABLE = 'unavailable'


class ProfilingState(StringEnum):
Expand All @@ -29,6 +30,7 @@ class ProfilingState(StringEnum):
GENERATING = 'generating'
FAILED = 'failed'
AVAILABLE = 'available'
UNAVAILABLE = 'unavailable'


class State(StringEnum):
Expand Down Expand Up @@ -73,6 +75,7 @@ class Status(BaseModel):
validation: Optional[ValidationStatus] = Field(None)
metadata: Optional[MetadataStatus] = Field(None)
profiling: Optional[ProfilingStatus] = Field(None)
dependentSynthesizersNumber: Optional[int] = Field(None)

@staticmethod
def unknown() -> "Status":
Expand Down
21 changes: 17 additions & 4 deletions src/ydata/sdk/datasources/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def _init_common(self, client: Optional[Client] = None):
self._client = client
self._logger = create_logger(__name__, level=LOG_LEVEL)

@property
def client(self):
return self._client

@property
def uid(self) -> UID:
return self._model.uid
Expand Down Expand Up @@ -127,7 +131,8 @@ def get(uid: UID, project: Optional[Project] = None, client: Optional[Client] =
data: list = response.json()
datasource_type = CONNECTOR_TO_DATASOURCE.get(
ConnectorType(data['connector']['type']))
datasource = DataSource._model_from_api(data, datasource_type)
model = DataSource._model_from_api(data, datasource_type)
datasource = DataSource._init_from_model_data(model)
datasource._project = project
return datasource

Expand All @@ -152,6 +157,7 @@ def create(
DataSource
"""
datasource_type = CONNECTOR_TO_DATASOURCE.get(connector.type)

return cls._create(
connector=connector, datasource_type=datasource_type, datatype=datatype, config=config, name=name,
project=project, wait_for_metadata=wait_for_metadata, client=client)
Expand All @@ -163,8 +169,14 @@ def _create(
name: Optional[str] = None, project: Optional[Project] = None, wait_for_metadata: bool = True,
client: Optional[Client] = None
) -> "DataSource":
model = DataSource._create_model(
connector, datasource_type, datatype, config, name, project, client)

if client is None:
model = DataSource._create_model(
connector, datasource_type, datatype, config, name, project)
else:
model = DataSource._create_model(
connector, datasource_type, datatype, config, name, project, client)

datasource = DataSource._init_from_model_data(model)

if wait_for_metadata:
Expand Down Expand Up @@ -201,9 +213,10 @@ def _create_model(
@staticmethod
def _wait_for_metadata(datasource):
logger = create_logger(__name__, level=LOG_LEVEL)
client = datasource._client
while State(datasource.status.state) not in [State.AVAILABLE, State.FAILED, State.UNAVAILABLE]:
logger.info(f'Calculating metadata [{datasource.status}]')
datasource = DataSource.get(uid=datasource.uid, client=datasource._client)
datasource = DataSource.get(uid=datasource.uid, client=client)
sleep(BACKOFF)
return datasource

Expand Down
8 changes: 4 additions & 4 deletions src/ydata/sdk/synthesizers/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _validate_datasource_attributes(X, dataset_attrs: DataSourceAttrs, datatype:
raise DataTypeMissingError(
"Argument `datatype` is mandatory for pandas.DataFrame training data")
elif datatype == DataSourceType.MULTITABLE:
tables = [t for t in X.tables.keys()] # noqa: F841
tables = [t for t in X._model.tables.keys()] # noqa: F841
# Does it make sense to add more validations here?
else:
columns = [c.name for c in X.metadata.columns]
Expand Down Expand Up @@ -320,14 +320,14 @@ def _sample(self, payload: Dict) -> pdDataFrame:
response = self._client.post(
f"/synthesizer/{self.uid}/sample", json=payload, project=self._project)

data: Dict = response.json()
sample_uid = data.get('uid')
data = response.json()
sample_uid: str = data.get('uid')
sample_status = None
while sample_status not in ['finished', 'failed']:
self._logger.info('Sampling from the synthesizer...')
response = self._client.get(
f'/synthesizer/{self.uid}/history', project=self._project)
history: Dict = response.json()
history: list = response.json()
sample_data = next((s for s in history if s.get('uid') == sample_uid), None)
sample_status = sample_data.get('status', {}).get('state')
sleep(BACKOFF)
Expand Down
4 changes: 2 additions & 2 deletions src/ydata/sdk/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ def get_datasource_info(dataframe, datatype):
nrows, ncols = dataframe.shape[0], dataframe.shape[1]
ntables = None # calculate the number of rows and cols
else:
connector = dataframe.connector_type
connector = dataframe._model.connector_type
if DataSourceType(datatype) != DataSourceType.MULTITABLE:
nrows = dataframe.metadata.number_of_rows
ncols = len(dataframe.metadata.columns)
ntables = 1
else:
nrows = 0
ncols = 0
ntables = len(dataframe.tables.keys())
ntables = len(dataframe._model.tables.keys())
return connector, nrows, ncols, ntables


Expand Down

0 comments on commit ec77c29

Please sign in to comment.