Skip to content

Commit f01a887

Browse files
authored
Split-up core into Router and RouterPayable (#10)
1 parent e97dd05 commit f01a887

File tree

11 files changed

+267
-106
lines changed

11 files changed

+267
-106
lines changed

src/core/Router.sol

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@ import "../eip/ERC165.sol";
88

99
abstract contract Router is IRouter, ERC165 {
1010

11-
fallback() external payable virtual {
11+
fallback() external virtual {
1212
/// @dev delegate calls the appropriate implementation smart contract for a given function.
1313
address implementation = getImplementationForFunction(msg.sig);
1414
_delegate(implementation);
1515
}
1616

17-
receive() external payable virtual {}
18-
1917
/// @dev See {IERC165-supportsInterface}.
2018
function supportsInterface(bytes4 interfaceId) public view virtual override returns (bool) {
2119
return interfaceId == type(IRouter).interfaceId || super.supportsInterface(interfaceId);

src/core/RouterPayable.sol

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// SPDX-License-Identifier: MIT
2+
// @author: thirdweb (https://github.com/thirdweb-dev/dynamic-contracts)
3+
4+
pragma solidity ^0.8.0;
5+
6+
import "../interface/IRouterPayable.sol";
7+
import "../eip/ERC165.sol";
8+
9+
abstract contract RouterPayable is IRouterPayable, ERC165 {
10+
11+
fallback() external payable virtual {
12+
/// @dev delegate calls the appropriate implementation smart contract for a given function.
13+
address implementation = getImplementationForFunction(msg.sig);
14+
_delegate(implementation);
15+
}
16+
17+
receive() external payable virtual {}
18+
19+
/// @dev See {IERC165-supportsInterface}.
20+
function supportsInterface(bytes4 interfaceId) public view virtual override returns (bool) {
21+
return interfaceId == type(IRouterPayable).interfaceId || super.supportsInterface(interfaceId);
22+
}
23+
24+
/// @dev delegateCalls an `implementation` smart contract.
25+
function _delegate(address implementation) internal virtual {
26+
assembly {
27+
// Copy msg.data. We take full control of memory in this inline assembly
28+
// block because it will not return to Solidity code. We overwrite the
29+
// Solidity scratch pad at memory position 0.
30+
calldatacopy(0, 0, calldatasize())
31+
32+
// Call the implementation.
33+
// out and outsize are 0 because we don't know the size yet.
34+
let result := delegatecall(gas(), implementation, 0, calldatasize(), 0, 0)
35+
36+
// Copy the returned data.
37+
returndatacopy(0, 0, returndatasize())
38+
39+
switch result
40+
// delegatecall returns 0 on error.
41+
case 0 {
42+
revert(0, returndatasize())
43+
}
44+
default {
45+
return(0, returndatasize())
46+
}
47+
}
48+
}
49+
50+
/// @dev Unimplemented. Returns the implementation contract address for a given function signature.
51+
function getImplementationForFunction(bytes4 _functionSelector) public view virtual returns (address);
52+
}

src/interface/IRouter.sol

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
pragma solidity ^0.8.0;
55

66
interface IRouter {
7-
fallback() external payable;
8-
receive() external payable;
9-
7+
fallback() external;
108
function getImplementationForFunction(bytes4 _functionSelector) external view returns (address);
119
}

src/interface/IRouterPayable.sol

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// SPDX-License-Identifier: MIT
2+
// @author: thirdweb (https://github.com/thirdweb-dev/dynamic-contracts)
3+
4+
pragma solidity ^0.8.0;
5+
6+
interface IRouterPayable {
7+
fallback() external payable;
8+
receive() external payable;
9+
10+
function getImplementationForFunction(bytes4 _functionSelector) external view returns (address);
11+
}

src/presets/BaseRouter.sol

Lines changed: 5 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,36 +11,11 @@ import "../core/Router.sol";
1111

1212
// Utils
1313
import "./utils/StringSet.sol";
14-
import "./utils/DefaultExtensionSet.sol";
1514
import "./utils/ExtensionState.sol";
1615

1716
abstract contract BaseRouter is IBaseRouter, Router, ExtensionState {
1817
using StringSet for StringSet.Set;
1918

20-
/*///////////////////////////////////////////////////////////////
21-
State variables
22-
//////////////////////////////////////////////////////////////*/
23-
24-
/// @notice The DefaultExtensionSet that stores default extensions of the router.
25-
address public immutable defaultExtensionSet;
26-
27-
/*///////////////////////////////////////////////////////////////
28-
Constructor
29-
//////////////////////////////////////////////////////////////*/
30-
31-
constructor(Extension[] memory _extensions) {
32-
33-
DefaultExtensionSet map = new DefaultExtensionSet();
34-
defaultExtensionSet = address(map);
35-
36-
uint256 len = _extensions.length;
37-
38-
for (uint256 i = 0; i < len; i += 1) {
39-
require(_canSetExtension(_extensions[i]), "BaseRouter: not authorized.");
40-
map.setExtension(_extensions[i]);
41-
}
42-
}
43-
4419
/*///////////////////////////////////////////////////////////////
4520
ERC 165 logic
4621
//////////////////////////////////////////////////////////////*/
@@ -84,45 +59,21 @@ abstract contract BaseRouter is IBaseRouter, Router, ExtensionState {
8459
* given precedence over default extensions in DefaultExtensionSet.
8560
*/
8661
function getAllExtensions() external view returns (Extension[] memory allExtensions) {
87-
Extension[] memory defaultExtensions = IDefaultExtensionSet(defaultExtensionSet).getAllExtensions();
88-
uint256 defaultExtensionsLen = defaultExtensions.length;
89-
90-
ExtensionStateStorage.Data storage data = ExtensionStateStorage.extensionStateStorage();
91-
string[] memory names = data.extensionNames.values();
62+
string[] memory names = _extensionStateStorage().extensionNames.values();
9263
uint256 namesLen = names.length;
9364

94-
uint256 overrides = 0;
95-
for (uint256 i = 0; i < defaultExtensionsLen; i += 1) {
96-
if (data.extensionNames.contains(defaultExtensions[i].metadata.name)) {
97-
overrides += 1;
98-
}
99-
}
100-
101-
uint256 total = (namesLen + defaultExtensionsLen) - overrides;
102-
103-
allExtensions = new Extension[](total);
65+
allExtensions = new Extension[](namesLen);
10466
uint256 idx = 0;
10567

106-
for (uint256 i = 0; i < defaultExtensionsLen; i += 1) {
107-
string memory name = defaultExtensions[i].metadata.name;
108-
if (!data.extensionNames.contains(name)) {
109-
allExtensions[idx] = defaultExtensions[i];
110-
idx += 1;
111-
}
112-
}
113-
11468
for (uint256 i = 0; i < namesLen; i += 1) {
115-
allExtensions[idx] = data.extensions[names[i]];
69+
allExtensions[i] = _extensionStateStorage().extensions[names[i]];
11670
idx += 1;
11771
}
11872
}
11973

12074
/// @dev Returns the extension metadata and functions for a given extension.
12175
function getExtension(string memory _extensionName) public view returns (Extension memory) {
122-
ExtensionStateStorage.Data storage data = ExtensionStateStorage.extensionStateStorage();
123-
bool isLocalExtension = data.extensionNames.contains(_extensionName);
124-
125-
return isLocalExtension ? data.extensions[_extensionName] : IDefaultExtensionSet(defaultExtensionSet).getExtension(_extensionName);
76+
return _extensionStateStorage().extensions[_extensionName];
12677
}
12778

12879
/// @dev Returns the extension's implementation smart contract address.
@@ -137,12 +88,7 @@ abstract contract BaseRouter is IBaseRouter, Router, ExtensionState {
13788

13889
/// @dev Returns the extension metadata for a given function.
13990
function getExtensionForFunction(bytes4 _functionSelector) public view returns (ExtensionMetadata memory) {
140-
ExtensionStateStorage.Data storage data = ExtensionStateStorage.extensionStateStorage();
141-
ExtensionMetadata memory metadata = data.extensionMetadata[_functionSelector];
142-
143-
bool isLocalExtension = metadata.implementation != address(0);
144-
145-
return isLocalExtension ? metadata : IDefaultExtensionSet(defaultExtensionSet).getExtensionForFunction(_functionSelector);
91+
return _extensionStateStorage().extensionMetadata[_functionSelector];
14692
}
14793

14894
/// @dev Returns the extension implementation address stored in router, for the given function.
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
// SPDX-License-Identifier: MIT
2+
// @author: thirdweb (https://github.com/thirdweb-dev/dynamic-contracts)
3+
4+
pragma solidity ^0.8.0;
5+
6+
// Interface
7+
import "../interface/IBaseRouter.sol";
8+
9+
// Core
10+
import "../core/Router.sol";
11+
12+
// Utils
13+
import "./utils/StringSet.sol";
14+
import "./utils/DefaultExtensionSet.sol";
15+
import "./utils/ExtensionState.sol";
16+
17+
abstract contract BaseRouterWithDefaults is IBaseRouter, Router, ExtensionState {
18+
using StringSet for StringSet.Set;
19+
20+
/*///////////////////////////////////////////////////////////////
21+
State variables
22+
//////////////////////////////////////////////////////////////*/
23+
24+
/// @notice The DefaultExtensionSet that stores default extensions of the router.
25+
address public immutable defaultExtensionSet;
26+
27+
/*///////////////////////////////////////////////////////////////
28+
Constructor
29+
//////////////////////////////////////////////////////////////*/
30+
31+
constructor(Extension[] memory _extensions) {
32+
33+
DefaultExtensionSet map = new DefaultExtensionSet();
34+
defaultExtensionSet = address(map);
35+
36+
uint256 len = _extensions.length;
37+
38+
for (uint256 i = 0; i < len; i += 1) {
39+
require(_canSetExtension(_extensions[i]), "BaseRouter: not authorized.");
40+
map.setExtension(_extensions[i]);
41+
}
42+
}
43+
44+
/*///////////////////////////////////////////////////////////////
45+
ERC 165 logic
46+
//////////////////////////////////////////////////////////////*/
47+
48+
/// @dev See {IERC165-supportsInterface}.
49+
function supportsInterface(bytes4 interfaceId) public view virtual override returns (bool) {
50+
return interfaceId == type(IBaseRouter).interfaceId || super.supportsInterface(interfaceId);
51+
}
52+
53+
/*///////////////////////////////////////////////////////////////
54+
External functions
55+
//////////////////////////////////////////////////////////////*/
56+
57+
/// @dev Adds a new extension to the router.
58+
function addExtension(Extension memory _extension) external {
59+
require(_canSetExtension(_extension), "BaseRouter: not authorized.");
60+
61+
_addExtension(_extension);
62+
}
63+
64+
/// @dev Updates an existing extension in the router, or overrides a default extension.
65+
function updateExtension(Extension memory _extension) external {
66+
require(_canSetExtension(_extension), "BaseRouter: not authorized.");
67+
68+
_updateExtension(_extension);
69+
}
70+
71+
/// @dev Removes an existing extension from the router.
72+
function removeExtension(Extension memory _extension) external {
73+
require(_canSetExtension(_extension), "BaseRouter: not authorized.");
74+
75+
_removeExtension(_extension.metadata.name);
76+
}
77+
78+
/*///////////////////////////////////////////////////////////////
79+
View functions
80+
//////////////////////////////////////////////////////////////*/
81+
82+
/**
83+
* @notice Returns all extensions stored. Override default lugins stored in router are
84+
* given precedence over default extensions in DefaultExtensionSet.
85+
*/
86+
function getAllExtensions() external view returns (Extension[] memory allExtensions) {
87+
Extension[] memory defaultExtensions = IDefaultExtensionSet(defaultExtensionSet).getAllExtensions();
88+
uint256 defaultExtensionsLen = defaultExtensions.length;
89+
90+
string[] memory names = _extensionStateStorage().extensionNames.values();
91+
uint256 namesLen = names.length;
92+
93+
uint256 overrides = 0;
94+
for (uint256 i = 0; i < defaultExtensionsLen; i += 1) {
95+
if (_extensionStateStorage().extensionNames.contains(defaultExtensions[i].metadata.name)) {
96+
overrides += 1;
97+
}
98+
}
99+
100+
uint256 total = (namesLen + defaultExtensionsLen) - overrides;
101+
102+
allExtensions = new Extension[](total);
103+
uint256 idx = 0;
104+
105+
for (uint256 i = 0; i < defaultExtensionsLen; i += 1) {
106+
string memory name = defaultExtensions[i].metadata.name;
107+
if (!_extensionStateStorage().extensionNames.contains(name)) {
108+
allExtensions[idx] = defaultExtensions[i];
109+
idx += 1;
110+
}
111+
}
112+
113+
for (uint256 i = 0; i < namesLen; i += 1) {
114+
allExtensions[idx] = _extensionStateStorage().extensions[names[i]];
115+
idx += 1;
116+
}
117+
}
118+
119+
/// @dev Returns the extension metadata and functions for a given extension.
120+
function getExtension(string memory _extensionName) public view returns (Extension memory) {
121+
bool isLocalExtension = _extensionStateStorage().extensionNames.contains(_extensionName);
122+
123+
return isLocalExtension ? _extensionStateStorage().extensions[_extensionName] : IDefaultExtensionSet(defaultExtensionSet).getExtension(_extensionName);
124+
}
125+
126+
/// @dev Returns the extension's implementation smart contract address.
127+
function getExtensionImplementation(string memory _extensionName) external view returns (address) {
128+
return getExtension(_extensionName).metadata.implementation;
129+
}
130+
131+
/// @dev Returns all functions that belong to the given extension contract.
132+
function getAllFunctionsOfExtension(string memory _extensionName) external view returns (ExtensionFunction[] memory) {
133+
return getExtension(_extensionName).functions;
134+
}
135+
136+
/// @dev Returns the extension metadata for a given function.
137+
function getExtensionForFunction(bytes4 _functionSelector) public view returns (ExtensionMetadata memory) {
138+
ExtensionMetadata memory metadata = _extensionStateStorage().extensionMetadata[_functionSelector];
139+
140+
bool isLocalExtension = metadata.implementation != address(0);
141+
142+
return isLocalExtension ? metadata : IDefaultExtensionSet(defaultExtensionSet).getExtensionForFunction(_functionSelector);
143+
}
144+
145+
/// @dev Returns the extension implementation address stored in router, for the given function.
146+
function getImplementationForFunction(bytes4 _functionSelector)
147+
public
148+
view
149+
override
150+
returns (address extensionAddress)
151+
{
152+
return getExtensionForFunction(_functionSelector).implementation;
153+
}
154+
155+
/*///////////////////////////////////////////////////////////////
156+
Internal functions
157+
//////////////////////////////////////////////////////////////*/
158+
159+
/// @dev Returns whether a extension can be set in the given execution context.
160+
function _canSetExtension(Extension memory _extension) internal view virtual returns (bool);
161+
}

src/presets/example/RouterImmutable.sol

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33

44
pragma solidity ^0.8.0;
55

6-
import "../BaseRouter.sol";
6+
import "../BaseRouterWithDefaults.sol";
77

88
/**
99
* This smart contract is an EXAMPLE, and is not meant for use in production.
1010
*/
1111

12-
contract RouterImmutable is BaseRouter {
12+
contract RouterImmutable is BaseRouterWithDefaults {
1313

14-
constructor(Extension[] memory _extensions) BaseRouter(_extensions) {}
14+
constructor(Extension[] memory _extensions) BaseRouterWithDefaults(_extensions) {}
1515

1616
/*///////////////////////////////////////////////////////////////
1717
Overrides

0 commit comments

Comments
 (0)