Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/common/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
},
"dependencies": {
"@latticexyz/schema-type": "workspace:*",
"@solidity-parser/parser": "^0.16.0",
"@nomicfoundation/slang": "^1.2.1",
"abitype": "1.0.9",
"debug": "^4.3.4",
"execa": "^9.5.2",
Expand Down
48 changes: 48 additions & 0 deletions packages/common/src/codegen/utils/contractToInterface.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { describe, expect, it } from "vitest";
import { contractToInterface } from "./contractToInterface";

const source = `
import { Data } from "./lib.sol";

contract World {
constructor() {}
fallback(bytes calldata input) external payable returns (bytes memory) {}
receive() external payable {}
function update(Data.Entity[] memory entities, uint delta) public returns (bool) {}
function visible() view external {}
function invisible() internal {}
error UpdateError(uint entityId);
}
`;

describe("contractToInterface", () => {
it("extracts public functions and errors from contract", () => {
const { functions, errors, symbolImports } = contractToInterface(source, "World");
expect(functions).toStrictEqual([
{
name: "update",
parameters: ["Data.Entity[] memory entities", "uint delta"],
returnParameters: ["bool"],
stateMutability: "",
},
{
name: "visible",
parameters: [],
returnParameters: [],
stateMutability: "view",
},
]);
expect(errors).toStrictEqual([
{
name: "UpdateError",
parameters: ["uint entityId"],
},
]);
expect(symbolImports).toStrictEqual([
{
path: "./lib.sol",
symbol: "Data",
},
]);
});
});
231 changes: 109 additions & 122 deletions packages/common/src/codegen/utils/contractToInterface.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
import { parse, visit } from "@solidity-parser/parser";
import type { SourceUnit, TypeName, VariableDeclaration } from "@solidity-parser/parser/dist/src/ast-types";
import { MUDError } from "../../errors";
import { findContractNode } from "./findContractNode";
import { findContractOrInterfaceNode } from "./findContractOrInterfaceNode";
import { SymbolImport, findSymbolImport } from "./findSymbolImport";
import { Parser } from "@nomicfoundation/slang/parser";
import { assertNonterminalNode, Cursor, Query, TerminalNode } from "@nomicfoundation/slang/cst";
import {
ArrayTypeName,
ErrorDefinition,
ErrorParametersDeclaration,
FunctionDefinition,
IdentifierPath,
MemberAccessExpression,
ParametersDeclaration,
TypeName,
} from "@nomicfoundation/slang/ast";

export interface ContractInterfaceFunction {
name: string;
Expand Down Expand Up @@ -39,13 +49,17 @@ export function contractToInterface(
symbolImports: SymbolImport[];
qualifiedSymbols: Map<string, QualifiedSymbol>;
} {
let ast: SourceUnit;
try {
ast = parse(source);
} catch (error) {
throw new MUDError(`Failed to parse contract ${contractName}: ${error}`);
const parser = Parser.create("0.8.24");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to avoid hardcoding the version here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we already parse the foundry config and pass that through in a few places in MUD tools, so we could pass that through here to avoid hardcoding it

Copy link
Member

@frolic frolic Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, why wouldn't the parser use the pragma to determine which version to use? 🤔

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure what the version policy wrt Solidity was for MUD. Is it bound to the MUD release, or is completely user defined? Otherwise, Slang can infer the version from the source code, or we can extract it as a constant, or pass it through as a parameter. Happy to implement whatever solution works best for you.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with the version inference path. It should be the safest and most durable.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alvrs @frolic I noticed that the currently used solc version in this repository is updated manually every now and then. Example: aabd307

In that commit, while projects/existing code specifically use 0.8.24, the Solidity files contain a more permissive pragma pragma solidity >=0.8.24;. When using the Slang parser, the version can be hard-coded/updated manually along the rest of the codebase like the commit above, or inferred automatically (as suggested by @ggiraldez's comment here). But with inference, a higher version could be inferred, since it would match >=0.8.24. Up to you on which solution you prefer.

Please let us know if you have any other feedback.

const parserResult = parser.parseFileContents(source);
if (!parserResult.isValid()) {
const errorMessage = parserResult
.errors()
.map((error) => error.message)
.join("\n");
throw new MUDError(`Failed to parse contract ${contractName}: ${errorMessage}`);
}
const contractNode = findContractNode(ast, contractName);
const root = parserResult.createTreeCursor();
const contractNode = findContractOrInterfaceNode(root, contractName);
let symbolImports: SymbolImport[] = [];
const functions: ContractInterfaceFunction[] = [];
const errors: ContractInterfaceError[] = [];
Expand All @@ -55,55 +69,76 @@ export function contractToInterface(
throw new MUDError(`Contract not found: ${contractName}`);
}

visit(contractNode, {
FunctionDefinition({
name,
visibility,
parameters,
stateMutability,
returnParameters,
isConstructor,
isFallback,
isReceiveEther,
}) {
try {
// skip constructor and fallbacks
if (isConstructor || isFallback || isReceiveEther) return;
// forbid default visibility (this check might be unnecessary, modern solidity already disallows this)
if (visibility === "default") throw new MUDError(`Visibility is not specified`);

if (visibility === "external" || visibility === "public") {
functions.push({
name: name === null ? "" : name,
parameters: parameters.map(parseParameter),
stateMutability: stateMutability || "",
returnParameters: returnParameters === null ? [] : returnParameters.map(parseParameter),
});

for (const { typeName } of parameters.concat(returnParameters ?? [])) {
const symbols = typeNameToSymbols(typeName);
symbolImports = symbolImports.concat(symbolsToImports(ast, symbols, findInheritedSymbol, qualifiedSymbols));
}
for (const match of contractNode.query([
Query.create(`
@function [FunctionDefinition [FunctionName [Identifier]]]
`),
Query.create(`
@error [ErrorDefinition]
`),
])) {
if (match.captures["function"]) {
// Functions
const funcNode = match.captures["function"]?.[0].node;
assertNonterminalNode(funcNode);
const funcDef = new FunctionDefinition(funcNode);
const name = funcDef.name.cst.unparse().trim();

let visibility = undefined;
let stateMutability = "";
for (const item of funcDef.attributes.items) {
const attribute = item.cst.unparse().trim();
switch (attribute) {
case "public":
case "private":
case "internal":
case "external":
visibility = attribute;
break;
case "view":
case "pure":
case "payable":
stateMutability = attribute;
break;
}
} catch (error: unknown) {
if (error instanceof MUDError) {
error.message = `Function "${name}" in contract "${contractName}": ${error.message}`;
}
throw error;
}
},
CustomErrorDefinition({ name, parameters }) {
if (visibility === undefined) throw new MUDError(`Visibility is not specified for function '${name}'`);
if (visibility !== "public" && visibility !== "external") {
continue;
}

functions.push({
name,
parameters: splatParameters(funcDef.parameters),
stateMutability,
returnParameters: splatParameters(funcDef.returns?.variables),
});

for (const { typeName } of funcDef.parameters.parameters.items.concat(
funcDef.returns?.variables.parameters.items ?? [],
)) {
const symbols = typeNameToSymbols(typeName);
symbolImports = symbolImports.concat(symbolsToImports(root, symbols, findInheritedSymbol, qualifiedSymbols));
}
} else if (match.captures["error"]) {
// Custom errors
const errorNode = match.captures.error?.[0].node;
assertNonterminalNode(errorNode);
const errorDef = new ErrorDefinition(errorNode);
const name = errorDef.name.unparse().trim();
errors.push({
name,
parameters: parameters.map(parseParameter),
parameters: splatParameters(errorDef.members),
});

for (const parameter of parameters) {
const symbols = typeNameToSymbols(parameter.typeName);
symbolImports = symbolImports.concat(symbolsToImports(ast, symbols, findInheritedSymbol, qualifiedSymbols));
for (const { typeName } of errorDef.members.parameters.items) {
const symbols = typeNameToSymbols(typeName);
symbolImports = symbolImports.concat(symbolsToImports(root, symbols, findInheritedSymbol, qualifiedSymbols));
}
},
});
}
}

symbolImports = deduplicateSymbolImports(symbolImports);

return {
functions,
Expand All @@ -113,85 +148,33 @@ export function contractToInterface(
};
}

function parseParameter({ name, typeName, storageLocation }: VariableDeclaration): string {
let typedNameWithLocation = "";

const { name: flattenedTypeName, stateMutability } = flattenTypeName(typeName);
// type name (e.g. uint256)
typedNameWithLocation += flattenedTypeName;
// optional mutability (e.g. address payable)
if (stateMutability !== null) {
typedNameWithLocation += ` ${stateMutability}`;
}
// location, when relevant (e.g. string memory)
if (storageLocation !== null) {
typedNameWithLocation += ` ${storageLocation}`;
}
// optional variable name
if (name !== null) {
typedNameWithLocation += ` ${name}`;
}

return typedNameWithLocation;
}

function flattenTypeName(typeName: TypeName | null): { name: string; stateMutability: string | null } {
if (typeName === null) {
return {
name: "",
stateMutability: null,
};
}
if (typeName.type === "ElementaryTypeName") {
return {
name: typeName.name,
stateMutability: typeName.stateMutability,
};
} else if (typeName.type === "UserDefinedTypeName") {
return {
name: typeName.namePath,
stateMutability: null,
};
} else if (typeName.type === "ArrayTypeName") {
let length = "";
if (typeName.length?.type === "NumberLiteral") {
length = typeName.length.number;
} else if (typeName.length?.type === "Identifier") {
length = typeName.length.name;
}

const { name, stateMutability } = flattenTypeName(typeName.baseTypeName);
return {
name: `${name}[${length}]`,
stateMutability,
};
} else {
// TODO function types are unsupported but could be useful
throw new MUDError(`Invalid typeName.type ${typeName.type}`);
}
function splatParameters(parameters: ParametersDeclaration | ErrorParametersDeclaration | undefined): string[] {
return parameters?.parameters.items.map((parameter) => parameter.cst.unparse().trim()) ?? [];
}

// Get symbols that need to be imported for given typeName
function typeNameToSymbols(typeName: TypeName | null): string[] {
if (typeName?.type === "UserDefinedTypeName") {
// split is needed to get a library, if types are internal to it
const symbol = typeName.namePath.split(".")[0];
return [symbol];
} else if (typeName?.type === "ArrayTypeName") {
const symbols = typeNameToSymbols(typeName.baseTypeName);
// array types can also use symbols (constants) for length
if (typeName.length?.type === "Identifier") {
const innerTypeName = typeName.length.name;
symbols.push(innerTypeName.split(".")[0]);
function typeNameToSymbols(typeName: TypeName): string[] {
const typeVariant = typeName.variant;
if (typeVariant instanceof IdentifierPath) {
return [typeVariant.items[0].unparse()];
} else if (typeVariant instanceof ArrayTypeName) {
const symbols = typeNameToSymbols(typeVariant.operand);
const indexVariant = typeVariant.index?.variant;
if (indexVariant instanceof TerminalNode) {
symbols.push(indexVariant.unparse());
} else if (indexVariant instanceof MemberAccessExpression) {
const memberOperandVariant = indexVariant.operand.variant;
if (memberOperandVariant instanceof TerminalNode) {
symbols.push(memberOperandVariant.unparse());
}
}
return symbols;
} else {
return [];
}
return [];
}

function symbolsToImports(
ast: SourceUnit,
root: Cursor,
symbols: string[],
findInheritedSymbol?: (symbol: string) => QualifiedSymbol | undefined,
qualifiedSymbols?: Map<string, QualifiedSymbol>,
Expand All @@ -200,7 +183,7 @@ function symbolsToImports(

for (const symbol of symbols) {
// First check explicit imports
const explicitImport = findSymbolImport(ast, symbol);
const explicitImport = findSymbolImport(root, symbol);
if (explicitImport) {
imports.push(explicitImport);
continue;
Expand All @@ -225,6 +208,10 @@ function symbolsToImports(
}
}

return imports;
}

function deduplicateSymbolImports(imports: SymbolImport[]): SymbolImport[] {
// Deduplicate imports
const uniqueImports = new Map<string, SymbolImport>();
for (const imp of imports) {
Expand Down
16 changes: 0 additions & 16 deletions packages/common/src/codegen/utils/findContractNode.ts

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { Cursor, Query } from "@nomicfoundation/slang/cst";

export function findContractOrInterfaceNode(root: Cursor, contractOrInterfaceName: string): Cursor | undefined {
for (const result of root.query([
Query.create(`
[ContractDefinition
name: ["${contractOrInterfaceName}"]
]
`),
Query.create(`
[InterfaceDefinition
name: ["${contractOrInterfaceName}"]
]
`),
])) {
return result.root;
}
}
Loading
Loading