Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to update a leaf in a MerkleTree structure #5453

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/good-zebras-ring.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---

`MerkleTree`: Add an update function that replaces a previously inserted leaf with a new value, updating the tree root along the way.
8 changes: 8 additions & 0 deletions contracts/mocks/MerkleTreeMock.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ contract MerkleTreeMock {
bytes32 public root;

event LeafInserted(bytes32 leaf, uint256 index, bytes32 root);
event LeafUpdated(bytes32 oldLeaf, bytes32 newLeaf, uint256 index, bytes32 root);

function setup(uint8 _depth, bytes32 _zero) public {
root = _tree.setup(_depth, _zero);
Expand All @@ -25,6 +26,13 @@ contract MerkleTreeMock {
root = currentRoot;
}

function update(uint256 index, bytes32 oldValue, bytes32 newValue, bytes32[] memory proof) public {
(bytes32 oldRoot, bytes32 newRoot) = _tree.update(index, oldValue, newValue, proof);
if (oldRoot != root) revert MerkleTree.MerkleTreeUpdateInvalidProof();
emit LeafUpdated(oldValue, newValue, index, newRoot);
root = newRoot;
}

function depth() public view returns (uint256) {
return _tree.depth();
}
Expand Down
92 changes: 92 additions & 0 deletions contracts/utils/structs/MerkleTree.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pragma solidity ^0.8.20;
import {Hashes} from "../cryptography/Hashes.sol";
import {Arrays} from "../Arrays.sol";
import {Panic} from "../Panic.sol";
import {StorageSlot} from "../StorageSlot.sol";

/**
* @dev Library for managing https://wikipedia.org/wiki/Merkle_Tree[Merkle Tree] data structures.
Expand All @@ -27,6 +28,12 @@ import {Panic} from "../Panic.sol";
* _Available since v5.1._
*/
library MerkleTree {
/// @dev Error emitted when trying to update a leaf that was not previously pushed.
error MerkleTreeUpdateInvalidIndex(uint256 index, uint256 length);

/// @dev Error emitted when the proof used during an update is invalid (could not reproduce the side).
error MerkleTreeUpdateInvalidProof();

/**
* @dev A complete `bytes32` Merkle tree.
*
Expand Down Expand Up @@ -166,6 +173,91 @@ library MerkleTree {
return (index, currentLevelHash);
}

/**
* @dev Change value of the leaf at position `index` from `oldValue` to `newValue`. Returns the recomputed "old"
* root (before the update) and "new" root (after the update). The caller must verify that the reconstructed old
* root is the last known one.
*
* The `proof` must be an up-to-date inclusion proof for the leaf being update. This means that this function is
* vulnerable to front-running. Any {push} or {update} operation (that changes the root of the tree) would render
* all "in flight" updates invalid.
*
* This variant uses {Hashes-commutativeKeccak256} to hash internal nodes. It should only be used on merkle trees
* that were setup using the same (default) hashing function (i.e. by calling
* {xref-MerkleTree-setup-struct-MerkleTree-Bytes32PushTree-uint8-bytes32-}[the default setup] function).
*/
function update(
Bytes32PushTree storage self,
uint256 index,
bytes32 oldValue,
bytes32 newValue,
bytes32[] memory proof
) internal returns (bytes32 oldRoot, bytes32 newRoot) {
return update(self, index, oldValue, newValue, proof, Hashes.commutativeKeccak256);
}

/**
* @dev Change value of the leaf at position `index` from `oldValue` to `newValue`. Returns the recomputed "old"
* root (before the update) and "new" root (after the update). The caller must verify that the reconstructed old
* root is the last known one.
*
* The `proof` must be an up-to-date inclusion proof for the leaf being update. This means that this function is
* vulnerable to front-running. Any {push} or {update} operation (that changes the root of the tree) would render
* all "in flight" updates invalid.
*
* This variant uses a custom hashing function to hash internal nodes. It should only be called with the same
* function as the one used during the initial setup of the merkle tree.
*/
function update(
Bytes32PushTree storage self,
uint256 index,
bytes32 oldValue,
bytes32 newValue,
bytes32[] memory proof,
function(bytes32, bytes32) view returns (bytes32) fnHash
) internal returns (bytes32 oldRoot, bytes32 newRoot) {
unchecked {
// Check index range
uint256 length = self._nextLeafIndex;
if (index >= length) revert MerkleTreeUpdateInvalidIndex(index, length);

// Cache read
uint256 treeDepth = depth(self);

// Workaround stack too deep
bytes32[] storage sides = self._sides;

// This cannot overflow because: 0 <= index < length
uint256 lastIndex = length - 1;
uint256 currentIndex = index;
bytes32 currentLevelHashOld = oldValue;
bytes32 currentLevelHashNew = newValue;
for (uint32 i = 0; i < treeDepth; i++) {
bool isLeft = currentIndex % 2 == 0;

lastIndex >>= 1;
currentIndex >>= 1;

if (isLeft && currentIndex == lastIndex) {
StorageSlot.Bytes32Slot storage side = Arrays.unsafeAccess(sides, i);
if (side.value != currentLevelHashOld) revert MerkleTreeUpdateInvalidProof();
side.value = currentLevelHashNew;
}

bytes32 sibling = proof[i];
currentLevelHashOld = fnHash(
isLeft ? currentLevelHashOld : sibling,
isLeft ? sibling : currentLevelHashOld
);
currentLevelHashNew = fnHash(
isLeft ? currentLevelHashNew : sibling,
isLeft ? sibling : currentLevelHashNew
);
}
return (currentLevelHashOld, currentLevelHashNew);
}
}

/**
* @dev Tree's depth (set at initialization)
*/
Expand Down
136 changes: 108 additions & 28 deletions test/utils/structs/MerkleTree.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
const { StandardMerkleTree } = require('@openzeppelin/merkle-tree');

const { generators } = require('../../helpers/random');
const { range } = require('../../helpers/iterate');

const makeTree = (leaves = [ethers.ZeroHash]) =>
const DEPTH = 4; // 16 slots

const makeTree = (leaves = [], length = 2 ** DEPTH, zero = ethers.ZeroHash) =>
StandardMerkleTree.of(
leaves.map(leaf => [leaf]),
[]
.concat(
leaves,
Array.from({ length: length - leaves.length }, () => zero),
)
.map(leaf => [leaf]),
['bytes32'],
{ sortLeaves: false },
);

const hashLeaf = leaf => makeTree().leafHash([leaf]);

const DEPTH = 4n; // 16 slots
const ZERO = hashLeaf(ethers.ZeroHash);
const ZERO = makeTree().leafHash([ethers.ZeroHash]);

async function fixture() {
const mock = await ethers.deployContract('MerkleTreeMock');
Expand All @@ -30,69 +35,144 @@ describe('MerkleTree', function () {
});

it('sets initial values at setup', async function () {
const merkleTree = makeTree(Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash));
const merkleTree = makeTree();

expect(await this.mock.root()).to.equal(merkleTree.root);
expect(await this.mock.depth()).to.equal(DEPTH);
expect(await this.mock.nextLeafIndex()).to.equal(0n);
await expect(this.mock.root()).to.eventually.equal(merkleTree.root);
await expect(this.mock.depth()).to.eventually.equal(DEPTH);
await expect(this.mock.nextLeafIndex()).to.eventually.equal(0n);
});

describe('push', function () {
it('tree is correctly updated', async function () {
const leaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
it('pushing correctly updates the tree', async function () {
const leaves = [];

// for each leaf slot
for (const i in leaves) {
// generate random leaf and hash it
const hashedLeaf = hashLeaf((leaves[i] = generators.bytes32()));
for (const i in range(2 ** DEPTH)) {
// generate random leaf
leaves.push(generators.bytes32());

// update leaf list and rebuild tree.
// rebuild tree.
const tree = makeTree(leaves);
const hash = tree.leafHash(tree.at(i));

// push value to tree
await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, i, tree.root);
await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, i, tree.root);

// check tree
expect(await this.mock.root()).to.equal(tree.root);
expect(await this.mock.nextLeafIndex()).to.equal(BigInt(i) + 1n);
await expect(this.mock.root()).to.eventually.equal(tree.root);
await expect(this.mock.nextLeafIndex()).to.eventually.equal(BigInt(i) + 1n);
}
});

it('revert when tree is full', async function () {
it('pushing to a full tree reverts', async function () {
await Promise.all(Array.from({ length: 2 ** Number(DEPTH) }).map(() => this.mock.push(ethers.ZeroHash)));

await expect(this.mock.push(ethers.ZeroHash)).to.be.revertedWithPanic(PANIC_CODES.TOO_MUCH_MEMORY_ALLOCATED);
});
});

describe('update', function () {
for (const { leafCount, leafIndex } of range(2 ** DEPTH + 1).flatMap(leafCount =>
range(leafCount).map(leafIndex => ({ leafCount, leafIndex })),
))
it(`updating a leaf correctly updates the tree (leaf #${leafIndex + 1}/${leafCount})`, async function () {
// initial tree
const leaves = Array.from({ length: leafCount }, generators.bytes32);
const oldTree = makeTree(leaves);

// fill tree and verify root
for (const i in leaves) {
await this.mock.push(oldTree.leafHash(oldTree.at(i)));
}
await expect(this.mock.root()).to.eventually.equal(oldTree.root);

// create updated tree
leaves[leafIndex] = generators.bytes32();
const newTree = makeTree(leaves);

const oldLeafHash = oldTree.leafHash(oldTree.at(leafIndex));
const newLeafHash = newTree.leafHash(newTree.at(leafIndex));

// perform update
await expect(this.mock.update(leafIndex, oldLeafHash, newLeafHash, oldTree.getProof(leafIndex)))
.to.emit(this.mock, 'LeafUpdated')
.withArgs(oldLeafHash, newLeafHash, leafIndex, newTree.root);

// verify updated root
await expect(this.mock.root()).to.eventually.equal(newTree.root);

// if there is still room in the tree, fill it
for (const i of range(leafCount, 2 ** DEPTH)) {
// push new value and rebuild tree
leaves.push(generators.bytes32());
const nextTree = makeTree(leaves);

// push and verify root
await this.mock.push(nextTree.leafHash(nextTree.at(i)));
await expect(this.mock.root()).to.eventually.equal(nextTree.root);
}
});

it('replacing a leaf that was not previously pushed reverts', async function () {
// changing leaf 0 on an empty tree
await expect(this.mock.update(1, ZERO, ZERO, []))
.to.be.revertedWithCustomError(this.mock, 'MerkleTreeUpdateInvalidIndex')
.withArgs(1, 0);
});

it('replacing a leaf using an invalid proof reverts', async function () {
const leafCount = 4;
const leafIndex = 2;

const leaves = Array.from({ length: leafCount }, generators.bytes32);
const tree = makeTree(leaves);

// fill tree and verify root
for (const i in leaves) {
await this.mock.push(tree.leafHash(tree.at(i)));
}
await expect(this.mock.root()).to.eventually.equal(tree.root);

const oldLeafHash = tree.leafHash(tree.at(leafIndex));
const newLeafHash = generators.bytes32();
const proof = tree.getProof(leafIndex);
// invalid proof (tamper)
proof[1] = generators.bytes32();

await expect(this.mock.update(leafIndex, oldLeafHash, newLeafHash, proof)).to.be.revertedWithCustomError(
this.mock,
'MerkleTreeUpdateInvalidProof',
);
});
});

it('reset', async function () {
// empty tree
const zeroLeaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
const zeroTree = makeTree(zeroLeaves);
const emptyTree = makeTree();

// tree with one element
const leaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash);
const hashedLeaf = hashLeaf((leaves[0] = generators.bytes32())); // fill first leaf and hash it
const leaves = [generators.bytes32()];
const tree = makeTree(leaves);
const hash = tree.leafHash(tree.at(0));

// root should be that of a zero tree
expect(await this.mock.root()).to.equal(zeroTree.root);
expect(await this.mock.root()).to.equal(emptyTree.root);
expect(await this.mock.nextLeafIndex()).to.equal(0n);

// push leaf and check root
await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root);
await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, 0, tree.root);

expect(await this.mock.root()).to.equal(tree.root);
expect(await this.mock.nextLeafIndex()).to.equal(1n);

// reset tree
await this.mock.setup(DEPTH, ZERO);

expect(await this.mock.root()).to.equal(zeroTree.root);
expect(await this.mock.root()).to.equal(emptyTree.root);
expect(await this.mock.nextLeafIndex()).to.equal(0n);

// re-push leaf and check root
await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root);
await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, 0, tree.root);

expect(await this.mock.root()).to.equal(tree.root);
expect(await this.mock.nextLeafIndex()).to.equal(1n);
Expand Down
Loading