Skip to content

Commit

Permalink
state trie fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
xlc committed Nov 8, 2024
1 parent e17f4e2 commit 4d75b2a
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 41 deletions.
30 changes: 21 additions & 9 deletions Blockchain/Sources/Blockchain/State/InMemoryBackend.swift
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import Codec
import Foundation
import TracingUtils
import Utils

private struct KVPair: Comparable, Sendable {
var key: Data
var value: Data
private let logger = Logger(label: "InMemoryBackend")

public static func < (lhs: KVPair, rhs: KVPair) -> Bool {
lhs.key.lexicographicallyPrecedes(rhs.key)
public actor InMemoryBackend: StateBackendProtocol {
public struct KVPair: Comparable, Sendable {
var key: Data
var value: Data

public static func < (lhs: KVPair, rhs: KVPair) -> Bool {
lhs.key.lexicographicallyPrecedes(rhs.key)
}
}
}

public actor InMemoryBackend: StateBackendProtocol {
// we really should be using Heap or some other Tree based structure here
// but let's keep it simple for now
private var store: SortedArray<KVPair> = .init([])
public private(set) var store: SortedArray<KVPair> = .init([])
private var rawValues: [Data32: Data] = [:]
private var refCounts: [Data: Int] = [:]
public private(set) var refCounts: [Data: Int] = [:]
private var rawValueRefCounts: [Data32: Int] = [:]

public init() {}
Expand Down Expand Up @@ -92,4 +95,13 @@ public actor InMemoryBackend: StateBackendProtocol {
}
}
}

public func debugPrint() {
for item in store.array {
let refCount = refCounts[item.key, default: 0]
logger.info("key: \(item.key.toHexString())")
logger.info("value: \(item.value.toHexString())")
logger.info("ref count: \(refCount)")
}
}
}
81 changes: 53 additions & 28 deletions Blockchain/Sources/Blockchain/State/StateTrie.swift
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
import Foundation
import TracingUtils
import Utils

private let logger = Logger(label: "StateTrie")

private enum TrieNodeType {
case branch
case embeddedLeaf
case regularLeaf
}

private func toId(hash: Data32) -> Data {
var id = hash.data
id[0] = id[0] & 0b0111_1111 // clear the highest bit
return id
}

