Skip to content

Commit

Permalink
update formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
teticio committed Jul 15, 2024
1 parent 45f32ad commit daca68b
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 103 deletions.
28 changes: 15 additions & 13 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

def generate_spectrogram_audio_and_loop(model_id):
audio_diffusion = AudioDiffusion(model_id=model_id)
image, (sample_rate,
audio) = audio_diffusion.generate_spectrogram_and_audio()
image, (sample_rate, audio) = audio_diffusion.generate_spectrogram_and_audio()
loop = AudioDiffusion.loop_it(audio, sample_rate)
if loop is None:
loop = audio
Expand All @@ -24,23 +23,26 @@ def generate_spectrogram_audio_and_loop(model_id):
[colab](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/gradio_app.ipynb) \
to run this app.",
inputs=[
gr.Dropdown(label="Model",
choices=[
"teticio/audio-diffusion-256",
"teticio/audio-diffusion-breaks-256",
"teticio/audio-diffusion-instrumental-hiphop-256",
"teticio/audio-diffusion-ddim-256",
"teticio/latent-audio-diffusion-256",
"teticio/latent-audio-diffusion-ddim-256"
],
value="teticio/latent-audio-diffusion-ddim-256")
gr.Dropdown(
label="Model",
choices=[
"teticio/audio-diffusion-256",
"teticio/audio-diffusion-breaks-256",
"teticio/audio-diffusion-instrumental-hiphop-256",
"teticio/audio-diffusion-ddim-256",
"teticio/latent-audio-diffusion-256",
"teticio/latent-audio-diffusion-ddim-256",
],
value="teticio/latent-audio-diffusion-ddim-256",
)
],
outputs=[
gr.Image(label="Mel spectrogram", image_mode="L"),
gr.Audio(label="Audio"),
gr.Audio(label="Loop"),
],
allow_flagging="never")
allow_flagging="never",
)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
1 change: 0 additions & 1 deletion audiodiffusion/mel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import numpy as np # noqa: E402


try:
import librosa # noqa: E402

Expand Down
1 change: 1 addition & 0 deletions audiodiffusion/pipeline_audio_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from .mel import Mel


class AudioDiffusionPipeline(DiffusionPipeline):
"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
Expand Down
127 changes: 51 additions & 76 deletions scripts/train_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset, load_from_disk
from diffusers import (AutoencoderKL, DDIMScheduler, DDPMScheduler,
UNet2DConditionModel, UNet2DModel)
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, UNet2DConditionModel, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.audio_diffusion import Mel
from diffusers.training_utils import EMAModel
Expand All @@ -28,9 +27,7 @@
logger = get_logger(__name__)


def get_full_repo_name(model_id: str,
organization: Optional[str] = None,
token: Optional[str] = None):
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
Expand All @@ -52,9 +49,7 @@ def main(args):

if args.dataset_name is not None:
if os.path.exists(args.dataset_name):
dataset = load_from_disk(
args.dataset_name,
storage_options=args.dataset_config_name)["train"]
dataset = load_from_disk(args.dataset_name, storage_options=args.dataset_config_name)["train"]
else:
dataset = load_dataset(
args.dataset_name,
Expand All @@ -73,17 +68,16 @@ def main(args):
# Determine image resolution
resolution = dataset[0]["image"].height, dataset[0]["image"].width

augmentations = Compose([
ToTensor(),
Normalize([0.5], [0.5]),
])
augmentations = Compose(
[
ToTensor(),
Normalize([0.5], [0.5]),
]
)

def transforms(examples):
if args.vae is not None and vqvae.config["in_channels"] == 3:
images = [
augmentations(image.convert("RGB"))
for image in examples["image"]
]
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
else:
images = [augmentations(image) for image in examples["image"]]
if args.encodings is not None:
Expand All @@ -92,8 +86,7 @@ def transforms(examples):
return {"input": images}

dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.train_batch_size, shuffle=True)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)

if args.encodings is not None:
encodings = pickle.load(open(args.encodings, "rb"))
Expand All @@ -106,9 +99,7 @@ def transforms(examples):
vqvae = AudioDiffusionPipeline.from_pretrained(args.vae).vqvae
# Determine latent resolution
with torch.no_grad():
latent_resolution = vqvae.encode(
torch.zeros((1, 1) +
resolution)).latent_dist.sample().shape[2:]
latent_resolution = vqvae.encode(torch.zeros((1, 1) + resolution)).latent_dist.sample().shape[2:]

if args.from_pretrained is not None:
pipeline = AudioDiffusionPipeline.from_pretrained(args.from_pretrained)
Expand All @@ -121,10 +112,8 @@ def transforms(examples):
if args.encodings is None:
model = UNet2DModel(
sample_size=resolution if vqvae is None else latent_resolution,
in_channels=1
if vqvae is None else vqvae.config["latent_channels"],
out_channels=1
if vqvae is None else vqvae.config["latent_channels"],
in_channels=1 if vqvae is None else vqvae.config["latent_channels"],
out_channels=1 if vqvae is None else vqvae.config["latent_channels"],
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
Expand All @@ -148,10 +137,8 @@ def transforms(examples):
else:
model = UNet2DConditionModel(
sample_size=resolution if vqvae is None else latent_resolution,
in_channels=1
if vqvae is None else vqvae.config["latent_channels"],
out_channels=1
if vqvae is None else vqvae.config["latent_channels"],
in_channels=1 if vqvae is None else vqvae.config["latent_channels"],
out_channels=1 if vqvae is None else vqvae.config["latent_channels"],
layers_per_block=2,
block_out_channels=(128, 256, 512, 512),
down_block_types=(
Expand All @@ -170,11 +157,9 @@ def transforms(examples):
)

if args.scheduler == "ddpm":
noise_scheduler = DDPMScheduler(
num_train_timesteps=args.num_train_steps)
noise_scheduler = DDPMScheduler(num_train_timesteps=args.num_train_steps)
else:
noise_scheduler = DDIMScheduler(
num_train_timesteps=args.num_train_steps)
noise_scheduler = DDIMScheduler(num_train_timesteps=args.num_train_steps)

optimizer = torch.optim.AdamW(
model.parameters(),
Expand All @@ -188,12 +173,12 @@ def transforms(examples):
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs) //
args.gradient_accumulation_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
)

model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler)
model, optimizer, train_dataloader, lr_scheduler
)

