Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/stable-diffusion/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@
// cargo run --release --example tensor-tools cp ./data/vae.npz ./data/vae.ot
// cargo run --release --example tensor-tools cp ./data/unet.npz ./data/unet.ot
use clap::Parser;
use diffusers::pipelines::stable_diffusion;
use diffusers::transformers::clip;
use diffusers::pipelines::stable_diffusion;
use tch::{nn::Module, Device, Kind, Tensor};

const GUIDANCE_SCALE: f64 = 7.5;
Expand Down Expand Up @@ -241,7 +241,6 @@ fn run(args: Args) -> anyhow::Result<()> {
let clip_device = cpu_or_cuda("clip");
let vae_device = cpu_or_cuda("vae");
let unet_device = cpu_or_cuda("unet");
let scheduler = sd_config.build_scheduler(n_steps);

let tokenizer = clip::Tokenizer::create(vocab_file, &sd_config.clip)?;
println!("Running with prompt \"{prompt}\".");
Expand Down Expand Up @@ -276,6 +275,11 @@ fn run(args: Args) -> anyhow::Result<()> {
// scale the initial noise by the standard deviation required by the scheduler
latents *= scheduler.init_noise_sigma();

let scheduler = sd_config.build_scheduler(n_steps);
// let mut scheduler = schedulers::dpmsolver_singlestep::DPMSolverSinglestepScheduler::new(n_steps, Default::default());
// Using this scheduler requires mutability, so change the to the following
// scheduler.timesteps().to_owned().iter().enumerate()

for (timestep_index, &timestep) in scheduler.timesteps().iter().enumerate() {
println!("Timestep {timestep_index}/{n_steps}");
let latent_model_input = Tensor::cat(&[&latents, &latents], 0);
Expand Down
67 changes: 67 additions & 0 deletions src/schedulers/dpmsolver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use crate::schedulers::BetaSchedule;
use crate::schedulers::PredictionType;

/// The algorithm type for the solver.
///
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub enum DPMSolverAlgorithmType {
/// Implements the algorithms defined in <https://arxiv.org/abs/2211.01095>.
#[default]
DPMSolverPlusPlus,
/// Implements the algorithms defined in <https://arxiv.org/abs/2206.00927>.
DPMSolver,
}

/// The solver type for the second-order solver.
/// The solver type slightly affects the sample quality, especially for
/// small number of steps.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub enum DPMSolverType {
#[default]
Midpoint,
Heun,
}

#[derive(Debug, Clone)]
pub struct DPMSolverSchedulerConfig {
/// The value of beta at the beginning of training.
pub beta_start: f64,
/// The value of beta at the end of training.
pub beta_end: f64,
/// How beta evolved during training.
pub beta_schedule: BetaSchedule,
/// number of diffusion steps used to train the model.
pub train_timesteps: usize,
/// the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
/// sampling, and `solver_order=3` for unconditional sampling.
pub solver_order: usize,
/// prediction type of the scheduler function
pub prediction_type: PredictionType,
/// The threshold value for dynamic thresholding. Valid only when `thresholding: true` and
/// `algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus`.
pub sample_max_value: f32,
/// The algorithm type for the solver
pub algorithm_type: DPMSolverAlgorithmType,
/// The solver type for the second-order solver.
pub solver_type: DPMSolverType,
/// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
/// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10.
pub lower_order_final: bool,
}

impl Default for DPMSolverSchedulerConfig {
fn default() -> Self {
Self {
beta_start: 0.0001,
beta_end: 0.02,
beta_schedule: BetaSchedule::Linear,
train_timesteps: 1000,
solver_order: 2,
prediction_type: PredictionType::Epsilon,
sample_max_value: 1.0,
algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus,
solver_type: DPMSolverType::Midpoint,
lower_order_final: true,
}
}
}
73 changes: 5 additions & 68 deletions src/schedulers/dpmsolver_multistep.rs
Original file line number Diff line number Diff line change
@@ -1,86 +1,23 @@
use super::{betas_for_alpha_bar, BetaSchedule, PredictionType};
use super::{betas_for_alpha_bar, BetaSchedule, PredictionType, dpmsolver::{DPMSolverSchedulerConfig, DPMSolverAlgorithmType, DPMSolverType}};
use std::iter;
use tch::{kind, Kind, Tensor};

/// The algorithm type for the solver.
///
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub enum DPMSolverAlgorithmType {
/// Implements the algorithms defined in <https://arxiv.org/abs/2211.01095>.
#[default]
DPMSolverPlusPlus,
/// Implements the algorithms defined in <https://arxiv.org/abs/2206.00927>.
DPMSolver,
}

/// The solver type for the second-order solver.
/// The solver type slightly affects the sample quality, especially for
/// small number of steps.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub enum DPMSolverType {
#[default]
Midpoint,
Heun,
}

#[derive(Debug, Clone)]
pub struct DPMSolverMultistepSchedulerConfig {
/// The value of beta at the beginning of training.
pub beta_start: f64,
/// The value of beta at the end of training.
pub beta_end: f64,
/// How beta evolved during training.
pub beta_schedule: BetaSchedule,
/// number of diffusion steps used to train the model.
pub train_timesteps: usize,
/// the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
/// sampling, and `solver_order=3` for unconditional sampling.
pub solver_order: usize,
/// prediction type of the scheduler function
pub prediction_type: PredictionType,
/// The threshold value for dynamic thresholding. Valid only when `thresholding: true` and
/// `algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus`.
pub sample_max_value: f32,
/// The algorithm type for the solver
pub algorithm_type: DPMSolverAlgorithmType,
/// The solver type for the second-order solver.
pub solver_type: DPMSolverType,
/// Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
/// find this can stabilize the sampling of DPM-Solver for `steps < 15`, especially for steps <= 10.
pub lower_order_final: bool,
}

impl Default for DPMSolverMultistepSchedulerConfig {
fn default() -> Self {
Self {
beta_start: 0.00085,
beta_end: 0.012,
beta_schedule: BetaSchedule::ScaledLinear,
train_timesteps: 1000,
solver_order: 2,
prediction_type: PredictionType::Epsilon,
sample_max_value: 1.0,
algorithm_type: DPMSolverAlgorithmType::DPMSolverPlusPlus,
solver_type: DPMSolverType::Midpoint,
lower_order_final: true,
}
}
}

pub struct DPMSolverMultistepScheduler {
alphas_cumprod: Vec<f64>,
alpha_t: Vec<f64>,
sigma_t: Vec<f64>,
lambda_t: Vec<f64>,
init_noise_sigma: f64,
lower_order_nums: usize,
/// Direct outputs from learned diffusion model at current and latter timesteps
model_outputs: Vec<Tensor>,
/// List of current discrete timesteps in the diffusion chain
timesteps: Vec<usize>,
pub config: DPMSolverMultistepSchedulerConfig,
pub config: DPMSolverSchedulerConfig,
}

impl DPMSolverMultistepScheduler {
pub fn new(inference_steps: usize, config: DPMSolverMultistepSchedulerConfig) -> Self {
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => Tensor::linspace(
config.beta_start.sqrt(),
Expand Down
Loading