-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathsegment.py
68 lines (54 loc) · 1.97 KB
/
segment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
import logging
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
import torch
import numpy as np
import itertools
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def segment_file(segmenter, in_path, out_path):
log_probs = np.load(in_path)
segments, boundaries = segmenter(log_probs)
np.savez(out_path.with_suffix(".npz"), segments=segments, boundaries=boundaries)
return log_probs.shape[0], np.mean(np.diff(boundaries))
def segment_dataset(args):
logging.info("Loading segmenter checkpoint")
segmenter = torch.hub.load("bshall/urhythmic:main", "segmenter", num_clusters=3)
in_paths = list(args.in_dir.rglob("*.npy"))
out_paths = [args.out_dir / path.relative_to(args.in_dir) for path in in_paths]
logger.info("Setting up folder structure")
for path in tqdm(out_paths):
path.parent.mkdir(exist_ok=True, parents=True)
logger.info("Segmenting dataset")
with ProcessPoolExecutor(max_workers=4) as executor:
results = list(
tqdm(
executor.map(
segment_file,
itertools.repeat(segmenter),
in_paths,
out_paths,
),
total=len(in_paths),
)
)
frames, boundary_length = zip(*results)
logger.info(f"Segmented {sum(frames) * 0.02 / 60 / 60:.2f} hours of audio")
logger.info(
f"Average segment length: {np.mean(boundary_length) * 0.02:.4f} seconds"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Segment an audio dataset.")
parser.add_argument(
"in_dir",
metavar="in-dir",
type=Path,
help="path to the log probability directory.",
)
parser.add_argument(
"out_dir", metavar="out-dir", type=Path, help="path to the output directory."
)
args = parser.parse_args()
segment_dataset(args)