diff --git a/stream/gtcrn_stream.py b/stream/gtcrn_stream.py index c25a8e6..e9f03bc 100644 --- a/stream/gtcrn_stream.py +++ b/stream/gtcrn_stream.py @@ -371,6 +371,7 @@ def forward(self, spec, conv_cache, tra_cache, inter_cache): x = torch.stft(x, 512, 256, 512, torch.hann_window(512).pow(0.5), return_complex=False)[None] with torch.no_grad(): y = model(x) + y = torch.view_as_complex(y.contiguous()) y = torch.istft(y, 512, 256, 512, torch.hann_window(512).pow(0.5)).detach().cpu().numpy() sf.write('test_wavs/enh.wav', y.squeeze(), 16000) @@ -389,6 +390,7 @@ def forward(self, spec, conv_cache, tra_cache, inter_cache): # times.append((toc-tic)*1000) # ys.append(yi) # ys = torch.cat(ys, dim=2) + # ys = torch.view_as_complex(ys.contiguous()) # ys = torch.istft(ys, 512, 256, 512, torch.hann_window(512).pow(0.5)).detach().cpu().numpy() # sf.write('test_wavs/enh_stream.wav', ys.squeeze(), 16000) @@ -458,3 +460,4 @@ def forward(self, spec, conv_cache, tra_cache, inter_cache): print(">>> inference time: mean: {:.1f}ms, max: {:.1f}ms, min: {:.1f}ms".format(1e3*np.mean(T_list), 1e3*np.max(T_list), 1e3*np.min(T_list))) print(">>> RTF:", 1e3*np.mean(T_list) / 16) +