Skip to content

Commit

Permalink
Added WIP SDE Solver demo
Browse files Browse the repository at this point in the history
  • Loading branch information
duvenaud committed Apr 13, 2021
1 parent c8a5296 commit 07c9578
Showing 1 changed file with 256 additions and 0 deletions.
256 changes: 256 additions & 0 deletions examples/sdeint.dx
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
' Adaptive-Step Stochastic Differential Equation Solvers

' Todo: switch to new while syntax.

import plot

interface HasStandardNormal a:Type
randNormal : Key -> a

instance HasStandardNormal Float32
randNormal = randn
instance [HasStandardNormal a] HasStandardNormal (n=>a)
randNormal = \key.
for i. randNormal (ixkey key i)

UnitInterval = Float
Time = Float
-- def scale (a: n=>Float) (b: n=>Float) : (n=>Float) = for i. a.i * b.i

-- Should these go in the prelude?
def linear_interp
(_: VSpace a) ?=> (z0: a) --o (z1: a) --o (t0: Time) (t1: Time) (t: Time) --o : a =
select (t1 == t0) z0 (z0 + ((t - t0) / (t1 - t0)) .* (z1 - z0))

def norm (x: d=>Float) : Float = sqrt $ sum for i. sq x.i
def (./) (x: d=>Float) (y: d=>Float) : d=>Float = for i. x.i / y.i

def prepend (first: v) (seq: m=>v) : ({head:Unit | tail:m }=>v) =
-- Concatenates a single element to the beginning of a sequence.
for idx. case idx of
{| head = () |} -> first
{| tail = i |} -> seq.i

