From daca68b99486cf19477d4b2bdce1f6f65e6b69a8 Mon Sep 17 00:00:00 2001 From: Robert Dargavel Smith Date: Mon, 15 Jul 2024 15:39:39 +0100 Subject: [PATCH] update formatting --- app.py | 28 ++--- audiodiffusion/mel.py | 1 - audiodiffusion/pipeline_audio_diffusion.py | 1 + scripts/train_unet.py | 127 +++++++++------------ streamlit_app.py | 31 ++--- 5 files changed, 85 insertions(+), 103 deletions(-) diff --git a/app.py b/app.py index 772e86b..2b73505 100644 --- a/app.py +++ b/app.py @@ -7,8 +7,7 @@ def generate_spectrogram_audio_and_loop(model_id): audio_diffusion = AudioDiffusion(model_id=model_id) - image, (sample_rate, - audio) = audio_diffusion.generate_spectrogram_and_audio() + image, (sample_rate, audio) = audio_diffusion.generate_spectrogram_and_audio() loop = AudioDiffusion.loop_it(audio, sample_rate) if loop is None: loop = audio @@ -24,23 +23,26 @@ def generate_spectrogram_audio_and_loop(model_id): [colab](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/gradio_app.ipynb) \ to run this app.", inputs=[ - gr.Dropdown(label="Model", - choices=[ - "teticio/audio-diffusion-256", - "teticio/audio-diffusion-breaks-256", - "teticio/audio-diffusion-instrumental-hiphop-256", - "teticio/audio-diffusion-ddim-256", - "teticio/latent-audio-diffusion-256", - "teticio/latent-audio-diffusion-ddim-256" - ], - value="teticio/latent-audio-diffusion-ddim-256") + gr.Dropdown( + label="Model", + choices=[ + "teticio/audio-diffusion-256", + "teticio/audio-diffusion-breaks-256", + "teticio/audio-diffusion-instrumental-hiphop-256", + "teticio/audio-diffusion-ddim-256", + "teticio/latent-audio-diffusion-256", + "teticio/latent-audio-diffusion-ddim-256", + ], + value="teticio/latent-audio-diffusion-ddim-256", + ) ], outputs=[ gr.Image(label="Mel spectrogram", image_mode="L"), gr.Audio(label="Audio"), gr.Audio(label="Loop"), ], - allow_flagging="never") + allow_flagging="never", +) if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/audiodiffusion/mel.py b/audiodiffusion/mel.py index 78f8e7b..929fb69 100644 --- a/audiodiffusion/mel.py +++ b/audiodiffusion/mel.py @@ -26,7 +26,6 @@ import numpy as np # noqa: E402 - try: import librosa # noqa: E402 diff --git a/audiodiffusion/pipeline_audio_diffusion.py b/audiodiffusion/pipeline_audio_diffusion.py index dd1ed82..c954b82 100644 --- a/audiodiffusion/pipeline_audio_diffusion.py +++ b/audiodiffusion/pipeline_audio_diffusion.py @@ -35,6 +35,7 @@ from .mel import Mel + class AudioDiffusionPipeline(DiffusionPipeline): """ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the diff --git a/scripts/train_unet.py b/scripts/train_unet.py index a29a7b2..319add7 100644 --- a/scripts/train_unet.py +++ b/scripts/train_unet.py @@ -13,8 +13,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from datasets import load_dataset, load_from_disk -from diffusers import (AutoencoderKL, DDIMScheduler, DDPMScheduler, - UNet2DConditionModel, UNet2DModel) +from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, UNet2DConditionModel, UNet2DModel from diffusers.optimization import get_scheduler from diffusers.pipelines.audio_diffusion import Mel from diffusers.training_utils import EMAModel @@ -28,9 +27,7 @@ logger = get_logger(__name__) -def get_full_repo_name(model_id: str, - organization: Optional[str] = None, - token: Optional[str] = None): +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): if token is None: token = HfFolder.get_token() if organization is None: @@ -52,9 +49,7 @@ def main(args): if args.dataset_name is not None: if os.path.exists(args.dataset_name): - dataset = load_from_disk( - args.dataset_name, - storage_options=args.dataset_config_name)["train"] + dataset = load_from_disk(args.dataset_name, storage_options=args.dataset_config_name)["train"] else: dataset = load_dataset( args.dataset_name, @@ -73,17 +68,16 @@ def main(args): # Determine image resolution resolution = dataset[0]["image"].height, dataset[0]["image"].width - augmentations = Compose([ - ToTensor(), - Normalize([0.5], [0.5]), - ]) + augmentations = Compose( + [ + ToTensor(), + Normalize([0.5], [0.5]), + ] + ) def transforms(examples): if args.vae is not None and vqvae.config["in_channels"] == 3: - images = [ - augmentations(image.convert("RGB")) - for image in examples["image"] - ] + images = [augmentations(image.convert("RGB")) for image in examples["image"]] else: images = [augmentations(image) for image in examples["image"]] if args.encodings is not None: @@ -92,8 +86,7 @@ def transforms(examples): return {"input": images} dataset.set_transform(transforms) - train_dataloader = torch.utils.data.DataLoader( - dataset, batch_size=args.train_batch_size, shuffle=True) + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) if args.encodings is not None: encodings = pickle.load(open(args.encodings, "rb")) @@ -106,9 +99,7 @@ def transforms(examples): vqvae = AudioDiffusionPipeline.from_pretrained(args.vae).vqvae # Determine latent resolution with torch.no_grad(): - latent_resolution = vqvae.encode( - torch.zeros((1, 1) + - resolution)).latent_dist.sample().shape[2:] + latent_resolution = vqvae.encode(torch.zeros((1, 1) + resolution)).latent_dist.sample().shape[2:] if args.from_pretrained is not None: pipeline = AudioDiffusionPipeline.from_pretrained(args.from_pretrained) @@ -121,10 +112,8 @@ def transforms(examples): if args.encodings is None: model = UNet2DModel( sample_size=resolution if vqvae is None else latent_resolution, - in_channels=1 - if vqvae is None else vqvae.config["latent_channels"], - out_channels=1 - if vqvae is None else vqvae.config["latent_channels"], + in_channels=1 if vqvae is None else vqvae.config["latent_channels"], + out_channels=1 if vqvae is None else vqvae.config["latent_channels"], layers_per_block=2, block_out_channels=(128, 128, 256, 256, 512, 512), down_block_types=( @@ -148,10 +137,8 @@ def transforms(examples): else: model = UNet2DConditionModel( sample_size=resolution if vqvae is None else latent_resolution, - in_channels=1 - if vqvae is None else vqvae.config["latent_channels"], - out_channels=1 - if vqvae is None else vqvae.config["latent_channels"], + in_channels=1 if vqvae is None else vqvae.config["latent_channels"], + out_channels=1 if vqvae is None else vqvae.config["latent_channels"], layers_per_block=2, block_out_channels=(128, 256, 512, 512), down_block_types=( @@ -170,11 +157,9 @@ def transforms(examples): ) if args.scheduler == "ddpm": - noise_scheduler = DDPMScheduler( - num_train_timesteps=args.num_train_steps) + noise_scheduler = DDPMScheduler(num_train_timesteps=args.num_train_steps) else: - noise_scheduler = DDIMScheduler( - num_train_timesteps=args.num_train_steps) + noise_scheduler = DDIMScheduler(num_train_timesteps=args.num_train_steps) optimizer = torch.optim.AdamW( model.parameters(), @@ -188,12 +173,12 @@ def transforms(examples): args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=(len(train_dataloader) * args.num_epochs) // - args.gradient_accumulation_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, ) model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - model, optimizer, train_dataloader, lr_scheduler) + model, optimizer, train_dataloader, lr_scheduler + ) ema_model = EMAModel( getattr(model, "module", model), @@ -204,8 +189,7 @@ def transforms(examples): if args.push_to_hub: if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(output_dir).name, - token=args.hub_token) + repo_name = get_full_repo_name(Path(output_dir).name, token=args.hub_token) else: repo_name = args.hub_model_id repo = Repository(output_dir, clone_from=repo_name) @@ -224,8 +208,7 @@ def transforms(examples): global_step = 0 for epoch in range(args.num_epochs): - progress_bar = tqdm(total=len(train_dataloader), - disable=not accelerator.is_local_main_process) + progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) progress_bar.set_description(f"Epoch {epoch}") if epoch < args.start_epoch: @@ -245,8 +228,7 @@ def transforms(examples): if vqvae is not None: vqvae.to(clean_images.device) with torch.no_grad(): - clean_images = vqvae.encode( - clean_images).latent_dist.sample() + clean_images = vqvae.encode(clean_images).latent_dist.sample() # Scale latent images to ensure approximately unit variance clean_images = clean_images * 0.18215 @@ -257,20 +239,18 @@ def transforms(examples): timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, - (bsz, ), + (bsz,), device=clean_images.device, ).long() # Add noise to the clean images according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_images = noise_scheduler.add_noise(clean_images, noise, - timesteps) + noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) with accelerator.accumulate(model): # Predict the noise residual if args.encodings is not None: - noise_pred = model(noisy_images, timesteps, - batch["encoding"])["sample"] + noise_pred = model(noisy_images, timesteps, batch["encoding"])["sample"] else: noise_pred = model(noisy_images, timesteps)["sample"] loss = F.mse_loss(noise_pred, noise) @@ -302,9 +282,11 @@ def transforms(examples): # Generate sample images for visual inspection if accelerator.is_main_process: - if ((epoch + 1) % args.save_model_epochs == 0 - or (epoch + 1) % args.save_images_epochs == 0 - or epoch == args.num_epochs - 1): + if ( + (epoch + 1) % args.save_model_epochs == 0 + or (epoch + 1) % args.save_images_epochs == 0 + or epoch == args.num_epochs - 1 + ): unet = accelerator.unwrap_model(model) if args.use_ema: ema_model.copy_to(unet.parameters()) @@ -315,9 +297,7 @@ def transforms(examples): scheduler=noise_scheduler, ) - if ( - epoch + 1 - ) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: + if (epoch + 1) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: pipeline.save_pretrained(output_dir) # save the model @@ -329,15 +309,13 @@ def transforms(examples): ) if (epoch + 1) % args.save_images_epochs == 0: - generator = torch.Generator( - device=clean_images.device).manual_seed(42) + generator = torch.Generator(device=clean_images.device).manual_seed(42) if args.encodings is not None: random.seed(42) - encoding = torch.stack( - random.sample(list(encodings.values()), - args.eval_batch_size)).to( - clean_images.device) + encoding = torch.stack(random.sample(list(encodings.values()), args.eval_batch_size)).to( + clean_images.device + ) else: encoding = None @@ -350,13 +328,15 @@ def transforms(examples): ) # denormalize the images and save to tensorboard - images = np.array([ - np.frombuffer(image.tobytes(), dtype="uint8").reshape( - (len(image.getbands()), image.height, image.width)) - for image in images - ]) - accelerator.trackers[0].writer.add_images( - "test_samples", images, epoch) + images = np.array( + [ + np.frombuffer(image.tobytes(), dtype="uint8").reshape( + (len(image.getbands()), image.height, image.width) + ) + for image in images + ] + ) + accelerator.trackers[0].writer.add_images("test_samples", images, epoch) for _, audio in enumerate(audios): accelerator.trackers[0].writer.add_audio( f"test_audio_{_}", @@ -370,8 +350,7 @@ def transforms(examples): if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Simple example of a training script.") + parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument("--local_rank", type=int, default=-1) parser.add_argument("--dataset_name", type=str, default=None) parser.add_argument("--dataset_config_name", type=str, default=None) @@ -415,7 +394,8 @@ def transforms(examples): help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU."), + "and an Nvidia Ampere GPU." + ), ) parser.add_argument("--hop_length", type=int, default=512) parser.add_argument("--sample_rate", type=int, default=22050) @@ -423,10 +403,7 @@ def transforms(examples): parser.add_argument("--from_pretrained", type=str, default=None) parser.add_argument("--start_epoch", type=int, default=0) parser.add_argument("--num_train_steps", type=int, default=1000) - parser.add_argument("--scheduler", - type=str, - default="ddpm", - help="ddpm or ddim") + parser.add_argument("--scheduler", type=str, default="ddpm", help="ddpm or ddim") parser.add_argument( "--vae", type=str, @@ -446,8 +423,6 @@ def transforms(examples): args.local_rank = env_local_rank if args.dataset_name is None and args.train_data_dir is None: - raise ValueError( - "You must specify either a dataset name from the hub or a train data directory." - ) + raise ValueError("You must specify either a dataset name from the hub or a train data directory.") main(args) diff --git a/streamlit_app.py b/streamlit_app.py index 1527a60..2132f2e 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -1,8 +1,9 @@ from io import BytesIO -import streamlit as st + import soundfile as sf -from librosa.util import normalize +import streamlit as st from librosa.beat import beat_track +from librosa.util import normalize from audiodiffusion import AudioDiffusion @@ -11,22 +12,26 @@ st.markdown( "Generate audio using Huggingface diffusers.\ The models without 'latent' or 'ddim' give better results but take about \ - 20 minutes without a GPU.", ) + 20 minutes without a GPU.", + ) - model_id = st.selectbox("Model", [ - "teticio/audio-diffusion-256", "teticio/audio-diffusion-breaks-256", - "teticio/audio-diffusion-instrumental-hiphop-256", - "teticio/audio-diffusion-ddim-256", - "teticio/latent-audio-diffusion-256", - "teticio/latent-audio-diffusion-ddim-256" - ], - index=5) + model_id = st.selectbox( + "Model", + [ + "teticio/audio-diffusion-256", + "teticio/audio-diffusion-breaks-256", + "teticio/audio-diffusion-instrumental-hiphop-256", + "teticio/audio-diffusion-ddim-256", + "teticio/latent-audio-diffusion-256", + "teticio/latent-audio-diffusion-ddim-256", + ], + index=5, + ) audio_diffusion = AudioDiffusion(model_id=model_id) if st.button("Generate"): st.markdown("Generating...") - image, (sample_rate, - audio) = audio_diffusion.generate_spectrogram_and_audio() + image, (sample_rate, audio) = audio_diffusion.generate_spectrogram_and_audio() st.image(image, caption="Mel spectrogram") buffer = BytesIO() sf.write(buffer, normalize(audio), sample_rate, format="WAV")