Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/common/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
},
"dependencies": {
"@latticexyz/schema-type": "workspace:*",
"@solidity-parser/parser": "^0.16.0",
"@nomicfoundation/slang": "^1.2.1",
"abitype": "1.0.9",
"debug": "^4.3.4",
"execa": "^9.5.2",
Expand Down
48 changes: 48 additions & 0 deletions packages/common/src/codegen/utils/contractToInterface.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { describe, expect, it } from "vitest";
import { contractToInterface } from "./contractToInterface";

const source = `
import { Data } from "./lib.sol";

contract World {
constructor() {}
fallback(bytes calldata input) external payable returns (bytes memory) {}
receive() external payable {}
function update(Data.Entity[] memory entities, uint delta) public returns (bool) {}
function visible() view external {}
function invisible() internal {}
error UpdateError(uint entityId);
}
`;

describe("contractToInterface", () => {
it("extracts public functions and errors from contract", () => {
const { functions, errors, symbolImports } = contractToInterface(source, "World");
expect(functions).toStrictEqual([
{
name: "update",
parameters: ["Data.Entity[] memory entities", "uint delta"],
returnParameters: ["bool"],
stateMutability: "",
},
{
name: "visible",
parameters: [],
returnParameters: [],
stateMutability: "view",
},
]);
expect(errors).toStrictEqual([
{
name: "UpdateError",
parameters: ["uint entityId"],
},
]);
expect(symbolImports).toStrictEqual([
{
path: "./lib.sol",
symbol: "Data",
},
]);
});
});
233 changes: 111 additions & 122 deletions packages/common/src/codegen/utils/contractToInterface.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import { parse, visit } from "@solidity-parser/parser";
import type { SourceUnit, TypeName, VariableDeclaration } from "@solidity-parser/parser/dist/src/ast-types";
import { MUDError } from "../../errors";
import { findContractNode } from "./findContractNode";
import { findContractOrInterfaceNode } from "./findContractOrInterfaceNode";
import { SymbolImport, findSymbolImport } from "./findSymbolImport";
import { Parser } from "@nomicfoundation/slang/parser";
import { assertNonterminalNode, Cursor, Query, TerminalNode } from "@nomicfoundation/slang/cst";
import {
ArrayTypeName,
ErrorDefinition,
ErrorParametersDeclaration,
FunctionDefinition,
IdentifierPath,
MemberAccessExpression,
ParametersDeclaration,
TypeName,
} from "@nomicfoundation/slang/ast";
import { LanguageFacts } from "@nomicfoundation/slang/utils";

export interface ContractInterfaceFunction {
name: string;
Expand Down Expand Up @@ -39,13 +50,18 @@ export function contractToInterface(
symbolImports: SymbolImport[];
qualifiedSymbols: Map<string, QualifiedSymbol>;
} {
let ast: SourceUnit;
try {
ast = parse(source);
} catch (error) {
throw new MUDError(`Failed to parse contract ${contractName}: ${error}`);
const version = LanguageFacts.inferLanguageVersions(source).at(-1);
const parser = Parser.create(version ?? LanguageFacts.latestVersion());
const parserResult = parser.parseFileContents(source);
if (!parserResult.isValid()) {
const errorMessage = parserResult
.errors()
.map((error) => error.message)
.join("\n");
throw new MUDError(`Failed to parse contract ${contractName}: ${errorMessage}`);
}
const contractNode = findContractNode(ast, contractName);
const root = parserResult.createTreeCursor();
const contractNode = findContractOrInterfaceNode(root, contractName);
let symbolImports: SymbolImport[] = [];
const functions: ContractInterfaceFunction[] = [];
const errors: ContractInterfaceError[] = [];
Expand All @@ -55,55 +71,76 @@ export function contractToInterface(
throw new MUDError(`Contract not found: ${contractName}`);
}

visit(contractNode, {
FunctionDefinition({
name,
visibility,
parameters,
stateMutability,
returnParameters,
isConstructor,
isFallback,
isReceiveEther,
}) {
try {
// skip constructor and fallbacks
if (isConstructor || isFallback || isReceiveEther) return;
// forbid default visibility (this check might be unnecessary, modern solidity already disallows this)
if (visibility === "default") throw new MUDError(`Visibility is not specified`);

if (visibility === "external" || visibility === "public") {
functions.push({
name: name === null ? "" : name,
parameters: parameters.map(parseParameter),
stateMutability: stateMutability || "",
returnParameters: returnParameters === null ? [] : returnParameters.map(parseParameter),
});

for (const { typeName } of parameters.concat(returnParameters ?? [])) {
const symbols = typeNameToSymbols(typeName);
symbolImports = symbolImports.concat(symbolsToImports(ast, symbols, findInheritedSymbol, qualifiedSymbols));
}
for (const match of contractNode.query([
Query.create(`
@function [FunctionDefinition [FunctionName [Identifier]]]
`),
Query.create(`
@error [ErrorDefinition]
`),
])) {
if (match.captures["function"]) {
// Functions
const funcNode = match.captures["function"]?.[0].node;
assertNonterminalNode(funcNode);
const funcDef = new FunctionDefinition(funcNode);
const name = funcDef.name.cst.unparse().trim();

let visibility = undefined;
let stateMutability = "";
for (const item of funcDef.attributes.items) {
const attribute = item.cst.unparse().trim();
switch (attribute) {
case "public":
case "private":
case "internal":
case "external":
visibility = attribute;
break;
case "view":
case "pure":
case "payable":
stateMutability = attribute;
break;
}
} catch (error: unknown) {
if (error instanceof MUDError) {
error.message = `Function "${name}" in contract "${contractName}": ${error.message}`;
}
throw error;
}
},
CustomErrorDefinition({ name, parameters }) {
if (visibility === undefined) throw new MUDError(`Visibility is not specified for function '${name}'`);
if (visibility !== "public" && visibility !== "external") {
continue;
}

functions.push({
name,
parameters: splatParameters(funcDef.parameters),
stateMutability,
returnParameters: splatParameters(funcDef.returns?.variables),
});

for (const { typeName } of funcDef.parameters.parameters.items.concat(
funcDef.returns?.variables.parameters.items ?? [],
)) {
const symbols = typeNameToSymbols(typeName);
symbolImports = symbolImports.concat(symbolsToImports(root, symbols, findInheritedSymbol, qualifiedSymbols));
}
} else if (match.captures["error"]) {
// Custom errors
const errorNode = match.captures.error?.[0].node;
assertNonterminalNode(errorNode);
const errorDef = new ErrorDefinition(errorNode);
const name = errorDef.name.unparse().trim();
errors.push({
name,
parameters: parameters.map(parseParameter),
parameters: splatParameters(errorDef.members),
});

for (const parameter of parameters) {
const symbols = typeNameToSymbols(parameter.typeName);
symbolImports = symbolImports.concat(symbolsToImports(ast, symbols, findInheritedSymbol, qualifiedSymbols));
for (const { typeName } of errorDef.members.parameters.items) {
const symbols = typeNameToSymbols(typeName);
symbolImports = symbolImports.concat(symbolsToImports(root, symbols, findInheritedSymbol, qualifiedSymbols));
}
},
});
}
}

symbolImports = deduplicateSymbolImports(symbolImports);

return {
functions,
Expand All @@ -113,85 +150,33 @@ export function contractToInterface(
};
}

function parseParameter({ name, typeName, storageLocation }: VariableDeclaration): string {
let typedNameWithLocation = "";

const { name: flattenedTypeName, stateMutability } = flattenTypeName(typeName);
// type name (e.g. uint256)
typedNameWithLocation += flattenedTypeName;
// optional mutability (e.g. address payable)
if (stateMutability !== null) {
typedNameWithLocation += ` ${stateMutability}`;
}
// location, when relevant (e.g. string memory)
if (storageLocation !== null) {
typedNameWithLocation += ` ${storageLocation}`;
}
// optional variable name
if (name !== null) {
typedNameWithLocation += ` ${name}`;
}

return typedNameWithLocation;
}

function flattenTypeName(typeName: TypeName | null): { name: string; stateMutability: string | null } {
if (typeName === null) {
return {
name: "",
stateMutability: null,
};
}
if (typeName.type === "ElementaryTypeName") {
return {
name: typeName.name,
stateMutability: typeName.stateMutability,
};
} else if (typeName.type === "UserDefinedTypeName") {
return {
name: typeName.namePath,
stateMutability: null,
};
} else if (typeName.type === "ArrayTypeName") {
let length = "";
if (typeName.length?.type === "NumberLiteral") {
length = typeName.length.number;
} else if (typeName.length?.type === "Identifier") {
length = typeName.length.name;
}

const { name, stateMutability } = flattenTypeName(typeName.baseTypeName);
return {
name: `${name}[${length}]`,
stateMutability,
};
} else {
// TODO function types are unsupported but could be useful
throw new MUDError(`Invalid typeName.type ${typeName.type}`);
}
function splatParameters(parameters: ParametersDeclaration | ErrorParametersDeclaration | undefined): string[] {
return parameters?.parameters.items.map((parameter) => parameter.cst.unparse().trim()) ?? [];
}

// Get symbols that need to be imported for given typeName
function typeNameToSymbols(typeName: TypeName | null): string[] {
if (typeName?.type === "UserDefinedTypeName") {
// split is needed to get a library, if types are internal to it
const symbol = typeName.namePath.split(".")[0];
return [symbol];
} else if (typeName?.type === "ArrayTypeName") {
const symbols = typeNameToSymbols(typeName.baseTypeName);
// array types can also use symbols (constants) for length
if (typeName.length?.type === "Identifier") {
const innerTypeName = typeName.length.name;
symbols.push(innerTypeName.split(".")[0]);
function typeNameToSymbols(typeName: TypeName): string[] {
const typeVariant = typeName.variant;
if (typeVariant instanceof IdentifierPath) {
return [typeVariant.items[0].unparse()];
} else if (typeVariant instanceof ArrayTypeName) {
const symbols = typeNameToSymbols(typeVariant.operand);
const indexVariant = typeVariant.index?.variant;
if (indexVariant instanceof TerminalNode) {
symbols.push(indexVariant.unparse());
} else if (indexVariant instanceof MemberAccessExpression) {
const memberOperandVariant = indexVariant.operand.variant;
if (memberOperandVariant instanceof TerminalNode) {
symbols.push(memberOperandVariant.unparse());
}
}
return symbols;
} else {
return [];
}
return [];
}

function symbolsToImports(
ast: SourceUnit,
root: Cursor,
symbols: string[],
findInheritedSymbol?: (symbol: string) => QualifiedSymbol | undefined,
qualifiedSymbols?: Map<string, QualifiedSymbol>,
Expand All @@ -200,7 +185,7 @@ function symbolsToImports(

for (const symbol of symbols) {
// First check explicit imports
const explicitImport = findSymbolImport(ast, symbol);
const explicitImport = findSymbolImport(root, symbol);
if (explicitImport) {
imports.push(explicitImport);
continue;
Expand All @@ -225,6 +210,10 @@ function symbolsToImports(
}
}

return imports;
}

function deduplicateSymbolImports(imports: SymbolImport[]): SymbolImport[] {
// Deduplicate imports
const uniqueImports = new Map<string, SymbolImport>();
for (const imp of imports) {
Expand Down
16 changes: 0 additions & 16 deletions packages/common/src/codegen/utils/findContractNode.ts

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { Cursor, Query } from "@nomicfoundation/slang/cst";

export function findContractOrInterfaceNode(root: Cursor, contractOrInterfaceName: string): Cursor | undefined {
for (const result of root.query([
Query.create(`
[ContractDefinition
name: ["${contractOrInterfaceName}"]
]
`),
Query.create(`
[InterfaceDefinition
name: ["${contractOrInterfaceName}"]
]
`),
])) {
return result.root;
}
}
Loading