diff --git a/audiotools/core/audio_signal.py b/audiotools/core/audio_signal.py index fb6d751..7adb822 100644 --- a/audiotools/core/audio_signal.py +++ b/audiotools/core/audio_signal.py @@ -17,14 +17,12 @@ from . import util from .display import DisplayMixin from .dsp import DSPMixin -from .effects import EffectMixin -from .effects import ImpulseResponseMixin +from .effects import EffectMixin, ImpulseResponseMixin from .ffmpeg import FFMPEGMixin from .loudness import LoudnessMixin from .playback import PlayMixin from .whisper import WhisperMixin - STFTParams = namedtuple( "STFTParams", ["window_length", "hop_length", "window_type", "match_stride", "padding_type"], @@ -88,7 +86,8 @@ class AudioSignal( duration : float, optional Duration in seconds to read from file, by default None device : str, optional - Device to load audio onto, by default None + Device to load audio onto. The default for files in CPU, and default for + tensors and numpy arrays is the same device that they are on. Examples -------- @@ -155,6 +154,8 @@ def __init__( audio_path, offset=offset, duration=duration, device=device ) elif audio_array is not None: + if device is None: + device = audio_array.device assert sample_rate is not None, "Must set sample rate!" self.load_from_array(audio_array, sample_rate, device=device)