private struct TrieNode {
let hash: Data32
let left: Data32
let right: Data32
let type: TrieNodeType
let isNew: Bool
let rawValue: Data?
let id: Data

init(hash: Data32, data: Data64, isNew: Bool = false) {
self.hash = hash
Expand All @@ -36,7 +32,6 @@ private struct TrieNode {
default:
type = .branch
}
id = toId(hash: hash)
}

private init(left: Data32, right: Data32, type: TrieNodeType, isNew: Bool, rawValue: Data?) {
Expand All @@ -46,7 +41,6 @@ private struct TrieNode {
self.type = type
self.isNew = isNew
self.rawValue = rawValue
id = toId(hash: hash)
}

var encodedData: Data64 {
Expand Down Expand Up @@ -144,7 +138,7 @@ public actor StateTrie: Sendable {
if hash == Data32() {
return nil
}
let id = toId(hash: hash)
let id = hash.data.suffix(31)
if deleted.contains(id) {
return nil
}
Expand Down Expand Up @@ -180,34 +174,33 @@ public actor StateTrie: Sendable {
var refChanges = [Data: Int]()

// process deleted nodes
let deletedCopy = deleted
deleted.removeAll()
for id in deletedCopy {
for id in deleted {
guard let node = nodes[id] else {
continue
}
if node.isBranch {
// assign -1 to not worry about duplicates
refChanges[node.hash.data] = -1
refChanges[node.left.data] = -1
refChanges[node.right.data] = -1
refChanges[node.hash.data.suffix(31)] = -1
refChanges[node.left.data.suffix(31)] = -1
refChanges[node.right.data.suffix(31)] = -1
}
nodes.removeValue(forKey: id)
}
deleted.removeAll()

for node in nodes.values where node.isNew {
ops.append(.write(key: node.id, value: node.encodedData.data))
ops.append(.write(key: node.hash.data.suffix(31), value: node.encodedData.data))
if node.type == .regularLeaf {
try ops.append(.writeRawValue(key: node.right, value: node.rawValue.unwrap()))
}
if node.isBranch {
refChanges[node.left.data] = (refChanges[node.left.data] ?? 0) + 1
refChanges[node.right.data] = (refChanges[node.right.data] ?? 0) + 1
refChanges[node.left.data.suffix(31), default: 0] += 1
refChanges[node.right.data.suffix(31), default: 0] += 1
}
}

// pin root node
refChanges[rootHash.data] = (refChanges[rootHash.data] ?? 0) + 1
refChanges[rootHash.data.suffix(31), default: 0] += 1

nodes.removeAll()

Expand All @@ -217,9 +210,9 @@ public actor StateTrie: Sendable {
continue
}
if value > 0 {
ops.append(.refIncrement(key: key))
ops.append(.refIncrement(key: key.suffix(31)))
} else if value < 0 {
ops.append(.refDecrement(key: key))
ops.append(.refDecrement(key: key.suffix(31)))
}
}

Expand All @@ -234,9 +227,10 @@ public actor StateTrie: Sendable {
saveNode(node: node)
return node.hash
}
removeNode(hash: hash)

if parent.isBranch {
removeNode(hash: hash)

let bitValue = bitAt(key.data, position: depth)
var left = parent.left
var right = parent.right
Expand All @@ -262,7 +256,7 @@ public actor StateTrie: Sendable {
return newLeaf.hash
}

let existingKeyBit = bitAt(existing.left.data, position: depth)
let existingKeyBit = bitAt(existing.left.data[1...], position: depth)
let newKeyBit = bitAt(newKey.data, position: depth)

if existingKeyBit == newKeyBit {
Expand Down Expand Up @@ -292,9 +286,10 @@ public actor StateTrie: Sendable {

private func delete(hash: Data32, key: Data32, depth: UInt8) async throws -> Data32 {
let node = try await get(hash: hash).unwrap(orError: StateTrieError.invalidParent)
removeNode(hash: hash)

if node.isBranch {
removeNode(hash: hash)

let bitValue = bitAt(key.data, position: depth)
var left = node.left
var right = node.right
Expand All @@ -320,14 +315,44 @@ public actor StateTrie: Sendable {
}

private func removeNode(hash: Data32) {
let id = toId(hash: hash)
let id = hash.data.suffix(31)
deleted.insert(id)
nodes.removeValue(forKey: id)
}

private func saveNode(node: TrieNode) {
nodes[node.id] = node
deleted.remove(node.id) // TODO: maybe this is not needed
let id = node.hash.data.suffix(31)
nodes[id] = node
deleted.remove(id) // TODO: maybe this is not needed
}

public func debugPrint() async throws {
func printNode(_ hash: Data32, depth: UInt8) async throws {
let prefix = String(repeating: " ", count: Int(depth))
if hash == Data32() {
logger.info("\(prefix) nil")
return
}
let node = try await get(hash: hash)
guard let node else {
return logger.info("\(prefix) ????")
}
logger.info("\(prefix)\(node.hash.toHexString()) \(node.type)")
if node.isBranch {
logger.info("\(prefix) left:")
try await printNode(node.left, depth: depth + 1)

logger.info("\(prefix) right:")
try await printNode(node.right, depth: depth + 1)
} else {
logger.info("\(prefix) key: \(node.left.toHexString())")
if let value = node.value {
logger.info("\(prefix) value: \(value.toHexString())")
}
}
}

try await printNode(rootHash, depth: 0)
}
}

Expand Down
64 changes: 60 additions & 4 deletions Blockchain/Tests/BlockchainTests/StateTrieTests.swift
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
@testable import Blockchain
import Foundation
import Testing
import TracingUtils
import Utils

@testable import Blockchain

private let logger = Logger(label: "StateTrieTests")

private func merklize(_ data: some Sequence<(key: Data32, value: Data)>) -> Data32 {
var dict = [Data32: Data]()
for (key, value) in data {
dict[key] = value
}
return try! stateMerklize(kv: dict)
}

struct StateTrieTests {
let backend = InMemoryBackend()

Expand All @@ -29,15 +41,59 @@ struct StateTrieTests {
#expect(retrieved == value)
}

@Test
func testInsertAndRetrieveSimple() async throws {
let trie = StateTrie(rootHash: Data32(), backend: backend)
let remainKey = Data(repeating: 0, count: 31)
let pairs = [
(key: Data32(Data([0b0000_0000]) + remainKey)!, value: Data([0])),
(key: Data32(Data([0b1000_0000]) + remainKey)!, value: Data([1])),
(key: Data32(Data([0b0100_0000]) + remainKey)!, value: Data([2])),
(key: Data32(Data([0b1100_0000]) + remainKey)!, value: Data([3])),
]

for (i, pair) in pairs.enumerated() {
try await trie.update([(key: pair.key, value: pair.value)])

let expectedRoot = merklize(pairs[0 ... i])
let trieRoot = await trie.rootHash
#expect(expectedRoot == trieRoot)
}

for (i, (key, value)) in pairs.enumerated() {
let retrieved = try await trie.read(key: key)
#expect(retrieved == value, "Failed at index \(i)")
}

try await trie.save()

for (i, (key, value)) in pairs.enumerated() {
let retrieved = try await trie.read(key: key)
#expect(retrieved == value, "Failed at index \(i)")
}
}

@Test
func testInsertAndRetrieveMultipleValues() async throws {
let trie = StateTrie(rootHash: Data32(), backend: backend)
let pairs = (0 ..< 5).map { i in
let data = Data(String(i).utf8)
let pairs = (0 ..< 50).map { i in
let data = Data([UInt8(i)])
return (key: data.blake2b256hash(), value: data)
}

try await trie.update(pairs)
for (i, pair) in pairs.enumerated() {
try await trie.update([(key: pair.key, value: pair.value)])

let expectedRoot = merklize(pairs[0 ... i])
let trieRoot = await trie.rootHash
#expect(expectedRoot == trieRoot)
}

for (i, (key, value)) in pairs.enumerated() {
let retrieved = try await trie.read(key: key)
#expect(retrieved == value, "Failed at index \(i)")
}

try await trie.save()

for (i, (key, value)) in pairs.enumerated() {
Expand Down

0 comments on commit 4d75b2a

Please sign in to comment.