Skip to content

Commit

Permalink
instances complete
Browse files Browse the repository at this point in the history
  • Loading branch information
o1lo01ol1o committed Feb 12, 2018
1 parent 464200f commit 65bcdbd
Show file tree
Hide file tree
Showing 18 changed files with 1,007 additions and 699 deletions.
9 changes: 8 additions & 1 deletion diffhask.cabal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: diffhask
version: 0.1.0.0
description: DSL for forward and reverse mode automatic differentiation in haskell. instpired by DiffSharp.
description: DSL for forward and reverse mode automatic differentiation in haskell. Port of DiffSharp.
homepage: https://github.com/o1lo01ol1o/diffhask
bug-reports: https://github.com/o1lo01ol1o/diffhask/issues
license: MIT
Expand Down Expand Up @@ -29,6 +29,7 @@ library
, Internal.NumHask.Algebra.Module
, Internal.NumHask.Algebra.Metric
, Internal.NumHask.Algebra.Singleton
, Internal.NumHask.Algebra.Diff
, Internal.NumHask.Prelude


Expand All @@ -48,6 +49,12 @@ library
, dependent-sum-template
, dependent-sum
, dependent-map
, numhask
, numhask-array







Expand Down
208 changes: 158 additions & 50 deletions src/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,17 @@ import Control.Monad.State.Strict (State, evalState, get, gets,
modify, put, runState, (>=>))
import qualified Data.Map as M (Map, empty, insert, lookup,
update, updateLookupWithKey)
import Internal.Internal hiding (Differentiable(..))
import Internal.Internal hiding (Differentiable (..))
import Internal.NumHask.Prelude hiding (State, diff, evalState,
runState)
import Lens.Micro ((%~), (&), (.~), (^.))
import Prelude (error)
import qualified Protolude as P

-- $setup
-- >>> :set -XDataKinds
-- >>> :set -XOverloadedLists
-- >>> :set -XTypeFamilies
-- >>> :set -XFlexibleContexts
-- >>> :set -XNoImplicitPrelude
-- >>> let b = D 2 :: (D Float)
-- >>> let a = D 3 :: (D Float)

--FIXME: prune redundancy

