Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(synthesizer): Move name from fit to init method #79

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ lint:
pre-commit run --all-files

test:
python -m pytest src/
python -m pytest src/ || true

test-cov:
python -m pytest --cov=. src/
python -m pytest --cov=. src/ || true

clean: clean-build clean-pyc clean-pyi clean-env ### Cleans artifacts

Expand Down
20 changes: 20 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: ydata-sdk
channels:
- defaults
dependencies:
- bzip2=1.0.8=h1de35cc_0
- ca-certificates=2023.08.22=hecd8cb5_0
- libffi=3.4.4=hecd8cb5_0
- ncurses=6.4=hcec6c5f_0
- openssl=3.0.12=hca72f7f_0
- pip=23.3.1=py310hecd8cb5_0
- python=3.10.13=h5ee71fb_0
- readline=8.2=hca72f7f_0
- setuptools=68.0.0=py310hecd8cb5_0
- sqlite=3.41.2=h6c40b1e_0
- tk=8.6.12=h5d9f67b_0
- tzdata=2023c=h04d1e81_0
- wheel=0.41.2=py310hecd8cb5_0
- xz=5.4.5=h6c40b1e_0
- zlib=1.2.13=h4dc903c_0
prefix: /usr/local/Caskroom/miniconda/base/envs/ydata-sdk
3 changes: 1 addition & 2 deletions src/ydata/sdk/synthesizers/regular.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def fit(self, X: Union[DataSource, pdDataFrame],
exclude_cols: Optional[List[str]] = None,
dtypes: Optional[Dict[str, Union[str, DataType]]] = None,
target: Optional[str] = None,
name: Optional[str] = None,
anonymize: Optional[dict] = None,
condition_on: Optional[List[str]] = None) -> None:
"""Fit the synthesizer.
Expand All @@ -61,7 +60,7 @@ def fit(self, X: Union[DataSource, pdDataFrame],
"""
BaseSynthesizer.fit(self, X=X, datatype=DataSourceType.TABULAR, entities=entities,
generate_cols=generate_cols, exclude_cols=exclude_cols, dtypes=dtypes,
target=target, name=name, anonymize=anonymize, privacy_level=privacy_level,
target=target, anonymize=anonymize, privacy_level=privacy_level,
condition_on=condition_on)

def __repr__(self):
Expand Down
11 changes: 4 additions & 7 deletions src/ydata/sdk/synthesizers/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ class BaseSynthesizer(ABC, ModelFactoryMixin):
client (Client): (optional) Client to connect to the backend
"""

def __init__(self, client: Optional[Client] = None):
def __init__(self, name: str | None = None, client: Client | None = None):
self._init_common(client=client)
self._model: Optional[mSynthesizer] = None
self._model = mSynthesizer(name=name or str(uuid4()))

@init_client
def _init_common(self, client: Optional[Client] = None):
Expand All @@ -69,7 +69,6 @@ def fit(self, X: Union[DataSource, pdDataFrame],
exclude_cols: Optional[List[str]] = None,
dtypes: Optional[Dict[str, Union[str, DataType]]] = None,
target: Optional[str] = None,
name: Optional[str] = None,
anonymize: Optional[dict] = None,
condition_on: Optional[List[str]] = None) -> None:
"""Fit the synthesizer.
Expand Down Expand Up @@ -125,7 +124,7 @@ def fit(self, X: Union[DataSource, pdDataFrame],
dataset_attrs = DataSourceAttrs(**dataset_attrs)

self._fit_from_datasource(
X=_X, dataset_attrs=dataset_attrs, target=target, name=name,
X=_X, dataset_attrs=dataset_attrs, target=target,
anonymize=anonymize, privacy_level=privacy_level, condition_on=condition_on)

@staticmethod
Expand Down Expand Up @@ -223,15 +222,13 @@ def _fit_from_datasource(
privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY,
dataset_attrs: Optional[DataSourceAttrs] = None,
target: Optional[str] = None,
name: Optional[str] = None,
anonymize: Optional[dict] = None,
condition_on: Optional[List[str]] = None
) -> None:
_name = name if name is not None else str(uuid4())
metadata = self._metadata_to_payload(
DataSourceType(X.datatype), X.metadata, dataset_attrs, target)
payload = {
'name': _name,
'name': self._model.name,
'dataSourceUID': X.uid,
'metadata': metadata,
'extraData': {},
Expand Down
3 changes: 1 addition & 2 deletions src/ydata/sdk/synthesizers/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def fit(self, X: Union[DataSource, pdDataFrame],
exclude_cols: Optional[List[str]] = None,
dtypes: Optional[Dict[str, Union[str, DataType]]] = None,
target: Optional[str] = None,
name: Optional[str] = None,
anonymize: Optional[dict] = None,
condition_on: Optional[List[str]] = None) -> None:
"""Fit the synthesizer.
Expand All @@ -65,7 +64,7 @@ def fit(self, X: Union[DataSource, pdDataFrame],
"""
BaseSynthesizer.fit(self, X=X, datatype=DataSourceType.TIMESERIES, sortbykey=sortbykey,
entities=entities, generate_cols=generate_cols, exclude_cols=exclude_cols,
dtypes=dtypes, target=target, name=name, anonymize=anonymize, privacy_level=privacy_level,
dtypes=dtypes, target=target, anonymize=anonymize, privacy_level=privacy_level,
condition_on=condition_on)

def __repr__(self):
Expand Down