def bbIter (_:VSpace v) ?=> (_: HasStandardNormal v) ?=>
((key, y, sigma, t):(Key & v & Float & UnitInterval)) :
(Key & v & Float & UnitInterval) =
-- Descend one step in a virtual Brownian tree.
-- Algorithm from "Scalable Gradients for Stochastic Differential Equations"
-- by Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud
-- AISTATS 2020, https://arxiv.org/abs/2001.01328
--\(key, y, sigma, t).
[kDraw, kL, kR] = splitKey key
t' = abs (t - 0.5)
y' = sigma * (0.5 - t') .* randNormal kDraw
key' = select (t > 0.5) kL kR
(key', y + y', sigma / sqrt 2.0, t' * 2.0)

def sampleUnitBM (_:VSpace v) ?=> (_: HasStandardNormal v) ?=>
(key:Key) (t:UnitInterval) : v =
-- Brownian motion in interval (0.0, 1.0), where y(0.0) = 0.0
(_, y, _, _) = fold (key, zero, 1.0, t) \i:(Fin 14). bbIter
y

def scalefunc (_:VSpace v) ?=> (f:Time -> v) (t0:Time) (t1:Time) : (Time -> v) =
\t. (sqrt (t1 - t0)) .* (f ((t - t0) / (t1 - t0)))

def sampleBM (_:VSpace v) ?=> (_: HasStandardNormal v) ?=>
(key:Key) (t0:Time) (t1:Time) (t:Time) : v =
curriedBM = sampleUnitBM key
sf = scalefunc curriedBM t0 t1
sf t

xs = linspace (Fin 1000) 0.0 1.0
ys:(Fin 1000)=>Float = for i. ((sampleUnitBM (newKey 0)) xs.i)
--:plot zip xs ys
:html showPlot $ xyPlot xs ys


--data DiffusionFunc v:Type =
-- HasProd (v->v->Time->v)
-- ManualProd (v->Time->v)


data GenDiffusionProd v:Type =
HasProd (v->v->Time->v)
ManualProd (v->Time->v)

--def evaldiffprod (f:GenDiffusionProd v) (state:v) (time:Time) (noise:v) : v =
-- case v of

-- Drift and diffusion product.
-- Diffusion takes state and noise and returns dstate/dtime
def Drift (v:Type) : Type = v->Time->v
def Diffusion (v:Type) : Type = v->Time->v
def DiffusionProd (v:Type) : Type = v->v->Time->v
def SDE (v:Type) : Type = (Drift v & DiffusionProd v )




def radonNikodym (drift1: Drift Float)
(drift2: Drift Float)
(diffusion: Diffusion Float)
(t0: Time) (t1:Time) : Drift Float =
-- Dynamics of simple Monte Carlo estimatr of KL divergence between
-- two SDEs that share a diffusion function.
-- Todo: generalize to multivariate.
\state t.
(sq ((drift1 state t) - (drift2 state t))) / (diffusion state t)

def ito_euler_step (_:VSpace v) ?=> (sde: SDE v) (z:v)
(t:Time) (dt:Time) (noise:v) : (v & Time) =
(drift, diffprod) = sde
new_z = z + dt .* (drift z t) + (diffprod z noise t)
(new_z, t + dt)

def sdeint (sde: SDE (d=>Float))
(z0: d=>Float) (t0: Time) (times: n=>Time)
(key: Key) : n=>d=>Float =
-- z0: the initial value for the state.
-- t: times for evaluation. values must be strictly increasing.
-- Returns:
-- Values of the solution at each time point in times.
dt = 0.0001
max_iters = 10000

lasttime = maximum (prepend t0 times)
-- Todo: think about how to extend the BM beyond this interval later.
noisefunc = \t. sampleBM key t0 lasttime t

integrate_to_next_time = \iter init_carry.
target_t = times.iter

stopping_condition = \(_, _, _, t, dt).
(t < target_t) && (dt > 0.0) && (ordinal iter < max_iters)

step = \(old_z, old_t, z, t, dt).
dB = (noisefunc (t + dt)) - (noisefunc t)
(new_z, new_t) = ito_euler_step sde z t dt dB
(z, t, new_z, new_t, dt)

-- Take steps until we pass target_t
new_state = snd $ withState init_carry \state.
while (\(). stopping_condition (get state)) \().
state := step (get state)
(old_z, old_t, cur_z, cur_t, _) = new_state

-- Interpolate to the target time.
z_target = linear_interp old_z cur_z old_t cur_t target_t
(new_state, z_target)

init_carry = (z0, t0, z0, t0, dt)
snd $ scan init_carry integrate_to_next_time


def clamp_min (eps:Float) (x:Float) : Float = max x eps

def error_ratio (z_full:d=>Float) (z_half:d=>Float) (rtol:Float) (atol:Float) : Float =
-- z_full obtained with one full step.
-- z_half obtained with two half steps.
eps = 1.0e-7
tol = for i. clamp_min eps $ rtol * (max (abs z_full.i) (abs z_half.i)) + atol
clamp_min eps $ norm $ (z_full - z_half) ./ tol

def step_size (error_estimate:Float)
(prev_step_size:Time)
(maybe_prev_error_ratio: Maybe Float) : (Float & Maybe Float) =
safety=0.9
facmin=0.2
facmax=1.4
(pfactor, ifactor) = case error_estimate > 1.0 of
True -> (0.0, 1.0 / 1.5)
False -> (0.13, 1.0 / 4.5)

error_ratio = safety / error_estimate
prev_error_ratio = case maybe_prev_error_ratio of
Nothing -> error_ratio
Just prev_error_ratio -> prev_error_ratio

factor = pow error_ratio $ ifactor * ( pow (error_ratio / prev_error_ratio) pfactor)

(prev_error_ratio, facmin) = case error_estimate <= 1.0 of
True -> (error_ratio, 1.0)
False -> (prev_error_ratio, facmin)

factor = min facmax (max facmin factor)
new_step_size = prev_step_size * factor
(new_step_size, Just prev_error_ratio)

def adaptive_sdeint (sde: SDE (d=>Float))
(z0: d=>Float) (t0: Time) (times: n=>Time)
(key: Key) : n=>d=>Float =
-- z0: the initial value for the state.
-- t: times for evaluation. values must be strictly increasing.
-- Returns:
-- Values of the solution at each time point in times.
rtol = 0.0001 --1.4e-7 -- relative local error tolerance for solver.
atol = 0.00001 --1.4e-7 -- absolute local error tolerance for solver.
max_iters = 10000
init_dt = 1.0e-3

lasttime = maximum (prepend t0 times)
-- Todo: think about how to extend the BM beyond this interval later.
noisefunc = \t. sampleBM key t0 lasttime t

integrate_to_next_time = \iter init_carry.
target_t = times.iter

stopping_condition = \(_, _, _, t, dt, _).
(t < target_t) && (dt > 0.0) && (ordinal iter < max_iters)

step = \(old_z, old_t, z, t, dt).
dB = (noisefunc (t + dt)) - (noisefunc t)
(new_z, new_t) = ito_euler_step sde z t dt dB
(z, t, new_z, new_t, dt)

possibly_step = \(old_z, old_t, z, t, dt, maybe_prev_error_ratio).
-- Take 1 full step.
(_, _, new_z_full, new_t, _) = step (old_z, old_t, z, t, dt)
-- Take 2 half steps.
(_, _, new_z_half, new_t_half, _) = step (old_z, old_t, z, t, dt / 2.0)
(_, _, new_z_2xhf, _, _) = step (old_z, old_t, new_z_half, new_t_half, dt / 2.0)

ratio = error_ratio new_z_full new_z_2xhf rtol atol

(new_dt, maybe_prev_error_ratio) = step_size ratio dt maybe_prev_error_ratio

move_state = (z, t, new_z_2xhf, new_t, new_dt, maybe_prev_error_ratio)
stay_state = (old_z, old_t, z, t, new_dt, maybe_prev_error_ratio)
select (ratio <= 1.0) move_state stay_state

-- Take steps until we pass target_t
new_state = snd $ withState init_carry \state.
while (\(). stopping_condition (get state)) \().
state := possibly_step (get state)
(old_z, old_t, cur_z, cur_t, _, _) = new_state

-- Interpolate to the target time.
z_target = linear_interp old_z cur_z old_t cur_t target_t
(new_state, z_target)

init_carry = (z0, t0, z0, t0, init_dt, Nothing)
snd $ scan init_carry integrate_to_next_time



drift : v=>Float -> Time -> v=>Float = \z t. -z

def diffprod (z:v=>Float) (noise:v=>Float) (t:Time) : v=>Float = for i. noise.i

--sde = (drift, diffprod)

z0 = [1.0]
t0 = 0.1
times = linspace (Fin 1000) 0.2 1.9
key = newKey 0

yout = sdeint (drift, diffprod) z0 t0 times key
yout' = adaptive_sdeint (drift, diffprod) z0 t0 times key

--:p yout - yout'


:html showPlot $ xyPlot times for i. yout.i.(0@(Fin 1))
:html showPlot $ xyPlot times for i. yout'.i.(0@(Fin 1))

0 comments on commit 07c9578

Please sign in to comment.