Skip to content

Commit f9ea8f6

Browse files
committed
Implement update method for MerkleTree to synchronize with source tree and manage unsaved changes
1 parent 23a27f7 commit f9ea8f6

File tree

1 file changed

+130
-13
lines changed

1 file changed

+130
-13
lines changed

src/main/java/io/pwrlabs/database/rocksdb/MerkleTree.java

Lines changed: 130 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,21 @@ public class MerkleTree {
6464
* Cache of loaded nodes (in-memory for quick access).
6565
*/
6666
private final Map<ByteArrayWrapper, Node> nodesCache = new ConcurrentHashMap<>();
67-
67+
private final Map<Integer /*level*/, byte[]> hangingNodes = new ConcurrentHashMap<>();
6868
private final Map<ByteArrayWrapper /*Key*/, byte[] /*data*/> keyDataCache = new ConcurrentHashMap<>();
6969

70-
/**
71-
* Lock for reading/writing to the tree.
72-
*/
73-
private final ReadWriteLock lock = new ReentrantReadWriteLock();
74-
75-
private final Map<Integer /*level*/, byte[]> hangingNodes = new ConcurrentHashMap<>();
7670
@Getter
7771
private int numLeaves = 0;
7872
@Getter
7973
private int depth = 0;
8074
private byte[] rootHash = null;
8175

8276
private AtomicBoolean closed = new AtomicBoolean(false);
77+
private AtomicBoolean hasUnsavedChanges = new AtomicBoolean(false);
78+
/**
79+
* Lock for reading/writing to the tree.
80+
*/
81+
private final ReadWriteLock lock = new ReentrantReadWriteLock();
8382
//endregion
8483

8584
//region ===================== Constructors =====================
@@ -130,6 +129,18 @@ public byte[] getRootHash() {
130129
}
131130
}
132131

