diff --git a/src/train.rs b/src/train.rs index 9b546f06dc..cf77a28c1e 100644 --- a/src/train.rs +++ b/src/train.rs @@ -269,6 +269,131 @@ impl Optimizer for AdadeltaOptimizer { } } +/// Optimizer that implements the Adadelta algorithm. +#[derive(Debug)] +pub struct AdamOptimizer { + learning_rate: Option, + beta1: Option, + beta2: Option, + beta1_power: Option, + beta2_power: Option, + epsilon: Option, + local_step: Option, + assign_add_local_step: Option, +} + +impl Default for AdamOptimizer { + fn default() -> Self { + Self::new() + } +} + +impl AdamOptimizer { + /// Creates a new optimizer with default parameters (learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8). + pub fn new() -> Self { + Self { + learning_rate: None, + beta1: None, + beta2: None, + beta1_power: None, + beta2_power: None, + epsilon: None, + local_step: None, + assign_add_local_step: None, + } + } + + /// Sets the learning rate. Default is 0.001. + pub fn set_learning_rate>(&mut self, learning_rate: T) { + self.learning_rate = Some(learning_rate.into()); + } + + /// Sets beta1, the momentum factor. Default is 0.9. + pub fn set_beta1>(&mut self, beta1: T) { + self.beta1 = Some(beta1.into()); + } + + /// Sets beta2, the momentum factor. Default is 0.999. + pub fn set_beta2>(&mut self, beta2: T) { + self.beta2 = Some(beta2.into()); + } + + /// Sets epsilon, the conditioning. Default is 1e-8. + pub fn set_epsilon>(&mut self, epsilon: T) { + self.epsilon = Some(epsilon.into()); + } + + /// Sets beta1_power and beta2_power. + pub fn set_calculated_parameters(&mut self, scope: &mut Scope) -> Result<()> { + self.local_step = Some(Variable::builder() + .const_initial_value(1.0f32) + .data_type(DataType::Float) + .build(&mut scope.with_op_name("local_step"))?); + + if let Some( local_step) = &self.local_step { + self.assign_add_local_step = Some(ops::assign_add(local_step.output().clone(), ops::constant(1.0f32, scope)?, &mut scope.with_op_name("add_local_step"))?); + + let beta1 = or_constant(scope, &self.beta1, 0.9f32)?; + let beta2 = or_constant(scope, &self.beta2, 0.999f32)?; + + self.beta1_power = Some(ops::pow(beta1, local_step.output().clone(), scope)?.into()); + self.beta2_power = Some(ops::pow(beta2, local_step.output().clone(), scope)?.into()); + } + + Ok(()) + } +} + +impl Optimizer for AdamOptimizer { + fn apply_gradients( + &self, + scope: &mut Scope, + opts: ApplyGradientsOptions, + ) -> Result<(Vec, Operation)> { + let learning_rate = or_constant(scope, &self.learning_rate, 0.001f32)?; + let beta1 = or_constant(scope, &self.beta1, 0.9f32)?; + let beta2 = or_constant(scope, &self.beta2, 0.999f32)?; + let beta1_power = or_constant(scope, &self.beta1_power, 0.0f32)?; + let beta2_power = or_constant(scope, &self.beta2_power, 0.0f32)?; + let epsilon = or_constant(scope, &self.epsilon, 1e-8f32)?; + let mut apply_ops = Vec::new(); + let mut variables = Vec::new(); + for (grad, var) in opts.grads_and_vars { + if let Some(grad) = grad { + let mut scope = scope.new_sub_scope(&var.name); + let m = create_zeros_slot(&mut scope.new_sub_scope("m"), var, None)?; + let v = create_zeros_slot(&mut scope.new_sub_scope("v"), var, None)?; + apply_ops.push(ops::apply_adam( + var.output.clone(), + m.output.clone(), + v.output.clone(), + beta1_power.clone(), + beta2_power.clone(), + learning_rate.clone(), + beta1.clone(), + beta2.clone(), + epsilon.clone(), + grad.clone(), + &mut scope, + )?); + variables.push(m.clone()); + variables.push(v.clone()); + } + } + if let Some( local_step ) = &self.local_step { + variables.push(local_step.clone()); + } + if let Some( assign_add_local_step ) = &self.assign_add_local_step { + apply_ops.push(assign_add_local_step.clone()); + } + let mut no_op = ops::NoOp::new(); + for apply_op in &apply_ops { + no_op = no_op.add_control_input(apply_op.clone()); + } + Ok((variables, no_op.build(scope)?)) + } +} + #[cfg(test)] mod tests { use super::*; @@ -401,6 +526,69 @@ mod tests { ); } + #[test] + fn simple_adam() { + let mut scope = Scope::new_root_scope(); + let x_var = Variable::builder() + .const_initial_value(3.0f32) + .build(&mut scope.with_op_name("x")) + .unwrap(); + let x_squared = ops::mul(x_var.output.clone(), x_var.output.clone(), &mut scope).unwrap(); + let mut optimizer = AdamOptimizer::new(); + optimizer.set_learning_rate(ops::constant(0.00001f32, &mut scope).unwrap()); + optimizer.set_beta1(ops::constant(0.9f32, &mut scope).unwrap()); + optimizer.set_beta2(ops::constant(0.999f32, &mut scope).unwrap()); + optimizer.set_epsilon(ops::constant(1e-8f32, &mut scope).unwrap()); + let _ = optimizer.set_calculated_parameters(&mut scope); + let (minimizer_vars, minimize) = optimizer + .minimize( + &mut scope, + x_squared.into(), + MinimizeOptions::default().with_variables(&[x_var.clone()]), + ) + .unwrap(); + let options = SessionOptions::new(); + let session = Session::new(&options, &scope.graph()).unwrap(); + + let mut run_args = SessionRunArgs::new(); + run_args.add_target(&x_var.initializer); + for var in &minimizer_vars { + run_args.add_target(&var.initializer); + } + session.run(&mut run_args).unwrap(); + + let mut run_args = SessionRunArgs::new(); + run_args.add_target(&minimize); + let x_fetch = run_args.request_fetch(&x_var.output.operation, 0); + + session.run(&mut run_args).unwrap(); + let x_output = run_args.fetch::(x_fetch).unwrap(); + assert_eq!(x_output.len(), 1); + assert!( + x_output[0] >= 2.99992 && x_output[0] <= 2.999999, + "x_output[0] = {}", + x_output[0] + ); + + session.run(&mut run_args).unwrap(); + let x_output = run_args.fetch::(x_fetch).unwrap(); + assert_eq!(x_output.len(), 1); + assert!( + x_output[0] >= 2.99992 && x_output[0] <= 2.999999, + "x_output[0] = {}", + x_output[0] + ); + + session.run(&mut run_args).unwrap(); + let x_output = run_args.fetch::(x_fetch).unwrap(); + assert_eq!(x_output.len(), 1); + assert!( + x_output[0] >= 2.99992 && x_output[0] <= 2.999999, + "x_output[0] = {}", + x_output[0] + ); + } + #[test] fn xor_nn() { let mut scope = Scope::new_root_scope();