From 97db8edf644c0c67f57f4ee457ddf8f6af7af638 Mon Sep 17 00:00:00 2001 From: Vikram Voleti Date: Tue, 19 Mar 2024 04:53:29 +0000 Subject: [PATCH 1/2] Gradio updates --- scripts/demo/sv3d_helpers.py | 23 +- scripts/demo/sv3d_p_gradio.py | 341 ++++++++++++++++++++++++ scripts/demo/sv3d_u_gradio.py | 295 ++++++++++++++++++++ scripts/sampling/simple_video_sample.py | 2 +- 4 files changed, 656 insertions(+), 5 deletions(-) create mode 100644 scripts/demo/sv3d_p_gradio.py create mode 100644 scripts/demo/sv3d_u_gradio.py diff --git a/scripts/demo/sv3d_helpers.py b/scripts/demo/sv3d_helpers.py index a0cebd197..b32750658 100644 --- a/scripts/demo/sv3d_helpers.py +++ b/scripts/demo/sv3d_helpers.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt import numpy as np +from PIL import Image def generate_dynamic_cycle_xy_values( @@ -74,8 +75,9 @@ def gen_dynamic_loop(length=21, elev_deg=0): return np.roll(azim_rad, -1), np.roll(elev_rad, -1) -def plot_3D(azim, polar, save_path, dynamic=True): - os.makedirs(os.path.dirname(save_path), exist_ok=True) +def plot_3D(azim, polar, save_path=None, dynamic=True): + if save_path is not None: + os.makedirs(os.path.dirname(save_path), exist_ok=True) elev = np.deg2rad(90) - polar fig = plt.figure(figsize=(5, 5)) ax = fig.add_subplot(projection="3d") @@ -98,7 +100,20 @@ def plot_3D(azim, polar, save_path, dynamic=True): ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1]) ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors="none", edgecolors="k") ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors="none", edgecolors="k") - ax.view_init(elev=30, azim=-20, roll=0) - plt.savefig(save_path, bbox_inches="tight") + ax.view_init(elev=40, azim=-20, roll=0) + ax.xaxis.set_ticklabels([]) + ax.yaxis.set_ticklabels([]) + ax.zaxis.set_ticklabels([]) + if save_path is None: + fig.canvas.draw() + lst = list(fig.canvas.get_width_height()) + lst.append(3) + image = Image.fromarray( + np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(lst) + ) + else: + plt.savefig(save_path, bbox_inches="tight") plt.clf() plt.close() + if save_path is None: + return image diff --git a/scripts/demo/sv3d_p_gradio.py b/scripts/demo/sv3d_p_gradio.py new file mode 100644 index 000000000..06848865c --- /dev/null +++ b/scripts/demo/sv3d_p_gradio.py @@ -0,0 +1,341 @@ +# Adding this at the very top of app.py to make 'generative-models' directory discoverable +import os +import sys + +sys.path.append(os.path.dirname(__file__)) + +import random +from glob import glob +from pathlib import Path +from typing import List, Optional + +import cv2 +import gradio as gr +import imageio +import numpy as np +import torch +from einops import rearrange, repeat +from huggingface_hub import hf_hub_download +from PIL import Image +from rembg import remove +from scripts.demo.sv3d_helpers import gen_dynamic_loop, plot_3D +from scripts.sampling.simple_video_sample import ( + get_batch, + get_unique_embedder_keys_from_conditioner, + load_model, +) +from sgm.inference.helpers import embed_watermark +from torchvision.transforms import ToTensor + +version = "sv3d_p" # replace with 'sv3d_p' or 'sv3d_u' for other models + +# Define the repo, local directory and filename +repo_id = "stabilityai/sv3d" +filename = f"{version}.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors" +local_dir = "checkpoints" +local_ckpt_path = os.path.join(local_dir, filename) + +# Check if the file already exists +if not os.path.exists(local_ckpt_path): + # If the file doesn't exist, download it + hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir) + print("File downloaded.") +else: + print("File already exists. No need to download.") + +device = "cuda" +max_64_bit_int = 2**63 - 1 + +num_frames = 21 +num_steps = 50 +model_config = f"scripts/sampling/configs/{version}.yaml" + +model, filter = load_model( + model_config, + device, + num_frames, + num_steps, +) + + +def gen_orbit(orbit, elev_deg): + global polars_rad + global azimuths_rad + if orbit == "dynamic": + azim_rad, elev_rad = gen_dynamic_loop(length=num_frames, elev_deg=elev_deg) + polars_rad = np.deg2rad(90) - elev_rad + azimuths_rad = azim_rad + else: + polars_rad = np.array([np.deg2rad(90 - elev_deg)] * num_frames) + azimuths_rad = np.linspace(0, 2 * np.pi, num_frames + 1)[1:] + + plot = plot_3D( + azim=azimuths_rad, + polar=polars_rad, + save_path=None, + dynamic=(orbit == "dynamic"), + ) + return plot + + +def sample( + input_path: str = "assets/test_image.png", # Can either be image file or folder with image files + seed: Optional[int] = None, + randomize_seed: bool = True, + orbit: str = "same elevation", + elev_deg: float = 10.0, + decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. + device: str = "cuda", + output_folder: str = None, + image_frame_ratio: Optional[float] = None, +): + """ + Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each + image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. + """ + if randomize_seed: + seed = random.randint(0, max_64_bit_int) + + torch.manual_seed(seed) + + path = Path(input_path) + all_img_paths = [] + if path.is_file(): + if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]): + all_img_paths = [input_path] + else: + raise ValueError("Path is not valid image file.") + elif path.is_dir(): + all_img_paths = sorted( + [ + f + for f in path.iterdir() + if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] + ] + ) + if len(all_img_paths) == 0: + raise ValueError("Folder does not contain any images.") + else: + raise ValueError + + for input_img_path in all_img_paths: + + image = Image.open(input_img_path) + if image.mode == "RGBA": + pass + else: + # remove bg + image.thumbnail([768, 768], Image.Resampling.LANCZOS) + image = remove(image.convert("RGBA"), alpha_matting=True) + + # resize object in frame + image_arr = np.array(image) + in_w, in_h = image_arr.shape[:2] + ret, mask = cv2.threshold( + np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY + ) + x, y, w, h = cv2.boundingRect(mask) + max_size = max(w, h) + side_len = ( + int(max_size / image_frame_ratio) if image_frame_ratio is not None else in_w + ) + padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) + center = side_len // 2 + padded_image[ + center - h // 2 : center - h // 2 + h, + center - w // 2 : center - w // 2 + w, + ] = image_arr[y : y + h, x : x + w] + # resize frame to 576x576 + rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS) + # white bg + rgba_arr = np.array(rgba) / 255.0 + rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) + input_image = Image.fromarray((rgb * 255).astype(np.uint8)) + + image = ToTensor()(input_image) + image = image * 2.0 - 1.0 + + image = image.unsqueeze(0).to(device) + H, W = image.shape[2:] + assert image.shape[1] == 3 + F = 8 + C = 4 + shape = (num_frames, C, H // F, W // F) + if (H, W) != (576, 576) and "sv3d" in version: + print( + "WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576." + ) + + cond_aug = 1e-5 + + value_dict = {} + value_dict["cond_aug"] = cond_aug + value_dict["cond_frames_without_noise"] = image + value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) + value_dict["cond_aug"] = cond_aug + + value_dict["polars_rad"] = polars_rad + value_dict["azimuths_rad"] = azimuths_rad + + output_folder = output_folder or f"outputs/gradio/{version}" + cond_aug = 1e-5 + + with torch.no_grad(): + with torch.autocast(device): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [1, num_frames], + T=num_frames, + device=device, + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=[ + "cond_frames", + "cond_frames_without_noise", + ], + ) + + for k in ["crossattn", "concat"]: + uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) + uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) + c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) + c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) + + randn = torch.randn(shape, device=device) + + additional_model_inputs = {} + additional_model_inputs["image_only_indicator"] = torch.zeros( + 2, num_frames + ).to(device) + additional_model_inputs["num_video_frames"] = batch["num_video_frames"] + + def denoiser(input, sigma, c): + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) + + samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) + model.en_and_decode_n_samples_a_time = decoding_t + samples_x = model.decode_first_stage(samples_z) + samples_x[-1:] = value_dict["cond_frames_without_noise"] + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + os.makedirs(output_folder, exist_ok=True) + base_count = len(glob(os.path.join(output_folder, "*.mp4"))) + + imageio.imwrite( + os.path.join(output_folder, f"{base_count:06d}.jpg"), input_image + ) + + samples = embed_watermark(samples) + samples = filter(samples) + vid = ( + (rearrange(samples, "t c h w -> t h w c") * 255) + .cpu() + .numpy() + .astype(np.uint8) + ) + video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") + imageio.mimwrite(video_path, vid) + + return video_path, seed + + +def resize_image(image_path, output_size=(576, 576)): + image = Image.open(image_path) + # Calculate aspect ratios + target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size + image_aspect = image.width / image.height # Aspect ratio of the original image + + # Resize then crop if the original image is larger + if image_aspect > target_aspect: + # Resize the image to match the target height, maintaining aspect ratio + new_height = output_size[1] + new_width = int(new_height * image_aspect) + resized_image = image.resize((new_width, new_height), Image.LANCZOS) + # Calculate coordinates for cropping + left = (new_width - output_size[0]) / 2 + top = 0 + right = (new_width + output_size[0]) / 2 + bottom = output_size[1] + else: + # Resize the image to match the target width, maintaining aspect ratio + new_width = output_size[0] + new_height = int(new_width / image_aspect) + resized_image = image.resize((new_width, new_height), Image.LANCZOS) + # Calculate coordinates for cropping + left = 0 + top = (new_height - output_size[1]) / 2 + right = output_size[0] + bottom = (new_height + output_size[1]) / 2 + + # Crop the image + cropped_image = resized_image.crop((left, top, right, bottom)) + + return cropped_image + + +with gr.Blocks() as demo: + gr.Markdown( + """# Demo for SV3D_p from Stability AI ([model](https://huggingface.co/stabilityai/sv3d), [news](https://stability.ai/news/introducing-stable-video-3d)) +#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/sv3d/blob/main/LICENSE)): generate 21 frames orbital video from a single image, at variable elevation and azimuth. +Generation takes ~40s (for 50 steps) in an A100. + """ + ) + with gr.Row(): + with gr.Column(): + image = gr.Image(label="Upload your image", type="filepath") + generate_btn = gr.Button("Generate") + video = gr.Video() + with gr.Row(): + with gr.Column(): + elev_deg = gr.Slider( + label="Elevation (in degrees)", + info="Elevation of the camera in the conditioning image, in degrees.", + value=10.0, + minimum=-10, + maximum=30, + ) + orbit = gr.Dropdown( + ["same elevation", "dynamic"], + label="Orbit", + info="Choose with orbit to generate", + ) + plot_image = gr.Image() + with gr.Accordion("Advanced options", open=False): + seed = gr.Slider( + label="Seed", + value=23, + randomize=True, + minimum=0, + maximum=max_64_bit_int, + step=1, + ) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + decoding_t = gr.Slider( + label="Decode n frames at a time", + info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.", + value=7, + minimum=1, + maximum=14, + ) + + image.upload(fn=resize_image, inputs=image, outputs=image, queue=False) + + elev_deg.change(gen_orbit, [orbit, elev_deg], plot_image) + orbit.change(gen_orbit, [orbit, elev_deg], plot_image) + # seed.change(gen_orbit, [orbit, elev_deg], plot_image) + + generate_btn.click( + fn=sample, + inputs=[image, seed, randomize_seed, decoding_t], + outputs=[video, seed], + api_name="video", + ) + +if __name__ == "__main__": + demo.queue(max_size=20) + demo.launch(share=True) diff --git a/scripts/demo/sv3d_u_gradio.py b/scripts/demo/sv3d_u_gradio.py new file mode 100644 index 000000000..1745fed80 --- /dev/null +++ b/scripts/demo/sv3d_u_gradio.py @@ -0,0 +1,295 @@ +# Adding this at the very top of app.py to make 'generative-models' directory discoverable +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) + +import random +from glob import glob +from pathlib import Path +from typing import Optional + +import cv2 +import gradio as gr +import imageio +import numpy as np +import torch +from einops import rearrange, repeat +from huggingface_hub import hf_hub_download +from PIL import Image +from rembg import remove +from scripts.sampling.simple_video_sample import ( + get_batch, + get_unique_embedder_keys_from_conditioner, + load_model, +) +from sgm.inference.helpers import embed_watermark +from torchvision.transforms import ToTensor + +version = "sv3d_u" # replace with 'sv3d_p' or 'sv3d_u' for other models + +# Define the repo, local directory and filename +repo_id = "stabilityai/sv3d" +filename = f"{version}.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors" +local_dir = "checkpoints" +local_ckpt_path = os.path.join(local_dir, filename) + +# Check if the file already exists +if not os.path.exists(local_ckpt_path): + # If the file doesn't exist, download it + hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir) + print("File downloaded.") +else: + print("File already exists. No need to download.") + +device = "cuda" +max_64_bit_int = 2**63 - 1 + +num_frames = 21 +num_steps = 50 +model_config = f"scripts/sampling/configs/{version}.yaml" + +model, filter = load_model( + model_config, + device, + num_frames, + num_steps, +) + + +def sample( + input_path: str = "assets/test_image.png", # Can either be image file or folder with image files + seed: Optional[int] = None, + randomize_seed: bool = True, + decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. + device: str = "cuda", + output_folder: str = None, + image_frame_ratio: Optional[float] = None, +): + """ + Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each + image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. + """ + if randomize_seed: + seed = random.randint(0, max_64_bit_int) + + torch.manual_seed(seed) + + path = Path(input_path) + all_img_paths = [] + if path.is_file(): + if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]): + all_img_paths = [input_path] + else: + raise ValueError("Path is not valid image file.") + elif path.is_dir(): + all_img_paths = sorted( + [ + f + for f in path.iterdir() + if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] + ] + ) + if len(all_img_paths) == 0: + raise ValueError("Folder does not contain any images.") + else: + raise ValueError + + for input_img_path in all_img_paths: + + image = Image.open(input_img_path) + if image.mode == "RGBA": + pass + else: + # remove bg + image.thumbnail([768, 768], Image.Resampling.LANCZOS) + image = remove(image.convert("RGBA"), alpha_matting=True) + + # resize object in frame + image_arr = np.array(image) + in_w, in_h = image_arr.shape[:2] + ret, mask = cv2.threshold( + np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY + ) + x, y, w, h = cv2.boundingRect(mask) + max_size = max(w, h) + side_len = ( + int(max_size / image_frame_ratio) if image_frame_ratio is not None else in_w + ) + padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) + center = side_len // 2 + padded_image[ + center - h // 2 : center - h // 2 + h, + center - w // 2 : center - w // 2 + w, + ] = image_arr[y : y + h, x : x + w] + # resize frame to 576x576 + rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS) + # white bg + rgba_arr = np.array(rgba) / 255.0 + rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) + input_image = Image.fromarray((rgb * 255).astype(np.uint8)) + + image = ToTensor()(input_image) + image = image * 2.0 - 1.0 + + image = image.unsqueeze(0).to(device) + H, W = image.shape[2:] + assert image.shape[1] == 3 + F = 8 + C = 4 + shape = (num_frames, C, H // F, W // F) + if (H, W) != (576, 576) and "sv3d" in version: + print( + "WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576." + ) + + cond_aug = 1e-5 + + value_dict = {} + value_dict["cond_aug"] = cond_aug + value_dict["cond_frames_without_noise"] = image + value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) + value_dict["cond_aug"] = cond_aug + + output_folder = output_folder or f"outputs/gradio/{version}" + cond_aug = 1e-5 + + with torch.no_grad(): + with torch.autocast(device): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [1, num_frames], + T=num_frames, + device=device, + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=[ + "cond_frames", + "cond_frames_without_noise", + ], + ) + + for k in ["crossattn", "concat"]: + uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) + uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) + c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) + c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) + + randn = torch.randn(shape, device=device) + + additional_model_inputs = {} + additional_model_inputs["image_only_indicator"] = torch.zeros( + 2, num_frames + ).to(device) + additional_model_inputs["num_video_frames"] = batch["num_video_frames"] + + def denoiser(input, sigma, c): + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) + + samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) + model.en_and_decode_n_samples_a_time = decoding_t + samples_x = model.decode_first_stage(samples_z) + samples_x[-1:] = value_dict["cond_frames_without_noise"] + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + os.makedirs(output_folder, exist_ok=True) + base_count = len(glob(os.path.join(output_folder, "*.mp4"))) + + imageio.imwrite( + os.path.join(output_folder, f"{base_count:06d}.jpg"), input_image + ) + + samples = embed_watermark(samples) + samples = filter(samples) + vid = ( + (rearrange(samples, "t c h w -> t h w c") * 255) + .cpu() + .numpy() + .astype(np.uint8) + ) + video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") + imageio.mimwrite(video_path, vid) + + return video_path, seed + + +def resize_image(image_path, output_size=(576, 576)): + image = Image.open(image_path) + # Calculate aspect ratios + target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size + image_aspect = image.width / image.height # Aspect ratio of the original image + + # Resize then crop if the original image is larger + if image_aspect > target_aspect: + # Resize the image to match the target height, maintaining aspect ratio + new_height = output_size[1] + new_width = int(new_height * image_aspect) + resized_image = image.resize((new_width, new_height), Image.LANCZOS) + # Calculate coordinates for cropping + left = (new_width - output_size[0]) / 2 + top = 0 + right = (new_width + output_size[0]) / 2 + bottom = output_size[1] + else: + # Resize the image to match the target width, maintaining aspect ratio + new_width = output_size[0] + new_height = int(new_width / image_aspect) + resized_image = image.resize((new_width, new_height), Image.LANCZOS) + # Calculate coordinates for cropping + left = 0 + top = (new_height - output_size[1]) / 2 + right = output_size[0] + bottom = (new_height + output_size[1]) / 2 + + # Crop the image + cropped_image = resized_image.crop((left, top, right, bottom)) + + return cropped_image + + +with gr.Blocks() as demo: + gr.Markdown( + """# Demo for SV3D_u from Stability AI ([model](https://huggingface.co/stabilityai/sv3d), [news](https://stability.ai/news/introducing-stable-video-3d)) +#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/sv3d/blob/main/LICENSE)): generate 21 frames orbital video from a single image, at the same elevation. +Generation takes ~40s (for 50 steps) in an A100. + """ + ) + with gr.Row(): + with gr.Column(): + image = gr.Image(label="Upload your image", type="filepath") + generate_btn = gr.Button("Generate") + video = gr.Video() + with gr.Accordion("Advanced options", open=False): + seed = gr.Slider( + label="Seed", + value=23, + randomize=True, + minimum=0, + maximum=max_64_bit_int, + step=1, + ) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + decoding_t = gr.Slider( + label="Decode n frames at a time", + info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.", + value=7, + minimum=1, + maximum=14, + ) + + image.upload(fn=resize_image, inputs=image, outputs=image, queue=False) + generate_btn.click( + fn=sample, + inputs=[image, seed, randomize_seed, decoding_t], + outputs=[video, seed], + api_name="video", + ) + +if __name__ == "__main__": + demo.queue(max_size=20) + demo.launch(share=True) diff --git a/scripts/sampling/simple_video_sample.py b/scripts/sampling/simple_video_sample.py index 29a8b8581..b96152969 100644 --- a/scripts/sampling/simple_video_sample.py +++ b/scripts/sampling/simple_video_sample.py @@ -100,7 +100,7 @@ def sample( device, num_frames, num_steps, - verbose, + verbose=verbose, ) torch.manual_seed(seed) From c4a9d1f8658835e59f38aab43470d486df030097 Mon Sep 17 00:00:00 2001 From: Vikram Voleti Date: Tue, 23 Jul 2024 13:38:39 +0000 Subject: [PATCH 2/2] Fixes --- scripts/demo/sv3d_p_gradio.py | 19 +++++++++---------- scripts/demo/sv3d_u_gradio.py | 2 +- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/scripts/demo/sv3d_p_gradio.py b/scripts/demo/sv3d_p_gradio.py index 06848865c..1cb0c514b 100644 --- a/scripts/demo/sv3d_p_gradio.py +++ b/scripts/demo/sv3d_p_gradio.py @@ -57,10 +57,11 @@ num_steps, ) +polars_rad = np.array([np.deg2rad(90 - 10.0)] * num_frames) +azimuths_rad = np.linspace(0, 2 * np.pi, num_frames + 1)[1:] + def gen_orbit(orbit, elev_deg): - global polars_rad - global azimuths_rad if orbit == "dynamic": azim_rad, elev_rad = gen_dynamic_loop(length=num_frames, elev_deg=elev_deg) polars_rad = np.deg2rad(90) - elev_rad @@ -82,8 +83,6 @@ def sample( input_path: str = "assets/test_image.png", # Can either be image file or folder with image files seed: Optional[int] = None, randomize_seed: bool = True, - orbit: str = "same elevation", - elev_deg: float = 10.0, decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: str = "cuda", output_folder: str = None, @@ -292,6 +291,11 @@ def resize_image(image_path, output_size=(576, 576)): video = gr.Video() with gr.Row(): with gr.Column(): + orbit = gr.Dropdown( + ["same elevation", "dynamic"], + label="Orbit", + info="Choose with orbit to generate", + ) elev_deg = gr.Slider( label="Elevation (in degrees)", info="Elevation of the camera in the conditioning image, in degrees.", @@ -299,11 +303,6 @@ def resize_image(image_path, output_size=(576, 576)): minimum=-10, maximum=30, ) - orbit = gr.Dropdown( - ["same elevation", "dynamic"], - label="Orbit", - info="Choose with orbit to generate", - ) plot_image = gr.Image() with gr.Accordion("Advanced options", open=False): seed = gr.Slider( @@ -325,8 +324,8 @@ def resize_image(image_path, output_size=(576, 576)): image.upload(fn=resize_image, inputs=image, outputs=image, queue=False) - elev_deg.change(gen_orbit, [orbit, elev_deg], plot_image) orbit.change(gen_orbit, [orbit, elev_deg], plot_image) + elev_deg.change(gen_orbit, [orbit, elev_deg], plot_image) # seed.change(gen_orbit, [orbit, elev_deg], plot_image) generate_btn.click( diff --git a/scripts/demo/sv3d_u_gradio.py b/scripts/demo/sv3d_u_gradio.py index 1745fed80..e298da5c3 100644 --- a/scripts/demo/sv3d_u_gradio.py +++ b/scripts/demo/sv3d_u_gradio.py @@ -2,7 +2,7 @@ import os import sys -sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) +sys.path.append(os.path.dirname(__file__)) import random from glob import glob