Skip to content

Commit

Permalink
Updated sdeint demo to new syntax.
Browse files Browse the repository at this point in the history
  • Loading branch information
duvenaud committed Apr 26, 2022
1 parent e89b099 commit 1f1b207
Showing 1 changed file with 38 additions and 62 deletions.
100 changes: 38 additions & 62 deletions examples/sdeint.dx
Original file line number Diff line number Diff line change
@@ -1,26 +1,3 @@
instance [Add a] Add (n->a)
add = \f g. \x. (f x) + (g x)
sub = \f g. \x. (f x) - (g x)
zero = \x. zero

instance [Mul a] Mul (n->a)
mul = \f g. \x. (f x) * (g x)
one = \x. one

instance [VSpace a] VSpace (n->a)
scaleVec = \s f. \x. s .* (f x)

--instance [VSpace a] VSpace (n->m->o->a)
-- scaleVec = \s f. \x y z. s .* (f x y z)

instance [Arbitrary a, Hashable n] Arbitrary (n->a)
arb = \key. \i. arb $ hash key i

:p (sin + cos) 5.0
:p (sin * cos) 5.0
:p (2.8 .* sin) 5.0


' Adaptive-Step Stochastic Differential Equation Solvers

import plot
Expand All @@ -30,60 +7,59 @@ interface HasStandardNormal a:Type

instance HasStandardNormal Float32
randNormal = randn
instance [HasStandardNormal a] HasStandardNormal (n=>a)
instance {a n} [HasStandardNormal a] HasStandardNormal (n=>a)
randNormal = \key.
for i. randNormal (ixkey key i)

UnitInterval = Float
Time = Float
-- def scale (a: n=>Float) (b: n=>Float) : (n=>Float) = for i. a.i * b.i

-- Should these go in the prelude?
def linear_interp [VSpace a]
(z0: a) --o (z1: a) --o (t0: Time) (t1: Time) (t: Time) --o : a =
def linear_interp {a} [VSpace a]
(z0: a) (z1: a) (t0: Time) (t1: Time) (t: Time) : a =
select (t1 == t0) z0 (z0 + ((t - t0) / (t1 - t0)) .* (z1 - z0))

def norm (x: d=>Float) : Float = sqrt $ sum for i. sq x.i
def (./) (x: d=>Float) (y: d=>Float) : d=>Float = for i. x.i / y.i
def norm {d} (x: d=>Float) : Float = sqrt $ sum for i. sq x.i
def (./) {d} (x: d=>Float) (y: d=>Float) : d=>Float = for i. x.i / y.i

def prepend (first: v) (seq: m=>v) : ({head:Unit | tail:m }=>v) =
def prepend {v m} (first: v) (seq: m=>v) : (Unit | m )=>v =
-- Concatenates a single element to the beginning of a sequence.
for idx. case idx of
{| head = () |} -> first
{| tail = i |} -> seq.i
Left _ -> first
Right i -> seq.i

