Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: conn datasource creation #156

Merged
merged 12 commits into from
Feb 7, 2025
40 changes: 40 additions & 0 deletions examples/datasource/mysql_datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Example to create a MySQL datasource.
"""
import os

from ydata.sdk.connectors import Connector
from ydata.sdk.datasources import DataSource
from ydata.sdk.datasources.datasource import DataSourceType

os.environ["YDATA_TOKEN"] = 'insert-token'

if __name__ == '__main__':

USERNAME = "username"
PASSWORD = "pass"
HOSTNAME = "host"
PORT = "3306"
DATABASE_NAME = "berka"

conn_str = {
"hostname": HOSTNAME,
"username": USERNAME,
"password": PASSWORD,
"port": PORT,
"database": DATABASE_NAME,
}

conn = Connector.get(uid='insert-id')
print(conn)

""" Connector creation example
connector = Connector.create(connector_type=ConnectorType.MYSQL,
credentials=conn_str,
name="MySQL Berka - SDK")
"""

datasource = DataSource(datatype=DataSourceType.TABULAR,
connector=conn,
name="MySQL Berka - SDK")
# query={'query': 'SELECT * FROM trans;'})
5 changes: 2 additions & 3 deletions examples/synthesizers/privacy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ydata.sdk.synthesizers import PrivacyLevel, RegularSynthesizer

# Do not forget to add your token as env variables
os.environ["YDATA_TOKEN"] = '<TOKEN>' # Remove if already defined
os.environ["YDATA_TOKEN"] = '{insert-your-token}' # Remove if already defined


def main():
Expand All @@ -16,12 +16,11 @@ def main():

# We initialize a regular synthesizer
# As long as the synthesizer does not call `fit`, it exists only locally
synth = RegularSynthesizer()
synth = RegularSynthesizer(name='Titanic Privacy')

# We train the synthesizer on our dataset setting the privacy level to high
synth.fit(
X,
name="titanic_synthesizer",
privacy_level=PrivacyLevel.HIGH_PRIVACY
)

Expand Down
27 changes: 27 additions & 0 deletions examples/synthesizers/regular_existing_datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

from ydata.sdk.datasources import DataSource
from ydata.sdk.synthesizers import RegularSynthesizer

# Authenticate to Fabric to leverage the SDK - https://docs.sdk.ydata.ai/latest/sdk/installation/
# Make sure to add your token as env variable.
os.environ["YDATA_TOKEN"] = '{insert-token}' # Remove if already defined


# In this example, we demonstrate how to train a synthesizer from an existing RDBMS Dataset.
# Make sure to follow the step-by-step gu ide to create a Dataset in Fabric's catalog: https://docs.sdk.ydata.ai/latest/get-started/create_multitable_dataset/
X = DataSource.get('{insert-datasource-id}')

# Init a multi-table synthesizer. Provide a connector so that the process of data synthesis write the
# synthetic data into the destination database
# Provide a connector ID as the write_connector argument. See in this tutorial how to get a connector ID
synth = RegularSynthesizer(name='existing_DS')

# Start the training of your synthetic data generator
synth.fit(X)

# As soon as the training process is completed you are able to sample a synthetic database
# The input expected is a percentage of the original database size
# In this case it was requested a synthetic database with the same size as the original
# Your synthetic sample was written to the database provided in the write_connector
synth.sample(n_samples=1000)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ dependencies = [
"pandas>=1.5.0",
"prettytable==3.13.*",
"pydantic>=2.0.0",
"typeguard==2.13.3",
"typeguard>=2.13.3",
"ydata-datascience",
"requests==2.*",
]
Expand Down
18 changes: 14 additions & 4 deletions src/ydata/sdk/connectors/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,14 @@ def __init__(
self, connector_type: Union[ConnectorType, str, None] = None, credentials: Optional[Dict] = None,
name: Optional[str] = None, project: Optional[Project] = None, client: Optional[Client] = None):
self._init_common(client=client)
self._model = _connector_type_to_model(ConnectorType._init_connector_type(connector_type))._create_model(
connector_type, credentials, name, client=client)

self._model = self.create(connector_type=connector_type,
credentials=credentials,
name=name, project=project,
client=client)

# self._model = _connector_type_to_model(ConnectorType._init_connector_type(connector_type))._create_model(
# connector_type, credentials, name, client=client)

self._project = project

Expand Down Expand Up @@ -150,9 +156,13 @@ def create(

payload = {
"type": connector_type.value,
"credentials": credentials.dict(by_alias=True)
"credentials": credentials if isinstance(credentials, dict) else credentials.dict(by_alias=True)
}
model = connector_class._create(payload, name, project, client)

if client is None:
model = connector_class._create(payload, name, project)
else:
model = connector_class._create(payload, name, project, client)

connector = connector_class._init_from_model_data(model)
connector._project = project
Expand Down
4 changes: 2 additions & 2 deletions src/ydata/sdk/datasources/_models/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ class DataSource:
connector_type: Optional[str] = None

def __post_init__(self):
if self.metadata is not None:
if self.metadata is not None and not isinstance(self.metadata, Metadata):
self.metadata = Metadata(**self.metadata)

if self.status is not None:
if self.status is not None and not isinstance(self.status, Status):
self.status = Status(**self.status)

def to_payload(self):
Expand Down
4 changes: 2 additions & 2 deletions src/ydata/sdk/datasources/_models/datasources/mysql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import asdict, dataclass

from ydata.sdk.datasources._models.datasource import DataSource

Expand All @@ -10,4 +10,4 @@ class MySQLDataSource(DataSource):
tables: dict = None

def to_payload(self):
self.dict()
return asdict(self)
3 changes: 3 additions & 0 deletions src/ydata/sdk/datasources/_models/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class MetadataState(StringEnum):
GENERATING = 'generating'
FAILED = 'failed'
AVAILABLE = 'available'
UNAVAILABLE = 'unavailable'


class ProfilingState(StringEnum):
Expand All @@ -29,6 +30,7 @@ class ProfilingState(StringEnum):
GENERATING = 'generating'
FAILED = 'failed'
AVAILABLE = 'available'
UNAVAILABLE = 'unavailable'


class State(StringEnum):
Expand Down Expand Up @@ -73,6 +75,7 @@ class Status(BaseModel):
validation: Optional[ValidationStatus] = Field(None)
metadata: Optional[MetadataStatus] = Field(None)
profiling: Optional[ProfilingStatus] = Field(None)
dependentSynthesizersNumber: Optional[int] = Field(None)

@staticmethod
def unknown() -> "Status":
Expand Down
21 changes: 17 additions & 4 deletions src/ydata/sdk/datasources/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def _init_common(self, client: Optional[Client] = None):
self._client = client
self._logger = create_logger(__name__, level=LOG_LEVEL)

@property
def client(self):
return self._client

@property
def uid(self) -> UID:
return self._model.uid
Expand Down Expand Up @@ -127,7 +131,8 @@ def get(uid: UID, project: Optional[Project] = None, client: Optional[Client] =
data: list = response.json()
datasource_type = CONNECTOR_TO_DATASOURCE.get(
ConnectorType(data['connector']['type']))
datasource = DataSource._model_from_api(data, datasource_type)
model = DataSource._model_from_api(data, datasource_type)
datasource = DataSource._init_from_model_data(model)
datasource._project = project
return datasource

Expand All @@ -152,6 +157,7 @@ def create(
DataSource
"""
datasource_type = CONNECTOR_TO_DATASOURCE.get(connector.type)

