From 7b22974392bae645bb10827542ce2a9160c9689a Mon Sep 17 00:00:00 2001 From: Fabiana <30911746+fabclmnt@users.noreply.github.com> Date: Wed, 29 Jan 2025 11:22:09 +0000 Subject: [PATCH] feat: update multitable interface & datasources information (#136) * feat: update multitable interface * fix(linting): code formatting * fix: remove typeguard Datasource type validation This does not work for multitable datasets. Needs to be revisisted later on. * fix(linting): code formatting * chore: update links * chore: fix linter messages * chore: fix linting error. * fix(linting): code formatting --------- Co-authored-by: Azory YData Bot --- .../sdk/datasources/_models/datasource.py | 2 + .../datasources/_models/datasources/mysql.py | 1 + src/ydata/sdk/datasources/datasource.py | 5 +- src/ydata/sdk/synthesizers/multitable.py | 18 +-- src/ydata/sdk/synthesizers/regular.py | 3 +- src/ydata/sdk/synthesizers/synthesizer.py | 68 +++++++----- src/ydata/sdk/utils/logger.py | 105 ++++++++++++++++++ 7 files changed, 158 insertions(+), 44 deletions(-) create mode 100644 src/ydata/sdk/utils/logger.py diff --git a/src/ydata/sdk/datasources/_models/datasource.py b/src/ydata/sdk/datasources/_models/datasource.py index 172aa2e4..082efd49 100644 --- a/src/ydata/sdk/datasources/_models/datasource.py +++ b/src/ydata/sdk/datasources/_models/datasource.py @@ -16,6 +16,8 @@ class DataSource: datatype: Optional[DataSourceType] = None metadata: Optional[Metadata] = None status: Optional[Status] = None + connector_ref: Optional[str] = None + connector_type: Optional[str] = None def __post_init__(self): if self.metadata is not None: diff --git a/src/ydata/sdk/datasources/_models/datasources/mysql.py b/src/ydata/sdk/datasources/_models/datasources/mysql.py index 26e5dc47..b144d4ca 100644 --- a/src/ydata/sdk/datasources/_models/datasources/mysql.py +++ b/src/ydata/sdk/datasources/_models/datasources/mysql.py @@ -7,6 +7,7 @@ class MySQLDataSource(DataSource): query: str = None + tables: dict = None def to_payload(self): self.dict() diff --git a/src/ydata/sdk/datasources/datasource.py b/src/ydata/sdk/datasources/datasource.py index 4d457349..968afd2f 100644 --- a/src/ydata/sdk/datasources/datasource.py +++ b/src/ydata/sdk/datasources/datasource.py @@ -127,8 +127,7 @@ def get(uid: UID, project: Optional[Project] = None, client: Optional[Client] = data: list = response.json() datasource_type = CONNECTOR_TO_DATASOURCE.get( ConnectorType(data['connector']['type'])) - model = DataSource._model_from_api(data, datasource_type) - datasource = DataSource._init_from_model_data(model) + datasource = DataSource._model_from_api(data, datasource_type) datasource._project = project return datasource @@ -211,6 +210,8 @@ def _wait_for_metadata(datasource): @staticmethod def _model_from_api(data: Dict, datasource_type: Type[mDataSource]) -> mDataSource: data['datatype'] = data.pop('dataType', None) + data['connector_ref'] = data['connector']['uid'] + data['connector_type'] = data['connector']['type'] data = filter_dict(datasource_type, data) model = datasource_type(**data) return model diff --git a/src/ydata/sdk/synthesizers/multitable.py b/src/ydata/sdk/synthesizers/multitable.py index e9021528..fa1749e7 100644 --- a/src/ydata/sdk/synthesizers/multitable.py +++ b/src/ydata/sdk/synthesizers/multitable.py @@ -1,15 +1,13 @@ from __future__ import annotations from time import sleep -from typing import Dict, List, Optional, Union +from typing import Dict, Optional, Union -from ydata.datascience.common import PrivacyLevel from ydata.sdk.common.client import Client from ydata.sdk.common.config import BACKOFF from ydata.sdk.common.exceptions import ConnectorError, InputError from ydata.sdk.common.types import UID, Project from ydata.sdk.connectors.connector import Connector, ConnectorType -from ydata.sdk.datasources import DataSource from ydata.sdk.datasources._models.datatype import DataSourceType from ydata.sdk.datasources._models.metadata.data_types import DataType from ydata.sdk.synthesizers.synthesizer import BaseSynthesizer @@ -43,17 +41,10 @@ def __init__( connector = self._check_or_fetch_connector(write_connector) self.__write_connector = connector.uid - def fit(self, X: DataSource, - privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY, + def fit(self, X, datatype: Optional[Union[DataSourceType, str]] = None, - sortbykey: Optional[Union[str, List[str]]] = None, - entities: Optional[Union[str, List[str]]] = None, - generate_cols: Optional[List[str]] = None, - exclude_cols: Optional[List[str]] = None, dtypes: Optional[Dict[str, Union[str, DataType]]] = None, - target: Optional[str] = None, - anonymize: Optional[dict] = None, - condition_on: Optional[List[str]] = None) -> None: + anonymize: Optional[dict] = None) -> None: """Fit the synthesizer. The synthesizer accepts as training dataset a YData [`DataSource`][ydata.sdk.datasources.DataSource]. @@ -62,8 +53,7 @@ def fit(self, X: DataSource, Arguments: X (DataSource): DataSource to Train """ - - self._fit_from_datasource(X, datatype=DataSourceType.MULTITABLE) + super().fit(X, datatype=DataSourceType.MULTITABLE) def sample(self, frac: Union[int, float] = 1, write_connector: Optional[Union[Connector, UID]] = None) -> None: """Sample from a [`MultiTableSynthesizer`][ydata.sdk.synthesizers.MultiTableSynthesizer] diff --git a/src/ydata/sdk/synthesizers/regular.py b/src/ydata/sdk/synthesizers/regular.py index c0dd4719..9286a5f0 100644 --- a/src/ydata/sdk/synthesizers/regular.py +++ b/src/ydata/sdk/synthesizers/regular.py @@ -4,7 +4,6 @@ from ydata.datascience.common import PrivacyLevel from ydata.sdk.common.exceptions import InputError -from ydata.sdk.datasources import DataSource from ydata.sdk.datasources._models.datatype import DataSourceType from ydata.sdk.datasources._models.metadata.data_types import DataType from ydata.sdk.synthesizers.synthesizer import BaseSynthesizer @@ -33,7 +32,7 @@ def sample(self, n_samples: int = 1, condition_on: Optional[dict] = None) -> pdD } return self._sample(payload=payload) - def fit(self, X: Union[DataSource, pdDataFrame], + def fit(self, X, privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY, entities: Optional[Union[str, List[str]]] = None, generate_cols: Optional[List[str]] = None, diff --git a/src/ydata/sdk/synthesizers/synthesizer.py b/src/ydata/sdk/synthesizers/synthesizer.py index c2929ec8..7327b3b6 100644 --- a/src/ydata/sdk/synthesizers/synthesizer.py +++ b/src/ydata/sdk/synthesizers/synthesizer.py @@ -17,7 +17,6 @@ from ydata.sdk.common.logger import create_logger from ydata.sdk.common.types import UID, Project from ydata.sdk.connectors import LocalConnector -from ydata.sdk.datasources import DataSource, LocalDataSource from ydata.sdk.datasources._models.attributes import DataSourceAttrs from ydata.sdk.datasources._models.datatype import DataSourceType from ydata.sdk.datasources._models.metadata.data_types import DataType @@ -27,8 +26,11 @@ from ydata.sdk.synthesizers._models.synthesizer import Synthesizer as mSynthesizer from ydata.sdk.synthesizers._models.synthesizers_list import SynthesizersList from ydata.sdk.synthesizers.anonymizer import build_and_validate_anonimization +from ydata.sdk.utils.logger import SDKLogger from ydata.sdk.utils.model_mixin import ModelFactoryMixin +logger = SDKLogger(name="SynthesizersLogger") + @typechecked class BaseSynthesizer(ABC, ModelFactoryMixin): @@ -65,7 +67,7 @@ def _init_common(self, client: Optional[Client] = None): def project(self) -> Project: return self._project or self._client.project - def fit(self, X: Union[DataSource, pdDataFrame], + def fit(self, X, privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY, datatype: Optional[Union[DataSourceType, str]] = None, sortbykey: Optional[Union[str, List[str]]] = None, @@ -100,6 +102,11 @@ def fit(self, X: Union[DataSource, pdDataFrame], anonymize (Optional[str]): (optional) fields to anonymize and the anonymization strategy condition_on: (Optional[List[str]]): (optional) list of features to condition upon """ + + logger.info(dataframe=X, + datatype=datatype.value, + method='synthesizer') + if self._already_fitted(): raise AlreadyFittedError() @@ -107,10 +114,12 @@ def fit(self, X: Union[DataSource, pdDataFrame], dataset_attrs = self._init_datasource_attributes( sortbykey, entities, generate_cols, exclude_cols, dtypes) + self._validate_datasource_attributes(X, dataset_attrs, datatype, target) # If the training data is a pandas dataframe, we first need to create a data source and then the instance if isinstance(X, pdDataFrame): + from ydata.sdk.datasources import LocalDataSource if X.empty: raise EmptyDataError("The DataFrame is empty") self._logger.info('creating local connector with pandas dataframe') @@ -131,9 +140,12 @@ def fit(self, X: Union[DataSource, pdDataFrame], if isinstance(dataset_attrs, dict): dataset_attrs = DataSourceAttrs(**dataset_attrs) - self._fit_from_datasource( - X=_X, datatype=datatype, dataset_attrs=dataset_attrs, target=target, - anonymize=anonymize, privacy_level=privacy_level, condition_on=condition_on) + if datatype == DataSourceType.MULTITABLE: + self._fit_from_datasource(_X, datatype=DataSourceType.MULTITABLE) + else: + self._fit_from_datasource( + X=_X, datatype=datatype, dataset_attrs=dataset_attrs, target=target, + anonymize=anonymize, privacy_level=privacy_level, condition_on=condition_on) @staticmethod def _init_datasource_attributes( @@ -152,37 +164,41 @@ def _init_datasource_attributes( return DataSourceAttrs(**dataset_attrs) @staticmethod - def _validate_datasource_attributes(X: Union[DataSource, pdDataFrame], dataset_attrs: DataSourceAttrs, datatype: DataSourceType, target: Optional[str]): + def _validate_datasource_attributes(X, dataset_attrs: DataSourceAttrs, datatype: DataSourceType, target: Optional[str]): columns = [] if isinstance(X, pdDataFrame): columns = X.columns if datatype is None: raise DataTypeMissingError( "Argument `datatype` is mandatory for pandas.DataFrame training data") + elif datatype == DataSourceType.MULTITABLE: + tables = [t for t in X.tables.keys()] # noqa: F841 + # Does it make sense to add more validations here? else: columns = [c.name for c in X.metadata.columns] - if target is not None and target not in columns: - raise DataSourceAttrsError( - "Invalid target: column '{target}' does not exist") - - if datatype == DataSourceType.TIMESERIES: - if not dataset_attrs.sortbykey: + if datatype != DataSourceType.MULTITABLE: + if target is not None and target not in columns: raise DataSourceAttrsError( - "The argument `sortbykey` is mandatory for timeseries datasource.") - - invalid_fields = {} - for field, v in dataset_attrs.dict().items(): - field_columns = v if field != 'dtypes' else v.keys() - not_in_cols = [c for c in field_columns if c not in columns] - if len(not_in_cols) > 0: - invalid_fields[field] = not_in_cols + "Invalid target: column '{target}' does not exist") - if len(invalid_fields) > 0: - error_msgs = ["\t- Field '{}': columns {} do not exist".format( - f, ', '.join(v)) for f, v in invalid_fields.items()] - raise DataSourceAttrsError( - "The dataset attributes are invalid:\n {}".format('\n'.join(error_msgs))) + if datatype == DataSourceType.TIMESERIES: + if not dataset_attrs.sortbykey: + raise DataSourceAttrsError( + "The argument `sortbykey` is mandatory for timeseries datasource.") + + invalid_fields = {} + for field, v in dataset_attrs.dict().items(): + field_columns = v if field != 'dtypes' else v.keys() + not_in_cols = [c for c in field_columns if c not in columns] + if len(not_in_cols) > 0: + invalid_fields[field] = not_in_cols + + if len(invalid_fields) > 0: + error_msgs = ["\t- Field '{}': columns {} do not exist".format( + f, ', '.join(v)) for f, v in invalid_fields.items()] + raise DataSourceAttrsError( + "The dataset attributes are invalid:\n {}".format('\n'.join(error_msgs))) @staticmethod def _metadata_to_payload( @@ -225,7 +241,7 @@ def _metadata_to_payload( def _fit_from_datasource( self, - X: DataSource, + X, datatype: DataSourceType, privacy_level: Optional[PrivacyLevel] = None, dataset_attrs: Optional[DataSourceAttrs] = None, diff --git a/src/ydata/sdk/utils/logger.py b/src/ydata/sdk/utils/logger.py new file mode 100644 index 00000000..5f3b61a9 --- /dev/null +++ b/src/ydata/sdk/utils/logger.py @@ -0,0 +1,105 @@ +""" + In this file it can be found both the logic for the logger and decorator function +""" +import contextlib +import logging +import os +import platform +import subprocess + +import pandas as pd +import requests + +from ydata.sdk import __version__ +from ydata.sdk.datasources._models.datatype import DataSourceType + + +def is_running_in_databricks(): + mask = "DATABRICKS_RUNTIME_VERSION" in os.environ + if "DATABRICKS_RUNTIME_VERSION" in os.environ: + return os.environ["DATABRICKS_RUNTIME_VERSION"] + else: + return str(mask) + + +def get_datasource_info(dataframe, datatype): + """ + calculate required datasource info + """ + if isinstance(dataframe, pd.DataFrame): + connector = 'csv' + nrows, ncols = dataframe.shape[0], dataframe.shape[1] + ntables = None # calculate the number of rows and cols + else: + connector = dataframe.connector_type + if DataSourceType(datatype) != DataSourceType.MULTITABLE: + nrows = dataframe.metadata.number_of_rows + ncols = len(dataframe.metadata.columns) + ntables = 1 + else: + nrows = 0 + ncols = 0 + ntables = len(dataframe.tables.keys()) + return connector, nrows, ncols, ntables + + +def analytics_features(datatype: str, connector: str, nrows: int, ncols: int, ntables: int, method: str, dbx: str) -> None: + """ + Returns metrics and analytics from ydata-fabric-sdk + """ + endpoint = "https://packages.ydata.ai/ydata-fabric-sdk?" + package_version = __version__ + + if ( + bool(os.getenv("YDATA_FABRIC_SDK_NO_ANALYTICS") + ) is not True and package_version != "0.0.dev0" + ): + try: + subprocess.check_output("nvidia-smi") + gpu_present = True + except Exception: + gpu_present = False + + python_version = ".".join(platform.python_version().split(".")[:2]) + + with contextlib.suppress(Exception): + request_message = ( + f"{endpoint}python_version={python_version}" + f"&datatype={datatype}" + f"&connector={connector}" + f"&ncols={ncols}" + f"&nrows={nrows}" + f"&ntables={ntables}" + f"&method={method}" + f"&os={platform.system()}" + f"&gpu={str(gpu_present)}" + f"&dbx={dbx}" + ) + + requests.get(request_message) + + +class SDKLogger(logging.Logger): + def __init__(self, name: str, level: int = logging.INFO): + super().__init__(name, level) + + def info(self, dataframe, datatype: str, method: str) -> None: # noqa: ANN001 + + dbx = is_running_in_databricks() + + connector, nrows, ncols, ntables = get_datasource_info(dataframe, datatype) + + analytics_features( + datatype=datatype, + method=method, + connector=connector, + nrows=nrows, + ncols=ncols, + ntables=ntables, + dbx=dbx + ) + + super().info( + f"[PROFILING] Calculating profile with the following characteristics " + f"- {datatype} | {method} | {connector}." + )