ema_model = EMAModel(
getattr(model, "module", model),
Expand All @@ -204,8 +189,7 @@ def transforms(examples):

if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(output_dir).name,
token=args.hub_token)
repo_name = get_full_repo_name(Path(output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(output_dir, clone_from=repo_name)
Expand All @@ -224,8 +208,7 @@ def transforms(examples):

global_step = 0
for epoch in range(args.num_epochs):
progress_bar = tqdm(total=len(train_dataloader),
disable=not accelerator.is_local_main_process)
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")

if epoch < args.start_epoch:
Expand All @@ -245,8 +228,7 @@ def transforms(examples):
if vqvae is not None:
vqvae.to(clean_images.device)
with torch.no_grad():
clean_images = vqvae.encode(
clean_images).latent_dist.sample()
clean_images = vqvae.encode(clean_images).latent_dist.sample()
# Scale latent images to ensure approximately unit variance
clean_images = clean_images * 0.18215

Expand All @@ -257,20 +239,18 @@ def transforms(examples):
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(bsz, ),
(bsz,),
device=clean_images.device,
).long()

# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.add_noise(clean_images, noise,
timesteps)
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

with accelerator.accumulate(model):
# Predict the noise residual
if args.encodings is not None:
noise_pred = model(noisy_images, timesteps,
batch["encoding"])["sample"]
noise_pred = model(noisy_images, timesteps, batch["encoding"])["sample"]
else:
noise_pred = model(noisy_images, timesteps)["sample"]
loss = F.mse_loss(noise_pred, noise)
Expand Down Expand Up @@ -302,9 +282,11 @@ def transforms(examples):

# Generate sample images for visual inspection
if accelerator.is_main_process:
if ((epoch + 1) % args.save_model_epochs == 0
or (epoch + 1) % args.save_images_epochs == 0
or epoch == args.num_epochs - 1):
if (
(epoch + 1) % args.save_model_epochs == 0
or (epoch + 1) % args.save_images_epochs == 0
or epoch == args.num_epochs - 1
):
unet = accelerator.unwrap_model(model)
if args.use_ema:
ema_model.copy_to(unet.parameters())
Expand All @@ -315,9 +297,7 @@ def transforms(examples):
scheduler=noise_scheduler,
)

if (
epoch + 1
) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
if (epoch + 1) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
pipeline.save_pretrained(output_dir)

# save the model
Expand All @@ -329,15 +309,13 @@ def transforms(examples):
)

if (epoch + 1) % args.save_images_epochs == 0:
generator = torch.Generator(
device=clean_images.device).manual_seed(42)
generator = torch.Generator(device=clean_images.device).manual_seed(42)

if args.encodings is not None:
random.seed(42)
encoding = torch.stack(
random.sample(list(encodings.values()),
args.eval_batch_size)).to(
clean_images.device)
encoding = torch.stack(random.sample(list(encodings.values()), args.eval_batch_size)).to(
clean_images.device
)
else:
encoding = None

Expand All @@ -350,13 +328,15 @@ def transforms(examples):
)

# denormalize the images and save to tensorboard
images = np.array([
np.frombuffer(image.tobytes(), dtype="uint8").reshape(
(len(image.getbands()), image.height, image.width))
for image in images
])
accelerator.trackers[0].writer.add_images(
"test_samples", images, epoch)
images = np.array(
[
np.frombuffer(image.tobytes(), dtype="uint8").reshape(
(len(image.getbands()), image.height, image.width)
)
for image in images
]
)
accelerator.trackers[0].writer.add_images("test_samples", images, epoch)
for _, audio in enumerate(audios):
accelerator.trackers[0].writer.add_audio(
f"test_audio_{_}",
Expand All @@ -370,8 +350,7 @@ def transforms(examples):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Simple example of a training script.")
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset_name", type=str, default=None)
parser.add_argument("--dataset_config_name", type=str, default=None)
Expand Down Expand Up @@ -415,18 +394,16 @@ def transforms(examples):
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."),
"and an Nvidia Ampere GPU."
),
)
parser.add_argument("--hop_length", type=int, default=512)
parser.add_argument("--sample_rate", type=int, default=22050)
parser.add_argument("--n_fft", type=int, default=2048)
parser.add_argument("--from_pretrained", type=str, default=None)
parser.add_argument("--start_epoch", type=int, default=0)
parser.add_argument("--num_train_steps", type=int, default=1000)
parser.add_argument("--scheduler",
type=str,
default="ddpm",
help="ddpm or ddim")
parser.add_argument("--scheduler", type=str, default="ddpm", help="ddpm or ddim")
parser.add_argument(
"--vae",
type=str,
Expand All @@ -446,8 +423,6 @@ def transforms(examples):
args.local_rank = env_local_rank

if args.dataset_name is None and args.train_data_dir is None:
raise ValueError(
"You must specify either a dataset name from the hub or a train data directory."
)
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")

main(args)
Loading

0 comments on commit daca68b

Please sign in to comment.