-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathhubconf.py
284 lines (243 loc) · 11.6 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
dependencies = ["torch", "torchaudio", "numpy", "scipy", "numba", "sklearn"]
URLS = {
"segmenter-3": "https://github.com/bshall/urhythmic/releases/download/v0.1/segmenter-3-61beaeac.pt",
"segmenter-8": "https://github.com/bshall/urhythmic/releases/download/v0.1/segmenter-8-b3d14f93.pt",
"rhythm-model-fine-grained": "https://github.com/bshall/urhythmic/releases/download/v0.1/rhythm-fine-143621e1.pt",
"rhythm-model-global": "https://github.com/bshall/urhythmic/releases/download/v0.1/rhythm-global-745d52d8.pt",
"hifigan-p228": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-p228-4ab1748f.pt",
"hifigan-p268": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-p268-36a1d51a.pt",
"hifigan-p225": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-p225-cc447edc.pt",
"hifigan-p232": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-p232-e0efc4c3.pt",
"hifigan-p257": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-p257-06fd495b.pt",
"hifigan-p231": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-p231-250198a1.pt",
"hifigan-LJSpeech": "https://github.com/bshall/urhythmic/releases/download/v0.1/hifigan-LJSpeech-ceb1368d.pt",
}
SPEAKERS = {"p228", "p268", "p225", "p232", "p257", "p231", "LJSpeech"}
from typing import Tuple, Callable
import torch
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from urhythmic.model import UrhythmicFine, UrhythmicGlobal, encode
from urhythmic.segmenter import Segmenter
from urhythmic.rhythm import RhythmModelFineGrained, RhythmModelGlobal
from urhythmic.stretcher import TimeStretcherFineGrained, TimeStretcherGlobal
from urhythmic.vocoder import HifiganGenerator, HifiganDiscriminator
def segmenter(
num_clusters: int,
gamma: float = 2,
pretrained: bool = True,
progress=True,
) -> Segmenter:
"""Segmentation and clustering block. Groups similar speech units into short segments.
The segments are then combined into coarser groups approximating sonorants, obstruents, and silences.
Args:
num_clusters (int): number of clusters used for agglomerative clustering.
gamma (float): regularizer weight encouraging longer segments
pretrained (bool): load pretrained weights into the model.
progress (bool): show progress bar when downloading model.
Returns:
Segmenter: the segmentation and clustering block (optionally pretrained).
"""
segmenter = Segmenter(num_clusters=num_clusters, gamma=gamma)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
URLS[f"segmenter-{num_clusters}"],
progress=progress,
)
segmenter.load_state_dict(checkpoint)
return segmenter
def rhythm_model_fine_grained(
source_speaker: None | str,
target_speaker: None | str,
pretrained: bool = True,
progress=True,
) -> RhythmModelFineGrained:
"""Rhythm modeling block (Fine-Grained). Estimates the duration distribution of each sound type.
Available speakers:
VCTK: p228, p268, p225, p232, p257, p231.
LJSpeech.
Args:
source_speaker (None | str): the source speaker. None to fit your own source speaker or a selection from the available speakers.
target_speaker (None | str): the target speaker. None to fit your own source speaker or a selection from the available speakers.
pretrained (bool): load pretrained weights into the model.
progress (bool): show progress bar when downloading model.
Returns:
RhythmModelFineGrained: the fine-grained rhythm modeling block (optionally preloaded with source and target duration models).
"""
if source_speaker is not None and source_speaker not in SPEAKERS:
raise ValueError(f"source speaker is not in available set: {SPEAKERS}")
if target_speaker is not None and target_speaker not in SPEAKERS:
raise ValueError(f"target speaker is not in available set: {SPEAKERS}")
rhythm_model = RhythmModelFineGrained()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
URLS["rhythm-model-fine-grained"],
progress=progress,
)
state_dict = {}
if target_speaker:
state_dict["target"] = checkpoint[target_speaker]
if source_speaker:
state_dict["source"] = checkpoint[source_speaker]
rhythm_model.load_state_dict(state_dict)
return rhythm_model
def rhythm_model_global(
source_speaker: None | str,
target_speaker: None | str,
pretrained: bool = True,
progress=True,
) -> RhythmModelGlobal:
"""Rhythm modeling block (Global). Estimates speaking rate.
Available speakers:
VCTK: p228, p268, p225, p232, p257, p231.
LJSpeech.
Args:
source_speaker (None | str): the source speaker. None to fit your own source speaker or a selection from the available speakers.
target_speaker (None | str): the target speaker. None to fit your own source speaker or a selection from the available speakers.
pretrained (bool): load pretrained weights into the model.
progress (bool): show progress bar when downloading model.
Returns:
RhythmModelGlobal: the global rhythm modeling block (optionally preloaded with source and target speaking rates).
"""
if source_speaker is not None and source_speaker not in SPEAKERS:
raise ValueError(f"source speaker is not in available set: {SPEAKERS}")
if target_speaker is not None and target_speaker not in SPEAKERS:
raise ValueError(f"target speaker is not in available set: {SPEAKERS}")
rhythm_model = RhythmModelGlobal()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
URLS["rhythm-model-global"],
progress=progress,
)
state_dict = {}
if target_speaker:
state_dict["target_rate"] = checkpoint[target_speaker]
if source_speaker:
state_dict["source_rate"] = checkpoint[source_speaker]
rhythm_model.load_state_dict(state_dict)
return rhythm_model
def hifigan_generator(
speaker: None | str,
pretrained: bool = True,
progress: bool = True,
map_location=None,
) -> HifiganGenerator:
"""HifiGAN Generator.
Available speakers:
VCTK: p228, p268, p225, p232, p257, p231.
LJSpeech.
Args:
speaker (None | str): the target speaker. None to fit your own speaker or a selection from the available speakers.
pretrained (bool): load pretrained weights into the model.
progress (bool): show progress bar when downloading model.
map_location: function or a dict specifying how to remap storage locations (see torch.load)
Returns:
HifiganGenerator: the HifiGAN Generator (pretrained on LJSpeech or one of the VCTK speakers).
"""
if speaker is not None and speaker not in SPEAKERS:
raise ValueError(f"target speaker is not in available set: {SPEAKERS}")
hifigan = HifiganGenerator()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
URLS[f"hifigan-{speaker}"], map_location=map_location, progress=progress
)
consume_prefix_in_state_dict_if_present(
checkpoint["generator"]["model"], "module."
)
hifigan.load_state_dict(checkpoint["generator"]["model"])
hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
def hifigan_discriminator(
pretrained: bool = True, progress: bool = True, map_location=None
) -> HifiganDiscriminator:
"""HifiGAN Discriminator.
Args:
pretrained (bool): load pretrained weights into the model.
progress (bool): show progress bar when downloading model.
map_location: function or a dict specifying how to remap storage locations (see torch.load)
Returns:
HifiganDiscriminator: the HifiGAN Discriminator (pretrained on LJSpeech).
"""
discriminator = HifiganDiscriminator()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
URLS["hifigan-LJSpeech"], map_location=map_location, progress=progress
)
consume_prefix_in_state_dict_if_present(
checkpoint["discriminator"]["model"], "module."
)
discriminator.load_state_dict(checkpoint["discriminator"]["model"])
discriminator.eval()
return discriminator
def urhythmic_fine(
source_speaker: str | None,
target_speaker: str | None,
pretrained: bool = True,
progress: bool = True,
map_location=None,
) -> Tuple[UrhythmicFine, Callable]:
"""Urhythmic (Fine-Grained), a voice and rhythm conversion system that does not require text or parallel data.
Available speakers:
VCTK: p228, p268, p225, p232, p257, p231.
LJSpeech.
Args:
source_speaker (None | str): the source speaker. None to fit your own source speaker or a selection from the available speakers.
target_speaker (None | str): the target speaker. None to fit your own source speaker or a selection from the available speakers.
pretrained (bool): load pretrained weights into the model.
progress (bool): show progress bar when downloading model.
map_location: function or a dict specifying how to remap storage locations (see torch.load)
Returns:
UrhythmicFine: the Fine-Grained Urhythmic model.
Callable: the encode function to extract soft speech units and log probabilies using HubertSoft.
"""
seg = segmenter(num_clusters=3, gamma=2, pretrained=pretrained, progress=progress)
rhythm_model = rhythm_model_fine_grained(
source_speaker=source_speaker,
target_speaker=target_speaker,
pretrained=pretrained,
progress=progress,
)
time_stretcher = TimeStretcherFineGrained()
vocoder = hifigan_generator(
speaker=target_speaker,
pretrained=pretrained,
progress=progress,
map_location=map_location,
)
return UrhythmicFine(seg, rhythm_model, time_stretcher, vocoder), encode
def urhythmic_global(
source_speaker: str | None,
target_speaker: str | None,
pretrained: bool = True,
progress: bool = True,
map_location=None,
) -> Tuple[UrhythmicGlobal, Callable]:
"""Urhythmic (Global), a voice and rhythm conversion system that does not require text or parallel data.
Available speakers:
VCTK: p228, p268, p225, p232, p257, p231.
LJSpeech.
Args:
source_speaker (None | str): the source speaker. None to fit your own source speaker or a selection from the available speakers.
target_speaker (None | str): the target speaker. None to fit your own source speaker or a selection from the available speakers.
pretrained (bool): load pretrained weights into the model.
progress (bool): show progress bar when downloading model.
map_location: function or a dict specifying how to remap storage locations (see torch.load)
Returns:
UrhythmicFine: the Fine-Grained Urhythmic model.
Callable: the encode function to extract soft speech units and log probabilies using HubertSoft.
"""
seg = segmenter(num_clusters=3, gamma=2, pretrained=pretrained, progress=progress)
rhythm_model = rhythm_model_global(
source_speaker=source_speaker,
target_speaker=target_speaker,
pretrained=pretrained,
progress=progress,
)
time_stretcher = TimeStretcherGlobal()
vocoder = hifigan_generator(
speaker=target_speaker,
pretrained=pretrained,
progress=progress,
map_location=map_location,
)
return UrhythmicGlobal(seg, rhythm_model, time_stretcher, vocoder), encode