Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multisets #1335

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added lots of index sets to prelude.
  • Loading branch information
duvenaud committed Sep 24, 2023
commit 8575c177869aeac67bae13da005f0aa6dfc7beca
10 changes: 9 additions & 1 deletion examples/ctc.dx
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ instance Ix(FenceAndPosts n) given (n|Ix)
False -> Posts $ unsafe_from_ordinal (intdiv2 o)

instance NonEmpty(FenceAndPosts n) given (n|Ix)
first_ix = unsafe_from_ordinal 0
pass

instance Eq(FenceAndPosts a) given (a|Ix|Eq)
def (==)(x, y) = case x of
@@ -220,3 +220,11 @@ or the paper.
sum for i:(Fin 3=>Vocab).
ls_to_f $ ctc blank logits i
> 0.5653746


'One major advantage of Dex is its parallelism-preserving autodiff.
The original CTC paper, and most CUDA implementations, used hand-written
reverse-mode derivatives. Dex should be able to
prodice an efficient one automatically. Let's check:

-- grad (\logits. ls_to_f $ ctc blank logits labels) logits
110 changes: 100 additions & 10 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
@@ -874,27 +874,27 @@ instance Ix(Maybe a) given (a|Ix)
True -> Nothing

interface NonEmpty(n|Ix)
first_ix : n
pass

instance NonEmpty(())
first_ix = unsafe_from_ordinal(0)
pass

instance NonEmpty(Bool)
first_ix = unsafe_from_ordinal 0
pass

instance NonEmpty((a,b)) given (a|NonEmpty, b|NonEmpty)
first_ix = unsafe_from_ordinal 0
pass

instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix)
first_ix = unsafe_from_ordinal 0
pass

-- The below instance is valid, but causes "multiple candidate dictionaries"
-- errors if both Left and Right are NonEmpty.
-- instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b]
-- first_ix = unsafe_from_ordinal _ 0
-- pass

instance NonEmpty(Maybe a) given (a|Ix)
first_ix = unsafe_from_ordinal 0
pass

'## Fencepost index sets

@@ -924,11 +924,14 @@ def right_fence(p:Post n) -> Maybe n given (n|Ix) =
then Nothing
else Just $ unsafe_from_ordinal ix

def first_ix() ->> n given (n|NonEmpty) =
unsafe_from_ordinal(0)

def last_ix() ->> n given (n|NonEmpty) =
unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1))

instance NonEmpty(Post n) given (n|Ix)
first_ix = unsafe_from_ordinal(n=Post n, 0)
pass

def scan(
init:a,
@@ -1704,7 +1707,7 @@ def from_ordinal(i:Nat) -> n given (n|Ix) =
False -> error $ from_ordinal_error(i, size n)

-- TODO: should this be called `from_ordinal`?
def to_ix(i:Nat) -> Maybe n given (n|Ix) =
def to_ix(i:Nat) -> Maybe n given (n|Ix) =
case i < size n of
True -> Just $ unsafe_from_ordinal i
False -> Nothing
@@ -2266,6 +2269,93 @@ instance Subset(b, Either(a,b)) given (a|Data, b|Data)
Left( x) -> error "Can't project Left branch to Right branch"
Right(x) -> x

instance Subset(n=>a, n=>b) given (n|Ix, a|Data, b|Data) (Subset a b)
def inject'(xs) = for i. inject xs[i]
def project'(xs') =
xs = for i. project xs'[i]
case any_sat(is_nothing, xs) of
True -> Nothing
False -> Just $ each xs from_just
def unsafe_project'(xs') =
xs = for i. project xs'[i]
case any_sat(is_nothing, xs) of
True -> error "Couldn't project table."
False -> each xs from_just

-- add instance for subset n=>a m=>a given subset n m

instance Subset(List a, List b) given (a|Data, b|Data) (Subset a b)
def inject'(xs') =
AsList(n, xs) = xs'
AsList(n, inject xs)
def project'(l) =
AsList(n, tab) = l
case project tab of
Nothing -> Nothing
Just xs -> Just AsList(n, xs)
def unsafe_project'(l) =
AsList(n, tab) = l
case project tab of
Nothing -> error "Couldn't project list."
Just xs -> AsList(n, xs)

'### All but Last Index set
All the indices of the original index set except the last one.

struct AllButLast(n:Nat, a|Ix) =
val : a

instance Ix(AllButLast n a) given (n:Nat, a|Ix|Data)
def size'() = (size a) -| n
def ordinal(i) = ordinal i.val
def unsafe_from_ordinal(o) = AllButLast $ unsafe_from_ordinal o

instance Subset(AllButLast n a, a) given (n:Nat, a|Ix)
def inject'(x) = x.val
def project'(x) = case (ordinal x) < ((size a) -| n) of
True -> Just (AllButLast x)
False -> Nothing
def unsafe_project'(x) = (AllButLast x)

instance Eq(AllButLast n a) given (n:Nat, a|Eq|Ix)
def (==)(x, y) = x.val == y.val

def unsafe_increment(i:n) -> n given (n|Ix) = from_ordinal (ordinal i + 1)
def next(i: AllButLast 1 n) -> n given (n|Ix) = unsafe_increment i.val
def get_next_m(tab:n=>a, i:AllButLast m n) -> List a given (n|Ix, m:Nat, a) =
-- The list returned always has size (n - m), but can't spell that yet.
to_list $ for j:(Fin m). tab[unsafe_from_ordinal (ordinal i + ordinal j)]

'### Fence and Inner Posts
A custom datatype and index set
that interleaves the elements of a table with another set
of values representing all the spaces in between those elements,
not including the 2 ends.

data FenceAndInnerPosts(n|Ix) =
Fence(n)
InnerPost(AllButLast 1 n)

instance Ix(FenceAndInnerPosts n) given (n|Ix)
def size'() = 2 * size n -| 1
def ordinal(i) = case i of
Fence j -> 2 * ordinal j
InnerPost j -> 2 * ordinal j + 1
def unsafe_from_ordinal(o) =
case is_odd o of
False -> Fence $ unsafe_from_ordinal (intdiv2 o)
True -> InnerPost $ unsafe_from_ordinal (intdiv2 o)

instance Eq(FenceAndInnerPosts a) given (a|Ix|Eq)
def (==)(x, y) = case x of
Fence x -> case y of
Fence y -> x == y
InnerPost y -> False
InnerPost x -> case y of
Fence y -> False
InnerPost y -> x == y


'### Index set for tables

def int_to_reversed_digits(k:Nat) -> a=>b given (a|Ix, b|Ix) =
@@ -2291,7 +2381,7 @@ instance Ix(a=>b) given (a|Ix, b|Ix)
def unsafe_from_ordinal(i) = int_to_reversed_digits i

instance NonEmpty(a=>b) given (a|Ix, b|NonEmpty)
first_ix = unsafe_from_ordinal 0
pass

'### Stack