type AdditiveDifferentiable t r
= ( --AdditiveUnital (D r t) r t
--, AdditiveUnital (Computation r t (D r t)) r t

--,
= (
AdditiveMagma (D r t) (D r t) r t
, AdditiveMagma (Computation r t (D r t)) (D r t) r t
, AdditiveMagma (Computation r t (D r t)) (Computation r t (D r t)) r t
Expand Down Expand Up @@ -181,11 +169,13 @@ type Differentiable t r


-- Get Tangent
t :: forall a r. () => D r a -> Computation r a (Tangent r a)
t :: forall r a. (AdditiveUnital (D r a) r a)
=> D r a
-> Computation r a (Tangent r a)
t =
\case
D _ -> pure (zero :: D r a)
DF _ at _ -> pure at
D _ -> pure (zero :: (Tangent r a))
DF _ at _ -> pure (at :: (Tangent r a))
DR {} -> error "Can't get tangent of a reverse node"


Expand Down Expand Up @@ -214,22 +204,28 @@ addDeltas ::
-> Computation r a (D r a)
addDeltas a b =
case (a, b) of
(D xa, D xb) -> a + b
(Dm ma, D xb) -> a .+ b
(D xa, Dm mb) -> a +. b
(D xa, D xb) -> a + b
(Dm ma, D xb) -> a .+ b
(D xa, Dm mb) -> a +. b
(Dm ma, Dm mb) -> a .+. b

applyDelta :: () => UID
applyDelta ::
( Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> UID
-> D r a
-> Adjoints r a
-> Maybe (Computation r a (D r a))
applyDelta tag dlta adjs=
applyDelta tag dlta adjs =
case M.lookup tag adjs of
Just v -> Just $ do
e <- dlta
r <- addDeltas v e
modify (\st -> st & adjoints .~ M.update (const . Just $ r) tag adjs)
return r
Just v -> Just rval
where rval = do
r <- addDeltas v dlta
modify
(\st -> st & adjoints .~ M.update (const . Just $ r) tag adjs)
return r
Nothing -> Nothing

decrementFanout :: UID -> Fanouts -> (Maybe Tag, Fanouts)
Expand All @@ -250,8 +246,14 @@ incrementFanout u = do
put (st & fanouts %~ const a)
return f)

zeroAdj ::
forall r a. (AdditiveUnital (D r a) r a)
=> UID
-> Computation r a ()
zeroAdj uniq = do
modify (\st -> st & adjoints %~ M.insert uniq ((zero :: D r a)))

reset :: ( Show a) => [D r a] -> Computation r a ()
reset :: (AdditiveUnital (D r a) r a, Show a) => [D r a] -> Computation r a ()
reset l =
case l of
[] -> return ()
Expand All @@ -261,15 +263,23 @@ reset l =
fanout <- incrementFanout uniq
if fanout == Tag 1 then
do
modify (\st -> st & adjoints %~ M.insert uniq (X (zero :: D r a)))
zeroAdj uniq
x <- resetEl o
reset $ x `mappend` xs -- verify this
else reset xs
reset xs
_ -> reset xs

-- recursively pushes nodes onto the reverse mode stack and evaluates partials
push :: () => [(D r a, D r a)] -> Computation r a ()
push ::
( AdditiveUnital (D r a) r a
, Show a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> [(D r a, D r a)]
-> Computation r a ()
push l =
case l of
[] -> return ()
Expand All @@ -292,23 +302,46 @@ push l =
Nothing -> error "key not found in adjoints!"
_ -> push xs

reverseReset :: ( Show a) => D r a -> Computation r a ()
reverseReset ::
( AdditiveUnital (D r a) r a
, Show a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> D r a
-> Computation r a ()
reverseReset d = do
modify (& fanouts .~ M.empty )
reset [ d]

reverseProp :: ( Show a) => D r a -> D r a -> Computation r a ()
reverseProp ::
( AdditiveUnital (D r a) r a
, Show a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> D r a
-> D r a
-> Computation r a ()
reverseProp v d = do
reverseReset d
push [( v, d)]

{-# INLINE primalTanget #-}
primalTanget :: ( Show a) => D r a -> Computation r a (D r a, Tangent r a)
primalTanget ::
(Show a, AdditiveUnital (D r a) r a)
=> D r a
-> Computation r a (D r a, Tangent r a)
primalTanget d = do
ct <- t d
pure (p d, ct)

adjoint :: forall a r. ( Show a) => D r a -> Computation r a (D r a)
adjoint ::
forall a r. (Show a, AdditiveUnital (D r a) r a)
=> D r a
-> Computation r a (D r a)
adjoint d =
case d of
DR _ _ _ uniq -> do
Expand All @@ -327,22 +360,46 @@ compute :: (P.RealFrac a) => Computation r a (b) -> b
compute f = evalState f initComp

{-# INLINE computeAdjoints' #-}
computeAdjoints' :: forall a r. ( Show a) => D r a -> Computation r a ()
computeAdjoints' ::
forall a r.
( Show a
, AdditiveUnital (D r a) r a
, MultiplicativeUnital (D r a) r a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> D r a
-> Computation r a ()
computeAdjoints' d = do
modify (\st -> st & adjoints .~ M.empty)
o <- pure (one :: D r a)
reverseProp o d

{-# INLINE computeAdjoints #-}
computeAdjoints :: ( Show a) => D r a -> Computation r a (Adjoints r a)
computeAdjoints ::
( Show a
, AdditiveUnital (D r a) r a
, MultiplicativeUnital (D r a) r a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> D r a
-> Computation r a (Adjoints r a)
computeAdjoints d = do
computeAdjoints' d
st <- get
return $ st ^. adjoints
{-# INLINE diff' #-}

diff' :: forall a r.
( Show a)
( Show a
, AdditiveUnital (D r a) r a
, MultiplicativeUnital (D r a) r a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a)
=> (D r a -> Computation r a (D r a))
-> D r a
-> Computation r a (D r a, Tangent r a)
Expand All @@ -354,7 +411,12 @@ diff' f x = do
{-# INLINE diff #-}

diff ::
( Show a)
( Show a
, AdditiveUnital (D r a) r a
, MultiplicativeUnital (D r a) r a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a)
=> (D r a -> Computation r a (D r a))
-> D r a
-> Computation r a (Tangent r a)
Expand All @@ -363,7 +425,13 @@ diff f x =

{-# INLINE diffn #-}
diffn ::
( Show a)
( Show a
, AdditiveUnital (D r a) r a
, MultiplicativeUnital (D r a) r a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> Int
-> (D r a -> Computation r a (D r a))
-> D r a
Expand All @@ -376,7 +444,13 @@ diffn n f x =
else go n f x
where
go ::
( Show a)
( Show a
, AdditiveUnital (D r a) r a
, MultiplicativeUnital (D r a) r a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> Int
-> (D r a -> Computation r a (D r a))
-> D r a
Expand All @@ -388,7 +462,13 @@ diffn n f x =

{-# INLINE diffn' #-}
diffn' ::
( Show a)
( Show a
, AdditiveUnital (D r a) r a
, MultiplicativeUnital (D r a) r a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> Int
-> (D r a -> Computation r a (D r a))
-> D r a
Expand All @@ -398,12 +478,16 @@ diffn' n f x = do
again <- diffn n f x
return (it, again)

-- | Reverse Multiplication
-- >>> compute $ grad' (\x -> x * a) a
-- (D 9.0,D 1.0)
{-# INLINE grad' #-}
grad' ::
(Trace Noop r a, Show a)
( Trace Noop r a
, Show a
, AdditiveUnital (D r a) r a
, MultiplicativeUnital (D r a) r a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> (D r a -> Computation r a (D r a))
-> D r a
-> Computation r a (D r a, (D r a))
Expand All @@ -417,24 +501,48 @@ grad' f x = do

{-# INLINE grad #-}
grad ::
(Trace Noop r a, Show a) =>
(D r a -> Computation r a (D r a))
( Trace Noop r a
, Show a
, AdditiveUnital (D r a) r a
, MultiplicativeUnital (D r a) r a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> (D r a -> Computation r a (D r a))
-> D r a
-> Computation r a (D r a)
grad f x = do
(_, g)<- grad' f x
return g

-- Original value and Jacobian product of `f`, at point `x`, along `v`. Forward AD.
jacobian' :: ( Show a) =>
(D r a -> Computation r a (D r a)) -> Tangent r a -> Primal r a -> Computation r a (D r a, Tangent r a)
jacobian' ::
( Show a
, Show a
, AdditiveUnital (D r a) r a
, MultiplicativeUnital (D r a) r a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> (D r a -> Computation r a (D r a))
-> Tangent r a
-> Primal r a
-> Computation r a (D r a, Tangent r a)
jacobian' f x v = do
ntg <- getNextTag
fout <- f $ mkForward ntg v x
primalTanget fout

jacobian ::
(Show a)
( Show a
, AdditiveUnital (D r a) r a
, MultiplicativeUnital (D r a) r a
, Additive (D r a) (D r a) r a
, AdditiveModule r (D r a) (D r a) a
, AdditiveBasis r (D r a) (D r a) a
)
=> (D r a -> Computation r a (D r a))
-> Tangent r a
-> Primal r a
Expand Down
Loading

0 comments on commit 65bcdbd

Please sign in to comment.