diff --git a/examples/ctc.dx b/examples/ctc.dx index 34b26998f..95b4bd6fb 100644 --- a/examples/ctc.dx +++ b/examples/ctc.dx @@ -38,7 +38,7 @@ instance Ix(FenceAndPosts n) given (n|Ix) False -> Posts $ unsafe_from_ordinal (intdiv2 o) instance NonEmpty(FenceAndPosts n) given (n|Ix) - first_ix = unsafe_from_ordinal 0 + pass instance Eq(FenceAndPosts a) given (a|Ix|Eq) def (==)(x, y) = case x of @@ -220,3 +220,11 @@ or the paper. sum for i:(Fin 3=>Vocab). ls_to_f $ ctc blank logits i > 0.5653746 + + +'One major advantage of Dex is its parallelism-preserving autodiff. +The original CTC paper, and most CUDA implementations, used hand-written +reverse-mode derivatives. Dex should be able to +prodice an efficient one automatically. Let's check: + +-- grad (\logits. ls_to_f $ ctc blank logits labels) logits diff --git a/lib/parser.dx b/lib/parser.dx index 805cd71dc..4b9c20bfc 100644 --- a/lib/parser.dx +++ b/lib/parser.dx @@ -40,11 +40,12 @@ def run_parser_partial(s:String, parser:Parser a) -> Maybe a given (a) = '## Primitive combinators -def p_char(c:Char) -> Parser () = MkParser \h. +def p_char(c:Char) -> Parser Char = MkParser \h. i = get h.offset c' = index_list h.input i assert (c == c') h.offset := i + 1 + c' def p_eof() ->> Parser () = MkParser \h. assert $ get h.offset >= list_length h.input @@ -99,15 +100,16 @@ def parse_digit() ->> Parser Int = try $ MkParser \h. def optional(p:Parser a) -> Parser (Maybe a) given (a) = (MkParser \h. Just (parse h p)) <|> returnP Nothing -def parse_many(parser:Parser a) -> Parser (List a) given (a|Data) = MkParser \h. - yield_state (AsList _ []) \results. - iter \_. - maybeVal = parse h $ optional parser - case maybeVal of - Nothing -> Done () - Just x -> - push results x - Continue +def parse_many(parser:Parser a) -> Parser (List a) given (a|Data) = + MkParser \h. + yield_state (AsList _ []) \results. + iter \_. + maybeVal = parse h $ optional parser + case maybeVal of + Nothing -> Done () + Just x -> + push results x + Continue def parse_some(parser:Parser a) -> Parser (List a) given (a|Data) = MkParser \h. @@ -125,9 +127,9 @@ def parse_int() ->> Parser Int = MkParser \h. x = parse h $ parse_unsigned_int case negSign of Nothing -> x - Just () -> (-1) * x + Just _ -> (-1) * x -def bracketed(l:Parser (), r:Parser (), body:Parser a) -> Parser a given (a) = +def bracketed(l:Parser Char, r:Parser Char, body:Parser a) -> Parser a given (a) = MkParser \h. _ = parse h l ans = parse h body @@ -137,8 +139,11 @@ def bracketed(l:Parser (), r:Parser (), body:Parser a) -> Parser a given (a) = def parens(parser:Parser a) -> Parser a given (a) = bracketed (p_char '(') (p_char ')') parser + +'## String Utilities + def split(space:Char, s:String) -> List String = - def trailing_spaces(space:Parser (), body:Parser a) -> Parser a given (a) = + def trailing_spaces(space:Parser Char, body:Parser a) -> Parser a given (a) = MkParser \h. ans = parse h body _ = parse h $ parse_many space @@ -149,3 +154,50 @@ def split(space:Char, s:String) -> List String = case run_parser s split_parser of Just l -> l Nothing -> AsList _ [] + +def join(space:a, strings:List(List a)) -> List a given (a|Data) = + AsList(n_string, string_table) = strings + yield_accum (ListMonoid a) \r. + for i:FenceAndInnerPosts(Fin n_string). + case i of + Fence j -> r += string_table[j] + InnerPost _ -> r += AsList(1, [space]) + +def find_first(word:m=>a, text:n=>a) -> Maybe n given (n|Ix, m|Ix, a|Eq) = + -- This implementation has O(nm) complexity, could be O(m + n). + -- Could maybe be nicer using the AllButLast index set. + case (size m > size n) || (size m == 0) of + True -> Nothing + False -> + bounded_iter (unsafe_nat_diff (size n) (size m)) Nothing \i. + next_substring = for j:m. text[unsafe_from_ordinal (i + ordinal j)] + case word == next_substring of + True -> Done $ Just (unsafe_from_ordinal i) + False -> Continue + +-- put in prelude? Nothing particular to strings. +def split_at(xs:n=>a, at:Post n) -> (List a, List a) given (n|Ix, a) = + size_left = ordinal at + size_right = unsafe_nat_diff (size n) size_left + left = AsList(size_left, for i. xs[unsafe_from_ordinal (ordinal i)]) + right = AsList(size_right, for i. xs[unsafe_from_ordinal (ordinal i + size_left)]) + (left, right) + +def unsafe_replace_at(text:n=>a, new:List a, old_len:Nat, at_ix:n) -> List a given (n|Ix, a|Data)= + AsList(m, new_tab) = new + (beginning_and_middle, end) = split_at text (unsafe_from_ordinal (ordinal at_ix + old_len)) + AsList(_, beginning_and_middle_tab) = beginning_and_middle + (beginning, _) = split_at beginning_and_middle_tab (unsafe_from_ordinal (ordinal (left_post at_ix))) + concat [beginning, new, end] + +def replace_all(old:List a, to_replace:List a, new:List a) -> List a given (a|Eq|Data) = + AsList(_, new_table) = new + AsList(to_replace_length, to_replace_table) = to_replace + yield_state old \cur_str. + while \. + AsList(_, cur_str_table) = get cur_str + case find_first to_replace_table cur_str_table of + Nothing -> False + Just i -> + cur_str := unsafe_replace_at cur_str_table new to_replace_length i + True diff --git a/lib/prelude.dx b/lib/prelude.dx index c2baec9e6..f705ad083 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -874,27 +874,27 @@ instance Ix(Maybe a) given (a|Ix) True -> Nothing interface NonEmpty(n|Ix) - first_ix : n + pass instance NonEmpty(()) - first_ix = unsafe_from_ordinal(0) + pass instance NonEmpty(Bool) - first_ix = unsafe_from_ordinal 0 + pass instance NonEmpty((a,b)) given (a|NonEmpty, b|NonEmpty) - first_ix = unsafe_from_ordinal 0 + pass instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix) - first_ix = unsafe_from_ordinal 0 + pass -- The below instance is valid, but causes "multiple candidate dictionaries" -- errors if both Left and Right are NonEmpty. -- instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b] --- first_ix = unsafe_from_ordinal _ 0 +-- pass instance NonEmpty(Maybe a) given (a|Ix) - first_ix = unsafe_from_ordinal 0 + pass '## Fencepost index sets @@ -924,11 +924,14 @@ def right_fence(p:Post n) -> Maybe n given (n|Ix) = then Nothing else Just $ unsafe_from_ordinal ix +def first_ix() ->> n given (n|NonEmpty) = + unsafe_from_ordinal(0) + def last_ix() ->> n given (n|NonEmpty) = unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1)) instance NonEmpty(Post n) given (n|Ix) - first_ix = unsafe_from_ordinal(n=Post n, 0) + pass def scan( init:a, @@ -1704,7 +1707,7 @@ def from_ordinal(i:Nat) -> n given (n|Ix) = False -> error $ from_ordinal_error(i, size n) -- TODO: should this be called `from_ordinal`? -def to_ix(i:Nat) -> Maybe n given (n|Ix) = +def to_ix(i:Nat) -> Maybe n given (n|Ix) = case i < size n of True -> Just $ unsafe_from_ordinal i False -> Nothing @@ -2266,6 +2269,93 @@ instance Subset(b, Either(a,b)) given (a|Data, b|Data) Left( x) -> error "Can't project Left branch to Right branch" Right(x) -> x +instance Subset(n=>a, n=>b) given (n|Ix, a|Data, b|Data) (Subset a b) + def inject'(xs) = for i. inject xs[i] + def project'(xs') = + xs = for i. project xs'[i] + case any_sat(is_nothing, xs) of + True -> Nothing + False -> Just $ each xs from_just + def unsafe_project'(xs') = + xs = for i. project xs'[i] + case any_sat(is_nothing, xs) of + True -> error "Couldn't project table." + False -> each xs from_just + +-- add instance for subset n=>a m=>a given subset n m + +instance Subset(List a, List b) given (a|Data, b|Data) (Subset a b) + def inject'(xs') = + AsList(n, xs) = xs' + AsList(n, inject xs) + def project'(l) = + AsList(n, tab) = l + case project tab of + Nothing -> Nothing + Just xs -> Just AsList(n, xs) + def unsafe_project'(l) = + AsList(n, tab) = l + case project tab of + Nothing -> error "Couldn't project list." + Just xs -> AsList(n, xs) + +'### All but Last Index set +All the indices of the original index set except the last one. + +struct AllButLast(n:Nat, a|Ix) = + val : a + +instance Ix(AllButLast n a) given (n:Nat, a|Ix|Data) + def size'() = (size a) -| n + def ordinal(i) = ordinal i.val + def unsafe_from_ordinal(o) = AllButLast $ unsafe_from_ordinal o + +instance Subset(AllButLast n a, a) given (n:Nat, a|Ix) + def inject'(x) = x.val + def project'(x) = case (ordinal x) < ((size a) -| n) of + True -> Just (AllButLast x) + False -> Nothing + def unsafe_project'(x) = (AllButLast x) + +instance Eq(AllButLast n a) given (n:Nat, a|Eq|Ix) + def (==)(x, y) = x.val == y.val + +def unsafe_increment(i:n) -> n given (n|Ix) = from_ordinal (ordinal i + 1) +def next(i: AllButLast 1 n) -> n given (n|Ix) = unsafe_increment i.val +def get_next_m(tab:n=>a, i:AllButLast m n) -> List a given (n|Ix, m:Nat, a) = + -- The list returned always has size (n - m), but can't spell that yet. + to_list $ for j:(Fin m). tab[unsafe_from_ordinal (ordinal i + ordinal j)] + +'### Fence and Inner Posts +A custom datatype and index set +that interleaves the elements of a table with another set +of values representing all the spaces in between those elements, +not including the 2 ends. + +data FenceAndInnerPosts(n|Ix) = + Fence(n) + InnerPost(AllButLast 1 n) + +instance Ix(FenceAndInnerPosts n) given (n|Ix) + def size'() = 2 * size n -| 1 + def ordinal(i) = case i of + Fence j -> 2 * ordinal j + InnerPost j -> 2 * ordinal j + 1 + def unsafe_from_ordinal(o) = + case is_odd o of + False -> Fence $ unsafe_from_ordinal (intdiv2 o) + True -> InnerPost $ unsafe_from_ordinal (intdiv2 o) + +instance Eq(FenceAndInnerPosts a) given (a|Ix|Eq) + def (==)(x, y) = case x of + Fence x -> case y of + Fence y -> x == y + InnerPost y -> False + InnerPost x -> case y of + Fence y -> False + InnerPost y -> x == y + + '### Index set for tables def int_to_reversed_digits(k:Nat) -> a=>b given (a|Ix, b|Ix) = @@ -2291,7 +2381,7 @@ instance Ix(a=>b) given (a|Ix, b|Ix) def unsafe_from_ordinal(i) = int_to_reversed_digits i instance NonEmpty(a=>b) given (a|Ix, b|NonEmpty) - first_ix = unsafe_from_ordinal 0 + pass '### Stack diff --git a/lib/set.dx b/lib/set.dx index 8a67a0177..54130e9a0 100644 --- a/lib/set.dx +++ b/lib/set.dx @@ -1,4 +1,4 @@ -'# Sets and Set-Indexed Arrays +'# Sets, Set-Indexed Arrays, and Multisets import sort @@ -21,25 +21,42 @@ def all_except_last(xs:n=>a) -> List a given (n|Ix, a) = allButLast = for i:shortSize. xs[unsafe_from_ordinal (ordinal i)] AsList _ allButLast -def merge_unique_sorted_lists(xlist:List a, ylist:List a) -> List a given (a|Eq) = +def all_except_first(xs:n=>a) -> List a given (n|Ix, a) = + shortSize = Fin (size n -| 1) + allButFirst = for i:shortSize. xs[unsafe_from_ordinal (1 + ordinal i)] + (AsList _ allButFirst) + +def merge_unique_sorted_lists_with_aux( + combine_side_info : (side, side)->side, + xlist:List (a, side), + ylist:List (a, side)) -> List (a, side) given (a|Eq, side|Data) = -- This function is associative, for use in a monoidal reduction. -- Assumes all xs are <= all ys. -- The element at the end of xs might equal the -- element at the beginning of ys. If so, this -- function removes the duplicate when concatenating the lists. AsList(nx, xs) = xlist - AsList(_ , ys) = ylist + AsList(_, ys) = ylist case last xs of Nothing -> ylist Just last_x -> case first ys of Nothing -> xlist - Just first_y -> case last_x == first_y of - False -> concat [xlist, ylist] - True -> concat [all_except_last xs, ylist] + Just first_y -> + (last_x_inner, last_x_side) = last_x + (first_y_inner, first_y_side) = first_y + case last_x_inner == first_y_inner of + False -> xlist <> ylist + True -> + combined = AsList 1 [(first_y_inner, combine_side_info last_x_side first_y_side)] + all_except_last xs <> combined <> all_except_first ys def remove_duplicates_from_sorted(xs:n=>a) -> List a given (n|Ix, a|Eq) = - xlists = for i:n. (AsList 1 [xs[i]]) - reduce (AsList 0 []) merge_unique_sorted_lists xlists + -- Special case for ordinary sets, which don't have any side information. + xlist = for i. AsList(_, [(xs[i], ())]) + ignore = \a b. () + AsList(_, set_with_aux) = + reduce AsList(0, []) (\x y. merge_unique_sorted_lists_with_aux(ignore, x, y)) xlist + AsList(_, for i. fst set_with_aux[i]) '## Sets @@ -91,31 +108,65 @@ def set_intersect( '## Sets as a type, whose inhabitants can index arrays --- TODO Implicit arguments to data definitions --- (Probably `a` should be implicit) struct Element(set:(Set a)) given (a|Ord) = - val: Nat + val: Fin (set_size set) -- TODO The set argument could be implicit (inferred from the Element -- type), but maybe it's easier to read if it's explicit. def member(x:a, set:(Set a)) -> Maybe (Element set) given (a|Ord) = UnsafeAsSet(_, elts) = set case search_sorted_exact elts x of - Just n -> Just $ Element(ordinal n) + Just i -> Just Element((ordinal i)@_) -- Should be Just Element(i) Nothing -> Nothing def value(x:Element set) -> a given (a|Ord, set:Set a) = UnsafeAsSet(_, elts) = set - elts[unsafe_from_ordinal x.val] + elts[x.val] instance Ix(Element set) given (a|Ord, set:Set a) def size'() = set_size set - def ordinal(n) = n.val - def unsafe_from_ordinal(n) = Element(n) + def ordinal(n) = ordinal n.val + def unsafe_from_ordinal(n) = Element(n@_) instance Eq(Element set) given (a|Ord, set:Set a) - def (==)(ix1, ix2) = ordinal ix1 == ordinal ix2 + def (==)(x, y) = x.val == y.val instance Ord(Element set) given (a|Ord, set:Set a) - def (<)(ix1, ix2) = ordinal ix1 < ordinal ix2 - def (>)(ix1, ix2) = ordinal ix1 > ordinal ix2 + def (<)(x, y) = x.val < y.val + def (>)(x, y) = x.val > y.val + +'## Multisets + +def remove_duplicates_from_sorted_with_counts(xs:n=>(a, Nat)) -> + List (a, Nat) given (n|Ix, a|Eq|Data) = + xlists = for i. AsList(1, [xs[i]]) + reduce AsList(_, []) (\x y. merge_unique_sorted_lists_with_aux(\a b. a + b, x, y)) xlists + +data Multiset(a|Ord) = + -- Guaranteed to be in sorted order, + -- as long as no one else uses this constructor. + -- Instead use the "to_multiset" function below. + UnsafeAsMultiset(n:Nat, elements:(Fin n => (a, Nat))) + +def to_multiset(xs:n=>a) -> Multiset a given (n|Ix, a|Ord) = + sorted_xs = sort xs + sorted_xs_with_1s = for i. (sorted_xs[i], 1) + AsList(n', unique_xs) = remove_duplicates_from_sorted_with_counts sorted_xs_with_1s + UnsafeAsMultiset n' unique_xs + +instance Eq(Multiset a) given (a|Ord) + def (==)(sx, sy) = + UnsafeAsMultiset(_, xs) = sx + UnsafeAsMultiset(_, ys) = sy + (AsList _ xs) == (AsList _ ys) + +def multiset_add(sx:Multiset a, sy:Multiset a) -> Multiset a given (a|Ord) = + UnsafeAsMultiset(nx, xs) = sx + UnsafeAsMultiset(ny, ys) = sy + combined = merge_sorted_tables xs ys + AsList(_, unique_xs) = remove_duplicates_from_sorted_with_counts combined + UnsafeAsMultiset(_, unique_xs) + +instance Add(Multiset a) given (a|Ord) + def (+)(sx, sy) = multiset_add sx sy + zero = to_multiset [] diff --git a/lib/sort.dx b/lib/sort.dx index 1178dd0cc..3e4c6b561 100644 --- a/lib/sort.dx +++ b/lib/sort.dx @@ -63,6 +63,11 @@ def sort(xs: n=>a) -> n=>a given (n|Ix, a|Ord) = AsList(_, r) = reduce mempty mcombine xlists unsafe_cast_table(to=n, r) +def arg_sort(xs: n=>a) -> n=>n given (n|Ix|Ord, a|Ord) = + indexed_table = for i. (xs[i], i) + sorted = sort indexed_table + map snd sorted + def (+|)(i:n, delta:Nat) -> n given (n|Ix) = i' = ordinal i + delta from_ordinal $ select (i' >= size n) (size n -| 1) i' diff --git a/tests/parser-combinator-tests.dx b/tests/parser-combinator-tests.dx index 1d2c7a52a..03d4444d3 100644 --- a/tests/parser-combinator-tests.dx +++ b/tests/parser-combinator-tests.dx @@ -55,3 +55,67 @@ def parserTFTriple() ->> Parser (Fin 3=>Bool) = MkParser \h. split ' ' " This is a sentence. " > (AsList 4 ["This", "is", "a", "sentence."]) + +text = "This is a test." +join ' ' (split ' ' text) == text +> True + +partition ' ' " hello world " +> (AsList 5 [" ", "hello", " ", "world", " "]) + +partition ' ' " " +> (AsList 1 [" "]) + +partition ' ' "" +> (AsList 0 []) + +partition ' ' "hello" +> (AsList 1 ["hello"]) + +-- concat is the inverse of partition +text = " hello world " +AsList(_, textparts) = (partition ' ' text) +concat textparts == text + + +'### Tests for string utilities + +AsList(_, stab) = "woof" +AsList(_, ttab) = "askdwoofljmohjmwwoofofasdfas" +find_first stab ttab +> (Just 4) + +AsList(_, s2) = ""::List Char +AsList(_, t2) = "askd" +find_first s2 t2 +> (Just 0) + +AsList(_, s3) = ""::List Char +AsList(_, t3) = ""::List Char +find_first s3 t3 +> Nothing + +split_at [1, 2, 4, 5] (4@_) +> ((AsList 4 [1, 2, 4, 5]), (AsList 0 [])) + +unsafe_replace_at(ttab, "zazazazaz", 0, 0@_) +> "zazazazazaskdwoofljmohjmwwoofofasdfas" + +replace_all("this is a list with lists of lists", "list", "pot") +> "this is a pot with pots of pots" + +replace_all("this is a list with lists of lists", "", "pot") +> "this is a list with lists of lists" + +replace_all("this is a list with lists of lists", "list", "") +> "this is a with s of s" + +replace_all("this is a list with lists of lists", "asdf", "pot") +> "this is a list with lists of lists" + +replace_all("this is a list with lists of lists", "", "") +> "this is a list with lists of lists" + +replace_all("", "list", "pot") +> "" + diff --git a/tests/set-tests.dx b/tests/set-tests.dx index dd40db9d3..4fa6f4259 100644 --- a/tests/set-tests.dx +++ b/tests/set-tests.dx @@ -78,3 +78,39 @@ setix : Person = from_just $ member "Bob" names2 setix2 : Person = from_just $ member "Charlie" names2 :p setix2 > Element(2) + + +'#### Multiset tests + +-- check order invariance. +:p (to_multiset ["Bob", "Alice", "Bob", "Charlie"]) == (to_multiset ["Charlie", "Bob", "Alice", "Bob"]) +> True + +-- check counts matter for equality. +:p (to_multiset ["Bob", "Alice", "Bob", "Charlie"]) == (to_multiset ["Charlie", "Bob", "Alice"]) +> False + +multiset1 = to_multiset ["Xeno", "Alice", "Bob", "Bob"] +multiset2 = to_multiset ["Bob", "Xeno", "Charlie", "Xeno", "Alice"] + +:p multiset1 == multiset2 +> False + +:p multiset1 + multiset2 +> (UnsafeAsMultiset 4 [("Alice", 2), ("Bob", 3), ("Charlie", 1), ("Xeno", 3)]) + +:p multiset1 == (multiset1 + multiset1) +> False + +'#### Empty multiset tests + +emptymultiset = to_multiset ([]::(Fin 0)=>String) + +:p emptymultiset == emptymultiset +> True + +:p emptymultiset == (emptymultiset + emptymultiset) +> True + +:p multiset1 == (multiset1 + emptymultiset) +> True diff --git a/tests/sort-tests.dx b/tests/sort-tests.dx index 461a1b613..2bfc4422e 100644 --- a/tests/sort-tests.dx +++ b/tests/sort-tests.dx @@ -54,3 +54,9 @@ import sort :p is_sorted $ sort ["Charlie", "Alice", "Bob", "Aaron"] > True + +'### Test argsort +xs = ['1', '2', '3', '4', '9', '8', '7', '6'] +ixs = arg_sort xs +sort xs == for i. xs[ixs[i]] +> True