From 4d75b2a9fb96308edf2701831cd9c073fa9ee2a3 Mon Sep 17 00:00:00 2001 From: Bryan Chen Date: Fri, 8 Nov 2024 23:34:12 +1300 Subject: [PATCH] state trie fixed --- .../Blockchain/State/InMemoryBackend.swift | 30 ++++--- .../Sources/Blockchain/State/StateTrie.swift | 81 ++++++++++++------- .../BlockchainTests/StateTrieTests.swift | 64 ++++++++++++++- 3 files changed, 134 insertions(+), 41 deletions(-) diff --git a/Blockchain/Sources/Blockchain/State/InMemoryBackend.swift b/Blockchain/Sources/Blockchain/State/InMemoryBackend.swift index 0452f185..af5dadce 100644 --- a/Blockchain/Sources/Blockchain/State/InMemoryBackend.swift +++ b/Blockchain/Sources/Blockchain/State/InMemoryBackend.swift @@ -1,22 +1,25 @@ import Codec import Foundation +import TracingUtils import Utils -private struct KVPair: Comparable, Sendable { - var key: Data - var value: Data +private let logger = Logger(label: "InMemoryBackend") - public static func < (lhs: KVPair, rhs: KVPair) -> Bool { - lhs.key.lexicographicallyPrecedes(rhs.key) +public actor InMemoryBackend: StateBackendProtocol { + public struct KVPair: Comparable, Sendable { + var key: Data + var value: Data + + public static func < (lhs: KVPair, rhs: KVPair) -> Bool { + lhs.key.lexicographicallyPrecedes(rhs.key) + } } -} -public actor InMemoryBackend: StateBackendProtocol { // we really should be using Heap or some other Tree based structure here // but let's keep it simple for now - private var store: SortedArray = .init([]) + public private(set) var store: SortedArray = .init([]) private var rawValues: [Data32: Data] = [:] - private var refCounts: [Data: Int] = [:] + public private(set) var refCounts: [Data: Int] = [:] private var rawValueRefCounts: [Data32: Int] = [:] public init() {} @@ -92,4 +95,13 @@ public actor InMemoryBackend: StateBackendProtocol { } } } + + public func debugPrint() { + for item in store.array { + let refCount = refCounts[item.key, default: 0] + logger.info("key: \(item.key.toHexString())") + logger.info("value: \(item.value.toHexString())") + logger.info("ref count: \(refCount)") + } + } } diff --git a/Blockchain/Sources/Blockchain/State/StateTrie.swift b/Blockchain/Sources/Blockchain/State/StateTrie.swift index af5696e0..40e26a47 100644 --- a/Blockchain/Sources/Blockchain/State/StateTrie.swift +++ b/Blockchain/Sources/Blockchain/State/StateTrie.swift @@ -1,18 +1,15 @@ import Foundation +import TracingUtils import Utils +private let logger = Logger(label: "StateTrie") + private enum TrieNodeType { case branch case embeddedLeaf case regularLeaf } -private func toId(hash: Data32) -> Data { - var id = hash.data - id[0] = id[0] & 0b0111_1111 // clear the highest bit - return id -} - private struct TrieNode { let hash: Data32 let left: Data32 @@ -20,7 +17,6 @@ private struct TrieNode { let type: TrieNodeType let isNew: Bool let rawValue: Data? - let id: Data init(hash: Data32, data: Data64, isNew: Bool = false) { self.hash = hash @@ -36,7 +32,6 @@ private struct TrieNode { default: type = .branch } - id = toId(hash: hash) } private init(left: Data32, right: Data32, type: TrieNodeType, isNew: Bool, rawValue: Data?) { @@ -46,7 +41,6 @@ private struct TrieNode { self.type = type self.isNew = isNew self.rawValue = rawValue - id = toId(hash: hash) } var encodedData: Data64 { @@ -144,7 +138,7 @@ public actor StateTrie: Sendable { if hash == Data32() { return nil } - let id = toId(hash: hash) + let id = hash.data.suffix(31) if deleted.contains(id) { return nil } @@ -180,34 +174,33 @@ public actor StateTrie: Sendable { var refChanges = [Data: Int]() // process deleted nodes - let deletedCopy = deleted - deleted.removeAll() - for id in deletedCopy { + for id in deleted { guard let node = nodes[id] else { continue } if node.isBranch { // assign -1 to not worry about duplicates - refChanges[node.hash.data] = -1 - refChanges[node.left.data] = -1 - refChanges[node.right.data] = -1 + refChanges[node.hash.data.suffix(31)] = -1 + refChanges[node.left.data.suffix(31)] = -1 + refChanges[node.right.data.suffix(31)] = -1 } nodes.removeValue(forKey: id) } + deleted.removeAll() for node in nodes.values where node.isNew { - ops.append(.write(key: node.id, value: node.encodedData.data)) + ops.append(.write(key: node.hash.data.suffix(31), value: node.encodedData.data)) if node.type == .regularLeaf { try ops.append(.writeRawValue(key: node.right, value: node.rawValue.unwrap())) } if node.isBranch { - refChanges[node.left.data] = (refChanges[node.left.data] ?? 0) + 1 - refChanges[node.right.data] = (refChanges[node.right.data] ?? 0) + 1 + refChanges[node.left.data.suffix(31), default: 0] += 1 + refChanges[node.right.data.suffix(31), default: 0] += 1 } } // pin root node - refChanges[rootHash.data] = (refChanges[rootHash.data] ?? 0) + 1 + refChanges[rootHash.data.suffix(31), default: 0] += 1 nodes.removeAll() @@ -217,9 +210,9 @@ public actor StateTrie: Sendable { continue } if value > 0 { - ops.append(.refIncrement(key: key)) + ops.append(.refIncrement(key: key.suffix(31))) } else if value < 0 { - ops.append(.refDecrement(key: key)) + ops.append(.refDecrement(key: key.suffix(31))) } } @@ -234,9 +227,10 @@ public actor StateTrie: Sendable { saveNode(node: node) return node.hash } - removeNode(hash: hash) if parent.isBranch { + removeNode(hash: hash) + let bitValue = bitAt(key.data, position: depth) var left = parent.left var right = parent.right @@ -262,7 +256,7 @@ public actor StateTrie: Sendable { return newLeaf.hash } - let existingKeyBit = bitAt(existing.left.data, position: depth) + let existingKeyBit = bitAt(existing.left.data[1...], position: depth) let newKeyBit = bitAt(newKey.data, position: depth) if existingKeyBit == newKeyBit { @@ -292,9 +286,10 @@ public actor StateTrie: Sendable { private func delete(hash: Data32, key: Data32, depth: UInt8) async throws -> Data32 { let node = try await get(hash: hash).unwrap(orError: StateTrieError.invalidParent) - removeNode(hash: hash) if node.isBranch { + removeNode(hash: hash) + let bitValue = bitAt(key.data, position: depth) var left = node.left var right = node.right @@ -320,14 +315,44 @@ public actor StateTrie: Sendable { } private func removeNode(hash: Data32) { - let id = toId(hash: hash) + let id = hash.data.suffix(31) deleted.insert(id) nodes.removeValue(forKey: id) } private func saveNode(node: TrieNode) { - nodes[node.id] = node - deleted.remove(node.id) // TODO: maybe this is not needed + let id = node.hash.data.suffix(31) + nodes[id] = node + deleted.remove(id) // TODO: maybe this is not needed + } + + public func debugPrint() async throws { + func printNode(_ hash: Data32, depth: UInt8) async throws { + let prefix = String(repeating: " ", count: Int(depth)) + if hash == Data32() { + logger.info("\(prefix) nil") + return + } + let node = try await get(hash: hash) + guard let node else { + return logger.info("\(prefix) ????") + } + logger.info("\(prefix)\(node.hash.toHexString()) \(node.type)") + if node.isBranch { + logger.info("\(prefix) left:") + try await printNode(node.left, depth: depth + 1) + + logger.info("\(prefix) right:") + try await printNode(node.right, depth: depth + 1) + } else { + logger.info("\(prefix) key: \(node.left.toHexString())") + if let value = node.value { + logger.info("\(prefix) value: \(value.toHexString())") + } + } + } + + try await printNode(rootHash, depth: 0) } } diff --git a/Blockchain/Tests/BlockchainTests/StateTrieTests.swift b/Blockchain/Tests/BlockchainTests/StateTrieTests.swift index 43f520af..60526eaa 100644 --- a/Blockchain/Tests/BlockchainTests/StateTrieTests.swift +++ b/Blockchain/Tests/BlockchainTests/StateTrieTests.swift @@ -1,8 +1,20 @@ -@testable import Blockchain import Foundation import Testing +import TracingUtils import Utils +@testable import Blockchain + +private let logger = Logger(label: "StateTrieTests") + +private func merklize(_ data: some Sequence<(key: Data32, value: Data)>) -> Data32 { + var dict = [Data32: Data]() + for (key, value) in data { + dict[key] = value + } + return try! stateMerklize(kv: dict) +} + struct StateTrieTests { let backend = InMemoryBackend() @@ -29,15 +41,59 @@ struct StateTrieTests { #expect(retrieved == value) } + @Test + func testInsertAndRetrieveSimple() async throws { + let trie = StateTrie(rootHash: Data32(), backend: backend) + let remainKey = Data(repeating: 0, count: 31) + let pairs = [ + (key: Data32(Data([0b0000_0000]) + remainKey)!, value: Data([0])), + (key: Data32(Data([0b1000_0000]) + remainKey)!, value: Data([1])), + (key: Data32(Data([0b0100_0000]) + remainKey)!, value: Data([2])), + (key: Data32(Data([0b1100_0000]) + remainKey)!, value: Data([3])), + ] + + for (i, pair) in pairs.enumerated() { + try await trie.update([(key: pair.key, value: pair.value)]) + + let expectedRoot = merklize(pairs[0 ... i]) + let trieRoot = await trie.rootHash + #expect(expectedRoot == trieRoot) + } + + for (i, (key, value)) in pairs.enumerated() { + let retrieved = try await trie.read(key: key) + #expect(retrieved == value, "Failed at index \(i)") + } + + try await trie.save() + + for (i, (key, value)) in pairs.enumerated() { + let retrieved = try await trie.read(key: key) + #expect(retrieved == value, "Failed at index \(i)") + } + } + @Test func testInsertAndRetrieveMultipleValues() async throws { let trie = StateTrie(rootHash: Data32(), backend: backend) - let pairs = (0 ..< 5).map { i in - let data = Data(String(i).utf8) + let pairs = (0 ..< 50).map { i in + let data = Data([UInt8(i)]) return (key: data.blake2b256hash(), value: data) } - try await trie.update(pairs) + for (i, pair) in pairs.enumerated() { + try await trie.update([(key: pair.key, value: pair.value)]) + + let expectedRoot = merklize(pairs[0 ... i]) + let trieRoot = await trie.rootHash + #expect(expectedRoot == trieRoot) + } + + for (i, (key, value)) in pairs.enumerated() { + let retrieved = try await trie.read(key: key) + #expect(retrieved == value, "Failed at index \(i)") + } + try await trie.save() for (i, (key, value)) in pairs.enumerated() {