Skip to content

Commit

Permalink
fix: Error while training a synthetic data generation (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
portellaa authored Mar 4, 2024
1 parent 8bead75 commit fc39ea9
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 19 deletions.
14 changes: 14 additions & 0 deletions src/ydata/sdk/common/status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Generic, Optional, TypeVar

from pydantic import Field

from .model import BaseModel

T = TypeVar("T")


class GenericStateErrorStatus(BaseModel, Generic[T]):
state: Optional[T] = Field(None)

class Config:
use_enum_values = True
10 changes: 9 additions & 1 deletion src/ydata/sdk/datasources/_models/datasource.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Optional

from pydantic.dataclasses import dataclass

from ydata.sdk.common.types import UID
from ydata.sdk.datasources._models.datatype import DataSourceType
from ydata.sdk.datasources._models.metadata.metadata import Metadata
Expand All @@ -16,5 +17,12 @@ class DataSource:
metadata: Optional[Metadata] = None
status: Optional[Status] = None

def __post_init__(self):
if self.metadata is not None:
self.metadata = Metadata(**self.metadata)

if self.status is not None:
self.status = Status(**self.status)

def to_payload(self):
return {}
22 changes: 18 additions & 4 deletions src/ydata/sdk/datasources/_models/status.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import Optional

from pydantic import Field

from ydata.core.enum import StringEnum
from ydata.sdk.common.model import BaseModel
from ydata.sdk.common.status import GenericStateErrorStatus


class ValidationState(StringEnum):
Expand Down Expand Up @@ -58,8 +63,17 @@ class State(StringEnum):
"""


ValidationStatus = GenericStateErrorStatus[ValidationState]
MetadataStatus = GenericStateErrorStatus[MetadataState]
ProfilingStatus = GenericStateErrorStatus[ProfilingState]


class Status(BaseModel):
state: State
validation: ValidationState
metadata: MetadataState
profiling: ProfilingState
state: Optional[State] = Field(None)
validation: Optional[ValidationStatus] = Field(None)
metadata: Optional[MetadataStatus] = Field(None)
profiling: Optional[ProfilingStatus] = Field(None)

@staticmethod
def unknown() -> "Status":
return Status(state=Status.State.UNKNOWN)
5 changes: 3 additions & 2 deletions src/ydata/sdk/datasources/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ def project(self) -> Project:
@property
def status(self) -> Status:
try:
self._model = self.get(self._model.uid, self._client)._model
self._model = self.get(uid=self._model.uid,
project=self._project, client=self._client)._model
return self._model.status
except Exception: # noqa: PIE786
return Status.UNKNOWN
return Status.unknown()

@property
def metadata(self) -> Optional[Metadata]:
Expand Down
12 changes: 2 additions & 10 deletions src/ydata/sdk/synthesizers/_models/status.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
from typing import Generic, Optional, TypeVar
from typing import Optional

from pydantic import Field

from ydata.core.enum import StringEnum
from ydata.sdk.common.model import BaseModel

T = TypeVar("T")


class GenericStateErrorStatus(BaseModel, Generic[T]):
state: Optional[T] = Field(None)

class Config:
use_enum_values = True
from ydata.sdk.common.status import GenericStateErrorStatus


class PrepareState(StringEnum):
Expand Down
6 changes: 4 additions & 2 deletions src/ydata/sdk/synthesizers/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,10 @@ def _already_fitted(self) -> bool:
True if the synthesizer is instanciated
"""

return self._is_initialized() and (self._model.status and self._model.status.training and
self._model.status.training.state is not [TrainingState.PREPARING])
return self._is_initialized() and \
(self._model.status is not None
and self._model.status.training is not None
and self._model.status.training.state is not [TrainingState.PREPARING])

@staticmethod
def _resolve_api_status(api_status: Dict) -> Status:
Expand Down

0 comments on commit fc39ea9

Please sign in to comment.