-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathf0_disp.py
150 lines (122 loc) · 6.81 KB
/
f0_disp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import json
import argparse
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
def parse_time(time_str):
"""Convert mm:ss format to seconds."""
if time_str is None:
return None
try:
minutes, seconds = map(int, time_str.split(":"))
return minutes * 60 + seconds
except ValueError:
raise argparse.ArgumentTypeError(f"Invalid time format: {time_str}. Use mm:ss.")
def load_json(json_file):
"""Load JSON file containing f0 analysis."""
with open(json_file, "r") as f:
return json.load(f)
def format_time(x, _):
"""Format x-axis ticks as mm:ss."""
minutes = int(x // 60)
seconds = int(x % 60)
return f"{minutes}:{seconds:02d}"
def plot_audio_and_f0(wav_file, start_time=None, clip_length=None, f0_cutoff=1000):
"""Plot waveform and spectrogram with F0 markings for each channel and the mono mix."""
# Load corresponding JSON file
json_file = os.path.splitext(wav_file)[0] + "_f0.json"
if not os.path.exists(json_file):
print(f"Error: JSON file {json_file} not found. Run the F0 analysis first.")
return
data = load_json(json_file)
# Load WAV file
audio_signal, sample_rate = librosa.load(wav_file, sr=None, mono=False)
# Ensure multi-channel format
if audio_signal.ndim == 1:
audio_signal = np.expand_dims(audio_signal, axis=0) # Convert mono to shape (1, N)
num_channels = audio_signal.shape[0]
total_duration = len(audio_signal[0]) / sample_rate
# Default values for start_time and clip_length
if start_time is None:
start_time = 0
if clip_length is None:
clip_length = total_duration
# Ensure start_time + clip_length does not exceed audio duration
end_time = min(start_time + clip_length, total_duration)
# Extract data from JSON
f0_time_steps = np.array(data["f0_time_steps"])
f0_values = {int(k): np.array(v) for k, v in data.get("f0_values", {}).items()}
f0_values_mono = np.array(data["f0_values_mono"])
# Compute time vectors
time_audio = np.linspace(0, total_duration, len(audio_signal[0]))
# Define time window for plotting
start_sample = int(start_time * sample_rate)
end_sample = int(end_time * sample_rate)
# Filter data strictly within start_time and end_time
mask_audio = (time_audio >= start_time) & (time_audio <= end_time)
mask_f0 = (f0_time_steps >= start_time) & (f0_time_steps <= end_time)
time_audio = time_audio[mask_audio] # Apply mask
# Determine unified y-axis limits for all signal waveforms
signal_min = np.min(audio_signal[:, mask_audio]) if num_channels > 1 else np.min(audio_signal[mask_audio])
signal_max = np.max(audio_signal[:, mask_audio]) if num_channels > 1 else np.max(audio_signal[mask_audio])
# Create plots: One waveform & spectrogram for each channel + one for mono
fig, axes = plt.subplots(num_channels * 2 + 2, 1, figsize=(14, num_channels * 4), sharex=True, constrained_layout=True)
for i in range(num_channels):
# Plot waveform
axes[i * 2].plot(time_audio, audio_signal[i][mask_audio], color="b", alpha=0.7)
axes[i * 2].set_ylabel(f"Ch {i} Signal")
axes[i * 2].set_ylim(signal_min, signal_max) # Standardize y-axis for all signals
axes[i * 2].grid(True, linestyle='--', alpha=0.5)
axes[i * 2].set_xlim(start_time, end_time) # Ensure x-axis limits match selection
# Overlay detected F0 values as black vertical markers on waveforms (if above cutoff)
for t, f0 in zip(f0_time_steps[mask_f0], f0_values.get(i, [])[mask_f0]):
if f0 > f0_cutoff:
axes[i * 2].axvspan(t, t + 0.1, color="black", alpha=0.3) # 100ms width
# Compute spectrogram and set correct time coordinates
S = librosa.feature.melspectrogram(y=audio_signal[i][start_sample:end_sample], sr=sample_rate, n_mels=128, fmax=8000)
S_dB = librosa.power_to_db(S, ref=np.max)
librosa.display.specshow(S_dB, sr=sample_rate, x_axis='time',
x_coords=np.linspace(start_time, end_time, S.shape[1]),
y_axis='mel', ax=axes[i * 2 + 1], cmap='magma')
axes[i * 2 + 1].set_ylabel(f"Ch {i} Spec")
axes[i * 2 + 1].set_xlim(start_time, end_time)
# Overlay detected F0 values on spectrogram (if above cutoff)
for t, f0 in zip(f0_time_steps[mask_f0], f0_values.get(i, [])[mask_f0]):
if f0 > f0_cutoff:
axes[i * 2 + 1].axvspan(t, t + 0.1, color="white", alpha=0.3)
axes[i * 2 + 1].scatter(t, f0, color='cyan', s=20, edgecolors='black') # Mark exact F0 location
# Compute and plot mono waveform
mono_signal = np.mean(audio_signal, axis=0)
axes[-2].plot(time_audio, mono_signal[mask_audio], color="g", alpha=0.7)
axes[-2].set_ylabel("Mono Signal")
axes[-2].set_ylim(signal_min, signal_max) # Standardize y-axis
axes[-2].grid(True, linestyle='--', alpha=0.5)
axes[-2].set_xlim(start_time, end_time)
# Compute and plot mono spectrogram
S_mono = librosa.feature.melspectrogram(y=mono_signal[start_sample:end_sample], sr=sample_rate, n_mels=128, fmax=8000)
S_mono_dB = librosa.power_to_db(S_mono, ref=np.max)
librosa.display.specshow(S_mono_dB, sr=sample_rate, x_axis='time',
x_coords=np.linspace(start_time, end_time, S_mono.shape[1]),
y_axis='mel', ax=axes[-1], cmap='magma')
axes[-1].set_ylabel("Mono Spec")
axes[-1].set_xlim(start_time, end_time)
# Overlay detected F0 values on mono spectrogram (if above cutoff)
for t, f0 in zip(f0_time_steps[mask_f0], f0_values_mono[mask_f0]):
if f0 > f0_cutoff:
axes[-1].axvspan(t, t + 0.1, color="white", alpha=0.3)
axes[-1].scatter(t, f0, color='cyan', s=20, edgecolors='black') # Mark exact F0 location
axes[-1].set_xlabel("Time (mm:ss)")
axes[-1].xaxis.set_major_formatter(ticker.FuncFormatter(format_time))
plt.xticks(rotation=45) # Rotate tick labels for readability
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Plot F0 values for each channel and the mono mix.")
parser.add_argument("wav_file", type=str, help="Path to the input WAV file")
parser.add_argument("-st", "--start_time", type=parse_time, default=None, help="Start time in mm:ss format (default: start of file)")
parser.add_argument("-t", "--clip_length", type=parse_time, default=None, help="Duration of the clip in mm:ss format (default: full file length)")
parser.add_argument("-c","--f0_cutoff", type=int, default=1000, help="Minimum F0 frequency to display (default: 1000 Hz)")
args = parser.parse_args()
plot_audio_and_f0(args.wav_file, args.start_time, args.clip_length, args.f0_cutoff)