-
Notifications
You must be signed in to change notification settings - Fork 109
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
256 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
' Adaptive-Step Stochastic Differential Equation Solvers | ||
|
||
' Todo: switch to new while syntax. | ||
|
||
import plot | ||
|
||
interface HasStandardNormal a:Type | ||
randNormal : Key -> a | ||
|
||
instance HasStandardNormal Float32 | ||
randNormal = randn | ||
instance [HasStandardNormal a] HasStandardNormal (n=>a) | ||
randNormal = \key. | ||
for i. randNormal (ixkey key i) | ||
|
||
UnitInterval = Float | ||
Time = Float | ||
-- def scale (a: n=>Float) (b: n=>Float) : (n=>Float) = for i. a.i * b.i | ||
|
||
-- Should these go in the prelude? | ||
def linear_interp | ||
(_: VSpace a) ?=> (z0: a) --o (z1: a) --o (t0: Time) (t1: Time) (t: Time) --o : a = | ||
select (t1 == t0) z0 (z0 + ((t - t0) / (t1 - t0)) .* (z1 - z0)) | ||
|
||
def norm (x: d=>Float) : Float = sqrt $ sum for i. sq x.i | ||
def (./) (x: d=>Float) (y: d=>Float) : d=>Float = for i. x.i / y.i | ||
|
||
def prepend (first: v) (seq: m=>v) : ({head:Unit | tail:m }=>v) = | ||
-- Concatenates a single element to the beginning of a sequence. | ||
for idx. case idx of | ||
{| head = () |} -> first | ||
{| tail = i |} -> seq.i | ||
|
||
def bbIter (_:VSpace v) ?=> (_: HasStandardNormal v) ?=> | ||
((key, y, sigma, t):(Key & v & Float & UnitInterval)) : | ||
(Key & v & Float & UnitInterval) = | ||
-- Descend one step in a virtual Brownian tree. | ||
-- Algorithm from "Scalable Gradients for Stochastic Differential Equations" | ||
-- by Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud | ||
-- AISTATS 2020, https://arxiv.org/abs/2001.01328 | ||
--\(key, y, sigma, t). | ||
[kDraw, kL, kR] = splitKey key | ||
t' = abs (t - 0.5) | ||
y' = sigma * (0.5 - t') .* randNormal kDraw | ||
key' = select (t > 0.5) kL kR | ||
(key', y + y', sigma / sqrt 2.0, t' * 2.0) | ||
|
||
def sampleUnitBM (_:VSpace v) ?=> (_: HasStandardNormal v) ?=> | ||
(key:Key) (t:UnitInterval) : v = | ||
-- Brownian motion in interval (0.0, 1.0), where y(0.0) = 0.0 | ||
(_, y, _, _) = fold (key, zero, 1.0, t) \i:(Fin 14). bbIter | ||
y | ||
|
||
def scalefunc (_:VSpace v) ?=> (f:Time -> v) (t0:Time) (t1:Time) : (Time -> v) = | ||
\t. (sqrt (t1 - t0)) .* (f ((t - t0) / (t1 - t0))) | ||
|
||
def sampleBM (_:VSpace v) ?=> (_: HasStandardNormal v) ?=> | ||
(key:Key) (t0:Time) (t1:Time) (t:Time) : v = | ||
curriedBM = sampleUnitBM key | ||
sf = scalefunc curriedBM t0 t1 | ||
sf t | ||
|
||
xs = linspace (Fin 1000) 0.0 1.0 | ||
ys:(Fin 1000)=>Float = for i. ((sampleUnitBM (newKey 0)) xs.i) | ||
--:plot zip xs ys | ||
:html showPlot $ xyPlot xs ys | ||
|
||
|
||
--data DiffusionFunc v:Type = | ||
-- HasProd (v->v->Time->v) | ||
-- ManualProd (v->Time->v) | ||
|
||
|
||
data GenDiffusionProd v:Type = | ||
HasProd (v->v->Time->v) | ||
ManualProd (v->Time->v) | ||
|
||
--def evaldiffprod (f:GenDiffusionProd v) (state:v) (time:Time) (noise:v) : v = | ||
-- case v of | ||
|
||
-- Drift and diffusion product. | ||
-- Diffusion takes state and noise and returns dstate/dtime | ||
def Drift (v:Type) : Type = v->Time->v | ||
def Diffusion (v:Type) : Type = v->Time->v | ||
def DiffusionProd (v:Type) : Type = v->v->Time->v | ||
def SDE (v:Type) : Type = (Drift v & DiffusionProd v ) | ||
|
||
|
||
|
||
|
||
def radonNikodym (drift1: Drift Float) | ||
(drift2: Drift Float) | ||
(diffusion: Diffusion Float) | ||
(t0: Time) (t1:Time) : Drift Float = | ||
-- Dynamics of simple Monte Carlo estimatr of KL divergence between | ||
-- two SDEs that share a diffusion function. | ||
-- Todo: generalize to multivariate. | ||
\state t. | ||
(sq ((drift1 state t) - (drift2 state t))) / (diffusion state t) | ||
|
||
def ito_euler_step (_:VSpace v) ?=> (sde: SDE v) (z:v) | ||
(t:Time) (dt:Time) (noise:v) : (v & Time) = | ||
(drift, diffprod) = sde | ||
new_z = z + dt .* (drift z t) + (diffprod z noise t) | ||
(new_z, t + dt) | ||
|
||
def sdeint (sde: SDE (d=>Float)) | ||
(z0: d=>Float) (t0: Time) (times: n=>Time) | ||
(key: Key) : n=>d=>Float = | ||
-- z0: the initial value for the state. | ||
-- t: times for evaluation. values must be strictly increasing. | ||
-- Returns: | ||
-- Values of the solution at each time point in times. | ||
dt = 0.0001 | ||
max_iters = 10000 | ||
|
||
lasttime = maximum (prepend t0 times) | ||
-- Todo: think about how to extend the BM beyond this interval later. | ||
noisefunc = \t. sampleBM key t0 lasttime t | ||
|
||
integrate_to_next_time = \iter init_carry. | ||
target_t = times.iter | ||
|
||
stopping_condition = \(_, _, _, t, dt). | ||
(t < target_t) && (dt > 0.0) && (ordinal iter < max_iters) | ||
|
||
step = \(old_z, old_t, z, t, dt). | ||
dB = (noisefunc (t + dt)) - (noisefunc t) | ||
(new_z, new_t) = ito_euler_step sde z t dt dB | ||
(z, t, new_z, new_t, dt) | ||
|
||
-- Take steps until we pass target_t | ||
new_state = snd $ withState init_carry \state. | ||
while (\(). stopping_condition (get state)) \(). | ||
state := step (get state) | ||
(old_z, old_t, cur_z, cur_t, _) = new_state | ||
|
||
-- Interpolate to the target time. | ||
z_target = linear_interp old_z cur_z old_t cur_t target_t | ||
(new_state, z_target) | ||
|
||
init_carry = (z0, t0, z0, t0, dt) | ||
snd $ scan init_carry integrate_to_next_time | ||
|
||
|
||
def clamp_min (eps:Float) (x:Float) : Float = max x eps | ||
|
||
def error_ratio (z_full:d=>Float) (z_half:d=>Float) (rtol:Float) (atol:Float) : Float = | ||
-- z_full obtained with one full step. | ||
-- z_half obtained with two half steps. | ||
eps = 1.0e-7 | ||
tol = for i. clamp_min eps $ rtol * (max (abs z_full.i) (abs z_half.i)) + atol | ||
clamp_min eps $ norm $ (z_full - z_half) ./ tol | ||
|
||
def step_size (error_estimate:Float) | ||
(prev_step_size:Time) | ||
(maybe_prev_error_ratio: Maybe Float) : (Float & Maybe Float) = | ||
safety=0.9 | ||
facmin=0.2 | ||
facmax=1.4 | ||
(pfactor, ifactor) = case error_estimate > 1.0 of | ||
True -> (0.0, 1.0 / 1.5) | ||
False -> (0.13, 1.0 / 4.5) | ||
|
||
error_ratio = safety / error_estimate | ||
prev_error_ratio = case maybe_prev_error_ratio of | ||
Nothing -> error_ratio | ||
Just prev_error_ratio -> prev_error_ratio | ||
|
||
factor = pow error_ratio $ ifactor * ( pow (error_ratio / prev_error_ratio) pfactor) | ||
|
||
(prev_error_ratio, facmin) = case error_estimate <= 1.0 of | ||
True -> (error_ratio, 1.0) | ||
False -> (prev_error_ratio, facmin) | ||
|
||
factor = min facmax (max facmin factor) | ||
new_step_size = prev_step_size * factor | ||
(new_step_size, Just prev_error_ratio) | ||
|
||
def adaptive_sdeint (sde: SDE (d=>Float)) | ||
(z0: d=>Float) (t0: Time) (times: n=>Time) | ||
(key: Key) : n=>d=>Float = | ||
-- z0: the initial value for the state. | ||
-- t: times for evaluation. values must be strictly increasing. | ||
-- Returns: | ||
-- Values of the solution at each time point in times. | ||
rtol = 0.0001 --1.4e-7 -- relative local error tolerance for solver. | ||
atol = 0.00001 --1.4e-7 -- absolute local error tolerance for solver. | ||
max_iters = 10000 | ||
init_dt = 1.0e-3 | ||
|
||
lasttime = maximum (prepend t0 times) | ||
-- Todo: think about how to extend the BM beyond this interval later. | ||
noisefunc = \t. sampleBM key t0 lasttime t | ||
|
||
integrate_to_next_time = \iter init_carry. | ||
target_t = times.iter | ||
|
||
stopping_condition = \(_, _, _, t, dt, _). | ||
(t < target_t) && (dt > 0.0) && (ordinal iter < max_iters) | ||
|
||
step = \(old_z, old_t, z, t, dt). | ||
dB = (noisefunc (t + dt)) - (noisefunc t) | ||
(new_z, new_t) = ito_euler_step sde z t dt dB | ||
(z, t, new_z, new_t, dt) | ||
|
||
possibly_step = \(old_z, old_t, z, t, dt, maybe_prev_error_ratio). | ||
-- Take 1 full step. | ||
(_, _, new_z_full, new_t, _) = step (old_z, old_t, z, t, dt) | ||
-- Take 2 half steps. | ||
(_, _, new_z_half, new_t_half, _) = step (old_z, old_t, z, t, dt / 2.0) | ||
(_, _, new_z_2xhf, _, _) = step (old_z, old_t, new_z_half, new_t_half, dt / 2.0) | ||
|
||
ratio = error_ratio new_z_full new_z_2xhf rtol atol | ||
|
||
(new_dt, maybe_prev_error_ratio) = step_size ratio dt maybe_prev_error_ratio | ||
|
||
move_state = (z, t, new_z_2xhf, new_t, new_dt, maybe_prev_error_ratio) | ||
stay_state = (old_z, old_t, z, t, new_dt, maybe_prev_error_ratio) | ||
select (ratio <= 1.0) move_state stay_state | ||
|
||
-- Take steps until we pass target_t | ||
new_state = snd $ withState init_carry \state. | ||
while (\(). stopping_condition (get state)) \(). | ||
state := possibly_step (get state) | ||
(old_z, old_t, cur_z, cur_t, _, _) = new_state | ||
|
||
-- Interpolate to the target time. | ||
z_target = linear_interp old_z cur_z old_t cur_t target_t | ||
(new_state, z_target) | ||
|
||
init_carry = (z0, t0, z0, t0, init_dt, Nothing) | ||
snd $ scan init_carry integrate_to_next_time | ||
|
||
|
||
|
||
drift : v=>Float -> Time -> v=>Float = \z t. -z | ||
|
||
def diffprod (z:v=>Float) (noise:v=>Float) (t:Time) : v=>Float = for i. noise.i | ||
|
||
--sde = (drift, diffprod) | ||
|
||
z0 = [1.0] | ||
t0 = 0.1 | ||
times = linspace (Fin 1000) 0.2 1.9 | ||
key = newKey 0 | ||
|
||
yout = sdeint (drift, diffprod) z0 t0 times key | ||
yout' = adaptive_sdeint (drift, diffprod) z0 t0 times key | ||
|
||
--:p yout - yout' | ||
|
||
|
||
:html showPlot $ xyPlot times for i. yout.i.(0@(Fin 1)) | ||
:html showPlot $ xyPlot times for i. yout'.i.(0@(Fin 1)) | ||
|