Skip to content

Commit

Permalink
add ability to use Connector as argument
Browse files Browse the repository at this point in the history
  • Loading branch information
portellaa committed Jan 12, 2024
1 parent 41d5969 commit bb339b9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/ydata/sdk/connectors/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def uid(self) -> UID:
return self._model.uid

@property
def type(self) -> str:
def type(self) -> ConnectorType:
return self._model.type

@staticmethod
Expand Down
28 changes: 23 additions & 5 deletions src/ydata/sdk/synthesizers/multitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from ydata.datascience.common import PrivacyLevel
from ydata.sdk.common.client import Client
from ydata.sdk.common.config import BACKOFF
from ydata.sdk.common.exceptions import InputError
from ydata.sdk.common.exceptions import InputError, ConnectorError
from ydata.sdk.common.types import UID, Project
from ydata.sdk.connectors.connector import Connector, ConnectorType
from ydata.sdk.datasources import DataSource
from ydata.sdk.datasources._models.datatype import DataSourceType
from ydata.sdk.datasources._models.metadata.data_types import DataType
Expand All @@ -30,10 +31,11 @@ class MultiTableSynthesizer(BaseSynthesizer):
"""

def __init__(
self, write_connector: UID, uid: UID | None = None, name: str | None = None,
self, write_connector: Connector | UID, uid: UID | None = None, name: str | None = None,
project: Project | None = None, client: Client | None = None):

self.__write_connector = write_connector
connector = self._check_or_fetch_connector(write_connector)
self.__write_connector = connector.uid

super().__init__(uid, name, project, client)

Expand All @@ -59,7 +61,7 @@ def fit(self, X: DataSource,

self._fit_from_datasource(X)

def sample(self, frac: int | float = 1, write_connector: UID | None = None) -> None:
def sample(self, frac: int | float = 1, write_connector: Connector | UID | None = None) -> None:
"""Sample from a [`MultiTableSynthesizer`][ydata.sdk.synthesizers.MultiTableSynthesizer]
instance.
The sample is saved in the connector that was provided in the synthesizer initialization
Expand All @@ -79,7 +81,8 @@ def sample(self, frac: int | float = 1, write_connector: UID | None = None) -> N
}

if write_connector is not None:
payload['writeConnector'] = write_connector
connector = self._check_or_fetch_connector(write_connector)
payload['writeConnector'] = connector.uid

response = self._client.post(
f"/synthesizer/{self.uid}/sample", json=payload, project=self._project)
Expand All @@ -104,3 +107,18 @@ def _create_payload(self) -> dict:
payload['writeConnector'] = self.__write_connector

return payload

def _check_or_fetch_connector(self, write_connector: Connector | UID) -> Connector:
self._logger.debug(f'Write connector is {write_connector}')
if isinstance(write_connector, str):
self._logger.debug(f'Write connector is of type `UID` {write_connector}')
write_connector = Connector.get(write_connector)
self._logger.debug(f'Using fetched connector {write_connector}')

if write_connector.uid is None:
raise InputError("Invalid connector provided as input for write")

if write_connector.type not in [ConnectorType.AZURE_SQL, ConnectorType.MYSQL, ConnectorType.SNOWFLAKE]:
raise ConnectorError(f"Invalid type `{write_connector.type}` for the provided connector")

return write_connector

0 comments on commit bb339b9

Please sign in to comment.