def bbIter [VSpace v, HasStandardNormal v]
def bbIter {v} [VSpace v, HasStandardNormal v]
((key, y, sigma, t):(Key & v & Float & UnitInterval)) :
(Key & v & Float & UnitInterval) =
-- Descend one step in a virtual Brownian tree.
-- Algorithm from "Scalable Gradients for Stochastic Differential Equations"
-- by Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud
-- AISTATS 2020, https://arxiv.org/abs/2001.01328
--\(key, y, sigma, t).
[kDraw, key_left, key_right] = splitKey key
[kDraw, key_left, key_right] = split_key key
t' = abs (t - 0.5)
y' = sigma * (0.5 - t') .* randNormal kDraw
key' = select (t > 0.5) key_left key_right
(key', y + y', sigma / sqrt 2.0, t' * 2.0)

def sampleUnitBM [VSpace v, HasStandardNormal v]
def sampleUnitBM {v} [VSpace v, HasStandardNormal v]
(key:Key) (t:UnitInterval) : v =
-- Brownian motion in interval (0.0, 1.0), where y(0.0) = 0.0
(_, y, _, _) = fold (key, zero, 1.0, t) \i:(Fin 14). bbIter
y

def scaleBM [VSpace v] (f:Time -> v) (t0:Time) (t1:Time) (t:Time) : v =
def scaleBM {v} [VSpace v] (f:Time -> v) (t0:Time) (t1:Time) (t:Time) : v =
(sqrt (t1 - t0)) .* (f ((t - t0) / (t1 - t0)))

def sampleBM [VSpace v, HasStandardNormal v]
def sampleBM {v} [VSpace v, HasStandardNormal v]
(key:Key) (t0:Time) (t1:Time) (t:Time) : v =
curriedBM = sampleUnitBM key
scaleBM curriedBM t0 t1 t

xs = linspace (Fin 1000) 0.0 1.0
ys:(Fin 1000)=>Float = for i. ((sampleUnitBM (newKey 0)) xs.i)
ys:(Fin 1000)=>Float = for i. ((sampleUnitBM (new_key 0)) xs.i)
--:plot zip xs ys
:html showPlot $ xyPlot xs ys
:html show_plot $ xy_plot xs ys


--data DiffusionFunc v:Type =
Expand All @@ -100,18 +76,18 @@ data GenDiffusionProd v:Type =

-- Drift and diffusion product.
-- DiffusionProd takes state and noise and returns dstate/dtime
def Drift (v:Type) : Type = v->Time->v
def Drift (v:Type) : Type = v->Time->v
def DiagDiffusion (v:Type) : Type = v->Time->v
def DiffusionProd (v:Type) : Type = v->Time->v->v
def SDE (v:Type) : Type = (Drift v & DiffusionProd v )
def SDE (v:Type) : Type = (Drift v & DiffusionProd v )

-- Idea: use typeclasses to automatically derive products?

def SkewSymmetricProd (v:Type) : Type = v->Time->v->v -- How to express matrix?
def NegEnergyFunc (v:Type) : Type = v->Time->Float
def StationarySDE (v:Type) : Type = (NegEnergyFunc v & SkewSymmetricProd v & DiffusionProd v)

def stationarySDEPartsToSDE [Mul v, VSpace v] (sta:StationarySDE v) : SDE v =
def stationarySDEPartsToSDE {v} [Mul v, VSpace v] (sta:StationarySDE v) : SDE v =
-- From Section 2.1 of "A Complete Recipe for Stochastic Gradient MCMC"
-- https://arxiv.org/pdf/1506.04696.pdf
(negEnergyFunc, skewSymmetricProd, diffusionProd) = sta
Expand All @@ -135,24 +111,24 @@ def radonNikodym (drift1: Drift Float)
-- Todo: generalize to multivariate.
(sq ((drift1 state t) - (drift2 state t))) / (diffusion state t)

def ito_euler_step [VSpace v] (sde: SDE v) (z:v)
def ito_euler_step {v} [VSpace v] (sde: SDE v) (z:v)
(t:Time) (dt:Time) (noise:v) : (v & Time) =
(drift, diffprod) = sde
new_z = z + dt .* (drift z t) + (diffprod z t noise)
(new_z, t + dt)

def ito_euler_step(y, t, args, noise, t_delta, f, g_prod, gdg_prod=None):
return y + t_delta * f(y, t, args) + g_prod(y, t, args, noise)
--def ito_milstein_step(y, t, args, noise, t_delta, f, g_prod, gdg_prod):
-- # Equation 20 from https://infoscience.epfl.ch/record/143450/files/sde_tutorial.pdf
-- return y + t_delta * f(y, t, args) + g_prod(y, t, args, noise) \
-- + 0.5 * gdg_prod(y, t, args, noise**2 - t_delta)

def ito_milstein_step(y, t, args, noise, t_delta, f, g_prod, gdg_prod):
# Equation 20 from https://infoscience.epfl.ch/record/143450/files/sde_tutorial.pdf
return y + t_delta * f(y, t, args) + g_prod(y, t, args, noise) \
+ 0.5 * gdg_prod(y, t, args, noise**2 - t_delta)

def ito_milstein_step [VSpace v] (sde: SDE v) (z:v)
def ito_milstein_step {v} [VSpace v] (sde: SDE v) (z:v)
(t:Time) (dt:Time) (noise:v) : (v & Time) =
-- Equation 20 from https://infoscience.epfl.ch/record/143450/files/sde_tutorial.pdf
(drift, diffprod) = sde
new_z = z + dt .* (drift z t) + (diffprod z t noise)
--y + t_delta * f(y, t, args) + g_prod(y, t, args, noise) \
-- + 0.5 * gdg_prod(y, t, args, noise**2 - t_delta)
new_z = z + dt .* (drift z t) + 0.5 .* (diffprod z t (sq noise))
(new_z, t + dt)


Expand All @@ -164,7 +140,7 @@ def strat_milstein_step(y, t, args, noise, t_delta, f, g_prod, gdg_prod):



def sdeint (sde: SDE (d=>Float))
def sdeint {d n} (sde: SDE (d=>Float))
(initial_state: d=>Float) (t0: Time) (eval_times: n=>Time)
(key: Key) : n=>d=>Float =
-- eval_times must be strictly increasing. Todo: enforce / automate.
Expand All @@ -189,7 +165,7 @@ def sdeint (sde: SDE (d=>Float))
(z, t, new_z, new_t, dt)

-- Take steps until we pass target_t
new_state = yieldState init_carry \state.
new_state = yield_state init_carry \state.
if shouldContinue (get state) then
while do
state := step (get state)
Expand All @@ -203,7 +179,7 @@ def sdeint (sde: SDE (d=>Float))
init_carry = (initial_state, t0, initial_state, t0, dt)
snd $ scan init_carry integrate_to_next_time

def error_ratio (z_full:d=>Float) (z_half:d=>Float) (rtol:Float) (atol:Float) : Float =
def error_ratio {d} (z_full:d=>Float) (z_half:d=>Float) (rtol:Float) (atol:Float) : Float =
-- z_full obtained with one full step.
-- z_half obtained with two half steps.
eps = 1.0e-7
Expand Down Expand Up @@ -235,7 +211,7 @@ def step_size (error_estimate:Float)
new_step_size = prev_step_size * factor
(new_step_size, Just prev_error_ratio)

def adaptive_sdeint (sde: SDE (d=>Float))
def adaptive_sdeint {d n} (sde: SDE (d=>Float))
(z0: d=>Float) (t0: Time) (times: n=>Time)
(key: Key) : n=>d=>Float =
-- z0: the initial value for the state.
Expand Down Expand Up @@ -278,7 +254,7 @@ def adaptive_sdeint (sde: SDE (d=>Float))
select (ratio <= 1.0) move_state stay_state

-- Take steps until we pass target_t
new_state = yieldState init_carry \state.
new_state = yield_state init_carry \state.
if shouldContinue (get state) then
while do
state := possibly_step (get state)
Expand All @@ -294,16 +270,16 @@ def adaptive_sdeint (sde: SDE (d=>Float))



drift : v=>Float -> Time -> v=>Float = \z t. -z
def drift {v} (z: v=>Float) (t:Time): v=>Float = -z

def diffprod (z:v=>Float) (t:Time) (noise:v=>Float): v=>Float = for i. noise.i
def diffprod {v} (z:v=>Float) (t:Time) (noise:v=>Float): v=>Float = for i. noise.i

--sde = (drift, diffprod)

z0 = [1.0]
t0 = 0.1
times = linspace (Fin 1000) 0.2 1.9
key = newKey 0
key = new_key 0

%time
yout = sdeint (drift, diffprod) z0 t0 times key
Expand All @@ -314,7 +290,7 @@ yout' = adaptive_sdeint (drift, diffprod) z0 t0 times key
--:p yout - yout'

%time
:html showPlot $ xyPlot times for i. yout.i.(0@(Fin 1))
:html show_plot $ xy_plot times for i. yout.i.(0@(Fin 1))

%time
:html showPlot $ xyPlot times for i. yout'.i.(0@(Fin 1))
:html show_plot $ xy_plot times for i. yout'.i.(0@(Fin 1))

0 comments on commit 1f1b207

Please sign in to comment.