diff --git a/.gitignore b/.gitignore index 85198aa..da93f6c 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ docs/ # Dotenv file .env + +# deps +dependencies/ diff --git a/.gitmodules b/.gitmodules index 888d42d..2e7f7c0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,10 @@ [submodule "lib/forge-std"] path = lib/forge-std url = https://github.com/foundry-rs/forge-std +[submodule "dependencies/prb-math"] + path = dependencies/prb-math + url = https://github.com/PaulRBerg/prb-math +[submodule "lib/prb-math"] + branch = "release-v4" + path = "lib/prb-math" + url = "https://github.com/PaulRBerg/prb-math" diff --git a/README.md b/README.md index 9265b45..d7b4043 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,18 @@ Foundry consists of: ## Documentation +## Core dev commands +```bash +# https://book.getfoundry.sh/projects/soldeer + +# install deps +forge soldeer install forge-std~1.9.2 https://github.com/foundry-rs/forge-std.git +forge remappings > remappings.txt + +# match test +forge test --match-test test_BuyShares -vvv +``` + https://book.getfoundry.sh/ ## Usage diff --git a/dependencies/prb-math b/dependencies/prb-math new file mode 160000 index 0000000..39eec81 --- /dev/null +++ b/dependencies/prb-math @@ -0,0 +1 @@ +Subproject commit 39eec818282a29df7406b8280b29c084c9a3f3b5 diff --git a/foundry.toml b/foundry.toml index 23fcc8b..09714b1 100644 --- a/foundry.toml +++ b/foundry.toml @@ -5,5 +5,6 @@ libs = ["lib", "dependencies"] [dependencies] "@openzeppelin-contracts" = "5.1.0" +forge-std = { version = "1.9.2", git = "https://github.com/foundry-rs/forge-std.git", rev = "0e7097750918380d84dd3cfdef595bee74dabb70" } # See more config options https://github.com/foundry-rs/foundry/blob/master/crates/config/README.md#all-options diff --git a/remappings.txt b/remappings.txt index 218757a..97fdbe9 100644 --- a/remappings.txt +++ b/remappings.txt @@ -1 +1,6 @@ @openzeppelin/=dependencies/@openzeppelin-contracts-5.1.0/ +@forge-std/=dependencies/forge-std-1.9.2/src/ +@openzeppelin-contracts-5.1.0/=dependencies/@openzeppelin-contracts-5.1.0/ +forge-std-1.9.2/=dependencies/forge-std-1.9.2/src/ +forge-std/=lib/forge-std/src/ +@prb-math/=dependencies/prb-math/src/ diff --git a/soldeer.lock b/soldeer.lock index 1b1284c..5f3688c 100644 --- a/soldeer.lock +++ b/soldeer.lock @@ -4,3 +4,9 @@ version = "5.1.0" url = "https://soldeer-revisions.s3.amazonaws.com/@openzeppelin-contracts/5_1_0_19-10-2024_10:28:52_contracts.zip" checksum = "fd3d1ea561cb27897008aee18ada6e85f248eb161c86e4435272fc2b5777574f" integrity = "cb6cf6e878f2943b2291d5636a9d72ac51d43d8135896ceb6cf88d36c386f212" + +[[dependencies]] +name = "forge-std" +version = "1.9.2" +git = "https://github.com/foundry-rs/forge-std.git" +rev = "0e7097750918380d84dd3cfdef595bee74dabb70" diff --git a/src/LMSR.sol b/src/LMSR.sol index 09cb43c..0d13b83 100644 --- a/src/LMSR.sol +++ b/src/LMSR.sol @@ -5,6 +5,8 @@ pragma solidity ^0.8.19; import {IERC20} from "@openzeppelin/token/ERC20/IERC20.sol"; import {ReentrancyGuard} from "@openzeppelin/utils/ReentrancyGuard.sol"; import {Math} from "@openzeppelin/utils/math/Math.sol"; +import { UD60x18, ud, unwrap, ln, exp, div, mul } from "@prb-math/UD60x18.sol"; +import { console } from "@forge-std/console.sol"; interface IUSDC is IERC20 { function decimals() external view returns (uint8); @@ -29,7 +31,7 @@ contract LMSRMarket is ReentrancyGuard { IUSDC public immutable USDC; uint8 public constant USDC_DECIMALS = 6; - uint256 public constant SCALE = 1e18; + uint256 public constant SCALE = 1e6; uint256 public constant MIN_AMOUNT = 1e6; // 1 USDC minimum /*////////////////////////////////////////////////////////////// @@ -94,11 +96,12 @@ contract LMSRMarket is ReentrancyGuard { newQuantities[outcome] += amount; uint256 newCost = calculateLMSRCost(newQuantities); + console.log('newCost: ', newCost); + uint256 currentCost = calculateLMSRCost(quantities); - cost = newCost > currentCost ? newCost - currentCost : 0; + console.log('currenCost: ', currentCost); - // Convert to USDC decimals - cost = cost / (10 ** (18 - USDC_DECIMALS)); + cost = newCost > currentCost ? newCost - currentCost : 0; } function buyShares(uint256 outcome, uint256 amount) external nonReentrant { @@ -107,6 +110,7 @@ contract LMSRMarket is ReentrancyGuard { if (amount < MIN_AMOUNT) revert AmountTooSmall(); uint256 cost = calculateCost(outcome, amount); + console.log(cost); bool success = USDC.transferFrom(msg.sender, address(this), cost); if (!success) revert TransferFailed(); @@ -161,20 +165,32 @@ contract LMSRMarket is ReentrancyGuard { INTERNAL FUNCTIONS //////////////////////////////////////////////////////////////*/ - function calculateLMSRCost(uint256[] memory _quantities) internal view returns (uint256) { - uint256 sum = 0; + function calculateLMSRCost(uint256[] memory _quantities) public view returns (uint256) { + require(_quantities.length == numOutcomes, "Quantities length mismatch"); + + UD60x18 sum = ud(0); + for (uint256 i = 0; i < numOutcomes; i++) { - sum += exp((_quantities[i] * SCALE) / liquidity); + // Convert quantity to UD60x18, ensuring proper scaling + UD60x18 qi = ud(_quantities[i]); + + // Calculate qi / b + UD60x18 qi_div_b = qi.div(ud(liquidity)); + + // Calculate exp(qi / b) + UD60x18 exp_qi_div_b = qi_div_b.exp(); + + // Sum up the exponentials + sum = sum.add(exp_qi_div_b); } - return liquidity * ln(sum); - } - // Replace with proper math library in production - function exp(uint256 x) internal pure returns (uint256) { - return x + SCALE; - } + // Calculate ln(sum) + UD60x18 ln_sum = sum.ln(); + + // Multiply by liquidity (b) + UD60x18 cost = ud(liquidity).mul(ln_sum); - function ln(uint256 x) internal pure returns (uint256) { - return x - SCALE; + // Return the cost as uint256 (unwrap the UD60x18) + return unwrap(cost); } } diff --git a/test/LMSR.t.sol b/test/LMSR.t.sol index 32538e4..312f854 100644 --- a/test/LMSR.t.sol +++ b/test/LMSR.t.sol @@ -1,8 +1,8 @@ -// test/LMSRMarket.t.sol // SPDX-License-Identifier: MIT pragma solidity ^0.8.19; -import "forge-std/Test.sol"; +import { Test } from "@forge-std/Test.sol"; +import { console } from "@forge-std/console.sol"; import "../src/LMSR.sol"; import {MockERC20} from "./ERC20.m.sol"; @@ -12,8 +12,9 @@ contract LMSRMarketTest is Test { address alice = address(0x1); address bob = address(0x2); + address charlie = address(0x3); - uint256 constant INITIAL_BALANCE = 1000000e6; // 1M USDC + uint256 constant INITIAL_BALANCE = 1_000_000e6; uint256 constant LIQUIDITY = 1000e18; uint256 constant NUM_OUTCOMES = 2; @@ -25,15 +26,37 @@ contract LMSRMarketTest is Test { market = new LMSRMarket(address(usdc), LIQUIDITY, NUM_OUTCOMES); // Setup test accounts - vm.startPrank(alice); - usdc.mint(alice, INITIAL_BALANCE); - usdc.approve(address(market), type(uint256).max); - vm.stopPrank(); + address[3] memory users = [alice, bob, charlie]; + for (uint256 i = 0; i < users.length; i++) { + vm.startPrank(users[i]); + usdc.mint(users[i], INITIAL_BALANCE); + usdc.approve(address(market), type(uint256).max); + vm.stopPrank(); + } + } - vm.startPrank(bob); - usdc.mint(bob, INITIAL_BALANCE); - usdc.approve(address(market), type(uint256).max); - vm.stopPrank(); + function exp(uint256 x) internal pure returns (uint256) { + return x + 1e18; // Same as contract's implementation + } + + function getProbabilities() internal view returns (uint256[] memory) { + uint256[] memory probs = new uint256[](market.numOutcomes()); + + // Calculate total exponential sum for denominator + uint256 totalSum = 0; + for (uint256 i = 0; i < market.numOutcomes(); i++) { + uint256 quantity = market.quantities(i); + totalSum += exp((quantity * market.SCALE()) / market.liquidity()); + } + + // Calculate probability for each outcome + for (uint256 i = 0; i < market.numOutcomes(); i++) { + uint256 quantity = market.quantities(i); + uint256 expTerm = exp((quantity * market.SCALE()) / market.liquidity()); + probs[i] = (expTerm * market.SCALE()) / totalSum; + } + + return probs; } function test_InitialState() public view { @@ -41,6 +64,12 @@ contract LMSRMarketTest is Test { assertEq(market.liquidity(), LIQUIDITY); assertEq(address(market.USDC()), address(usdc)); assertEq(market.resolved(), false); + + // Check initial probabilities are equal + uint256[] memory probs = getProbabilities(); + for (uint256 i = 0; i < NUM_OUTCOMES; i++) { + assertApproxEqRel(probs[i], 0.5e18, 0.01e18); // 1% tolerance + } } function test_BuyShares() public { @@ -51,16 +80,40 @@ contract LMSRMarketTest is Test { uint256 cost = market.calculateCost(outcome, amount); uint256 balanceBefore = usdc.balanceOf(alice); + uint256[] memory probsBefore = getProbabilities(); market.buyShares(outcome, amount); + // Check balance and position updates assertEq(usdc.balanceOf(alice), balanceBefore - cost); assertEq(market.getUserPosition(alice, outcome), amount); assertEq(market.quantities(outcome), amount); + + // Check probability changes + uint256[] memory probsAfter = getProbabilities(); + assertTrue(probsAfter[outcome] > probsBefore[outcome], "Probability should increase"); + assertTrue(probsAfter[1 - outcome] < probsBefore[1 - outcome], "Other outcome prob should decrease"); + } + + function test_MultipleBuyers() public { + uint256 amount = 100e6; + + // Alice buys outcome 0 + vm.prank(alice); + market.buyShares(0, amount); + + // Bob buys outcome 1 + vm.prank(bob); + market.buyShares(1, amount); + + assertEq(market.getUserPosition(alice, 0), amount); + assertEq(market.getUserPosition(bob, 1), amount); + + uint256[] memory probs = getProbabilities(); + assertApproxEqRel(probs[0], probs[1], 0.01e18); // Should be roughly equal } function test_RevertWhenBuyingAfterResolution() public { - // Resolve market first market.resolveMarket(0); vm.startPrank(alice); @@ -72,22 +125,30 @@ contract LMSRMarketTest is Test { uint256 amount = 100e6; uint256 outcome = 0; - // Buy shares - vm.startPrank(alice); + // Alice and Bob buy shares + vm.prank(alice); market.buyShares(outcome, amount); - // Resolve market - vm.stopPrank(); + vm.prank(bob); + market.buyShares(1, amount); + + // Resolve market with outcome 0 market.resolveMarket(outcome); - // Claim winnings + // Alice claims winnings (winner) vm.startPrank(alice); - uint256 balanceBefore = usdc.balanceOf(alice); + uint256 aliceBalanceBefore = usdc.balanceOf(alice); market.claimWinnings(amount); uint256 expectedPayout = (amount * market.SCALE()) / (10 ** (18 - market.USDC_DECIMALS())); - assertEq(usdc.balanceOf(alice), balanceBefore + expectedPayout); + assertEq(usdc.balanceOf(alice), aliceBalanceBefore + expectedPayout); assertEq(market.getUserPosition(alice, outcome), 0); + + // Bob tries to claim (loser) + vm.startPrank(bob); + uint256 bobBalanceBefore = usdc.balanceOf(bob); + market.claimWinnings(amount); + assertEq(usdc.balanceOf(bob), bobBalanceBefore, "Losing position should pay nothing"); } function test_RevertWhenClaimingBeforeResolution() public { @@ -97,4 +158,43 @@ contract LMSRMarketTest is Test { vm.expectRevert(MarketNotResolved.selector); market.claimWinnings(100e6); } + + function test_RevertInvalidOutcome() public { + vm.startPrank(alice); + vm.expectRevert(InvalidOutcome.selector); + market.buyShares(NUM_OUTCOMES, 100e6); // Invalid outcome index + } + + function test_RevertInsufficientBalance() public { + vm.startPrank(charlie); + usdc.transfer(alice, INITIAL_BALANCE); // Transfer all balance away + + vm.expectRevert(); + market.buyShares(0, 100e6); + } + + function test_GetProbabilities() public { + vm.startPrank(alice); + usdc.mint(alice, 10_000_000e6); + usdc.approve(address(market), type(uint256).max); + vm.stopPrank(); + + // Buy a large amount of outcome 0 to skew probabilities + vm.prank(alice); + console.log('pre purchase cost: ', market.calculateCost(0, 1e6)); + + market.buyShares(0, 1e6); + + console.log('post purchase cost: ', market.calculateCost(0, 1e6)); + + uint256[] memory probs = getProbabilities(); + console.log("prob: ", probs[0]); + console.log("prob: ", probs[1]); + + assertTrue(probs[0] > 0.75e18, "Probability should be heavily skewed"); + assertTrue(probs[1] < 0.25e18, "Other outcome should have low probability"); + + // Sum of probabilities should equal 1 + assertApproxEqRel(probs[0] + probs[1], 1e18, 0.01e18); + } }