Skip to content

Commit

Permalink
ds: add note_glide field to model
Browse files Browse the repository at this point in the history
  • Loading branch information
SoulMelody committed Sep 13, 2024
1 parent 1d368d7 commit 279414c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
17 changes: 2 additions & 15 deletions libresvip/plugins/acep/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@

from more_itertools import batched, minmax
from pydantic import (
Discriminator,
Field,
FieldSerializationInfo,
RootModel,
Tag,
ValidationInfo,
field_serializer,
field_validator,
Expand Down Expand Up @@ -338,20 +336,9 @@ class AcepChordTrack(AcepTrackProperties, BaseModel):
patterns: list[AcepChordPattern] = Field(default_factory=list)


def get_discriminator_value(v: Any) -> str:
if isinstance(v, dict):
return v.get("type", None)
return getattr(v, "type_", None)


AcepTrack = Annotated[
Union[
Annotated[AcepAudioTrack, Tag("audio")],
Annotated[AcepEmptyTrack, Tag("empty")],
Annotated[AcepVocalTrack, Tag("sing")],
Annotated[AcepChordTrack, Tag("chord")],
],
Discriminator(get_discriminator_value),
Union[AcepAudioTrack, AcepEmptyTrack, AcepVocalTrack, AcepChordTrack],
Field(discriminator="type_"),
]


Expand Down
25 changes: 24 additions & 1 deletion libresvip/plugins/ds/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
from typing import Literal, Optional, Union

from pydantic import (
Expand All @@ -12,13 +13,20 @@
from libresvip.model.base import BaseModel


class GlideStyle(enum.Enum):
NONE = "none"
UP = "up"
DOWN = "down"


class DsItem(BaseModel):
text: list[str]
ph_seq: list[str]
note_seq: list[str]
note_dur: Optional[list[float]] = None
note_dur_seq: Optional[list[float]] = None
note_slur: Optional[list[int]] = None
note_glide: Optional[list[GlideStyle]] = None
is_slur_seq: Optional[list[int]] = None
ph_dur: Optional[list[float]] = None
ph_num: Optional[list[int]] = None
Expand All @@ -36,9 +44,16 @@ class DsItem(BaseModel):

@field_validator("text", "note_seq", "ph_seq", mode="before")
@classmethod
def _validate_str_list(cls, value: Optional[str], _info: ValidationInfo) -> list[str]:
def _validate_str_list(cls, value: Optional[str], _info: ValidationInfo) -> Optional[list[str]]:
return None if value is None else value.split()

@field_validator("note_glide", mode="before")
@classmethod
def _validate_glide_list(
cls, value: Optional[str], _info: ValidationInfo
) -> Optional[list[GlideStyle]]:
return None if value is None else [GlideStyle(x) for x in value.split()]

@field_validator(
"f0_seq",
"ph_dur",
Expand Down Expand Up @@ -76,6 +91,14 @@ def _validate_int_list(cls, value: Optional[str], _info: ValidationInfo) -> list
def _serialize_list(cls, value: list[Union[str, int, float]], _info: SerializationInfo) -> str:
return " ".join(str(x) for x in value)

@field_serializer(
"note_glide",
when_used="json-unless-none",
)
@classmethod
def _serialize_glide(cls, value: list[GlideStyle], _info: SerializationInfo) -> str:
return " ".join(x.value for x in value)

@field_validator("spk_mix", mode="before")
@classmethod
def _validate_nested_dict(
Expand Down

0 comments on commit 279414c

Please sign in to comment.