Skip to content

Commit

Permalink
refactor: replace pitch_tags with PiecewiseIntervalDict
Browse files Browse the repository at this point in the history
  • Loading branch information
SoulMelody committed Aug 6, 2024
1 parent b3fd8d7 commit 085a201
Showing 1 changed file with 31 additions and 46 deletions.
77 changes: 31 additions & 46 deletions libresvip/model/pitch_simulator.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,70 @@
import dataclasses
import functools

import portion

from libresvip.core.time_interval import PiecewiseIntervalDict
from libresvip.core.time_sync import TimeSynchronizer
from libresvip.model.base import Note
from libresvip.model.portamento import PortamentoPitch
from libresvip.utils.search import find_last_index


@dataclasses.dataclass
class PitchSimulator:
synchronizer: TimeSynchronizer
note_list: list[Note]
portamento: PortamentoPitch
pitch_tags: list[tuple[float, int]] = dataclasses.field(default_factory=list)
note_list: dataclasses.InitVar[list[Note]]
interval_dict: PiecewiseIntervalDict = dataclasses.field(default_factory=PiecewiseIntervalDict)

def __post_init__(self) -> None:
if len(self.note_list) == 0:
def __post_init__(self, note_list: list[Note]) -> None:
if not note_list:
return
max_portamento_time = self.portamento.max_inter_time_in_secs
max_portamento_percent = self.portamento.max_inter_time_percent

current_note = self.note_list[0]
current_note = note_list[0]
current_head = self.synchronizer.get_actual_secs_from_ticks(current_note.start_pos)
current_dur = self.synchronizer.get_duration_secs_from_ticks(
current_note.start_pos, current_note.end_pos
)
current_portamento = min(current_dur * max_portamento_percent, max_portamento_time)

self.pitch_tags.append((current_head, current_note.key_number))
for i in range(len(self.note_list) - 1):
next_note = self.note_list[i + 1]
self.interval_dict[portion.closedopen(0.0, current_head)] = current_note.key_number
prev_portamento_end = current_head
for next_note in note_list[1:]:
next_head = self.synchronizer.get_actual_secs_from_ticks(next_note.start_pos)
next_dur = self.synchronizer.get_duration_secs_from_ticks(
next_note.start_pos, next_note.end_pos
)
next_portamento = min(next_dur * max_portamento_percent, max_portamento_time)
interval = next_head - current_head - current_dur
if interval <= 2 * max_portamento_time:
self.pitch_tags.append(
(
next_head - interval / 2 - current_portamento,
current_note.key_number,
)
)
self.pitch_tags.append(
(
next_head - interval / 2 + next_portamento,
next_note.key_number,
)
)
current_portamento_start = next_head - interval / 2 - current_portamento
current_portamento_end = next_head - interval / 2 + next_portamento
else:
self.pitch_tags.append(
(
next_head - interval / 2 - max_portamento_time,
current_note.key_number,
)
)
self.pitch_tags.append(
(
next_head - interval / 2 + max_portamento_time,
next_note.key_number,
)
)
current_portamento_start = next_head - interval / 2 - max_portamento_time
current_portamento_end = next_head - interval / 2 + max_portamento_time
self.interval_dict[
portion.closedopen(prev_portamento_end, current_portamento_start)
] = current_note.key_number
self.interval_dict[
portion.closedopen(current_portamento_start, current_portamento_end)
] = functools.partial( # type: ignore[call-arg]
self.portamento.inter_func,
start=(current_portamento_start, current_note.key_number),
end=(current_portamento_end, next_note.key_number),
)
current_note = next_note
current_head = next_head
current_dur = next_dur
current_portamento = next_portamento
self.pitch_tags.append((current_head + current_dur, current_note.key_number))
prev_portamento_end = current_portamento_end
self.interval_dict[portion.closedopen(prev_portamento_end, portion.inf)] = (
current_note.key_number
)

def pitch_at_ticks(self, ticks: int) -> float:
return self.pitch_at_secs(self.synchronizer.get_actual_ticks_from_ticks(ticks))

def pitch_at_secs(self, secs: float) -> float:
index = find_last_index(self.pitch_tags, lambda tag: tag[0] <= secs)
if index == -1:
value = self.pitch_tags[0][1]
elif index == len(self.pitch_tags) - 1:
value = self.pitch_tags[-1][1]
elif self.pitch_tags[index][1] == self.pitch_tags[index + 1][1]:
value = self.pitch_tags[index][1]
else:
value = self.portamento.inter_func( # type: ignore[assignment]
secs, self.pitch_tags[index], self.pitch_tags[index + 1]
)
return value * 100
return self.interval_dict.get(secs, 0.0) * 100

0 comments on commit 085a201

Please sign in to comment.