return cls._create(
connector=connector, datasource_type=datasource_type, datatype=datatype, config=config, name=name,
project=project, wait_for_metadata=wait_for_metadata, client=client)
Expand All @@ -163,8 +169,14 @@ def _create(
name: Optional[str] = None, project: Optional[Project] = None, wait_for_metadata: bool = True,
client: Optional[Client] = None
) -> "DataSource":
model = DataSource._create_model(
connector, datasource_type, datatype, config, name, project, client)

if client is None:
model = DataSource._create_model(
connector, datasource_type, datatype, config, name, project)
else:
model = DataSource._create_model(
connector, datasource_type, datatype, config, name, project, client)

datasource = DataSource._init_from_model_data(model)

if wait_for_metadata:
Expand Down Expand Up @@ -201,9 +213,10 @@ def _create_model(
@staticmethod
def _wait_for_metadata(datasource):
logger = create_logger(__name__, level=LOG_LEVEL)
client = datasource._client
while State(datasource.status.state) not in [State.AVAILABLE, State.FAILED, State.UNAVAILABLE]:
logger.info(f'Calculating metadata [{datasource.status}]')
datasource = DataSource.get(uid=datasource.uid, client=datasource._client)
datasource = DataSource.get(uid=datasource.uid, client=client)
sleep(BACKOFF)
return datasource

Expand Down
8 changes: 4 additions & 4 deletions src/ydata/sdk/synthesizers/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _validate_datasource_attributes(X, dataset_attrs: DataSourceAttrs, datatype:
raise DataTypeMissingError(
"Argument `datatype` is mandatory for pandas.DataFrame training data")
elif datatype == DataSourceType.MULTITABLE:
tables = [t for t in X.tables.keys()] # noqa: F841
tables = [t for t in X._model.tables.keys()] # noqa: F841
# Does it make sense to add more validations here?
else:
columns = [c.name for c in X.metadata.columns]
Expand Down Expand Up @@ -320,14 +320,14 @@ def _sample(self, payload: Dict) -> pdDataFrame:
response = self._client.post(
f"/synthesizer/{self.uid}/sample", json=payload, project=self._project)

data: Dict = response.json()
sample_uid = data.get('uid')
data = response.json()
sample_uid: str = data.get('uid')
sample_status = None
while sample_status not in ['finished', 'failed']:
self._logger.info('Sampling from the synthesizer...')
response = self._client.get(
f'/synthesizer/{self.uid}/history', project=self._project)
history: Dict = response.json()
history: list = response.json()
sample_data = next((s for s in history if s.get('uid') == sample_uid), None)
sample_status = sample_data.get('status', {}).get('state')
sleep(BACKOFF)
Expand Down
4 changes: 2 additions & 2 deletions src/ydata/sdk/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ def get_datasource_info(dataframe, datatype):
nrows, ncols = dataframe.shape[0], dataframe.shape[1]
ntables = None # calculate the number of rows and cols
else:
connector = dataframe.connector_type
connector = dataframe._model.connector_type
if DataSourceType(datatype) != DataSourceType.MULTITABLE:
nrows = dataframe.metadata.number_of_rows
ncols = len(dataframe.metadata.columns)
ntables = 1
else:
nrows = 0
ncols = 0
ntables = len(dataframe.tables.keys())
ntables = len(dataframe._model.tables.keys())
return connector, nrows, ncols, ntables


Expand Down