From 7d7b5bb0891efb40a8553f33d2eb3f7ca07dc2f6 Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Tue, 30 Sep 2025 12:10:52 -0700 Subject: [PATCH 01/15] Add traversals and lenses for format types --- .../U/Codebase/Sqlite/Branch/Format.hs | 47 +++++++++++++++++++ .../U/Codebase/Sqlite/Causal.hs | 9 ++++ .../U/Codebase/Sqlite/Decl/Format.hs | 16 +++++++ .../U/Codebase/Sqlite/Entity.hs | 33 +++++++++++++ .../U/Codebase/Sqlite/Patch/Format.hs | 37 +++++++++++++++ .../U/Codebase/Sqlite/Term/Format.hs | 18 +++++++ 6 files changed, 160 insertions(+) diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Branch/Format.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Branch/Format.hs index 2a2300329f..99bba447a9 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Branch/Format.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Branch/Format.hs @@ -4,9 +4,18 @@ module U.Codebase.Sqlite.Branch.Format HashBranchFormat, BranchLocalIds, BranchLocalIds' (..), + branchLocalIdsText_, + branchLocalIdsDefn_, + branchLocalIdsPatch_, + branchLocalIdsChildren_, HashBranchLocalIds, SyncBranchFormat, SyncBranchFormat' (..), + syncBranchFormatTexts_, + syncBranchFormatDefns_, + syncBranchFormatPatches_, + syncBranchFormatChildren_, + syncBranchFormatParents_, LocalBranchBytes (..), localToDbBranch, localToDbDiff, @@ -16,6 +25,7 @@ module U.Codebase.Sqlite.Branch.Format ) where +import Control.Lens import Data.Vector (Vector) import Data.Vector qualified as Vector import U.Codebase.HashTags @@ -103,6 +113,18 @@ data BranchLocalIds' t d p c = LocalIds } deriving (Show, Eq) +branchLocalIdsText_ :: Traversal (BranchLocalIds' t d p c) (BranchLocalIds' t' d p c) t t' +branchLocalIdsText_ f (LocalIds t d p c) = LocalIds <$> traverse f t <*> pure d <*> pure p <*> pure c + +branchLocalIdsDefn_ :: Traversal (BranchLocalIds' t d p c) (BranchLocalIds' t d' p c) d d' +branchLocalIdsDefn_ f (LocalIds t d p c) = LocalIds <$> pure t <*> traverse f d <*> pure p <*> pure c + +branchLocalIdsPatch_ :: Traversal (BranchLocalIds' t d p c) (BranchLocalIds' t d p' c) p p' +branchLocalIdsPatch_ f (LocalIds t d p c) = LocalIds <$> pure t <*> pure d <*> traverse f p <*> pure c + +branchLocalIdsChildren_ :: Traversal (BranchLocalIds' t d p c) (BranchLocalIds' t d p c') c c' +branchLocalIdsChildren_ f (LocalIds t d p c) = LocalIds <$> pure t <*> pure d <*> pure p <*> traverse f c + -- | Bytes encoding a LocalBranch newtype LocalBranchBytes = LocalBranchBytes ByteString deriving (Show, Eq, Ord) @@ -112,6 +134,31 @@ data SyncBranchFormat' parent text defn patch child | SyncDiff parent (BranchLocalIds' text defn patch child) LocalBranchBytes deriving (Eq, Show) +syncBranchFormatTexts_ :: Traversal (SyncBranchFormat' parent text defn patch child) (SyncBranchFormat' parent text' defn patch child) text text' +syncBranchFormatTexts_ f = \case + SyncFull li bytes -> SyncFull <$> (li & branchLocalIdsText_ %%~ f) <*> pure bytes + SyncDiff parent li bytes -> SyncDiff parent <$> (li & branchLocalIdsText_ %%~ f) <*> pure bytes + +syncBranchFormatDefns_ :: Traversal (SyncBranchFormat' parent text defn patch child) (SyncBranchFormat' parent text defn' patch child) defn defn' +syncBranchFormatDefns_ f = \case + SyncFull li bytes -> SyncFull <$> (li & branchLocalIdsDefn_ %%~ f) <*> pure bytes + SyncDiff parent li bytes -> SyncDiff parent <$> (li & branchLocalIdsDefn_ %%~ f) <*> pure bytes + +syncBranchFormatPatches_ :: Traversal (SyncBranchFormat' parent text defn patch child) (SyncBranchFormat' parent text defn patch' child) patch patch' +syncBranchFormatPatches_ f = \case + SyncFull li bytes -> SyncFull <$> (li & branchLocalIdsPatch_ %%~ f) <*> pure bytes + SyncDiff parent li bytes -> SyncDiff parent <$> (li & branchLocalIdsPatch_ %%~ f) <*> pure bytes + +syncBranchFormatChildren_ :: Traversal (SyncBranchFormat' parent text defn patch child) (SyncBranchFormat' parent text defn patch child') child child' +syncBranchFormatChildren_ f = \case + SyncFull li bytes -> SyncFull <$> (li & branchLocalIdsChildren_ %%~ f) <*> pure bytes + SyncDiff parent li bytes -> SyncDiff parent <$> (li & branchLocalIdsChildren_ %%~ f) <*> pure bytes + +syncBranchFormatParents_ :: Traversal (SyncBranchFormat' parent text defn patch child) (SyncBranchFormat' parent' text defn patch child) parent parent' +syncBranchFormatParents_ f = \case + SyncFull li bytes -> pure $ SyncFull li bytes + SyncDiff parent li bytes -> SyncDiff <$> f parent <*> pure li <*> pure bytes + type SyncBranchFormat = SyncBranchFormat' BranchObjectId TextId ObjectId PatchObjectId (BranchObjectId, CausalHashId) localToBranch :: (Ord t, Ord d) => BranchLocalIds' t d p c -> LocalBranch -> (Branch.Full.Branch' t d p c) diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Causal.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Causal.hs index 87f532bf25..b8b35f1555 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Causal.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Causal.hs @@ -3,9 +3,12 @@ module U.Codebase.Sqlite.Causal GDbCausal (..), SyncCausalFormat, SyncCausalFormat' (..), + syncCausalFormatCausalHash_, + syncCausalFormatValueHash_, ) where +import Control.Lens import Data.Vector (Vector) import U.Codebase.Sqlite.DbId (BranchHashId, CausalHashId) import Unison.Prelude @@ -24,4 +27,10 @@ data SyncCausalFormat' causalHash valueHash = SyncCausalFormat } deriving stock (Eq, Show) +syncCausalFormatCausalHash_ :: Traversal (SyncCausalFormat' causalHash valueHash) (SyncCausalFormat' causalHash' valueHash) causalHash causalHash' +syncCausalFormatCausalHash_ f (SyncCausalFormat v p) = SyncCausalFormat v <$> traverse f p + +syncCausalFormatValueHash_ :: Lens (SyncCausalFormat' causalHash valueHash) (SyncCausalFormat' causalHash valueHash') valueHash valueHash' +syncCausalFormatValueHash_ f (SyncCausalFormat v p) = (\v' -> SyncCausalFormat v' p) <$> f v + type SyncCausalFormat = SyncCausalFormat' CausalHashId BranchHashId diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Decl/Format.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Decl/Format.hs index 5752d2dd87..6e492f5183 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Decl/Format.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Decl/Format.hs @@ -2,11 +2,13 @@ module U.Codebase.Sqlite.Decl.Format where +import Control.Lens import Data.Vector (Vector) import U.Codebase.Decl (DeclR) import U.Codebase.Reference (Reference') import U.Codebase.Sqlite.DbId (ObjectId, TextId) import U.Codebase.Sqlite.LocalIds (LocalDefnId, LocalIds', LocalTextId) +import U.Codebase.Sqlite.LocalIds qualified as LocalIds import U.Codebase.Sqlite.Symbol (Symbol) import U.Codebase.Type qualified as Type import U.Core.ABT qualified as ABT @@ -38,10 +40,24 @@ data SyncDeclFormat' t d = SyncDecl (SyncLocallyIndexedComponent' t d) deriving stock (Eq, Show) +syncDeclFormatTexts_ :: Traversal (SyncDeclFormat' t d) (SyncDeclFormat' t' d) t t' +syncDeclFormatTexts_ f (SyncDecl c) = SyncDecl <$> syncLocallyIndexedComponentTexts_ f c + +syncDeclFormatDefns_ :: Traversal (SyncDeclFormat' t d) (SyncDeclFormat' t d') d d' +syncDeclFormatDefns_ f (SyncDecl c) = SyncDecl <$> syncLocallyIndexedComponentDefns_ f c + newtype SyncLocallyIndexedComponent' t d = SyncLocallyIndexedComponent (Vector (LocalIds' t d, ByteString)) deriving stock (Eq, Show) +syncLocallyIndexedComponentTexts_ :: Traversal (SyncLocallyIndexedComponent' t d) (SyncLocallyIndexedComponent' t' d) t t' +syncLocallyIndexedComponentTexts_ f (SyncLocallyIndexedComponent v) = + SyncLocallyIndexedComponent <$> (v & traversed . _1 . LocalIds.t_ %%~ f) + +syncLocallyIndexedComponentDefns_ :: Traversal (SyncLocallyIndexedComponent' t d) (SyncLocallyIndexedComponent' t d') d d' +syncLocallyIndexedComponentDefns_ f (SyncLocallyIndexedComponent v) = + SyncLocallyIndexedComponent <$> (v & traversed . _1 . LocalIds.h_ %%~ f) + -- [OldDecl] ==map==> [NewDecl] ==number==> [(NewDecl, Int)] ==sort==> [(NewDecl, Int)] ==> permutation is map snd of that -- type List a = Nil | Cons (List a) diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs index 92cbb58828..e0feea15b6 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs @@ -1,5 +1,6 @@ module U.Codebase.Sqlite.Entity where +import Control.Lens import U.Codebase.Sqlite.Branch.Format qualified as Namespace import U.Codebase.Sqlite.Causal qualified as Causal import U.Codebase.Sqlite.DbId (BranchHashId, BranchObjectId, CausalHashId, HashId, ObjectId, PatchObjectId, TextId) @@ -33,3 +34,35 @@ entityType = \case N _ -> NamespaceType P _ -> PatchType C _ -> CausalType + +texts_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text' hash defn patch branchh branch causal) text text' +texts_ f = \case + TC tcf -> TC <$> Term.syncTermFormatTexts_ f tcf + DC dcf -> DC <$> Decl.syncDeclFormatTexts_ f dcf + N ncf -> N <$> Namespace.syncBranchFormatTexts_ f ncf + P pcf -> P <$> Patch.syncPatchFormatTexts_ f pcf + C ccf -> pure (C ccf) + +hashes_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text hash' defn patch branchh branch causal) hash hash' +hashes_ f = \case + TC tcf -> pure (TC tcf) + DC dcf -> pure (DC dcf) + N ncf -> pure (N ncf) + P pcf -> P <$> Patch.syncPatchFormatHashes_ f pcf + C ccf -> pure (C ccf) + +defns_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text hash defn' patch branchh branch causal) defn defn' +defns_ f = \case + TC tcf -> TC <$> Term.syncTermFormatDefns_ f tcf + DC dcf -> DC <$> Decl.syncDeclFormatDefns_ f dcf + N ncf -> N <$> Namespace.syncBranchFormatDefns_ f ncf + P pcf -> P <$> Patch.syncPatchFormatDefns_ f pcf + C ccf -> pure (C ccf) + +patches_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text hash defn patch' branchh branch causal) patch patch' +patches_ f = \case + TC tcf -> pure (TC tcf) + DC dcf -> pure (DC dcf) + N ncf -> N <$> Namespace.syncBranchFormatPatches_ f ncf + P pcf -> P <$> Patch.syncPatchFormatParents_ f pcf + C ccf -> pure (C ccf) diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Patch/Format.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Patch/Format.hs index 452df27904..34ad6d855f 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Patch/Format.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Patch/Format.hs @@ -2,9 +2,16 @@ module U.Codebase.Sqlite.Patch.Format ( PatchFormat (..), PatchLocalIds, PatchLocalIds' (..), + patchLocalIdsTexts_, + patchLocalIdsHashes_, + patchLocalIdsDefns_, HashPatchLocalIds, SyncPatchFormat, SyncPatchFormat' (..), + syncPatchFormatParents_, + syncPatchFormatTexts_, + syncPatchFormatHashes_, + syncPatchFormatDefns_, applyPatchDiffs, localPatchToPatch, localPatchToPatch', @@ -13,6 +20,7 @@ module U.Codebase.Sqlite.Patch.Format ) where +import Control.Lens import Data.Map.Strict qualified as Map import Data.Set qualified as Set import Data.Vector (Vector) @@ -42,6 +50,15 @@ data PatchLocalIds' t h d = LocalIds } deriving stock (Eq, Show) +patchLocalIdsTexts_ :: Traversal (PatchLocalIds' t h d) (PatchLocalIds' t' h d) t t' +patchLocalIdsTexts_ f (LocalIds t h d) = LocalIds <$> traverse f t <*> pure h <*> pure d + +patchLocalIdsHashes_ :: Traversal (PatchLocalIds' t h d) (PatchLocalIds' t h' d) h h' +patchLocalIdsHashes_ f (LocalIds t h d) = LocalIds <$> pure t <*> traverse f h <*> pure d + +patchLocalIdsDefns_ :: Traversal (PatchLocalIds' t h d) (PatchLocalIds' t h d') d d' +patchLocalIdsDefns_ f (LocalIds t h d) = LocalIds <$> pure t <*> pure h <*> traverse f d + type SyncPatchFormat = SyncPatchFormat' PatchObjectId TextId HashId ObjectId data SyncPatchFormat' parent text hash defn @@ -50,6 +67,26 @@ data SyncPatchFormat' parent text hash defn SyncDiff parent (PatchLocalIds' text hash defn) ByteString deriving stock (Eq, Show) +syncPatchFormatParents_ :: Traversal (SyncPatchFormat' p text hash defn) (SyncPatchFormat' p' text hash defn) p p' +syncPatchFormatParents_ f = \case + (SyncDiff p li b) -> SyncDiff <$> f p <*> pure li <*> pure b + (SyncFull li b) -> SyncFull <$> pure li <*> pure b + +syncPatchFormatTexts_ :: Traversal (SyncPatchFormat' p text hash defn) (SyncPatchFormat' p text' hash defn) text text' +syncPatchFormatTexts_ f = \case + (SyncDiff p li b) -> SyncDiff p <$> (li & patchLocalIdsTexts_ %%~ f) <*> pure b + (SyncFull li b) -> SyncFull <$> (li & patchLocalIdsTexts_ %%~ f) <*> pure b + +syncPatchFormatHashes_ :: Traversal (SyncPatchFormat' p text hash defn) (SyncPatchFormat' p text hash' defn) hash hash' +syncPatchFormatHashes_ f = \case + (SyncDiff p li b) -> SyncDiff p <$> (li & patchLocalIdsHashes_ %%~ f) <*> pure b + (SyncFull li b) -> SyncFull <$> (li & patchLocalIdsHashes_ %%~ f) <*> pure b + +syncPatchFormatDefns_ :: Traversal (SyncPatchFormat' p text hash defn) (SyncPatchFormat' p text hash defn') defn defn' +syncPatchFormatDefns_ f = \case + (SyncDiff p li b) -> SyncDiff p <$> (li & patchLocalIdsDefns_ %%~ f) <*> pure b + (SyncFull li b) -> SyncFull <$> (li & patchLocalIdsDefns_ %%~ f) <*> pure b + -- | Apply a list of patch diffs to a patch, left to right. applyPatchDiffs :: Patch -> [PatchDiff] -> Patch applyPatchDiffs = diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Term/Format.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Term/Format.hs index f06fc70ec3..8e5a722dcc 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Term/Format.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Term/Format.hs @@ -2,6 +2,7 @@ module U.Codebase.Sqlite.Term.Format where +import Control.Lens import Data.ByteString (ByteString) import Data.Text (Text) import Data.Vector (Vector) @@ -9,6 +10,7 @@ import U.Codebase.Reference (Reference') import U.Codebase.Referent (Referent') import U.Codebase.Sqlite.DbId (ObjectId, TextId) import U.Codebase.Sqlite.LocalIds (LocalDefnId, LocalIds', LocalTextId, WatchLocalIds) +import U.Codebase.Sqlite.LocalIds qualified as LocalIds import U.Codebase.Sqlite.Reference qualified as Sqlite import U.Codebase.Sqlite.Symbol (Symbol) import U.Codebase.Term qualified as Term @@ -51,6 +53,14 @@ newtype SyncLocallyIndexedComponent' t d = SyncLocallyIndexedComponent (Vector (LocalIds' t d, ByteString)) deriving stock (Eq, Show) +syncLocallyIndexedComponentTexts_ :: Traversal (SyncLocallyIndexedComponent' t d) (SyncLocallyIndexedComponent' t' d) t t' +syncLocallyIndexedComponentTexts_ f (SyncLocallyIndexedComponent v) = + SyncLocallyIndexedComponent <$> (v & traversed . _1 . LocalIds.t_ %%~ f) + +syncLocallyIndexedComponentDefns :: Traversal (SyncLocallyIndexedComponent' t d) (SyncLocallyIndexedComponent' t d') d d' +syncLocallyIndexedComponentDefns f (SyncLocallyIndexedComponent v) = + SyncLocallyIndexedComponent <$> (v & traversed . _1 . LocalIds.h_ %%~ f) + {- message = "hello, world" -> ABT { ... { Term.F.Text "hello, world" } } -> hashes to (#abc, 0) program = printLine message -> ABT { ... { Term.F.App (ReferenceBuiltin ##io.PrintLine) (Reference #abc 0) } } -> hashes to (#def, 0) @@ -130,6 +140,14 @@ type SyncTermFormat = SyncTermFormat' TextId ObjectId data SyncTermFormat' t d = SyncTerm (SyncLocallyIndexedComponent' t d) deriving stock (Eq, Show) +syncTermFormatTexts_ :: Traversal (SyncTermFormat' t d) (SyncTermFormat' t' d) t t' +syncTermFormatTexts_ f (SyncTerm slic) = + SyncTerm <$> (slic & syncLocallyIndexedComponentTexts_ %%~ f) + +syncTermFormatDefns_ :: Traversal (SyncTermFormat' t d) (SyncTermFormat' t d') d d' +syncTermFormatDefns_ f (SyncTerm slic) = + SyncTerm <$> (slic & syncLocallyIndexedComponentDefns %%~ f) + data WatchResultFormat = WatchResult WatchLocalIds Term From b880ac9380651ad536557bbc728e0af16b173bf0 Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Tue, 30 Sep 2025 12:28:39 -0700 Subject: [PATCH 02/15] More traversals --- .../U/Codebase/Sqlite/Entity.hs | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs index e0feea15b6..b51f1063a1 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs @@ -66,3 +66,28 @@ patches_ f = \case N ncf -> N <$> Namespace.syncBranchFormatPatches_ f ncf P pcf -> P <$> Patch.syncPatchFormatParents_ f pcf C ccf -> pure (C ccf) + +branchHashes_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text hash defn patch branchh' branch causal) branchh branchh' +branchHashes_ f = \case + TC tcf -> pure (TC tcf) + DC dcf -> pure (DC dcf) + N ncf -> pure (N ncf) + P pcf -> pure (P pcf) + C ccf -> C <$> Causal.syncCausalFormatValueHash_ f ccf + +branches_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text hash defn patch branchh branch' causal) branch branch' +branches_ f = \case + TC tcf -> pure (TC tcf) + DC dcf -> pure (DC dcf) + N ncf -> + ( case ncf of + Namespace.SyncFull li bytes -> Namespace.SyncFull <$> (li & Namespace.branchLocalIdsChildren_ . _1 %%~ f) <*> pure bytes + Namespace.SyncDiff parent li bytes -> + Namespace.SyncDiff + <$> (f parent) + <*> (li & Namespace.branchLocalIdsChildren_ . _1 %%~ f) + <*> pure bytes + ) + <&> N + P pcf -> pure (P pcf) + C ccf -> pure (C ccf) From ccfeb0beda2b711915902e1c28e1ffe3a76364fd Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Tue, 30 Sep 2025 12:46:56 -0700 Subject: [PATCH 03/15] Add missing traversal --- .../codebase-sqlite/U/Codebase/Sqlite/Entity.hs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs index b51f1063a1..1d034627bc 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Entity.hs @@ -91,3 +91,20 @@ branches_ f = \case <&> N P pcf -> pure (P pcf) C ccf -> pure (C ccf) + +causalHashes_ :: Traversal (SyncEntity' text hash defn patch branchh branch causal) (SyncEntity' text hash defn patch branchh branch causal') causal causal' +causalHashes_ f = \case + TC tcf -> pure (TC tcf) + DC dcf -> pure (DC dcf) + N ncf -> + ( case ncf of + Namespace.SyncFull li bytes -> Namespace.SyncFull <$> (li & Namespace.branchLocalIdsChildren_ . _2 %%~ f) <*> pure bytes + Namespace.SyncDiff parent li bytes -> + Namespace.SyncDiff + <$> (pure parent) + <*> (li & Namespace.branchLocalIdsChildren_ . _2 %%~ f) + <*> pure bytes + ) + <&> N + P pcf -> pure (P pcf) + C ccf -> C <$> Causal.syncCausalFormatCausalHash_ f ccf From 2e657232c2cb1fe33ef9d06ae471a574288e03dd Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Tue, 30 Sep 2025 13:45:33 -0700 Subject: [PATCH 04/15] Add Hashable instance for Hash32 --- lib/unison-hash/package.yaml | 1 + lib/unison-hash/src/Unison/Hash32.hs | 3 ++- lib/unison-hash/unison-hash.cabal | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/unison-hash/package.yaml b/lib/unison-hash/package.yaml index 523de61905..92be9dba76 100644 --- a/lib/unison-hash/package.yaml +++ b/lib/unison-hash/package.yaml @@ -10,6 +10,7 @@ dependencies: - deepseq - unison-prelude - unison-util-base32hex + - hashable library: source-dirs: src diff --git a/lib/unison-hash/src/Unison/Hash32.hs b/lib/unison-hash/src/Unison/Hash32.hs index 97e7c201ed..4b74d89631 100644 --- a/lib/unison-hash/src/Unison/Hash32.hs +++ b/lib/unison-hash/src/Unison/Hash32.hs @@ -18,6 +18,7 @@ module Unison.Hash32 ) where +import Data.Hashable (Hashable) import U.Util.Base32Hex (Base32Hex (..)) import Unison.Hash (Hash) import Unison.Hash qualified as Hash @@ -30,7 +31,7 @@ import Unison.Prelude -- * @unison-util-base32hex-orphans-aeson@ -- * @unison-util-base32hex-orphans-sqlite@ newtype Hash32 = UnsafeFromBase32Hex Base32Hex - deriving (Eq, Ord, Show) via (Text) + deriving (Eq, Ord, Show, Hashable) via (Text) instance From Hash32 Text where from = toText diff --git a/lib/unison-hash/unison-hash.cabal b/lib/unison-hash/unison-hash.cabal index a1a32f0b1b..d68ba201ba 100644 --- a/lib/unison-hash/unison-hash.cabal +++ b/lib/unison-hash/unison-hash.cabal @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.36.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack @@ -54,6 +54,7 @@ library base , bytestring , deepseq + , hashable , unison-prelude , unison-util-base32hex default-language: Haskell2010 From 916dfbad1ad8567ffd95944b5fe5f2ba932d137a Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Wed, 1 Oct 2025 09:20:44 -0700 Subject: [PATCH 05/15] Add SyncV3 Types --- unison-cli/src/Unison/Share/SyncV3.hs | 6 + unison-cli/unison-cli.cabal | 3 +- unison-share-api/package.yaml | 1 + unison-share-api/src/Unison/SyncV3/Types.hs | 342 ++++++++++++++++++++ unison-share-api/unison-share-api.cabal | 4 +- 5 files changed, 354 insertions(+), 2 deletions(-) create mode 100644 unison-cli/src/Unison/Share/SyncV3.hs create mode 100644 unison-share-api/src/Unison/SyncV3/Types.hs diff --git a/unison-cli/src/Unison/Share/SyncV3.hs b/unison-cli/src/Unison/Share/SyncV3.hs new file mode 100644 index 0000000000..adfb08c2bf --- /dev/null +++ b/unison-cli/src/Unison/Share/SyncV3.hs @@ -0,0 +1,6 @@ +module Unison.Share.SyncV2 ( + syncFromCodeserver, + ) +where + + diff --git a/unison-cli/unison-cli.cabal b/unison-cli/unison-cli.cabal index 75daa4c176..6433469e07 100644 --- a/unison-cli/unison-cli.cabal +++ b/unison-cli/unison-cli.cabal @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.36.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack @@ -163,6 +163,7 @@ library Unison.Share.Sync Unison.Share.Sync.Types Unison.Share.SyncV2 + Unison.Share.SyncV3 Unison.Util.HTTP Unison.Version hs-source-dirs: diff --git a/unison-share-api/package.yaml b/unison-share-api/package.yaml index 879fce940c..4a258ea77e 100644 --- a/unison-share-api/package.yaml +++ b/unison-share-api/package.yaml @@ -65,6 +65,7 @@ library: - wai - wai-cors - warp + - websockets - yaml tests: diff --git a/unison-share-api/src/Unison/SyncV3/Types.hs b/unison-share-api/src/Unison/SyncV3/Types.hs new file mode 100644 index 0000000000..f698c48045 --- /dev/null +++ b/unison-share-api/src/Unison/SyncV3/Types.hs @@ -0,0 +1,342 @@ +module Unison.SyncV3.Types + ( InitMsg (..), + EntityRequestMsg (..), + FromReceiverMessage (..), + FromEmitterMessage (..), + MsgOrError (..), + SyncError (..), + Entity (..), + EntityKind (..), + EntityDepth (..), + HashMappings (..), + HashTag (..), + ) +where + +import Codec.Serialise qualified as CBOR +import Control.Lens hiding ((.=)) +import Data.Aeson +import Data.Aeson qualified as Aeson +import Data.ByteString qualified as BS +import Data.ByteString.Lazy.Char8 qualified as BL +import Data.Int (Int64) +import Data.Map (Map) +import Data.Set (Set) +import Data.Set qualified as Set +import Data.Text (Text) +import Network.WebSockets (WebSocketsData) +import Network.WebSockets qualified as WS +import U.Codebase.Sqlite.Orphans () +import U.Codebase.Sqlite.TempEntity +import Unison.Hash32 (Hash32) +import Unison.Prelude (tShow) +import Unison.Server.Orphans () +import Unison.Util.Servant.CBOR qualified as CBOR + +data InitMsg authedHash = InitMsg + { initMsgClientVersion :: Text, + initMsgProjectId :: Text, + initMsgRootCausal :: authedHash, + initMsgRequestedDepth :: Maybe Int64 + } + deriving (Show, Eq) + +instance (ToJSON authedHash) => ToJSON (InitMsg authedHash) where + toJSON (InitMsg {initMsgClientVersion, initMsgProjectId, initMsgRootCausal, initMsgRequestedDepth}) = + object + [ "clientVersion" .= initMsgClientVersion, + "projectId" .= initMsgProjectId, + "rootCausal" .= initMsgRootCausal, + "requestedDepth" .= initMsgRequestedDepth + ] + +instance (FromJSON authedHash) => FromJSON (InitMsg authedHash) where + parseJSON = withObject "InitMsg" $ \o -> + InitMsg + <$> o .: "clientVersion" + <*> o .: "projectId" + <*> o .: "rootCausal" + <*> o .:? "requestedDepth" + +data EntityRequestMsg hash = EntityRequestMsg + { hashes :: [(EntityKind, hash)] + } + deriving (Show, Eq) + +instance (CBOR.Serialise sh) => CBOR.Serialise (EntityRequestMsg sh) where + encode (EntityRequestMsg {hashes}) = + CBOR.encode hashes + + decode = do + hashes <- CBOR.decode @[(EntityKind, sh)] + pure $ EntityRequestMsg {hashes} + +data FromReceiverMessageTag + = InitStreamTag + | EntityRequestTag + +instance CBOR.Serialise FromReceiverMessageTag where + encode = \case + InitStreamTag -> CBOR.encode (0 :: Int) + EntityRequestTag -> CBOR.encode (1 :: Int) + + decode = do + tag <- CBOR.decode @Int + case tag of + 0 -> pure InitStreamTag + 1 -> pure EntityRequestTag + _ -> fail $ "Unknown FromReceiverMessageTag: " <> show tag + +-- A message sent from the downloader to the emitter. +data FromReceiverMessage ah hash + = InitStream (InitMsg ah) + | EntityRequest (EntityRequestMsg hash) + deriving (Show, Eq) + +instance (ToJSON ah, FromJSON ah) => CBOR.Serialise (InitMsg ah) where + encode msg = do + -- This is dumb, but there's currently no reasonable way to encode a heterogenous Map + -- using Haskell's CBOR library :| + -- + -- See https://github.com/well-typed/cborg/issues/369 + CBOR.encode $ Aeson.encode msg + + decode = do + bs <- CBOR.decode @BL.ByteString + case Aeson.eitherDecode bs of + Left err -> fail $ "Error decoding InitMsg from JSON: " <> err + Right msg -> pure msg + +instance (CBOR.Serialise h, ToJSON ah, FromJSON ah) => CBOR.Serialise (FromReceiverMessage ah h) where + encode = \case + InitStream initMsg -> + CBOR.encode InitStreamTag + <> CBOR.encode initMsg + EntityRequest msg -> + CBOR.encode EntityRequestTag + <> CBOR.encode msg + decode = do + tag <- CBOR.decode @FromReceiverMessageTag + case tag of + InitStreamTag -> InitStream <$> CBOR.decode @(InitMsg ah) + EntityRequestTag -> EntityRequest <$> CBOR.decode @(EntityRequestMsg h) + +data SyncError + = InitializationError Text + | UnexpectedMessage BL.ByteString + | EncodingFailure Text + | -- The caller asked for a Hash they shouldn't have access to. + ForbiddenEntityRequest (Set (EntityKind, Hash32)) + +instance CBOR.Serialise SyncError where + encode = \case + InitializationError msg -> + CBOR.encode (0 :: Int) <> CBOR.encode msg + UnexpectedMessage msg -> + CBOR.encode (1 :: Int) <> CBOR.encode (BL.toStrict msg) + EncodingFailure msg -> + CBOR.encode (2 :: Int) <> CBOR.encode msg + ForbiddenEntityRequest hashes -> + CBOR.encode (3 :: Int) <> CBOR.encode hashes + + decode = do + tag <- CBOR.decode @Int + case tag of + 0 -> InitializationError <$> CBOR.decode + 1 -> do + bs <- CBOR.decode @BS.ByteString + pure $ UnexpectedMessage (BL.fromStrict bs) + 2 -> EncodingFailure <$> CBOR.decode + 3 -> ForbiddenEntityRequest . Set.fromList <$> CBOR.decode + _ -> fail $ "Unknown SyncError tag: " <> show tag + +-- A message sent from the emitter to the downloader. +data FromEmitterMessage hash text + = ErrorMsg SyncError + | -- | HashMappingsMsg (HashMappings hash smallHash) + EntityMsg (Entity hash text) + +instance (CBOR.Serialise hash, CBOR.Serialise text) => WebSocketsData (FromEmitterMessage hash text) where + fromLazyByteString bytes = + CBOR.deserialiseOrFailCBORBytes (CBOR.CBORBytes bytes) + & either (\err -> ErrorMsg . EncodingFailure $ "Error decoding CBOR message from bytes: " <> tShow err) id + + toLazyByteString = CBOR.serialise + + fromDataMessage dm = do + case dm of + WS.Text bytes _ -> WS.fromLazyByteString bytes + WS.Binary bytes -> WS.fromLazyByteString bytes + +data HashMappings hash smallHash = HashMappings + { hashMappings :: Map smallHash hash + } + +data EntityKind + = CausalEntity + | NamespaceEntity + | TermEntity + | TypeEntity + | PatchEntity + deriving stock (Show, Eq, Ord) + +instance CBOR.Serialise EntityKind where + encode = \case + CausalEntity -> CBOR.encode (0 :: Int) + NamespaceEntity -> CBOR.encode (1 :: Int) + TermEntity -> CBOR.encode (2 :: Int) + TypeEntity -> CBOR.encode (3 :: Int) + PatchEntity -> CBOR.encode (4 :: Int) + + decode = do + tag <- CBOR.decode @Int + case tag of + 0 -> pure CausalEntity + 1 -> pure NamespaceEntity + 2 -> pure TermEntity + 3 -> pure TypeEntity + 4 -> pure PatchEntity + _ -> fail $ "Unknown EntityKind tag: " <> show tag + +-- | The number of _levels_ of dependencies an entity has, +-- this has no real semantic meaning on its own, but provides the +-- property that out of a given set of synced entities, if you process +-- them in order of increasing EntityDepth, you will always have +-- processed an entity's dependencies before you see the entity itself. +newtype EntityDepth = EntityDepth {unEntityDepth :: Int64} + deriving (Show, Eq, Ord) + deriving newtype (CBOR.Serialise) + +data Entity hash text = Entity + { entityHash :: hash, + entityKind :: EntityKind, + entityDepth :: EntityDepth, + entityData :: CBOR.CBORBytes TempEntity + } + +-- entityTexts_ :: Traversal (Entity smallHash text) (Entity smallHash text') text text' +-- entityTexts_ f (Entity {entityData, ..}) = +-- (\entityData' -> Entity {entityData = entityData', ..}) <$> Entity.texts_ f entityData + +-- entityHashesSetter_ :: (Monad m) => LensLike m (Entity smallHash text) (Entity smallHash' text) smallHash smallHash' +-- entityHashesSetter_ f (Entity {entityHash, entityData, ..}) = +-- (\entityHash' entityData' -> Entity {entityHash = entityHash', entityData = entityData', ..}) +-- <$> f entityHash +-- <*> ( entityData +-- & Entity.hashes_ f +-- >>= Entity.defns_ f +-- >>= Entity.patches_ f +-- >>= Entity.branchHashes_ f +-- >>= Entity.branches_ f +-- >>= Entity.causalHashes_ f +-- ) + +-- -- | It's technically possible to implement entityHashesGetter_ and entityHashesSetter_ +-- -- as a single Traversal, but it's a ton of extra unpacking/packing that's probably not worth +-- -- it. +-- entityHashesGetter_ :: Fold (Entity smallHash text) smallHash +-- entityHashesGetter_ f (Entity {entityHash, entityData}) = +-- phantom (f entityHash) +-- *> phantom (Entity.hashes_ f entityData) +-- *> phantom (Entity.defns_ f entityData) +-- *> phantom (Entity.patches_ f entityData) +-- *> phantom (Entity.branchHashes_ f entityData) +-- *> phantom (Entity.branches_ f entityData) +-- *> phantom (Entity.causalHashes_ f entityData) + +instance (CBOR.Serialise smallHash, CBOR.Serialise text) => CBOR.Serialise (Entity smallHash text) where + encode (Entity {entityHash, entityKind, entityDepth, entityData}) = + CBOR.encode entityHash + <> CBOR.encode entityKind + <> CBOR.encode entityDepth + <> CBOR.encode entityData + + decode = do + entityHash <- CBOR.decode @smallHash + entityKind <- CBOR.decode @EntityKind + entityDepth <- CBOR.decode @EntityDepth + entityData <- CBOR.decode @(CBOR.CBORBytes TempEntity) + + pure $ Entity {entityHash, entityKind, entityData, entityDepth} + +instance (Ord smallHash, CBOR.Serialise hash, CBOR.Serialise smallHash) => CBOR.Serialise (HashMappings hash smallHash) where + encode (HashMappings {hashMappings}) = + CBOR.encode hashMappings + + decode = do + hashMappings <- CBOR.decode @(Map smallHash hash) + pure $ HashMappings {hashMappings} + +instance (CBOR.Serialise hash, CBOR.Serialise text) => CBOR.Serialise (FromEmitterMessage hash text) where + encode = \case + ErrorMsg err -> CBOR.encode ErrorMsgTag <> CBOR.encode err + -- HashMappingsMsg msg -> CBOR.encode HashMappingsTag <> CBOR.encode msg + EntityMsg msg -> CBOR.encode EntityTag <> CBOR.encode msg + + decode = do + tag <- CBOR.decode @FromEmitterMessageTag + case tag of + ErrorMsgTag -> ErrorMsg <$> CBOR.decode + -- HashMappingsTag -> HashMappingsMsg <$> CBOR.decode + EntityTag -> EntityMsg <$> CBOR.decode + +data FromEmitterMessageTag + = ErrorMsgTag + | -- | HashMappingsTag + EntityTag + +instance CBOR.Serialise FromEmitterMessageTag where + encode = \case + ErrorMsgTag -> CBOR.encode (0 :: Int) + -- HashMappingsTag -> CBOR.encode (1 :: Int) + EntityTag -> CBOR.encode (2 :: Int) + + decode = do + tag <- CBOR.decode @Int + case tag of + 0 -> pure ErrorMsgTag + -- 1 -> pure HashMappingsTag + 2 -> pure EntityTag + _ -> fail $ "Unknown FromEmitterMessageTag: " <> show tag + +data MsgOrError err a + = Msg a + | Err err + +instance (CBOR.Serialise a, CBOR.Serialise err) => CBOR.Serialise (MsgOrError err a) where + encode = \case + Msg a -> CBOR.encode (0 :: Int) <> CBOR.encode a + Err e -> CBOR.encode (1 :: Int) <> CBOR.encode e + + decode = do + tag <- CBOR.decode @Int + case tag of + 0 -> Msg <$> CBOR.decode + 1 -> Err <$> CBOR.decode + _ -> fail $ "Unknown MsgOrError tag: " <> show tag + +instance (CBOR.Serialise sh, ToJSON ah, FromJSON ah) => WebSocketsData (MsgOrError SyncError (FromReceiverMessage ah sh)) where + fromLazyByteString bytes = + CBOR.deserialiseOrFailCBORBytes (CBOR.CBORBytes bytes) + & either (\err -> Err . EncodingFailure $ "Error decoding CBOR message from bytes: " <> tShow err) Msg + + toLazyByteString = CBOR.serialise + + fromDataMessage dm = do + case dm of + WS.Text bytes _ -> WS.fromLazyByteString bytes + WS.Binary bytes -> WS.fromLazyByteString bytes + +-- Application level compression of Hash references. +-- We can send a mapping of Hash <-> HashTag at the start of the stream, +-- and then use the smaller HashTag in all subsequent messages. +data HashTag = HashTag (EntityKind, Int64) + deriving (Show, Eq, Ord) + +instance CBOR.Serialise HashTag where + encode (HashTag (kind, idx)) = + CBOR.encode (kind, idx) + + decode = do + (kind, idx) <- CBOR.decode @(EntityKind, Int64) + pure $ HashTag (kind, idx) diff --git a/unison-share-api/unison-share-api.cabal b/unison-share-api/unison-share-api.cabal index dd8a117685..38bcc5395c 100644 --- a/unison-share-api/unison-share-api.cabal +++ b/unison-share-api/unison-share-api.cabal @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.37.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack @@ -51,6 +51,7 @@ library Unison.Sync.Types Unison.SyncV2.API Unison.SyncV2.Types + Unison.SyncV3.Types Unison.Util.Find Unison.Util.Servant.CBOR hs-source-dirs: @@ -144,6 +145,7 @@ library , wai , wai-cors , warp + , websockets , yaml default-language: Haskell2010 From c6c38429546f67a0f342019b9145e32888b0b5a0 Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Wed, 1 Oct 2025 15:28:58 -0700 Subject: [PATCH 06/15] Implementing SyncV3 client --- .../U/Codebase/Sqlite/Queries.hs | 26 +++ .../sql/020-add-sync-v3-temp-tables.sql | 14 ++ .../unison-codebase-sqlite.cabal | 3 +- unison-cli/package.yaml | 1 + unison-cli/src/Unison/Share/SyncV3.hs | 163 +++++++++++++++++- unison-cli/unison-cli.cabal | 1 + unison-share-api/package.yaml | 1 + unison-share-api/src/Unison/SyncV3/Types.hs | 99 ++++++----- .../src/Unison/Util/Websockets.hs | 79 +++++++++ unison-share-api/unison-share-api.cabal | 2 + 10 files changed, 344 insertions(+), 45 deletions(-) create mode 100644 codebase2/codebase-sqlite/sql/020-add-sync-v3-temp-tables.sql create mode 100644 unison-share-api/src/Unison/Util/Websockets.hs diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs index 30e56300fd..c0a50122fe 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs @@ -218,9 +218,11 @@ module U.Codebase.Sqlite.Queries EntityLocation (..), entityExists, entityLocation, + entityLocationSyncV3, expectEntity, syncToTempEntity, insertTempEntity, + insertTempEntitySyncV3, saveTempEntityInMain, expectTempEntity, deleteTempEntity, @@ -254,6 +256,7 @@ module U.Codebase.Sqlite.Queries addUpdateBranchTable, addDerivedDependentsByDependencyIndex, addUpgradeBranchTable, + addSyncV3TempTables, -- ** schema version currentSchemaVersion, @@ -499,6 +502,10 @@ addUpgradeBranchTable :: Transaction () addUpgradeBranchTable = executeStatements $(embedProjectStringFile "sql/019-add-upgrade-branch-table.sql") +addSyncV3TempTables :: Transaction () +addSyncV3TempTables = + executeStatements $(embedProjectStringFile "sql/020-add-sync-v3-temp-tables.sql") + schemaVersion :: Transaction SchemaVersion schemaVersion = queryOneCol @@ -2232,6 +2239,16 @@ entityLocation hash = True -> Just EntityInTempStorage False -> Nothing +entityLocationSyncV3 :: Hash32 -> Transaction (Maybe EntityLocation) +entityLocationSyncV3 hash = + entityExists hash >>= \case + True -> pure (Just EntityInMainStorage) + False -> do + let theSql = [sql| SELECT EXISTS (SELECT 1 FROM syncv3_temp_entity WHERE entity_hash = :hash) |] + queryOneCol theSql <&> \case + True -> Just EntityInTempStorage + False -> Nothing + -- | Does this entity already exist in the database, i.e. in the `object` or `causal` table? entityExists :: Hash32 -> Transaction Bool entityExists hash = do @@ -2285,6 +2302,15 @@ insertTempEntity entityHash entity missingDependencies = do entityType = Entity.entityType entity +insertTempEntitySyncV3 :: Hash32 -> Text -> Hash32 -> Int32 -> ByteString -> Transaction () +insertTempEntitySyncV3 rootCausal entityKind entityHash entityDepth entityBlob = do + execute + [sql| + INSERT INTO syncv3_temp_entity (root_causal, entity_hash, entity_kind, entity_data, entity_depth) + VALUES (:rootCausal, :entityHash, :entityKind, :entityBlob, :entityDepth) + ON CONFLICT DO NOTHING + |] + -- | Delete a row from the `temp_entity` table, if it exists. deleteTempEntity :: Hash32 -> Transaction () deleteTempEntity hash = diff --git a/codebase2/codebase-sqlite/sql/020-add-sync-v3-temp-tables.sql b/codebase2/codebase-sqlite/sql/020-add-sync-v3-temp-tables.sql new file mode 100644 index 0000000000..243dc128ab --- /dev/null +++ b/codebase2/codebase-sqlite/sql/020-add-sync-v3-temp-tables.sql @@ -0,0 +1,14 @@ +-- Add a new table for storing entities which are currently being synced + +CREATE TABLE syncv3_temp_entity ( + root_causal INTEGER NOT NULL REFERENCES hash (id) ON DELETE CASCADE, + entity_hash TEXT NOT NULL, + entity_kind TEXT NOT NULL, + entity_data BLOB NOT NULL, + entity_depth INTEGER NOT NULL, + PRIMARY KEY (root_causal, entity_hash) +) WITHOUT ROWID; + +-- We _could_ add an index on (root_causal, entity_depth), since that's how we'll +-- be querying this table, but we only run the query exactly once per sync, so it's +-- probably faster to sort on query rather than maintaining the index. diff --git a/codebase2/codebase-sqlite/unison-codebase-sqlite.cabal b/codebase2/codebase-sqlite/unison-codebase-sqlite.cabal index 0e5780e7f9..923a04bfb1 100644 --- a/codebase2/codebase-sqlite/unison-codebase-sqlite.cabal +++ b/codebase2/codebase-sqlite/unison-codebase-sqlite.cabal @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.37.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack @@ -29,6 +29,7 @@ extra-source-files: sql/017-add-update-branch-table.sql sql/018-add-derived-dependents-by-dependency-index.sql sql/019-add-upgrade-branch-table.sql + sql/020-add-sync-v3-temp-tables.sql sql/create.sql source-repository head diff --git a/unison-cli/package.yaml b/unison-cli/package.yaml index 4418105d7b..369119746c 100644 --- a/unison-cli/package.yaml +++ b/unison-cli/package.yaml @@ -104,6 +104,7 @@ library: - vector - wai - warp + - websockets - witch - witherable diff --git a/unison-cli/src/Unison/Share/SyncV3.hs b/unison-cli/src/Unison/Share/SyncV3.hs index adfb08c2bf..93ca059dbf 100644 --- a/unison-cli/src/Unison/Share/SyncV3.hs +++ b/unison-cli/src/Unison/Share/SyncV3.hs @@ -1,6 +1,165 @@ -module Unison.Share.SyncV2 ( - syncFromCodeserver, +module Unison.Share.SyncV3 + ( syncFromCodeserver, ) where +import Control.Monad.Reader +import Data.Int (Int32) +import Data.Maybe (fromMaybe) +import Data.Set (Set) +import Data.Set qualified as Set +import Data.Set.Lens qualified as Lens +import Data.Text (Text) +import GHC.Natural +import Ki qualified +import Network.WebSockets.Client qualified as WS +import Servant.Client qualified as Servant +import U.Codebase.HashTags +import U.Codebase.Sqlite.DbId +import U.Codebase.Sqlite.Entity qualified as Entity +import U.Codebase.Sqlite.Queries qualified as Q +import U.Codebase.Sqlite.TempEntity (TempEntity) +import Unison.Cli.Monad +import Unison.Cli.Monad qualified as Cli +import Unison.Codebase (Codebase) +import Unison.Codebase qualified as Codebase +import Unison.Codebase.Editor.UCMVersion (UCMVersion) +import Unison.Hash32 (Hash32) +import Unison.Prelude +import Unison.Share.API.Hash qualified as Share +import Unison.Share.Codeserver qualified as Codeserver +import Unison.Share.Sync.Types qualified as Sync +import Unison.Sqlite qualified as Sqlite +import Unison.SyncV3.Types +import Unison.SyncV3.Types as SyncV3 +import Unison.Util.Servant.CBOR (CBORBytes) +import Unison.Util.Servant.CBOR qualified as CBOR +import Unison.Util.Websockets (Queues (..), withQueues) +import UnliftIO.STM +-- Websocket send/receive buffer sizes +inputBuffer :: Natural +inputBuffer = 1000 + +outputBuffer :: Natural +outputBuffer = 1000 + +syncV3ClientVersion :: Int32 +syncV3ClientVersion = 1 + +syncFromCodeserver :: + Bool -> + -- | The Unison Share URL. + Codeserver.CodeserverURI -> + -- | The branch to download from. + BranchRef -> + -- | The hash to download. + Share.HashJWT -> + Cli (Either (Sync.SyncError SyncV3.SyncError) (CausalHash, CausalHashId)) +syncFromCodeserver shouldValidate shareCodeserver branchRef hashJwt = do + Cli.Env {authHTTPClient, codebase} <- ask + let host = Codeserver.codeserverRegName shareCodeserver + let tlsPort = 443 + let port = fromMaybe tlsPort $ Codeserver.codeserverPort shareCodeserver + let syncV3Path = "/ucm/v3/sync" + Cli.with (WS.runClient host port syncV3Path) \conn -> do + withQueues inputBuffer outputBuffer conn $ \queues@Queues {send} -> do + let initMsg = + InitMsg + { initMsgClientVersion = syncV3ClientVersion, + initMsgBranchRef = branchRef, + initMsgRootCausal = hashJwt, + initMsgRequestedDepth = Nothing + } + atomically $ send $ InitStream initMsg + pendingRequestsVar <- newTVarIO (Set.singleton $ Share.hashJWTHash $ hashJwt) + let initState = SyncState {pendingRequestsVar} + liftIO $ doSync codebase initState queues + causalId <- flushTemp codebase (Share.hashJWTHash hashJwt) + _ + _ -- Get the final causalHashId + +data SyncState = SyncState + { pendingRequestsVar :: TVar (Set (EntityKind, Hash32)), + yetToRequestVar :: TVar (Set (EntityKind, Hash32)), + toIngestQueue :: TBQueue (Entity Hash32 Text), + rootCausalHash :: Hash32 + } + +-- | Given a stream that's already been initialized, receive entities and issue requests as needed. +doSync :: Codebase IO v a -> SyncState -> Queues () (FromEmitterMessage Hash32 Text) -> IO (Either SyncError ()) +doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, rootCausalHash} (Queues {send, receive, shutdown}) = Ki.scoped \scope -> do + errorVar <- newEmptyTMVarIO + let onErr err = do + atomically $ putTMVar errorVar err + shutdown + _ <- Ki.fork scope (receiverWorker onErr) + _ <- Ki.fork scope (requesterWorker onErr) + _ <- Ki.fork scope (ingestionWorker onErr) + atomically $ (Right <$> Ki.awaitAll scope) <|> (Left <$> readTMVar errorVar) + where + receiverWorker :: (SyncError -> IO ()) -> IO () + receiverWorker onErr = do + atomically receive >>= \case + EmitterErrorMsg err -> onErr err + EmitterEntityMsg entity -> do + missingDeps <- Codebase.runTransaction codebase $ saveEntity codebase entity + atomically $ do + pending <- readTVar pendingRequestsVar + let newDeps = Set.difference missingDeps pending + modifyTVar' pendingRequestsVar (Set.union newDeps) + receiverWorker onErr + EmitterDoneMsg -> return () + requesterWorker :: (SyncError -> IO ()) -> IO () + requesterWorker onErr = forever do + atomically $ do + requests <- readTVar yetToRequestVar + writeTVar yetToRequestVar Set.empty + modifyTVar' pendingRequestsVar (Set.union requests) + for_ requests $ \h -> send $ EntityRequestMsg h + + ingestionWorker :: (SyncError -> IO ()) -> IO () + ingestionWorker onErr = forever do + newEntities <- atomically $ do + flushTBQueue toIngestQueue + tempEntities <- case for newEntities (CBOR.deserialiseOrFailCBORBytes . entityData) of + -- TODO: proper error handling + Left err -> error $ show err + Right tempEntities -> pure tempEntities + Codebase.runTransaction codebase $ do + for_ newEntities $ \newEntity@(Entity {entityKind, entityHash, entityDepth, entityData}) -> do + case CBOR.deserialiseOrFailCBORBytes entityData of + -- TODO: proper error handling + Left err -> error $ show err + Right tempEntity -> do + Q.insertTempEntitySyncV3 rootCausalHash entityKind entityHash entityDepth entityData + + let allDeps = foldMap tempEntityDependencies tempEntities + missingDeps <- + (Set.toList allDeps) & filterM \(_depKind, depHash) -> do + Codebase.runTransaction codebase (Q.entityLocationSyncV3 depHash) <&> \case + Nothing -> True + _ -> False + -- Request any deps we're missing which also haven't already been requested + atomically $ do + pending <- readTVar pendingRequestsVar + let missingDepsSet = Set.fromList missingDeps + let unRequestedDeps = Set.difference missingDepsSet pending + modifyTVar' yetToRequestVar (Set.union unRequestedDeps) + +flushTemp :: Codebase IO v a -> CausalHash -> IO CausalHashId +flushTemp codebase rootCausalHash = do + _ + +tempEntityDependencies :: TempEntity -> Set (EntityKind, Hash32) +tempEntityDependencies entity = do + let componentDeps = Lens.setOf Entity.defns_ entity + patchDeps = Lens.setOf Entity.patches_ entity + branchHashes = Lens.setOf Entity.branchHashes_ entity <> Lens.setOf Entity.branches_ entity + causalHashes = Lens.setOf Entity.causalHashes_ entity + in Set.unions + [ Set.map (DefnComponentEntity,) componentDeps, + Set.map (PatchEntity,) patchDeps, + Set.map (NamespaceEntity,) branchHashes, + Set.map (CausalEntity,) causalHashes + ] diff --git a/unison-cli/unison-cli.cabal b/unison-cli/unison-cli.cabal index 6433469e07..e8281c7782 100644 --- a/unison-cli/unison-cli.cabal +++ b/unison-cli/unison-cli.cabal @@ -299,6 +299,7 @@ library , vector , wai , warp + , websockets , witch , witherable default-language: Haskell2010 diff --git a/unison-share-api/package.yaml b/unison-share-api/package.yaml index 4a258ea77e..4bba34c9af 100644 --- a/unison-share-api/package.yaml +++ b/unison-share-api/package.yaml @@ -26,6 +26,7 @@ library: - hs-mcp - http-media - http-types + - ki-unlifted - lens - lucid - memory diff --git a/unison-share-api/src/Unison/SyncV3/Types.hs b/unison-share-api/src/Unison/SyncV3/Types.hs index f698c48045..b58a0a20c8 100644 --- a/unison-share-api/src/Unison/SyncV3/Types.hs +++ b/unison-share-api/src/Unison/SyncV3/Types.hs @@ -10,16 +10,18 @@ module Unison.SyncV3.Types EntityDepth (..), HashMappings (..), HashTag (..), + BranchRef (..), ) where +import Codec.Serialise (Serialise) import Codec.Serialise qualified as CBOR import Control.Lens hiding ((.=)) import Data.Aeson import Data.Aeson qualified as Aeson import Data.ByteString qualified as BS import Data.ByteString.Lazy.Char8 qualified as BL -import Data.Int (Int64) +import Data.Int (Int32, Int64) import Data.Map (Map) import Data.Set (Set) import Data.Set qualified as Set @@ -34,18 +36,18 @@ import Unison.Server.Orphans () import Unison.Util.Servant.CBOR qualified as CBOR data InitMsg authedHash = InitMsg - { initMsgClientVersion :: Text, - initMsgProjectId :: Text, + { initMsgClientVersion :: Int32, + initMsgBranchRef :: BranchRef, initMsgRootCausal :: authedHash, initMsgRequestedDepth :: Maybe Int64 } deriving (Show, Eq) instance (ToJSON authedHash) => ToJSON (InitMsg authedHash) where - toJSON (InitMsg {initMsgClientVersion, initMsgProjectId, initMsgRootCausal, initMsgRequestedDepth}) = + toJSON (InitMsg {initMsgClientVersion, initMsgBranchRef, initMsgRootCausal, initMsgRequestedDepth}) = object [ "clientVersion" .= initMsgClientVersion, - "projectId" .= initMsgProjectId, + "branchRef" .= initMsgBranchRef, "rootCausal" .= initMsgRootCausal, "requestedDepth" .= initMsgRequestedDepth ] @@ -54,7 +56,7 @@ instance (FromJSON authedHash) => FromJSON (InitMsg authedHash) where parseJSON = withObject "InitMsg" $ \o -> InitMsg <$> o .: "clientVersion" - <*> o .: "projectId" + <*> o .: "branchRef" <*> o .: "rootCausal" <*> o .:? "requestedDepth" @@ -72,25 +74,32 @@ instance (CBOR.Serialise sh) => CBOR.Serialise (EntityRequestMsg sh) where pure $ EntityRequestMsg {hashes} data FromReceiverMessageTag - = InitStreamTag - | EntityRequestTag + = ReceiverInitStreamTag + | ReceiverEntityRequestTag + | ReceiverDoneTag instance CBOR.Serialise FromReceiverMessageTag where encode = \case - InitStreamTag -> CBOR.encode (0 :: Int) - EntityRequestTag -> CBOR.encode (1 :: Int) + ReceiverInitStreamTag -> CBOR.encode (0 :: Int) + ReceiverEntityRequestTag -> CBOR.encode (1 :: Int) + ReceiverDoneTag -> CBOR.encode (2 :: Int) decode = do tag <- CBOR.decode @Int case tag of - 0 -> pure InitStreamTag - 1 -> pure EntityRequestTag + 0 -> pure ReceiverInitStreamTag + 1 -> pure ReceiverEntityRequestTag + 2 -> pure ReceiverDoneTag _ -> fail $ "Unknown FromReceiverMessageTag: " <> show tag -- A message sent from the downloader to the emitter. data FromReceiverMessage ah hash - = InitStream (InitMsg ah) - | EntityRequest (EntityRequestMsg hash) + = -- Initialize the stream + ReceiverInitStream (InitMsg ah) + | -- Request more entities by hash. + ReceiverEntityRequest (EntityRequestMsg hash) + | -- Sent when the receiver has no outstanding requests. + ReceiverDone deriving (Show, Eq) instance (ToJSON ah, FromJSON ah) => CBOR.Serialise (InitMsg ah) where @@ -109,17 +118,19 @@ instance (ToJSON ah, FromJSON ah) => CBOR.Serialise (InitMsg ah) where instance (CBOR.Serialise h, ToJSON ah, FromJSON ah) => CBOR.Serialise (FromReceiverMessage ah h) where encode = \case - InitStream initMsg -> - CBOR.encode InitStreamTag + ReceiverInitStream initMsg -> + CBOR.encode ReceiverInitStreamTag <> CBOR.encode initMsg - EntityRequest msg -> - CBOR.encode EntityRequestTag + ReceiverEntityRequest msg -> + CBOR.encode ReceiverEntityRequestTag <> CBOR.encode msg + ReceiverDone -> CBOR.encode ReceiverDoneTag decode = do tag <- CBOR.decode @FromReceiverMessageTag case tag of - InitStreamTag -> InitStream <$> CBOR.decode @(InitMsg ah) - EntityRequestTag -> EntityRequest <$> CBOR.decode @(EntityRequestMsg h) + ReceiverInitStreamTag -> ReceiverInitStream <$> CBOR.decode @(InitMsg ah) + ReceiverEntityRequestTag -> ReceiverEntityRequest <$> CBOR.decode @(EntityRequestMsg h) + ReceiverDoneTag -> pure ReceiverDone data SyncError = InitializationError Text @@ -152,14 +163,14 @@ instance CBOR.Serialise SyncError where -- A message sent from the emitter to the downloader. data FromEmitterMessage hash text - = ErrorMsg SyncError - | -- | HashMappingsMsg (HashMappings hash smallHash) - EntityMsg (Entity hash text) + = EmitterErrorMsg SyncError + | EmitterEntityMsg (Entity hash text) + | EmitterDoneMsg instance (CBOR.Serialise hash, CBOR.Serialise text) => WebSocketsData (FromEmitterMessage hash text) where fromLazyByteString bytes = CBOR.deserialiseOrFailCBORBytes (CBOR.CBORBytes bytes) - & either (\err -> ErrorMsg . EncodingFailure $ "Error decoding CBOR message from bytes: " <> tShow err) id + & either (\err -> EmitterErrorMsg . EncodingFailure $ "Error decoding CBOR message from bytes: " <> tShow err) id toLazyByteString = CBOR.serialise @@ -175,8 +186,7 @@ data HashMappings hash smallHash = HashMappings data EntityKind = CausalEntity | NamespaceEntity - | TermEntity - | TypeEntity + | DefnComponentEntity | PatchEntity deriving stock (Show, Eq, Ord) @@ -184,18 +194,16 @@ instance CBOR.Serialise EntityKind where encode = \case CausalEntity -> CBOR.encode (0 :: Int) NamespaceEntity -> CBOR.encode (1 :: Int) - TermEntity -> CBOR.encode (2 :: Int) - TypeEntity -> CBOR.encode (3 :: Int) - PatchEntity -> CBOR.encode (4 :: Int) + DefnComponentEntity -> CBOR.encode (2 :: Int) + PatchEntity -> CBOR.encode (3 :: Int) decode = do tag <- CBOR.decode @Int case tag of 0 -> pure CausalEntity 1 -> pure NamespaceEntity - 2 -> pure TermEntity - 3 -> pure TypeEntity - 4 -> pure PatchEntity + 2 -> pure DefnComponentEntity + 3 -> pure PatchEntity _ -> fail $ "Unknown EntityKind tag: " <> show tag -- | The number of _levels_ of dependencies an entity has, @@ -269,34 +277,38 @@ instance (Ord smallHash, CBOR.Serialise hash, CBOR.Serialise smallHash) => CBOR. instance (CBOR.Serialise hash, CBOR.Serialise text) => CBOR.Serialise (FromEmitterMessage hash text) where encode = \case - ErrorMsg err -> CBOR.encode ErrorMsgTag <> CBOR.encode err + EmitterErrorMsg err -> CBOR.encode EmitterErrorMsgTag <> CBOR.encode err -- HashMappingsMsg msg -> CBOR.encode HashMappingsTag <> CBOR.encode msg - EntityMsg msg -> CBOR.encode EntityTag <> CBOR.encode msg + EmitterEntityMsg msg -> CBOR.encode EmitterEntityTag <> CBOR.encode msg + EmitterDoneMsg -> CBOR.encode EmitterDoneTag decode = do tag <- CBOR.decode @FromEmitterMessageTag case tag of - ErrorMsgTag -> ErrorMsg <$> CBOR.decode + EmitterErrorMsgTag -> EmitterErrorMsg <$> CBOR.decode -- HashMappingsTag -> HashMappingsMsg <$> CBOR.decode - EntityTag -> EntityMsg <$> CBOR.decode + EmitterEntityTag -> EmitterEntityMsg <$> CBOR.decode + EmitterDoneTag -> pure EmitterDoneMsg data FromEmitterMessageTag - = ErrorMsgTag + = EmitterErrorMsgTag | -- | HashMappingsTag - EntityTag + EmitterEntityTag + | EmitterDoneTag instance CBOR.Serialise FromEmitterMessageTag where encode = \case - ErrorMsgTag -> CBOR.encode (0 :: Int) + EmitterErrorMsgTag -> CBOR.encode (0 :: Int) -- HashMappingsTag -> CBOR.encode (1 :: Int) - EntityTag -> CBOR.encode (2 :: Int) + EmitterEntityTag -> CBOR.encode (2 :: Int) + EmitterDoneTag -> CBOR.encode (3 :: Int) decode = do tag <- CBOR.decode @Int case tag of - 0 -> pure ErrorMsgTag + 0 -> pure EmitterErrorMsgTag -- 1 -> pure HashMappingsTag - 2 -> pure EntityTag + 2 -> pure EmitterEntityTag _ -> fail $ "Unknown FromEmitterMessageTag: " <> show tag data MsgOrError err a @@ -340,3 +352,6 @@ instance CBOR.Serialise HashTag where decode = do (kind, idx) <- CBOR.decode @(EntityKind, Int64) pure $ HashTag (kind, idx) + +newtype BranchRef = BranchRef {unBranchRef :: Text} + deriving (Serialise, Eq, Show, Ord, ToJSON, FromJSON) via Text diff --git a/unison-share-api/src/Unison/Util/Websockets.hs b/unison-share-api/src/Unison/Util/Websockets.hs new file mode 100644 index 0000000000..f2e0574641 --- /dev/null +++ b/unison-share-api/src/Unison/Util/Websockets.hs @@ -0,0 +1,79 @@ +{-# LANGUAGE KindSignatures #-} + +module Unison.Util.Websockets + ( withQueues, + Queues (..), + ) +where + +import Control.Applicative +import Control.Lens (Profunctor (..)) +import Control.Monad +import Data.Text (Text) +import GHC.Natural +import Ki.Unlifted qualified as Ki +import Network.WebSockets +import UnliftIO + +-- | Allows interfacing with a websocket as a pair of bounded queues. +data Queues i o = Queues + { -- Receive from the client + receive :: STM o, + -- Send to the client + send :: i -> STM (), + shutdown :: IO (), + isConnectionClosed :: STM Bool + } + +instance Profunctor Queues where + dimap f g (Queues {receive, send, shutdown, isConnectionClosed}) = + Queues + { receive = g <$> receive, + send = send . f, + shutdown, + isConnectionClosed + } + +withQueues :: forall i o m a. (MonadUnliftIO m, WebSocketsData i, WebSocketsData o) => Natural -> Natural -> Connection -> (Queues i o -> m a) -> m a +withQueues inputBuffer outputBuffer conn action = Ki.scoped $ \scope -> do + receiveQ <- liftIO $ newTBQueueIO inputBuffer + sendQ <- liftIO $ newTBQueueIO outputBuffer + isClosedVar <- liftIO $ newTVarIO False + let receive = do readTBQueue receiveQ + let send msg = writeTBQueue sendQ msg + + let triggerClose :: IO () + triggerClose = do + alreadyClosed <- atomically $ do + isClosed <- readTVar isClosedVar + when (not isClosed) $ do + writeTVar isClosedVar True + pure () + pure isClosed + when (not alreadyClosed) $ do + sendClose conn ("Server is shutting down" :: Text) + + let queues = Queues {receive, send, shutdown = triggerClose, isConnectionClosed = readTVar isClosedVar} + _ <- Ki.fork scope $ recvWorker receiveQ + _ <- Ki.fork scope $ sendWorker sendQ + r <- action queues + liftIO $ triggerClose + pure r + where + recvWorker :: TBQueue o -> m () + recvWorker q = UnliftIO.handle handler $ do + msg <- liftIO $ receiveData conn + atomically $ writeTBQueue q msg + recvWorker q + + handler :: ConnectionException -> m () + handler = \case + CloseRequest _ _ -> liftIO $ sendClose conn ("Client closed connection" :: Text) + ConnectionClosed -> pure () + err -> throwIO err + + sendWorker :: TBQueue i -> m () + sendWorker q = UnliftIO.handle handler $ do + outMsgs <- atomically $ some $ readTBQueue q + liftIO $ sendBinaryDatas conn outMsgs + sendWorker q diff --git a/unison-share-api/unison-share-api.cabal b/unison-share-api/unison-share-api.cabal index 38bcc5395c..198a1c5385 100644 --- a/unison-share-api/unison-share-api.cabal +++ b/unison-share-api/unison-share-api.cabal @@ -54,6 +54,7 @@ library Unison.SyncV3.Types Unison.Util.Find Unison.Util.Servant.CBOR + Unison.Util.Websockets hs-source-dirs: src default-extensions: @@ -106,6 +107,7 @@ library , hs-mcp , http-media , http-types + , ki-unlifted , lens , lucid , memory From d0e34aa7b82ddb40f1a7e8e40c02c90bf32ff9cd Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Wed, 1 Oct 2025 15:42:38 -0700 Subject: [PATCH 07/15] WIP --- .../U/Codebase/Sqlite/Queries.hs | 3 +- unison-cli/src/Unison/Share/SyncV3.hs | 43 ++++++++++--------- unison-share-api/src/Unison/SyncV3/Types.hs | 19 ++++++++ 3 files changed, 44 insertions(+), 21 deletions(-) diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs index c0a50122fe..14f8b45b4e 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs @@ -303,6 +303,7 @@ import Data.Aeson qualified as Aeson import Data.Aeson.Text qualified as Aeson import Data.Bitraversable (bitraverse) import Data.ByteString.Lazy (LazyByteString) +import Data.ByteString.Lazy.Char8 qualified as BL import Data.Bytes.Put (runPutS) import Data.Foldable qualified as Foldable import Data.List qualified as List @@ -2302,7 +2303,7 @@ insertTempEntity entityHash entity missingDependencies = do entityType = Entity.entityType entity -insertTempEntitySyncV3 :: Hash32 -> Text -> Hash32 -> Int32 -> ByteString -> Transaction () +insertTempEntitySyncV3 :: Hash32 -> Text -> Hash32 -> Int64 -> BL.ByteString -> Transaction () insertTempEntitySyncV3 rootCausal entityKind entityHash entityDepth entityBlob = do execute [sql| diff --git a/unison-cli/src/Unison/Share/SyncV3.hs b/unison-cli/src/Unison/Share/SyncV3.hs index 93ca059dbf..38b2427452 100644 --- a/unison-cli/src/Unison/Share/SyncV3.hs +++ b/unison-cli/src/Unison/Share/SyncV3.hs @@ -3,17 +3,13 @@ module Unison.Share.SyncV3 ) where +import Control.Arrow ((&&&)) import Control.Monad.Reader -import Data.Int (Int32) -import Data.Maybe (fromMaybe) -import Data.Set (Set) import Data.Set qualified as Set import Data.Set.Lens qualified as Lens -import Data.Text (Text) import GHC.Natural import Ki qualified import Network.WebSockets.Client qualified as WS -import Servant.Client qualified as Servant import U.Codebase.HashTags import U.Codebase.Sqlite.DbId import U.Codebase.Sqlite.Entity qualified as Entity @@ -23,16 +19,13 @@ import Unison.Cli.Monad import Unison.Cli.Monad qualified as Cli import Unison.Codebase (Codebase) import Unison.Codebase qualified as Codebase -import Unison.Codebase.Editor.UCMVersion (UCMVersion) import Unison.Hash32 (Hash32) import Unison.Prelude import Unison.Share.API.Hash qualified as Share import Unison.Share.Codeserver qualified as Codeserver import Unison.Share.Sync.Types qualified as Sync -import Unison.Sqlite qualified as Sqlite import Unison.SyncV3.Types import Unison.SyncV3.Types as SyncV3 -import Unison.Util.Servant.CBOR (CBORBytes) import Unison.Util.Servant.CBOR qualified as CBOR import Unison.Util.Websockets (Queues (..), withQueues) import UnliftIO.STM @@ -74,10 +67,12 @@ syncFromCodeserver shouldValidate shareCodeserver branchRef hashJwt = do atomically $ send $ InitStream initMsg pendingRequestsVar <- newTVarIO (Set.singleton $ Share.hashJWTHash $ hashJwt) let initState = SyncState {pendingRequestsVar} - liftIO $ doSync codebase initState queues - causalId <- flushTemp codebase (Share.hashJWTHash hashJwt) + liftIO (doSync codebase initState queues) >>= \case + -- TODO: proper error handling + Left err -> error $ show err + Right () -> pure () + causalId <- liftIO $ flushTemp codebase (Share.hashJWTHash hashJwt) _ - _ -- Get the final causalHashId data SyncState = SyncState { pendingRequestsVar :: TVar (Set (EntityKind, Hash32)), @@ -103,6 +98,8 @@ doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, r atomically receive >>= \case EmitterErrorMsg err -> onErr err EmitterEntityMsg entity -> do + atomically $ do + toIngestQueue missingDeps <- Codebase.runTransaction codebase $ saveEntity codebase entity atomically $ do pending <- readTVar pendingRequestsVar @@ -116,7 +113,7 @@ doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, r requests <- readTVar yetToRequestVar writeTVar yetToRequestVar Set.empty modifyTVar' pendingRequestsVar (Set.union requests) - for_ requests $ \h -> send $ EntityRequestMsg h + for_ requests $ \h -> send $ ReceiverEntityRequestMsg h ingestionWorker :: (SyncError -> IO ()) -> IO () ingestionWorker onErr = forever do @@ -127,27 +124,33 @@ doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, r Left err -> error $ show err Right tempEntities -> pure tempEntities Codebase.runTransaction codebase $ do - for_ newEntities $ \newEntity@(Entity {entityKind, entityHash, entityDepth, entityData}) -> do - case CBOR.deserialiseOrFailCBORBytes entityData of - -- TODO: proper error handling - Left err -> error $ show err - Right tempEntity -> do - Q.insertTempEntitySyncV3 rootCausalHash entityKind entityHash entityDepth entityData + for_ newEntities $ \newEntity@(Entity {entityKind, entityHash, entityDepth, entityData = CBOR.CBORBytes entityBytes}) -> do + Q.insertTempEntitySyncV3 rootCausalHash (tShow entityKind) entityHash (unEntityDepth entityDepth) entityBytes let allDeps = foldMap tempEntityDependencies tempEntities + alreadyRequestedEntities <- atomically $ do + pending <- readTVar pendingRequestsVar + reqs <- readTVar yetToRequestVar + pure $ Set.union pending reqs + let unrequestedDeps = Set.difference allDeps alreadyRequestedEntities missingDeps <- - (Set.toList allDeps) & filterM \(_depKind, depHash) -> do + (Set.toList unrequestedDeps) & filterA \(_depKind, depHash) -> do Codebase.runTransaction codebase (Q.entityLocationSyncV3 depHash) <&> \case Nothing -> True _ -> False + let newlyInserted = + newEntities + <&> (entityKind &&& entityHash) + & Set.fromList -- Request any deps we're missing which also haven't already been requested atomically $ do pending <- readTVar pendingRequestsVar let missingDepsSet = Set.fromList missingDeps let unRequestedDeps = Set.difference missingDepsSet pending modifyTVar' yetToRequestVar (Set.union unRequestedDeps) + modifyTVar' pendingRequestsVar (\pending -> Set.difference pending newlyInserted) -flushTemp :: Codebase IO v a -> CausalHash -> IO CausalHashId +flushTemp :: Codebase IO v a -> Hash32 -> IO CausalHashId flushTemp codebase rootCausalHash = do _ diff --git a/unison-share-api/src/Unison/SyncV3/Types.hs b/unison-share-api/src/Unison/SyncV3/Types.hs index b58a0a20c8..564a6eae1f 100644 --- a/unison-share-api/src/Unison/SyncV3/Types.hs +++ b/unison-share-api/src/Unison/SyncV3/Types.hs @@ -33,6 +33,7 @@ import U.Codebase.Sqlite.TempEntity import Unison.Hash32 (Hash32) import Unison.Prelude (tShow) import Unison.Server.Orphans () +import Unison.Sqlite qualified as Sqlite import Unison.Util.Servant.CBOR qualified as CBOR data InitMsg authedHash = InitMsg @@ -190,6 +191,24 @@ data EntityKind | PatchEntity deriving stock (Show, Eq, Ord) +instance Sqlite.ToField EntityKind where + toField = + Sqlite.toField . \case + CausalEntity -> (0 :: Int) + NamespaceEntity -> 1 + DefnComponentEntity -> 2 + PatchEntity -> 3 + +instance Sqlite.FromField EntityKind where + fromField field = do + tag <- Sqlite.fromField field + case tag of + (0 :: Int) -> pure CausalEntity + 1 -> pure NamespaceEntity + 2 -> pure DefnComponentEntity + 3 -> pure PatchEntity + _ -> fail $ "Unknown EntityKind tag: " <> show tag + instance CBOR.Serialise EntityKind where encode = \case CausalEntity -> CBOR.encode (0 :: Int) From 7d4eb0ae461a6099eb86c7ebcc2e7666a684e30b Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Wed, 1 Oct 2025 16:30:13 -0700 Subject: [PATCH 08/15] SyncV3 client mostly implemented --- .../U/Codebase/Sqlite/Queries.hs | 11 +++ unison-cli/package.yaml | 1 + unison-cli/src/Unison/Share/SyncV3.hs | 88 +++++++++++++------ unison-cli/unison-cli.cabal | 1 + unison-share-api/src/Unison/SyncV3/Types.hs | 15 +--- 5 files changed, 76 insertions(+), 40 deletions(-) diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs index 14f8b45b4e..58f2756540 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs @@ -227,6 +227,7 @@ module U.Codebase.Sqlite.Queries expectTempEntity, deleteTempEntity, clearTempEntityTables, + streamTempEntitiesSyncV3, -- * elaborate hashes elaborateHashes, @@ -4032,3 +4033,13 @@ saveSquashResult bhId chId = ) ON CONFLICT DO NOTHING |] + +streamTempEntitiesSyncV3 :: Hash32 -> (Transaction (Maybe (Hash32, BL.ByteString)) -> Transaction a) -> Transaction a +streamTempEntitiesSyncV3 rootCausalHash action = do + Sqlite.queryStreamRow @(Hash32, BL.ByteString) + [sql| + SELECT entity_hash, entity_data + WHERE root_causal = :rootCausalHash + ORDER BY entity_depth ASC + |] + action diff --git a/unison-cli/package.yaml b/unison-cli/package.yaml index 369119746c..8aa0135894 100644 --- a/unison-cli/package.yaml +++ b/unison-cli/package.yaml @@ -105,6 +105,7 @@ library: - wai - warp - websockets + - wuss - witch - witherable diff --git a/unison-cli/src/Unison/Share/SyncV3.hs b/unison-cli/src/Unison/Share/SyncV3.hs index 38b2427452..24704a14eb 100644 --- a/unison-cli/src/Unison/Share/SyncV3.hs +++ b/unison-cli/src/Unison/Share/SyncV3.hs @@ -9,26 +9,30 @@ import Data.Set qualified as Set import Data.Set.Lens qualified as Lens import GHC.Natural import Ki qualified -import Network.WebSockets.Client qualified as WS import U.Codebase.HashTags import U.Codebase.Sqlite.DbId import U.Codebase.Sqlite.Entity qualified as Entity import U.Codebase.Sqlite.Queries qualified as Q import U.Codebase.Sqlite.TempEntity (TempEntity) +import U.Codebase.Sqlite.V2.HashHandle (v2HashHandle) import Unison.Cli.Monad import Unison.Cli.Monad qualified as Cli import Unison.Codebase (Codebase) import Unison.Codebase qualified as Codebase import Unison.Hash32 (Hash32) import Unison.Prelude +import Unison.Server.Orphans () import Unison.Share.API.Hash qualified as Share import Unison.Share.Codeserver qualified as Codeserver import Unison.Share.Sync.Types qualified as Sync +import Unison.Sync.Common qualified as Sync import Unison.SyncV3.Types import Unison.SyncV3.Types as SyncV3 import Unison.Util.Servant.CBOR qualified as CBOR import Unison.Util.Websockets (Queues (..), withQueues) import UnliftIO.STM +import Wuss qualified +import Network.WebSockets qualified as WS -- Websocket send/receive buffer sizes inputBuffer :: Natural @@ -37,6 +41,9 @@ inputBuffer = 1000 outputBuffer :: Natural outputBuffer = 1000 +transactionBatchSize :: Natural +transactionBatchSize = 1000 + syncV3ClientVersion :: Int32 syncV3ClientVersion = 1 @@ -49,13 +56,18 @@ syncFromCodeserver :: -- | The hash to download. Share.HashJWT -> Cli (Either (Sync.SyncError SyncV3.SyncError) (CausalHash, CausalHashId)) -syncFromCodeserver shouldValidate shareCodeserver branchRef hashJwt = do - Cli.Env {authHTTPClient, codebase} <- ask +syncFromCodeserver _shouldValidate shareCodeserver branchRef hashJwt = do + Cli.Env {codebase} <- ask let host = Codeserver.codeserverRegName shareCodeserver let tlsPort = 443 - let port = fromMaybe tlsPort $ Codeserver.codeserverPort shareCodeserver + let port = maybe tlsPort fromIntegral $ (Codeserver.codeserverPort) shareCodeserver let syncV3Path = "/ucm/v3/sync" - Cli.with (WS.runClient host port syncV3Path) \conn -> do + let rootCausalHash = Share.hashJWTHash hashJwt + -- Enable compression + let connectionOptions = WS.defaultConnectionOptions {WS.connectionCompressionOptions = WS.PermessageDeflateCompression WS.defaultPermessageDeflate} + -- TODO: Add authentication headers manually. + let headers = [] + liftIO $ (Wuss.runSecureClientWith host port syncV3Path connectionOptions headers) \conn -> do withQueues inputBuffer outputBuffer conn $ \queues@Queues {send} -> do let initMsg = InitMsg @@ -64,15 +76,24 @@ syncFromCodeserver shouldValidate shareCodeserver branchRef hashJwt = do initMsgRootCausal = hashJwt, initMsgRequestedDepth = Nothing } - atomically $ send $ InitStream initMsg - pendingRequestsVar <- newTVarIO (Set.singleton $ Share.hashJWTHash $ hashJwt) - let initState = SyncState {pendingRequestsVar} + atomically $ send $ Msg $ ReceiverInitStream initMsg + pendingRequestsVar <- newTVarIO (Set.singleton (CausalEntity, rootCausalHash)) + yetToRequestVar <- newTVarIO Set.empty + toIngestQueue <- newTBQueueIO transactionBatchSize + let initState = + SyncState + { pendingRequestsVar, + yetToRequestVar, + toIngestQueue, + rootCausalHash + } + liftIO (doSync codebase initState queues) >>= \case -- TODO: proper error handling Left err -> error $ show err Right () -> pure () causalId <- liftIO $ flushTemp codebase (Share.hashJWTHash hashJwt) - _ + pure $ Right (Sync.hash32ToCausalHash rootCausalHash, causalId) data SyncState = SyncState { pendingRequestsVar :: TVar (Set (EntityKind, Hash32)), @@ -82,7 +103,7 @@ data SyncState = SyncState } -- | Given a stream that's already been initialized, receive entities and issue requests as needed. -doSync :: Codebase IO v a -> SyncState -> Queues () (FromEmitterMessage Hash32 Text) -> IO (Either SyncError ()) +doSync :: Codebase IO v a -> SyncState -> Queues (MsgOrError SyncError (FromReceiverMessage Share.HashJWT Hash32)) (MsgOrError SyncError (FromEmitterMessage Hash32 Text)) -> IO (Either SyncError ()) doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, rootCausalHash} (Queues {send, receive, shutdown}) = Ki.scoped \scope -> do errorVar <- newEmptyTMVarIO let onErr err = do @@ -96,38 +117,36 @@ doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, r receiverWorker :: (SyncError -> IO ()) -> IO () receiverWorker onErr = do atomically receive >>= \case - EmitterErrorMsg err -> onErr err - EmitterEntityMsg entity -> do + Msg (EmitterErrorMsg err) -> onErr err + Msg (EmitterEntityMsg entity) -> do atomically $ do - toIngestQueue - missingDeps <- Codebase.runTransaction codebase $ saveEntity codebase entity - atomically $ do - pending <- readTVar pendingRequestsVar - let newDeps = Set.difference missingDeps pending - modifyTVar' pendingRequestsVar (Set.union newDeps) + writeTBQueue toIngestQueue entity receiverWorker onErr - EmitterDoneMsg -> return () + Msg EmitterDoneMsg -> return () + Err err -> onErr err requesterWorker :: (SyncError -> IO ()) -> IO () - requesterWorker onErr = forever do + requesterWorker _onErr = forever do atomically $ do requests <- readTVar yetToRequestVar writeTVar yetToRequestVar Set.empty modifyTVar' pendingRequestsVar (Set.union requests) - for_ requests $ \h -> send $ ReceiverEntityRequestMsg h + send $ Msg $ ReceiverEntityRequest $ EntityRequestMsg (Set.toList requests) ingestionWorker :: (SyncError -> IO ()) -> IO () - ingestionWorker onErr = forever do + ingestionWorker _onErr = forever do newEntities <- atomically $ do flushTBQueue toIngestQueue + Codebase.runTransaction codebase $ do + -- TODO: do hash validation based on shouldValidate + for_ newEntities $ \(Entity {entityKind, entityHash, entityDepth, entityData = CBOR.CBORBytes entityBytes}) -> do + Q.insertTempEntitySyncV3 rootCausalHash (tShow entityKind) entityHash (unEntityDepth entityDepth) entityBytes + tempEntities <- case for newEntities (CBOR.deserialiseOrFailCBORBytes . entityData) of -- TODO: proper error handling Left err -> error $ show err Right tempEntities -> pure tempEntities - Codebase.runTransaction codebase $ do - for_ newEntities $ \newEntity@(Entity {entityKind, entityHash, entityDepth, entityData = CBOR.CBORBytes entityBytes}) -> do - Q.insertTempEntitySyncV3 rootCausalHash (tShow entityKind) entityHash (unEntityDepth entityDepth) entityBytes - let allDeps = foldMap tempEntityDependencies tempEntities + -- TODO: double-check whether it's okay to have this as a separate atomic block. alreadyRequestedEntities <- atomically $ do pending <- readTVar pendingRequestsVar reqs <- readTVar yetToRequestVar @@ -152,7 +171,22 @@ doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, r flushTemp :: Codebase IO v a -> Hash32 -> IO CausalHashId flushTemp codebase rootCausalHash = do - _ + Codebase.runTransaction codebase $ do + Q.streamTempEntitiesSyncV3 rootCausalHash \next -> + do + let loop = do + next >>= \case + Nothing -> pure () + Just (hash, tempEntityBytes) -> + do + tempEntity <- case CBOR.deserialiseOrFailCBORBytes (CBOR.CBORBytes tempEntityBytes) of + -- TODO: proper error handling + Left err -> error $ show err + Right tempEntity -> pure tempEntity + void $ Q.saveTempEntityInMain v2HashHandle hash tempEntity + loop + loop + Q.expectCausalHashIdByCausalHash (Sync.hash32ToCausalHash rootCausalHash) tempEntityDependencies :: TempEntity -> Set (EntityKind, Hash32) tempEntityDependencies entity = do diff --git a/unison-cli/unison-cli.cabal b/unison-cli/unison-cli.cabal index e8281c7782..37651a2157 100644 --- a/unison-cli/unison-cli.cabal +++ b/unison-cli/unison-cli.cabal @@ -302,6 +302,7 @@ library , websockets , witch , witherable + , wuss default-language: Haskell2010 if !os(windows) build-depends: diff --git a/unison-share-api/src/Unison/SyncV3/Types.hs b/unison-share-api/src/Unison/SyncV3/Types.hs index 564a6eae1f..ac13e2e0aa 100644 --- a/unison-share-api/src/Unison/SyncV3/Types.hs +++ b/unison-share-api/src/Unison/SyncV3/Types.hs @@ -139,6 +139,7 @@ data SyncError | EncodingFailure Text | -- The caller asked for a Hash they shouldn't have access to. ForbiddenEntityRequest (Set (EntityKind, Hash32)) + deriving (Show, Eq) instance CBOR.Serialise SyncError where encode = \case @@ -168,18 +169,6 @@ data FromEmitterMessage hash text | EmitterEntityMsg (Entity hash text) | EmitterDoneMsg -instance (CBOR.Serialise hash, CBOR.Serialise text) => WebSocketsData (FromEmitterMessage hash text) where - fromLazyByteString bytes = - CBOR.deserialiseOrFailCBORBytes (CBOR.CBORBytes bytes) - & either (\err -> EmitterErrorMsg . EncodingFailure $ "Error decoding CBOR message from bytes: " <> tShow err) id - - toLazyByteString = CBOR.serialise - - fromDataMessage dm = do - case dm of - WS.Text bytes _ -> WS.fromLazyByteString bytes - WS.Binary bytes -> WS.fromLazyByteString bytes - data HashMappings hash smallHash = HashMappings { hashMappings :: Map smallHash hash } @@ -346,7 +335,7 @@ instance (CBOR.Serialise a, CBOR.Serialise err) => CBOR.Serialise (MsgOrError er 1 -> Err <$> CBOR.decode _ -> fail $ "Unknown MsgOrError tag: " <> show tag -instance (CBOR.Serialise sh, ToJSON ah, FromJSON ah) => WebSocketsData (MsgOrError SyncError (FromReceiverMessage ah sh)) where +instance (Serialise msg) => WebSocketsData (MsgOrError SyncError msg) where fromLazyByteString bytes = CBOR.deserialiseOrFailCBORBytes (CBOR.CBORBytes bytes) & either (\err -> Err . EncodingFailure $ "Error decoding CBOR message from bytes: " <> tShow err) Msg From fc554483ee81bb4e7a7fa08099dcf5df6184bab9 Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Thu, 2 Oct 2025 11:01:16 -0700 Subject: [PATCH 09/15] SyncV3 stuff compiling --- unison-cli/src/Unison/Cli/DownloadUtils.hs | 21 +++++++++++++++++++-- unison-cli/src/Unison/Share/SyncV3.hs | 8 ++++---- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/unison-cli/src/Unison/Cli/DownloadUtils.hs b/unison-cli/src/Unison/Cli/DownloadUtils.hs index 936d5f00fb..16a42c8451 100644 --- a/unison-cli/src/Unison/Cli/DownloadUtils.hs +++ b/unison-cli/src/Unison/Cli/DownloadUtils.hs @@ -31,16 +31,19 @@ import Unison.NameSegment.Internal qualified as NameSegment import Unison.Prelude import Unison.Share.API.Hash qualified as Share import Unison.Share.Codeserver qualified as Codeserver +import Unison.Share.Codeserver qualified as Share import Unison.Share.Sync qualified as Share import Unison.Share.Sync.Types qualified as Share import Unison.Share.SyncV2 qualified as SyncV2 +import Unison.Share.SyncV3 qualified as SyncV3 import Unison.Share.Types (codeserverBaseURL) import Unison.Sync.Common qualified as Sync.Common import Unison.Sync.Types qualified as Share import Unison.SyncV2.Types qualified as SyncV2 +import Unison.SyncV3.Types qualified as SyncV3 import UnliftIO.Environment qualified as UnliftIO -data SyncVersion = SyncV1 | SyncV2 +data SyncVersion = SyncV1 | SyncV2 | SyncV3 deriving (Eq, Show) -- | The version of the sync protocol to use. @@ -49,7 +52,8 @@ syncVersion = unsafePerformIO do UnliftIO.lookupEnv "UNISON_SYNC_VERSION" <&> \case Just "1" -> SyncV1 - _ -> SyncV2 + Just "2" -> SyncV2 + _ -> SyncV3 -- | Download a project/branch from Share. downloadProjectBranchFromShare :: @@ -95,6 +99,19 @@ downloadProjectBranchFromShare useSquashed branch isPull = Share.SyncError pullErr -> Output.ShareErrorPullV2 pullErr Share.TransportError err -> Output.ShareErrorTransport err + SyncV3 -> do + let branchRef = SyncV3.BranchRef (into @Text (ProjectAndBranch branch.projectName remoteProjectBranchName)) + let shouldValidate = Codeserver.isCustomCodeserver Codeserver.defaultCodeserver + when isPull $ do + pb <- Cli.getCurrentProjectBranch + currentCausalHash <- Cli.runTransaction $ Ops.expectProjectBranchHead pb.projectId pb.branchId + Cli.respond $ Output.SyncingFromTo currentCausalHash (Sync.Common.hash32ToCausalHash causalHash32) + result <- SyncV3.syncFromCodeserver shouldValidate Share.defaultCodeserver branchRef causalHashJwt + void result & onLeft \err0 -> do + done case err0 of + Share.SyncError _pullErr -> + error "TODO: define SyncV3 pull error and handle it here" + Share.TransportError err -> Output.ShareErrorTransport err pure (Sync.Common.hash32ToCausalHash (Share.hashJWTHash causalHashJwt)) -- | Download loose code from Share. diff --git a/unison-cli/src/Unison/Share/SyncV3.hs b/unison-cli/src/Unison/Share/SyncV3.hs index 24704a14eb..76168362cb 100644 --- a/unison-cli/src/Unison/Share/SyncV3.hs +++ b/unison-cli/src/Unison/Share/SyncV3.hs @@ -9,6 +9,7 @@ import Data.Set qualified as Set import Data.Set.Lens qualified as Lens import GHC.Natural import Ki qualified +import Network.WebSockets qualified as WS import U.Codebase.HashTags import U.Codebase.Sqlite.DbId import U.Codebase.Sqlite.Entity qualified as Entity @@ -32,7 +33,6 @@ import Unison.Util.Servant.CBOR qualified as CBOR import Unison.Util.Websockets (Queues (..), withQueues) import UnliftIO.STM import Wuss qualified -import Network.WebSockets qualified as WS -- Websocket send/receive buffer sizes inputBuffer :: Natural @@ -56,11 +56,11 @@ syncFromCodeserver :: -- | The hash to download. Share.HashJWT -> Cli (Either (Sync.SyncError SyncV3.SyncError) (CausalHash, CausalHashId)) -syncFromCodeserver _shouldValidate shareCodeserver branchRef hashJwt = do +syncFromCodeserver _shouldValidate codeserver branchRef hashJwt = do Cli.Env {codebase} <- ask - let host = Codeserver.codeserverRegName shareCodeserver + let host = Codeserver.codeserverRegName codeserver let tlsPort = 443 - let port = maybe tlsPort fromIntegral $ (Codeserver.codeserverPort) shareCodeserver + let port = maybe tlsPort fromIntegral $ (Codeserver.codeserverPort) codeserver let syncV3Path = "/ucm/v3/sync" let rootCausalHash = Share.hashJWTHash hashJwt -- Enable compression From 5de43b74f6b86c3222f56224484a26440c1ea56d Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Thu, 2 Oct 2025 12:15:50 -0700 Subject: [PATCH 10/15] Upgrade websocket Queues --- unison-cli/src/Unison/Share/SyncV3.hs | 16 +++-- unison-share-api/src/Unison/SyncV3/Types.hs | 67 +++---------------- .../src/Unison/Util/Websockets.hs | 63 +++++++++-------- 3 files changed, 56 insertions(+), 90 deletions(-) diff --git a/unison-cli/src/Unison/Share/SyncV3.hs b/unison-cli/src/Unison/Share/SyncV3.hs index 76168362cb..d30fec9168 100644 --- a/unison-cli/src/Unison/Share/SyncV3.hs +++ b/unison-cli/src/Unison/Share/SyncV3.hs @@ -104,7 +104,7 @@ data SyncState = SyncState -- | Given a stream that's already been initialized, receive entities and issue requests as needed. doSync :: Codebase IO v a -> SyncState -> Queues (MsgOrError SyncError (FromReceiverMessage Share.HashJWT Hash32)) (MsgOrError SyncError (FromEmitterMessage Hash32 Text)) -> IO (Either SyncError ()) -doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, rootCausalHash} (Queues {send, receive, shutdown}) = Ki.scoped \scope -> do +doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, rootCausalHash} (Queues {send, receive, shutdown, connectionClosed}) = Ki.scoped \scope -> do errorVar <- newEmptyTMVarIO let onErr err = do atomically $ putTMVar errorVar err @@ -112,17 +112,25 @@ doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, r _ <- Ki.fork scope (receiverWorker onErr) _ <- Ki.fork scope (requesterWorker onErr) _ <- Ki.fork scope (ingestionWorker onErr) - atomically $ (Right <$> Ki.awaitAll scope) <|> (Left <$> readTMVar errorVar) + result <- + atomically $ + (Right <$> Ki.awaitAll scope) + <|> (Left . Left <$> readTMVar errorVar) + <|> (Left . Right <$> connectionClosed) + case result of + Left (Left syncErr) -> pure $ Left syncErr + Left (Right mayConnErr) -> case mayConnErr of + Nothing -> pure $ Right () + Just connErr -> pure $ Left $ ConnectionError (tShow connErr) + Right () -> pure $ Right () where receiverWorker :: (SyncError -> IO ()) -> IO () receiverWorker onErr = do atomically receive >>= \case - Msg (EmitterErrorMsg err) -> onErr err Msg (EmitterEntityMsg entity) -> do atomically $ do writeTBQueue toIngestQueue entity receiverWorker onErr - Msg EmitterDoneMsg -> return () Err err -> onErr err requesterWorker :: (SyncError -> IO ()) -> IO () requesterWorker _onErr = forever do diff --git a/unison-share-api/src/Unison/SyncV3/Types.hs b/unison-share-api/src/Unison/SyncV3/Types.hs index ac13e2e0aa..909bd7d459 100644 --- a/unison-share-api/src/Unison/SyncV3/Types.hs +++ b/unison-share-api/src/Unison/SyncV3/Types.hs @@ -77,20 +77,17 @@ instance (CBOR.Serialise sh) => CBOR.Serialise (EntityRequestMsg sh) where data FromReceiverMessageTag = ReceiverInitStreamTag | ReceiverEntityRequestTag - | ReceiverDoneTag instance CBOR.Serialise FromReceiverMessageTag where encode = \case ReceiverInitStreamTag -> CBOR.encode (0 :: Int) ReceiverEntityRequestTag -> CBOR.encode (1 :: Int) - ReceiverDoneTag -> CBOR.encode (2 :: Int) decode = do tag <- CBOR.decode @Int case tag of 0 -> pure ReceiverInitStreamTag 1 -> pure ReceiverEntityRequestTag - 2 -> pure ReceiverDoneTag _ -> fail $ "Unknown FromReceiverMessageTag: " <> show tag -- A message sent from the downloader to the emitter. @@ -99,8 +96,6 @@ data FromReceiverMessage ah hash ReceiverInitStream (InitMsg ah) | -- Request more entities by hash. ReceiverEntityRequest (EntityRequestMsg hash) - | -- Sent when the receiver has no outstanding requests. - ReceiverDone deriving (Show, Eq) instance (ToJSON ah, FromJSON ah) => CBOR.Serialise (InitMsg ah) where @@ -125,13 +120,11 @@ instance (CBOR.Serialise h, ToJSON ah, FromJSON ah) => CBOR.Serialise (FromRecei ReceiverEntityRequest msg -> CBOR.encode ReceiverEntityRequestTag <> CBOR.encode msg - ReceiverDone -> CBOR.encode ReceiverDoneTag decode = do tag <- CBOR.decode @FromReceiverMessageTag case tag of ReceiverInitStreamTag -> ReceiverInitStream <$> CBOR.decode @(InitMsg ah) ReceiverEntityRequestTag -> ReceiverEntityRequest <$> CBOR.decode @(EntityRequestMsg h) - ReceiverDoneTag -> pure ReceiverDone data SyncError = InitializationError Text @@ -139,6 +132,7 @@ data SyncError | EncodingFailure Text | -- The caller asked for a Hash they shouldn't have access to. ForbiddenEntityRequest (Set (EntityKind, Hash32)) + | ConnectionError Text deriving (Show, Eq) instance CBOR.Serialise SyncError where @@ -151,6 +145,8 @@ instance CBOR.Serialise SyncError where CBOR.encode (2 :: Int) <> CBOR.encode msg ForbiddenEntityRequest hashes -> CBOR.encode (3 :: Int) <> CBOR.encode hashes + ConnectionError err -> + CBOR.encode (4 :: Int) <> CBOR.encode err decode = do tag <- CBOR.decode @Int @@ -161,13 +157,14 @@ instance CBOR.Serialise SyncError where pure $ UnexpectedMessage (BL.fromStrict bs) 2 -> EncodingFailure <$> CBOR.decode 3 -> ForbiddenEntityRequest . Set.fromList <$> CBOR.decode + 4 -> do + err <- CBOR.decode @Text + pure $ ConnectionError err _ -> fail $ "Unknown SyncError tag: " <> show tag -- A message sent from the emitter to the downloader. data FromEmitterMessage hash text - = EmitterErrorMsg SyncError - | EmitterEntityMsg (Entity hash text) - | EmitterDoneMsg + = EmitterEntityMsg (Entity hash text) data HashMappings hash smallHash = HashMappings { hashMappings :: Map smallHash hash @@ -230,36 +227,6 @@ data Entity hash text = Entity entityData :: CBOR.CBORBytes TempEntity } --- entityTexts_ :: Traversal (Entity smallHash text) (Entity smallHash text') text text' --- entityTexts_ f (Entity {entityData, ..}) = --- (\entityData' -> Entity {entityData = entityData', ..}) <$> Entity.texts_ f entityData - --- entityHashesSetter_ :: (Monad m) => LensLike m (Entity smallHash text) (Entity smallHash' text) smallHash smallHash' --- entityHashesSetter_ f (Entity {entityHash, entityData, ..}) = --- (\entityHash' entityData' -> Entity {entityHash = entityHash', entityData = entityData', ..}) --- <$> f entityHash --- <*> ( entityData --- & Entity.hashes_ f --- >>= Entity.defns_ f --- >>= Entity.patches_ f --- >>= Entity.branchHashes_ f --- >>= Entity.branches_ f --- >>= Entity.causalHashes_ f --- ) - --- -- | It's technically possible to implement entityHashesGetter_ and entityHashesSetter_ --- -- as a single Traversal, but it's a ton of extra unpacking/packing that's probably not worth --- -- it. --- entityHashesGetter_ :: Fold (Entity smallHash text) smallHash --- entityHashesGetter_ f (Entity {entityHash, entityData}) = --- phantom (f entityHash) --- *> phantom (Entity.hashes_ f entityData) --- *> phantom (Entity.defns_ f entityData) --- *> phantom (Entity.patches_ f entityData) --- *> phantom (Entity.branchHashes_ f entityData) --- *> phantom (Entity.branches_ f entityData) --- *> phantom (Entity.causalHashes_ f entityData) - instance (CBOR.Serialise smallHash, CBOR.Serialise text) => CBOR.Serialise (Entity smallHash text) where encode (Entity {entityHash, entityKind, entityDepth, entityData}) = CBOR.encode entityHash @@ -285,38 +252,24 @@ instance (Ord smallHash, CBOR.Serialise hash, CBOR.Serialise smallHash) => CBOR. instance (CBOR.Serialise hash, CBOR.Serialise text) => CBOR.Serialise (FromEmitterMessage hash text) where encode = \case - EmitterErrorMsg err -> CBOR.encode EmitterErrorMsgTag <> CBOR.encode err - -- HashMappingsMsg msg -> CBOR.encode HashMappingsTag <> CBOR.encode msg EmitterEntityMsg msg -> CBOR.encode EmitterEntityTag <> CBOR.encode msg - EmitterDoneMsg -> CBOR.encode EmitterDoneTag decode = do tag <- CBOR.decode @FromEmitterMessageTag case tag of - EmitterErrorMsgTag -> EmitterErrorMsg <$> CBOR.decode - -- HashMappingsTag -> HashMappingsMsg <$> CBOR.decode EmitterEntityTag -> EmitterEntityMsg <$> CBOR.decode - EmitterDoneTag -> pure EmitterDoneMsg data FromEmitterMessageTag - = EmitterErrorMsgTag - | -- | HashMappingsTag - EmitterEntityTag - | EmitterDoneTag + = EmitterEntityTag instance CBOR.Serialise FromEmitterMessageTag where encode = \case - EmitterErrorMsgTag -> CBOR.encode (0 :: Int) - -- HashMappingsTag -> CBOR.encode (1 :: Int) - EmitterEntityTag -> CBOR.encode (2 :: Int) - EmitterDoneTag -> CBOR.encode (3 :: Int) + EmitterEntityTag -> CBOR.encode (0 :: Int) decode = do tag <- CBOR.decode @Int case tag of - 0 -> pure EmitterErrorMsgTag - -- 1 -> pure HashMappingsTag - 2 -> pure EmitterEntityTag + 0 -> pure EmitterEntityTag _ -> fail $ "Unknown FromEmitterMessageTag: " <> show tag data MsgOrError err a diff --git a/unison-share-api/src/Unison/Util/Websockets.hs b/unison-share-api/src/Unison/Util/Websockets.hs index f2e0574641..721034ec3c 100644 --- a/unison-share-api/src/Unison/Util/Websockets.hs +++ b/unison-share-api/src/Unison/Util/Websockets.hs @@ -22,58 +22,63 @@ data Queues i o = Queues -- Send to the client send :: i -> STM (), shutdown :: IO (), - isConnectionClosed :: STM Bool + -- This succeeds with a 'Just' value if the connection was closed due to an exception, + -- 'Nothing' if it was closed normally, or retries if the connection is still open. + connectionClosed :: STM (Maybe ConnectionException) } instance Profunctor Queues where - dimap f g (Queues {receive, send, shutdown, isConnectionClosed}) = + dimap f g (Queues {receive, send, shutdown, connectionClosed}) = Queues { receive = g <$> receive, send = send . f, shutdown, - isConnectionClosed + connectionClosed } withQueues :: forall i o m a. (MonadUnliftIO m, WebSocketsData i, WebSocketsData o) => Natural -> Natural -> Connection -> (Queues i o -> m a) -> m a withQueues inputBuffer outputBuffer conn action = Ki.scoped $ \scope -> do receiveQ <- liftIO $ newTBQueueIO inputBuffer sendQ <- liftIO $ newTBQueueIO outputBuffer - isClosedVar <- liftIO $ newTVarIO False + connectionClosedMVar <- liftIO $ newEmptyTMVarIO let receive = do readTBQueue receiveQ let send msg = writeTBQueue sendQ msg - let triggerClose :: IO () - triggerClose = do - alreadyClosed <- atomically $ do - isClosed <- readTVar isClosedVar - when (not isClosed) $ do - writeTVar isClosedVar True - pure () - pure isClosed - when (not alreadyClosed) $ do - sendClose conn ("Server is shutting down" :: Text) + let triggerClose :: forall n. (MonadIO n) => (Maybe ConnectionException) -> n () + triggerClose mayErr = do + newlyClosed <- atomically $ do + tryPutTMVar connectionClosedMVar mayErr + when newlyClosed $ do + -- If we closed due to a connection error, we don't need to send a close. + -- If we're shutting down normally, we send a close message. + case mayErr of + Nothing -> liftIO $ sendClose conn ("Server is shutting down" :: Text) + _ -> pure () - let queues = Queues {receive, send, shutdown = triggerClose, isConnectionClosed = readTVar isClosedVar} - _ <- Ki.fork scope $ recvWorker receiveQ - _ <- Ki.fork scope $ sendWorker sendQ + let queues = Queues {receive, send, shutdown = (triggerClose Nothing), connectionClosed = readTMVar connectionClosedMVar} + _ <- Ki.fork scope $ recvWorker triggerClose receiveQ + _ <- Ki.fork scope $ sendWorker triggerClose sendQ r <- action queues - liftIO $ triggerClose + -- Ensure the connection is closed when done. + liftIO $ triggerClose Nothing pure r where - recvWorker :: TBQueue o -> m () - recvWorker q = UnliftIO.handle handler $ do + recvWorker :: (Maybe ConnectionException -> m ()) -> TBQueue o -> m () + recvWorker triggerClose q = UnliftIO.handle (handler triggerClose) $ do msg <- liftIO $ receiveData conn atomically $ writeTBQueue q msg - recvWorker q + recvWorker triggerClose q - handler :: ConnectionException -> m () - handler = \case - CloseRequest _ _ -> liftIO $ sendClose conn ("Client closed connection" :: Text) - ConnectionClosed -> pure () - err -> throwIO err + handler :: (Maybe ConnectionException -> m ()) -> ConnectionException -> m () + handler triggerClose = \case + CloseRequest {} -> do + -- The client requested a close, we can just close normally. + triggerClose Nothing + -- Other cases are exceptional + err -> triggerClose (Just err) - sendWorker :: TBQueue i -> m () - sendWorker q = UnliftIO.handle handler $ do + sendWorker :: (Maybe ConnectionException -> m ()) -> TBQueue i -> m () + sendWorker triggerClose q = UnliftIO.handle (handler triggerClose) $ do outMsgs <- atomically $ some $ readTBQueue q liftIO $ sendBinaryDatas conn outMsgs - sendWorker q + sendWorker triggerClose q From e42cfcadf4543106992d8997ed0c71516481553b Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Thu, 2 Oct 2025 14:19:08 -0700 Subject: [PATCH 11/15] Debugging CBOR problems --- unison-cli/src/Unison/Cli/DownloadUtils.hs | 3 +++ unison-cli/src/Unison/Share/Codeserver.hs | 2 ++ unison-cli/src/Unison/Share/SyncV3.hs | 29 ++++++++++++++++++--- unison-share-api/src/Unison/SyncV3/Types.hs | 18 ++++++++++--- 4 files changed, 44 insertions(+), 8 deletions(-) diff --git a/unison-cli/src/Unison/Cli/DownloadUtils.hs b/unison-cli/src/Unison/Cli/DownloadUtils.hs index 16a42c8451..6861803dc1 100644 --- a/unison-cli/src/Unison/Cli/DownloadUtils.hs +++ b/unison-cli/src/Unison/Cli/DownloadUtils.hs @@ -27,6 +27,7 @@ import Unison.Codebase.Editor.RemoteRepo qualified as RemoteRepo import Unison.Codebase.Path qualified as Path import Unison.Codebase.ProjectPath (ProjectBranch (..)) import Unison.Core.Project (ProjectAndBranch (..)) +import Unison.Debug qualified as Debug import Unison.NameSegment.Internal qualified as NameSegment import Unison.Prelude import Unison.Share.API.Hash qualified as Share @@ -74,6 +75,7 @@ downloadProjectBranchFromShare useSquashed branch isPull = (Share.NoSquashedHead, _) -> pure branch.branchHead let causalHash32 = Share.hashJWTHash causalHashJwt exists <- Cli.runTransaction (Queries.causalExistsByHash32 causalHash32) + Debug.debugM Debug.Temp "Downloading using Sync " syncVersion when (not exists) do case syncVersion of SyncV1 -> do @@ -100,6 +102,7 @@ downloadProjectBranchFromShare useSquashed branch isPull = Output.ShareErrorPullV2 pullErr Share.TransportError err -> Output.ShareErrorTransport err SyncV3 -> do + Debug.debugLogM Debug.Temp "Using SyncV3 protocol" let branchRef = SyncV3.BranchRef (into @Text (ProjectAndBranch branch.projectName remoteProjectBranchName)) let shouldValidate = Codeserver.isCustomCodeserver Codeserver.defaultCodeserver when isPull $ do diff --git a/unison-cli/src/Unison/Share/Codeserver.hs b/unison-cli/src/Unison/Share/Codeserver.hs index ea7aee4b73..a569094da2 100644 --- a/unison-cli/src/Unison/Share/Codeserver.hs +++ b/unison-cli/src/Unison/Share/Codeserver.hs @@ -3,6 +3,8 @@ module Unison.Share.Codeserver defaultCodeserver, resolveCodeserver, CodeserverURI (..), + Scheme (..), + CodeserverId (..), ) where diff --git a/unison-cli/src/Unison/Share/SyncV3.hs b/unison-cli/src/Unison/Share/SyncV3.hs index d30fec9168..6a3bd771f6 100644 --- a/unison-cli/src/Unison/Share/SyncV3.hs +++ b/unison-cli/src/Unison/Share/SyncV3.hs @@ -20,6 +20,7 @@ import Unison.Cli.Monad import Unison.Cli.Monad qualified as Cli import Unison.Codebase (Codebase) import Unison.Codebase qualified as Codebase +import Unison.Debug qualified as Debug import Unison.Hash32 (Hash32) import Unison.Prelude import Unison.Server.Orphans () @@ -59,16 +60,26 @@ syncFromCodeserver :: syncFromCodeserver _shouldValidate codeserver branchRef hashJwt = do Cli.Env {codebase} <- ask let host = Codeserver.codeserverRegName codeserver - let tlsPort = 443 - let port = maybe tlsPort fromIntegral $ (Codeserver.codeserverPort) codeserver - let syncV3Path = "/ucm/v3/sync" + let syncV3Path = "/ucm/v3/sync/download" let rootCausalHash = Share.hashJWTHash hashJwt -- Enable compression let connectionOptions = WS.defaultConnectionOptions {WS.connectionCompressionOptions = WS.PermessageDeflateCompression WS.defaultPermessageDeflate} -- TODO: Add authentication headers manually. let headers = [] - liftIO $ (Wuss.runSecureClientWith host port syncV3Path connectionOptions headers) \conn -> do + let runner = case Codeserver.codeserverScheme codeserver of + Codeserver.Https -> + let tlsPort = 443 + port = maybe tlsPort fromIntegral $ (Codeserver.codeserverPort) codeserver + in Wuss.runSecureClientWith host port + Codeserver.Http -> + let tlsPort = 443 :: Int + port = maybe tlsPort id $ (Codeserver.codeserverPort) codeserver + in WS.runClientWith host port + Debug.debugLogM Debug.Temp "Obtaining Connection" + liftIO $ (runner syncV3Path connectionOptions headers) \conn -> do + Debug.debugLogM Debug.Temp "Obtained Connection" withQueues inputBuffer outputBuffer conn $ \queues@Queues {send} -> do + Debug.debugLogM Debug.Temp "Obtained Queues" let initMsg = InitMsg { initMsgClientVersion = syncV3ClientVersion, @@ -76,7 +87,9 @@ syncFromCodeserver _shouldValidate codeserver branchRef hashJwt = do initMsgRootCausal = hashJwt, initMsgRequestedDepth = Nothing } + Debug.debugLogM Debug.Temp "Sending init message" atomically $ send $ Msg $ ReceiverInitStream initMsg + Debug.debugLogM Debug.Temp "Init message sent" pendingRequestsVar <- newTVarIO (Set.singleton (CausalEntity, rootCausalHash)) yetToRequestVar <- newTVarIO Set.empty toIngestQueue <- newTBQueueIO transactionBatchSize @@ -92,6 +105,7 @@ syncFromCodeserver _shouldValidate codeserver branchRef hashJwt = do -- TODO: proper error handling Left err -> error $ show err Right () -> pure () + Debug.debugLogM Debug.Temp "Done sync, flushing temp entities" causalId <- liftIO $ flushTemp codebase (Share.hashJWTHash hashJwt) pure $ Right (Sync.hash32ToCausalHash rootCausalHash, causalId) @@ -112,11 +126,15 @@ doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, r _ <- Ki.fork scope (receiverWorker onErr) _ <- Ki.fork scope (requesterWorker onErr) _ <- Ki.fork scope (ingestionWorker onErr) + + Debug.debugLogM Debug.Temp "Awaiting completion" result <- atomically $ (Right <$> Ki.awaitAll scope) <|> (Left . Left <$> readTMVar errorVar) <|> (Left . Right <$> connectionClosed) + + Debug.debugM Debug.Temp "End result" result case result of Left (Left syncErr) -> pure $ Left syncErr Left (Right mayConnErr) -> case mayConnErr of @@ -126,6 +144,7 @@ doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, r where receiverWorker :: (SyncError -> IO ()) -> IO () receiverWorker onErr = do + Debug.debugLogM Debug.Temp "Receiver waiting for message" atomically receive >>= \case Msg (EmitterEntityMsg entity) -> do atomically $ do @@ -134,6 +153,7 @@ doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, r Err err -> onErr err requesterWorker :: (SyncError -> IO ()) -> IO () requesterWorker _onErr = forever do + Debug.debugLogM Debug.Temp "Requester waiting to send requests" atomically $ do requests <- readTVar yetToRequestVar writeTVar yetToRequestVar Set.empty @@ -142,6 +162,7 @@ doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, r ingestionWorker :: (SyncError -> IO ()) -> IO () ingestionWorker _onErr = forever do + Debug.debugLogM Debug.Temp "Ingestion waiting for entities" newEntities <- atomically $ do flushTBQueue toIngestQueue Codebase.runTransaction codebase $ do diff --git a/unison-share-api/src/Unison/SyncV3/Types.hs b/unison-share-api/src/Unison/SyncV3/Types.hs index 909bd7d459..11e34e5d03 100644 --- a/unison-share-api/src/Unison/SyncV3/Types.hs +++ b/unison-share-api/src/Unison/SyncV3/Types.hs @@ -14,6 +14,7 @@ module Unison.SyncV3.Types ) where +import Codec.CBOR.Term (decodeTerm) import Codec.Serialise (Serialise) import Codec.Serialise qualified as CBOR import Control.Lens hiding ((.=)) @@ -30,6 +31,7 @@ import Network.WebSockets (WebSocketsData) import Network.WebSockets qualified as WS import U.Codebase.Sqlite.Orphans () import U.Codebase.Sqlite.TempEntity +import Unison.Debug qualified as Debug import Unison.Hash32 (Hash32) import Unison.Prelude (tShow) import Unison.Server.Orphans () @@ -77,6 +79,7 @@ instance (CBOR.Serialise sh) => CBOR.Serialise (EntityRequestMsg sh) where data FromReceiverMessageTag = ReceiverInitStreamTag | ReceiverEntityRequestTag + deriving (Show, Eq) instance CBOR.Serialise FromReceiverMessageTag where encode = \case @@ -85,6 +88,7 @@ instance CBOR.Serialise FromReceiverMessageTag where decode = do tag <- CBOR.decode @Int + Debug.debugM Debug.Temp "Decoding FromReceiverMessageTag with tag" tag case tag of 0 -> pure ReceiverInitStreamTag 1 -> pure ReceiverEntityRequestTag @@ -104,11 +108,12 @@ instance (ToJSON ah, FromJSON ah) => CBOR.Serialise (InitMsg ah) where -- using Haskell's CBOR library :| -- -- See https://github.com/well-typed/cborg/issues/369 - CBOR.encode $ Aeson.encode msg + CBOR.encode @BS.ByteString $ BL.toStrict $ Aeson.encode msg decode = do - bs <- CBOR.decode @BL.ByteString - case Aeson.eitherDecode bs of + Debug.debugLogM Debug.Temp "Decoding InitMsg from JSON via CBOR" + bs <- CBOR.decode @BS.ByteString + case Aeson.eitherDecode $ BL.fromStrict bs of Left err -> fail $ "Error decoding InitMsg from JSON: " <> err Right msg -> pure msg @@ -122,6 +127,7 @@ instance (CBOR.Serialise h, ToJSON ah, FromJSON ah) => CBOR.Serialise (FromRecei <> CBOR.encode msg decode = do tag <- CBOR.decode @FromReceiverMessageTag + Debug.debugM Debug.Temp "Decoding FromReceiverMessage with tag" tag case tag of ReceiverInitStreamTag -> ReceiverInitStream <$> CBOR.decode @(InitMsg ah) ReceiverEntityRequestTag -> ReceiverEntityRequest <$> CBOR.decode @(EntityRequestMsg h) @@ -165,6 +171,7 @@ instance CBOR.Serialise SyncError where -- A message sent from the emitter to the downloader. data FromEmitterMessage hash text = EmitterEntityMsg (Entity hash text) + deriving (Show, Eq) data HashMappings hash smallHash = HashMappings { hashMappings :: Map smallHash hash @@ -226,6 +233,7 @@ data Entity hash text = Entity entityDepth :: EntityDepth, entityData :: CBOR.CBORBytes TempEntity } + deriving (Show, Eq) instance (CBOR.Serialise smallHash, CBOR.Serialise text) => CBOR.Serialise (Entity smallHash text) where encode (Entity {entityHash, entityKind, entityDepth, entityData}) = @@ -275,6 +283,7 @@ instance CBOR.Serialise FromEmitterMessageTag where data MsgOrError err a = Msg a | Err err + deriving (Show, Eq, Ord) instance (CBOR.Serialise a, CBOR.Serialise err) => CBOR.Serialise (MsgOrError err a) where encode = \case @@ -283,6 +292,7 @@ instance (CBOR.Serialise a, CBOR.Serialise err) => CBOR.Serialise (MsgOrError er decode = do tag <- CBOR.decode @Int + Debug.debugM Debug.Temp "Decoding MsgOrError with tag" tag case tag of 0 -> Msg <$> CBOR.decode 1 -> Err <$> CBOR.decode @@ -290,7 +300,7 @@ instance (CBOR.Serialise a, CBOR.Serialise err) => CBOR.Serialise (MsgOrError er instance (Serialise msg) => WebSocketsData (MsgOrError SyncError msg) where fromLazyByteString bytes = - CBOR.deserialiseOrFailCBORBytes (CBOR.CBORBytes bytes) + CBOR.deserialiseOrFail bytes & either (\err -> Err . EncodingFailure $ "Error decoding CBOR message from bytes: " <> tShow err) Msg toLazyByteString = CBOR.serialise From ab2297c153dc78e7dddeea6abb0b376dd0d7c9d7 Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Fri, 3 Oct 2025 16:33:48 -0700 Subject: [PATCH 12/15] Pull out shared sync types --- .../src/Unison/SyncCommon/Types.hs | 18 +++ unison-share-api/src/Unison/SyncV2/Types.hs | 9 +- unison-share-api/src/Unison/SyncV3/Types.hs | 103 ++++++++++++++---- unison-share-api/unison-share-api.cabal | 1 + 4 files changed, 104 insertions(+), 27 deletions(-) create mode 100644 unison-share-api/src/Unison/SyncCommon/Types.hs diff --git a/unison-share-api/src/Unison/SyncCommon/Types.hs b/unison-share-api/src/Unison/SyncCommon/Types.hs new file mode 100644 index 0000000000..e9900cb855 --- /dev/null +++ b/unison-share-api/src/Unison/SyncCommon/Types.hs @@ -0,0 +1,18 @@ +-- Types common to multiple versions of Sync +module Unison.SyncCommon.Types + ( BranchRef (..), + ) +where + +import Codec.Serialise (Serialise (..)) +import Data.Aeson (FromJSON (..), ToJSON (..)) +import Data.Text (Text) +import Unison.Core.Project (ProjectAndBranch (..), ProjectBranchName, ProjectName) +import Unison.Prelude (From (..)) +import Unison.Server.Orphans () + +newtype BranchRef = BranchRef {unBranchRef :: Text} + deriving (Serialise, Eq, Show, Ord, ToJSON, FromJSON) via Text + +instance From (ProjectAndBranch ProjectName ProjectBranchName) BranchRef where + from pab = BranchRef $ from pab diff --git a/unison-share-api/src/Unison/SyncV2/Types.hs b/unison-share-api/src/Unison/SyncV2/Types.hs index 82f2e95f63..1bfe3e7a6f 100644 --- a/unison-share-api/src/Unison/SyncV2/Types.hs +++ b/unison-share-api/src/Unison/SyncV2/Types.hs @@ -38,20 +38,13 @@ import Data.Text qualified as Text import Data.Word (Word16, Word64) import U.Codebase.HashTags (CausalHash) import U.Codebase.Sqlite.TempEntity (TempEntity) -import Unison.Core.Project (ProjectAndBranch (..), ProjectBranchName, ProjectName) import Unison.Hash32 (Hash32) -import Unison.Prelude (From (..)) import Unison.Server.Orphans () import Unison.Share.API.Hash (HashJWT) import Unison.Sync.Types qualified as SyncV1 +import Unison.SyncCommon.Types import Unison.Util.Servant.CBOR -newtype BranchRef = BranchRef {unBranchRef :: Text} - deriving (Serialise, Eq, Show, Ord, ToJSON, FromJSON) via Text - -instance From (ProjectAndBranch ProjectName ProjectBranchName) BranchRef where - from pab = BranchRef $ from pab - data GetCausalHashErrorTag = GetCausalHashNoReadPermissionTag | GetCausalHashUserNotFoundTag diff --git a/unison-share-api/src/Unison/SyncV3/Types.hs b/unison-share-api/src/Unison/SyncV3/Types.hs index 11e34e5d03..687cceba3a 100644 --- a/unison-share-api/src/Unison/SyncV3/Types.hs +++ b/unison-share-api/src/Unison/SyncV3/Types.hs @@ -8,13 +8,12 @@ module Unison.SyncV3.Types Entity (..), EntityKind (..), EntityDepth (..), - HashMappings (..), HashTag (..), BranchRef (..), ) where -import Codec.CBOR.Term (decodeTerm) +import Unison.SyncCommon.Types import Codec.Serialise (Serialise) import Codec.Serialise qualified as CBOR import Control.Lens hiding ((.=)) @@ -23,7 +22,6 @@ import Data.Aeson qualified as Aeson import Data.ByteString qualified as BS import Data.ByteString.Lazy.Char8 qualified as BL import Data.Int (Int32, Int64) -import Data.Map (Map) import Data.Set (Set) import Data.Set qualified as Set import Data.Text (Text) @@ -68,6 +66,11 @@ data EntityRequestMsg hash = EntityRequestMsg } deriving (Show, Eq) +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> let msg = EntityRequestMsg {hashes = [(CausalEntity, "hash1"), (NamespaceEntity, "hash2")]} +-- >>> CBOR.deserialise (CBOR.serialise msg) == msg +-- True instance (CBOR.Serialise sh) => CBOR.Serialise (EntityRequestMsg sh) where encode (EntityRequestMsg {hashes}) = CBOR.encode hashes @@ -81,6 +84,12 @@ data FromReceiverMessageTag | ReceiverEntityRequestTag deriving (Show, Eq) +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> CBOR.deserialise (CBOR.serialise ReceiverInitStreamTag) == ReceiverInitStreamTag +-- True +-- >>> CBOR.deserialise (CBOR.serialise ReceiverEntityRequestTag) == ReceiverEntityRequestTag +-- True instance CBOR.Serialise FromReceiverMessageTag where encode = \case ReceiverInitStreamTag -> CBOR.encode (0 :: Int) @@ -117,6 +126,17 @@ instance (ToJSON ah, FromJSON ah) => CBOR.Serialise (InitMsg ah) where Left err -> fail $ "Error decoding InitMsg from JSON: " <> err Right msg -> pure msg +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> let msg = InitMsg {initMsgClientVersion = 1, initMsgBranchRef = BranchRef "main", initMsgRootCausal = "hash123", initMsgRequestedDepth = Just 10} +-- >>> CBOR.deserialise (CBOR.serialise msg) == msg +-- True +-- >>> let initMsg :: FromReceiverMessage Text Text = ReceiverInitStream msg +-- >>> CBOR.deserialise (CBOR.serialise initMsg) == initMsg +-- True +-- >>> let entityReq :: FromReceiverMessage Text Text = ReceiverEntityRequest (EntityRequestMsg {hashes = [(CausalEntity, "h1")]}) +-- >>> CBOR.deserialise (CBOR.serialise entityReq) == entityReq +-- True instance (CBOR.Serialise h, ToJSON ah, FromJSON ah) => CBOR.Serialise (FromReceiverMessage ah h) where encode = \case ReceiverInitStream initMsg -> @@ -141,6 +161,16 @@ data SyncError | ConnectionError Text deriving (Show, Eq) +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> import qualified Data.Set as Set +-- >>> CBOR.deserialise (CBOR.serialise (InitializationError "test")) == InitializationError "test" +-- True +-- >>> CBOR.deserialise (CBOR.serialise (EncodingFailure "fail")) == EncodingFailure "fail" +-- True +-- >>> let forbidden = ForbiddenEntityRequest (Set.fromList [(CausalEntity, undefined)]) +-- >>> CBOR.deserialise (CBOR.serialise (ConnectionError "err")) == ConnectionError "err" +-- True instance CBOR.Serialise SyncError where encode = \case InitializationError msg -> @@ -173,10 +203,6 @@ data FromEmitterMessage hash text = EmitterEntityMsg (Entity hash text) deriving (Show, Eq) -data HashMappings hash smallHash = HashMappings - { hashMappings :: Map smallHash hash - } - data EntityKind = CausalEntity | NamespaceEntity @@ -202,6 +228,16 @@ instance Sqlite.FromField EntityKind where 3 -> pure PatchEntity _ -> fail $ "Unknown EntityKind tag: " <> show tag +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> CBOR.deserialise (CBOR.serialise CausalEntity) == CausalEntity +-- True +-- >>> CBOR.deserialise (CBOR.serialise NamespaceEntity) == NamespaceEntity +-- True +-- >>> CBOR.deserialise (CBOR.serialise DefnComponentEntity) == DefnComponentEntity +-- True +-- >>> CBOR.deserialise (CBOR.serialise PatchEntity) == PatchEntity +-- True instance CBOR.Serialise EntityKind where encode = \case CausalEntity -> CBOR.encode (0 :: Int) @@ -235,6 +271,12 @@ data Entity hash text = Entity } deriving (Show, Eq) +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> import U.Codebase.Sqlite.TempEntity (TempEntity(..)) +-- >>> let ent :: Entity Text Text = Entity {entityHash = "hash", entityKind = CausalEntity, entityDepth = EntityDepth 5, entityData = CBOR.CBORBytes "abc"} +-- >>> CBOR.deserialise (CBOR.serialise ent) == ent +-- True instance (CBOR.Serialise smallHash, CBOR.Serialise text) => CBOR.Serialise (Entity smallHash text) where encode (Entity {entityHash, entityKind, entityDepth, entityData}) = CBOR.encode entityHash @@ -250,14 +292,13 @@ instance (CBOR.Serialise smallHash, CBOR.Serialise text) => CBOR.Serialise (Enti pure $ Entity {entityHash, entityKind, entityData, entityDepth} -instance (Ord smallHash, CBOR.Serialise hash, CBOR.Serialise smallHash) => CBOR.Serialise (HashMappings hash smallHash) where - encode (HashMappings {hashMappings}) = - CBOR.encode hashMappings - - decode = do - hashMappings <- CBOR.decode @(Map smallHash hash) - pure $ HashMappings {hashMappings} - +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> import U.Codebase.Sqlite.TempEntity (TempEntity(..)) +-- >>> let ent :: Entity Text Text = Entity {entityHash = "hash", entityKind = CausalEntity, entityDepth = EntityDepth 5, entityData = CBOR.CBORBytes "abc"} +-- >>> let msg = EmitterEntityMsg ent +-- >>> CBOR.deserialise (CBOR.serialise msg) == msg +-- True instance (CBOR.Serialise hash, CBOR.Serialise text) => CBOR.Serialise (FromEmitterMessage hash text) where encode = \case EmitterEntityMsg msg -> CBOR.encode EmitterEntityTag <> CBOR.encode msg @@ -269,7 +310,12 @@ instance (CBOR.Serialise hash, CBOR.Serialise text) => CBOR.Serialise (FromEmitt data FromEmitterMessageTag = EmitterEntityTag + deriving (Show, Eq) +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> CBOR.deserialise (CBOR.serialise EmitterEntityTag) == EmitterEntityTag +-- True instance CBOR.Serialise FromEmitterMessageTag where encode = \case EmitterEntityTag -> CBOR.encode (0 :: Int) @@ -285,6 +331,12 @@ data MsgOrError err a | Err err deriving (Show, Eq, Ord) +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> CBOR.deserialise (CBOR.serialise (Msg "test" :: MsgOrError Text Text)) == Msg "test" +-- True +-- >>> CBOR.deserialise (CBOR.serialise (Err "error" :: MsgOrError Text Text)) == Err "error" +-- True instance (CBOR.Serialise a, CBOR.Serialise err) => CBOR.Serialise (MsgOrError err a) where encode = \case Msg a -> CBOR.encode (0 :: Int) <> CBOR.encode a @@ -298,10 +350,21 @@ instance (CBOR.Serialise a, CBOR.Serialise err) => CBOR.Serialise (MsgOrError er 1 -> Err <$> CBOR.decode _ -> fail $ "Unknown MsgOrError tag: " <> show tag +-- | Roundtrip test: +-- >>> import qualified Network.WebSockets as WS +-- >>> let msgVal = Msg "test" :: MsgOrError SyncError Text +-- >>> WS.fromLazyByteString (WS.toLazyByteString msgVal) == msgVal +-- True +-- >>> let errVal = Err (InitializationError "init error") :: MsgOrError SyncError Text +-- >>> WS.fromLazyByteString (WS.toLazyByteString errVal) == errVal +-- True +-- >>> let dataMsg = WS.Binary (WS.toLazyByteString msgVal) +-- >>> WS.fromDataMessage dataMsg == msgVal +-- True instance (Serialise msg) => WebSocketsData (MsgOrError SyncError msg) where fromLazyByteString bytes = CBOR.deserialiseOrFail bytes - & either (\err -> Err . EncodingFailure $ "Error decoding CBOR message from bytes: " <> tShow err) Msg + & either (\err -> Err . EncodingFailure $ "Error decoding CBOR message from bytes: " <> tShow err) id toLazyByteString = CBOR.serialise @@ -316,6 +379,11 @@ instance (Serialise msg) => WebSocketsData (MsgOrError SyncError msg) where data HashTag = HashTag (EntityKind, Int64) deriving (Show, Eq, Ord) +-- | Roundtrip test: +-- >>> import qualified Codec.Serialise as CBOR +-- >>> let tag = HashTag (CausalEntity, 42) +-- >>> CBOR.deserialise (CBOR.serialise tag) == tag +-- True instance CBOR.Serialise HashTag where encode (HashTag (kind, idx)) = CBOR.encode (kind, idx) @@ -323,6 +391,3 @@ instance CBOR.Serialise HashTag where decode = do (kind, idx) <- CBOR.decode @(EntityKind, Int64) pure $ HashTag (kind, idx) - -newtype BranchRef = BranchRef {unBranchRef :: Text} - deriving (Serialise, Eq, Show, Ord, ToJSON, FromJSON) via Text diff --git a/unison-share-api/unison-share-api.cabal b/unison-share-api/unison-share-api.cabal index 198a1c5385..1018ffdb1a 100644 --- a/unison-share-api/unison-share-api.cabal +++ b/unison-share-api/unison-share-api.cabal @@ -49,6 +49,7 @@ library Unison.Sync.Common Unison.Sync.EntityValidation Unison.Sync.Types + Unison.SyncCommon.Types Unison.SyncV2.API Unison.SyncV2.Types Unison.SyncV3.Types From ca4bb33d83ffcebef400f74bceecbc11e5329962 Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Mon, 6 Oct 2025 12:46:14 -0700 Subject: [PATCH 13/15] More common sync types --- unison-share-api/src/Unison/SyncV3/Types.hs | 25 ++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/unison-share-api/src/Unison/SyncV3/Types.hs b/unison-share-api/src/Unison/SyncV3/Types.hs index 687cceba3a..6618944cf9 100644 --- a/unison-share-api/src/Unison/SyncV3/Types.hs +++ b/unison-share-api/src/Unison/SyncV3/Types.hs @@ -13,7 +13,6 @@ module Unison.SyncV3.Types ) where -import Unison.SyncCommon.Types import Codec.Serialise (Serialise) import Codec.Serialise qualified as CBOR import Control.Lens hiding ((.=)) @@ -34,6 +33,7 @@ import Unison.Hash32 (Hash32) import Unison.Prelude (tShow) import Unison.Server.Orphans () import Unison.Sqlite qualified as Sqlite +import Unison.SyncCommon.Types import Unison.Util.Servant.CBOR qualified as CBOR data InitMsg authedHash = InitMsg @@ -159,6 +159,11 @@ data SyncError | -- The caller asked for a Hash they shouldn't have access to. ForbiddenEntityRequest (Set (EntityKind, Hash32)) | ConnectionError Text + | ProjectNotFound BranchRef + | UserNotFound BranchRef + | NoReadPermission BranchRef + | HashJWTVerificationError Text + | InvalidBranchRef Text BranchRef deriving (Show, Eq) -- | Roundtrip test: @@ -183,6 +188,16 @@ instance CBOR.Serialise SyncError where CBOR.encode (3 :: Int) <> CBOR.encode hashes ConnectionError err -> CBOR.encode (4 :: Int) <> CBOR.encode err + ProjectNotFound branchRef -> + CBOR.encode (5 :: Int) <> CBOR.encode branchRef + UserNotFound branchRef -> + CBOR.encode (6 :: Int) <> CBOR.encode branchRef + NoReadPermission branchRef -> + CBOR.encode (7 :: Int) <> CBOR.encode branchRef + HashJWTVerificationError err -> + CBOR.encode (8 :: Int) <> CBOR.encode err + InvalidBranchRef err branchRef -> + CBOR.encode (9 :: Int) <> CBOR.encode err <> CBOR.encode branchRef decode = do tag <- CBOR.decode @Int @@ -196,6 +211,14 @@ instance CBOR.Serialise SyncError where 4 -> do err <- CBOR.decode @Text pure $ ConnectionError err + 5 -> ProjectNotFound <$> CBOR.decode + 6 -> UserNotFound <$> CBOR.decode + 7 -> NoReadPermission <$> CBOR.decode + 8 -> HashJWTVerificationError <$> CBOR.decode + 9 -> do + err <- CBOR.decode @Text + branchRef <- CBOR.decode @BranchRef + pure $ InvalidBranchRef err branchRef _ -> fail $ "Unknown SyncError tag: " <> show tag -- A message sent from the emitter to the downloader. From cd9ce0e8cd21316ecb8f915d2b877612b5d55466 Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Mon, 6 Oct 2025 17:15:28 -0700 Subject: [PATCH 14/15] Pass Token Provider --- .../Codebase/SqliteCodebase/Migrations.hs | 3 +- .../Codebase/SqliteCodebase/Operations.hs | 1 + unison-cli/src/Unison/Cli/Monad.hs | 4 +++ .../src/Unison/Codebase/Transcript/Runner.hs | 14 ++++----- unison-cli/src/Unison/CommandLine/Main.hs | 5 +++- unison-cli/src/Unison/MCP/Cli.hs | 1 + unison-cli/src/Unison/Main.hs | 17 ++++++++--- unison-cli/src/Unison/Share/SyncV3.hs | 30 +++++++------------ unison-share-api/src/Unison/SyncV3/Utils.hs | 30 +++++++++++++++++++ unison-share-api/unison-share-api.cabal | 1 + 10 files changed, 73 insertions(+), 33 deletions(-) create mode 100644 unison-share-api/src/Unison/SyncV3/Utils.hs diff --git a/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Migrations.hs b/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Migrations.hs index 5b13751a16..63b528166e 100644 --- a/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Migrations.hs +++ b/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Migrations.hs @@ -89,7 +89,8 @@ migrations regionVar getDeclType termBuffer declBuffer rootCodebasePath = sqlMigration 19 Q.addMergeBranchTables, sqlMigration 20 Q.addUpdateBranchTable, sqlMigration 21 Q.addDerivedDependentsByDependencyIndex, - sqlMigration 22 Q.addUpgradeBranchTable + sqlMigration 22 Q.addUpgradeBranchTable, + sqlMigration 23 Q.addSyncV3TempTables ] where runT :: Sqlite.Transaction () -> Sqlite.Connection -> IO () diff --git a/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Operations.hs b/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Operations.hs index e95091a164..b31ffaa2c2 100644 --- a/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Operations.hs +++ b/parser-typechecker/src/Unison/Codebase/SqliteCodebase/Operations.hs @@ -85,6 +85,7 @@ createSchema = do Q.addUpdateBranchTable Q.addDerivedDependentsByDependencyIndex Q.addUpgradeBranchTable + Q.addSyncV3TempTables (_, emptyCausalHashId) <- emptyCausalHash (_, ProjectBranchRow {projectId, branchId}) <- insertProjectAndBranch scratchProjectName scratchBranchName emptyCausalHashId diff --git a/unison-cli/src/Unison/Cli/Monad.hs b/unison-cli/src/Unison/Cli/Monad.hs index 60a9382a1c..66621d4d11 100644 --- a/unison-cli/src/Unison/Cli/Monad.hs +++ b/unison-cli/src/Unison/Cli/Monad.hs @@ -71,6 +71,7 @@ import U.Codebase.Sqlite.DbId (ProjectBranchId, ProjectId) import U.Codebase.Sqlite.Queries qualified as Q import Unison.Auth.CredentialManager (CredentialManager) import Unison.Auth.HTTPClient (AuthenticatedHttpClient) +import Unison.Auth.Tokens (TokenProvider) import Unison.Codebase (Codebase) import Unison.Codebase qualified as Codebase import Unison.Codebase.Editor.Input (Input) @@ -158,6 +159,9 @@ type SourceName = Text -- Get the environment with 'ask'. data Env = Env { authHTTPClient :: AuthenticatedHttpClient, + -- | How to get auth tokens for a given codeserver. + -- Using AuthenticatedHttpClient takes care of this, but websocket connection need to provide auth headers manually. + tokenProvider :: TokenProvider, codebase :: Codebase IO Symbol Ann, credentialManager :: CredentialManager, -- | Generate a unique name. diff --git a/unison-cli/src/Unison/Codebase/Transcript/Runner.hs b/unison-cli/src/Unison/Codebase/Transcript/Runner.hs index b71b895d44..3b87f87d0e 100644 --- a/unison-cli/src/Unison/Codebase/Transcript/Runner.hs +++ b/unison-cli/src/Unison/Codebase/Transcript/Runner.hs @@ -96,7 +96,9 @@ withRunner :: m r withRunner isTest verbosity ucmVersion action = do credMan <- AuthN.newCredentialManager - authenticatedHTTPClient <- initTranscriptAuthenticatedHTTPClient credMan + let tokenProvider :: AuthN.TokenProvider + tokenProvider = AuthN.newTokenProvider credMan + authenticatedHTTPClient <- AuthN.newAuthenticatedHTTPClient tokenProvider ucmVersion -- If we're in a transcript test, configure the environment to use a non-existent fzf binary -- so that errors are consistent. @@ -130,6 +132,7 @@ withRunner isTest verbosity ucmVersion action = do ucmVersion baseUrlText authenticatedHTTPClient + tokenProvider credMan stanzas where @@ -138,11 +141,6 @@ withRunner isTest verbosity ucmVersion action = do RTI.withRuntime False RTI.Persistent ucmVersion \runtime -> RTI.withRuntime True RTI.Persistent ucmVersion \sbRuntime -> action runtime sbRuntime - initTranscriptAuthenticatedHTTPClient :: AuthN.CredentialManager -> m AuthN.AuthenticatedHttpClient - initTranscriptAuthenticatedHTTPClient credMan = liftIO $ do - let tokenProvider :: AuthN.TokenProvider - tokenProvider = AuthN.newTokenProvider credMan - AuthN.newAuthenticatedHTTPClient tokenProvider ucmVersion isGeneratedBlock :: ProcessedBlock -> Bool isGeneratedBlock = generated . getCommonInfoTags @@ -157,10 +155,11 @@ run :: UCMVersion -> Text -> AuthN.AuthenticatedHttpClient -> + AuthN.TokenProvider -> AuthN.CredentialManager -> Transcript -> IO (Either Error Transcript) -run isTest verbosity codebase runtime sbRuntime ucmVersion baseURL authenticatedHTTPClient credMan transcript = UnliftIO.try do +run isTest verbosity codebase runtime sbRuntime ucmVersion baseURL authenticatedHTTPClient tokenProvider credMan transcript = UnliftIO.try do let behaviors = extractBehaviors $ settings transcript let stanzas' = stanzas transcript httpManager <- HTTP.newManager HTTP.defaultManagerSettings @@ -518,6 +517,7 @@ run isTest verbosity codebase runtime sbRuntime ucmVersion baseURL authenticated let env = Cli.Env { authHTTPClient = authenticatedHTTPClient, + tokenProvider, codebase, credentialManager = credMan, generateUniqueName = do diff --git a/unison-cli/src/Unison/CommandLine/Main.hs b/unison-cli/src/Unison/CommandLine/Main.hs index 845abeee05..36c2801f98 100644 --- a/unison-cli/src/Unison/CommandLine/Main.hs +++ b/unison-cli/src/Unison/CommandLine/Main.hs @@ -26,6 +26,7 @@ import U.Codebase.Sqlite.Queries qualified as Queries import Unison.Auth.CredentialManager qualified as AuthN import Unison.Auth.HTTPClient (AuthenticatedHttpClient) import Unison.Auth.HTTPClient qualified as AuthN +import Unison.Auth.Tokens (TokenProvider) import Unison.Cli.Monad qualified as Cli import Unison.Cli.Pretty qualified as P import Unison.Cli.ProjectUtils qualified as ProjectUtils @@ -146,11 +147,12 @@ main :: Maybe Server.BaseUrl -> UCMVersion -> AuthN.AuthenticatedHttpClient -> + TokenProvider -> AuthN.CredentialManager -> (PP.ProjectPathIds -> IO ()) -> ShouldWatchFiles -> IO () -main dir welcome ppIds initialInputs runtime sbRuntime codebase serverBaseUrl ucmVersion authHTTPClient credentialManager lspCheckForChanges shouldWatchFiles = do +main dir welcome ppIds initialInputs runtime sbRuntime codebase serverBaseUrl ucmVersion authHTTPClient tokenProvider credentialManager lspCheckForChanges shouldWatchFiles = do -- we don't like FSNotify's debouncing (it seems to drop later events) -- so we will be doing our own instead let config = FSNotify.defaultConfig @@ -288,6 +290,7 @@ main dir welcome ppIds initialInputs runtime sbRuntime codebase serverBaseUrl uc { authHTTPClient, codebase, credentialManager, + tokenProvider, loadSource = loadSourceFile, lspCheckForChanges, writeSource, diff --git a/unison-cli/src/Unison/MCP/Cli.hs b/unison-cli/src/Unison/MCP/Cli.hs index 403aa7ab09..889aea5180 100644 --- a/unison-cli/src/Unison/MCP/Cli.hs +++ b/unison-cli/src/Unison/MCP/Cli.hs @@ -102,6 +102,7 @@ cliToMCP projCtx cli = do let cliEnv = Cli.Env { authHTTPClient = authenticatedHTTPClient, + tokenProvider, codebase, credentialManager = credMan, generateUniqueName = do diff --git a/unison-cli/src/Unison/Main.hs b/unison-cli/src/Unison/Main.hs index ab03c6c38d..88f94a3c26 100644 --- a/unison-cli/src/Unison/Main.hs +++ b/unison-cli/src/Unison/Main.hs @@ -54,6 +54,7 @@ import Text.Megaparsec qualified as MP import U.Codebase.Sqlite.Queries qualified as Queries import Unison.Auth.CredentialManager qualified as AuthN import Unison.Auth.HTTPClient qualified as AuthN +import Unison.Auth.Tokens (TokenProvider) import Unison.Auth.Tokens qualified as AuthN import Unison.Cli.ProjectUtils qualified as ProjectUtils import Unison.Codebase (Codebase, CodebasePath) @@ -185,7 +186,8 @@ main version = do let serverUrl = Nothing let ucmVersion = Version.gitDescribeWithDate version credMan <- liftIO $ AuthN.newCredentialManager - authenticatedHTTPClient <- initTranscriptAuthenticatedHTTPClient ucmVersion credMan + let tokenProvider = AuthN.newTokenProvider credMan + authenticatedHTTPClient <- AuthN.newAuthenticatedHTTPClient tokenProvider ucmVersion startProjectPath <- Codebase.runTransaction theCodebase Codebase.expectCurrentProjectPath launch version @@ -195,6 +197,7 @@ main version = do theCodebase [Left fileEvent, Right $ Input.ExecuteI NoProf mainName args, Right Input.QuitI] authenticatedHTTPClient + tokenProvider credMan serverUrl (PP.toIds startProjectPath) @@ -213,7 +216,8 @@ main version = do let serverUrl = Nothing let ucmVersion = Version.gitDescribeWithDate version credMan <- liftIO $ AuthN.newCredentialManager - authenticatedHTTPClient <- initTranscriptAuthenticatedHTTPClient ucmVersion credMan + let tokenProvider = AuthN.newTokenProvider credMan + authenticatedHTTPClient <- AuthN.newAuthenticatedHTTPClient tokenProvider ucmVersion startProjectPath <- Codebase.runTransaction theCodebase Codebase.expectCurrentProjectPath launch version @@ -223,6 +227,7 @@ main version = do theCodebase [Left fileEvent, Right $ Input.ExecuteI NoProf mainName args, Right Input.QuitI] authenticatedHTTPClient + tokenProvider credMan serverUrl (PP.toIds startProjectPath) @@ -330,7 +335,8 @@ main version = do let isTest = False let ucmVersion = Version.gitDescribeWithDate version credMan <- liftIO $ AuthN.newCredentialManager - authenticatedHTTPClient <- initTranscriptAuthenticatedHTTPClient ucmVersion credMan + let tokenProvider = AuthN.newTokenProvider credMan + authenticatedHTTPClient <- AuthN.newAuthenticatedHTTPClient tokenProvider ucmVersion mcpServerConfig <- MCP.initServer theCodebase runtime sbRuntime (Just currentDir) ucmVersion authenticatedHTTPClient Server.startServer isTest @@ -374,6 +380,7 @@ main version = do theCodebase [] authenticatedHTTPClient + tokenProvider credMan mayBaseUrl (PP.toIds startingProjectPath) @@ -596,6 +603,7 @@ launch :: Codebase.Codebase IO Symbol Ann -> [Either Input.Event Input.Input] -> AuthN.AuthenticatedHttpClient -> + TokenProvider -> AuthN.CredentialManager -> Maybe Server.BaseUrl -> PP.ProjectPathIds -> @@ -603,7 +611,7 @@ launch :: (PP.ProjectPathIds -> IO ()) -> CommandLine.ShouldWatchFiles -> IO () -launch version dir runtime sbRuntime codebase inputs authenticatedHTTPClient credMan serverBaseUrl startingPath initResult lspCheckForChanges shouldWatchFiles = do +launch version dir runtime sbRuntime codebase inputs authenticatedHTTPClient tokenProvider credMan serverBaseUrl startingPath initResult lspCheckForChanges shouldWatchFiles = do showWelcomeHint <- Codebase.runTransaction codebase Queries.doProjectsExist let isNewCodebase = case initResult of CreatedCodebase -> NewlyCreatedCodebase @@ -621,6 +629,7 @@ launch version dir runtime sbRuntime codebase inputs authenticatedHTTPClient cre serverBaseUrl ucmVersion authenticatedHTTPClient + tokenProvider credMan lspCheckForChanges shouldWatchFiles diff --git a/unison-cli/src/Unison/Share/SyncV3.hs b/unison-cli/src/Unison/Share/SyncV3.hs index 6a3bd771f6..5045da4ab8 100644 --- a/unison-cli/src/Unison/Share/SyncV3.hs +++ b/unison-cli/src/Unison/Share/SyncV3.hs @@ -3,18 +3,17 @@ module Unison.Share.SyncV3 ) where +import Network.Socket (withSocketsDo) import Control.Arrow ((&&&)) import Control.Monad.Reader import Data.Set qualified as Set -import Data.Set.Lens qualified as Lens +import Data.Text.Encoding as Text import GHC.Natural import Ki qualified import Network.WebSockets qualified as WS import U.Codebase.HashTags import U.Codebase.Sqlite.DbId -import U.Codebase.Sqlite.Entity qualified as Entity import U.Codebase.Sqlite.Queries qualified as Q -import U.Codebase.Sqlite.TempEntity (TempEntity) import U.Codebase.Sqlite.V2.HashHandle (v2HashHandle) import Unison.Cli.Monad import Unison.Cli.Monad qualified as Cli @@ -27,9 +26,11 @@ import Unison.Server.Orphans () import Unison.Share.API.Hash qualified as Share import Unison.Share.Codeserver qualified as Codeserver import Unison.Share.Sync.Types qualified as Sync +import Unison.Share.Types import Unison.Sync.Common qualified as Sync import Unison.SyncV3.Types import Unison.SyncV3.Types as SyncV3 +import Unison.SyncV3.Utils (tempEntityDependencies) import Unison.Util.Servant.CBOR qualified as CBOR import Unison.Util.Websockets (Queues (..), withQueues) import UnliftIO.STM @@ -58,14 +59,16 @@ syncFromCodeserver :: Share.HashJWT -> Cli (Either (Sync.SyncError SyncV3.SyncError) (CausalHash, CausalHashId)) syncFromCodeserver _shouldValidate codeserver branchRef hashJwt = do - Cli.Env {codebase} <- ask + Cli.Env {codebase, tokenProvider} <- ask let host = Codeserver.codeserverRegName codeserver let syncV3Path = "/ucm/v3/sync/download" let rootCausalHash = Share.hashJWTHash hashJwt -- Enable compression let connectionOptions = WS.defaultConnectionOptions {WS.connectionCompressionOptions = WS.PermessageDeflateCompression WS.defaultPermessageDeflate} - -- TODO: Add authentication headers manually. - let headers = [] + headers <- + (liftIO (tokenProvider (codeserverIdFromCodeserverURI codeserver))) <&> \case + Left {} -> [] + Right token -> [("Authorization", "Bearer " <> Text.encodeUtf8 token)] let runner = case Codeserver.codeserverScheme codeserver of Codeserver.Https -> let tlsPort = 443 @@ -76,7 +79,7 @@ syncFromCodeserver _shouldValidate codeserver branchRef hashJwt = do port = maybe tlsPort id $ (Codeserver.codeserverPort) codeserver in WS.runClientWith host port Debug.debugLogM Debug.Temp "Obtaining Connection" - liftIO $ (runner syncV3Path connectionOptions headers) \conn -> do + liftIO $ withSocketsDo $ (runner syncV3Path connectionOptions headers) \conn -> do Debug.debugLogM Debug.Temp "Obtained Connection" withQueues inputBuffer outputBuffer conn $ \queues@Queues {send} -> do Debug.debugLogM Debug.Temp "Obtained Queues" @@ -216,16 +219,3 @@ flushTemp codebase rootCausalHash = do loop loop Q.expectCausalHashIdByCausalHash (Sync.hash32ToCausalHash rootCausalHash) - -tempEntityDependencies :: TempEntity -> Set (EntityKind, Hash32) -tempEntityDependencies entity = do - let componentDeps = Lens.setOf Entity.defns_ entity - patchDeps = Lens.setOf Entity.patches_ entity - branchHashes = Lens.setOf Entity.branchHashes_ entity <> Lens.setOf Entity.branches_ entity - causalHashes = Lens.setOf Entity.causalHashes_ entity - in Set.unions - [ Set.map (DefnComponentEntity,) componentDeps, - Set.map (PatchEntity,) patchDeps, - Set.map (NamespaceEntity,) branchHashes, - Set.map (CausalEntity,) causalHashes - ] diff --git a/unison-share-api/src/Unison/SyncV3/Utils.hs b/unison-share-api/src/Unison/SyncV3/Utils.hs new file mode 100644 index 0000000000..513fba3ef7 --- /dev/null +++ b/unison-share-api/src/Unison/SyncV3/Utils.hs @@ -0,0 +1,30 @@ +module Unison.SyncV3.Utils (tempEntityDependencies, entityDependencies) where + +import Data.Set (Set) +import Data.Set qualified as Set +import Data.Set.Lens qualified as Lens +import U.Codebase.Sqlite.Entity qualified as Entity +import U.Codebase.Sqlite.TempEntity +import Unison.Hash32 (Hash32) +import Unison.SyncV3.Types +import Unison.Util.Servant.CBOR qualified as CBOR + +tempEntityDependencies :: TempEntity -> Set (EntityKind, Hash32) +tempEntityDependencies entity = do + let componentDeps = Lens.setOf Entity.defns_ entity + patchDeps = Lens.setOf Entity.patches_ entity + branchHashes = Lens.setOf Entity.branchHashes_ entity <> Lens.setOf Entity.branches_ entity + causalHashes = Lens.setOf Entity.causalHashes_ entity + in Set.unions + [ Set.map (DefnComponentEntity,) componentDeps, + Set.map (PatchEntity,) patchDeps, + Set.map (NamespaceEntity,) branchHashes, + Set.map (CausalEntity,) causalHashes + ] + +entityDependencies :: Entity hash text -> Set (EntityKind, Hash32) +entityDependencies Entity {entityData} = do + case (CBOR.deserialiseOrFailCBORBytes $ entityData) of + -- TODO: proper error handling + Left err -> error $ show err + Right tempEntity -> tempEntityDependencies tempEntity diff --git a/unison-share-api/unison-share-api.cabal b/unison-share-api/unison-share-api.cabal index 1018ffdb1a..e995cadf76 100644 --- a/unison-share-api/unison-share-api.cabal +++ b/unison-share-api/unison-share-api.cabal @@ -53,6 +53,7 @@ library Unison.SyncV2.API Unison.SyncV2.Types Unison.SyncV3.Types + Unison.SyncV3.Utils Unison.Util.Find Unison.Util.Servant.CBOR Unison.Util.Websockets From 31655c33f2bff2c6b23fc366b22bb3bd0ad24a9a Mon Sep 17 00:00:00 2001 From: Chris Penner Date: Tue, 7 Oct 2025 11:40:15 -0700 Subject: [PATCH 15/15] Fix up some queries --- .../codebase-sqlite/U/Codebase/Sqlite/Queries.hs | 3 ++- .../sql/020-add-sync-v3-temp-tables.sql | 2 +- unison-cli/package.yaml | 1 + unison-cli/src/Unison/Share/SyncV3.hs | 13 +++++++++++-- unison-cli/unison-cli.cabal | 1 + unison-share-api/src/Unison/SyncV3/Types.hs | 5 ----- 6 files changed, 16 insertions(+), 9 deletions(-) diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs index 58f2756540..6b96067540 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs @@ -418,7 +418,7 @@ type TextPathSegments = [Text] -- * main squeeze currentSchemaVersion :: SchemaVersion -currentSchemaVersion = 22 +currentSchemaVersion = 23 runCreateSql :: Transaction () runCreateSql = @@ -4039,6 +4039,7 @@ streamTempEntitiesSyncV3 rootCausalHash action = do Sqlite.queryStreamRow @(Hash32, BL.ByteString) [sql| SELECT entity_hash, entity_data + FROM syncv3_temp_entity WHERE root_causal = :rootCausalHash ORDER BY entity_depth ASC |] diff --git a/codebase2/codebase-sqlite/sql/020-add-sync-v3-temp-tables.sql b/codebase2/codebase-sqlite/sql/020-add-sync-v3-temp-tables.sql index 243dc128ab..5bd3eba986 100644 --- a/codebase2/codebase-sqlite/sql/020-add-sync-v3-temp-tables.sql +++ b/codebase2/codebase-sqlite/sql/020-add-sync-v3-temp-tables.sql @@ -1,7 +1,7 @@ -- Add a new table for storing entities which are currently being synced CREATE TABLE syncv3_temp_entity ( - root_causal INTEGER NOT NULL REFERENCES hash (id) ON DELETE CASCADE, + root_causal INTEGER NOT NULL, entity_hash TEXT NOT NULL, entity_kind TEXT NOT NULL, entity_data BLOB NOT NULL, diff --git a/unison-cli/package.yaml b/unison-cli/package.yaml index 8aa0135894..844f92d83d 100644 --- a/unison-cli/package.yaml +++ b/unison-cli/package.yaml @@ -53,6 +53,7 @@ library: - megaparsec - memory - mtl + - network - network-simple - network-uri - nonempty-containers diff --git a/unison-cli/src/Unison/Share/SyncV3.hs b/unison-cli/src/Unison/Share/SyncV3.hs index 5045da4ab8..80386908ec 100644 --- a/unison-cli/src/Unison/Share/SyncV3.hs +++ b/unison-cli/src/Unison/Share/SyncV3.hs @@ -108,7 +108,7 @@ syncFromCodeserver _shouldValidate codeserver branchRef hashJwt = do -- TODO: proper error handling Left err -> error $ show err Right () -> pure () - Debug.debugLogM Debug.Temp "Done sync, flushing temp entities" + Debug.debugLogM Debug.Temp "!Done sync, flushing temp entities" causalId <- liftIO $ flushTemp codebase (Share.hashJWTHash hashJwt) pure $ Right (Sync.hash32ToCausalHash rootCausalHash, causalId) @@ -129,11 +129,16 @@ doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, r _ <- Ki.fork scope (receiverWorker onErr) _ <- Ki.fork scope (requesterWorker onErr) _ <- Ki.fork scope (ingestionWorker onErr) + let finished = do + pending <- readTVar pendingRequestsVar + yetToReq <- readTVar yetToRequestVar + guard $ Set.null pending && Set.null yetToReq Debug.debugLogM Debug.Temp "Awaiting completion" result <- atomically $ - (Right <$> Ki.awaitAll scope) + (Right <$> finished) + <|> (Right <$> Ki.awaitAll scope) <|> (Left . Left <$> readTMVar errorVar) <|> (Left . Right <$> connectionClosed) @@ -159,6 +164,7 @@ doSync codebase SyncState {pendingRequestsVar, yetToRequestVar, toIngestQueue, r Debug.debugLogM Debug.Temp "Requester waiting to send requests" atomically $ do requests <- readTVar yetToRequestVar + guard $ not (Set.null requests) writeTVar yetToRequestVar Set.empty modifyTVar' pendingRequestsVar (Set.union requests) send $ Msg $ ReceiverEntityRequest $ EntityRequestMsg (Set.toList requests) @@ -211,11 +217,14 @@ flushTemp codebase rootCausalHash = do Nothing -> pure () Just (hash, tempEntityBytes) -> do + Debug.debugLogM Debug.Temp $ "Flushing temp entity: " <> show hash tempEntity <- case CBOR.deserialiseOrFailCBORBytes (CBOR.CBORBytes tempEntityBytes) of -- TODO: proper error handling Left err -> error $ show err Right tempEntity -> pure tempEntity + Debug.debugLogM Debug.Temp $ "Saving in main" <> show hash void $ Q.saveTempEntityInMain v2HashHandle hash tempEntity loop loop + Debug.debugLogM Debug.Temp "Flushed temp entities, getting causal hash id" Q.expectCausalHashIdByCausalHash (Sync.hash32ToCausalHash rootCausalHash) diff --git a/unison-cli/unison-cli.cabal b/unison-cli/unison-cli.cabal index 37651a2157..dbea97e245 100644 --- a/unison-cli/unison-cli.cabal +++ b/unison-cli/unison-cli.cabal @@ -248,6 +248,7 @@ library , megaparsec , memory , mtl + , network , network-simple , network-uri , nonempty-containers diff --git a/unison-share-api/src/Unison/SyncV3/Types.hs b/unison-share-api/src/Unison/SyncV3/Types.hs index 6618944cf9..8ada257633 100644 --- a/unison-share-api/src/Unison/SyncV3/Types.hs +++ b/unison-share-api/src/Unison/SyncV3/Types.hs @@ -28,7 +28,6 @@ import Network.WebSockets (WebSocketsData) import Network.WebSockets qualified as WS import U.Codebase.Sqlite.Orphans () import U.Codebase.Sqlite.TempEntity -import Unison.Debug qualified as Debug import Unison.Hash32 (Hash32) import Unison.Prelude (tShow) import Unison.Server.Orphans () @@ -97,7 +96,6 @@ instance CBOR.Serialise FromReceiverMessageTag where decode = do tag <- CBOR.decode @Int - Debug.debugM Debug.Temp "Decoding FromReceiverMessageTag with tag" tag case tag of 0 -> pure ReceiverInitStreamTag 1 -> pure ReceiverEntityRequestTag @@ -120,7 +118,6 @@ instance (ToJSON ah, FromJSON ah) => CBOR.Serialise (InitMsg ah) where CBOR.encode @BS.ByteString $ BL.toStrict $ Aeson.encode msg decode = do - Debug.debugLogM Debug.Temp "Decoding InitMsg from JSON via CBOR" bs <- CBOR.decode @BS.ByteString case Aeson.eitherDecode $ BL.fromStrict bs of Left err -> fail $ "Error decoding InitMsg from JSON: " <> err @@ -147,7 +144,6 @@ instance (CBOR.Serialise h, ToJSON ah, FromJSON ah) => CBOR.Serialise (FromRecei <> CBOR.encode msg decode = do tag <- CBOR.decode @FromReceiverMessageTag - Debug.debugM Debug.Temp "Decoding FromReceiverMessage with tag" tag case tag of ReceiverInitStreamTag -> ReceiverInitStream <$> CBOR.decode @(InitMsg ah) ReceiverEntityRequestTag -> ReceiverEntityRequest <$> CBOR.decode @(EntityRequestMsg h) @@ -367,7 +363,6 @@ instance (CBOR.Serialise a, CBOR.Serialise err) => CBOR.Serialise (MsgOrError er decode = do tag <- CBOR.decode @Int - Debug.debugM Debug.Temp "Decoding MsgOrError with tag" tag case tag of 0 -> Msg <$> CBOR.decode 1 -> Err <$> CBOR.decode