132+
public byte[] getRootHashSavedOnDisk() {
133+
errorIfClosed();
134+
lock.readLock().lock();
135+
try {
136+
return db.get(metaDataHandle, KEY_ROOT_HASH.getBytes());
137+
} catch (RocksDBException e) {
138+
throw new RuntimeException(e);
139+
} finally {
140+
lock.readLock().unlock();
141+
}
142+
}
143+
133144
public int getNumLeaves() {
134145
errorIfClosed();
135146
lock.readLock().lock();
@@ -238,6 +249,7 @@ public void addOrUpdateData(byte[] key, byte[] data) throws RocksDBException {
238249

239250
// Store key-data mapping
240251
keyDataCache.put(new ByteArrayWrapper(key), data);
252+
hasUnsavedChanges.set(true);
241253

242254
if (oldLeafHash == null) {
243255
// Key doesn't exist, add new leaf
@@ -247,14 +259,13 @@ public void addOrUpdateData(byte[] key, byte[] data) throws RocksDBException {
247259
// First get the old leaf hash
248260
updateLeaf(oldLeafHash, newLeafHash);
249261
}
250-
251-
flushToDisk();
252262
} finally {
253263
lock.writeLock().unlock();
254264
}
255265
}
256266

257267
public void revertUnsavedChanges() {
268+
if(!hasUnsavedChanges.get()) return;
258269
errorIfClosed();
259270

260271
lock.writeLock().lock();
@@ -264,6 +275,8 @@ public void revertUnsavedChanges() {
264275
keyDataCache.clear();
265276

266277
loadMetaData();
278+
279+
hasUnsavedChanges.set(false);
267280
} catch (RocksDBException e) {
268281
throw new RuntimeException(e);
269282
} finally {
@@ -391,6 +404,7 @@ public void flushToDisk() throws RocksDBException {
391404

392405
nodesCache.clear();
393406
keyDataCache.clear();
407+
hasUnsavedChanges.set(false);
394408
}
395409
} finally {
396410
lock.writeLock().unlock();
@@ -491,6 +505,64 @@ public MerkleTree clone(String newTreeName) throws RocksDBException, IOException
491505
}
492506
}
493507

508+
public void update(MerkleTree sourceTree) throws RocksDBException, IOException {
509+
errorIfClosed();
510+
lock.writeLock().lock();
511+
sourceTree.lock.writeLock().lock();
512+
try {
513+
if (sourceTree == null) {
514+
throw new IllegalArgumentException("Source tree cannot be null");
515+
}
516+
517+
if(Arrays.equals(getRootHashSavedOnDisk(), sourceTree.getRootHashSavedOnDisk())) {
518+
//This means that this tree is already a copy of the source tree and we only need to replace the cache
519+
copyCache(sourceTree);
520+
} else {
521+
if(metaDataHandle != null) {
522+
metaDataHandle.close();
523+
metaDataHandle = null;
524+
}
525+
if(nodesHandle != null) {
526+
nodesHandle.close();
527+
nodesHandle = null;
528+
}
529+
if(keyDataHandle != null) {
530+
keyDataHandle.close();
531+
keyDataHandle = null;
532+
}
533+
if(db != null && !db.isClosed()) {
534+
db.close();
535+
db = null;
536+
};
537+
538+
sourceTree.flushToDisk();
539+
540+
File thisTreesDirectory = new File(path);
541+
FileUtils.deleteDirectory(thisTreesDirectory);
542+
543+
try (Checkpoint checkpoint = Checkpoint.create(sourceTree.db)) {
544+
checkpoint.createCheckpoint(thisTreesDirectory.getAbsolutePath());
545+
} catch (Exception e) {
546+
e.printStackTrace();
547+
throw new RuntimeException(e);
548+
}
549+
550+
// Reinitialize the database
551+
initializeDb();
552+
loadMetaData();
553+
554+
nodesCache.clear();
555+
keyDataCache.clear();
556+
hangingNodes.clear();
557+
hasUnsavedChanges.set(false);
558+
}
559+
560+
} finally {
561+
sourceTree.lock.writeLock().unlock();
562+
lock.writeLock().unlock();
563+
}
564+
}
565+
494566
/**
495567
* Efficiently clears the entire MerkleTree by closing, deleting and recreating the RocksDB instance.
496568
* This is much faster than iterating through all entries and deleting them individually.
@@ -519,6 +591,7 @@ public void clear() throws RocksDBException {
519591
hangingNodes.clear();
520592
rootHash = null;
521593
numLeaves = depth = 0;
594+
hasUnsavedChanges.set(false);
522595

523596
} finally {
524597
lock.writeLock().unlock();
@@ -815,6 +888,29 @@ private void errorIfClosed() {
815888
}
816889
}
817890

891+
private void copyCache(MerkleTree sourceTree) {
892+
nodesCache.clear();
893+
keyDataCache.clear();
894+
hangingNodes.clear();
895+
896+
for (Map.Entry<ByteArrayWrapper, Node> entry : sourceTree.nodesCache.entrySet()) {
897+
nodesCache.put(entry.getKey(), new Node(entry.getValue()));
898+
}
899+
900+
for (Map.Entry<ByteArrayWrapper, byte[]> entry : sourceTree.keyDataCache.entrySet()) {
901+
keyDataCache.put(entry.getKey(), Arrays.copyOf(entry.getValue(), entry.getValue().length));
902+
}
903+
904+
for (Map.Entry<Integer, byte[]> entry : sourceTree.hangingNodes.entrySet()) {
905+
hangingNodes.put(entry.getKey(), Arrays.copyOf(entry.getValue(), entry.getValue().length));
906+
}
907+
908+
rootHash = Arrays.copyOf(sourceTree.rootHash, sourceTree.rootHash.length);
909+
numLeaves = sourceTree.numLeaves;
910+
depth = sourceTree.depth;
911+
hasUnsavedChanges.set(sourceTree.hasUnsavedChanges.get());
912+
}
913+
818914
//endregion
819915

820916
//region ===================== Nested Classes =====================
@@ -882,6 +978,17 @@ public Node(byte[] left, byte[] right) {
882978
nodesCache.put(new ByteArrayWrapper(hash), this);
883979
}
884980

981+
/**
982+
* Copy constructor for Node.
983+
*/
984+
public Node(Node node) {
985+
this.hash = Arrays.copyOf(node.hash, node.hash.length);
986+
this.left = (node.left != null) ? Arrays.copyOf(node.left, node.left.length) : null;
987+
this.right = (node.right != null) ? Arrays.copyOf(node.right, node.right.length) : null;
988+
this.parent = (node.parent != null) ? Arrays.copyOf(node.parent, node.parent.length) : null;
989+
this.nodeHashToRemoveFromDb = (node.nodeHashToRemoveFromDb != null) ? Arrays.copyOf(node.nodeHashToRemoveFromDb, node.nodeHashToRemoveFromDb.length) : null;
990+
}
991+
885992
/**
886993
* Calculate the hash of this node based on the left and right child hashes.
887994
*/
@@ -1142,11 +1249,21 @@ public boolean equals(Object obj) {
11421249
//endregion
11431250

11441251
public static void main(String[] args) throws Exception {
1145-
MerkleTree tree = new MerkleTree("bro/tree1");
1252+
MerkleTree tree = new MerkleTree("b41230566oo/tree1");
11461253
tree.addOrUpdateData("key1".getBytes(), "value1".getBytes());
11471254

1148-
MerkleTree tree2 = tree.clone("bro/tree2");
1149-
System.out.println(Hex.toHexString(tree2.getData("key1".getBytes())));
1150-
System.out.println("ok");
1255+
MerkleTree tree2 = tree.clone("br2615034oo6/tree2");
1256+
1257+
tree.addOrUpdateData("key2".getBytes(), "value2".getBytes());
1258+
//tree.flushToDisk();
1259+
1260+
System.out.println("u");
1261+
tree2.update(tree);
1262+
System.out.println("ud");
1263+
1264+
System.out.println(Hex.toHexString(tree2.getData(("key2").getBytes())));
1265+
System.out.println(Hex.toHexString(tree.getData(("key2").getBytes())));
1266+
1267+
System.out.println(Arrays.equals(tree.getRootHash(), tree2.getRootHash()));
11511268
}
11521269
}

0 commit comments

Comments
 (0)